providers.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. package service
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "strings"
  10. "dbview/service/internal/common/manager/storage/types"
  11. )
  12. // ProviderClient 抽象不同 AI 提供商的调用方式
  13. type ProviderClient interface {
  14. Chat(ctx context.Context, hc *http.Client, cfg types.AISettings, model, system, prompt string) (string, error)
  15. }
  16. // provider registry
  17. var providerRegistry = map[string]ProviderClient{}
  18. func registerProvider(name string, c ProviderClient) {
  19. providerRegistry[strings.ToLower(strings.TrimSpace(name))] = c
  20. }
  21. func getProviderClient(name string) ProviderClient {
  22. return providerRegistry[strings.ToLower(strings.TrimSpace(name))]
  23. }
  24. // OpenAIProvider 实现 OpenAI 兼容的 Chat 调用,但读取 cfg 中的 auth/header 设定
  25. type OpenAIProvider struct{}
  26. func (p *OpenAIProvider) Chat(ctx context.Context, hc *http.Client, cfg types.AISettings, model, system, prompt string) (string, error) {
  27. messages := []openAIMessage{}
  28. if system != "" {
  29. messages = append(messages, openAIMessage{Role: "system", Content: system})
  30. }
  31. messages = append(messages, openAIMessage{Role: "user", Content: prompt})
  32. body := openAIChatRequest{Model: model, Messages: messages}
  33. b, err := json.Marshal(body)
  34. if err != nil {
  35. return "", err
  36. }
  37. req, err := http.NewRequestWithContext(ctx, http.MethodPost, cfg.BaseURL, bytes.NewReader(b))
  38. if err != nil {
  39. return "", err
  40. }
  41. headerName := cfg.AuthHeader
  42. if headerName == "" {
  43. headerName = "Authorization"
  44. }
  45. if cfg.AuthScheme != "" {
  46. req.Header.Set(headerName, cfg.AuthScheme+" "+cfg.APIKey)
  47. } else {
  48. req.Header.Set(headerName, cfg.APIKey)
  49. }
  50. req.Header.Set("Content-Type", "application/json")
  51. // extra headers JSON
  52. if cfg.ExtraHeaders != "" {
  53. var extra map[string]string
  54. if err := json.Unmarshal([]byte(cfg.ExtraHeaders), &extra); err == nil {
  55. for k, v := range extra {
  56. if k != "Content-Type" {
  57. req.Header.Set(k, v)
  58. }
  59. }
  60. }
  61. }
  62. resp, err := hc.Do(req)
  63. if err != nil {
  64. return "", err
  65. }
  66. defer resp.Body.Close()
  67. raw, err := io.ReadAll(resp.Body)
  68. if err != nil {
  69. return "", err
  70. }
  71. var out openAIChatResponse
  72. _ = json.Unmarshal(raw, &out)
  73. if resp.StatusCode >= 300 {
  74. if out.Error != nil && out.Error.Message != "" {
  75. return "", fmt.Errorf("%s error: %s (status=%d)", cfg.RequestType, out.Error.Message, resp.StatusCode)
  76. }
  77. snippet := string(raw)
  78. if len(snippet) > 500 {
  79. snippet = snippet[:500] + "..."
  80. }
  81. return "", fmt.Errorf("%s request failed: status=%d, body=%s", cfg.RequestType, resp.StatusCode, snippet)
  82. }
  83. if len(out.Choices) == 0 {
  84. return "", fmt.Errorf("%s 返回为空(status=%d)", cfg.RequestType, resp.StatusCode)
  85. }
  86. return out.Choices[0].Message.Content, nil
  87. }
  88. // MimoProvider 简单实现(当作示例):使用 body {"input": <prompt>},并支持自定义 header 名称/无 scheme
  89. type MimoProvider struct{}
  90. func (p *MimoProvider) Chat(ctx context.Context, hc *http.Client, cfg types.AISettings, model, system, prompt string) (string, error) {
  91. // 简单 body:{"input":"..."}
  92. payload := map[string]any{"input": prompt}
  93. b, err := json.Marshal(payload)
  94. if err != nil {
  95. return "", err
  96. }
  97. req, err := http.NewRequestWithContext(ctx, http.MethodPost, cfg.BaseURL, bytes.NewReader(b))
  98. if err != nil {
  99. return "", err
  100. }
  101. headerName := cfg.AuthHeader
  102. if headerName == "" {
  103. headerName = "x-api-key"
  104. }
  105. if cfg.AuthScheme != "" {
  106. req.Header.Set(headerName, cfg.AuthScheme+" "+cfg.APIKey)
  107. } else {
  108. req.Header.Set(headerName, cfg.APIKey)
  109. }
  110. req.Header.Set("Content-Type", "application/json")
  111. if cfg.ExtraHeaders != "" {
  112. var extra map[string]string
  113. if err := json.Unmarshal([]byte(cfg.ExtraHeaders), &extra); err == nil {
  114. for k, v := range extra {
  115. if k != "Content-Type" {
  116. req.Header.Set(k, v)
  117. }
  118. }
  119. }
  120. }
  121. resp, err := hc.Do(req)
  122. if err != nil {
  123. return "", err
  124. }
  125. defer resp.Body.Close()
  126. raw, err := io.ReadAll(resp.Body)
  127. if err != nil {
  128. return "", err
  129. }
  130. // 尝试解析常见字段,否则返回原始 body 文本
  131. var parsed map[string]any
  132. if err := json.Unmarshal(raw, &parsed); err == nil {
  133. // 常见字段:output / answer / data
  134. if v, ok := parsed["output"]; ok {
  135. return fmt.Sprint(v), nil
  136. }
  137. if v, ok := parsed["answer"]; ok {
  138. return fmt.Sprint(v), nil
  139. }
  140. if v, ok := parsed["data"]; ok {
  141. return fmt.Sprint(v), nil
  142. }
  143. }
  144. // fallback to raw string
  145. return strings.TrimSpace(string(raw)), nil
  146. }
  147. func init() {
  148. registerProvider("openai", &OpenAIProvider{})
  149. registerProvider("mimo", &MimoProvider{})
  150. }