| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517 |
- package task
- import (
- "context"
- "encoding/hex"
- "encoding/json"
- "errors"
- "fmt"
- "math/rand"
- "sync"
- "time"
- logpkg "dbview/service/internal/common/logger"
- "github.com/gorilla/websocket"
- "go.uber.org/zap"
- )
- // Manager 任务管理器,负责任务的创建、执行和状态跟踪
- type Manager struct {
- tasks map[string]*Task // 任务ID到任务的映射
- mutex sync.RWMutex // 保护tasks的互斥锁
- workerPool chan struct{} // 工作池,控制并发数量
- wg sync.WaitGroup // 等待所有任务完成
- ctx context.Context // 上下文,用于取消所有任务
- cancel context.CancelFunc // 取消函数
- // WebSocket 订阅管理:taskID -> set of wrapped websocket connections
- wsMu sync.Mutex
- wsSubs map[string]map[*wsConn]struct{}
- connMap map[*websocket.Conn]*wsConn // 原始连接到包装的映射,用于查找
- }
- // NewManager 创建新的任务管理器
- // maxConcurrent 最大并发任务数
- func NewManager(ctx context.Context, maxConcurrent int) *Manager {
- ctx, cancel := context.WithCancel(ctx)
- return &Manager{
- tasks: make(map[string]*Task),
- workerPool: make(chan struct{}, maxConcurrent),
- ctx: ctx,
- cancel: cancel,
- wsSubs: make(map[string]map[*wsConn]struct{}),
- connMap: make(map[*websocket.Conn]*wsConn),
- }
- }
- // wsConn 是对 websocket.Conn 的包装结构。
- // 目的:
- // - 将原始的 *websocket.Conn 包装为带有发送/接收缓冲的本地结构,
- // 避免直接并发调用 websocket.Conn 的写方法(导致 concurrent write panic)。
- // - 提供独立的发送通道 `send`,由单独的 write pump 序列化所有写入操作;
- // 以及接收通道 `inbound`,由 read pump 将来自客户端的消息入队,供上层读取。
- //
- // 字段说明:
- // - conn: 底层的 *websocket.Conn 实例。
- // - send: 发送缓冲通道(消息由服务端入队,此通道由 write pump 串行写入 websocket)。
- // 缓冲大小应有限制以防止单一慢客户端耗尽内存(目前使用 256 条消息)。
- // - inbound: 接收缓冲通道(来自客户端的消息由 read pump 入队,上层通过 Register 获取此通道)。
- // inbound 的缓冲大小也有限制(目前使用 16 条消息),超过则丢弃客户端消息以保护服务器。
- // - closeCh: 一个用于通知 write pump 退出的信号通道(由 Unsubscribe 或清理逻辑关闭)。
- type wsConn struct {
- conn *websocket.Conn
- send chan []byte
- inbound chan []byte
- closeCh chan struct{}
- }
- const (
- writeWait = 10 * time.Second
- pongWait = 60 * time.Second
- pingPeriod = (pongWait * 9) / 10
- maxMessageSize = 1024 * 8
- )
- // GenerateTaskID 生成唯一任务ID
- func (m *Manager) GenerateTaskID() string {
- bytes := make([]byte, 8)
- _, err := rand.Read(bytes)
- if err != nil {
- // 作为 fallback,使用时间戳+随机数
- return fmt.Sprintf("%d%04d", time.Now().UnixNano(), rand.Intn(10000))
- }
- return hex.EncodeToString(bytes)
- }
- // CreateTask 创建新任务但不执行
- func (m *Manager) CreateTask(taskType string, opts ...TaskOption) *Task {
- taskID := m.GenerateTaskID()
- now := time.Now()
- task := &Task{
- ID: taskID,
- Type: taskType,
- Status: TaskPending,
- Progress: 0,
- CreatedAt: now,
- }
- // 应用选项
- for _, opt := range opts {
- opt(task)
- }
- m.mutex.Lock()
- m.tasks[taskID] = task
- m.mutex.Unlock()
- return task
- }
- // ExecuteTask 执行任务
- func (m *Manager) ExecuteTask(task *Task, taskFunc TaskFunc) {
- m.wg.Add(1)
- go func() {
- defer m.wg.Done()
- // 等待工作池有空位或上下文被取消
- select {
- case m.workerPool <- struct{}{}:
- defer func() { <-m.workerPool }()
- case <-m.ctx.Done():
- m.updateTaskStatus(task.ID, TaskCancelled, 0, nil, m.ctx.Err().Error())
- return
- }
- // 更新任务为运行中状态
- startTime := time.Now()
- m.mutex.Lock()
- task.Status = TaskRunning
- task.StartedAt = &startTime
- m.mutex.Unlock()
- // 通知订阅者任务已开始
- m.publishEvent(task.ID, map[string]interface{}{"taskId": task.ID, "status": string(TaskRunning), "progress": task.Progress})
- // 执行任务函数
- result, err := taskFunc(m.ctx, func(progress int) {
- // 确保进度在0-100之间
- if progress < 0 {
- progress = 0
- } else if progress > 100 {
- progress = 100
- }
- m.updateTaskProgress(task.ID, progress)
- })
- // 更新任务完成状态
- completedTime := time.Now()
- if err != nil {
- m.updateTaskStatus(task.ID, TaskFailed, 100, nil, err.Error())
- } else {
- m.updateTaskStatus(task.ID, TaskSuccess, 100, result, "")
- }
- m.mutex.Lock()
- task.CompletedAt = &completedTime
- m.mutex.Unlock()
- // 通知订阅者任务已完成(包含最终状态和结果/错误)
- m.publishEvent(task.ID, map[string]interface{}{"taskId": task.ID, "status": string(task.Status), "progress": task.Progress, "result": task.Result, "error": task.Error})
- }()
- }
- // SubmitTask 创建并执行任务
- func (m *Manager) SubmitTask(taskType string, taskFunc TaskFunc, opts ...TaskOption) *Task {
- task := m.CreateTask(taskType, opts...)
- m.ExecuteTask(task, taskFunc)
- return task
- }
- // GetTaskStatus 获取任务状态
- func (m *Manager) GetTaskStatus(taskID string) (*Task, error) {
- m.mutex.RLock()
- defer m.mutex.RUnlock()
- task, exists := m.tasks[taskID]
- if !exists {
- return nil, errors.New("任务不存在")
- }
- // 返回任务的副本,避免外部修改
- return &Task{
- ID: task.ID,
- Type: task.Type,
- Source: task.Source,
- Status: task.Status,
- Progress: task.Progress,
- Result: task.Result,
- Error: task.Error,
- CreatedAt: task.CreatedAt,
- StartedAt: copyTime(task.StartedAt),
- CompletedAt: copyTime(task.CompletedAt),
- }, nil
- }
- // CancelTask 取消任务
- func (m *Manager) CancelTask(taskID string) error {
- m.mutex.Lock()
- defer m.mutex.Unlock()
- task, exists := m.tasks[taskID]
- if !exists {
- return errors.New("任务不存在")
- }
- // 只有等待中或运行中的任务可以被取消
- if task.Status != TaskPending && task.Status != TaskRunning {
- return errors.New("任务无法取消")
- }
- task.Status = TaskCancelled
- completedTime := time.Now()
- task.CompletedAt = &completedTime
- task.Error = "任务已被取消"
- return nil
- }
- // ListTasks 列出指定状态的任务
- func (m *Manager) ListTasks(status TaskStatus) []*Task {
- m.mutex.RLock()
- defer m.mutex.RUnlock()
- var result []*Task
- for _, task := range m.tasks {
- if status == "" || task.Status == status {
- // 返回副本
- result = append(result, &Task{
- ID: task.ID,
- Type: task.Type,
- Source: task.Source,
- Status: task.Status,
- Progress: task.Progress,
- Result: task.Result,
- Error: task.Error,
- CreatedAt: task.CreatedAt,
- StartedAt: copyTime(task.StartedAt),
- CompletedAt: copyTime(task.CompletedAt),
- })
- }
- }
- return result
- }
- // Shutdown 关闭任务管理器,等待所有任务完成
- func (m *Manager) Shutdown() {
- m.cancel()
- m.wg.Wait()
- }
- // 内部方法:更新任务进度
- func (m *Manager) updateTaskProgress(taskID string, progress int) {
- m.mutex.Lock()
- defer m.mutex.Unlock()
- if task, exists := m.tasks[taskID]; exists {
- task.Progress = progress
- // 发布进度更新给订阅者
- m.publishEvent(taskID, map[string]interface{}{"taskId": taskID, "status": string(task.Status), "progress": task.Progress})
- }
- }
- // 内部方法:更新任务状态
- func (m *Manager) updateTaskStatus(taskID string, status TaskStatus, progress int, result interface{}, errMsg string) {
- m.mutex.Lock()
- defer m.mutex.Unlock()
- if task, exists := m.tasks[taskID]; exists {
- task.Status = status
- task.Progress = progress
- task.Result = result
- task.Error = errMsg
- // 发布状态变更事件(非阻塞)
- go m.publishEvent(taskID, map[string]interface{}{"taskId": taskID, "status": string(status), "progress": progress, "result": result, "error": errMsg})
- }
- }
- // Subscribe 将 websocket 连接注册到指定 taskID 的订阅列表
- func (m *Manager) Subscribe(conn *websocket.Conn, taskID string) {
- m.wsMu.Lock()
- defer m.wsMu.Unlock()
- // 获取或创建包装
- w := m.connMap[conn]
- if w == nil {
- // 如果尚未注册读写泵,注册一个
- w = &wsConn{conn: conn}
- m.connMap[conn] = w
- // lazy init channels to avoid nil panics
- w.send = make(chan []byte, 256)
- w.inbound = make(chan []byte, 16)
- w.closeCh = make(chan struct{})
- m.startPumps(w)
- }
- subs, ok := m.wsSubs[taskID]
- if !ok {
- subs = make(map[*wsConn]struct{})
- m.wsSubs[taskID] = subs
- }
- subs[w] = struct{}{}
- // log subscribe
- lg := logpkg.FromContext(context.Background())
- lg.Info("ws subscribe", zap.String("taskId", taskID), zap.String("remote", conn.RemoteAddr().String()))
- }
- // Register 注册 websocket 连接并启动对应的读/写 pump。
- // 返回值:
- // - inbound 通道:上层(handler)应从该通道读取客户端发来的消息(订阅/退订等命令)。
- //
- // 使用说明:
- // - 如果同一个 *websocket.Conn 重复调用 Register,会返回已存在的 inbound 通道(幂等)。
- // - Register 会创建并启动两个 goroutine:read pump(将客户端消息写入 inbound)和
- // write pump(从 send 通道中读取消息并序列化写入底层连接,还负责心跳 ping)。
- //
- // 并发/安全性:
- // - Register 在内部持有 wsMu 锁以保证 connMap 的并发安全。
- // - write pump 是唯一会直接调用底层 conn.Write* 的协程,从而避免 concurrent write panic。
- func (m *Manager) Register(conn *websocket.Conn) chan []byte {
- m.wsMu.Lock()
- defer m.wsMu.Unlock()
- if existing, ok := m.connMap[conn]; ok {
- return existing.inbound
- }
- w := &wsConn{
- conn: conn,
- send: make(chan []byte, 256),
- inbound: make(chan []byte, 16),
- closeCh: make(chan struct{}),
- }
- m.connMap[conn] = w
- m.startPumps(w)
- lg := logpkg.FromContext(context.Background())
- lg.Info("ws conn registered", zap.String("remote", conn.RemoteAddr().String()))
- return w.inbound
- }
- // Unsubscribe 从指定 taskID 的订阅列表移除连接。
- // 行为说明:
- // - 仅从该 task 的订阅集合中移除包装 wsConn,如果该连接不再订阅任何任务,则完全移除并关闭关联资源。
- // - 当连接不再被任何 task 使用时,函数会从 connMap 中删除对应条目,关闭 w.closeCh(通知 write pump 退出),
- // 并记录连接被移除的日志。注意:read pump 也会在检测到连接断开时做相似清理(互补)。
- //
- // 并发性:
- // - 本方法持有 wsMu 锁以保护 wsSubs 与 connMap 的一致性。
- func (m *Manager) Unsubscribe(conn *websocket.Conn, taskID string) {
- m.wsMu.Lock()
- defer m.wsMu.Unlock()
- w := m.connMap[conn]
- if w == nil {
- return
- }
- if subs, ok := m.wsSubs[taskID]; ok {
- delete(subs, w)
- if len(subs) == 0 {
- delete(m.wsSubs, taskID)
- }
- }
- // 如果该连接不再订阅任何 task,则从 connMap 移除
- stillUsed := false
- for _, subs := range m.wsSubs {
- if _, ok := subs[w]; ok {
- stillUsed = true
- break
- }
- }
- if !stillUsed {
- delete(m.connMap, conn)
- // close pumps
- close(w.closeCh)
- lg := logpkg.FromContext(context.Background())
- lg.Info("ws conn removed", zap.String("remote", conn.RemoteAddr().String()))
- }
- }
- // publishEvent 将事件推送给订阅 taskID 的所有 websocket 连接(非阻塞入队到各自 send 通道)。
- // 设计要点:
- // - publishEvent 不直接往 websocket 连接写入数据,而是将序列化后的消息入队到每个订阅者的 `send` 缓冲通道,
- // 由各自的 write pump 负责实际的网络写入,从而避免并发写冲突。
- // - 为了保护服务器稳定性,如果某个连接的 send 通道已满(可能由慢客户端导致),当前实现会记录告警并取消该连接的订阅,
- // 并尝试关闭底层连接。该策略可以避免单个慢客户端或恶意客户端耗尽服务器资源。
- // - 注意:publishEvent 在跨进程/多实例部署下仅能推送到本实例内的订阅者,若需要跨实例广播需引入外部 pub/sub(如 Redis)。
- func (m *Manager) publishEvent(taskID string, event interface{}) {
- m.wsMu.Lock()
- subs := m.wsSubs[taskID]
- m.wsMu.Unlock()
- if subs == nil || len(subs) == 0 {
- return
- }
- b, err := json.Marshal(event)
- if err != nil {
- return
- }
- lg := logpkg.FromContext(context.Background())
- for w := range subs {
- select {
- case w.send <- b:
- // enqueued
- default:
- // send buffer full, drop and unregister
- lg.Warn("ws send buffer full, unsubscribing", zap.String("taskId", taskID), zap.String("remote", w.conn.RemoteAddr().String()))
- m.Unsubscribe(w.conn, taskID)
- _ = w.conn.Close()
- }
- }
- }
- // WriteJSON 将 v 序列化为 JSON 并尝试安全地发送到指定的 websocket 连接。
- // 如果该连接由 manager 包装(存在于 connMap),则消息会被入队到其 send 通道,由 write pump 串行写出;
- // 否则会直接使用 conn.WriteJSON 写入(不推荐,因为可能造成并发写)。
- // 返回错误场景:
- // - JSON 序列化失败 -> 返回序列化错误
- // - send 通道已满 -> 返回错误并记录警告(调用方可选择重试或忽略)
- //
- // 说明:推荐上层总是通过 Manager.WriteJSON 或 publishEvent 让 write pump 负责写操作,避免 concurrent write 问题。
- func (m *Manager) WriteJSON(conn *websocket.Conn, v interface{}) error {
- m.wsMu.Lock()
- w := m.connMap[conn]
- m.wsMu.Unlock()
- if w == nil {
- // 如果连接未包装,直接写(但这不推荐)
- return conn.WriteJSON(v)
- }
- b, err := json.Marshal(v)
- if err != nil {
- return err
- }
- select {
- case w.send <- b:
- return nil
- default:
- // buffer full
- lg := logpkg.FromContext(context.Background())
- lg.Warn("ws writejson buffer full", zap.String("remote", conn.RemoteAddr().String()))
- return fmt.Errorf("ws send buffer full")
- }
- }
- // 辅助函数:复制时间指针
- func copyTime(t *time.Time) *time.Time {
- if t == nil {
- return nil
- }
- newTime := *t
- return &newTime
- }
- // startPumps 启动 read/write pumps,使用 wsConn 的 channels
- func (m *Manager) startPumps(w *wsConn) {
- // read pump
- go func() {
- conn := w.conn
- conn.SetReadLimit(maxMessageSize)
- conn.SetReadDeadline(time.Now().Add(pongWait))
- conn.SetPongHandler(func(string) error {
- conn.SetReadDeadline(time.Now().Add(pongWait))
- return nil
- })
- for {
- _, message, err := conn.ReadMessage()
- if err != nil {
- // read error or closed
- break
- }
- select {
- case w.inbound <- message:
- default:
- // drop if inbound full
- }
- }
- // cleanup on exit
- m.wsMu.Lock()
- delete(m.connMap, conn)
- // remove from all subscriptions
- for tid, subs := range m.wsSubs {
- if _, ok := subs[w]; ok {
- delete(subs, w)
- if len(subs) == 0 {
- delete(m.wsSubs, tid)
- }
- }
- }
- m.wsMu.Unlock()
- close(w.send)
- close(w.inbound)
- _ = conn.Close()
- }()
- // write pump
- go func() {
- conn := w.conn
- ticker := time.NewTicker(pingPeriod)
- defer func() {
- ticker.Stop()
- _ = conn.Close()
- }()
- for {
- select {
- case message, ok := <-w.send:
- conn.SetWriteDeadline(time.Now().Add(writeWait))
- if !ok {
- // channel closed
- _ = conn.WriteMessage(websocket.CloseMessage, []byte{})
- return
- }
- if err := conn.WriteMessage(websocket.TextMessage, message); err != nil {
- return
- }
- case <-ticker.C:
- conn.SetWriteDeadline(time.Now().Add(writeWait))
- if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil {
- return
- }
- case <-w.closeCh:
- return
- }
- }
- }()
- }
|