package task import ( "context" "encoding/hex" "encoding/json" "errors" "fmt" "math/rand" "sync" "time" logpkg "dbview/service/internal/common/logger" "github.com/gorilla/websocket" "go.uber.org/zap" ) // Manager 任务管理器,负责任务的创建、执行和状态跟踪 type Manager struct { tasks map[string]*Task // 任务ID到任务的映射 mutex sync.RWMutex // 保护tasks的互斥锁 workerPool chan struct{} // 工作池,控制并发数量 wg sync.WaitGroup // 等待所有任务完成 ctx context.Context // 上下文,用于取消所有任务 cancel context.CancelFunc // 取消函数 // WebSocket 订阅管理:taskID -> set of wrapped websocket connections wsMu sync.Mutex wsSubs map[string]map[*wsConn]struct{} connMap map[*websocket.Conn]*wsConn // 原始连接到包装的映射,用于查找 } // NewManager 创建新的任务管理器 // maxConcurrent 最大并发任务数 func NewManager(ctx context.Context, maxConcurrent int) *Manager { ctx, cancel := context.WithCancel(ctx) return &Manager{ tasks: make(map[string]*Task), workerPool: make(chan struct{}, maxConcurrent), ctx: ctx, cancel: cancel, wsSubs: make(map[string]map[*wsConn]struct{}), connMap: make(map[*websocket.Conn]*wsConn), } } // wsConn 是对 websocket.Conn 的包装结构。 // 目的: // - 将原始的 *websocket.Conn 包装为带有发送/接收缓冲的本地结构, // 避免直接并发调用 websocket.Conn 的写方法(导致 concurrent write panic)。 // - 提供独立的发送通道 `send`,由单独的 write pump 序列化所有写入操作; // 以及接收通道 `inbound`,由 read pump 将来自客户端的消息入队,供上层读取。 // // 字段说明: // - conn: 底层的 *websocket.Conn 实例。 // - send: 发送缓冲通道(消息由服务端入队,此通道由 write pump 串行写入 websocket)。 // 缓冲大小应有限制以防止单一慢客户端耗尽内存(目前使用 256 条消息)。 // - inbound: 接收缓冲通道(来自客户端的消息由 read pump 入队,上层通过 Register 获取此通道)。 // inbound 的缓冲大小也有限制(目前使用 16 条消息),超过则丢弃客户端消息以保护服务器。 // - closeCh: 一个用于通知 write pump 退出的信号通道(由 Unsubscribe 或清理逻辑关闭)。 type wsConn struct { conn *websocket.Conn send chan []byte inbound chan []byte closeCh chan struct{} } const ( writeWait = 10 * time.Second pongWait = 60 * time.Second pingPeriod = (pongWait * 9) / 10 maxMessageSize = 1024 * 8 ) // GenerateTaskID 生成唯一任务ID func (m *Manager) GenerateTaskID() string { bytes := make([]byte, 8) _, err := rand.Read(bytes) if err != nil { // 作为 fallback,使用时间戳+随机数 return fmt.Sprintf("%d%04d", time.Now().UnixNano(), rand.Intn(10000)) } return hex.EncodeToString(bytes) } // CreateTask 创建新任务但不执行 func (m *Manager) CreateTask(taskType string, opts ...TaskOption) *Task { taskID := m.GenerateTaskID() now := time.Now() task := &Task{ ID: taskID, Type: taskType, Status: TaskPending, Progress: 0, CreatedAt: now, } // 应用选项 for _, opt := range opts { opt(task) } m.mutex.Lock() m.tasks[taskID] = task m.mutex.Unlock() return task } // ExecuteTask 执行任务 func (m *Manager) ExecuteTask(task *Task, taskFunc TaskFunc) { m.wg.Add(1) go func() { defer m.wg.Done() // 等待工作池有空位或上下文被取消 select { case m.workerPool <- struct{}{}: defer func() { <-m.workerPool }() case <-m.ctx.Done(): m.updateTaskStatus(task.ID, TaskCancelled, 0, nil, m.ctx.Err().Error()) return } // 更新任务为运行中状态 startTime := time.Now() m.mutex.Lock() task.Status = TaskRunning task.StartedAt = &startTime m.mutex.Unlock() // 通知订阅者任务已开始 m.publishEvent(task.ID, map[string]interface{}{"taskId": task.ID, "status": string(TaskRunning), "progress": task.Progress}) // 执行任务函数 result, err := taskFunc(m.ctx, func(progress int) { // 确保进度在0-100之间 if progress < 0 { progress = 0 } else if progress > 100 { progress = 100 } m.updateTaskProgress(task.ID, progress) }) // 更新任务完成状态 completedTime := time.Now() if err != nil { m.updateTaskStatus(task.ID, TaskFailed, 100, nil, err.Error()) } else { m.updateTaskStatus(task.ID, TaskSuccess, 100, result, "") } m.mutex.Lock() task.CompletedAt = &completedTime m.mutex.Unlock() // 通知订阅者任务已完成(包含最终状态和结果/错误) m.publishEvent(task.ID, map[string]interface{}{"taskId": task.ID, "status": string(task.Status), "progress": task.Progress, "result": task.Result, "error": task.Error}) }() } // SubmitTask 创建并执行任务 func (m *Manager) SubmitTask(taskType string, taskFunc TaskFunc, opts ...TaskOption) *Task { task := m.CreateTask(taskType, opts...) m.ExecuteTask(task, taskFunc) return task } // GetTaskStatus 获取任务状态 func (m *Manager) GetTaskStatus(taskID string) (*Task, error) { m.mutex.RLock() defer m.mutex.RUnlock() task, exists := m.tasks[taskID] if !exists { return nil, errors.New("任务不存在") } // 返回任务的副本,避免外部修改 return &Task{ ID: task.ID, Type: task.Type, Source: task.Source, Status: task.Status, Progress: task.Progress, Result: task.Result, Error: task.Error, CreatedAt: task.CreatedAt, StartedAt: copyTime(task.StartedAt), CompletedAt: copyTime(task.CompletedAt), }, nil } // CancelTask 取消任务 func (m *Manager) CancelTask(taskID string) error { m.mutex.Lock() defer m.mutex.Unlock() task, exists := m.tasks[taskID] if !exists { return errors.New("任务不存在") } // 只有等待中或运行中的任务可以被取消 if task.Status != TaskPending && task.Status != TaskRunning { return errors.New("任务无法取消") } task.Status = TaskCancelled completedTime := time.Now() task.CompletedAt = &completedTime task.Error = "任务已被取消" return nil } // ListTasks 列出指定状态的任务 func (m *Manager) ListTasks(status TaskStatus) []*Task { m.mutex.RLock() defer m.mutex.RUnlock() var result []*Task for _, task := range m.tasks { if status == "" || task.Status == status { // 返回副本 result = append(result, &Task{ ID: task.ID, Type: task.Type, Source: task.Source, Status: task.Status, Progress: task.Progress, Result: task.Result, Error: task.Error, CreatedAt: task.CreatedAt, StartedAt: copyTime(task.StartedAt), CompletedAt: copyTime(task.CompletedAt), }) } } return result } // Shutdown 关闭任务管理器,等待所有任务完成 func (m *Manager) Shutdown() { m.cancel() m.wg.Wait() } // 内部方法:更新任务进度 func (m *Manager) updateTaskProgress(taskID string, progress int) { m.mutex.Lock() defer m.mutex.Unlock() if task, exists := m.tasks[taskID]; exists { task.Progress = progress // 发布进度更新给订阅者 m.publishEvent(taskID, map[string]interface{}{"taskId": taskID, "status": string(task.Status), "progress": task.Progress}) } } // 内部方法:更新任务状态 func (m *Manager) updateTaskStatus(taskID string, status TaskStatus, progress int, result interface{}, errMsg string) { m.mutex.Lock() defer m.mutex.Unlock() if task, exists := m.tasks[taskID]; exists { task.Status = status task.Progress = progress task.Result = result task.Error = errMsg // 发布状态变更事件(非阻塞) go m.publishEvent(taskID, map[string]interface{}{"taskId": taskID, "status": string(status), "progress": progress, "result": result, "error": errMsg}) } } // Subscribe 将 websocket 连接注册到指定 taskID 的订阅列表 func (m *Manager) Subscribe(conn *websocket.Conn, taskID string) { m.wsMu.Lock() defer m.wsMu.Unlock() // 获取或创建包装 w := m.connMap[conn] if w == nil { // 如果尚未注册读写泵,注册一个 w = &wsConn{conn: conn} m.connMap[conn] = w // lazy init channels to avoid nil panics w.send = make(chan []byte, 256) w.inbound = make(chan []byte, 16) w.closeCh = make(chan struct{}) m.startPumps(w) } subs, ok := m.wsSubs[taskID] if !ok { subs = make(map[*wsConn]struct{}) m.wsSubs[taskID] = subs } subs[w] = struct{}{} // log subscribe lg := logpkg.FromContext(context.Background()) lg.Info("ws subscribe", zap.String("taskId", taskID), zap.String("remote", conn.RemoteAddr().String())) } // Register 注册 websocket 连接并启动对应的读/写 pump。 // 返回值: // - inbound 通道:上层(handler)应从该通道读取客户端发来的消息(订阅/退订等命令)。 // // 使用说明: // - 如果同一个 *websocket.Conn 重复调用 Register,会返回已存在的 inbound 通道(幂等)。 // - Register 会创建并启动两个 goroutine:read pump(将客户端消息写入 inbound)和 // write pump(从 send 通道中读取消息并序列化写入底层连接,还负责心跳 ping)。 // // 并发/安全性: // - Register 在内部持有 wsMu 锁以保证 connMap 的并发安全。 // - write pump 是唯一会直接调用底层 conn.Write* 的协程,从而避免 concurrent write panic。 func (m *Manager) Register(conn *websocket.Conn) chan []byte { m.wsMu.Lock() defer m.wsMu.Unlock() if existing, ok := m.connMap[conn]; ok { return existing.inbound } w := &wsConn{ conn: conn, send: make(chan []byte, 256), inbound: make(chan []byte, 16), closeCh: make(chan struct{}), } m.connMap[conn] = w m.startPumps(w) lg := logpkg.FromContext(context.Background()) lg.Info("ws conn registered", zap.String("remote", conn.RemoteAddr().String())) return w.inbound } // Unsubscribe 从指定 taskID 的订阅列表移除连接。 // 行为说明: // - 仅从该 task 的订阅集合中移除包装 wsConn,如果该连接不再订阅任何任务,则完全移除并关闭关联资源。 // - 当连接不再被任何 task 使用时,函数会从 connMap 中删除对应条目,关闭 w.closeCh(通知 write pump 退出), // 并记录连接被移除的日志。注意:read pump 也会在检测到连接断开时做相似清理(互补)。 // // 并发性: // - 本方法持有 wsMu 锁以保护 wsSubs 与 connMap 的一致性。 func (m *Manager) Unsubscribe(conn *websocket.Conn, taskID string) { m.wsMu.Lock() defer m.wsMu.Unlock() w := m.connMap[conn] if w == nil { return } if subs, ok := m.wsSubs[taskID]; ok { delete(subs, w) if len(subs) == 0 { delete(m.wsSubs, taskID) } } // 如果该连接不再订阅任何 task,则从 connMap 移除 stillUsed := false for _, subs := range m.wsSubs { if _, ok := subs[w]; ok { stillUsed = true break } } if !stillUsed { delete(m.connMap, conn) // close pumps close(w.closeCh) lg := logpkg.FromContext(context.Background()) lg.Info("ws conn removed", zap.String("remote", conn.RemoteAddr().String())) } } // publishEvent 将事件推送给订阅 taskID 的所有 websocket 连接(非阻塞入队到各自 send 通道)。 // 设计要点: // - publishEvent 不直接往 websocket 连接写入数据,而是将序列化后的消息入队到每个订阅者的 `send` 缓冲通道, // 由各自的 write pump 负责实际的网络写入,从而避免并发写冲突。 // - 为了保护服务器稳定性,如果某个连接的 send 通道已满(可能由慢客户端导致),当前实现会记录告警并取消该连接的订阅, // 并尝试关闭底层连接。该策略可以避免单个慢客户端或恶意客户端耗尽服务器资源。 // - 注意:publishEvent 在跨进程/多实例部署下仅能推送到本实例内的订阅者,若需要跨实例广播需引入外部 pub/sub(如 Redis)。 func (m *Manager) publishEvent(taskID string, event interface{}) { m.wsMu.Lock() subs := m.wsSubs[taskID] m.wsMu.Unlock() if subs == nil || len(subs) == 0 { return } b, err := json.Marshal(event) if err != nil { return } lg := logpkg.FromContext(context.Background()) for w := range subs { select { case w.send <- b: // enqueued default: // send buffer full, drop and unregister lg.Warn("ws send buffer full, unsubscribing", zap.String("taskId", taskID), zap.String("remote", w.conn.RemoteAddr().String())) m.Unsubscribe(w.conn, taskID) _ = w.conn.Close() } } } // WriteJSON 将 v 序列化为 JSON 并尝试安全地发送到指定的 websocket 连接。 // 如果该连接由 manager 包装(存在于 connMap),则消息会被入队到其 send 通道,由 write pump 串行写出; // 否则会直接使用 conn.WriteJSON 写入(不推荐,因为可能造成并发写)。 // 返回错误场景: // - JSON 序列化失败 -> 返回序列化错误 // - send 通道已满 -> 返回错误并记录警告(调用方可选择重试或忽略) // // 说明:推荐上层总是通过 Manager.WriteJSON 或 publishEvent 让 write pump 负责写操作,避免 concurrent write 问题。 func (m *Manager) WriteJSON(conn *websocket.Conn, v interface{}) error { m.wsMu.Lock() w := m.connMap[conn] m.wsMu.Unlock() if w == nil { // 如果连接未包装,直接写(但这不推荐) return conn.WriteJSON(v) } b, err := json.Marshal(v) if err != nil { return err } select { case w.send <- b: return nil default: // buffer full lg := logpkg.FromContext(context.Background()) lg.Warn("ws writejson buffer full", zap.String("remote", conn.RemoteAddr().String())) return fmt.Errorf("ws send buffer full") } } // 辅助函数:复制时间指针 func copyTime(t *time.Time) *time.Time { if t == nil { return nil } newTime := *t return &newTime } // startPumps 启动 read/write pumps,使用 wsConn 的 channels func (m *Manager) startPumps(w *wsConn) { // read pump go func() { conn := w.conn conn.SetReadLimit(maxMessageSize) conn.SetReadDeadline(time.Now().Add(pongWait)) conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(pongWait)) return nil }) for { _, message, err := conn.ReadMessage() if err != nil { // read error or closed break } select { case w.inbound <- message: default: // drop if inbound full } } // cleanup on exit m.wsMu.Lock() delete(m.connMap, conn) // remove from all subscriptions for tid, subs := range m.wsSubs { if _, ok := subs[w]; ok { delete(subs, w) if len(subs) == 0 { delete(m.wsSubs, tid) } } } m.wsMu.Unlock() close(w.send) close(w.inbound) _ = conn.Close() }() // write pump go func() { conn := w.conn ticker := time.NewTicker(pingPeriod) defer func() { ticker.Stop() _ = conn.Close() }() for { select { case message, ok := <-w.send: conn.SetWriteDeadline(time.Now().Add(writeWait)) if !ok { // channel closed _ = conn.WriteMessage(websocket.CloseMessage, []byte{}) return } if err := conn.WriteMessage(websocket.TextMessage, message); err != nil { return } case <-ticker.C: conn.SetWriteDeadline(time.Now().Add(writeWait)) if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { return } case <-w.closeCh: return } } }() }