| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178 |
- package service
- import (
- "bytes"
- "context"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "strings"
- "dbview/service/internal/common/manager/storage/types"
- )
- // ProviderClient 抽象不同 AI 提供商的调用方式
- type ProviderClient interface {
- Chat(ctx context.Context, hc *http.Client, cfg types.AISettings, model, system, prompt string) (string, error)
- }
- // provider registry
- var providerRegistry = map[string]ProviderClient{}
- func registerProvider(name string, c ProviderClient) {
- providerRegistry[strings.ToLower(strings.TrimSpace(name))] = c
- }
- func getProviderClient(name string) ProviderClient {
- return providerRegistry[strings.ToLower(strings.TrimSpace(name))]
- }
- // OpenAIProvider 实现 OpenAI 兼容的 Chat 调用,但读取 cfg 中的 auth/header 设定
- type OpenAIProvider struct{}
- func (p *OpenAIProvider) Chat(ctx context.Context, hc *http.Client, cfg types.AISettings, model, system, prompt string) (string, error) {
- messages := []openAIMessage{}
- if system != "" {
- messages = append(messages, openAIMessage{Role: "system", Content: system})
- }
- messages = append(messages, openAIMessage{Role: "user", Content: prompt})
- body := openAIChatRequest{Model: model, Messages: messages}
- b, err := json.Marshal(body)
- if err != nil {
- return "", err
- }
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, cfg.BaseURL, bytes.NewReader(b))
- if err != nil {
- return "", err
- }
- headerName := cfg.AuthHeader
- if headerName == "" {
- headerName = "Authorization"
- }
- if cfg.AuthScheme != "" {
- req.Header.Set(headerName, cfg.AuthScheme+" "+cfg.APIKey)
- } else {
- req.Header.Set(headerName, cfg.APIKey)
- }
- req.Header.Set("Content-Type", "application/json")
- // extra headers JSON
- if cfg.ExtraHeaders != "" {
- var extra map[string]string
- if err := json.Unmarshal([]byte(cfg.ExtraHeaders), &extra); err == nil {
- for k, v := range extra {
- if k != "Content-Type" {
- req.Header.Set(k, v)
- }
- }
- }
- }
- resp, err := hc.Do(req)
- if err != nil {
- return "", err
- }
- defer resp.Body.Close()
- raw, err := io.ReadAll(resp.Body)
- if err != nil {
- return "", err
- }
- var out openAIChatResponse
- _ = json.Unmarshal(raw, &out)
- if resp.StatusCode >= 300 {
- if out.Error != nil && out.Error.Message != "" {
- return "", fmt.Errorf("%s error: %s (status=%d)", cfg.RequestType, out.Error.Message, resp.StatusCode)
- }
- snippet := string(raw)
- if len(snippet) > 500 {
- snippet = snippet[:500] + "..."
- }
- return "", fmt.Errorf("%s request failed: status=%d, body=%s", cfg.RequestType, resp.StatusCode, snippet)
- }
- if len(out.Choices) == 0 {
- return "", fmt.Errorf("%s 返回为空(status=%d)", cfg.RequestType, resp.StatusCode)
- }
- return out.Choices[0].Message.Content, nil
- }
- // MimoProvider 简单实现(当作示例):使用 body {"input": <prompt>},并支持自定义 header 名称/无 scheme
- type MimoProvider struct{}
- func (p *MimoProvider) Chat(ctx context.Context, hc *http.Client, cfg types.AISettings, model, system, prompt string) (string, error) {
- // 简单 body:{"input":"..."}
- payload := map[string]any{"input": prompt}
- b, err := json.Marshal(payload)
- if err != nil {
- return "", err
- }
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, cfg.BaseURL, bytes.NewReader(b))
- if err != nil {
- return "", err
- }
- headerName := cfg.AuthHeader
- if headerName == "" {
- headerName = "x-api-key"
- }
- if cfg.AuthScheme != "" {
- req.Header.Set(headerName, cfg.AuthScheme+" "+cfg.APIKey)
- } else {
- req.Header.Set(headerName, cfg.APIKey)
- }
- req.Header.Set("Content-Type", "application/json")
- if cfg.ExtraHeaders != "" {
- var extra map[string]string
- if err := json.Unmarshal([]byte(cfg.ExtraHeaders), &extra); err == nil {
- for k, v := range extra {
- if k != "Content-Type" {
- req.Header.Set(k, v)
- }
- }
- }
- }
- resp, err := hc.Do(req)
- if err != nil {
- return "", err
- }
- defer resp.Body.Close()
- raw, err := io.ReadAll(resp.Body)
- if err != nil {
- return "", err
- }
- // 尝试解析常见字段,否则返回原始 body 文本
- var parsed map[string]any
- if err := json.Unmarshal(raw, &parsed); err == nil {
- // 常见字段:output / answer / data
- if v, ok := parsed["output"]; ok {
- return fmt.Sprint(v), nil
- }
- if v, ok := parsed["answer"]; ok {
- return fmt.Sprint(v), nil
- }
- if v, ok := parsed["data"]; ok {
- return fmt.Sprint(v), nil
- }
- }
- // fallback to raw string
- return strings.TrimSpace(string(raw)), nil
- }
- func init() {
- registerProvider("openai", &OpenAIProvider{})
- registerProvider("mimo", &MimoProvider{})
- }
|