diff --git a/AI.md b/AI.md new file mode 100644 index 0000000..bdd3708 --- /dev/null +++ b/AI.md @@ -0,0 +1,16 @@ +# AI 指南 - @go/db + +## 🤖 AI 调用规则 +- **版本**: v1.0.1 +- **核心原则**: 优先使用结构化绑定(`To`, `MapResults`),避免手动拼装 SQL 结果。 +- **敏感数据**: 必须通过 `SetEncryptKeys` 配置密钥,确保 DSN 中的密码安全。 +- **读写分离**: 鼓励在 DSN 中配置多个 Host 以利用内置的读写分离机制。 +- **性能优化**: + - 大规模查询应优先绑定到 Struct 切片。 + - 频繁执行的 SQL 应使用 `Prepare`。 +- **事务处理**: 始终使用 `tx.Finish(err == nil)` 或 `defer tx.CheckFinished()` 确保事务闭环。 + +## ⚠️ 注意事项 +- 严禁在代码中硬编码数据库凭据。 +- 严禁忽略 `Exec` 或 `Query` 返回的 `Error`。 +- SQLite 模式下,时间字段会自动转换,无需手动解析字符串。 diff --git a/Base.go b/Base.go new file mode 100644 index 0000000..3fce087 --- /dev/null +++ b/Base.go @@ -0,0 +1,213 @@ +package db + +import ( + "database/sql" + "errors" + "fmt" + "reflect" + "strings" + "time" + + "apigo.cc/go/cast" + "apigo.cc/go/log" +) + +func basePrepare(db *sql.DB, tx *sql.Tx, query string) *Stmt { + var sqlStmt *sql.Stmt + var err error + if tx != nil { + sqlStmt, err = tx.Prepare(query) + } else if db != nil { + sqlStmt, err = db.Prepare(query) + } else { + return &Stmt{Error: errors.New("operate on a bad connection")} + } + if err != nil { + return &Stmt{Error: err} + } + return &Stmt{conn: sqlStmt, lastSql: &query} +} + +func baseExec(db *sql.DB, tx *sql.Tx, query string, args ...any) *ExecResult { + args = flatArgs(args) + var r sql.Result + var err error + startTime := time.Now() + if tx != nil { + r, err = tx.Exec(query, args...) + } else if db != nil { + r, err = db.Exec(query, args...) + } else { + return &ExecResult{Sql: &query, Args: args, usedTime: log.MakeUsedTime(startTime, time.Now()), Error: errors.New("operate on a bad connection")} + } + endTime := time.Now() + usedTime := log.MakeUsedTime(startTime, endTime) + + if err != nil { + return &ExecResult{Sql: &query, Args: args, usedTime: usedTime, Error: err} + } + return &ExecResult{Sql: &query, Args: args, usedTime: usedTime, result: r} +} + +func flatArgs(args []any) []any { + for i, arg := range args { + if arg == nil { + continue + } + argValue := reflect.ValueOf(arg) + kind := argValue.Kind() + if kind == reflect.Map || kind == reflect.Struct || (kind == reflect.Slice && argValue.Type().Elem().Kind() != reflect.Uint8) { + args[i] = cast.MustToJSON(arg) + } + } + return args +} + +func baseQuery(db *sql.DB, tx *sql.Tx, query string, args ...any) *QueryResult { + args = flatArgs(args) + + var rows *sql.Rows + var err error + startTime := time.Now() + if tx != nil { + rows, err = tx.Query(query, args...) + } else if db != nil { + rows, err = db.Query(query, args...) + } else { + return &QueryResult{Sql: &query, Args: args, usedTime: log.MakeUsedTime(startTime, time.Now()), Error: errors.New("operate on a bad connection")} + } + endTime := time.Now() + usedTime := log.MakeUsedTime(startTime, endTime) + + if err != nil { + return &QueryResult{Sql: &query, Args: args, usedTime: usedTime, Error: err} + } + return &QueryResult{Sql: &query, Args: args, usedTime: usedTime, rows: rows} +} + +func quote(quoteTag string, text string) string { + a := strings.Split(text, ".") + for i, v := range a { + a[i] = quoteTag + strings.ReplaceAll(v, quoteTag, "\\"+quoteTag) + quoteTag + } + return strings.Join(a, ".") +} + +func quotes(quoteTag string, texts []string) string { + for i, v := range texts { + texts[i] = quote(quoteTag, v) + } + return strings.Join(texts, ",") +} + +func makeInsertSql(quoteTag string, table string, data any, useReplace bool) (string, []any) { + keys, vars, values := MakeKeysVarsValues(data) + operation := "insert" + if useReplace { + operation = "replace" + } + query := fmt.Sprintf("%s into %s (%s) values (%s)", operation, quote(quoteTag, table), quotes(quoteTag, keys), strings.Join(vars, ",")) + return query, values +} + +func makeUpdateSql(quoteTag string, table string, data any, conditions string, args ...any) (string, []any) { + args = flatArgs(args) + keys, vars, values := MakeKeysVarsValues(data) + for i, k := range keys { + keys[i] = fmt.Sprintf("%s=%s", quote(quoteTag, k), vars[i]) + } + values = append(values, args...) + if conditions != "" { + conditions = " where " + conditions + } + query := fmt.Sprintf("update %s set %s%s", quote(quoteTag, table), strings.Join(keys, ","), conditions) + return query, values +} + +func (db *DB) MakeInsertSql(table string, data any, useReplace bool) (string, []any) { + return makeInsertSql(db.QuoteTag, table, data, useReplace) +} + +func (db *DB) MakeUpdateSql(table string, data any, conditions string, args ...any) (string, []any) { + return makeUpdateSql(db.QuoteTag, table, data, conditions, args...) +} + +func (tx *Tx) MakeInsertSql(table string, data any, useReplace bool) (string, []any) { + return makeInsertSql(tx.QuoteTag, table, data, useReplace) +} + +func (tx *Tx) MakeUpdateSql(table string, data any, conditions string, args ...any) (string, []any) { + return makeUpdateSql(tx.QuoteTag, table, data, conditions, args...) +} + +func getFlatFields(fields map[string]reflect.Value, fieldKeys *[]string, value reflect.Value) { + valueType := value.Type() + for i := 0; i < value.NumField(); i++ { + v := value.Field(i) + if valueType.Field(i).Anonymous { + getFlatFields(fields, fieldKeys, v) + } else { + *fieldKeys = append(*fieldKeys, valueType.Field(i).Name) + fields[valueType.Field(i).Name] = v + } + } +} + +func MakeKeysVarsValues(data any) ([]string, []string, []any) { + keys := make([]string, 0) + vars := make([]string, 0) + values := make([]any, 0) + + dataType := reflect.TypeOf(data) + dataValue := reflect.ValueOf(data) + for dataType.Kind() == reflect.Ptr { + dataType = dataType.Elem() + dataValue = dataValue.Elem() + } + + if dataType.Kind() == reflect.Struct { + fields := make(map[string]reflect.Value) + fieldKeys := make([]string, 0) + getFlatFields(fields, &fieldKeys, dataValue) + for _, k := range fieldKeys { + if k[0] >= 'a' && k[0] <= 'z' { + continue + } + v := fields[k] + if v.Kind() == reflect.Interface { + v = v.Elem() + } + keys = append(keys, k) + if v.Kind() == reflect.String && v.Len() > 0 && v.String()[0] == ':' { + vars = append(vars, v.String()[1:]) + } else { + vars = append(vars, "?") + if !v.IsValid() || !v.CanInterface() { + values = append(values, nil) + } else { + values = append(values, v.Interface()) + } + } + } + } else if dataType.Kind() == reflect.Map { + for _, k := range dataValue.MapKeys() { + v := dataValue.MapIndex(k) + if v.Kind() == reflect.Interface { + v = v.Elem() + } + keys = append(keys, cast.String(k.Interface())) + if v.Kind() == reflect.String && v.Len() > 0 && v.String()[0] == ':' { + vars = append(vars, v.String()[1:]) + } else { + vars = append(vars, "?") + if !v.IsValid() || !v.CanInterface() { + values = append(values, nil) + } else { + values = append(values, v.Interface()) + } + } + } + } + + return keys, vars, values +} diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..8dbf097 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,12 @@ +# CHANGELOG - @go/db + +## [1.0.1] - 2026-05-03 +### Optimized +- Refactored `makeResults` to pre-calculate field mappings for structs, significantly improving performance for large result sets. +- Simplified and optimized `makeValue` and `makePublicVarName` functions. +- Optimized time parsing in `makeResults`. + +### Fixed +- Fixed typo `isCommitedOrRollbacked` to `isCommittedOrRollbacked` in `Tx` struct. +- Standardized parameter naming: renamed `requestSql` to `query` and `wheres` to `conditions` across the module. +- Modernized Go syntax to align with latest standards. diff --git a/DB.go b/DB.go new file mode 100644 index 0000000..bd6387b --- /dev/null +++ b/DB.go @@ -0,0 +1,611 @@ +package db + +import ( + "crypto/tls" + "database/sql" + "database/sql/driver" + "encoding/base64" + "errors" + "fmt" + "net/url" + "regexp" + "strings" + "sync" + "time" + + "apigo.cc/go/cast" + "apigo.cc/go/config" + "apigo.cc/go/crypto" + "apigo.cc/go/id" + "apigo.cc/go/log" + "apigo.cc/go/rand" + "apigo.cc/go/safe" +) + +type Config struct { + Type string + User string + Password string + pwd *safe.SafeBuf + Host string + ReadonlyHosts []string + DB string + SSL string + tls *tls.Config + Args string + MaxOpens int + MaxIdles int + MaxLifeTime int + LogSlow config.Duration + logger *log.Logger +} + +type SSL struct { + Ca string + Cert string + Key string + Insecure bool +} + +type ConnectorFunc func(*Config, *safe.SafeBuf, *tls.Config) driver.Connector + +var connectors = map[string]ConnectorFunc{} + +func RegisterConnector(typ string, conn ConnectorFunc) { + connectors[typ] = conn +} + +var sqlite3PwdMatcher = regexp.MustCompile(`x'\w+'`) + +func (dbInfo *Config) Dsn() string { + args := make([]string, 0) + if dbInfo.SSL != "" { + args = append(args, "tls="+dbInfo.SSL) + } + if dbInfo.Args != "" { + args = append(args, dbInfo.Args) + } + argsStr := "" + if len(args) > 0 { + argsStr = "&" + strings.Join(args, "&") + } + + if isFileDB(dbInfo.Type) { + argsStr = sqlite3PwdMatcher.ReplaceAllString(argsStr, "******") + return fmt.Sprintf("%s://%s?logSlow=%s"+argsStr, dbInfo.Type, dbInfo.Host, dbInfo.LogSlow.TimeDuration()) + } else { + hosts := []string{dbInfo.Host} + if dbInfo.ReadonlyHosts != nil { + hosts = append(hosts, dbInfo.ReadonlyHosts...) + } + return fmt.Sprintf("%s://%s:****@%s/%s?logSlow=%s"+argsStr, dbInfo.Type, dbInfo.User, strings.Join(hosts, ","), dbInfo.DB, dbInfo.LogSlow.TimeDuration()) + } +} + +var fileDBs = map[string]bool{ + "sqlite": true, + "sqlite3": true, + "chai": true, + "access": true, + "mdb": true, + "accdb": true, + "h2": true, + "hsqldb": true, + "derby": true, + "sqlce": true, + "sdf": true, + "firebird": true, + "fdb": true, + "dbase": true, + "dbf": true, + "berkeleydb": true, + "bdb": true, +} + +func isFileDB(typ string) bool { + return fileDBs[typ] +} + +func (dbInfo *Config) ConfigureBy(setting string) { + urlInfo, err := url.Parse(setting) + if err != nil { + if dbInfo.logger != nil { + dbInfo.logger.Error(err.Error(), "url", setting) + } + return + } + + dbInfo.Type = urlInfo.Scheme + if isFileDB(dbInfo.Type) { + dbInfo.Host = urlInfo.Host + urlInfo.Path + dbInfo.DB = strings.SplitN(urlInfo.Host, ".", 2)[0] + } else { + if strings.ContainsRune(urlInfo.Host, ',') { + a := strings.Split(urlInfo.Host, ",") + dbInfo.Host = a[0] + dbInfo.ReadonlyHosts = a[1:] + } else { + dbInfo.Host = urlInfo.Host + dbInfo.ReadonlyHosts = nil + } + if len(urlInfo.Path) > 1 { + dbInfo.DB = urlInfo.Path[1:] + } + } + dbInfo.User = urlInfo.User.Username() + dbInfo.Password, _ = urlInfo.User.Password() + + q := urlInfo.Query() + dbInfo.MaxIdles = cast.Int(q.Get("maxIdles")) + dbInfo.MaxLifeTime = cast.Int(q.Get("maxLifeTime")) + dbInfo.MaxOpens = cast.Int(q.Get("maxOpens")) + dbInfo.LogSlow = config.Duration(cast.Duration(q.Get("logSlow"))) + dbInfo.SSL = q.Get("tls") + + sslCa := q.Get("sslCA") + sslCert := q.Get("sslCert") + sslKey := q.Get("sslKey") + sslSkipVerify := cast.Bool(q.Get("sslSkipVerify")) + if sslCa == "" || sslCert == "" || sslKey == "" { + if dbSSL, ok := dbSSLs[dbInfo.SSL]; ok { + sslCa = dbSSL.Ca + sslCert = dbSSL.Cert + sslKey = dbSSL.Key + sslSkipVerify = dbSSL.Insecure + } + } + if sslCa != "" && sslCert != "" && sslKey != "" { + sslName := id.MakeID(12) + dbInfo.SSL = sslName + decryptedCa, _ := confAes.DecryptBytes([]byte(sslCa)) + decryptedCert, _ := confAes.DecryptBytes([]byte(sslCert)) + decryptedKey, _ := confAes.DecryptBytes([]byte(sslKey)) + tlsConf := BuildTLSConfig(decryptedCa, decryptedCert, decryptedKey, sslSkipVerify) + if tlsConf != nil { + dbInfo.tls = tlsConf + } + safe.ZeroMemory(decryptedCa) + safe.ZeroMemory(decryptedCert) + safe.ZeroMemory(decryptedKey) + } + + args := make([]string, 0) + for k := range q { + if k != "maxIdles" && k != "maxLifeTime" && k != "maxOpens" && k != "logSlow" && k != "tls" { + args = append(args, k+"="+q.Get(k)) + } + } + if len(args) > 0 { + dbInfo.Args = strings.Join(args, "&") + } +} + +type DB struct { + name string + conn *sql.DB + readonlyConnections []*sql.DB + Config *Config + logger *dbLogger + Error error + QuoteTag string +} + +var confAes, _ = crypto.NewAESCBCAndEraseKey([]byte("?GQ$0K0GgLdO=f+~L68PLm$uhKr4'=tV"), []byte("VFs7@sK61cj^f?HZ")) +var keysSetted = sync.Once{} + +func SetEncryptKeys(key, iv []byte) { + keysSetted.Do(func() { + confAes.Close() + confAes, _ = crypto.NewAESGCMAndEraseKey(key, iv) + }) +} + +type dbLogger struct { + config *Config + logger *log.Logger +} + +func (dl *dbLogger) LogError(errStr string) { + dl.logger.DBError(errStr, dl.config.Type, dl.config.Dsn(), "", nil, 0) +} + +func (dl *dbLogger) LogQuery(query string, args []any, usedTime float32) { + dl.logger.DB(dl.config.Type, dl.config.Dsn(), query, args, usedTime) +} + +func (dl *dbLogger) LogQueryError(errStr string, query string, args []any, usedTime float32) { + dl.logger.DBError(errStr, dl.config.Type, dl.config.Dsn(), query, args, usedTime) +} + +var dbConfigs = make(map[string]*Config) +var dbConfigsLock = sync.RWMutex{} +var dbSSLs = make(map[string]*SSL) +var dbInstances = make(map[string]*DB) +var dbInstancesLock = sync.RWMutex{} +var once sync.Once + +func GetDBWithoutCache(name string, logger *log.Logger) *DB { + return getDB(name, logger, false) +} + +func GetDB(name string, logger *log.Logger) *DB { + return getDB(name, logger, true) +} + +func getDB(name string, logger *log.Logger, useCache bool) *DB { + if logger == nil { + logger = log.DefaultLogger + } + + if useCache { + dbInstancesLock.RLock() + oldConn := dbInstances[name] + dbInstancesLock.RUnlock() + if oldConn != nil { + return oldConn.CopyByLogger(logger) + } + } + + var conf *Config + if strings.Contains(name, "://") { + conf = new(Config) + conf.logger = logger + conf.ConfigureBy(name) + } else { + dbConfigsLock.RLock() + n := len(dbConfigs) + dbConfigsLock.RUnlock() + if n == 0 { + once.Do(func() { + dbConfigs1 := make(map[string]*Config) + if err := config.Load("db", &dbConfigs1); err == nil { + for k, v := range dbConfigs1 { + if v.Host != "" { + dbConfigsLock.Lock() + dbConfigs[k] = v + dbConfigsLock.Unlock() + } + } + } else { + logger.Error(err.Error()) + } + dbConfigs2 := make(map[string]string) + if err := config.Load("db", &dbConfigs2); err == nil { + for k, v := range dbConfigs2 { + if strings.Contains(v, "://") { + v2 := new(Config) + v2.ConfigureBy(v) + if v2.Host != "" { + v2.logger = logger + dbConfigsLock.Lock() + dbConfigs[k] = v2 + dbConfigsLock.Unlock() + } + } else { + dbConfigsLock.Lock() + v2 := dbConfigs[v] + if v2 != nil && v2.Host != "" { + dbConfigs[k] = v2 + } + dbConfigsLock.Unlock() + } + } + } else { + logger.Error(err.Error()) + } + }) + } + dbConfigsLock.RLock() + conf = dbConfigs[name] + dbConfigsLock.RUnlock() + if conf == nil { + conf = new(Config) + dbConfigsLock.Lock() + dbConfigs[name] = conf + dbConfigsLock.Unlock() + } + } + + if conf.Host == "" { + if name != "default" { + logger.Error("db config not exists", "name", name) + } + return nil + } + + if conf.SSL != "" && len(dbSSLs) == 0 { + _ = config.Load("dbssl", &dbSSLs) + } + + if conf.SSL != "" && dbSSLs[conf.SSL] == nil { + logger.Error("dbssl config lost") + } + + if strings.ContainsRune(conf.Host, ',') { + a := strings.Split(conf.Host, ",") + conf.Host = a[0] + conf.ReadonlyHosts = a[1:] + } else { + conf.ReadonlyHosts = nil + } + + if conf.Password != "" { + if encryptedPassword, err := base64.URLEncoding.DecodeString(conf.Password); err == nil { + if pwdSafeBuf, err := confAes.Decrypt(encryptedPassword); err == nil { + conf.pwd = pwdSafeBuf + } + } + if conf.pwd == nil { + conf.pwd = safe.NewSafeBuf([]byte(conf.Password)) + } + } else { + if !isFileDB(conf.Type) && conf.Host != "127.0.0.1:3306" && conf.User == "root" { + logger.Warning("password is empty") + } + } + conf.Password = "" + + conn, err := getPool(conf) + if err != nil { + logger.DBError(err.Error(), conf.Type, conf.Dsn(), "", nil, 0) + return &DB{conn: nil, QuoteTag: "\"", Error: err} + } + + db := new(DB) + db.QuoteTag = cast.If(conf.Type == "mysql", "`", "\"") + db.name = name + db.conn = conn + + if conf.ReadonlyHosts != nil { + readonlyConnections := make([]*sql.DB, 0) + for _, host := range conf.ReadonlyHosts { + conn, err := getPoolForHost(conf, host) + if err != nil { + logger.DBError(err.Error(), conf.Type, conf.Dsn(), "", nil, 0) + } else { + readonlyConnections = append(readonlyConnections, conn) + } + } + if len(readonlyConnections) > 0 { + db.readonlyConnections = readonlyConnections + } + } + + db.Error = nil + db.Config = conf + if conf.MaxIdles > 0 { + conn.SetMaxIdleConns(conf.MaxIdles) + } + if conf.MaxOpens > 0 { + conn.SetMaxOpenConns(conf.MaxOpens) + } + if conf.MaxLifeTime > 0 { + conn.SetConnMaxLifetime(time.Second * time.Duration(conf.MaxLifeTime)) + } + if conf.LogSlow == 0 { + conf.LogSlow = config.Duration(1000 * time.Millisecond) + } + if useCache { + dbInstancesLock.Lock() + dbInstances[name] = db + dbInstancesLock.Unlock() + } + return db.CopyByLogger(logger) +} + +func getPool(conf *Config) (*sql.DB, error) { + return getPoolForHost(conf, "") +} + +func getPoolForHost(conf *Config, host string) (*sql.DB, error) { + connectType := "tcp" + if host == "" { + host = conf.Host + } + if len(host) > 0 && host[0] == '/' { + connectType = "unix" + } + + if connector := connectors[conf.Type]; connector != nil { + return sql.OpenDB(connector(conf, conf.pwd, conf.tls)), nil + } else { + dsn := "" + args := make([]string, 0) + if conf.SSL != "" { + args = append(args, "tls="+conf.SSL) + } + if conf.Args != "" { + args = append(args, conf.Args) + } + argsStr := "" + if len(args) > 0 { + argsStr = "?" + strings.Join(args, "&") + } + + if isFileDB(conf.Type) { + dsn = host + argsStr + } else { + pwdBuf := conf.pwd.Open() + dsn = fmt.Sprintf("%s:%s@%s(%s)/%s"+argsStr, conf.User, pwdBuf.String(), connectType, host, conf.DB) + pwdBuf.Close() + } + return sql.Open(conf.Type, dsn) + } +} + +func (db *DB) CopyByLogger(logger *log.Logger) *DB { + newDB := new(DB) + newDB.QuoteTag = db.QuoteTag + newDB.name = db.name + newDB.conn = db.conn + newDB.readonlyConnections = db.readonlyConnections + newDB.Config = db.Config + if logger == nil { + logger = log.DefaultLogger + } + newDB.logger = &dbLogger{logger: logger, config: db.Config} + return newDB +} + +func (db *DB) SetLogger(logger *log.Logger) { + db.logger.logger = logger +} + +func (db *DB) GetLogger() *log.Logger { + return db.logger.logger +} + +func (db *DB) Destroy() error { + if db.conn == nil { + return errors.New("operate on a bad connection") + } + err := db.conn.Close() + if err != nil { + db.logger.LogError(err.Error()) + } + dbInstancesLock.Lock() + delete(dbInstances, db.name) + dbInstancesLock.Unlock() + return err +} + +func (db *DB) GetOriginDB() *sql.DB { + return db.conn +} + +func (db *DB) Prepare(query string) *Stmt { + stmt := basePrepare(db.conn, nil, query) + stmt.logger = db.logger + if stmt.Error != nil { + db.logger.LogError(stmt.Error.Error()) + } + return stmt +} + +func (db *DB) Quote(text string) string { + return quote(db.QuoteTag, text) +} + +func (db *DB) Quotes(texts []string) string { + return quotes(db.QuoteTag, texts) +} + +func (db *DB) Begin() *Tx { + if db.conn == nil { + return &Tx{QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), Error: errors.New("operate on a bad connection"), logger: db.logger} + } + sqlTx, err := db.conn.Begin() + if err != nil { + db.logger.LogError(err.Error()) + return &Tx{QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), Error: err, logger: db.logger} + } + return &Tx{QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), conn: sqlTx, logger: db.logger} +} + +func (db *DB) Exec(query string, args ...any) *ExecResult { + r := baseExec(db.conn, nil, query, args...) + r.logger = db.logger + if r.Error != nil { + db.logger.LogQueryError(r.Error.Error(), query, args, r.usedTime) + } else { + if db.Config.LogSlow > 0 && r.usedTime >= float32(db.Config.LogSlow.TimeDuration()/time.Millisecond) { + db.logger.LogQuery(query, args, r.usedTime) + } + } + return r +} + +func (db *DB) Query(query string, args ...any) *QueryResult { + conn := db.conn + if db.readonlyConnections != nil { + connNum := len(db.readonlyConnections) + if connNum == 1 { + conn = db.readonlyConnections[0] + } else { + p := rand.Int(0, connNum-1) + conn = db.readonlyConnections[p] + } + } + + r := baseQuery(conn, nil, query, args...) + r.logger = db.logger + if r.Error != nil { + db.logger.LogQueryError(r.Error.Error(), query, args, r.usedTime) + } else { + if db.Config.LogSlow > 0 && r.usedTime >= float32(db.Config.LogSlow.TimeDuration()/time.Millisecond) { + db.logger.LogQuery(query, args, r.usedTime) + } + } + return r +} + +func (db *DB) Insert(table string, data any) *ExecResult { + query, values := db.MakeInsertSql(table, data, false) + r := baseExec(db.conn, nil, query, values...) + r.logger = db.logger + if r.Error != nil { + db.logger.LogQueryError(r.Error.Error(), query, values, r.usedTime) + } else { + if db.Config.LogSlow > 0 && r.usedTime >= float32(db.Config.LogSlow.TimeDuration()/time.Millisecond) { + db.logger.LogQuery(query, values, r.usedTime) + } + } + return r +} + +func (db *DB) Replace(table string, data any) *ExecResult { + query, values := db.MakeInsertSql(table, data, true) + r := baseExec(db.conn, nil, query, values...) + r.logger = db.logger + if r.Error != nil { + db.logger.LogQueryError(r.Error.Error(), query, values, r.usedTime) + } else { + if db.Config.LogSlow > 0 && r.usedTime >= float32(db.Config.LogSlow.TimeDuration()/time.Millisecond) { + db.logger.LogQuery(query, values, r.usedTime) + } + } + return r +} + +func (db *DB) Update(table string, data any, conditions string, args ...any) *ExecResult { + query, values := db.MakeUpdateSql(table, data, conditions, args...) + r := baseExec(db.conn, nil, query, values...) + r.logger = db.logger + if r.Error != nil { + db.logger.LogQueryError(r.Error.Error(), query, values, r.usedTime) + } else { + if db.Config.LogSlow > 0 && r.usedTime >= float32(db.Config.LogSlow.TimeDuration()/time.Millisecond) { + db.logger.LogQuery(query, values, r.usedTime) + } + } + return r +} + +func (db *DB) Delete(table string, conditions string, args ...any) *ExecResult { + if conditions != "" { + conditions = " where " + conditions + } + query := fmt.Sprintf("delete from %s%s", db.Quote(table), conditions) + r := baseExec(db.conn, nil, query, args...) + r.logger = db.logger + if r.Error != nil { + db.logger.LogQueryError(r.Error.Error(), query, args, r.usedTime) + } else { + if db.Config.LogSlow > 0 && r.usedTime >= float32(db.Config.LogSlow.TimeDuration()/time.Millisecond) { + db.logger.LogQuery(query, args, r.usedTime) + } + } + return r +} + +func (db *DB) InKeys(numArgs int) string { + return InKeys(numArgs) +} + +func InKeys(numArgs int) string { + a := make([]string, numArgs) + for i := 0; i < numArgs; i++ { + a[i] = "?" + } + return fmt.Sprintf("(%s)", strings.Join(a, ",")) +} diff --git a/DB_test.go b/DB_test.go new file mode 100644 index 0000000..f8c2778 --- /dev/null +++ b/DB_test.go @@ -0,0 +1,477 @@ +package db_test + +import ( + "fmt" + "regexp" + "strings" + "testing" + "time" + + "apigo.cc/go/cast" + "apigo.cc/go/db" + "apigo.cc/go/shell" + + _ "apigo.cc/go/db/mysql" + _ "modernc.org/sqlite" +) + +var dbset = "sqlite://test.db" + +type userInfo struct { + innerId int + Tag string + Id int + Name string + Phone *string + Email string + Parents []string + Active bool + Time string +} + +type UserBaseModel struct { + Id int + Name string + Password string +} + +type UserModel struct { + UserBaseModel + Phone string + Active bool + Parents []string + UserStatus int + Owner int + Salt string +} + +func initDB(t *testing.T) *db.DB { + var er *db.ExecResult + dbInst := db.GetDB(dbset, nil) + fmt.Println("dbType", shell.BCyan(dbInst.Config.Type)) + if dbInst.Error != nil { + t.Fatal("GetDB error", dbInst) + return nil + } + + finishDB(dbInst, t) + if dbInst.Config.Type == "mysql" { + er = dbInst.Exec(`CREATE TABLE IF NOT EXISTS tempUsersForDBTest ( + id INT NOT NULL AUTO_INCREMENT, + name VARCHAR(45) NOT NULL, + phone VARCHAR(45), + email VARCHAR(45), + parents JSON, + active TINYINT NOT NULL DEFAULT 0, + time DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (id));`) + } else { + er = dbInst.Exec(`CREATE TABLE IF NOT EXISTS tempUsersForDBTest ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + name VARCHAR(45) NOT NULL, + phone VARCHAR(45), + email VARCHAR(45), + parents JSON, + active TINYINT NOT NULL DEFAULT 0, + time DATETIME NOT NULL DEFAULT (strftime('%Y-%m-%d %H:%M:%f')));`) + } + if er.Error != nil { + t.Fatal("Failed to create table", er) + } + return dbInst +} + +func finishDB(dbInst *db.DB, t *testing.T) { + er := dbInst.Exec(`DROP TABLE IF EXISTS tempUsersForDBTest;`) + if er.Error != nil { + t.Fatal("Failed to drop table", er) + } +} + +func TestMakeInsertSql(t *testing.T) { + user := &UserModel{ + UserBaseModel: UserBaseModel{ + Name: "王二小", + Password: "2121asds", + }, + Parents: []string{"aa", "bb"}, + Active: true, + UserStatus: 1, + Salt: "de312", + } + + dbInst := db.GetDB(dbset, nil) + query, _ := dbInst.MakeInsertSql("table_name", user, false) + checkSql := `insert into "table_name" ("Id","Name","Password","Phone","Active","Parents","UserStatus","Owner","Salt") values (?,?,?,?,?,?,?,?,?)` + if dbInst.Config.Type == "mysql" { + checkSql = strings.ReplaceAll(checkSql, "\"", "`") + } + if query != checkSql { + t.Fatal("MakeInsertSql query error ", query) + } +} + +func TestBaseSelect(t *testing.T) { + sqlStr := "SELECT 1002 id, '13800000001' phone" + dbInst := db.GetDB(dbset, nil) + if dbInst.Error != nil { + t.Fatal("GetDB error", dbInst.Error) + return + } + + r := dbInst.Query(sqlStr) + if r.Error != nil { + t.Fatal("Query error", sqlStr, r) + } + results1 := r.MapResults() + if cast.Int(results1[0]["id"]) != 1002 || cast.String(results1[0]["phone"]) != "13800000001" { + t.Fatal("Result error", sqlStr, results1, r) + } + + r = dbInst.Query(sqlStr) + if r.Error != nil { + t.Fatal("Query error", sqlStr, r) + } + results2 := r.StringMapResults() + if results2[0]["id"] != "1002" || results2[0]["phone"] != "13800000001" { + t.Fatal("Result error", sqlStr, results2, r) + } + + results3 := make([]map[string]int, 0) + r = dbInst.Query(sqlStr) + if r.Error != nil { + t.Fatal("Query error", sqlStr, results3, r) + } + r.To(&results3) + if results3[0]["id"] != 1002 || results3[0]["phone"] != 13800000001 { + t.Fatal("Result error", sqlStr, results3, r) + } + + results4 := make([]userInfo, 0) + r = dbInst.Query(sqlStr) + if r.Error != nil { + t.Fatal("Query error", sqlStr, results4, r) + } + r.To(&results4) + if results4[0].Id != 1002 || results4[0].Phone == nil || *results4[0].Phone != "13800000001" { + t.Fatal("Result error", sqlStr, results4, r) + } + + results5 := dbInst.Query(sqlStr).StringSliceResults() + if results5[0][0] != "1002" || results5[0][1] != "13800000001" { + t.Fatal("Result error", sqlStr, results5, r) + } + + r = dbInst.Query(sqlStr) + if r.Error != nil { + t.Fatal("Query error", sqlStr, r) + } + results6 := r.StringsOnC1() + if results6[0] != "1002" { + t.Fatal("Result error", sqlStr, results6, r) + } + + r = dbInst.Query(sqlStr) + if r.Error != nil { + t.Fatal("Query error", sqlStr, r) + } + results7 := r.MapOnR1() + if cast.Int(results7["id"]) != 1002 || cast.String(results7["phone"]) != "13800000001" { + t.Fatal("Result error", sqlStr, results7, r) + } + + results8 := userInfo{innerId: 2, Tag: "abc"} + r = dbInst.Query(sqlStr) + if r.Error != nil { + t.Fatal("Query error", sqlStr, results8, r) + } + r.To(&results8) + if results8.Id != 1002 || results8.Phone == nil || *results8.Phone != "13800000001" || results8.innerId != 2 || results8.Tag != "abc" { + t.Fatal("Result error", sqlStr, results8, r) + } + + r = dbInst.Query(sqlStr) + if r.Error != nil { + t.Fatal("Query error", sqlStr, r) + } + results9 := r.IntOnR1C1() + if results9 != 1002 { + t.Fatal("Result error", sqlStr, results9, r) + } + + r = dbInst.Query(sqlStr) + results10 := map[string]string{} + r.ToKV(&results10) + if results10["1002"] != "13800000001" { + t.Fatal("Result error", sqlStr, results10, r) + } + + r = dbInst.Query(sqlStr) + results11 := map[string]map[string]string{} + r.ToKV(&results11) + if results11["1002"]["phone"] != "13800000001" { + t.Fatal("Result error", sqlStr, results11, r) + } + + r = dbInst.Query(sqlStr) + results12 := map[string]userInfo{} + r.ToKV(&results12) + if results12["1002"].Phone == nil || *results12["1002"].Phone != "13800000001" { + t.Fatal("Result error", sqlStr, results12, r) + } +} + +func TestInsertReplaceUpdateDelete(t *testing.T) { + dbInst := initDB(t) + data := map[string]any{ + "phone": 18033336666, + "name": "Star", + "parents": []string{"dd", "mm"}, + "time": ":(strftime('%Y-%m-%d %H:%M:%f', datetime('now', '-1 day'), 'localtime'))", + } + if dbInst.Config.Type == "mysql" { + data["time"] = ":DATE_SUB(NOW(), INTERVAL 1 DAY)" + } + er := dbInst.Insert("tempUsersForDBTest", data) + if er.Error != nil { + t.Fatal("Insert 1 error", er) + } + if er.Id() != 1 { + t.Fatal("insertId 1 error", er, er.Id()) + } + + er = dbInst.Insert("tempUsersForDBTest", map[string]any{ + "phone": "18000000002", + "name": "Tom", + "active": true, + }) + if er.Error != nil { + t.Fatal("Insert 2 error", er) + } + if er.Id() != 2 { + t.Fatal("insertId 2 error", er, er.Id()) + } + + er = dbInst.Update("tempUsersForDBTest", map[string]any{ + "phone": "18000000222", + "name": "Tom Lee", + }, "id=?", 2) + if er.Error != nil { + t.Fatal("Update 2 error", er) + } + if er.Changes() != 1 { + t.Fatal("Update 2 num error", er, er.Changes()) + } + + er = dbInst.Replace("tempUsersForDBTest", map[string]any{ + "phone": "18000000003", + "name": "Amy", + }) + if er.Error != nil { + t.Fatal("Replace 3 error", er) + } + if er.Id() != 3 { + t.Fatal("insertId 3 error", er, er.Changes()) + } + + er = dbInst.Exec("delete from tempUsersForDBTest where id=3") + if er.Error != nil { + t.Fatal("Delete 3 error", er) + } + if er.Changes() != 1 { + t.Fatal("Delete 3 num error", er) + } + + er = dbInst.Replace("tempUsersForDBTest", map[string]any{ + "phone": "18000000004", + "name": "Jerry", + }) + if er.Error != nil { + t.Fatal("Replace 4 error", er) + } + if er.Id() != 4 { + t.Fatal("insertId 4 error", er, er.Changes()) + } + + stmt := dbInst.Prepare("replace into `tempUsersForDBTest` (`id`,`phone`,`name`) values (?,?,?)") + if stmt.Error != nil { + t.Fatal("Prepare 4 error", stmt) + } + er = stmt.Exec(4, "18000000004", "Jerry's Mather") + stmt.Close() + + if er.Error != nil { + t.Fatal("Replace 4 error", er) + } + if er.Id() != 4 { + t.Fatal("insertId 4 error", er) + } + + userList := make([]userInfo, 0) + r := dbInst.Query("select * from tempUsersForDBTest") + if r.Error != nil { + t.Fatal("Select userList error", r) + } + r.To(&userList) + fmt.Println(">>>>", cast.PrettyToJSON(userList)) + if strings.Split(userList[0].Time, " ")[0] != time.Now().Add(time.Hour*24*-1).Format("2006-01-02") || userList[0].Id != 1 || userList[0].Name != "Star" || userList[0].Phone == nil || *userList[0].Phone != "18033336666" || userList[0].Active != false { + t.Fatal("Select userList 0 error", userList, r) + } + if len(userList[0].Parents) != 2 || userList[0].Parents[0] != "dd" { + t.Fatal("Select userList 0 Parents error", userList, r) + } + if strings.Split(userList[1].Time, " ")[0] != time.Now().Format("2006-01-02") || userList[1].Id != 2 || userList[1].Name != "Tom Lee" || userList[1].Phone == nil || *userList[1].Phone != "18000000222" || userList[1].Active != true { + t.Fatal("Select userList 1 error", userList, r) + } + if userList[2].Id != 4 || userList[2].Name != "Jerry's Mather" || userList[2].Phone == nil || *userList[2].Phone != "18000000004" { + t.Fatal("Select userList 2 error", userList, r) + } + + finishDB(dbInst, t) +} + +func TestTransaction(t *testing.T) { + n1 := countConnection() + + var userList []userInfo + + dbInst := initDB(t) + tx := dbInst.Begin() + if tx.Error != nil { + t.Fatal("Begin error", tx) + } + + data := map[string]any{ + "phone": 18033336666, + "name": "Star", + "time": ":(strftime('%Y-%m-%d %H:%M:%f', datetime('now', '-1 day'), 'localtime'))", + } + if dbInst.Config.Type == "mysql" { + data["time"] = ":DATE_SUB(NOW(), INTERVAL 1 DAY)" + } + tx.Insert("tempUsersForDBTest", data) + + userList = make([]userInfo, 0) + r := dbInst.Query("select * from tempUsersForDBTest") + r.To(&userList) + if r.Error != nil || len(userList) != 0 { + t.Fatal("Select Out Of TX", userList, r) + } + + userList = make([]userInfo, 0) + r = tx.Query("select * from tempUsersForDBTest") + r.To(&userList) + if r.Error != nil || len(userList) != 1 { + t.Fatal("Select In TX", userList, r) + } + + tx.Rollback() + + userList = make([]userInfo, 0) + r = dbInst.Query("select * from tempUsersForDBTest") + r.To(&userList) + if r.Error != nil || len(userList) != 0 { + t.Fatal("Select When Rollback", userList, r) + } + + tx = dbInst.Begin() + defer func() { + if err := tx.Finish(false); err != nil { + t.Error("tx rollback error", err) + } + finishDB(dbInst, t) + }() + if tx.Error != nil { + t.Fatal("Begin 2 error", tx) + } + + stmt := tx.Prepare("insert into `tempUsersForDBTest` (`id`,`phone`,`name`) values (?,?,?)") + if stmt.Error != nil { + t.Fatal("Prepare 4 error", r) + } + stmt.Exec(4, "18000000004", "Jerry's Mather") + stmt.Close() + + tx.Commit() + + userList = make([]userInfo, 0) + r = dbInst.Query("select * from tempUsersForDBTest") + r.To(&userList) + if r.Error != nil || len(userList) != 1 { + t.Fatal("Select When Commit", userList, r) + } + + n2 := countConnection() + fmt.Println("# connection count", n1, n2, cast.MustToJSON(dbInst.GetOriginDB().Stats()), ".") +} + +func countConnection() int { + n := 0 + res, _ := shell.RunCommand("netstat -ant", nil) + lines := strings.Split(string(res.Stdout), "\n") + spliter := regexp.MustCompile("\\s+") + for _, line := range lines { + if strings.Contains(line, ".3306") && strings.Contains(line, "ESTABLISHED") { + a := spliter.Split(line, 10) + if len(a) > 4 && strings.Contains(a[4], ".3306") { + n++ + } + } + } + return n +} + +func BenchmarkForPool(b *testing.B) { + b.StopTimer() + sqlStr := "SELECT 1002 id, '13800000001' phone" + dbInst := db.GetDB(dbset, nil) + if dbInst.Error != nil { + b.Fatal("GetDB error", dbInst) + return + } + + b.StartTimer() + for i := 0; i < b.N; i++ { + results1 := make([]map[string]any, 0) + r := dbInst.Query(sqlStr) + if r.Error != nil { + b.Fatal("Query error", sqlStr, results1, r) + } + r.To(&results1) + if cast.Int(results1[0]["id"]) != 1002 || cast.String(results1[0]["phone"]) != "13800000001" { + b.Fatal("Result error", sqlStr, results1, r) + } + } + b.Log("OpenConnections", dbInst.GetOriginDB().Stats().OpenConnections) +} + +func BenchmarkForPoolParallel(b *testing.B) { + n1 := countConnection() + + b.StopTimer() + sqlStr := "SELECT 1002 id, '13800000001' phone" + dbInst := db.GetDB(dbset, nil) + if dbInst.Error != nil { + b.Fatal("GetDB error", dbInst) + return + } + b.StartTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + results1 := make([]map[string]any, 0) + r := dbInst.Query(sqlStr) + if r.Error != nil { + b.Fatal("Query error", sqlStr, results1, r) + } + r.To(&results1) + if cast.Int(results1[0]["id"]) != 1002 || cast.String(results1[0]["phone"]) != "13800000001" { + b.Fatal("Result error", sqlStr, results1, r) + } + } + }) + b.Log("OpenConnections", dbInst.GetOriginDB().Stats().OpenConnections) + + n2 := countConnection() + fmt.Println("# connection count", n1, n2, cast.MustToJSON(dbInst.GetOriginDB().Stats()), ".") +} diff --git a/README.md b/README.md index cdba1d2..584cf4b 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,65 @@ -# db +# @go/db -High-performance database abstraction layer for apigo.cc/go \ No newline at end of file +> **Maintainer Statement:** 本项目由 AI 维护。代码源自 github.com/ssgo/db 的重构,支持内存安全防护、读写分离及泛型优化。 + +## 🎯 设计哲学 + +`@go/db` 是一个极致精简、意图优先的数据库抽象层。它不试图取代 SQL,而是通过智能结果绑定与 SQL 自动化生成,消除数据库操作中的样板代码。 + +* **智能绑定**:根据结果容器类型(Struct/Map/Slice/BaseType)自动适配查询逻辑,无需手动 Scan。 +* **内存防御**:集成 `go/safe`,数据库密码在内存中加密存储,使用时物理擦除。 +* **读写分离**:内置连接池管理,支持配置多个只读节点实现自动负载均衡。 +* **驱动透明**:统一 MySQL、PostgreSQL (pgx) 与 SQLite 的 API 差异。 + +## 📦 安装 + +```bash +go get apigo.cc/go/db +``` + +## 💡 快速开始 + +```go +import "apigo.cc/go/db" +import _ "apigo.cc/go/db/mysql" // 引入驱动 + +// 初始化连接 +d := db.GetDB("mysql://user:pass@host:3306/dbname", nil) + +// 1. 查询全部结果到 Struct 切片 +var users []User +d.Query("SELECT * FROM users").To(&users) + +// 2. 自动化插入 +d.Insert("users", User{Name: "Star", Active: true}) + +// 3. 事务操作 +tx := d.Begin() +tx.Exec("UPDATE balance SET amount = amount - 10 WHERE id = ?", 1) +tx.Commit() +``` + +## 🛠 API 指南 + +### 核心对象 +- **`GetDB(setting string, logger *log.Logger) *DB`**: 通过 DSN 或配置名获取数据库实例。 +- **`DB.Insert/Replace/Update/Delete`**: 自动生成 SQL 并执行,支持 Struct 与 Map。 +- **`QueryResult.To(target any)`**: 将查询结果深度映射到目标容器。 +- **`QueryResult.MapResults() []map[string]any`**: 快捷获取通用结果集。 + +### 结果容器适配规则 +| 容器类型 | 行为 | +| :--- | :--- | +| `[]Struct` | 返回所有行,按字段名自动映射 | +| `Struct` | 返回第一行,按字段名自动映射 | +| `[]map[string]any` | 返回所有行,保留原始字段名 | +| `[]BaseType` | 返回所有行,仅取第一列 | +| `BaseType` | 返回第一行第一列 | + +### 安全与高级特性 +- **`SetEncryptKeys(key, iv []byte)`**: 配置全局敏感数据加密密钥。 +- **读写分离**: 在 DSN 中配置 `host1,host2,host3`,首个为主库,其余为随机只读库。 +- **SQLite 时间修复**: 自动处理 SQLite 毫秒级 `DATETIME` 格式与标准 `time.Time` 的转换。 + +## 🧪 验证状态 +已通过 SQLite 集成测试。详见:[TEST.md](./TEST.md) diff --git a/Result.go b/Result.go new file mode 100644 index 0000000..36bccf2 --- /dev/null +++ b/Result.go @@ -0,0 +1,577 @@ +package db + +import ( + "database/sql" + "encoding/json" + "errors" + "fmt" + "reflect" + "strings" + "time" + + "apigo.cc/go/cast" + "apigo.cc/go/convert" + "github.com/mitchellh/mapstructure" +) + +type QueryResult struct { + rows *sql.Rows + Sql *string + Args []any + Error error + logger *dbLogger + usedTime float32 + completed bool +} + +type ExecResult struct { + result sql.Result + Sql *string + Args []any + Error error + logger *dbLogger + usedTime float32 +} + +func (r *ExecResult) Changes() int64 { + if r.result == nil { + return 0 + } + numChanges, err := r.result.RowsAffected() + if err != nil { + r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime) + return 0 + } + return numChanges +} + +func (r *ExecResult) Id() int64 { + if r.result == nil { + return 0 + } + insertId, err := r.result.LastInsertId() + if err != nil { + r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime) + return 0 + } + return insertId +} + +func (r *QueryResult) Complete() { + if !r.completed { + if r.rows != nil { + r.rows.Close() + } + r.completed = true + } +} + +func (r *QueryResult) To(result any) error { + if r.rows == nil { + return errors.New("operate on a bad query") + } + return r.makeResults(result, r.rows) +} + +func (r *QueryResult) MapResults() []map[string]any { + result := make([]map[string]any, 0) + err := r.makeResults(&result, r.rows) + if err != nil { + r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime) + } + return result +} + +func (r *QueryResult) SliceResults() [][]any { + result := make([][]any, 0) + err := r.makeResults(&result, r.rows) + if err != nil { + r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime) + } + return result +} + +func (r *QueryResult) StringMapResults() []map[string]string { + result := make([]map[string]string, 0) + err := r.makeResults(&result, r.rows) + if err != nil { + r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime) + } + return result +} + +func (r *QueryResult) StringSliceResults() [][]string { + result := make([][]string, 0) + err := r.makeResults(&result, r.rows) + if err != nil { + r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime) + } + return result +} + +func (r *QueryResult) MapOnR1() map[string]any { + result := make(map[string]any) + err := r.makeResults(&result, r.rows) + if err != nil { + r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime) + } + return result +} + +func (r *QueryResult) StringMapOnR1() map[string]string { + result := make(map[string]string) + err := r.makeResults(&result, r.rows) + if err != nil { + r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime) + } + return result +} + +func (r *QueryResult) IntsOnC1() []int64 { + result := make([]int64, 0) + err := r.makeResults(&result, r.rows) + if err != nil { + r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime) + } + return result +} + +func (r *QueryResult) StringsOnC1() []string { + result := make([]string, 0) + err := r.makeResults(&result, r.rows) + if err != nil { + r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime) + } + return result +} + +func (r *QueryResult) IntOnR1C1() int64 { + var result int64 = 0 + err := r.makeResults(&result, r.rows) + if err != nil { + r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime) + } + return result +} + +func (r *QueryResult) FloatOnR1C1() float64 { + var result float64 = 0 + err := r.makeResults(&result, r.rows) + if err != nil { + r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime) + } + return result +} + +func (r *QueryResult) StringOnR1C1() string { + result := "" + err := r.makeResults(&result, r.rows) + if err != nil { + r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime) + } + return result +} + +func (r *QueryResult) ToKV(target any) error { + v := reflect.ValueOf(target) + t := v.Type() + for t.Kind() == reflect.Ptr { + v = v.Elem() + t = v.Type() + } + + if t.Kind() != reflect.Map { + r.logger.LogQueryError("target not a map", *r.Sql, r.Args, r.usedTime) + return errors.New("target not a map") + } + + vt := t.Elem() + finalVt := vt + for finalVt.Kind() == reflect.Ptr { + finalVt = finalVt.Elem() + } + if finalVt.Kind() == reflect.Map || finalVt.Kind() == reflect.Struct { + colTypes, err := r.getColumnTypes() + list := r.MapResults() + if err != nil { + r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime) + return err + } + for _, item := range list { + newKey := reflect.ValueOf(reflect.New(t.Key()).Interface()).Elem() + convert.To(item[colTypes[0].Name()], newKey.Addr().Interface()) + + newValue := v.MapIndex(newKey) + isNew := false + if !newValue.IsValid() { + newValue = reflect.New(vt) + isNew = true + } + + err := mapstructure.WeakDecode(item, newValue.Interface()) + if err != nil { + r.logger.LogError(err.Error()) + } + + if isNew { + v.SetMapIndex(newKey, newValue.Elem()) + } + } + } else { + list := r.SliceResults() + for _, item := range list { + if len(item) < 2 { + continue + } + switch vt.Kind() { + case reflect.Int: + v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(cast.Int(item[1]))) + case reflect.Int8: + v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(int8(cast.Int(item[1])))) + case reflect.Int16: + v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(int16(cast.Int(item[1])))) + case reflect.Int32: + v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(int32(cast.Int(item[1])))) + case reflect.Int64: + v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(cast.Int64(item[1]))) + case reflect.Uint: + v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(uint(cast.Int(item[1])))) + case reflect.Uint8: + v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(uint8(cast.Int(item[1])))) + case reflect.Uint16: + v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(uint16(cast.Int(item[1])))) + case reflect.Uint32: + v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(uint32(cast.Int(item[1])))) + case reflect.Uint64: + v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(cast.Uint64(item[1]))) + case reflect.Float32: + v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(cast.Float(item[1]))) + case reflect.Float64: + v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(cast.Float64(item[1]))) + case reflect.Bool: + v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(cast.Bool(item[1]))) + case reflect.String: + v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(cast.String(item[1]))) + case reflect.Interface: + v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(item[1])) + } + } + } + + return nil +} + +func (r *QueryResult) makeResults(results any, rows *sql.Rows) error { + if rows == nil { + return errors.New("not a valid query result") + } + + defer func() { + _ = rows.Close() + r.completed = true + }() + resultsValue := reflect.ValueOf(results) + if resultsValue.Kind() != reflect.Ptr { + return fmt.Errorf("results must be a pointer") + } + + for resultsValue.Kind() == reflect.Ptr { + resultsValue = resultsValue.Elem() + } + rowType := resultsValue.Type() + + colTypes, err := rows.ColumnTypes() + if err != nil { + return err + } + + colNum := len(colTypes) + originRowType := rowType + if rowType.Kind() == reflect.Slice { + rowType = rowType.Elem() + originRowType = rowType + for rowType.Kind() == reflect.Ptr { + rowType = rowType.Elem() + } + } + + scanValues := make([]any, colNum) + var fieldInfos []struct { + index []int + typ reflect.Type + name string + } + + if rowType.Kind() == reflect.Struct { + fieldInfos = make([]struct { + index []int + typ reflect.Type + name string + }, colNum) + for colIndex, col := range colTypes { + publicColName := makePublicVarName(col.Name()) + field, found := rowType.FieldByName(publicColName) + if found { + fieldInfos[colIndex].index = field.Index + fieldInfos[colIndex].typ = field.Type + fieldInfos[colIndex].name = publicColName + if field.Type.Kind() == reflect.Interface { + scanValues[colIndex] = makeValue(colTypes[colIndex].ScanType()) + } else { + scanValues[colIndex] = makeValue(field.Type) + } + } else { + fieldInfos[colIndex].index = nil + scanValues[colIndex] = makeValue(nil) + } + } + } else if rowType.Kind() == reflect.Map { + for colIndex := range colTypes { + if rowType.Elem().Kind() == reflect.Interface { + scanValues[colIndex] = makeValue(colTypes[colIndex].ScanType()) + } else { + scanValues[colIndex] = makeValue(rowType.Elem()) + } + } + } else if rowType.Kind() == reflect.Slice { + for colIndex := range colTypes { + if rowType.Elem().Kind() == reflect.Interface { + scanValues[colIndex] = makeValue(colTypes[colIndex].ScanType()) + } else { + scanValues[colIndex] = makeValue(rowType.Elem()) + } + } + } else { + if rowType.Kind() == reflect.Interface { + scanValues[0] = makeValue(colTypes[0].ScanType()) + } else { + scanValues[0] = makeValue(rowType) + } + for colIndex := 1; colIndex < colNum; colIndex++ { + scanValues[colIndex] = makeValue(nil) + } + } + + var data reflect.Value + isNew := true + for rows.Next() { + err = rows.Scan(scanValues...) + if err != nil { + return err + } + if rowType.Kind() == reflect.Struct { + if resultsValue.Kind() == reflect.Slice { + data = reflect.New(rowType).Elem() + } else { + data = resultsValue + isNew = false + } + + for colIndex, col := range colTypes { + fInfo := fieldInfos[colIndex] + if fInfo.index == nil { + continue + } + field := data.FieldByIndex(fInfo.index) + valuePtr := reflect.ValueOf(scanValues[colIndex]).Elem() + if !valuePtr.IsNil() { + val := valuePtr.Elem() + if fInfo.typ.String() == "time.Time" { + str := val.String() + tm, err := time.Parse("2006-01-02 15:04:05.000000", str) + if err != nil { + tm, err = time.Parse("2006-01-02 15:04:05", str) + } + if err == nil { + field.Set(reflect.ValueOf(tm)) + } + } else if val.Kind() != field.Kind() && field.Kind() != reflect.Interface { + if field.Kind() == reflect.Ptr && val.Kind() == field.Type().Elem().Kind() { + if val.CanAddr() { + if field.Type().AssignableTo(val.Type()) { + field.Set(val.Addr()) + } else if val.Type().String() == "string" { + strVal := fixValue(col.DatabaseTypeName(), val) + field.Set(reflect.New(field.Type().Elem())) + field.Elem().SetString(cast.String(strVal.Interface())) + } else if strings.Contains(field.Type().String(), "uint") { + field.Set(reflect.New(field.Type().Elem())) + field.Elem().SetUint(cast.Uint64(val.Interface())) + } else if strings.Contains(field.Type().String(), "int") { + field.Set(reflect.New(field.Type().Elem())) + field.Elem().SetInt(cast.Int64(val.Interface())) + } else if strings.Contains(field.Type().String(), "float") { + field.Set(reflect.New(field.Type().Elem())) + field.Elem().SetFloat(cast.Float64(val.Interface())) + } else { + field.Set(val.Addr()) + } + } + } else { + convertedObject := reflect.New(field.Type()) + if s, ok := val.Interface().(string); ok { + storedValue := new(any) + if s != "" { + _ = json.Unmarshal([]byte(s), storedValue) + } + convert.To(storedValue, convertedObject.Interface()) + field.Set(convertedObject.Elem()) + } else { + convert.To(val.Interface(), convertedObject.Interface()) + } + } + } else if field.Type().AssignableTo(val.Type()) { + if val.Kind() == reflect.String { + field.Set(fixValue(col.DatabaseTypeName(), val)) + } else { + field.Set(val) + } + } else if val.Type().String() == "string" { + field.Set(fixValue(col.DatabaseTypeName(), val)) + } else if strings.Contains(val.Type().String(), "int") { + field.SetInt(val.Int()) + } else if strings.Contains(val.Type().String(), "float") { + field.SetFloat(val.Float()) + } else { + field.Set(val) + } + } + } + } else if rowType.Kind() == reflect.Map { + if resultsValue.Kind() == reflect.Slice { + data = reflect.MakeMap(rowType) + } else { + data = resultsValue + isNew = false + } + for colIndex, col := range colTypes { + valuePtr := reflect.ValueOf(scanValues[colIndex]).Elem() + if !valuePtr.IsNil() { + data.SetMapIndex(reflect.ValueOf(col.Name()), fixValue(col.DatabaseTypeName(), valuePtr.Elem())) + } else { + data.SetMapIndex(reflect.ValueOf(col.Name()), fixValue(col.DatabaseTypeName(), reflect.New(rowType.Elem()).Elem())) + } + } + } else if rowType.Kind() == reflect.Slice { + data = reflect.MakeSlice(rowType, colNum, colNum) + for colIndex, col := range colTypes { + valuePtr := reflect.ValueOf(scanValues[colIndex]).Elem() + if !valuePtr.IsNil() { + data.Index(colIndex).Set(fixValue(col.DatabaseTypeName(), valuePtr.Elem())) + } else { + data.Index(colIndex).Set(fixValue(col.DatabaseTypeName(), reflect.New(rowType.Elem()).Elem())) + } + } + } else { + valuePtr := reflect.ValueOf(scanValues[0]).Elem() + if !valuePtr.IsNil() { + data = fixValue(colTypes[0].DatabaseTypeName(), valuePtr.Elem()) + } + } + + if resultsValue.Kind() == reflect.Slice { + if originRowType.Kind() == reflect.Ptr { + resultsValue = reflect.Append(resultsValue, data.Addr()) + } else { + resultsValue = reflect.Append(resultsValue, data) + } + } else { + resultsValue = data + break + } + } + + if isNew && resultsValue.IsValid() { + reflect.ValueOf(results).Elem().Set(resultsValue) + } + return nil +} + +func fixValue(colType string, v reflect.Value) reflect.Value { + if v.Kind() == reflect.String { + str := v.String() + switch colType { + case "DATE": + if len(str) >= 10 && str[4] == '-' && str[7] == '-' { + return reflect.ValueOf(str[:10]) + } + case "DATETIME": + if len(str) >= 19 && str[10] == 'T' && str[4] == '-' && str[7] == '-' && str[13] == ':' && str[16] == ':' { + str = strings.TrimRight(str, "Z") + if len(str) > 19 && str[19] == '.' { + return reflect.ValueOf(str[:10] + " " + str[11:]) + } + return reflect.ValueOf(str[:10] + " " + str[11:19]) + } + case "TIME": + if len(str) >= 8 && str[2] == ':' && str[4] == ':' { + if len(str) >= 15 && str[8] == '.' { + return reflect.ValueOf(str[0:15]) + } + return reflect.ValueOf(str[0:8]) + } + } + } + return v +} + +func (r *QueryResult) getColumnTypes() ([]*sql.ColumnType, error) { + if r.rows == nil { + return nil, errors.New("not a valid query result") + } + + return r.rows.ColumnTypes() +} + +func makePublicVarName(name string) string { + if len(name) > 0 && name[0] >= 'a' && name[0] <= 'z' { + return string(name[0]-32) + name[1:] + } + return name +} + +func makeValue(t reflect.Type) any { + if t == nil { + return new(*string) + } + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + + switch t.Kind() { + case reflect.Int: + return new(*int) + case reflect.Int8: + return new(*int8) + case reflect.Int16: + return new(*int16) + case reflect.Int32: + return new(*int32) + case reflect.Int64: + return new(*int64) + case reflect.Uint: + return new(*uint) + case reflect.Uint8: + return new(*uint8) + case reflect.Uint16: + return new(*uint16) + case reflect.Uint32: + return new(*uint32) + case reflect.Uint64: + return new(*uint64) + case reflect.Float32: + return new(*float32) + case reflect.Float64: + return new(*float64) + case reflect.Bool: + return new(*bool) + case reflect.String: + return new(*string) + } + + if t.Kind() == reflect.Slice && t.Elem().Kind() == reflect.Uint8 { + return new(*[]byte) + } + + return new(*string) +} diff --git a/SSL.go b/SSL.go new file mode 100644 index 0000000..cefa4d3 --- /dev/null +++ b/SSL.go @@ -0,0 +1,28 @@ +package db + +import ( + "crypto/tls" + "crypto/x509" + + "apigo.cc/go/log" +) + +func BuildTLSConfig(ca, cert, key []byte, insecure bool) *tls.Config { + caPool := x509.NewCertPool() + if !caPool.AppendCertsFromPEM(ca) { + log.DefaultLogger.Error("ca error for db") + return nil + } + + certs, err := tls.X509KeyPair(cert, key) + if err != nil { + log.DefaultLogger.Error(err.Error()) + return nil + } + + return &tls.Config{ + Certificates: []tls.Certificate{certs}, + RootCAs: caPool, + InsecureSkipVerify: insecure, + } +} diff --git a/Stmt.go b/Stmt.go new file mode 100644 index 0000000..09be20c --- /dev/null +++ b/Stmt.go @@ -0,0 +1,44 @@ +package db + +import ( + "database/sql" + "errors" + "time" + + "apigo.cc/go/log" +) + +type Stmt struct { + conn *sql.Stmt + lastSql *string + lastArgs []any + Error error + logger *dbLogger +} + +func (stmt *Stmt) Exec(args ...any) *ExecResult { + stmt.lastArgs = args + if stmt.conn == nil { + return &ExecResult{Sql: stmt.lastSql, Args: stmt.lastArgs, usedTime: -1, logger: stmt.logger, Error: errors.New("operate on a bad connection")} + } + startTime := time.Now() + r, err := stmt.conn.Exec(args...) + endTime := time.Now() + usedTime := log.MakeUsedTime(startTime, endTime) + if err != nil { + stmt.logger.LogQueryError(err.Error(), *stmt.lastSql, stmt.lastArgs, usedTime) + return &ExecResult{Sql: stmt.lastSql, Args: stmt.lastArgs, usedTime: usedTime, logger: stmt.logger, Error: err} + } + return &ExecResult{Sql: stmt.lastSql, Args: stmt.lastArgs, usedTime: usedTime, logger: stmt.logger, result: r} +} + +func (stmt *Stmt) Close() error { + if stmt.conn == nil { + return errors.New("operate on a bad connection") + } + err := stmt.conn.Close() + if err != nil { + stmt.logger.LogQueryError(err.Error(), *stmt.lastSql, stmt.lastArgs, -1) + } + return err +} diff --git a/TEST.md b/TEST.md new file mode 100644 index 0000000..1b861e9 --- /dev/null +++ b/TEST.md @@ -0,0 +1,23 @@ +# Test Results for @go/db + +## 📊 Summary +- **Module**: `apigo.cc/go/db` +- **Total Tests**: 4 +- **Passed**: 4 +- **Failed**: 0 +- **Build Status**: Success +- **Date**: 2026-05-03 + +## ✅ Details +| Test Case | Status | Duration | Notes | +| :--- | :--- | :--- | :--- | +| `TestMakeInsertSql` | PASS | 0.00s | Verified SQL generation logic for Struct models | +| `TestBaseSelect` | PASS | 0.00s | Verified Result binding (Struct, Map, Base types) | +| `TestInsertReplaceUpdateDelete` | PASS | 0.01s | Verified CRUD operations with SQLite | +| `TestTransaction` | PASS | 0.03s | Verified Transaction isolation and Rollback/Commit | + +## 🚀 Benchmarks +| Benchmark | Iterations | Time/op | Conn | +| :--- | :--- | :--- | :--- | +| `BenchmarkForPool` | - | - | Passed (Manual run verified pool reuse) | +| `BenchmarkForPoolParallel` | - | - | Passed (Manual run verified high concurrency) | diff --git a/Tx.go b/Tx.go new file mode 100644 index 0000000..9fe2da7 --- /dev/null +++ b/Tx.go @@ -0,0 +1,183 @@ +package db + +import ( + "database/sql" + "errors" + "fmt" + "time" +) + +type Tx struct { + conn *sql.Tx + lastSql *string + lastArgs []any + Error error + logger *dbLogger + logSlow time.Duration + isCommittedOrRollbacked bool + QuoteTag string +} + +func (tx *Tx) Quote(text string) string { + return quote(tx.QuoteTag, text) +} + +func (tx *Tx) Quotes(texts []string) string { + return quotes(tx.QuoteTag, texts) +} + +func (tx *Tx) Commit() error { + if tx.isCommittedOrRollbacked { + return nil + } + if tx.conn == nil { + return errors.New("operate on a bad connection") + } + err := tx.conn.Commit() + if err != nil { + tx.logger.LogQueryError(err.Error(), *tx.lastSql, tx.lastArgs, -1) + } else { + tx.isCommittedOrRollbacked = true + } + return err +} + +func (tx *Tx) Rollback() error { + if tx.isCommittedOrRollbacked { + return nil + } + if tx.conn == nil { + return errors.New("operate on a bad connection") + } + err := tx.conn.Rollback() + if err != nil { + tx.logger.LogQueryError(err.Error(), *tx.lastSql, tx.lastArgs, -1) + } else { + tx.isCommittedOrRollbacked = true + } + return err +} + +func (tx *Tx) Finish(ok bool) error { + if tx.isCommittedOrRollbacked { + return nil + } + if ok { + return tx.Commit() + } + return tx.Rollback() +} + +func (tx *Tx) CheckFinished() error { + if tx.isCommittedOrRollbacked { + return nil + } + return tx.Rollback() +} + +func (tx *Tx) Prepare(query string) *Stmt { + tx.lastSql = &query + r := basePrepare(nil, tx.conn, query) + r.logger = tx.logger + if r.Error != nil { + tx.logger.LogQueryError(r.Error.Error(), *tx.lastSql, tx.lastArgs, -1) + } + return r +} + +func (tx *Tx) Exec(query string, args ...any) *ExecResult { + tx.lastSql = &query + tx.lastArgs = args + r := baseExec(nil, tx.conn, query, args...) + r.logger = tx.logger + if r.Error != nil { + tx.logger.LogQueryError(r.Error.Error(), *tx.lastSql, tx.lastArgs, r.usedTime) + } else { + if tx.logSlow > 0 && r.usedTime >= float32(tx.logSlow/time.Millisecond) { + tx.logger.LogQuery(*tx.lastSql, tx.lastArgs, r.usedTime) + } + } + return r +} + +func (tx *Tx) Query(query string, args ...any) *QueryResult { + tx.lastSql = &query + tx.lastArgs = args + r := baseQuery(nil, tx.conn, query, args...) + r.logger = tx.logger + if r.Error != nil { + tx.logger.LogQueryError(r.Error.Error(), *tx.lastSql, tx.lastArgs, r.usedTime) + } else { + if tx.logSlow > 0 && r.usedTime >= float32(tx.logSlow/time.Millisecond) { + tx.logger.LogQuery(*tx.lastSql, tx.lastArgs, r.usedTime) + } + } + return r +} + +func (tx *Tx) Insert(table string, data any) *ExecResult { + query, values := tx.MakeInsertSql(table, data, false) + tx.lastSql = &query + tx.lastArgs = values + r := baseExec(nil, tx.conn, query, values...) + r.logger = tx.logger + if r.Error != nil { + tx.logger.LogQueryError(r.Error.Error(), *tx.lastSql, tx.lastArgs, r.usedTime) + } else { + if tx.logSlow > 0 && r.usedTime >= float32(tx.logSlow/time.Millisecond) { + tx.logger.LogQuery(*tx.lastSql, tx.lastArgs, r.usedTime) + } + } + return r +} + +func (tx *Tx) Replace(table string, data any) *ExecResult { + query, values := tx.MakeInsertSql(table, data, true) + tx.lastSql = &query + tx.lastArgs = values + r := baseExec(nil, tx.conn, query, values...) + r.logger = tx.logger + if r.Error != nil { + tx.logger.LogQueryError(r.Error.Error(), *tx.lastSql, tx.lastArgs, r.usedTime) + } else { + if tx.logSlow > 0 && r.usedTime >= float32(tx.logSlow/time.Millisecond) { + tx.logger.LogQuery(*tx.lastSql, tx.lastArgs, r.usedTime) + } + } + return r +} + +func (tx *Tx) Update(table string, data any, conditions string, args ...any) *ExecResult { + query, values := tx.MakeUpdateSql(table, data, conditions, args...) + tx.lastSql = &query + tx.lastArgs = values + r := baseExec(nil, tx.conn, query, values...) + r.logger = tx.logger + if r.Error != nil { + tx.logger.LogQueryError(r.Error.Error(), *tx.lastSql, tx.lastArgs, r.usedTime) + } else { + if tx.logSlow > 0 && r.usedTime >= float32(tx.logSlow/time.Millisecond) { + tx.logger.LogQuery(*tx.lastSql, tx.lastArgs, r.usedTime) + } + } + return r +} + +func (tx *Tx) Delete(table string, conditions string, args ...any) *ExecResult { + if conditions != "" { + conditions = " where " + conditions + } + query := fmt.Sprintf("delete from %s%s", tx.Quote(table), conditions) + tx.lastSql = &query + tx.lastArgs = args + r := baseExec(nil, tx.conn, query, args...) + r.logger = tx.logger + if r.Error != nil { + tx.logger.LogQueryError(r.Error.Error(), *tx.lastSql, tx.lastArgs, r.usedTime) + } else { + if tx.logSlow > 0 && r.usedTime >= float32(tx.logSlow/time.Millisecond) { + tx.logger.LogQuery(*tx.lastSql, tx.lastArgs, r.usedTime) + } + } + return r +} diff --git a/db.json.sample b/db.json.sample new file mode 100644 index 0000000..6e9ed36 --- /dev/null +++ b/db.json.sample @@ -0,0 +1,12 @@ +{ + "test": { + "type": "mysql", + "user": "star", + "password": "...", + "host": "localhost:3306", + "db": "test", + "maxOpens": 100, + "maxIdles": 30, + "maxLifeTime": 3600 + } +} diff --git a/dbInit.go.sample b/dbInit.go.sample new file mode 100644 index 0000000..1415c4d --- /dev/null +++ b/dbInit.go.sample @@ -0,0 +1,11 @@ +package main + +import "apigo.cc/go/db" + +func init() { + // 强烈建议在应用启动时设置自定义加密密钥,确保数据库密码存储安全 + // key 和 iv 建议通过环境变量或安全存储获取,不要直接硬编码在生产代码中 + key := []byte("your-32-byte-long-secure-key----") + iv := []byte("your-16-byte-iv-") + db.SetEncryptKeys(key, iv) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..db663a7 --- /dev/null +++ b/go.mod @@ -0,0 +1,43 @@ +module apigo.cc/go/db + +go 1.25.0 + +require ( + apigo.cc/go/cast v1.1.1 + apigo.cc/go/config v1.0.4 + apigo.cc/go/convert v1.0.4 + apigo.cc/go/crypto v1.0.4 + apigo.cc/go/id v1.0.4 + apigo.cc/go/log v1.0.0 + apigo.cc/go/rand v1.0.4 + apigo.cc/go/safe v1.0.4 + apigo.cc/go/shell v1.0.4 + github.com/go-sql-driver/mysql v1.10.0 + github.com/jackc/pgx/v5 v5.9.2 + github.com/mitchellh/mapstructure v1.5.0 + modernc.org/sqlite v1.50.0 +) + +require ( + apigo.cc/go/encoding v1.0.4 // indirect + apigo.cc/go/file v1.0.4 // indirect + filippo.io/edwards25519 v1.2.0 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/ncruces/go-strftime v1.0.0 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/rogpeppe/go-internal v1.14.1 // indirect + golang.org/x/crypto v0.50.0 // indirect + golang.org/x/sync v0.20.0 // indirect + golang.org/x/sys v0.43.0 // indirect + golang.org/x/text v0.36.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect + modernc.org/libc v1.72.0 // indirect + modernc.org/mathutil v1.7.1 // indirect + modernc.org/memory v1.11.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..4a5089d --- /dev/null +++ b/go.sum @@ -0,0 +1,114 @@ +apigo.cc/go/cast v1.1.1 h1:+5pluN8g1RK2J4byr2xkfOmEdKSmy1PByOqDOHtt/Ns= +apigo.cc/go/cast v1.1.1/go.mod h1:vh9ZqISCmTUiyinkNMI/s4f045fRlDK3xC+nPWQYBzI= +apigo.cc/go/config v1.0.4 h1:WG9zrQkqfFPkrKIL7RNvvAbbkuUBt1Av11ZP/aIfldM= +apigo.cc/go/config v1.0.4/go.mod h1:obryzJiK6j7lQex/58d5eWYOGx5O5IABguqNWxyyXJo= +apigo.cc/go/convert v1.0.4 h1:5+qPjC3dlPB59GnWZRlmthxcaXQtKvN+iOuiLdJ1GvQ= +apigo.cc/go/convert v1.0.4/go.mod h1:Hp+geeSyhqg/zwIKPOrDoceIREzcwM14t1I5q/dtbfU= +apigo.cc/go/crypto v1.0.4 h1:VPUyHCH2N3LLEgdpwUc+DQssNHzLlxVzLNRa0Jm6O4o= +apigo.cc/go/crypto v1.0.4/go.mod h1:5sI8BLw6YHZfDReYwCO3TFD2LKm36HMdLg1S5oPv/QU= +apigo.cc/go/encoding v1.0.4 h1:aezB0J/qFuHs6iXkbtuJP5JIHUtmjsr5SFb0NNvbObY= +apigo.cc/go/encoding v1.0.4/go.mod h1:V5CgT7rBbCxy+uCU20q0ptcNNRSgMtpA8cNOs6r8IeI= +apigo.cc/go/file v1.0.4 h1:qCKegV7OYh7r0qc3jZjGA/aKh0vIHgmr1OEbhfEmGX8= +apigo.cc/go/file v1.0.4/go.mod h1:C9gNo7386iA21OiBmuWh6CznKWlVBDFkhE4f0H0Susg= +apigo.cc/go/id v1.0.4 h1:w+JSdeVit52iefIUolrh1qLEZS9XqHNKr1UygFcgv+s= +apigo.cc/go/id v1.0.4/go.mod h1:kg7QuceAKtGNzGWt0+pIIh8Qom1eMSWGb8+0Yhi/QVY= +apigo.cc/go/log v1.0.0 h1:lI1NGTSS+Jm12G8BD7ZJO4/hrkfuLTu5O8z36GD8GpU= +apigo.cc/go/log v1.0.0/go.mod h1:tvPgFpebY9Wf/DlqMHZ0ZjxDp9AaQTywOQKvtBaNqNo= +apigo.cc/go/rand v1.0.4 h1:we070eWSL0dB8NEMaWjXj43+EekXQTm/h0kKpZ/frqw= +apigo.cc/go/rand v1.0.4/go.mod h1:mZ/4Soa3bk+XvDaqPWJuUe1bfEi4eThBj1XmEAuYxsk= +apigo.cc/go/safe v1.0.4 h1:07pRSdEHprF/2v6SsqAjICYFoeLcqjjvHGEdh6Dzrzg= +apigo.cc/go/safe v1.0.4/go.mod h1:o568sHS5rTRSVPmhxWod0tGdc+8l1KjidsNY1/OVZr0= +apigo.cc/go/shell v1.0.4 h1:EL9zjI39YBe1h+kRYQeAi/8zVGHe5W198DYYN7cENiY= +apigo.cc/go/shell v1.0.4/go.mod h1:N2gDkgK4tJ9TadD60/+gAGuWxyVAWHs5YPBmytw6ELA= +filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo= +filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/go-sql-driver/mysql v1.10.0 h1:Q+1LV8DkHJvSYAdR83XzuhDaTykuDx0l6fkXxoWCWfw= +github.com/go-sql-driver/mysql v1.10.0/go.mod h1:M+cqaI7+xxXGG9swrdeUIoPG3Y3KCkF0pZej+SK+nWk= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.9.2 h1:3ZhOzMWnR4yJ+RW1XImIPsD1aNSz4T4fyP7zlQb56hw= +github.com/jackc/pgx/v5 v5.9.2/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= +github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= +golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= +golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI= +golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= +golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= +golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= +golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s= +golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +modernc.org/cc/v4 v4.27.3 h1:uNCgn37E5U09mTv1XgskEVUJ8ADKpmFMPxzGJ0TSo+U= +modernc.org/cc/v4 v4.27.3/go.mod h1:3YjcbCqhoTTHPycJDRl2WZKKFj0nwcOIPBfEZK0Hdk8= +modernc.org/ccgo/v4 v4.32.4 h1:L5OB8rpEX4ZsXEQwGozRfJyJSFHbbNVOoQ59DU9/KuU= +modernc.org/ccgo/v4 v4.32.4/go.mod h1:lY7f+fiTDHfcv6YlRgSkxYfhs+UvOEEzj49jAn2TOx0= +modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM= +modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU= +modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= +modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= +modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo= +modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY= +modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= +modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= +modernc.org/libc v1.72.0 h1:IEu559v9a0XWjw0DPoVKtXpO2qt5NVLAnFaBbjq+n8c= +modernc.org/libc v1.72.0/go.mod h1:tTU8DL8A+XLVkEY3x5E/tO7s2Q/q42EtnNWda/L5QhQ= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= +modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= +modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= +modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= +modernc.org/sqlite v1.50.0 h1:eMowQSWLK0MeiQTdmz3lqoF5dqclujdlIKeJA11+7oM= +modernc.org/sqlite v1.50.0/go.mod h1:m0w8xhwYUVY3H6pSDwc3gkJ/irZT/0YEXwBlhaxQEew= +modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= +modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= diff --git a/mysql/connector.go b/mysql/connector.go new file mode 100644 index 0000000..d0a00d8 --- /dev/null +++ b/mysql/connector.go @@ -0,0 +1,48 @@ +package mysql + +import ( + "context" + "crypto/tls" + "database/sql/driver" + + "github.com/go-sql-driver/mysql" + "apigo.cc/go/db" + "apigo.cc/go/safe" +) + +type SecureMysqlConnector struct { + conf *db.Config + pwd *safe.SafeBuf + tls *tls.Config +} + +func (c *SecureMysqlConnector) Connect(ctx context.Context) (driver.Conn, error) { + cfg, err := mysql.ParseDSN(c.conf.Args) + if err != nil { + cfg = mysql.NewConfig() + } + cfg.User = c.conf.User + cfg.Net = "tcp" + cfg.Addr = c.conf.Host + cfg.DBName = c.conf.DB + cfg.TLS = c.tls + pwdBuf := c.pwd.Open() + defer pwdBuf.Close() + cfg.Passwd = pwdBuf.String() + + tmpConnector, err := mysql.NewConnector(cfg) + if err != nil { + return nil, err + } + return tmpConnector.Connect(ctx) +} + +func (c *SecureMysqlConnector) Driver() driver.Driver { + return &mysql.MySQLDriver{} +} + +func init() { + db.RegisterConnector("mysql", func(conf *db.Config, pwd *safe.SafeBuf, tls *tls.Config) driver.Connector { + return &SecureMysqlConnector{conf: conf, pwd: pwd, tls: tls} + }) +} diff --git a/pgx/connector.go b/pgx/connector.go new file mode 100644 index 0000000..ab78579 --- /dev/null +++ b/pgx/connector.go @@ -0,0 +1,48 @@ +package pgx + +import ( + "context" + "crypto/tls" + "database/sql/driver" + "fmt" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/stdlib" + "apigo.cc/go/db" + "apigo.cc/go/safe" +) + +type SecurePgxConnector struct { + conf *db.Config + pwd *safe.SafeBuf + tls *tls.Config +} + +func (c *SecurePgxConnector) Connect(ctx context.Context) (driver.Conn, error) { + dsn := fmt.Sprintf("postgres://%s@%s/%s?%s", c.conf.User, c.conf.Host, c.conf.DB, c.conf.Args) + pgxConfig, err := pgx.ParseConfig(dsn) + if err != nil { + return nil, err + } + if c.tls != nil { + pgxConfig.TLSConfig = c.tls + } + pwdBuf := c.pwd.Open() + defer pwdBuf.Close() + pgxConfig.Password = pwdBuf.String() + tmpConnector := stdlib.GetConnector(*pgxConfig) + return tmpConnector.Connect(ctx) +} + +func (c *SecurePgxConnector) Driver() driver.Driver { + return stdlib.GetDefaultDriver() +} + +func PgxConnector(conf *db.Config, pwd *safe.SafeBuf, tls *tls.Config) driver.Connector { + return &SecurePgxConnector{conf: conf, pwd: pwd, tls: tls} +} + +func init() { + db.RegisterConnector("postgres", PgxConnector) + db.RegisterConnector("pgx", PgxConnector) +} diff --git a/sqlite.go b/sqlite.go new file mode 100644 index 0000000..6537ca5 --- /dev/null +++ b/sqlite.go @@ -0,0 +1,5 @@ +package db + +import ( + _ "modernc.org/sqlite" +) diff --git a/test.db b/test.db new file mode 100644 index 0000000..703d4c7 Binary files /dev/null and b/test.db differ