package dbBase

import (
	"C"
	"database/sql"
	"fmt"
	"log"
	"os"
	"strings"
	_ "xgAutoTest/pkg/go-driver-xugusql"
)
import "time"

var db *sql.DB

func InitDb(ip string, port string, dbBase string, user string, pwd string) {
	var err error
	dbLink := fmt.Sprintf("IP=%s;DB=%s;User=%s;"+"PWD=%s;Port=%s;AUTO_COMMIT=on;CHAR_SET=UTF8", ip, dbBase, user, pwd, port)
	db, err = sql.Open("xugusql", dbLink)
	if err != nil {
		log.Fatal("db open fail")
	}
	err = db.Ping()
	if err != nil {
		log.Fatal("db Ping fail")
	} else {
		log.Printf("db Ping ok")
	}
}

func GetDb() *sql.DB {
	return db
}

func QueryString(db *sql.DB, sql string) string {
	fmt.Println("执行查询sql: ", sql)
	// 记录开始时间
	start := time.Now()
	rows, err := db.Query(sql)
	if err != nil {
		log.Fatal(err)
	}
	var cols []string
	cols, err = rows.Columns()
	if err != nil {
		log.Fatal(err)
	}

	var sqlResult string
	sqlResult += fmt.Sprintf("SQL> %s", sql+"\n\n")
	pvals := make([]interface{}, len(cols))
	for key, _ := range pvals {
		dest := make([]byte, 216)
		pvals[key] = &dest
	} /* end for */
	for _, v := range cols {
		//fmt.Printf("%s | ", v)
		sqlResult += fmt.Sprintf("%s | ", v)
	}
	sqlResult += "\n------------------------------------------------------------------------------\n"
	records := 0
	for rows.Next() {
		err = rows.Scan(pvals...)
		if err != nil {
			log.Fatal(err)
		}
		for _, v := range pvals {
			if string(*(v.(*[]byte))) == "" {
				//	fmt.Printf("%s| ", "<NULL>")
				sqlResult += fmt.Sprintf("%s | ", "<NULL>")
			} else {
				//	fmt.Printf("%s | ", string(*(v.(*[]byte))))
				sqlResult += fmt.Sprintf("%s | ", string(*(v.(*[]byte))))
			}

		}
		records++
		//fmt.Printf("\n")
		sqlResult += "\n"
	}

	//fmt.Println("--------------------------------")
	// 记录结束时间
	end := time.Now()

	// 计算操作耗时
	duration := end.Sub(start)
	durationInt := int64(duration.Milliseconds())
	if durationInt == 0 {
		durationInt = 1
	}
	rows.Close()

	sqlResult += fmt.Sprintf("\nTotal %d records.\n", records)

	sqlResult += fmt.Sprintf("\nUse time:%d ms.", durationInt)
	//fmt.Println("sqlResult:", sqlResult)
	return sqlResult
}

func ExecString(db *sql.DB, sql string) string {
	fmt.Println("执行插入sql: ", sql)
	var sqlResult string
	start := time.Now()
	result, err := db.Exec(sql)
	if err != nil {
		log.Fatal("Exe查询错误", err)
	}
	// 记录结束时间
	end := time.Now()

	// 计算操作耗时
	duration := end.Sub(start)
	sqlResult += fmt.Sprintf("SQL> %s", sql+"\n\n")
	// 获取影响的行数
	rowsAffected, err := result.RowsAffected()
	if err != nil {
		log.Fatal(err)
	}
	//fmt.Printf("Rows affected: %d\n", rowsAffected)

	durationInt := int64(duration.Milliseconds())
	if durationInt == 0 {
		durationInt = 1
	}
	sqlResult += fmt.Sprintf("Total %d records effected.\n", rowsAffected)
	sqlResult += fmt.Sprintf("\nUse time:%d ms.", durationInt)
	//fmt.Println("sqlResult", sqlResult)

	return sqlResult
}

func ExecPrepareString(db1 *sql.DB, sql string, fileLocal []string) {

	//var sqlResult string
	fmt.Println("fileLocal的数量: ", len(fileLocal))
	var args []interface{}
	for _, v := range fileLocal {

		// 读取图像文件
		v := strings.ReplaceAll(v, " ", "")
		fmt.Println("文件地址:", v)
		imageData, err := os.ReadFile(v)
		if err != nil {
			log.Fatal(err)
		}

		args = append(args, &imageData)
	}

	fmt.Println("sql查询:", sql)
	stmt, err := db.Prepare(sql)
	if err != nil {
		log.Fatal(err)
	}

	_, err = stmt.Exec(args...)
	if err != nil {
		log.Fatal(err)
	}

	// // 获取影响的行数
	// rowsAffected, err := result.RowsAffected()
	// if err != nil {
	// 	log.Fatal(err)
	// }
	// sqlResult += fmt.Sprintf("Total %d records effected.", rowsAffected)
	// sqlResult += fmt.Sprintf("\nUse time:1 ms.")
	// return sqlResult
}