package bootstrap import ( "bytes" "context" "encoding/json" "fmt" "net/http" "time" "dbview/service/internal/common/databases/meta" "dbview/service/internal/common/logger" "dbview/service/internal/common/manager/connection" "dbview/service/internal/common/manager/storage" "dbview/service/internal/common/manager/storage/types" "dbview/service/internal/common/manager/task" "dbview/service/internal/common/mcp" "dbview/service/internal/config" dq "dbview/service/internal/modules/data_query/service" mcpgo "github.com/mark3labs/mcp-go/mcp" "go.uber.org/zap" ) // initializeMCP creates and starts the MCP component when enabled in config. // It registers a few useful tools backed by existing storage and query services. func initializeMCP(ctx context.Context, cfg *config.AppConfig, log logger.Logger, storageMgr storage.StorageInterface, pool *connection.ConnectionPool, taskMgr *task.Manager) *mcp.Component { if cfg == nil || !cfg.MCP.Enable { if log != nil { log.Info("MCP 组件未启用,跳过初始化", zap.Bool("enabled", cfg != nil && cfg.MCP.Enable)) } return nil } comp := mcp.NewComponent(mcp.Config{ Enable: cfg.MCP.Enable, ServerName: cfg.MCP.ServerName, ServerVersion: cfg.MCP.ServerVersion, }, log) // 基础工具 comp.RegisterHealthTool() comp.RegisterEchoTool() // 与当前业务相关的工具 registerConnectionTools(comp, storageMgr, log) registerExecuteSQLTool(comp, pool, taskMgr, log) registerAIChatTool(comp, cfg.AI, storageMgr, log) go func() { if err := comp.ServeStdio(ctx); err != nil && log != nil { log.Error("MCP ServeStdio 启动失败", zap.Error(err)) } }() if log != nil { log.Info("MCP 组件已启用", zap.String("server_name", cfg.MCP.ServerName), zap.String("version", cfg.MCP.ServerVersion)) } return comp } // registerConnectionTools exposes a simple connection listing backed by storage. func registerConnectionTools(comp *mcp.Component, storageMgr storage.StorageInterface, log logger.Logger) { if storageMgr == nil { if log != nil { log.Warn("未找到存储管理器,跳过 MCP 连接工具注册") } return } tool := mcpgo.NewTool("connections_list", mcpgo.WithDescription("列出已存储的连接(不含敏感字段)"), ) comp.Server().AddTool(tool, func(ctx context.Context, req mcpgo.CallToolRequest) (*mcpgo.CallToolResult, error) { conns, err := storageMgr.GetAllConnections() if err != nil { return nil, err } summaries := make([]connectionSummary, 0, len(conns)) for _, c := range conns { summaries = append(summaries, summarizeConnection(c)) } payload, err := json.MarshalIndent(summaries, "", " ") if err != nil { return nil, err } return &mcpgo.CallToolResult{Content: []mcpgo.Content{mcpgo.TextContent{Text: string(payload)}}}, nil }) } // registerExecuteSQLTool exposes a synchronous SQL execution tool backed by the connection pool. func registerExecuteSQLTool(comp *mcp.Component, pool *connection.ConnectionPool, taskMgr *task.Manager, log logger.Logger) { if pool == nil { if log != nil { log.Warn("未找到连接池,跳过 MCP SQL 执行工具注册") } return } dataSvc := dq.NewDataService( dq.WithConnectionPool(pool), dq.WithTaskManager(taskMgr), ) tool := mcpgo.NewTool("execute_sql", mcpgo.WithDescription("在指定连接上同步执行 SQL"), mcpgo.WithString("connection_id", mcpgo.Description("连接ID,需已在连接池中就绪")), mcpgo.WithString("sql", mcpgo.Description("要执行的 SQL 语句")), ) comp.Server().AddTool(tool, func(ctx context.Context, req mcpgo.CallToolRequest) (*mcpgo.CallToolResult, error) { args, _ := req.Params.Arguments.(map[string]any) connID, _ := args["connection_id"].(string) sqlText, _ := args["sql"].(string) if connID == "" || sqlText == "" { return nil, fmt.Errorf("connection_id 与 sql 不能为空") } result, err := dataSvc.ExecuteSQL(ctx, connID, "", meta.ObjectPath{}, sqlText, nil, false, false) if err != nil { return nil, err } execRes, ok := result.(meta.ExecuteResult) if !ok { return nil, fmt.Errorf("unexpected execute result type %T", result) } payload, err := json.MarshalIndent(execRes, "", " ") if err != nil { return nil, err } return &mcpgo.CallToolResult{Content: []mcpgo.Content{mcpgo.TextContent{Text: string(payload)}}}, nil }) } // registerAIChatTool exposes a simple AI chat tool backed by OpenAI HTTP API. func registerAIChatTool(comp *mcp.Component, aiCfg config.AIConfig, storageMgr storage.StorageInterface, log logger.Logger) { if !aiCfg.Enable { if log != nil { log.Info("AI 交互未启用,跳过 MCP ai_chat 注册") } return } if storageMgr == nil { if log != nil { log.Warn("AI 交互已启用但存储未初始化,跳过 MCP ai_chat 注册") } return } aiSettings, err := storageMgr.GetAIConfig() if err != nil { if log != nil { log.Warn("读取 AI 配置失败,跳过 MCP ai_chat 注册", zap.Error(err)) } return } provider := aiSettings.Provider if provider == "" { provider = "openai" } baseURL := aiSettings.BaseURL if baseURL == "" { baseURL = "https://api.openai.com/v1/chat/completions" } defaultModel := aiSettings.Model if defaultModel == "" { defaultModel = "gpt-4o-mini" } if aiSettings.APIKey == "" { if log != nil { log.Warn("AI 交互已启用但未配置 api_key") } return } tool := mcpgo.NewTool("ai_chat", mcpgo.WithDescription("调用外部 AI 模型生成回复"), mcpgo.WithString("prompt", mcpgo.Description("用户提示,必填")), mcpgo.WithString("system", mcpgo.Description("系统指令,可选")), mcpgo.WithString("model", mcpgo.Description("覆盖默认模型,可选")), ) comp.Server().AddTool(tool, func(ctx context.Context, req mcpgo.CallToolRequest) (*mcpgo.CallToolResult, error) { args, _ := req.Params.Arguments.(map[string]any) prompt, _ := args["prompt"].(string) if prompt == "" { return nil, fmt.Errorf("prompt 不能为空") } systemPrompt, _ := args["system"].(string) model := defaultModel if override, ok := args["model"].(string); ok && override != "" { model = override } if p, ok := args["provider"].(string); ok && p != "" && p != provider { return nil, fmt.Errorf("不支持请求覆盖 provider") } callCtx, cancel := context.WithTimeout(ctx, 45*time.Second) defer cancel() respText, err := callOpenAICompatible(callCtx, baseURL, aiSettings.APIKey, model, systemPrompt, prompt, provider) if err != nil { return nil, err } return &mcpgo.CallToolResult{Content: []mcpgo.Content{mcpgo.TextContent{Text: respText}}}, nil }) if log != nil { log.Info("MCP ai_chat 工具已注册", zap.String("model", defaultModel), zap.String("provider", provider)) } } // callOpenAICompatible performs a chat completion request against an OpenAI-compatible endpoint. func callOpenAICompatible(ctx context.Context, baseURL, apiKey, model, systemPrompt, userPrompt, provider string) (string, error) { if apiKey == "" { return "", fmt.Errorf("missing api key") } messages := []openAIMessage{} if systemPrompt != "" { messages = append(messages, openAIMessage{Role: "system", Content: systemPrompt}) } messages = append(messages, openAIMessage{Role: "user", Content: userPrompt}) body := openAIChatRequest{Model: model, Messages: messages} b, err := json.Marshal(body) if err != nil { return "", err } req, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(b)) if err != nil { return "", err } req.Header.Set("Authorization", "Bearer "+apiKey) req.Header.Set("Content-Type", "application/json") resp, err := http.DefaultClient.Do(req) if err != nil { return "", err } defer resp.Body.Close() var out openAIChatResponse if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { return "", err } if resp.StatusCode >= 300 { if out.Error != nil && out.Error.Message != "" { return "", fmt.Errorf("%s error: %s", provider, out.Error.Message) } return "", fmt.Errorf("%s request failed: status %d", provider, resp.StatusCode) } if len(out.Choices) == 0 { return "", fmt.Errorf("%s 返回为空", provider) } return out.Choices[0].Message.Content, nil } type openAIChatRequest struct { Model string `json:"model"` Messages []openAIMessage `json:"messages"` } type openAIMessage struct { Role string `json:"role"` Content string `json:"content"` } type openAIChatResponse struct { Choices []struct { Message openAIMessage `json:"message"` } `json:"choices"` Error *struct { Message string `json:"message"` } `json:"error,omitempty"` } // connectionSummary holds non-sensitive fields for MCP responses. type connectionSummary struct { ID string `json:"id"` Name string `json:"name"` Description string `json:"description,omitempty"` Kind string `json:"kind"` Type string `json:"type,omitempty"` Version string `json:"version,omitempty"` Server string `json:"server,omitempty"` Port int `json:"port,omitempty"` Database string `json:"database,omitempty"` Color string `json:"color,omitempty"` AutoConnect bool `json:"auto_connect"` } func summarizeConnection(c types.ConnectionWithDetails) connectionSummary { summary := connectionSummary{ ID: c.ID, Name: c.Name, Description: c.Description, Kind: c.Kind, Color: c.Color, AutoConnect: c.AutoConnect, } if c.DBDetail != nil { summary.Type = c.DBDetail.Type summary.Version = c.DBDetail.Version summary.Server = c.DBDetail.Server summary.Port = c.DBDetail.Port summary.Database = c.DBDetail.DatabaseName } if c.ServerDetail != nil { summary.Type = c.ServerDetail.Type summary.Version = c.ServerDetail.Version summary.Server = c.ServerDetail.Server summary.Port = c.ServerDetail.Port } return summary }