config.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. package config
  2. import (
  3. "fmt"
  4. "os"
  5. "path/filepath"
  6. "github.com/pelletier/go-toml/v2"
  7. "go.uber.org/zap/zapcore"
  8. )
  9. // AppConfig 应用配置结构
  10. type AppConfig struct {
  11. Server ServerConfig `toml:"server"`
  12. Log LogConfig `toml:"log"`
  13. Audit AuditConfig `toml:"audit"`
  14. Storage StorageConfig `toml:"storage"`
  15. MCP MCPConfig `toml:"mcp"`
  16. AI AIConfig `toml:"ai"`
  17. }
  18. // ServerConfig 服务器配置
  19. type ServerConfig struct {
  20. Ip string `toml:"ip"`
  21. Port int `toml:"port"`
  22. }
  23. // LogConfig 日志配置
  24. type LogConfig struct {
  25. Level string `toml:"level"` // 日志级别: debug, info, warn, error, fatal
  26. Development bool `toml:"development"` // 开发模式
  27. Encoding string `toml:"encoding"` // 编码格式: json, console
  28. OutputPaths []string `toml:"output_paths"` // 输出路径
  29. LogDir string `toml:"log_dir"` // 日志目录
  30. }
  31. // AuditConfig 审计配置
  32. type AuditConfig struct {
  33. Enabled bool `toml:"enabled"` // 是否启用审计
  34. DatabasePath string `toml:"database_path"` // 数据库路径
  35. RetentionDays int `toml:"retention_days"` // 保留天数
  36. BufferSize int `toml:"buffer_size"` // 缓冲区大小
  37. }
  38. // StorageConfig 存储配置
  39. type StorageConfig struct {
  40. // 存储类型: "file" 或 "db"
  41. Type string `toml:"type"`
  42. // 文件存储配置
  43. ConfigPath string `toml:"config_path"` // 配置文件路径 (用于文件存储)
  44. SQLBaseDir string `toml:"sql_base_dir"` // 脚本(Script)基础目录 (用于文件存储)
  45. // 数据库存储配置
  46. DatabasePath string `toml:"database_path"` // 数据库文件路径 (用于数据库存储)
  47. // 通用配置
  48. DataFile string `toml:"data_file"` // 数据文件路径(兼容旧配置 key `data_file`,推荐使用 `config_path`)
  49. }
  50. // MCPConfig MCP 组件配置
  51. type MCPConfig struct {
  52. Enable bool `toml:"enable"`
  53. ServerName string `toml:"server_name"`
  54. ServerVersion string `toml:"server_version"`
  55. }
  56. // AIConfig AI 交互配置
  57. type AIConfig struct {
  58. Enable bool `toml:"enable"`
  59. Provider string `toml:"provider"` // 目前支持 openai
  60. Model string `toml:"model"` // 默认模型
  61. APIKeyEnv string `toml:"api_key_env"` // 从哪个环境变量读取 API Key
  62. }
  63. // LoadConfig 加载配置文件
  64. func LoadConfig(path string) (*AppConfig, error) {
  65. // 检查文件是否存在
  66. if _, err := os.Stat(path); os.IsNotExist(err) {
  67. // 如果配置文件不存在,创建默认配置
  68. if err := CreateDefaultConfig(path); err != nil {
  69. return nil, fmt.Errorf("创建默认配置文件失败: %w", err)
  70. }
  71. }
  72. // 使用 os.ReadFile (Go 1.16+)
  73. data, err := os.ReadFile(path)
  74. if err != nil {
  75. return nil, err
  76. }
  77. var cfg AppConfig
  78. if err := toml.Unmarshal(data, &cfg); err != nil {
  79. return nil, err
  80. }
  81. return &cfg, nil
  82. }
  83. // CreateDefaultConfig 创建默认配置文件
  84. func CreateDefaultConfig(configPath string) error {
  85. // 创建默认配置
  86. defaultConfig := &AppConfig{
  87. Server: ServerConfig{
  88. Ip: "127.0.0.1",
  89. Port: 8080,
  90. },
  91. Log: LogConfig{
  92. Level: "info",
  93. Development: false,
  94. Encoding: "json",
  95. OutputPaths: []string{"stdout", "app.log"},
  96. LogDir: "./DBconfig/logs",
  97. },
  98. Audit: AuditConfig{
  99. Enabled: true,
  100. DatabasePath: "./DBconfig/audit.db",
  101. RetentionDays: 90,
  102. BufferSize: 1000,
  103. },
  104. Storage: StorageConfig{
  105. Type: "db",
  106. ConfigPath: "./DBconfig/data.toml",
  107. SQLBaseDir: "./DBconfig/sqlfiles",
  108. DatabasePath: "./DBconfig/storage.db",
  109. DataFile: "./DBconfig/data.toml",
  110. },
  111. MCP: MCPConfig{
  112. Enable: false,
  113. ServerName: "dbview-mcp",
  114. ServerVersion: "0.1.0",
  115. },
  116. AI: AIConfig{
  117. Enable: false,
  118. Provider: "openai",
  119. Model: "gpt-4o-mini",
  120. APIKeyEnv: "OPENAI_API_KEY",
  121. },
  122. }
  123. // 确保配置目录存在
  124. configDir := filepath.Dir(configPath)
  125. if err := os.MkdirAll(configDir, 0755); err != nil {
  126. return fmt.Errorf("创建配置目录失败: %w", err)
  127. }
  128. // 创建必要的目录
  129. if err := defaultConfig.CreateDirectories(); err != nil {
  130. return fmt.Errorf("创建必要目录失败: %w", err)
  131. }
  132. // 将配置序列化为TOML格式
  133. data, err := toml.Marshal(defaultConfig)
  134. if err != nil {
  135. return fmt.Errorf("序列化配置失败: %w", err)
  136. }
  137. // 写入配置文件
  138. if err := os.WriteFile(configPath, data, 0644); err != nil {
  139. return fmt.Errorf("写入配置文件失败: %w", err)
  140. }
  141. return nil
  142. }
  143. // LoadConfigFromFile 从文件句柄加载配置
  144. func LoadConfigFromFile(file *os.File) (*AppConfig, error) {
  145. var cfg AppConfig
  146. if err := toml.NewDecoder(file).Decode(&cfg); err != nil {
  147. return nil, err
  148. }
  149. return &cfg, nil
  150. }
  151. // SetDefaults 设置默认值
  152. func (cfg *AppConfig) SetDefaults() {
  153. // 服务器默认值
  154. if cfg.Server.Port == 0 {
  155. cfg.Server.Port = 8080
  156. }
  157. // 日志默认值
  158. if cfg.Log.Level == "" {
  159. cfg.Log.Level = "info"
  160. }
  161. if cfg.Log.Encoding == "" {
  162. cfg.Log.Encoding = "json"
  163. }
  164. if len(cfg.Log.OutputPaths) == 0 {
  165. cfg.Log.OutputPaths = []string{"stdout"}
  166. }
  167. if cfg.Log.LogDir == "" {
  168. cfg.Log.LogDir = "./DBconfig/logs"
  169. }
  170. // 审计默认值
  171. if cfg.Audit.DatabasePath == "" {
  172. cfg.Audit.DatabasePath = "./DBconfig/audit.db"
  173. }
  174. if cfg.Audit.RetentionDays == 0 {
  175. cfg.Audit.RetentionDays = 90
  176. }
  177. if cfg.Audit.BufferSize == 0 {
  178. cfg.Audit.BufferSize = 1000
  179. }
  180. // 注意:Audit.Enabled 默认为 false,需要明确启用
  181. // 存储默认值
  182. if cfg.Storage.Type == "" {
  183. cfg.Storage.Type = "db" // 默认使用数据库存储
  184. }
  185. if cfg.Storage.ConfigPath == "" {
  186. cfg.Storage.ConfigPath = "./DBconfig/data.toml"
  187. }
  188. if cfg.Storage.SQLBaseDir == "" {
  189. cfg.Storage.SQLBaseDir = "./DBconfig/sqlfiles"
  190. }
  191. if cfg.Storage.DatabasePath == "" {
  192. cfg.Storage.DatabasePath = "./DBconfig/storage.db"
  193. }
  194. if cfg.Storage.DataFile == "" {
  195. cfg.Storage.DataFile = "./DBconfig/data.toml" // 兼容旧配置 key `data_file`(推荐使用 `config_path`)
  196. }
  197. // MCP 默认值
  198. if cfg.MCP.ServerName == "" {
  199. cfg.MCP.ServerName = "dbview-mcp"
  200. }
  201. if cfg.MCP.ServerVersion == "" {
  202. cfg.MCP.ServerVersion = "0.1.0"
  203. }
  204. // AI 默认值
  205. if cfg.AI.Provider == "" {
  206. cfg.AI.Provider = "openai"
  207. }
  208. if cfg.AI.Model == "" {
  209. cfg.AI.Model = "gpt-4o-mini"
  210. }
  211. if cfg.AI.APIKeyEnv == "" {
  212. cfg.AI.APIKeyEnv = "OPENAI_API_KEY"
  213. }
  214. }
  215. // GetLogLevel 将字符串日志级别转换为 zapcore.Level
  216. func (lc *LogConfig) GetLogLevel() zapcore.Level {
  217. switch lc.Level {
  218. case "debug":
  219. return zapcore.DebugLevel
  220. case "info":
  221. return zapcore.InfoLevel
  222. case "warn", "warning":
  223. return zapcore.WarnLevel
  224. case "error":
  225. return zapcore.ErrorLevel
  226. case "fatal":
  227. return zapcore.FatalLevel
  228. default:
  229. return zapcore.InfoLevel
  230. }
  231. }
  232. // GetOutputPaths 获取完整的输出路径列表
  233. func (lc *LogConfig) GetOutputPaths() []string {
  234. paths := make([]string, 0, len(lc.OutputPaths))
  235. for _, path := range lc.OutputPaths {
  236. if path == "stdout" {
  237. paths = append(paths, "stdout")
  238. } else if path == "stderr" {
  239. paths = append(paths, "stderr")
  240. } else {
  241. // 如果是相对路径,转换为绝对路径
  242. if !filepath.IsAbs(path) {
  243. absPath := filepath.Join(lc.LogDir, path)
  244. paths = append(paths, absPath)
  245. } else {
  246. paths = append(paths, path)
  247. }
  248. }
  249. }
  250. return paths
  251. }
  252. // Validate 验证配置的有效性
  253. func (cfg *AppConfig) Validate() error {
  254. // 验证服务器配置
  255. if cfg.Server.Port <= 0 || cfg.Server.Port > 65535 {
  256. return fmt.Errorf("invalid server port: %d", cfg.Server.Port)
  257. }
  258. // 验证日志配置
  259. validLevels := map[string]bool{
  260. "debug": true, "info": true, "warn": true, "warning": true,
  261. "error": true, "fatal": true,
  262. }
  263. if !validLevels[cfg.Log.Level] {
  264. return fmt.Errorf("invalid log level: %s", cfg.Log.Level)
  265. }
  266. validEncodings := map[string]bool{
  267. "json": true, "console": true,
  268. }
  269. if !validEncodings[cfg.Log.Encoding] {
  270. return fmt.Errorf("invalid log encoding: %s", cfg.Log.Encoding)
  271. }
  272. // 验证审计配置
  273. if cfg.Audit.RetentionDays < 0 {
  274. return fmt.Errorf("invalid audit retention days: %d", cfg.Audit.RetentionDays)
  275. }
  276. if cfg.Audit.BufferSize < 0 {
  277. return fmt.Errorf("invalid audit buffer size: %d", cfg.Audit.BufferSize)
  278. }
  279. // 验证存储配置
  280. validStorageTypes := map[string]bool{
  281. "file": true, "db": true,
  282. }
  283. if !validStorageTypes[cfg.Storage.Type] {
  284. return fmt.Errorf("invalid storage type: %s", cfg.Storage.Type)
  285. }
  286. // 验证 AI 配置(仅在启用时检查)
  287. if cfg.AI.Enable {
  288. validProviders := map[string]bool{"openai": true}
  289. if !validProviders[cfg.AI.Provider] {
  290. return fmt.Errorf("invalid ai provider: %s", cfg.AI.Provider)
  291. }
  292. if cfg.AI.Model == "" {
  293. return fmt.Errorf("ai model required when ai enabled")
  294. }
  295. if cfg.AI.APIKeyEnv == "" {
  296. return fmt.Errorf("ai api_key_env required when ai enabled")
  297. }
  298. }
  299. return nil
  300. }
  301. // CreateDirectories 创建必要的目录
  302. func (cfg *AppConfig) CreateDirectories() error {
  303. dirs := []string{
  304. cfg.Log.LogDir,
  305. filepath.Dir(cfg.Audit.DatabasePath),
  306. cfg.Storage.SQLBaseDir,
  307. filepath.Dir(cfg.Storage.DataFile),
  308. }
  309. // 根据存储类型添加额外的目录
  310. if cfg.Storage.Type == "db" {
  311. dirs = append(dirs, filepath.Dir(cfg.Storage.DatabasePath))
  312. }
  313. for _, dir := range dirs {
  314. if dir == "" {
  315. continue
  316. }
  317. if err := os.MkdirAll(dir, 0755); err != nil {
  318. return fmt.Errorf("failed to create directory %s: %w", dir, err)
  319. }
  320. }
  321. return nil
  322. }
  323. // GetStorageManager 根据配置创建存储管理器
  324. func (sc *StorageConfig) GetStorageManager() (interface{}, error) {
  325. switch sc.Type {
  326. case "file":
  327. // 这里需要导入storage包,但为了避免循环依赖,我们返回配置信息
  328. return map[string]string{
  329. "type": "file",
  330. "config_path": sc.ConfigPath,
  331. "sql_base_dir": sc.SQLBaseDir,
  332. }, nil
  333. case "db":
  334. return map[string]string{
  335. "type": "db",
  336. "database_path": sc.DatabasePath,
  337. }, nil
  338. default:
  339. return nil, fmt.Errorf("unsupported storage type: %s", sc.Type)
  340. }
  341. }
  342. // IsFileStorage 是否使用文件存储
  343. func (sc *StorageConfig) IsFileStorage() bool {
  344. return sc.Type == "file"
  345. }
  346. // IsDBStorage 是否使用数据库存储
  347. func (sc *StorageConfig) IsDBStorage() bool {
  348. return sc.Type == "db"
  349. }