package services import ( "encoding/csv" "fmt" "os" "path/filepath" "runtime" "sort" "sync" "time" "regexp" "xg_fetl/internal/db_executor" "xg_fetl/internal/models" ) // ReaderMain 主函数,参数化配置,使其更加灵活 func ReaderMain( dbName string, // 数据库名称 dirPath string, // CSV 文件所在的根目录 tableInfos map[string]models.TableInfo, // 要处理的表信息 delimiter rune, // CSV 文件的分隔符 pageSize int, // 分页大小,每次读取的记录数 mysqeExec db_executor.DBExecutor, // 数据库执行器接口 ) { // 创建用于接收每个表处理结果的通道 resultChan := make(chan TableResult) // 创建等待组,用于等待所有表的处理完成 var wg sync.WaitGroup // 遍历每个表进行处理,使用协程 for _, tableName := range sortedTableNames(tableInfos) { wg.Add(1) go func(tableName string) { defer wg.Done() // 开始处理表,记录开始时间 startTime := time.Now() // 初始化 TableResult tableResult := TableResult{ TableName: tableName, Success: false, GoroutineCount: 1, // 只使用一个协程 Logs: []string{}, } // 构建表名对应的文件夹路径,例如 dirPath/tableName tableDirPath := filepath.Join(dirPath, tableName) // 检查表名对应的文件夹是否存在 if _, err := os.Stat(tableDirPath); os.IsNotExist(err) { errMsg := fmt.Sprintf("表 %s 对应的目录不存在: %v", tableName, err) tableResult.Error = fmt.Errorf(errMsg) tableResult.Logs = append(tableResult.Logs, errMsg) resultChan <- tableResult return } // 编译匹配 CSV 文件名的正则表达式,格式为 表名_[序号].csv filenamePattern := fmt.Sprintf(`^%s_\d+\.csv$`, regexp.QuoteMeta(tableName)) re, err := regexp.Compile(filenamePattern) if err != nil { errMsg := fmt.Sprintf("无法编译文件名正则表达式: %v", err) tableResult.Error = fmt.Errorf(errMsg) tableResult.Logs = append(tableResult.Logs, errMsg) resultChan <- tableResult return } // 获取表目录下的所有 CSV 文件路径 files, err := GetCSVFiles(tableDirPath) if err != nil { errMsg := fmt.Sprintf("无法获取 CSV 文件: %v", err) tableResult.Error = fmt.Errorf(errMsg) tableResult.Logs = append(tableResult.Logs, errMsg) resultChan <- tableResult return } // 过滤出符合命名规则的文件列表,并按文件名排序 var matchedFiles []string for _, filePath := range files { filename := filepath.Base(filePath) if re.MatchString(filename) { matchedFiles = append(matchedFiles, filePath) } } if len(matchedFiles) == 0 { errMsg := fmt.Sprintf("表 %s 没有找到符合条件的 CSV 文件", tableName) tableResult.Error = fmt.Errorf(errMsg) tableResult.Logs = append(tableResult.Logs, errMsg) resultChan <- tableResult return } // 按文件名排序,确保顺序一致 sort.Strings(matchedFiles) // 估算可用内存,决定一次加载多少个文件 availableMemory := getAvailableMemory() var batchFiles []string var totalSize int64 = 0 fmt.Println("可用内存:", availableMemory) for idx, filePath := range matchedFiles { fileInfo, err := os.Stat(filePath) if err != nil { tableResult.ErrorCount++ logMsg := fmt.Sprintf("无法获取文件信息 %s: %v", filePath, err) tableResult.Logs = append(tableResult.Logs, logMsg) continue } fileSize := fileInfo.Size() // 判断是否超过可用内存 if totalSize+fileSize > availableMemory && len(batchFiles) > 0 { // 处理当前批次文件 fmt.Println("处理当前批次文件", batchFiles) err = processCSVFiles(batchFiles, delimiter, mysqeExec, tableName, &tableResult) if err != nil { tableResult.ErrorCount++ logMsg := fmt.Sprintf("处理文件批次时出错: %v", err) tableResult.Logs = append(tableResult.Logs, logMsg) } // 重置批次 batchFiles = []string{} fmt.Println("重置批次") totalSize = 0 } batchFiles = append(batchFiles, filePath) totalSize += fileSize // 如果是最后一个文件,处理剩余的批次 if idx == len(matchedFiles)-1 && len(batchFiles) > 0 { fmt.Println(",处理剩余的批次 ", batchFiles) err = processCSVFiles(batchFiles, delimiter, mysqeExec, tableName, &tableResult) if err != nil { tableResult.ErrorCount++ logMsg := fmt.Sprintf("处理文件批次时出错: %v", err) tableResult.Logs = append(tableResult.Logs, logMsg) } } } // 计算导出持续时间 tableResult.ExportDuration = time.Since(startTime) tableResult.Success = tableResult.ErrorCount == 0 resultChan <- tableResult }(tableName) } // 开启一个协程来收集结果并打印 go func() { wg.Wait() close(resultChan) }() // 打印结果 for result := range resultChan { if result.Success { fmt.Printf("表 %s 导入成功,耗时 %v,共插入 %d 行数据。\n", result.TableName, result.ExportDuration, result.TotalRows) } else { fmt.Printf("表 %s 导入完成,存在错误,错误数量 %d,耗时 %v。\n", result.TableName, result.ErrorCount, result.ExportDuration) } // 打印详细日志 for _, logMsg := range result.Logs { fmt.Println(logMsg) } } } // processCSVFiles 处理一批 CSV 文件 func processCSVFiles( files []string, delimiter rune, mysqeExec db_executor.DBExecutor, tableName string, tableResult *TableResult, ) error { for _, filePath := range files { logMsg := fmt.Sprintf("开始读入文件: %s", filePath) fmt.Println(logMsg) tableResult.Logs = append(tableResult.Logs, logMsg) // 读取整个 CSV 文件到内存中 headers, records, err := ReadEntireCSV(filePath, delimiter) if err != nil { tableResult.ErrorCount++ logMsg = fmt.Sprintf("无法读取 CSV 文件 %s: %v", filePath, err) fmt.Println(logMsg) tableResult.Logs = append(tableResult.Logs, logMsg) continue } // 更新总行数和平均行大小 tableResult.TotalRows += len(records) if len(records) > 0 { fileInfo, statErr := os.Stat(filePath) if statErr == nil && fileInfo.Size() > 0 { rowSize := fileInfo.Size() / int64(len(records)) // 计算平均行大小 if tableResult.AverageRowSize == 0 { tableResult.AverageRowSize = rowSize } else { tableResult.AverageRowSize = (tableResult.AverageRowSize + rowSize) / 2 } } } // 插入数据到数据库 err = mysqeExec.InsertRecordsToDB(tableName, headers, records) if err != nil { tableResult.ErrorCount++ logMsg = fmt.Sprintf("无法将记录插入数据库: %v", err) fmt.Println(logMsg) tableResult.Logs = append(tableResult.Logs, logMsg) continue } logMsg = fmt.Sprintf("完成导入文件: %s,共插入 %d 行数据。", filePath, len(records)) fmt.Println(logMsg) tableResult.Logs = append(tableResult.Logs, logMsg) } return nil } // ReadEntireCSV 读取整个 CSV 文件到内存中 func ReadEntireCSV(filePath string, delimiter rune) ([]string, [][]string, error) { // 打开 CSV 文件 file, err := os.Open(filePath) if err != nil { return nil, nil, fmt.Errorf("无法打开文件: %v", err) } defer file.Close() // 创建 CSV Reader,设置分隔符 reader := csv.NewReader(file) reader.Comma = delimiter // 读取所有数据 records, err := reader.ReadAll() if err != nil { return nil, nil, fmt.Errorf("无法读取 CSV 文件: %v", err) } if len(records) == 0 { return nil, nil, fmt.Errorf("CSV 文件为空") } // 第一个记录是表头 headers := records[0] dataRecords := records[1:] // 跳过表头 return headers, dataRecords, nil } // GetCSVFiles 获取指定目录下的所有 CSV 文件路径 func GetCSVFiles(dirPath string) ([]string, error) { var files []string // 遍历目录,查找所有 .csv 文件 entries, err := os.ReadDir(dirPath) if err != nil { return nil, fmt.Errorf("无法读取目录 %s: %v", dirPath, err) } for _, entry := range entries { if !entry.IsDir() && filepath.Ext(entry.Name()) == ".csv" { files = append(files, filepath.Join(dirPath, entry.Name())) } } return files, nil } // getAvailableMemory 获取可用的系统内存 func getAvailableMemory() int64 { var m runtime.MemStats runtime.ReadMemStats(&m) // 假设使用可用内存的 80% availableMemory := int64(m.Sys - m.Alloc) return availableMemory * 80 / 100 }