From 6c2b2fed4d0766882c29872e7544b8fb3480cd13 Mon Sep 17 00:00:00 2001 From: AI Engineer Date: Sun, 3 May 2026 14:08:46 +0800 Subject: [PATCH] Optimize db module: pre-calculate field mappings in makeResults, fix typos, and standardize naming (by AI) --- AI.md | 16 ++ Base.go | 213 ++++++++++++++++ CHANGELOG.md | 12 + DB.go | 611 +++++++++++++++++++++++++++++++++++++++++++++ DB_test.go | 477 +++++++++++++++++++++++++++++++++++ README.md | 66 ++++- Result.go | 577 ++++++++++++++++++++++++++++++++++++++++++ SSL.go | 28 +++ Stmt.go | 44 ++++ TEST.md | 23 ++ Tx.go | 183 ++++++++++++++ db.json.sample | 12 + dbInit.go.sample | 11 + go.mod | 43 ++++ go.sum | 114 +++++++++ mysql/connector.go | 48 ++++ pgx/connector.go | 48 ++++ sqlite.go | 5 + test.db | Bin 0 -> 12288 bytes 19 files changed, 2529 insertions(+), 2 deletions(-) create mode 100644 AI.md create mode 100644 Base.go create mode 100644 CHANGELOG.md create mode 100644 DB.go create mode 100644 DB_test.go create mode 100644 Result.go create mode 100644 SSL.go create mode 100644 Stmt.go create mode 100644 TEST.md create mode 100644 Tx.go create mode 100644 db.json.sample create mode 100644 dbInit.go.sample create mode 100644 go.mod create mode 100644 go.sum create mode 100644 mysql/connector.go create mode 100644 pgx/connector.go create mode 100644 sqlite.go create mode 100644 test.db diff --git a/AI.md b/AI.md new file mode 100644 index 0000000..bdd3708 --- /dev/null +++ b/AI.md @@ -0,0 +1,16 @@ +# AI 指南 - @go/db + +## 🤖 AI 调用规则 +- **版本**: v1.0.1 +- **核心原则**: 优先使用结构化绑定(`To`, `MapResults`),避免手动拼装 SQL 结果。 +- **敏感数据**: 必须通过 `SetEncryptKeys` 配置密钥,确保 DSN 中的密码安全。 +- **读写分离**: 鼓励在 DSN 中配置多个 Host 以利用内置的读写分离机制。 +- **性能优化**: + - 大规模查询应优先绑定到 Struct 切片。 + - 频繁执行的 SQL 应使用 `Prepare`。 +- **事务处理**: 始终使用 `tx.Finish(err == nil)` 或 `defer tx.CheckFinished()` 确保事务闭环。 + +## ⚠️ 注意事项 +- 严禁在代码中硬编码数据库凭据。 +- 严禁忽略 `Exec` 或 `Query` 返回的 `Error`。 +- SQLite 模式下,时间字段会自动转换,无需手动解析字符串。 diff --git a/Base.go b/Base.go new file mode 100644 index 0000000..3fce087 --- /dev/null +++ b/Base.go @@ -0,0 +1,213 @@ +package db + +import ( + "database/sql" + "errors" + "fmt" + "reflect" + "strings" + "time" + + "apigo.cc/go/cast" + "apigo.cc/go/log" +) + +func basePrepare(db *sql.DB, tx *sql.Tx, query string) *Stmt { + var sqlStmt *sql.Stmt + var err error + if tx != nil { + sqlStmt, err = tx.Prepare(query) + } else if db != nil { + sqlStmt, err = db.Prepare(query) + } else { + return &Stmt{Error: errors.New("operate on a bad connection")} + } + if err != nil { + return &Stmt{Error: err} + } + return &Stmt{conn: sqlStmt, lastSql: &query} +} + +func baseExec(db *sql.DB, tx *sql.Tx, query string, args ...any) *ExecResult { + args = flatArgs(args) + var r sql.Result + var err error + startTime := time.Now() + if tx != nil { + r, err = tx.Exec(query, args...) + } else if db != nil { + r, err = db.Exec(query, args...) + } else { + return &ExecResult{Sql: &query, Args: args, usedTime: log.MakeUsedTime(startTime, time.Now()), Error: errors.New("operate on a bad connection")} + } + endTime := time.Now() + usedTime := log.MakeUsedTime(startTime, endTime) + + if err != nil { + return &ExecResult{Sql: &query, Args: args, usedTime: usedTime, Error: err} + } + return &ExecResult{Sql: &query, Args: args, usedTime: usedTime, result: r} +} + +func flatArgs(args []any) []any { + for i, arg := range args { + if arg == nil { + continue + } + argValue := reflect.ValueOf(arg) + kind := argValue.Kind() + if kind == reflect.Map || kind == reflect.Struct || (kind == reflect.Slice && argValue.Type().Elem().Kind() != reflect.Uint8) { + args[i] = cast.MustToJSON(arg) + } + } + return args +} + +func baseQuery(db *sql.DB, tx *sql.Tx, query string, args ...any) *QueryResult { + args = flatArgs(args) + + var rows *sql.Rows + var err error + startTime := time.Now() + if tx != nil { + rows, err = tx.Query(query, args...) + } else if db != nil { + rows, err = db.Query(query, args...) + } else { + return &QueryResult{Sql: &query, Args: args, usedTime: log.MakeUsedTime(startTime, time.Now()), Error: errors.New("operate on a bad connection")} + } + endTime := time.Now() + usedTime := log.MakeUsedTime(startTime, endTime) + + if err != nil { + return &QueryResult{Sql: &query, Args: args, usedTime: usedTime, Error: err} + } + return &QueryResult{Sql: &query, Args: args, usedTime: usedTime, rows: rows} +} + +func quote(quoteTag string, text string) string { + a := strings.Split(text, ".") + for i, v := range a { + a[i] = quoteTag + strings.ReplaceAll(v, quoteTag, "\\"+quoteTag) + quoteTag + } + return strings.Join(a, ".") +} + +func quotes(quoteTag string, texts []string) string { + for i, v := range texts { + texts[i] = quote(quoteTag, v) + } + return strings.Join(texts, ",") +} + +func makeInsertSql(quoteTag string, table string, data any, useReplace bool) (string, []any) { + keys, vars, values := MakeKeysVarsValues(data) + operation := "insert" + if useReplace { + operation = "replace" + } + query := fmt.Sprintf("%s into %s (%s) values (%s)", operation, quote(quoteTag, table), quotes(quoteTag, keys), strings.Join(vars, ",")) + return query, values +} + +func makeUpdateSql(quoteTag string, table string, data any, conditions string, args ...any) (string, []any) { + args = flatArgs(args) + keys, vars, values := MakeKeysVarsValues(data) + for i, k := range keys { + keys[i] = fmt.Sprintf("%s=%s", quote(quoteTag, k), vars[i]) + } + values = append(values, args...) + if conditions != "" { + conditions = " where " + conditions + } + query := fmt.Sprintf("update %s set %s%s", quote(quoteTag, table), strings.Join(keys, ","), conditions) + return query, values +} + +func (db *DB) MakeInsertSql(table string, data any, useReplace bool) (string, []any) { + return makeInsertSql(db.QuoteTag, table, data, useReplace) +} + +func (db *DB) MakeUpdateSql(table string, data any, conditions string, args ...any) (string, []any) { + return makeUpdateSql(db.QuoteTag, table, data, conditions, args...) +} + +func (tx *Tx) MakeInsertSql(table string, data any, useReplace bool) (string, []any) { + return makeInsertSql(tx.QuoteTag, table, data, useReplace) +} + +func (tx *Tx) MakeUpdateSql(table string, data any, conditions string, args ...any) (string, []any) { + return makeUpdateSql(tx.QuoteTag, table, data, conditions, args...) +} + +func getFlatFields(fields map[string]reflect.Value, fieldKeys *[]string, value reflect.Value) { + valueType := value.Type() + for i := 0; i < value.NumField(); i++ { + v := value.Field(i) + if valueType.Field(i).Anonymous { + getFlatFields(fields, fieldKeys, v) + } else { + *fieldKeys = append(*fieldKeys, valueType.Field(i).Name) + fields[valueType.Field(i).Name] = v + } + } +} + +func MakeKeysVarsValues(data any) ([]string, []string, []any) { + keys := make([]string, 0) + vars := make([]string, 0) + values := make([]any, 0) + + dataType := reflect.TypeOf(data) + dataValue := reflect.ValueOf(data) + for dataType.Kind() == reflect.Ptr { + dataType = dataType.Elem() + dataValue = dataValue.Elem() + } + + if dataType.Kind() == reflect.Struct { + fields := make(map[string]reflect.Value) + fieldKeys := make([]string, 0) + getFlatFields(fields, &fieldKeys, dataValue) + for _, k := range fieldKeys { + if k[0] >= 'a' && k[0] <= 'z' { + continue + } + v := fields[k] + if v.Kind() == reflect.Interface { + v = v.Elem() + } + keys = append(keys, k) + if v.Kind() == reflect.String && v.Len() > 0 && v.String()[0] == ':' { + vars = append(vars, v.String()[1:]) + } else { + vars = append(vars, "?") + if !v.IsValid() || !v.CanInterface() { + values = append(values, nil) + } else { + values = append(values, v.Interface()) + } + } + } + } else if dataType.Kind() == reflect.Map { + for _, k := range dataValue.MapKeys() { + v := dataValue.MapIndex(k) + if v.Kind() == reflect.Interface { + v = v.Elem() + } + keys = append(keys, cast.String(k.Interface())) + if v.Kind() == reflect.String && v.Len() > 0 && v.String()[0] == ':' { + vars = append(vars, v.String()[1:]) + } else { + vars = append(vars, "?") + if !v.IsValid() || !v.CanInterface() { + values = append(values, nil) + } else { + values = append(values, v.Interface()) + } + } + } + } + + return keys, vars, values +} diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..8dbf097 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,12 @@ +# CHANGELOG - @go/db + +## [1.0.1] - 2026-05-03 +### Optimized +- Refactored `makeResults` to pre-calculate field mappings for structs, significantly improving performance for large result sets. +- Simplified and optimized `makeValue` and `makePublicVarName` functions. +- Optimized time parsing in `makeResults`. + +### Fixed +- Fixed typo `isCommitedOrRollbacked` to `isCommittedOrRollbacked` in `Tx` struct. +- Standardized parameter naming: renamed `requestSql` to `query` and `wheres` to `conditions` across the module. +- Modernized Go syntax to align with latest standards. diff --git a/DB.go b/DB.go new file mode 100644 index 0000000..bd6387b --- /dev/null +++ b/DB.go @@ -0,0 +1,611 @@ +package db + +import ( + "crypto/tls" + "database/sql" + "database/sql/driver" + "encoding/base64" + "errors" + "fmt" + "net/url" + "regexp" + "strings" + "sync" + "time" + + "apigo.cc/go/cast" + "apigo.cc/go/config" + "apigo.cc/go/crypto" + "apigo.cc/go/id" + "apigo.cc/go/log" + "apigo.cc/go/rand" + "apigo.cc/go/safe" +) + +type Config struct { + Type string + User string + Password string + pwd *safe.SafeBuf + Host string + ReadonlyHosts []string + DB string + SSL string + tls *tls.Config + Args string + MaxOpens int + MaxIdles int + MaxLifeTime int + LogSlow config.Duration + logger *log.Logger +} + +type SSL struct { + Ca string + Cert string + Key string + Insecure bool +} + +type ConnectorFunc func(*Config, *safe.SafeBuf, *tls.Config) driver.Connector + +var connectors = map[string]ConnectorFunc{} + +func RegisterConnector(typ string, conn ConnectorFunc) { + connectors[typ] = conn +} + +var sqlite3PwdMatcher = regexp.MustCompile(`x'\w+'`) + +func (dbInfo *Config) Dsn() string { + args := make([]string, 0) + if dbInfo.SSL != "" { + args = append(args, "tls="+dbInfo.SSL) + } + if dbInfo.Args != "" { + args = append(args, dbInfo.Args) + } + argsStr := "" + if len(args) > 0 { + argsStr = "&" + strings.Join(args, "&") + } + + if isFileDB(dbInfo.Type) { + argsStr = sqlite3PwdMatcher.ReplaceAllString(argsStr, "******") + return fmt.Sprintf("%s://%s?logSlow=%s"+argsStr, dbInfo.Type, dbInfo.Host, dbInfo.LogSlow.TimeDuration()) + } else { + hosts := []string{dbInfo.Host} + if dbInfo.ReadonlyHosts != nil { + hosts = append(hosts, dbInfo.ReadonlyHosts...) + } + return fmt.Sprintf("%s://%s:****@%s/%s?logSlow=%s"+argsStr, dbInfo.Type, dbInfo.User, strings.Join(hosts, ","), dbInfo.DB, dbInfo.LogSlow.TimeDuration()) + } +} + +var fileDBs = map[string]bool{ + "sqlite": true, + "sqlite3": true, + "chai": true, + "access": true, + "mdb": true, + "accdb": true, + "h2": true, + "hsqldb": true, + "derby": true, + "sqlce": true, + "sdf": true, + "firebird": true, + "fdb": true, + "dbase": true, + "dbf": true, + "berkeleydb": true, + "bdb": true, +} + +func isFileDB(typ string) bool { + return fileDBs[typ] +} + +func (dbInfo *Config) ConfigureBy(setting string) { + urlInfo, err := url.Parse(setting) + if err != nil { + if dbInfo.logger != nil { + dbInfo.logger.Error(err.Error(), "url", setting) + } + return + } + + dbInfo.Type = urlInfo.Scheme + if isFileDB(dbInfo.Type) { + dbInfo.Host = urlInfo.Host + urlInfo.Path + dbInfo.DB = strings.SplitN(urlInfo.Host, ".", 2)[0] + } else { + if strings.ContainsRune(urlInfo.Host, ',') { + a := strings.Split(urlInfo.Host, ",") + dbInfo.Host = a[0] + dbInfo.ReadonlyHosts = a[1:] + } else { + dbInfo.Host = urlInfo.Host + dbInfo.ReadonlyHosts = nil + } + if len(urlInfo.Path) > 1 { + dbInfo.DB = urlInfo.Path[1:] + } + } + dbInfo.User = urlInfo.User.Username() + dbInfo.Password, _ = urlInfo.User.Password() + + q := urlInfo.Query() + dbInfo.MaxIdles = cast.Int(q.Get("maxIdles")) + dbInfo.MaxLifeTime = cast.Int(q.Get("maxLifeTime")) + dbInfo.MaxOpens = cast.Int(q.Get("maxOpens")) + dbInfo.LogSlow = config.Duration(cast.Duration(q.Get("logSlow"))) + dbInfo.SSL = q.Get("tls") + + sslCa := q.Get("sslCA") + sslCert := q.Get("sslCert") + sslKey := q.Get("sslKey") + sslSkipVerify := cast.Bool(q.Get("sslSkipVerify")) + if sslCa == "" || sslCert == "" || sslKey == "" { + if dbSSL, ok := dbSSLs[dbInfo.SSL]; ok { + sslCa = dbSSL.Ca + sslCert = dbSSL.Cert + sslKey = dbSSL.Key + sslSkipVerify = dbSSL.Insecure + } + } + if sslCa != "" && sslCert != "" && sslKey != "" { + sslName := id.MakeID(12) + dbInfo.SSL = sslName + decryptedCa, _ := confAes.DecryptBytes([]byte(sslCa)) + decryptedCert, _ := confAes.DecryptBytes([]byte(sslCert)) + decryptedKey, _ := confAes.DecryptBytes([]byte(sslKey)) + tlsConf := BuildTLSConfig(decryptedCa, decryptedCert, decryptedKey, sslSkipVerify) + if tlsConf != nil { + dbInfo.tls = tlsConf + } + safe.ZeroMemory(decryptedCa) + safe.ZeroMemory(decryptedCert) + safe.ZeroMemory(decryptedKey) + } + + args := make([]string, 0) + for k := range q { + if k != "maxIdles" && k != "maxLifeTime" && k != "maxOpens" && k != "logSlow" && k != "tls" { + args = append(args, k+"="+q.Get(k)) + } + } + if len(args) > 0 { + dbInfo.Args = strings.Join(args, "&") + } +} + +type DB struct { + name string + conn *sql.DB + readonlyConnections []*sql.DB + Config *Config + logger *dbLogger + Error error + QuoteTag string +} + +var confAes, _ = crypto.NewAESCBCAndEraseKey([]byte("?GQ$0K0GgLdO=f+~L68PLm$uhKr4'=tV"), []byte("VFs7@sK61cj^f?HZ")) +var keysSetted = sync.Once{} + +func SetEncryptKeys(key, iv []byte) { + keysSetted.Do(func() { + confAes.Close() + confAes, _ = crypto.NewAESGCMAndEraseKey(key, iv) + }) +} + +type dbLogger struct { + config *Config + logger *log.Logger +} + +func (dl *dbLogger) LogError(errStr string) { + dl.logger.DBError(errStr, dl.config.Type, dl.config.Dsn(), "", nil, 0) +} + +func (dl *dbLogger) LogQuery(query string, args []any, usedTime float32) { + dl.logger.DB(dl.config.Type, dl.config.Dsn(), query, args, usedTime) +} + +func (dl *dbLogger) LogQueryError(errStr string, query string, args []any, usedTime float32) { + dl.logger.DBError(errStr, dl.config.Type, dl.config.Dsn(), query, args, usedTime) +} + +var dbConfigs = make(map[string]*Config) +var dbConfigsLock = sync.RWMutex{} +var dbSSLs = make(map[string]*SSL) +var dbInstances = make(map[string]*DB) +var dbInstancesLock = sync.RWMutex{} +var once sync.Once + +func GetDBWithoutCache(name string, logger *log.Logger) *DB { + return getDB(name, logger, false) +} + +func GetDB(name string, logger *log.Logger) *DB { + return getDB(name, logger, true) +} + +func getDB(name string, logger *log.Logger, useCache bool) *DB { + if logger == nil { + logger = log.DefaultLogger + } + + if useCache { + dbInstancesLock.RLock() + oldConn := dbInstances[name] + dbInstancesLock.RUnlock() + if oldConn != nil { + return oldConn.CopyByLogger(logger) + } + } + + var conf *Config + if strings.Contains(name, "://") { + conf = new(Config) + conf.logger = logger + conf.ConfigureBy(name) + } else { + dbConfigsLock.RLock() + n := len(dbConfigs) + dbConfigsLock.RUnlock() + if n == 0 { + once.Do(func() { + dbConfigs1 := make(map[string]*Config) + if err := config.Load("db", &dbConfigs1); err == nil { + for k, v := range dbConfigs1 { + if v.Host != "" { + dbConfigsLock.Lock() + dbConfigs[k] = v + dbConfigsLock.Unlock() + } + } + } else { + logger.Error(err.Error()) + } + dbConfigs2 := make(map[string]string) + if err := config.Load("db", &dbConfigs2); err == nil { + for k, v := range dbConfigs2 { + if strings.Contains(v, "://") { + v2 := new(Config) + v2.ConfigureBy(v) + if v2.Host != "" { + v2.logger = logger + dbConfigsLock.Lock() + dbConfigs[k] = v2 + dbConfigsLock.Unlock() + } + } else { + dbConfigsLock.Lock() + v2 := dbConfigs[v] + if v2 != nil && v2.Host != "" { + dbConfigs[k] = v2 + } + dbConfigsLock.Unlock() + } + } + } else { + logger.Error(err.Error()) + } + }) + } + dbConfigsLock.RLock() + conf = dbConfigs[name] + dbConfigsLock.RUnlock() + if conf == nil { + conf = new(Config) + dbConfigsLock.Lock() + dbConfigs[name] = conf + dbConfigsLock.Unlock() + } + } + + if conf.Host == "" { + if name != "default" { + logger.Error("db config not exists", "name", name) + } + return nil + } + + if conf.SSL != "" && len(dbSSLs) == 0 { + _ = config.Load("dbssl", &dbSSLs) + } + + if conf.SSL != "" && dbSSLs[conf.SSL] == nil { + logger.Error("dbssl config lost") + } + + if strings.ContainsRune(conf.Host, ',') { + a := strings.Split(conf.Host, ",") + conf.Host = a[0] + conf.ReadonlyHosts = a[1:] + } else { + conf.ReadonlyHosts = nil + } + + if conf.Password != "" { + if encryptedPassword, err := base64.URLEncoding.DecodeString(conf.Password); err == nil { + if pwdSafeBuf, err := confAes.Decrypt(encryptedPassword); err == nil { + conf.pwd = pwdSafeBuf + } + } + if conf.pwd == nil { + conf.pwd = safe.NewSafeBuf([]byte(conf.Password)) + } + } else { + if !isFileDB(conf.Type) && conf.Host != "127.0.0.1:3306" && conf.User == "root" { + logger.Warning("password is empty") + } + } + conf.Password = "" + + conn, err := getPool(conf) + if err != nil { + logger.DBError(err.Error(), conf.Type, conf.Dsn(), "", nil, 0) + return &DB{conn: nil, QuoteTag: "\"", Error: err} + } + + db := new(DB) + db.QuoteTag = cast.If(conf.Type == "mysql", "`", "\"") + db.name = name + db.conn = conn + + if conf.ReadonlyHosts != nil { + readonlyConnections := make([]*sql.DB, 0) + for _, host := range conf.ReadonlyHosts { + conn, err := getPoolForHost(conf, host) + if err != nil { + logger.DBError(err.Error(), conf.Type, conf.Dsn(), "", nil, 0) + } else { + readonlyConnections = append(readonlyConnections, conn) + } + } + if len(readonlyConnections) > 0 { + db.readonlyConnections = readonlyConnections + } + } + + db.Error = nil + db.Config = conf + if conf.MaxIdles > 0 { + conn.SetMaxIdleConns(conf.MaxIdles) + } + if conf.MaxOpens > 0 { + conn.SetMaxOpenConns(conf.MaxOpens) + } + if conf.MaxLifeTime > 0 { + conn.SetConnMaxLifetime(time.Second * time.Duration(conf.MaxLifeTime)) + } + if conf.LogSlow == 0 { + conf.LogSlow = config.Duration(1000 * time.Millisecond) + } + if useCache { + dbInstancesLock.Lock() + dbInstances[name] = db + dbInstancesLock.Unlock() + } + return db.CopyByLogger(logger) +} + +func getPool(conf *Config) (*sql.DB, error) { + return getPoolForHost(conf, "") +} + +func getPoolForHost(conf *Config, host string) (*sql.DB, error) { + connectType := "tcp" + if host == "" { + host = conf.Host + } + if len(host) > 0 && host[0] == '/' { + connectType = "unix" + } + + if connector := connectors[conf.Type]; connector != nil { + return sql.OpenDB(connector(conf, conf.pwd, conf.tls)), nil + } else { + dsn := "" + args := make([]string, 0) + if conf.SSL != "" { + args = append(args, "tls="+conf.SSL) + } + if conf.Args != "" { + args = append(args, conf.Args) + } + argsStr := "" + if len(args) > 0 { + argsStr = "?" + strings.Join(args, "&") + } + + if isFileDB(conf.Type) { + dsn = host + argsStr + } else { + pwdBuf := conf.pwd.Open() + dsn = fmt.Sprintf("%s:%s@%s(%s)/%s"+argsStr, conf.User, pwdBuf.String(), connectType, host, conf.DB) + pwdBuf.Close() + } + return sql.Open(conf.Type, dsn) + } +} + +func (db *DB) CopyByLogger(logger *log.Logger) *DB { + newDB := new(DB) + newDB.QuoteTag = db.QuoteTag + newDB.name = db.name + newDB.conn = db.conn + newDB.readonlyConnections = db.readonlyConnections + newDB.Config = db.Config + if logger == nil { + logger = log.DefaultLogger + } + newDB.logger = &dbLogger{logger: logger, config: db.Config} + return newDB +} + +func (db *DB) SetLogger(logger *log.Logger) { + db.logger.logger = logger +} + +func (db *DB) GetLogger() *log.Logger { + return db.logger.logger +} + +func (db *DB) Destroy() error { + if db.conn == nil { + return errors.New("operate on a bad connection") + } + err := db.conn.Close() + if err != nil { + db.logger.LogError(err.Error()) + } + dbInstancesLock.Lock() + delete(dbInstances, db.name) + dbInstancesLock.Unlock() + return err +} + +func (db *DB) GetOriginDB() *sql.DB { + return db.conn +} + +func (db *DB) Prepare(query string) *Stmt { + stmt := basePrepare(db.conn, nil, query) + stmt.logger = db.logger + if stmt.Error != nil { + db.logger.LogError(stmt.Error.Error()) + } + return stmt +} + +func (db *DB) Quote(text string) string { + return quote(db.QuoteTag, text) +} + +func (db *DB) Quotes(texts []string) string { + return quotes(db.QuoteTag, texts) +} + +func (db *DB) Begin() *Tx { + if db.conn == nil { + return &Tx{QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), Error: errors.New("operate on a bad connection"), logger: db.logger} + } + sqlTx, err := db.conn.Begin() + if err != nil { + db.logger.LogError(err.Error()) + return &Tx{QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), Error: err, logger: db.logger} + } + return &Tx{QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), conn: sqlTx, logger: db.logger} +} + +func (db *DB) Exec(query string, args ...any) *ExecResult { + r := baseExec(db.conn, nil, query, args...) + r.logger = db.logger + if r.Error != nil { + db.logger.LogQueryError(r.Error.Error(), query, args, r.usedTime) + } else { + if db.Config.LogSlow > 0 && r.usedTime >= float32(db.Config.LogSlow.TimeDuration()/time.Millisecond) { + db.logger.LogQuery(query, args, r.usedTime) + } + } + return r +} + +func (db *DB) Query(query string, args ...any) *QueryResult { + conn := db.conn + if db.readonlyConnections != nil { + connNum := len(db.readonlyConnections) + if connNum == 1 { + conn = db.readonlyConnections[0] + } else { + p := rand.Int(0, connNum-1) + conn = db.readonlyConnections[p] + } + } + + r := baseQuery(conn, nil, query, args...) + r.logger = db.logger + if r.Error != nil { + db.logger.LogQueryError(r.Error.Error(), query, args, r.usedTime) + } else { + if db.Config.LogSlow > 0 && r.usedTime >= float32(db.Config.LogSlow.TimeDuration()/time.Millisecond) { + db.logger.LogQuery(query, args, r.usedTime) + } + } + return r +} + +func (db *DB) Insert(table string, data any) *ExecResult { + query, values := db.MakeInsertSql(table, data, false) + r := baseExec(db.conn, nil, query, values...) + r.logger = db.logger + if r.Error != nil { + db.logger.LogQueryError(r.Error.Error(), query, values, r.usedTime) + } else { + if db.Config.LogSlow > 0 && r.usedTime >= float32(db.Config.LogSlow.TimeDuration()/time.Millisecond) { + db.logger.LogQuery(query, values, r.usedTime) + } + } + return r +} + +func (db *DB) Replace(table string, data any) *ExecResult { + query, values := db.MakeInsertSql(table, data, true) + r := baseExec(db.conn, nil, query, values...) + r.logger = db.logger + if r.Error != nil { + db.logger.LogQueryError(r.Error.Error(), query, values, r.usedTime) + } else { + if db.Config.LogSlow > 0 && r.usedTime >= float32(db.Config.LogSlow.TimeDuration()/time.Millisecond) { + db.logger.LogQuery(query, values, r.usedTime) + } + } + return r +} + +func (db *DB) Update(table string, data any, conditions string, args ...any) *ExecResult { + query, values := db.MakeUpdateSql(table, data, conditions, args...) + r := baseExec(db.conn, nil, query, values...) + r.logger = db.logger + if r.Error != nil { + db.logger.LogQueryError(r.Error.Error(), query, values, r.usedTime) + } else { + if db.Config.LogSlow > 0 && r.usedTime >= float32(db.Config.LogSlow.TimeDuration()/time.Millisecond) { + db.logger.LogQuery(query, values, r.usedTime) + } + } + return r +} + +func (db *DB) Delete(table string, conditions string, args ...any) *ExecResult { + if conditions != "" { + conditions = " where " + conditions + } + query := fmt.Sprintf("delete from %s%s", db.Quote(table), conditions) + r := baseExec(db.conn, nil, query, args...) + r.logger = db.logger + if r.Error != nil { + db.logger.LogQueryError(r.Error.Error(), query, args, r.usedTime) + } else { + if db.Config.LogSlow > 0 && r.usedTime >= float32(db.Config.LogSlow.TimeDuration()/time.Millisecond) { + db.logger.LogQuery(query, args, r.usedTime) + } + } + return r +} + +func (db *DB) InKeys(numArgs int) string { + return InKeys(numArgs) +} + +func InKeys(numArgs int) string { + a := make([]string, numArgs) + for i := 0; i < numArgs; i++ { + a[i] = "?" + } + return fmt.Sprintf("(%s)", strings.Join(a, ",")) +} diff --git a/DB_test.go b/DB_test.go new file mode 100644 index 0000000..f8c2778 --- /dev/null +++ b/DB_test.go @@ -0,0 +1,477 @@ +package db_test + +import ( + "fmt" + "regexp" + "strings" + "testing" + "time" + + "apigo.cc/go/cast" + "apigo.cc/go/db" + "apigo.cc/go/shell" + + _ "apigo.cc/go/db/mysql" + _ "modernc.org/sqlite" +) + +var dbset = "sqlite://test.db" + +type userInfo struct { + innerId int + Tag string + Id int + Name string + Phone *string + Email string + Parents []string + Active bool + Time string +} + +type UserBaseModel struct { + Id int + Name string + Password string +} + +type UserModel struct { + UserBaseModel + Phone string + Active bool + Parents []string + UserStatus int + Owner int + Salt string +} + +func initDB(t *testing.T) *db.DB { + var er *db.ExecResult + dbInst := db.GetDB(dbset, nil) + fmt.Println("dbType", shell.BCyan(dbInst.Config.Type)) + if dbInst.Error != nil { + t.Fatal("GetDB error", dbInst) + return nil + } + + finishDB(dbInst, t) + if dbInst.Config.Type == "mysql" { + er = dbInst.Exec(`CREATE TABLE IF NOT EXISTS tempUsersForDBTest ( + id INT NOT NULL AUTO_INCREMENT, + name VARCHAR(45) NOT NULL, + phone VARCHAR(45), + email VARCHAR(45), + parents JSON, + active TINYINT NOT NULL DEFAULT 0, + time DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (id));`) + } else { + er = dbInst.Exec(`CREATE TABLE IF NOT EXISTS tempUsersForDBTest ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + name VARCHAR(45) NOT NULL, + phone VARCHAR(45), + email VARCHAR(45), + parents JSON, + active TINYINT NOT NULL DEFAULT 0, + time DATETIME NOT NULL DEFAULT (strftime('%Y-%m-%d %H:%M:%f')));`) + } + if er.Error != nil { + t.Fatal("Failed to create table", er) + } + return dbInst +} + +func finishDB(dbInst *db.DB, t *testing.T) { + er := dbInst.Exec(`DROP TABLE IF EXISTS tempUsersForDBTest;`) + if er.Error != nil { + t.Fatal("Failed to drop table", er) + } +} + +func TestMakeInsertSql(t *testing.T) { + user := &UserModel{ + UserBaseModel: UserBaseModel{ + Name: "王二小", + Password: "2121asds", + }, + Parents: []string{"aa", "bb"}, + Active: true, + UserStatus: 1, + Salt: "de312", + } + + dbInst := db.GetDB(dbset, nil) + query, _ := dbInst.MakeInsertSql("table_name", user, false) + checkSql := `insert into "table_name" ("Id","Name","Password","Phone","Active","Parents","UserStatus","Owner","Salt") values (?,?,?,?,?,?,?,?,?)` + if dbInst.Config.Type == "mysql" { + checkSql = strings.ReplaceAll(checkSql, "\"", "`") + } + if query != checkSql { + t.Fatal("MakeInsertSql query error ", query) + } +} + +func TestBaseSelect(t *testing.T) { + sqlStr := "SELECT 1002 id, '13800000001' phone" + dbInst := db.GetDB(dbset, nil) + if dbInst.Error != nil { + t.Fatal("GetDB error", dbInst.Error) + return + } + + r := dbInst.Query(sqlStr) + if r.Error != nil { + t.Fatal("Query error", sqlStr, r) + } + results1 := r.MapResults() + if cast.Int(results1[0]["id"]) != 1002 || cast.String(results1[0]["phone"]) != "13800000001" { + t.Fatal("Result error", sqlStr, results1, r) + } + + r = dbInst.Query(sqlStr) + if r.Error != nil { + t.Fatal("Query error", sqlStr, r) + } + results2 := r.StringMapResults() + if results2[0]["id"] != "1002" || results2[0]["phone"] != "13800000001" { + t.Fatal("Result error", sqlStr, results2, r) + } + + results3 := make([]map[string]int, 0) + r = dbInst.Query(sqlStr) + if r.Error != nil { + t.Fatal("Query error", sqlStr, results3, r) + } + r.To(&results3) + if results3[0]["id"] != 1002 || results3[0]["phone"] != 13800000001 { + t.Fatal("Result error", sqlStr, results3, r) + } + + results4 := make([]userInfo, 0) + r = dbInst.Query(sqlStr) + if r.Error != nil { + t.Fatal("Query error", sqlStr, results4, r) + } + r.To(&results4) + if results4[0].Id != 1002 || results4[0].Phone == nil || *results4[0].Phone != "13800000001" { + t.Fatal("Result error", sqlStr, results4, r) + } + + results5 := dbInst.Query(sqlStr).StringSliceResults() + if results5[0][0] != "1002" || results5[0][1] != "13800000001" { + t.Fatal("Result error", sqlStr, results5, r) + } + + r = dbInst.Query(sqlStr) + if r.Error != nil { + t.Fatal("Query error", sqlStr, r) + } + results6 := r.StringsOnC1() + if results6[0] != "1002" { + t.Fatal("Result error", sqlStr, results6, r) + } + + r = dbInst.Query(sqlStr) + if r.Error != nil { + t.Fatal("Query error", sqlStr, r) + } + results7 := r.MapOnR1() + if cast.Int(results7["id"]) != 1002 || cast.String(results7["phone"]) != "13800000001" { + t.Fatal("Result error", sqlStr, results7, r) + } + + results8 := userInfo{innerId: 2, Tag: "abc"} + r = dbInst.Query(sqlStr) + if r.Error != nil { + t.Fatal("Query error", sqlStr, results8, r) + } + r.To(&results8) + if results8.Id != 1002 || results8.Phone == nil || *results8.Phone != "13800000001" || results8.innerId != 2 || results8.Tag != "abc" { + t.Fatal("Result error", sqlStr, results8, r) + } + + r = dbInst.Query(sqlStr) + if r.Error != nil { + t.Fatal("Query error", sqlStr, r) + } + results9 := r.IntOnR1C1() + if results9 != 1002 { + t.Fatal("Result error", sqlStr, results9, r) + } + + r = dbInst.Query(sqlStr) + results10 := map[string]string{} + r.ToKV(&results10) + if results10["1002"] != "13800000001" { + t.Fatal("Result error", sqlStr, results10, r) + } + + r = dbInst.Query(sqlStr) + results11 := map[string]map[string]string{} + r.ToKV(&results11) + if results11["1002"]["phone"] != "13800000001" { + t.Fatal("Result error", sqlStr, results11, r) + } + + r = dbInst.Query(sqlStr) + results12 := map[string]userInfo{} + r.ToKV(&results12) + if results12["1002"].Phone == nil || *results12["1002"].Phone != "13800000001" { + t.Fatal("Result error", sqlStr, results12, r) + } +} + +func TestInsertReplaceUpdateDelete(t *testing.T) { + dbInst := initDB(t) + data := map[string]any{ + "phone": 18033336666, + "name": "Star", + "parents": []string{"dd", "mm"}, + "time": ":(strftime('%Y-%m-%d %H:%M:%f', datetime('now', '-1 day'), 'localtime'))", + } + if dbInst.Config.Type == "mysql" { + data["time"] = ":DATE_SUB(NOW(), INTERVAL 1 DAY)" + } + er := dbInst.Insert("tempUsersForDBTest", data) + if er.Error != nil { + t.Fatal("Insert 1 error", er) + } + if er.Id() != 1 { + t.Fatal("insertId 1 error", er, er.Id()) + } + + er = dbInst.Insert("tempUsersForDBTest", map[string]any{ + "phone": "18000000002", + "name": "Tom", + "active": true, + }) + if er.Error != nil { + t.Fatal("Insert 2 error", er) + } + if er.Id() != 2 { + t.Fatal("insertId 2 error", er, er.Id()) + } + + er = dbInst.Update("tempUsersForDBTest", map[string]any{ + "phone": "18000000222", + "name": "Tom Lee", + }, "id=?", 2) + if er.Error != nil { + t.Fatal("Update 2 error", er) + } + if er.Changes() != 1 { + t.Fatal("Update 2 num error", er, er.Changes()) + } + + er = dbInst.Replace("tempUsersForDBTest", map[string]any{ + "phone": "18000000003", + "name": "Amy", + }) + if er.Error != nil { + t.Fatal("Replace 3 error", er) + } + if er.Id() != 3 { + t.Fatal("insertId 3 error", er, er.Changes()) + } + + er = dbInst.Exec("delete from tempUsersForDBTest where id=3") + if er.Error != nil { + t.Fatal("Delete 3 error", er) + } + if er.Changes() != 1 { + t.Fatal("Delete 3 num error", er) + } + + er = dbInst.Replace("tempUsersForDBTest", map[string]any{ + "phone": "18000000004", + "name": "Jerry", + }) + if er.Error != nil { + t.Fatal("Replace 4 error", er) + } + if er.Id() != 4 { + t.Fatal("insertId 4 error", er, er.Changes()) + } + + stmt := dbInst.Prepare("replace into `tempUsersForDBTest` (`id`,`phone`,`name`) values (?,?,?)") + if stmt.Error != nil { + t.Fatal("Prepare 4 error", stmt) + } + er = stmt.Exec(4, "18000000004", "Jerry's Mather") + stmt.Close() + + if er.Error != nil { + t.Fatal("Replace 4 error", er) + } + if er.Id() != 4 { + t.Fatal("insertId 4 error", er) + } + + userList := make([]userInfo, 0) + r := dbInst.Query("select * from tempUsersForDBTest") + if r.Error != nil { + t.Fatal("Select userList error", r) + } + r.To(&userList) + fmt.Println(">>>>", cast.PrettyToJSON(userList)) + if strings.Split(userList[0].Time, " ")[0] != time.Now().Add(time.Hour*24*-1).Format("2006-01-02") || userList[0].Id != 1 || userList[0].Name != "Star" || userList[0].Phone == nil || *userList[0].Phone != "18033336666" || userList[0].Active != false { + t.Fatal("Select userList 0 error", userList, r) + } + if len(userList[0].Parents) != 2 || userList[0].Parents[0] != "dd" { + t.Fatal("Select userList 0 Parents error", userList, r) + } + if strings.Split(userList[1].Time, " ")[0] != time.Now().Format("2006-01-02") || userList[1].Id != 2 || userList[1].Name != "Tom Lee" || userList[1].Phone == nil || *userList[1].Phone != "18000000222" || userList[1].Active != true { + t.Fatal("Select userList 1 error", userList, r) + } + if userList[2].Id != 4 || userList[2].Name != "Jerry's Mather" || userList[2].Phone == nil || *userList[2].Phone != "18000000004" { + t.Fatal("Select userList 2 error", userList, r) + } + + finishDB(dbInst, t) +} + +func TestTransaction(t *testing.T) { + n1 := countConnection() + + var userList []userInfo + + dbInst := initDB(t) + tx := dbInst.Begin() + if tx.Error != nil { + t.Fatal("Begin error", tx) + } + + data := map[string]any{ + "phone": 18033336666, + "name": "Star", + "time": ":(strftime('%Y-%m-%d %H:%M:%f', datetime('now', '-1 day'), 'localtime'))", + } + if dbInst.Config.Type == "mysql" { + data["time"] = ":DATE_SUB(NOW(), INTERVAL 1 DAY)" + } + tx.Insert("tempUsersForDBTest", data) + + userList = make([]userInfo, 0) + r := dbInst.Query("select * from tempUsersForDBTest") + r.To(&userList) + if r.Error != nil || len(userList) != 0 { + t.Fatal("Select Out Of TX", userList, r) + } + + userList = make([]userInfo, 0) + r = tx.Query("select * from tempUsersForDBTest") + r.To(&userList) + if r.Error != nil || len(userList) != 1 { + t.Fatal("Select In TX", userList, r) + } + + tx.Rollback() + + userList = make([]userInfo, 0) + r = dbInst.Query("select * from tempUsersForDBTest") + r.To(&userList) + if r.Error != nil || len(userList) != 0 { + t.Fatal("Select When Rollback", userList, r) + } + + tx = dbInst.Begin() + defer func() { + if err := tx.Finish(false); err != nil { + t.Error("tx rollback error", err) + } + finishDB(dbInst, t) + }() + if tx.Error != nil { + t.Fatal("Begin 2 error", tx) + } + + stmt := tx.Prepare("insert into `tempUsersForDBTest` (`id`,`phone`,`name`) values (?,?,?)") + if stmt.Error != nil { + t.Fatal("Prepare 4 error", r) + } + stmt.Exec(4, "18000000004", "Jerry's Mather") + stmt.Close() + + tx.Commit() + + userList = make([]userInfo, 0) + r = dbInst.Query("select * from tempUsersForDBTest") + r.To(&userList) + if r.Error != nil || len(userList) != 1 { + t.Fatal("Select When Commit", userList, r) + } + + n2 := countConnection() + fmt.Println("# connection count", n1, n2, cast.MustToJSON(dbInst.GetOriginDB().Stats()), ".") +} + +func countConnection() int { + n := 0 + res, _ := shell.RunCommand("netstat -ant", nil) + lines := strings.Split(string(res.Stdout), "\n") + spliter := regexp.MustCompile("\\s+") + for _, line := range lines { + if strings.Contains(line, ".3306") && strings.Contains(line, "ESTABLISHED") { + a := spliter.Split(line, 10) + if len(a) > 4 && strings.Contains(a[4], ".3306") { + n++ + } + } + } + return n +} + +func BenchmarkForPool(b *testing.B) { + b.StopTimer() + sqlStr := "SELECT 1002 id, '13800000001' phone" + dbInst := db.GetDB(dbset, nil) + if dbInst.Error != nil { + b.Fatal("GetDB error", dbInst) + return + } + + b.StartTimer() + for i := 0; i < b.N; i++ { + results1 := make([]map[string]any, 0) + r := dbInst.Query(sqlStr) + if r.Error != nil { + b.Fatal("Query error", sqlStr, results1, r) + } + r.To(&results1) + if cast.Int(results1[0]["id"]) != 1002 || cast.String(results1[0]["phone"]) != "13800000001" { + b.Fatal("Result error", sqlStr, results1, r) + } + } + b.Log("OpenConnections", dbInst.GetOriginDB().Stats().OpenConnections) +} + +func BenchmarkForPoolParallel(b *testing.B) { + n1 := countConnection() + + b.StopTimer() + sqlStr := "SELECT 1002 id, '13800000001' phone" + dbInst := db.GetDB(dbset, nil) + if dbInst.Error != nil { + b.Fatal("GetDB error", dbInst) + return + } + b.StartTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + results1 := make([]map[string]any, 0) + r := dbInst.Query(sqlStr) + if r.Error != nil { + b.Fatal("Query error", sqlStr, results1, r) + } + r.To(&results1) + if cast.Int(results1[0]["id"]) != 1002 || cast.String(results1[0]["phone"]) != "13800000001" { + b.Fatal("Result error", sqlStr, results1, r) + } + } + }) + b.Log("OpenConnections", dbInst.GetOriginDB().Stats().OpenConnections) + + n2 := countConnection() + fmt.Println("# connection count", n1, n2, cast.MustToJSON(dbInst.GetOriginDB().Stats()), ".") +} diff --git a/README.md b/README.md index cdba1d2..584cf4b 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,65 @@ -# db +# @go/db -High-performance database abstraction layer for apigo.cc/go \ No newline at end of file +> **Maintainer Statement:** 本项目由 AI 维护。代码源自 github.com/ssgo/db 的重构,支持内存安全防护、读写分离及泛型优化。 + +## 🎯 设计哲学 + +`@go/db` 是一个极致精简、意图优先的数据库抽象层。它不试图取代 SQL,而是通过智能结果绑定与 SQL 自动化生成,消除数据库操作中的样板代码。 + +* **智能绑定**:根据结果容器类型(Struct/Map/Slice/BaseType)自动适配查询逻辑,无需手动 Scan。 +* **内存防御**:集成 `go/safe`,数据库密码在内存中加密存储,使用时物理擦除。 +* **读写分离**:内置连接池管理,支持配置多个只读节点实现自动负载均衡。 +* **驱动透明**:统一 MySQL、PostgreSQL (pgx) 与 SQLite 的 API 差异。 + +## 📦 安装 + +```bash +go get apigo.cc/go/db +``` + +## 💡 快速开始 + +```go +import "apigo.cc/go/db" +import _ "apigo.cc/go/db/mysql" // 引入驱动 + +// 初始化连接 +d := db.GetDB("mysql://user:pass@host:3306/dbname", nil) + +// 1. 查询全部结果到 Struct 切片 +var users []User +d.Query("SELECT * FROM users").To(&users) + +// 2. 自动化插入 +d.Insert("users", User{Name: "Star", Active: true}) + +// 3. 事务操作 +tx := d.Begin() +tx.Exec("UPDATE balance SET amount = amount - 10 WHERE id = ?", 1) +tx.Commit() +``` + +## 🛠 API 指南 + +### 核心对象 +- **`GetDB(setting string, logger *log.Logger) *DB`**: 通过 DSN 或配置名获取数据库实例。 +- **`DB.Insert/Replace/Update/Delete`**: 自动生成 SQL 并执行,支持 Struct 与 Map。 +- **`QueryResult.To(target any)`**: 将查询结果深度映射到目标容器。 +- **`QueryResult.MapResults() []map[string]any`**: 快捷获取通用结果集。 + +### 结果容器适配规则 +| 容器类型 | 行为 | +| :--- | :--- | +| `[]Struct` | 返回所有行,按字段名自动映射 | +| `Struct` | 返回第一行,按字段名自动映射 | +| `[]map[string]any` | 返回所有行,保留原始字段名 | +| `[]BaseType` | 返回所有行,仅取第一列 | +| `BaseType` | 返回第一行第一列 | + +### 安全与高级特性 +- **`SetEncryptKeys(key, iv []byte)`**: 配置全局敏感数据加密密钥。 +- **读写分离**: 在 DSN 中配置 `host1,host2,host3`,首个为主库,其余为随机只读库。 +- **SQLite 时间修复**: 自动处理 SQLite 毫秒级 `DATETIME` 格式与标准 `time.Time` 的转换。 + +## 🧪 验证状态 +已通过 SQLite 集成测试。详见:[TEST.md](./TEST.md) diff --git a/Result.go b/Result.go new file mode 100644 index 0000000..36bccf2 --- /dev/null +++ b/Result.go @@ -0,0 +1,577 @@ +package db + +import ( + "database/sql" + "encoding/json" + "errors" + "fmt" + "reflect" + "strings" + "time" + + "apigo.cc/go/cast" + "apigo.cc/go/convert" + "github.com/mitchellh/mapstructure" +) + +type QueryResult struct { + rows *sql.Rows + Sql *string + Args []any + Error error + logger *dbLogger + usedTime float32 + completed bool +} + +type ExecResult struct { + result sql.Result + Sql *string + Args []any + Error error + logger *dbLogger + usedTime float32 +} + +func (r *ExecResult) Changes() int64 { + if r.result == nil { + return 0 + } + numChanges, err := r.result.RowsAffected() + if err != nil { + r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime) + return 0 + } + return numChanges +} + +func (r *ExecResult) Id() int64 { + if r.result == nil { + return 0 + } + insertId, err := r.result.LastInsertId() + if err != nil { + r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime) + return 0 + } + return insertId +} + +func (r *QueryResult) Complete() { + if !r.completed { + if r.rows != nil { + r.rows.Close() + } + r.completed = true + } +} + +func (r *QueryResult) To(result any) error { + if r.rows == nil { + return errors.New("operate on a bad query") + } + return r.makeResults(result, r.rows) +} + +func (r *QueryResult) MapResults() []map[string]any { + result := make([]map[string]any, 0) + err := r.makeResults(&result, r.rows) + if err != nil { + r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime) + } + return result +} + +func (r *QueryResult) SliceResults() [][]any { + result := make([][]any, 0) + err := r.makeResults(&result, r.rows) + if err != nil { + r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime) + } + return result +} + +func (r *QueryResult) StringMapResults() []map[string]string { + result := make([]map[string]string, 0) + err := r.makeResults(&result, r.rows) + if err != nil { + r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime) + } + return result +} + +func (r *QueryResult) StringSliceResults() [][]string { + result := make([][]string, 0) + err := r.makeResults(&result, r.rows) + if err != nil { + r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime) + } + return result +} + +func (r *QueryResult) MapOnR1() map[string]any { + result := make(map[string]any) + err := r.makeResults(&result, r.rows) + if err != nil { + r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime) + } + return result +} + +func (r *QueryResult) StringMapOnR1() map[string]string { + result := make(map[string]string) + err := r.makeResults(&result, r.rows) + if err != nil { + r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime) + } + return result +} + +func (r *QueryResult) IntsOnC1() []int64 { + result := make([]int64, 0) + err := r.makeResults(&result, r.rows) + if err != nil { + r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime) + } + return result +} + +func (r *QueryResult) StringsOnC1() []string { + result := make([]string, 0) + err := r.makeResults(&result, r.rows) + if err != nil { + r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime) + } + return result +} + +func (r *QueryResult) IntOnR1C1() int64 { + var result int64 = 0 + err := r.makeResults(&result, r.rows) + if err != nil { + r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime) + } + return result +} + +func (r *QueryResult) FloatOnR1C1() float64 { + var result float64 = 0 + err := r.makeResults(&result, r.rows) + if err != nil { + r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime) + } + return result +} + +func (r *QueryResult) StringOnR1C1() string { + result := "" + err := r.makeResults(&result, r.rows) + if err != nil { + r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime) + } + return result +} + +func (r *QueryResult) ToKV(target any) error { + v := reflect.ValueOf(target) + t := v.Type() + for t.Kind() == reflect.Ptr { + v = v.Elem() + t = v.Type() + } + + if t.Kind() != reflect.Map { + r.logger.LogQueryError("target not a map", *r.Sql, r.Args, r.usedTime) + return errors.New("target not a map") + } + + vt := t.Elem() + finalVt := vt + for finalVt.Kind() == reflect.Ptr { + finalVt = finalVt.Elem() + } + if finalVt.Kind() == reflect.Map || finalVt.Kind() == reflect.Struct { + colTypes, err := r.getColumnTypes() + list := r.MapResults() + if err != nil { + r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime) + return err + } + for _, item := range list { + newKey := reflect.ValueOf(reflect.New(t.Key()).Interface()).Elem() + convert.To(item[colTypes[0].Name()], newKey.Addr().Interface()) + + newValue := v.MapIndex(newKey) + isNew := false + if !newValue.IsValid() { + newValue = reflect.New(vt) + isNew = true + } + + err := mapstructure.WeakDecode(item, newValue.Interface()) + if err != nil { + r.logger.LogError(err.Error()) + } + + if isNew { + v.SetMapIndex(newKey, newValue.Elem()) + } + } + } else { + list := r.SliceResults() + for _, item := range list { + if len(item) < 2 { + continue + } + switch vt.Kind() { + case reflect.Int: + v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(cast.Int(item[1]))) + case reflect.Int8: + v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(int8(cast.Int(item[1])))) + case reflect.Int16: + v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(int16(cast.Int(item[1])))) + case reflect.Int32: + v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(int32(cast.Int(item[1])))) + case reflect.Int64: + v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(cast.Int64(item[1]))) + case reflect.Uint: + v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(uint(cast.Int(item[1])))) + case reflect.Uint8: + v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(uint8(cast.Int(item[1])))) + case reflect.Uint16: + v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(uint16(cast.Int(item[1])))) + case reflect.Uint32: + v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(uint32(cast.Int(item[1])))) + case reflect.Uint64: + v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(cast.Uint64(item[1]))) + case reflect.Float32: + v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(cast.Float(item[1]))) + case reflect.Float64: + v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(cast.Float64(item[1]))) + case reflect.Bool: + v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(cast.Bool(item[1]))) + case reflect.String: + v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(cast.String(item[1]))) + case reflect.Interface: + v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(item[1])) + } + } + } + + return nil +} + +func (r *QueryResult) makeResults(results any, rows *sql.Rows) error { + if rows == nil { + return errors.New("not a valid query result") + } + + defer func() { + _ = rows.Close() + r.completed = true + }() + resultsValue := reflect.ValueOf(results) + if resultsValue.Kind() != reflect.Ptr { + return fmt.Errorf("results must be a pointer") + } + + for resultsValue.Kind() == reflect.Ptr { + resultsValue = resultsValue.Elem() + } + rowType := resultsValue.Type() + + colTypes, err := rows.ColumnTypes() + if err != nil { + return err + } + + colNum := len(colTypes) + originRowType := rowType + if rowType.Kind() == reflect.Slice { + rowType = rowType.Elem() + originRowType = rowType + for rowType.Kind() == reflect.Ptr { + rowType = rowType.Elem() + } + } + + scanValues := make([]any, colNum) + var fieldInfos []struct { + index []int + typ reflect.Type + name string + } + + if rowType.Kind() == reflect.Struct { + fieldInfos = make([]struct { + index []int + typ reflect.Type + name string + }, colNum) + for colIndex, col := range colTypes { + publicColName := makePublicVarName(col.Name()) + field, found := rowType.FieldByName(publicColName) + if found { + fieldInfos[colIndex].index = field.Index + fieldInfos[colIndex].typ = field.Type + fieldInfos[colIndex].name = publicColName + if field.Type.Kind() == reflect.Interface { + scanValues[colIndex] = makeValue(colTypes[colIndex].ScanType()) + } else { + scanValues[colIndex] = makeValue(field.Type) + } + } else { + fieldInfos[colIndex].index = nil + scanValues[colIndex] = makeValue(nil) + } + } + } else if rowType.Kind() == reflect.Map { + for colIndex := range colTypes { + if rowType.Elem().Kind() == reflect.Interface { + scanValues[colIndex] = makeValue(colTypes[colIndex].ScanType()) + } else { + scanValues[colIndex] = makeValue(rowType.Elem()) + } + } + } else if rowType.Kind() == reflect.Slice { + for colIndex := range colTypes { + if rowType.Elem().Kind() == reflect.Interface { + scanValues[colIndex] = makeValue(colTypes[colIndex].ScanType()) + } else { + scanValues[colIndex] = makeValue(rowType.Elem()) + } + } + } else { + if rowType.Kind() == reflect.Interface { + scanValues[0] = makeValue(colTypes[0].ScanType()) + } else { + scanValues[0] = makeValue(rowType) + } + for colIndex := 1; colIndex < colNum; colIndex++ { + scanValues[colIndex] = makeValue(nil) + } + } + + var data reflect.Value + isNew := true + for rows.Next() { + err = rows.Scan(scanValues...) + if err != nil { + return err + } + if rowType.Kind() == reflect.Struct { + if resultsValue.Kind() == reflect.Slice { + data = reflect.New(rowType).Elem() + } else { + data = resultsValue + isNew = false + } + + for colIndex, col := range colTypes { + fInfo := fieldInfos[colIndex] + if fInfo.index == nil { + continue + } + field := data.FieldByIndex(fInfo.index) + valuePtr := reflect.ValueOf(scanValues[colIndex]).Elem() + if !valuePtr.IsNil() { + val := valuePtr.Elem() + if fInfo.typ.String() == "time.Time" { + str := val.String() + tm, err := time.Parse("2006-01-02 15:04:05.000000", str) + if err != nil { + tm, err = time.Parse("2006-01-02 15:04:05", str) + } + if err == nil { + field.Set(reflect.ValueOf(tm)) + } + } else if val.Kind() != field.Kind() && field.Kind() != reflect.Interface { + if field.Kind() == reflect.Ptr && val.Kind() == field.Type().Elem().Kind() { + if val.CanAddr() { + if field.Type().AssignableTo(val.Type()) { + field.Set(val.Addr()) + } else if val.Type().String() == "string" { + strVal := fixValue(col.DatabaseTypeName(), val) + field.Set(reflect.New(field.Type().Elem())) + field.Elem().SetString(cast.String(strVal.Interface())) + } else if strings.Contains(field.Type().String(), "uint") { + field.Set(reflect.New(field.Type().Elem())) + field.Elem().SetUint(cast.Uint64(val.Interface())) + } else if strings.Contains(field.Type().String(), "int") { + field.Set(reflect.New(field.Type().Elem())) + field.Elem().SetInt(cast.Int64(val.Interface())) + } else if strings.Contains(field.Type().String(), "float") { + field.Set(reflect.New(field.Type().Elem())) + field.Elem().SetFloat(cast.Float64(val.Interface())) + } else { + field.Set(val.Addr()) + } + } + } else { + convertedObject := reflect.New(field.Type()) + if s, ok := val.Interface().(string); ok { + storedValue := new(any) + if s != "" { + _ = json.Unmarshal([]byte(s), storedValue) + } + convert.To(storedValue, convertedObject.Interface()) + field.Set(convertedObject.Elem()) + } else { + convert.To(val.Interface(), convertedObject.Interface()) + } + } + } else if field.Type().AssignableTo(val.Type()) { + if val.Kind() == reflect.String { + field.Set(fixValue(col.DatabaseTypeName(), val)) + } else { + field.Set(val) + } + } else if val.Type().String() == "string" { + field.Set(fixValue(col.DatabaseTypeName(), val)) + } else if strings.Contains(val.Type().String(), "int") { + field.SetInt(val.Int()) + } else if strings.Contains(val.Type().String(), "float") { + field.SetFloat(val.Float()) + } else { + field.Set(val) + } + } + } + } else if rowType.Kind() == reflect.Map { + if resultsValue.Kind() == reflect.Slice { + data = reflect.MakeMap(rowType) + } else { + data = resultsValue + isNew = false + } + for colIndex, col := range colTypes { + valuePtr := reflect.ValueOf(scanValues[colIndex]).Elem() + if !valuePtr.IsNil() { + data.SetMapIndex(reflect.ValueOf(col.Name()), fixValue(col.DatabaseTypeName(), valuePtr.Elem())) + } else { + data.SetMapIndex(reflect.ValueOf(col.Name()), fixValue(col.DatabaseTypeName(), reflect.New(rowType.Elem()).Elem())) + } + } + } else if rowType.Kind() == reflect.Slice { + data = reflect.MakeSlice(rowType, colNum, colNum) + for colIndex, col := range colTypes { + valuePtr := reflect.ValueOf(scanValues[colIndex]).Elem() + if !valuePtr.IsNil() { + data.Index(colIndex).Set(fixValue(col.DatabaseTypeName(), valuePtr.Elem())) + } else { + data.Index(colIndex).Set(fixValue(col.DatabaseTypeName(), reflect.New(rowType.Elem()).Elem())) + } + } + } else { + valuePtr := reflect.ValueOf(scanValues[0]).Elem() + if !valuePtr.IsNil() { + data = fixValue(colTypes[0].DatabaseTypeName(), valuePtr.Elem()) + } + } + + if resultsValue.Kind() == reflect.Slice { + if originRowType.Kind() == reflect.Ptr { + resultsValue = reflect.Append(resultsValue, data.Addr()) + } else { + resultsValue = reflect.Append(resultsValue, data) + } + } else { + resultsValue = data + break + } + } + + if isNew && resultsValue.IsValid() { + reflect.ValueOf(results).Elem().Set(resultsValue) + } + return nil +} + +func fixValue(colType string, v reflect.Value) reflect.Value { + if v.Kind() == reflect.String { + str := v.String() + switch colType { + case "DATE": + if len(str) >= 10 && str[4] == '-' && str[7] == '-' { + return reflect.ValueOf(str[:10]) + } + case "DATETIME": + if len(str) >= 19 && str[10] == 'T' && str[4] == '-' && str[7] == '-' && str[13] == ':' && str[16] == ':' { + str = strings.TrimRight(str, "Z") + if len(str) > 19 && str[19] == '.' { + return reflect.ValueOf(str[:10] + " " + str[11:]) + } + return reflect.ValueOf(str[:10] + " " + str[11:19]) + } + case "TIME": + if len(str) >= 8 && str[2] == ':' && str[4] == ':' { + if len(str) >= 15 && str[8] == '.' { + return reflect.ValueOf(str[0:15]) + } + return reflect.ValueOf(str[0:8]) + } + } + } + return v +} + +func (r *QueryResult) getColumnTypes() ([]*sql.ColumnType, error) { + if r.rows == nil { + return nil, errors.New("not a valid query result") + } + + return r.rows.ColumnTypes() +} + +func makePublicVarName(name string) string { + if len(name) > 0 && name[0] >= 'a' && name[0] <= 'z' { + return string(name[0]-32) + name[1:] + } + return name +} + +func makeValue(t reflect.Type) any { + if t == nil { + return new(*string) + } + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + + switch t.Kind() { + case reflect.Int: + return new(*int) + case reflect.Int8: + return new(*int8) + case reflect.Int16: + return new(*int16) + case reflect.Int32: + return new(*int32) + case reflect.Int64: + return new(*int64) + case reflect.Uint: + return new(*uint) + case reflect.Uint8: + return new(*uint8) + case reflect.Uint16: + return new(*uint16) + case reflect.Uint32: + return new(*uint32) + case reflect.Uint64: + return new(*uint64) + case reflect.Float32: + return new(*float32) + case reflect.Float64: + return new(*float64) + case reflect.Bool: + return new(*bool) + case reflect.String: + return new(*string) + } + + if t.Kind() == reflect.Slice && t.Elem().Kind() == reflect.Uint8 { + return new(*[]byte) + } + + return new(*string) +} diff --git a/SSL.go b/SSL.go new file mode 100644 index 0000000..cefa4d3 --- /dev/null +++ b/SSL.go @@ -0,0 +1,28 @@ +package db + +import ( + "crypto/tls" + "crypto/x509" + + "apigo.cc/go/log" +) + +func BuildTLSConfig(ca, cert, key []byte, insecure bool) *tls.Config { + caPool := x509.NewCertPool() + if !caPool.AppendCertsFromPEM(ca) { + log.DefaultLogger.Error("ca error for db") + return nil + } + + certs, err := tls.X509KeyPair(cert, key) + if err != nil { + log.DefaultLogger.Error(err.Error()) + return nil + } + + return &tls.Config{ + Certificates: []tls.Certificate{certs}, + RootCAs: caPool, + InsecureSkipVerify: insecure, + } +} diff --git a/Stmt.go b/Stmt.go new file mode 100644 index 0000000..09be20c --- /dev/null +++ b/Stmt.go @@ -0,0 +1,44 @@ +package db + +import ( + "database/sql" + "errors" + "time" + + "apigo.cc/go/log" +) + +type Stmt struct { + conn *sql.Stmt + lastSql *string + lastArgs []any + Error error + logger *dbLogger +} + +func (stmt *Stmt) Exec(args ...any) *ExecResult { + stmt.lastArgs = args + if stmt.conn == nil { + return &ExecResult{Sql: stmt.lastSql, Args: stmt.lastArgs, usedTime: -1, logger: stmt.logger, Error: errors.New("operate on a bad connection")} + } + startTime := time.Now() + r, err := stmt.conn.Exec(args...) + endTime := time.Now() + usedTime := log.MakeUsedTime(startTime, endTime) + if err != nil { + stmt.logger.LogQueryError(err.Error(), *stmt.lastSql, stmt.lastArgs, usedTime) + return &ExecResult{Sql: stmt.lastSql, Args: stmt.lastArgs, usedTime: usedTime, logger: stmt.logger, Error: err} + } + return &ExecResult{Sql: stmt.lastSql, Args: stmt.lastArgs, usedTime: usedTime, logger: stmt.logger, result: r} +} + +func (stmt *Stmt) Close() error { + if stmt.conn == nil { + return errors.New("operate on a bad connection") + } + err := stmt.conn.Close() + if err != nil { + stmt.logger.LogQueryError(err.Error(), *stmt.lastSql, stmt.lastArgs, -1) + } + return err +} diff --git a/TEST.md b/TEST.md new file mode 100644 index 0000000..1b861e9 --- /dev/null +++ b/TEST.md @@ -0,0 +1,23 @@ +# Test Results for @go/db + +## 📊 Summary +- **Module**: `apigo.cc/go/db` +- **Total Tests**: 4 +- **Passed**: 4 +- **Failed**: 0 +- **Build Status**: Success +- **Date**: 2026-05-03 + +## ✅ Details +| Test Case | Status | Duration | Notes | +| :--- | :--- | :--- | :--- | +| `TestMakeInsertSql` | PASS | 0.00s | Verified SQL generation logic for Struct models | +| `TestBaseSelect` | PASS | 0.00s | Verified Result binding (Struct, Map, Base types) | +| `TestInsertReplaceUpdateDelete` | PASS | 0.01s | Verified CRUD operations with SQLite | +| `TestTransaction` | PASS | 0.03s | Verified Transaction isolation and Rollback/Commit | + +## 🚀 Benchmarks +| Benchmark | Iterations | Time/op | Conn | +| :--- | :--- | :--- | :--- | +| `BenchmarkForPool` | - | - | Passed (Manual run verified pool reuse) | +| `BenchmarkForPoolParallel` | - | - | Passed (Manual run verified high concurrency) | diff --git a/Tx.go b/Tx.go new file mode 100644 index 0000000..9fe2da7 --- /dev/null +++ b/Tx.go @@ -0,0 +1,183 @@ +package db + +import ( + "database/sql" + "errors" + "fmt" + "time" +) + +type Tx struct { + conn *sql.Tx + lastSql *string + lastArgs []any + Error error + logger *dbLogger + logSlow time.Duration + isCommittedOrRollbacked bool + QuoteTag string +} + +func (tx *Tx) Quote(text string) string { + return quote(tx.QuoteTag, text) +} + +func (tx *Tx) Quotes(texts []string) string { + return quotes(tx.QuoteTag, texts) +} + +func (tx *Tx) Commit() error { + if tx.isCommittedOrRollbacked { + return nil + } + if tx.conn == nil { + return errors.New("operate on a bad connection") + } + err := tx.conn.Commit() + if err != nil { + tx.logger.LogQueryError(err.Error(), *tx.lastSql, tx.lastArgs, -1) + } else { + tx.isCommittedOrRollbacked = true + } + return err +} + +func (tx *Tx) Rollback() error { + if tx.isCommittedOrRollbacked { + return nil + } + if tx.conn == nil { + return errors.New("operate on a bad connection") + } + err := tx.conn.Rollback() + if err != nil { + tx.logger.LogQueryError(err.Error(), *tx.lastSql, tx.lastArgs, -1) + } else { + tx.isCommittedOrRollbacked = true + } + return err +} + +func (tx *Tx) Finish(ok bool) error { + if tx.isCommittedOrRollbacked { + return nil + } + if ok { + return tx.Commit() + } + return tx.Rollback() +} + +func (tx *Tx) CheckFinished() error { + if tx.isCommittedOrRollbacked { + return nil + } + return tx.Rollback() +} + +func (tx *Tx) Prepare(query string) *Stmt { + tx.lastSql = &query + r := basePrepare(nil, tx.conn, query) + r.logger = tx.logger + if r.Error != nil { + tx.logger.LogQueryError(r.Error.Error(), *tx.lastSql, tx.lastArgs, -1) + } + return r +} + +func (tx *Tx) Exec(query string, args ...any) *ExecResult { + tx.lastSql = &query + tx.lastArgs = args + r := baseExec(nil, tx.conn, query, args...) + r.logger = tx.logger + if r.Error != nil { + tx.logger.LogQueryError(r.Error.Error(), *tx.lastSql, tx.lastArgs, r.usedTime) + } else { + if tx.logSlow > 0 && r.usedTime >= float32(tx.logSlow/time.Millisecond) { + tx.logger.LogQuery(*tx.lastSql, tx.lastArgs, r.usedTime) + } + } + return r +} + +func (tx *Tx) Query(query string, args ...any) *QueryResult { + tx.lastSql = &query + tx.lastArgs = args + r := baseQuery(nil, tx.conn, query, args...) + r.logger = tx.logger + if r.Error != nil { + tx.logger.LogQueryError(r.Error.Error(), *tx.lastSql, tx.lastArgs, r.usedTime) + } else { + if tx.logSlow > 0 && r.usedTime >= float32(tx.logSlow/time.Millisecond) { + tx.logger.LogQuery(*tx.lastSql, tx.lastArgs, r.usedTime) + } + } + return r +} + +func (tx *Tx) Insert(table string, data any) *ExecResult { + query, values := tx.MakeInsertSql(table, data, false) + tx.lastSql = &query + tx.lastArgs = values + r := baseExec(nil, tx.conn, query, values...) + r.logger = tx.logger + if r.Error != nil { + tx.logger.LogQueryError(r.Error.Error(), *tx.lastSql, tx.lastArgs, r.usedTime) + } else { + if tx.logSlow > 0 && r.usedTime >= float32(tx.logSlow/time.Millisecond) { + tx.logger.LogQuery(*tx.lastSql, tx.lastArgs, r.usedTime) + } + } + return r +} + +func (tx *Tx) Replace(table string, data any) *ExecResult { + query, values := tx.MakeInsertSql(table, data, true) + tx.lastSql = &query + tx.lastArgs = values + r := baseExec(nil, tx.conn, query, values...) + r.logger = tx.logger + if r.Error != nil { + tx.logger.LogQueryError(r.Error.Error(), *tx.lastSql, tx.lastArgs, r.usedTime) + } else { + if tx.logSlow > 0 && r.usedTime >= float32(tx.logSlow/time.Millisecond) { + tx.logger.LogQuery(*tx.lastSql, tx.lastArgs, r.usedTime) + } + } + return r +} + +func (tx *Tx) Update(table string, data any, conditions string, args ...any) *ExecResult { + query, values := tx.MakeUpdateSql(table, data, conditions, args...) + tx.lastSql = &query + tx.lastArgs = values + r := baseExec(nil, tx.conn, query, values...) + r.logger = tx.logger + if r.Error != nil { + tx.logger.LogQueryError(r.Error.Error(), *tx.lastSql, tx.lastArgs, r.usedTime) + } else { + if tx.logSlow > 0 && r.usedTime >= float32(tx.logSlow/time.Millisecond) { + tx.logger.LogQuery(*tx.lastSql, tx.lastArgs, r.usedTime) + } + } + return r +} + +func (tx *Tx) Delete(table string, conditions string, args ...any) *ExecResult { + if conditions != "" { + conditions = " where " + conditions + } + query := fmt.Sprintf("delete from %s%s", tx.Quote(table), conditions) + tx.lastSql = &query + tx.lastArgs = args + r := baseExec(nil, tx.conn, query, args...) + r.logger = tx.logger + if r.Error != nil { + tx.logger.LogQueryError(r.Error.Error(), *tx.lastSql, tx.lastArgs, r.usedTime) + } else { + if tx.logSlow > 0 && r.usedTime >= float32(tx.logSlow/time.Millisecond) { + tx.logger.LogQuery(*tx.lastSql, tx.lastArgs, r.usedTime) + } + } + return r +} diff --git a/db.json.sample b/db.json.sample new file mode 100644 index 0000000..6e9ed36 --- /dev/null +++ b/db.json.sample @@ -0,0 +1,12 @@ +{ + "test": { + "type": "mysql", + "user": "star", + "password": "...", + "host": "localhost:3306", + "db": "test", + "maxOpens": 100, + "maxIdles": 30, + "maxLifeTime": 3600 + } +} diff --git a/dbInit.go.sample b/dbInit.go.sample new file mode 100644 index 0000000..1415c4d --- /dev/null +++ b/dbInit.go.sample @@ -0,0 +1,11 @@ +package main + +import "apigo.cc/go/db" + +func init() { + // 强烈建议在应用启动时设置自定义加密密钥,确保数据库密码存储安全 + // key 和 iv 建议通过环境变量或安全存储获取,不要直接硬编码在生产代码中 + key := []byte("your-32-byte-long-secure-key----") + iv := []byte("your-16-byte-iv-") + db.SetEncryptKeys(key, iv) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..db663a7 --- /dev/null +++ b/go.mod @@ -0,0 +1,43 @@ +module apigo.cc/go/db + +go 1.25.0 + +require ( + apigo.cc/go/cast v1.1.1 + apigo.cc/go/config v1.0.4 + apigo.cc/go/convert v1.0.4 + apigo.cc/go/crypto v1.0.4 + apigo.cc/go/id v1.0.4 + apigo.cc/go/log v1.0.0 + apigo.cc/go/rand v1.0.4 + apigo.cc/go/safe v1.0.4 + apigo.cc/go/shell v1.0.4 + github.com/go-sql-driver/mysql v1.10.0 + github.com/jackc/pgx/v5 v5.9.2 + github.com/mitchellh/mapstructure v1.5.0 + modernc.org/sqlite v1.50.0 +) + +require ( + apigo.cc/go/encoding v1.0.4 // indirect + apigo.cc/go/file v1.0.4 // indirect + filippo.io/edwards25519 v1.2.0 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/ncruces/go-strftime v1.0.0 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/rogpeppe/go-internal v1.14.1 // indirect + golang.org/x/crypto v0.50.0 // indirect + golang.org/x/sync v0.20.0 // indirect + golang.org/x/sys v0.43.0 // indirect + golang.org/x/text v0.36.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect + modernc.org/libc v1.72.0 // indirect + modernc.org/mathutil v1.7.1 // indirect + modernc.org/memory v1.11.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..4a5089d --- /dev/null +++ b/go.sum @@ -0,0 +1,114 @@ +apigo.cc/go/cast v1.1.1 h1:+5pluN8g1RK2J4byr2xkfOmEdKSmy1PByOqDOHtt/Ns= +apigo.cc/go/cast v1.1.1/go.mod h1:vh9ZqISCmTUiyinkNMI/s4f045fRlDK3xC+nPWQYBzI= +apigo.cc/go/config v1.0.4 h1:WG9zrQkqfFPkrKIL7RNvvAbbkuUBt1Av11ZP/aIfldM= +apigo.cc/go/config v1.0.4/go.mod h1:obryzJiK6j7lQex/58d5eWYOGx5O5IABguqNWxyyXJo= +apigo.cc/go/convert v1.0.4 h1:5+qPjC3dlPB59GnWZRlmthxcaXQtKvN+iOuiLdJ1GvQ= +apigo.cc/go/convert v1.0.4/go.mod h1:Hp+geeSyhqg/zwIKPOrDoceIREzcwM14t1I5q/dtbfU= +apigo.cc/go/crypto v1.0.4 h1:VPUyHCH2N3LLEgdpwUc+DQssNHzLlxVzLNRa0Jm6O4o= +apigo.cc/go/crypto v1.0.4/go.mod h1:5sI8BLw6YHZfDReYwCO3TFD2LKm36HMdLg1S5oPv/QU= +apigo.cc/go/encoding v1.0.4 h1:aezB0J/qFuHs6iXkbtuJP5JIHUtmjsr5SFb0NNvbObY= +apigo.cc/go/encoding v1.0.4/go.mod h1:V5CgT7rBbCxy+uCU20q0ptcNNRSgMtpA8cNOs6r8IeI= +apigo.cc/go/file v1.0.4 h1:qCKegV7OYh7r0qc3jZjGA/aKh0vIHgmr1OEbhfEmGX8= +apigo.cc/go/file v1.0.4/go.mod h1:C9gNo7386iA21OiBmuWh6CznKWlVBDFkhE4f0H0Susg= +apigo.cc/go/id v1.0.4 h1:w+JSdeVit52iefIUolrh1qLEZS9XqHNKr1UygFcgv+s= +apigo.cc/go/id v1.0.4/go.mod h1:kg7QuceAKtGNzGWt0+pIIh8Qom1eMSWGb8+0Yhi/QVY= +apigo.cc/go/log v1.0.0 h1:lI1NGTSS+Jm12G8BD7ZJO4/hrkfuLTu5O8z36GD8GpU= +apigo.cc/go/log v1.0.0/go.mod h1:tvPgFpebY9Wf/DlqMHZ0ZjxDp9AaQTywOQKvtBaNqNo= +apigo.cc/go/rand v1.0.4 h1:we070eWSL0dB8NEMaWjXj43+EekXQTm/h0kKpZ/frqw= +apigo.cc/go/rand v1.0.4/go.mod h1:mZ/4Soa3bk+XvDaqPWJuUe1bfEi4eThBj1XmEAuYxsk= +apigo.cc/go/safe v1.0.4 h1:07pRSdEHprF/2v6SsqAjICYFoeLcqjjvHGEdh6Dzrzg= +apigo.cc/go/safe v1.0.4/go.mod h1:o568sHS5rTRSVPmhxWod0tGdc+8l1KjidsNY1/OVZr0= +apigo.cc/go/shell v1.0.4 h1:EL9zjI39YBe1h+kRYQeAi/8zVGHe5W198DYYN7cENiY= +apigo.cc/go/shell v1.0.4/go.mod h1:N2gDkgK4tJ9TadD60/+gAGuWxyVAWHs5YPBmytw6ELA= +filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo= +filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/go-sql-driver/mysql v1.10.0 h1:Q+1LV8DkHJvSYAdR83XzuhDaTykuDx0l6fkXxoWCWfw= +github.com/go-sql-driver/mysql v1.10.0/go.mod h1:M+cqaI7+xxXGG9swrdeUIoPG3Y3KCkF0pZej+SK+nWk= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.9.2 h1:3ZhOzMWnR4yJ+RW1XImIPsD1aNSz4T4fyP7zlQb56hw= +github.com/jackc/pgx/v5 v5.9.2/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= +github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= +golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= +golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI= +golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= +golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= +golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= +golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s= +golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +modernc.org/cc/v4 v4.27.3 h1:uNCgn37E5U09mTv1XgskEVUJ8ADKpmFMPxzGJ0TSo+U= +modernc.org/cc/v4 v4.27.3/go.mod h1:3YjcbCqhoTTHPycJDRl2WZKKFj0nwcOIPBfEZK0Hdk8= +modernc.org/ccgo/v4 v4.32.4 h1:L5OB8rpEX4ZsXEQwGozRfJyJSFHbbNVOoQ59DU9/KuU= +modernc.org/ccgo/v4 v4.32.4/go.mod h1:lY7f+fiTDHfcv6YlRgSkxYfhs+UvOEEzj49jAn2TOx0= +modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM= +modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU= +modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= +modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= +modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo= +modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY= +modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= +modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= +modernc.org/libc v1.72.0 h1:IEu559v9a0XWjw0DPoVKtXpO2qt5NVLAnFaBbjq+n8c= +modernc.org/libc v1.72.0/go.mod h1:tTU8DL8A+XLVkEY3x5E/tO7s2Q/q42EtnNWda/L5QhQ= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= +modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= +modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= +modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= +modernc.org/sqlite v1.50.0 h1:eMowQSWLK0MeiQTdmz3lqoF5dqclujdlIKeJA11+7oM= +modernc.org/sqlite v1.50.0/go.mod h1:m0w8xhwYUVY3H6pSDwc3gkJ/irZT/0YEXwBlhaxQEew= +modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= +modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= diff --git a/mysql/connector.go b/mysql/connector.go new file mode 100644 index 0000000..d0a00d8 --- /dev/null +++ b/mysql/connector.go @@ -0,0 +1,48 @@ +package mysql + +import ( + "context" + "crypto/tls" + "database/sql/driver" + + "github.com/go-sql-driver/mysql" + "apigo.cc/go/db" + "apigo.cc/go/safe" +) + +type SecureMysqlConnector struct { + conf *db.Config + pwd *safe.SafeBuf + tls *tls.Config +} + +func (c *SecureMysqlConnector) Connect(ctx context.Context) (driver.Conn, error) { + cfg, err := mysql.ParseDSN(c.conf.Args) + if err != nil { + cfg = mysql.NewConfig() + } + cfg.User = c.conf.User + cfg.Net = "tcp" + cfg.Addr = c.conf.Host + cfg.DBName = c.conf.DB + cfg.TLS = c.tls + pwdBuf := c.pwd.Open() + defer pwdBuf.Close() + cfg.Passwd = pwdBuf.String() + + tmpConnector, err := mysql.NewConnector(cfg) + if err != nil { + return nil, err + } + return tmpConnector.Connect(ctx) +} + +func (c *SecureMysqlConnector) Driver() driver.Driver { + return &mysql.MySQLDriver{} +} + +func init() { + db.RegisterConnector("mysql", func(conf *db.Config, pwd *safe.SafeBuf, tls *tls.Config) driver.Connector { + return &SecureMysqlConnector{conf: conf, pwd: pwd, tls: tls} + }) +} diff --git a/pgx/connector.go b/pgx/connector.go new file mode 100644 index 0000000..ab78579 --- /dev/null +++ b/pgx/connector.go @@ -0,0 +1,48 @@ +package pgx + +import ( + "context" + "crypto/tls" + "database/sql/driver" + "fmt" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/stdlib" + "apigo.cc/go/db" + "apigo.cc/go/safe" +) + +type SecurePgxConnector struct { + conf *db.Config + pwd *safe.SafeBuf + tls *tls.Config +} + +func (c *SecurePgxConnector) Connect(ctx context.Context) (driver.Conn, error) { + dsn := fmt.Sprintf("postgres://%s@%s/%s?%s", c.conf.User, c.conf.Host, c.conf.DB, c.conf.Args) + pgxConfig, err := pgx.ParseConfig(dsn) + if err != nil { + return nil, err + } + if c.tls != nil { + pgxConfig.TLSConfig = c.tls + } + pwdBuf := c.pwd.Open() + defer pwdBuf.Close() + pgxConfig.Password = pwdBuf.String() + tmpConnector := stdlib.GetConnector(*pgxConfig) + return tmpConnector.Connect(ctx) +} + +func (c *SecurePgxConnector) Driver() driver.Driver { + return stdlib.GetDefaultDriver() +} + +func PgxConnector(conf *db.Config, pwd *safe.SafeBuf, tls *tls.Config) driver.Connector { + return &SecurePgxConnector{conf: conf, pwd: pwd, tls: tls} +} + +func init() { + db.RegisterConnector("postgres", PgxConnector) + db.RegisterConnector("pgx", PgxConnector) +} diff --git a/sqlite.go b/sqlite.go new file mode 100644 index 0000000..6537ca5 --- /dev/null +++ b/sqlite.go @@ -0,0 +1,5 @@ +package db + +import ( + _ "modernc.org/sqlite" +) diff --git a/test.db b/test.db new file mode 100644 index 0000000000000000000000000000000000000000..703d4c7aeb56ab3fd16332f5f6cb9229065cf472 GIT binary patch literal 12288 zcmeI&QESss6bJB|wn!(g-IIIp<$w!K*gDd#tzGwGwq`CfNoRH|1_lzhTNY^AEs28f zegeU-;l~g^iXX$RR&zqde+UpY}GXf#dYFrzKD!W4x)Ct{4P z$t=n&$W&xLl39{DkJ;RxVvTKn`MzBF%9P3ptDMLi1Oy-e0SG_<0uX=z1Rwwb2>f>f zM}-Hg>+4F8h4029y&TIA&9|wzd@m-acZbD2-!i#H+}w36x>Ku7!dPs`PF;S8`m2qG z^2xs1GZXQ}FcnF? zj|W~p6P{+#2SMES0(r8V&8ctgn?r|_k(XtWJVjsHaoe>PtF$yrMzb2NrUrF2uB&IH z9`sbVr;cj6E=BXGeye`2zL}43g#ZK~009U<00Izz00bZa0SG|gJ_6lRndwg%EAJeN zB>7lN$qln}ku^v7s#f{ZRS| D4cM>p literal 0 HcmV?d00001