mcp.go 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. package bootstrap
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "net/http"
  8. "os"
  9. "time"
  10. "dbview/service/internal/common/databases/meta"
  11. "dbview/service/internal/common/logger"
  12. "dbview/service/internal/common/manager/connection"
  13. "dbview/service/internal/common/manager/storage"
  14. "dbview/service/internal/common/manager/storage/types"
  15. "dbview/service/internal/common/manager/task"
  16. "dbview/service/internal/common/mcp"
  17. "dbview/service/internal/config"
  18. dq "dbview/service/internal/modules/data_query/service"
  19. mcpgo "github.com/mark3labs/mcp-go/mcp"
  20. "go.uber.org/zap"
  21. )
  22. // initializeMCP creates and starts the MCP component when enabled in config.
  23. // It registers a few useful tools backed by existing storage and query services.
  24. func initializeMCP(ctx context.Context, cfg *config.AppConfig, log logger.Logger, storageMgr storage.StorageInterface, pool *connection.ConnectionPool, taskMgr *task.Manager) *mcp.Component {
  25. if cfg == nil || !cfg.MCP.Enable {
  26. if log != nil {
  27. log.Info("MCP 组件未启用,跳过初始化", zap.Bool("enabled", cfg != nil && cfg.MCP.Enable))
  28. }
  29. return nil
  30. }
  31. comp := mcp.NewComponent(mcp.Config{
  32. Enable: cfg.MCP.Enable,
  33. ServerName: cfg.MCP.ServerName,
  34. ServerVersion: cfg.MCP.ServerVersion,
  35. }, log)
  36. // 基础工具
  37. comp.RegisterHealthTool()
  38. comp.RegisterEchoTool()
  39. // 与当前业务相关的工具
  40. registerConnectionTools(comp, storageMgr, log)
  41. registerExecuteSQLTool(comp, pool, taskMgr, log)
  42. registerAIChatTool(comp, cfg.AI, log)
  43. go func() {
  44. if err := comp.ServeStdio(ctx); err != nil && log != nil {
  45. log.Error("MCP ServeStdio 启动失败", zap.Error(err))
  46. }
  47. }()
  48. if log != nil {
  49. log.Info("MCP 组件已启用", zap.String("server_name", cfg.MCP.ServerName), zap.String("version", cfg.MCP.ServerVersion))
  50. }
  51. return comp
  52. }
  53. // registerConnectionTools exposes a simple connection listing backed by storage.
  54. func registerConnectionTools(comp *mcp.Component, storageMgr storage.StorageInterface, log logger.Logger) {
  55. if storageMgr == nil {
  56. if log != nil {
  57. log.Warn("未找到存储管理器,跳过 MCP 连接工具注册")
  58. }
  59. return
  60. }
  61. tool := mcpgo.NewTool("connections_list",
  62. mcpgo.WithDescription("列出已存储的连接(不含敏感字段)"),
  63. )
  64. comp.Server().AddTool(tool, func(ctx context.Context, req mcpgo.CallToolRequest) (*mcpgo.CallToolResult, error) {
  65. conns, err := storageMgr.GetAllConnections()
  66. if err != nil {
  67. return nil, err
  68. }
  69. summaries := make([]connectionSummary, 0, len(conns))
  70. for _, c := range conns {
  71. summaries = append(summaries, summarizeConnection(c))
  72. }
  73. payload, err := json.MarshalIndent(summaries, "", " ")
  74. if err != nil {
  75. return nil, err
  76. }
  77. return &mcpgo.CallToolResult{Content: []mcpgo.Content{mcpgo.TextContent{Text: string(payload)}}}, nil
  78. })
  79. }
  80. // registerExecuteSQLTool exposes a synchronous SQL execution tool backed by the connection pool.
  81. func registerExecuteSQLTool(comp *mcp.Component, pool *connection.ConnectionPool, taskMgr *task.Manager, log logger.Logger) {
  82. if pool == nil {
  83. if log != nil {
  84. log.Warn("未找到连接池,跳过 MCP SQL 执行工具注册")
  85. }
  86. return
  87. }
  88. dataSvc := dq.NewDataService(
  89. dq.WithConnectionPool(pool),
  90. dq.WithTaskManager(taskMgr),
  91. )
  92. tool := mcpgo.NewTool("execute_sql",
  93. mcpgo.WithDescription("在指定连接上同步执行 SQL"),
  94. mcpgo.WithString("connection_id", mcpgo.Description("连接ID,需已在连接池中就绪")),
  95. mcpgo.WithString("sql", mcpgo.Description("要执行的 SQL 语句")),
  96. )
  97. comp.Server().AddTool(tool, func(ctx context.Context, req mcpgo.CallToolRequest) (*mcpgo.CallToolResult, error) {
  98. args, _ := req.Params.Arguments.(map[string]any)
  99. connID, _ := args["connection_id"].(string)
  100. sqlText, _ := args["sql"].(string)
  101. if connID == "" || sqlText == "" {
  102. return nil, fmt.Errorf("connection_id 与 sql 不能为空")
  103. }
  104. result, err := dataSvc.ExecuteSQL(ctx, connID, "", meta.ObjectPath{}, sqlText, nil, false, false)
  105. if err != nil {
  106. return nil, err
  107. }
  108. execRes, ok := result.(meta.ExecuteResult)
  109. if !ok {
  110. return nil, fmt.Errorf("unexpected execute result type %T", result)
  111. }
  112. payload, err := json.MarshalIndent(execRes, "", " ")
  113. if err != nil {
  114. return nil, err
  115. }
  116. return &mcpgo.CallToolResult{Content: []mcpgo.Content{mcpgo.TextContent{Text: string(payload)}}}, nil
  117. })
  118. }
  119. // registerAIChatTool exposes a simple AI chat tool backed by OpenAI HTTP API.
  120. func registerAIChatTool(comp *mcp.Component, aiCfg config.AIConfig, log logger.Logger) {
  121. if !aiCfg.Enable {
  122. if log != nil {
  123. log.Info("AI 交互未启用,跳过 MCP ai_chat 注册")
  124. }
  125. return
  126. }
  127. if aiCfg.Provider != "openai" {
  128. if log != nil {
  129. log.Warn("AI Provider 未支持,跳过 MCP ai_chat 注册", zap.String("provider", aiCfg.Provider))
  130. }
  131. return
  132. }
  133. apiKeyEnv := aiCfg.APIKeyEnv
  134. if apiKeyEnv == "" {
  135. apiKeyEnv = "OPENAI_API_KEY"
  136. }
  137. apiKey := os.Getenv(apiKeyEnv)
  138. if apiKey == "" {
  139. if log != nil {
  140. log.Warn("AI 交互已启用但未找到 API Key 环境变量", zap.String("env", apiKeyEnv))
  141. }
  142. return
  143. }
  144. defaultModel := aiCfg.Model
  145. if defaultModel == "" {
  146. defaultModel = "gpt-4o-mini"
  147. }
  148. tool := mcpgo.NewTool("ai_chat",
  149. mcpgo.WithDescription("调用外部 AI 模型生成回复 (OpenAI)"),
  150. mcpgo.WithString("prompt", mcpgo.Description("用户提示,必填")),
  151. mcpgo.WithString("system", mcpgo.Description("系统指令,可选")),
  152. mcpgo.WithString("model", mcpgo.Description("覆盖默认模型,可选")),
  153. )
  154. comp.Server().AddTool(tool, func(ctx context.Context, req mcpgo.CallToolRequest) (*mcpgo.CallToolResult, error) {
  155. args, _ := req.Params.Arguments.(map[string]any)
  156. prompt, _ := args["prompt"].(string)
  157. if prompt == "" {
  158. return nil, fmt.Errorf("prompt 不能为空")
  159. }
  160. systemPrompt, _ := args["system"].(string)
  161. model := defaultModel
  162. if override, ok := args["model"].(string); ok && override != "" {
  163. model = override
  164. }
  165. callCtx, cancel := context.WithTimeout(ctx, 45*time.Second)
  166. defer cancel()
  167. respText, err := callOpenAIChat(callCtx, apiKey, model, systemPrompt, prompt)
  168. if err != nil {
  169. return nil, err
  170. }
  171. return &mcpgo.CallToolResult{Content: []mcpgo.Content{mcpgo.TextContent{Text: respText}}}, nil
  172. })
  173. if log != nil {
  174. log.Info("MCP ai_chat 工具已注册", zap.String("model", defaultModel), zap.String("provider", aiCfg.Provider))
  175. }
  176. }
  177. // callOpenAIChat performs a minimal chat completion request.
  178. func callOpenAIChat(ctx context.Context, apiKey, model, systemPrompt, userPrompt string) (string, error) {
  179. if apiKey == "" {
  180. return "", fmt.Errorf("missing api key")
  181. }
  182. messages := []openAIMessage{}
  183. if systemPrompt != "" {
  184. messages = append(messages, openAIMessage{Role: "system", Content: systemPrompt})
  185. }
  186. messages = append(messages, openAIMessage{Role: "user", Content: userPrompt})
  187. body := openAIChatRequest{Model: model, Messages: messages}
  188. b, err := json.Marshal(body)
  189. if err != nil {
  190. return "", err
  191. }
  192. req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader(b))
  193. if err != nil {
  194. return "", err
  195. }
  196. req.Header.Set("Authorization", "Bearer "+apiKey)
  197. req.Header.Set("Content-Type", "application/json")
  198. resp, err := http.DefaultClient.Do(req)
  199. if err != nil {
  200. return "", err
  201. }
  202. defer resp.Body.Close()
  203. var out openAIChatResponse
  204. if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
  205. return "", err
  206. }
  207. if resp.StatusCode >= 300 {
  208. if out.Error != nil && out.Error.Message != "" {
  209. return "", fmt.Errorf("openai error: %s", out.Error.Message)
  210. }
  211. return "", fmt.Errorf("openai request failed: status %d", resp.StatusCode)
  212. }
  213. if len(out.Choices) == 0 {
  214. return "", fmt.Errorf("openai 返回为空")
  215. }
  216. return out.Choices[0].Message.Content, nil
  217. }
  218. type openAIChatRequest struct {
  219. Model string `json:"model"`
  220. Messages []openAIMessage `json:"messages"`
  221. }
  222. type openAIMessage struct {
  223. Role string `json:"role"`
  224. Content string `json:"content"`
  225. }
  226. type openAIChatResponse struct {
  227. Choices []struct {
  228. Message openAIMessage `json:"message"`
  229. } `json:"choices"`
  230. Error *struct {
  231. Message string `json:"message"`
  232. } `json:"error,omitempty"`
  233. }
  234. // connectionSummary holds non-sensitive fields for MCP responses.
  235. type connectionSummary struct {
  236. ID string `json:"id"`
  237. Name string `json:"name"`
  238. Description string `json:"description,omitempty"`
  239. Kind string `json:"kind"`
  240. Type string `json:"type,omitempty"`
  241. Version string `json:"version,omitempty"`
  242. Server string `json:"server,omitempty"`
  243. Port int `json:"port,omitempty"`
  244. Database string `json:"database,omitempty"`
  245. Color string `json:"color,omitempty"`
  246. AutoConnect bool `json:"auto_connect"`
  247. }
  248. func summarizeConnection(c types.ConnectionWithDetails) connectionSummary {
  249. summary := connectionSummary{
  250. ID: c.ID,
  251. Name: c.Name,
  252. Description: c.Description,
  253. Kind: c.Kind,
  254. Color: c.Color,
  255. AutoConnect: c.AutoConnect,
  256. }
  257. if c.DBDetail != nil {
  258. summary.Type = c.DBDetail.Type
  259. summary.Version = c.DBDetail.Version
  260. summary.Server = c.DBDetail.Server
  261. summary.Port = c.DBDetail.Port
  262. summary.Database = c.DBDetail.DatabaseName
  263. }
  264. if c.ServerDetail != nil {
  265. summary.Type = c.ServerDetail.Type
  266. summary.Version = c.ServerDetail.Version
  267. summary.Server = c.ServerDetail.Server
  268. summary.Port = c.ServerDetail.Port
  269. }
  270. return summary
  271. }