config.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. package config
  2. import (
  3. "fmt"
  4. "regexp"
  5. "github.com/spf13/viper"
  6. )
  7. /*
  8. Config 配置结构体,包含导入和导出相关的配置
  9. */
  10. type Config struct {
  11. ImportConfig ImportConfig // 导入配置
  12. ExportConfig ExportConfig // 导出配置
  13. }
  14. /*
  15. ImportConfig 导入配置结构体
  16. 包含数据库配置、导入目录、分隔符、分页大小以及文件名匹配模式等配置
  17. */
  18. type ImportConfig struct {
  19. DBConfig DBConfig // 数据库配置
  20. Delimiter rune // 导入文件的分隔符
  21. PageSize int // 分页读取的大小
  22. FilenamePattern *regexp.Regexp // 文件名匹配的正则表达式
  23. }
  24. /*
  25. ExportConfig 导出配置结构体
  26. 包含数据库配置、导出表名、文件路径、分隔符、文件大小上限以及分页大小等配置
  27. */
  28. type ExportConfig struct {
  29. DBConfig DBConfig // 数据库配置
  30. BaseFilePath string // 导出文件的基本路径
  31. Delimiter rune // CSV 的分隔符
  32. MaxFileSize int64 // 单个文件的最大大小 (字节)
  33. PageSize int // 每次分页读取的行数
  34. }
  35. /*
  36. DBConfig 数据库配置结构体
  37. 包含数据库类型、名称、用户名、密码、IP 地址和端口等信息
  38. */
  39. type DBConfig struct {
  40. DBType string // 数据库类型 (如 MySQL, PostgreSQL 等)
  41. DBName string // 数据库名称
  42. User string // 数据库用户名
  43. Password string // 数据库密码
  44. IP string // 数据库 IP 地址
  45. Port int // 数据库端口号
  46. }
  47. func LoadConfig() (*Config, error) {
  48. viper.SetConfigName("config") // 配置文件名(不带扩展名)
  49. viper.SetConfigType("toml") // 配置文件类型
  50. viper.AddConfigPath(".") // 查找配置文件的路径
  51. if err := viper.ReadInConfig(); err != nil {
  52. return nil, fmt.Errorf("error reading config file: %v", err)
  53. }
  54. // 解析 Import 配置
  55. importPattern, err := regexp.Compile(viper.GetString("import.cfg.filename_pattern"))
  56. if err != nil {
  57. return nil, fmt.Errorf("error compiling filename pattern: %v", err)
  58. }
  59. importConfig := ImportConfig{
  60. DBConfig: DBConfig{
  61. DBType: viper.GetString("import.db_type"),
  62. DBName: viper.GetString("import.dbname"),
  63. User: viper.GetString("import.user"),
  64. Password: viper.GetString("import.password"),
  65. IP: viper.GetString("import.ip"),
  66. Port: viper.GetInt("import.port"),
  67. },
  68. Delimiter: rune(viper.GetString("import.cfg.delimiter")[0]),
  69. PageSize: viper.GetInt("import.cfg.page_size"),
  70. FilenamePattern: importPattern,
  71. }
  72. // 解析 Export 配置
  73. exportConfig := ExportConfig{
  74. DBConfig: DBConfig{
  75. DBType: viper.GetString("export.db_type"),
  76. DBName: viper.GetString("export.dbname"),
  77. User: viper.GetString("export.user"),
  78. Password: viper.GetString("export.password"),
  79. IP: viper.GetString("export.ip"),
  80. Port: viper.GetInt("export.port"),
  81. },
  82. BaseFilePath: viper.GetString("export.cfg.base_file_path"),
  83. Delimiter: rune(viper.GetString("export.cfg.delimiter")[0]),
  84. MaxFileSize: viper.GetInt64("export.cfg.max_file_size"),
  85. PageSize: viper.GetInt("export.cfg.page_size"),
  86. }
  87. return &Config{
  88. ImportConfig: importConfig,
  89. ExportConfig: exportConfig,
  90. }, nil
  91. }
  92. func (dbConfig *DBConfig) GetDSN() string {
  93. switch dbConfig.DBType {
  94. case "mysql":
  95. return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", dbConfig.User, dbConfig.Password, dbConfig.IP, dbConfig.Port, dbConfig.DBName)
  96. case "xugu":
  97. return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", dbConfig.User, dbConfig.Password, dbConfig.IP, dbConfig.Port, dbConfig.DBName)
  98. }
  99. return ""
  100. }