mcp.go 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. package bootstrap
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "net/http"
  8. "time"
  9. "dbview/service/internal/common/databases/meta"
  10. "dbview/service/internal/common/logger"
  11. "dbview/service/internal/common/manager/connection"
  12. "dbview/service/internal/common/manager/storage"
  13. "dbview/service/internal/common/manager/storage/types"
  14. "dbview/service/internal/common/manager/task"
  15. "dbview/service/internal/common/mcp"
  16. "dbview/service/internal/config"
  17. dq "dbview/service/internal/modules/data_query/service"
  18. mcpgo "github.com/mark3labs/mcp-go/mcp"
  19. "go.uber.org/zap"
  20. )
  21. // initializeMCP creates and starts the MCP component when enabled in config.
  22. // It registers a few useful tools backed by existing storage and query services.
  23. func initializeMCP(ctx context.Context, cfg *config.AppConfig, log logger.Logger, storageMgr storage.StorageInterface, pool *connection.ConnectionPool, taskMgr *task.Manager) *mcp.Component {
  24. if cfg == nil || !cfg.MCP.Enable {
  25. if log != nil {
  26. log.Info("MCP 组件未启用,跳过初始化", zap.Bool("enabled", cfg != nil && cfg.MCP.Enable))
  27. }
  28. return nil
  29. }
  30. comp := mcp.NewComponent(mcp.Config{
  31. Enable: cfg.MCP.Enable,
  32. ServerName: cfg.MCP.ServerName,
  33. ServerVersion: cfg.MCP.ServerVersion,
  34. }, log)
  35. // 基础工具
  36. comp.RegisterHealthTool()
  37. comp.RegisterEchoTool()
  38. // 与当前业务相关的工具
  39. registerConnectionTools(comp, storageMgr, log)
  40. registerExecuteSQLTool(comp, pool, taskMgr, log)
  41. registerAIChatTool(comp, cfg.AI, storageMgr, log)
  42. go func() {
  43. if err := comp.ServeStdio(ctx); err != nil && log != nil {
  44. log.Error("MCP ServeStdio 启动失败", zap.Error(err))
  45. }
  46. }()
  47. if log != nil {
  48. log.Info("MCP 组件已启用", zap.String("server_name", cfg.MCP.ServerName), zap.String("version", cfg.MCP.ServerVersion))
  49. }
  50. return comp
  51. }
  52. // registerConnectionTools exposes a simple connection listing backed by storage.
  53. func registerConnectionTools(comp *mcp.Component, storageMgr storage.StorageInterface, log logger.Logger) {
  54. if storageMgr == nil {
  55. if log != nil {
  56. log.Warn("未找到存储管理器,跳过 MCP 连接工具注册")
  57. }
  58. return
  59. }
  60. tool := mcpgo.NewTool("connections_list",
  61. mcpgo.WithDescription("列出已存储的连接(不含敏感字段)"),
  62. )
  63. comp.Server().AddTool(tool, func(ctx context.Context, req mcpgo.CallToolRequest) (*mcpgo.CallToolResult, error) {
  64. conns, err := storageMgr.GetAllConnections()
  65. if err != nil {
  66. return nil, err
  67. }
  68. summaries := make([]connectionSummary, 0, len(conns))
  69. for _, c := range conns {
  70. summaries = append(summaries, summarizeConnection(c))
  71. }
  72. payload, err := json.MarshalIndent(summaries, "", " ")
  73. if err != nil {
  74. return nil, err
  75. }
  76. return &mcpgo.CallToolResult{Content: []mcpgo.Content{mcpgo.TextContent{Text: string(payload)}}}, nil
  77. })
  78. }
  79. // registerExecuteSQLTool exposes a synchronous SQL execution tool backed by the connection pool.
  80. func registerExecuteSQLTool(comp *mcp.Component, pool *connection.ConnectionPool, taskMgr *task.Manager, log logger.Logger) {
  81. if pool == nil {
  82. if log != nil {
  83. log.Warn("未找到连接池,跳过 MCP SQL 执行工具注册")
  84. }
  85. return
  86. }
  87. dataSvc := dq.NewDataService(
  88. dq.WithConnectionPool(pool),
  89. dq.WithTaskManager(taskMgr),
  90. )
  91. tool := mcpgo.NewTool("execute_sql",
  92. mcpgo.WithDescription("在指定连接上同步执行 SQL"),
  93. mcpgo.WithString("connection_id", mcpgo.Description("连接ID,需已在连接池中就绪")),
  94. mcpgo.WithString("sql", mcpgo.Description("要执行的 SQL 语句")),
  95. )
  96. comp.Server().AddTool(tool, func(ctx context.Context, req mcpgo.CallToolRequest) (*mcpgo.CallToolResult, error) {
  97. args, _ := req.Params.Arguments.(map[string]any)
  98. connID, _ := args["connection_id"].(string)
  99. sqlText, _ := args["sql"].(string)
  100. if connID == "" || sqlText == "" {
  101. return nil, fmt.Errorf("connection_id 与 sql 不能为空")
  102. }
  103. result, err := dataSvc.ExecuteSQL(ctx, connID, "", meta.ObjectPath{}, sqlText, nil, false, false)
  104. if err != nil {
  105. return nil, err
  106. }
  107. execRes, ok := result.(meta.ExecuteResult)
  108. if !ok {
  109. return nil, fmt.Errorf("unexpected execute result type %T", result)
  110. }
  111. payload, err := json.MarshalIndent(execRes, "", " ")
  112. if err != nil {
  113. return nil, err
  114. }
  115. return &mcpgo.CallToolResult{Content: []mcpgo.Content{mcpgo.TextContent{Text: string(payload)}}}, nil
  116. })
  117. }
  118. // registerAIChatTool exposes a simple AI chat tool backed by OpenAI HTTP API.
  119. func registerAIChatTool(comp *mcp.Component, aiCfg config.AIConfig, storageMgr storage.StorageInterface, log logger.Logger) {
  120. if !aiCfg.Enable {
  121. if log != nil {
  122. log.Info("AI 交互未启用,跳过 MCP ai_chat 注册")
  123. }
  124. return
  125. }
  126. if storageMgr == nil {
  127. if log != nil {
  128. log.Warn("AI 交互已启用但存储未初始化,跳过 MCP ai_chat 注册")
  129. }
  130. return
  131. }
  132. aiSettings, err := storageMgr.GetAIConfig()
  133. if err != nil {
  134. if log != nil {
  135. log.Warn("读取 AI 配置失败,跳过 MCP ai_chat 注册", zap.Error(err))
  136. }
  137. return
  138. }
  139. provider := aiSettings.Provider
  140. if provider == "" {
  141. provider = "openai"
  142. }
  143. baseURL := aiSettings.BaseURL
  144. if baseURL == "" {
  145. baseURL = "https://api.openai.com/v1/chat/completions"
  146. }
  147. defaultModel := aiSettings.Model
  148. if defaultModel == "" {
  149. defaultModel = "gpt-4o-mini"
  150. }
  151. if aiSettings.APIKey == "" {
  152. if log != nil {
  153. log.Warn("AI 交互已启用但未配置 api_key")
  154. }
  155. return
  156. }
  157. tool := mcpgo.NewTool("ai_chat",
  158. mcpgo.WithDescription("调用外部 AI 模型生成回复"),
  159. mcpgo.WithString("prompt", mcpgo.Description("用户提示,必填")),
  160. mcpgo.WithString("system", mcpgo.Description("系统指令,可选")),
  161. mcpgo.WithString("model", mcpgo.Description("覆盖默认模型,可选")),
  162. )
  163. comp.Server().AddTool(tool, func(ctx context.Context, req mcpgo.CallToolRequest) (*mcpgo.CallToolResult, error) {
  164. args, _ := req.Params.Arguments.(map[string]any)
  165. prompt, _ := args["prompt"].(string)
  166. if prompt == "" {
  167. return nil, fmt.Errorf("prompt 不能为空")
  168. }
  169. systemPrompt, _ := args["system"].(string)
  170. model := defaultModel
  171. if override, ok := args["model"].(string); ok && override != "" {
  172. model = override
  173. }
  174. if p, ok := args["provider"].(string); ok && p != "" && p != provider {
  175. return nil, fmt.Errorf("不支持请求覆盖 provider")
  176. }
  177. callCtx, cancel := context.WithTimeout(ctx, 45*time.Second)
  178. defer cancel()
  179. respText, err := callOpenAICompatible(callCtx, baseURL, aiSettings.APIKey, model, systemPrompt, prompt, provider)
  180. if err != nil {
  181. return nil, err
  182. }
  183. return &mcpgo.CallToolResult{Content: []mcpgo.Content{mcpgo.TextContent{Text: respText}}}, nil
  184. })
  185. if log != nil {
  186. log.Info("MCP ai_chat 工具已注册", zap.String("model", defaultModel), zap.String("provider", provider))
  187. }
  188. }
  189. // callOpenAICompatible performs a chat completion request against an OpenAI-compatible endpoint.
  190. func callOpenAICompatible(ctx context.Context, baseURL, apiKey, model, systemPrompt, userPrompt, provider string) (string, error) {
  191. if apiKey == "" {
  192. return "", fmt.Errorf("missing api key")
  193. }
  194. messages := []openAIMessage{}
  195. if systemPrompt != "" {
  196. messages = append(messages, openAIMessage{Role: "system", Content: systemPrompt})
  197. }
  198. messages = append(messages, openAIMessage{Role: "user", Content: userPrompt})
  199. body := openAIChatRequest{Model: model, Messages: messages}
  200. b, err := json.Marshal(body)
  201. if err != nil {
  202. return "", err
  203. }
  204. req, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(b))
  205. if err != nil {
  206. return "", err
  207. }
  208. req.Header.Set("Authorization", "Bearer "+apiKey)
  209. req.Header.Set("Content-Type", "application/json")
  210. resp, err := http.DefaultClient.Do(req)
  211. if err != nil {
  212. return "", err
  213. }
  214. defer resp.Body.Close()
  215. var out openAIChatResponse
  216. if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
  217. return "", err
  218. }
  219. if resp.StatusCode >= 300 {
  220. if out.Error != nil && out.Error.Message != "" {
  221. return "", fmt.Errorf("%s error: %s", provider, out.Error.Message)
  222. }
  223. return "", fmt.Errorf("%s request failed: status %d", provider, resp.StatusCode)
  224. }
  225. if len(out.Choices) == 0 {
  226. return "", fmt.Errorf("%s 返回为空", provider)
  227. }
  228. return out.Choices[0].Message.Content, nil
  229. }
  230. type openAIChatRequest struct {
  231. Model string `json:"model"`
  232. Messages []openAIMessage `json:"messages"`
  233. }
  234. type openAIMessage struct {
  235. Role string `json:"role"`
  236. Content string `json:"content"`
  237. }
  238. type openAIChatResponse struct {
  239. Choices []struct {
  240. Message openAIMessage `json:"message"`
  241. } `json:"choices"`
  242. Error *struct {
  243. Message string `json:"message"`
  244. } `json:"error,omitempty"`
  245. }
  246. // connectionSummary holds non-sensitive fields for MCP responses.
  247. type connectionSummary struct {
  248. ID string `json:"id"`
  249. Name string `json:"name"`
  250. Description string `json:"description,omitempty"`
  251. Kind string `json:"kind"`
  252. Type string `json:"type,omitempty"`
  253. Version string `json:"version,omitempty"`
  254. Server string `json:"server,omitempty"`
  255. Port int `json:"port,omitempty"`
  256. Database string `json:"database,omitempty"`
  257. Color string `json:"color,omitempty"`
  258. AutoConnect bool `json:"auto_connect"`
  259. }
  260. func summarizeConnection(c types.ConnectionWithDetails) connectionSummary {
  261. summary := connectionSummary{
  262. ID: c.ID,
  263. Name: c.Name,
  264. Description: c.Description,
  265. Kind: c.Kind,
  266. Color: c.Color,
  267. AutoConnect: c.AutoConnect,
  268. }
  269. if c.DBDetail != nil {
  270. summary.Type = c.DBDetail.Type
  271. summary.Version = c.DBDetail.Version
  272. summary.Server = c.DBDetail.Server
  273. summary.Port = c.DBDetail.Port
  274. summary.Database = c.DBDetail.DatabaseName
  275. }
  276. if c.ServerDetail != nil {
  277. summary.Type = c.ServerDetail.Type
  278. summary.Version = c.ServerDetail.Version
  279. summary.Server = c.ServerDetail.Server
  280. summary.Port = c.ServerDetail.Port
  281. }
  282. return summary
  283. }