| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335 |
- package bootstrap
- import (
- "bytes"
- "context"
- "encoding/json"
- "fmt"
- "net/http"
- "time"
- "dbview/service/internal/common/databases/meta"
- "dbview/service/internal/common/logger"
- "dbview/service/internal/common/manager/connection"
- "dbview/service/internal/common/manager/storage"
- "dbview/service/internal/common/manager/storage/types"
- "dbview/service/internal/common/manager/task"
- "dbview/service/internal/common/mcp"
- "dbview/service/internal/config"
- dq "dbview/service/internal/modules/data_query/service"
- mcpgo "github.com/mark3labs/mcp-go/mcp"
- "go.uber.org/zap"
- )
- // initializeMCP creates and starts the MCP component when enabled in config.
- // It registers a few useful tools backed by existing storage and query services.
- func initializeMCP(ctx context.Context, cfg *config.AppConfig, log logger.Logger, storageMgr storage.StorageInterface, pool *connection.ConnectionPool, taskMgr *task.Manager) *mcp.Component {
- if cfg == nil || !cfg.MCP.Enable {
- if log != nil {
- log.Info("MCP 组件未启用,跳过初始化", zap.Bool("enabled", cfg != nil && cfg.MCP.Enable))
- }
- return nil
- }
- comp := mcp.NewComponent(mcp.Config{
- Enable: cfg.MCP.Enable,
- ServerName: cfg.MCP.ServerName,
- ServerVersion: cfg.MCP.ServerVersion,
- }, log)
- // 基础工具
- comp.RegisterHealthTool()
- comp.RegisterEchoTool()
- // 与当前业务相关的工具
- registerConnectionTools(comp, storageMgr, log)
- registerExecuteSQLTool(comp, pool, taskMgr, log)
- registerAIChatTool(comp, cfg.AI, storageMgr, log)
- go func() {
- if err := comp.ServeStdio(ctx); err != nil && log != nil {
- log.Error("MCP ServeStdio 启动失败", zap.Error(err))
- }
- }()
- if log != nil {
- log.Info("MCP 组件已启用", zap.String("server_name", cfg.MCP.ServerName), zap.String("version", cfg.MCP.ServerVersion))
- }
- return comp
- }
- // registerConnectionTools exposes a simple connection listing backed by storage.
- func registerConnectionTools(comp *mcp.Component, storageMgr storage.StorageInterface, log logger.Logger) {
- if storageMgr == nil {
- if log != nil {
- log.Warn("未找到存储管理器,跳过 MCP 连接工具注册")
- }
- return
- }
- tool := mcpgo.NewTool("connections_list",
- mcpgo.WithDescription("列出已存储的连接(不含敏感字段)"),
- )
- comp.Server().AddTool(tool, func(ctx context.Context, req mcpgo.CallToolRequest) (*mcpgo.CallToolResult, error) {
- conns, err := storageMgr.GetAllConnections()
- if err != nil {
- return nil, err
- }
- summaries := make([]connectionSummary, 0, len(conns))
- for _, c := range conns {
- summaries = append(summaries, summarizeConnection(c))
- }
- payload, err := json.MarshalIndent(summaries, "", " ")
- if err != nil {
- return nil, err
- }
- return &mcpgo.CallToolResult{Content: []mcpgo.Content{mcpgo.TextContent{Text: string(payload)}}}, nil
- })
- }
- // registerExecuteSQLTool exposes a synchronous SQL execution tool backed by the connection pool.
- func registerExecuteSQLTool(comp *mcp.Component, pool *connection.ConnectionPool, taskMgr *task.Manager, log logger.Logger) {
- if pool == nil {
- if log != nil {
- log.Warn("未找到连接池,跳过 MCP SQL 执行工具注册")
- }
- return
- }
- dataSvc := dq.NewDataService(
- dq.WithConnectionPool(pool),
- dq.WithTaskManager(taskMgr),
- )
- tool := mcpgo.NewTool("execute_sql",
- mcpgo.WithDescription("在指定连接上同步执行 SQL"),
- mcpgo.WithString("connection_id", mcpgo.Description("连接ID,需已在连接池中就绪")),
- mcpgo.WithString("sql", mcpgo.Description("要执行的 SQL 语句")),
- )
- comp.Server().AddTool(tool, func(ctx context.Context, req mcpgo.CallToolRequest) (*mcpgo.CallToolResult, error) {
- args, _ := req.Params.Arguments.(map[string]any)
- connID, _ := args["connection_id"].(string)
- sqlText, _ := args["sql"].(string)
- if connID == "" || sqlText == "" {
- return nil, fmt.Errorf("connection_id 与 sql 不能为空")
- }
- result, err := dataSvc.ExecuteSQL(ctx, connID, "", meta.ObjectPath{}, sqlText, nil, false, false)
- if err != nil {
- return nil, err
- }
- execRes, ok := result.(meta.ExecuteResult)
- if !ok {
- return nil, fmt.Errorf("unexpected execute result type %T", result)
- }
- payload, err := json.MarshalIndent(execRes, "", " ")
- if err != nil {
- return nil, err
- }
- return &mcpgo.CallToolResult{Content: []mcpgo.Content{mcpgo.TextContent{Text: string(payload)}}}, nil
- })
- }
- // registerAIChatTool exposes a simple AI chat tool backed by OpenAI HTTP API.
- func registerAIChatTool(comp *mcp.Component, aiCfg config.AIConfig, storageMgr storage.StorageInterface, log logger.Logger) {
- if !aiCfg.Enable {
- if log != nil {
- log.Info("AI 交互未启用,跳过 MCP ai_chat 注册")
- }
- return
- }
- if storageMgr == nil {
- if log != nil {
- log.Warn("AI 交互已启用但存储未初始化,跳过 MCP ai_chat 注册")
- }
- return
- }
- aiSettings, err := storageMgr.GetAIConfig()
- if err != nil {
- if log != nil {
- log.Warn("读取 AI 配置失败,跳过 MCP ai_chat 注册", zap.Error(err))
- }
- return
- }
- provider := aiSettings.Provider
- if provider == "" {
- provider = "openai"
- }
- baseURL := aiSettings.BaseURL
- if baseURL == "" {
- baseURL = "https://api.openai.com/v1/chat/completions"
- }
- defaultModel := aiSettings.Model
- if defaultModel == "" {
- defaultModel = "gpt-4o-mini"
- }
- if aiSettings.APIKey == "" {
- if log != nil {
- log.Warn("AI 交互已启用但未配置 api_key")
- }
- return
- }
- tool := mcpgo.NewTool("ai_chat",
- mcpgo.WithDescription("调用外部 AI 模型生成回复"),
- mcpgo.WithString("prompt", mcpgo.Description("用户提示,必填")),
- mcpgo.WithString("system", mcpgo.Description("系统指令,可选")),
- mcpgo.WithString("model", mcpgo.Description("覆盖默认模型,可选")),
- )
- comp.Server().AddTool(tool, func(ctx context.Context, req mcpgo.CallToolRequest) (*mcpgo.CallToolResult, error) {
- args, _ := req.Params.Arguments.(map[string]any)
- prompt, _ := args["prompt"].(string)
- if prompt == "" {
- return nil, fmt.Errorf("prompt 不能为空")
- }
- systemPrompt, _ := args["system"].(string)
- model := defaultModel
- if override, ok := args["model"].(string); ok && override != "" {
- model = override
- }
- if p, ok := args["provider"].(string); ok && p != "" && p != provider {
- return nil, fmt.Errorf("不支持请求覆盖 provider")
- }
- callCtx, cancel := context.WithTimeout(ctx, 45*time.Second)
- defer cancel()
- respText, err := callOpenAICompatible(callCtx, baseURL, aiSettings.APIKey, model, systemPrompt, prompt, provider)
- if err != nil {
- return nil, err
- }
- return &mcpgo.CallToolResult{Content: []mcpgo.Content{mcpgo.TextContent{Text: respText}}}, nil
- })
- if log != nil {
- log.Info("MCP ai_chat 工具已注册", zap.String("model", defaultModel), zap.String("provider", provider))
- }
- }
- // callOpenAICompatible performs a chat completion request against an OpenAI-compatible endpoint.
- func callOpenAICompatible(ctx context.Context, baseURL, apiKey, model, systemPrompt, userPrompt, provider string) (string, error) {
- if apiKey == "" {
- return "", fmt.Errorf("missing api key")
- }
- messages := []openAIMessage{}
- if systemPrompt != "" {
- messages = append(messages, openAIMessage{Role: "system", Content: systemPrompt})
- }
- messages = append(messages, openAIMessage{Role: "user", Content: userPrompt})
- body := openAIChatRequest{Model: model, Messages: messages}
- b, err := json.Marshal(body)
- if err != nil {
- return "", err
- }
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(b))
- if err != nil {
- return "", err
- }
- req.Header.Set("Authorization", "Bearer "+apiKey)
- req.Header.Set("Content-Type", "application/json")
- resp, err := http.DefaultClient.Do(req)
- if err != nil {
- return "", err
- }
- defer resp.Body.Close()
- var out openAIChatResponse
- if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
- return "", err
- }
- if resp.StatusCode >= 300 {
- if out.Error != nil && out.Error.Message != "" {
- return "", fmt.Errorf("%s error: %s", provider, out.Error.Message)
- }
- return "", fmt.Errorf("%s request failed: status %d", provider, resp.StatusCode)
- }
- if len(out.Choices) == 0 {
- return "", fmt.Errorf("%s 返回为空", provider)
- }
- return out.Choices[0].Message.Content, nil
- }
- type openAIChatRequest struct {
- Model string `json:"model"`
- Messages []openAIMessage `json:"messages"`
- }
- type openAIMessage struct {
- Role string `json:"role"`
- Content string `json:"content"`
- }
- type openAIChatResponse struct {
- Choices []struct {
- Message openAIMessage `json:"message"`
- } `json:"choices"`
- Error *struct {
- Message string `json:"message"`
- } `json:"error,omitempty"`
- }
- // connectionSummary holds non-sensitive fields for MCP responses.
- type connectionSummary struct {
- ID string `json:"id"`
- Name string `json:"name"`
- Description string `json:"description,omitempty"`
- Kind string `json:"kind"`
- Type string `json:"type,omitempty"`
- Version string `json:"version,omitempty"`
- Server string `json:"server,omitempty"`
- Port int `json:"port,omitempty"`
- Database string `json:"database,omitempty"`
- Color string `json:"color,omitempty"`
- AutoConnect bool `json:"auto_connect"`
- }
- func summarizeConnection(c types.ConnectionWithDetails) connectionSummary {
- summary := connectionSummary{
- ID: c.ID,
- Name: c.Name,
- Description: c.Description,
- Kind: c.Kind,
- Color: c.Color,
- AutoConnect: c.AutoConnect,
- }
- if c.DBDetail != nil {
- summary.Type = c.DBDetail.Type
- summary.Version = c.DBDetail.Version
- summary.Server = c.DBDetail.Server
- summary.Port = c.DBDetail.Port
- summary.Database = c.DBDetail.DatabaseName
- }
- if c.ServerDetail != nil {
- summary.Type = c.ServerDetail.Type
- summary.Version = c.ServerDetail.Version
- summary.Server = c.ServerDetail.Server
- summary.Port = c.ServerDetail.Port
- }
- return summary
- }
|