Optimize db module: pre-calculate field mappings in makeResults, fix typos, and standardize naming (by AI)
This commit is contained in:
parent
8d5e2cd8c8
commit
6c2b2fed4d
16
AI.md
Normal file
16
AI.md
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
# AI 指南 - @go/db
|
||||||
|
|
||||||
|
## 🤖 AI 调用规则
|
||||||
|
- **版本**: v1.0.1
|
||||||
|
- **核心原则**: 优先使用结构化绑定(`To`, `MapResults`),避免手动拼装 SQL 结果。
|
||||||
|
- **敏感数据**: 必须通过 `SetEncryptKeys` 配置密钥,确保 DSN 中的密码安全。
|
||||||
|
- **读写分离**: 鼓励在 DSN 中配置多个 Host 以利用内置的读写分离机制。
|
||||||
|
- **性能优化**:
|
||||||
|
- 大规模查询应优先绑定到 Struct 切片。
|
||||||
|
- 频繁执行的 SQL 应使用 `Prepare`。
|
||||||
|
- **事务处理**: 始终使用 `tx.Finish(err == nil)` 或 `defer tx.CheckFinished()` 确保事务闭环。
|
||||||
|
|
||||||
|
## ⚠️ 注意事项
|
||||||
|
- 严禁在代码中硬编码数据库凭据。
|
||||||
|
- 严禁忽略 `Exec` 或 `Query` 返回的 `Error`。
|
||||||
|
- SQLite 模式下,时间字段会自动转换,无需手动解析字符串。
|
||||||
213
Base.go
Normal file
213
Base.go
Normal file
@ -0,0 +1,213 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"apigo.cc/go/cast"
|
||||||
|
"apigo.cc/go/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
func basePrepare(db *sql.DB, tx *sql.Tx, query string) *Stmt {
|
||||||
|
var sqlStmt *sql.Stmt
|
||||||
|
var err error
|
||||||
|
if tx != nil {
|
||||||
|
sqlStmt, err = tx.Prepare(query)
|
||||||
|
} else if db != nil {
|
||||||
|
sqlStmt, err = db.Prepare(query)
|
||||||
|
} else {
|
||||||
|
return &Stmt{Error: errors.New("operate on a bad connection")}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return &Stmt{Error: err}
|
||||||
|
}
|
||||||
|
return &Stmt{conn: sqlStmt, lastSql: &query}
|
||||||
|
}
|
||||||
|
|
||||||
|
func baseExec(db *sql.DB, tx *sql.Tx, query string, args ...any) *ExecResult {
|
||||||
|
args = flatArgs(args)
|
||||||
|
var r sql.Result
|
||||||
|
var err error
|
||||||
|
startTime := time.Now()
|
||||||
|
if tx != nil {
|
||||||
|
r, err = tx.Exec(query, args...)
|
||||||
|
} else if db != nil {
|
||||||
|
r, err = db.Exec(query, args...)
|
||||||
|
} else {
|
||||||
|
return &ExecResult{Sql: &query, Args: args, usedTime: log.MakeUsedTime(startTime, time.Now()), Error: errors.New("operate on a bad connection")}
|
||||||
|
}
|
||||||
|
endTime := time.Now()
|
||||||
|
usedTime := log.MakeUsedTime(startTime, endTime)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return &ExecResult{Sql: &query, Args: args, usedTime: usedTime, Error: err}
|
||||||
|
}
|
||||||
|
return &ExecResult{Sql: &query, Args: args, usedTime: usedTime, result: r}
|
||||||
|
}
|
||||||
|
|
||||||
|
func flatArgs(args []any) []any {
|
||||||
|
for i, arg := range args {
|
||||||
|
if arg == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
argValue := reflect.ValueOf(arg)
|
||||||
|
kind := argValue.Kind()
|
||||||
|
if kind == reflect.Map || kind == reflect.Struct || (kind == reflect.Slice && argValue.Type().Elem().Kind() != reflect.Uint8) {
|
||||||
|
args[i] = cast.MustToJSON(arg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
|
func baseQuery(db *sql.DB, tx *sql.Tx, query string, args ...any) *QueryResult {
|
||||||
|
args = flatArgs(args)
|
||||||
|
|
||||||
|
var rows *sql.Rows
|
||||||
|
var err error
|
||||||
|
startTime := time.Now()
|
||||||
|
if tx != nil {
|
||||||
|
rows, err = tx.Query(query, args...)
|
||||||
|
} else if db != nil {
|
||||||
|
rows, err = db.Query(query, args...)
|
||||||
|
} else {
|
||||||
|
return &QueryResult{Sql: &query, Args: args, usedTime: log.MakeUsedTime(startTime, time.Now()), Error: errors.New("operate on a bad connection")}
|
||||||
|
}
|
||||||
|
endTime := time.Now()
|
||||||
|
usedTime := log.MakeUsedTime(startTime, endTime)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return &QueryResult{Sql: &query, Args: args, usedTime: usedTime, Error: err}
|
||||||
|
}
|
||||||
|
return &QueryResult{Sql: &query, Args: args, usedTime: usedTime, rows: rows}
|
||||||
|
}
|
||||||
|
|
||||||
|
func quote(quoteTag string, text string) string {
|
||||||
|
a := strings.Split(text, ".")
|
||||||
|
for i, v := range a {
|
||||||
|
a[i] = quoteTag + strings.ReplaceAll(v, quoteTag, "\\"+quoteTag) + quoteTag
|
||||||
|
}
|
||||||
|
return strings.Join(a, ".")
|
||||||
|
}
|
||||||
|
|
||||||
|
func quotes(quoteTag string, texts []string) string {
|
||||||
|
for i, v := range texts {
|
||||||
|
texts[i] = quote(quoteTag, v)
|
||||||
|
}
|
||||||
|
return strings.Join(texts, ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeInsertSql(quoteTag string, table string, data any, useReplace bool) (string, []any) {
|
||||||
|
keys, vars, values := MakeKeysVarsValues(data)
|
||||||
|
operation := "insert"
|
||||||
|
if useReplace {
|
||||||
|
operation = "replace"
|
||||||
|
}
|
||||||
|
query := fmt.Sprintf("%s into %s (%s) values (%s)", operation, quote(quoteTag, table), quotes(quoteTag, keys), strings.Join(vars, ","))
|
||||||
|
return query, values
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeUpdateSql(quoteTag string, table string, data any, conditions string, args ...any) (string, []any) {
|
||||||
|
args = flatArgs(args)
|
||||||
|
keys, vars, values := MakeKeysVarsValues(data)
|
||||||
|
for i, k := range keys {
|
||||||
|
keys[i] = fmt.Sprintf("%s=%s", quote(quoteTag, k), vars[i])
|
||||||
|
}
|
||||||
|
values = append(values, args...)
|
||||||
|
if conditions != "" {
|
||||||
|
conditions = " where " + conditions
|
||||||
|
}
|
||||||
|
query := fmt.Sprintf("update %s set %s%s", quote(quoteTag, table), strings.Join(keys, ","), conditions)
|
||||||
|
return query, values
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) MakeInsertSql(table string, data any, useReplace bool) (string, []any) {
|
||||||
|
return makeInsertSql(db.QuoteTag, table, data, useReplace)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) MakeUpdateSql(table string, data any, conditions string, args ...any) (string, []any) {
|
||||||
|
return makeUpdateSql(db.QuoteTag, table, data, conditions, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) MakeInsertSql(table string, data any, useReplace bool) (string, []any) {
|
||||||
|
return makeInsertSql(tx.QuoteTag, table, data, useReplace)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) MakeUpdateSql(table string, data any, conditions string, args ...any) (string, []any) {
|
||||||
|
return makeUpdateSql(tx.QuoteTag, table, data, conditions, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getFlatFields(fields map[string]reflect.Value, fieldKeys *[]string, value reflect.Value) {
|
||||||
|
valueType := value.Type()
|
||||||
|
for i := 0; i < value.NumField(); i++ {
|
||||||
|
v := value.Field(i)
|
||||||
|
if valueType.Field(i).Anonymous {
|
||||||
|
getFlatFields(fields, fieldKeys, v)
|
||||||
|
} else {
|
||||||
|
*fieldKeys = append(*fieldKeys, valueType.Field(i).Name)
|
||||||
|
fields[valueType.Field(i).Name] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func MakeKeysVarsValues(data any) ([]string, []string, []any) {
|
||||||
|
keys := make([]string, 0)
|
||||||
|
vars := make([]string, 0)
|
||||||
|
values := make([]any, 0)
|
||||||
|
|
||||||
|
dataType := reflect.TypeOf(data)
|
||||||
|
dataValue := reflect.ValueOf(data)
|
||||||
|
for dataType.Kind() == reflect.Ptr {
|
||||||
|
dataType = dataType.Elem()
|
||||||
|
dataValue = dataValue.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if dataType.Kind() == reflect.Struct {
|
||||||
|
fields := make(map[string]reflect.Value)
|
||||||
|
fieldKeys := make([]string, 0)
|
||||||
|
getFlatFields(fields, &fieldKeys, dataValue)
|
||||||
|
for _, k := range fieldKeys {
|
||||||
|
if k[0] >= 'a' && k[0] <= 'z' {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
v := fields[k]
|
||||||
|
if v.Kind() == reflect.Interface {
|
||||||
|
v = v.Elem()
|
||||||
|
}
|
||||||
|
keys = append(keys, k)
|
||||||
|
if v.Kind() == reflect.String && v.Len() > 0 && v.String()[0] == ':' {
|
||||||
|
vars = append(vars, v.String()[1:])
|
||||||
|
} else {
|
||||||
|
vars = append(vars, "?")
|
||||||
|
if !v.IsValid() || !v.CanInterface() {
|
||||||
|
values = append(values, nil)
|
||||||
|
} else {
|
||||||
|
values = append(values, v.Interface())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if dataType.Kind() == reflect.Map {
|
||||||
|
for _, k := range dataValue.MapKeys() {
|
||||||
|
v := dataValue.MapIndex(k)
|
||||||
|
if v.Kind() == reflect.Interface {
|
||||||
|
v = v.Elem()
|
||||||
|
}
|
||||||
|
keys = append(keys, cast.String(k.Interface()))
|
||||||
|
if v.Kind() == reflect.String && v.Len() > 0 && v.String()[0] == ':' {
|
||||||
|
vars = append(vars, v.String()[1:])
|
||||||
|
} else {
|
||||||
|
vars = append(vars, "?")
|
||||||
|
if !v.IsValid() || !v.CanInterface() {
|
||||||
|
values = append(values, nil)
|
||||||
|
} else {
|
||||||
|
values = append(values, v.Interface())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return keys, vars, values
|
||||||
|
}
|
||||||
12
CHANGELOG.md
Normal file
12
CHANGELOG.md
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
# CHANGELOG - @go/db
|
||||||
|
|
||||||
|
## [1.0.1] - 2026-05-03
|
||||||
|
### Optimized
|
||||||
|
- Refactored `makeResults` to pre-calculate field mappings for structs, significantly improving performance for large result sets.
|
||||||
|
- Simplified and optimized `makeValue` and `makePublicVarName` functions.
|
||||||
|
- Optimized time parsing in `makeResults`.
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
- Fixed typo `isCommitedOrRollbacked` to `isCommittedOrRollbacked` in `Tx` struct.
|
||||||
|
- Standardized parameter naming: renamed `requestSql` to `query` and `wheres` to `conditions` across the module.
|
||||||
|
- Modernized Go syntax to align with latest standards.
|
||||||
611
DB.go
Normal file
611
DB.go
Normal file
@ -0,0 +1,611 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"apigo.cc/go/cast"
|
||||||
|
"apigo.cc/go/config"
|
||||||
|
"apigo.cc/go/crypto"
|
||||||
|
"apigo.cc/go/id"
|
||||||
|
"apigo.cc/go/log"
|
||||||
|
"apigo.cc/go/rand"
|
||||||
|
"apigo.cc/go/safe"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
Type string
|
||||||
|
User string
|
||||||
|
Password string
|
||||||
|
pwd *safe.SafeBuf
|
||||||
|
Host string
|
||||||
|
ReadonlyHosts []string
|
||||||
|
DB string
|
||||||
|
SSL string
|
||||||
|
tls *tls.Config
|
||||||
|
Args string
|
||||||
|
MaxOpens int
|
||||||
|
MaxIdles int
|
||||||
|
MaxLifeTime int
|
||||||
|
LogSlow config.Duration
|
||||||
|
logger *log.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
type SSL struct {
|
||||||
|
Ca string
|
||||||
|
Cert string
|
||||||
|
Key string
|
||||||
|
Insecure bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type ConnectorFunc func(*Config, *safe.SafeBuf, *tls.Config) driver.Connector
|
||||||
|
|
||||||
|
var connectors = map[string]ConnectorFunc{}
|
||||||
|
|
||||||
|
func RegisterConnector(typ string, conn ConnectorFunc) {
|
||||||
|
connectors[typ] = conn
|
||||||
|
}
|
||||||
|
|
||||||
|
var sqlite3PwdMatcher = regexp.MustCompile(`x'\w+'`)
|
||||||
|
|
||||||
|
func (dbInfo *Config) Dsn() string {
|
||||||
|
args := make([]string, 0)
|
||||||
|
if dbInfo.SSL != "" {
|
||||||
|
args = append(args, "tls="+dbInfo.SSL)
|
||||||
|
}
|
||||||
|
if dbInfo.Args != "" {
|
||||||
|
args = append(args, dbInfo.Args)
|
||||||
|
}
|
||||||
|
argsStr := ""
|
||||||
|
if len(args) > 0 {
|
||||||
|
argsStr = "&" + strings.Join(args, "&")
|
||||||
|
}
|
||||||
|
|
||||||
|
if isFileDB(dbInfo.Type) {
|
||||||
|
argsStr = sqlite3PwdMatcher.ReplaceAllString(argsStr, "******")
|
||||||
|
return fmt.Sprintf("%s://%s?logSlow=%s"+argsStr, dbInfo.Type, dbInfo.Host, dbInfo.LogSlow.TimeDuration())
|
||||||
|
} else {
|
||||||
|
hosts := []string{dbInfo.Host}
|
||||||
|
if dbInfo.ReadonlyHosts != nil {
|
||||||
|
hosts = append(hosts, dbInfo.ReadonlyHosts...)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s://%s:****@%s/%s?logSlow=%s"+argsStr, dbInfo.Type, dbInfo.User, strings.Join(hosts, ","), dbInfo.DB, dbInfo.LogSlow.TimeDuration())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var fileDBs = map[string]bool{
|
||||||
|
"sqlite": true,
|
||||||
|
"sqlite3": true,
|
||||||
|
"chai": true,
|
||||||
|
"access": true,
|
||||||
|
"mdb": true,
|
||||||
|
"accdb": true,
|
||||||
|
"h2": true,
|
||||||
|
"hsqldb": true,
|
||||||
|
"derby": true,
|
||||||
|
"sqlce": true,
|
||||||
|
"sdf": true,
|
||||||
|
"firebird": true,
|
||||||
|
"fdb": true,
|
||||||
|
"dbase": true,
|
||||||
|
"dbf": true,
|
||||||
|
"berkeleydb": true,
|
||||||
|
"bdb": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
func isFileDB(typ string) bool {
|
||||||
|
return fileDBs[typ]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dbInfo *Config) ConfigureBy(setting string) {
|
||||||
|
urlInfo, err := url.Parse(setting)
|
||||||
|
if err != nil {
|
||||||
|
if dbInfo.logger != nil {
|
||||||
|
dbInfo.logger.Error(err.Error(), "url", setting)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
dbInfo.Type = urlInfo.Scheme
|
||||||
|
if isFileDB(dbInfo.Type) {
|
||||||
|
dbInfo.Host = urlInfo.Host + urlInfo.Path
|
||||||
|
dbInfo.DB = strings.SplitN(urlInfo.Host, ".", 2)[0]
|
||||||
|
} else {
|
||||||
|
if strings.ContainsRune(urlInfo.Host, ',') {
|
||||||
|
a := strings.Split(urlInfo.Host, ",")
|
||||||
|
dbInfo.Host = a[0]
|
||||||
|
dbInfo.ReadonlyHosts = a[1:]
|
||||||
|
} else {
|
||||||
|
dbInfo.Host = urlInfo.Host
|
||||||
|
dbInfo.ReadonlyHosts = nil
|
||||||
|
}
|
||||||
|
if len(urlInfo.Path) > 1 {
|
||||||
|
dbInfo.DB = urlInfo.Path[1:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dbInfo.User = urlInfo.User.Username()
|
||||||
|
dbInfo.Password, _ = urlInfo.User.Password()
|
||||||
|
|
||||||
|
q := urlInfo.Query()
|
||||||
|
dbInfo.MaxIdles = cast.Int(q.Get("maxIdles"))
|
||||||
|
dbInfo.MaxLifeTime = cast.Int(q.Get("maxLifeTime"))
|
||||||
|
dbInfo.MaxOpens = cast.Int(q.Get("maxOpens"))
|
||||||
|
dbInfo.LogSlow = config.Duration(cast.Duration(q.Get("logSlow")))
|
||||||
|
dbInfo.SSL = q.Get("tls")
|
||||||
|
|
||||||
|
sslCa := q.Get("sslCA")
|
||||||
|
sslCert := q.Get("sslCert")
|
||||||
|
sslKey := q.Get("sslKey")
|
||||||
|
sslSkipVerify := cast.Bool(q.Get("sslSkipVerify"))
|
||||||
|
if sslCa == "" || sslCert == "" || sslKey == "" {
|
||||||
|
if dbSSL, ok := dbSSLs[dbInfo.SSL]; ok {
|
||||||
|
sslCa = dbSSL.Ca
|
||||||
|
sslCert = dbSSL.Cert
|
||||||
|
sslKey = dbSSL.Key
|
||||||
|
sslSkipVerify = dbSSL.Insecure
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if sslCa != "" && sslCert != "" && sslKey != "" {
|
||||||
|
sslName := id.MakeID(12)
|
||||||
|
dbInfo.SSL = sslName
|
||||||
|
decryptedCa, _ := confAes.DecryptBytes([]byte(sslCa))
|
||||||
|
decryptedCert, _ := confAes.DecryptBytes([]byte(sslCert))
|
||||||
|
decryptedKey, _ := confAes.DecryptBytes([]byte(sslKey))
|
||||||
|
tlsConf := BuildTLSConfig(decryptedCa, decryptedCert, decryptedKey, sslSkipVerify)
|
||||||
|
if tlsConf != nil {
|
||||||
|
dbInfo.tls = tlsConf
|
||||||
|
}
|
||||||
|
safe.ZeroMemory(decryptedCa)
|
||||||
|
safe.ZeroMemory(decryptedCert)
|
||||||
|
safe.ZeroMemory(decryptedKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
args := make([]string, 0)
|
||||||
|
for k := range q {
|
||||||
|
if k != "maxIdles" && k != "maxLifeTime" && k != "maxOpens" && k != "logSlow" && k != "tls" {
|
||||||
|
args = append(args, k+"="+q.Get(k))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(args) > 0 {
|
||||||
|
dbInfo.Args = strings.Join(args, "&")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type DB struct {
|
||||||
|
name string
|
||||||
|
conn *sql.DB
|
||||||
|
readonlyConnections []*sql.DB
|
||||||
|
Config *Config
|
||||||
|
logger *dbLogger
|
||||||
|
Error error
|
||||||
|
QuoteTag string
|
||||||
|
}
|
||||||
|
|
||||||
|
var confAes, _ = crypto.NewAESCBCAndEraseKey([]byte("?GQ$0K0GgLdO=f+~L68PLm$uhKr4'=tV"), []byte("VFs7@sK61cj^f?HZ"))
|
||||||
|
var keysSetted = sync.Once{}
|
||||||
|
|
||||||
|
func SetEncryptKeys(key, iv []byte) {
|
||||||
|
keysSetted.Do(func() {
|
||||||
|
confAes.Close()
|
||||||
|
confAes, _ = crypto.NewAESGCMAndEraseKey(key, iv)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type dbLogger struct {
|
||||||
|
config *Config
|
||||||
|
logger *log.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dl *dbLogger) LogError(errStr string) {
|
||||||
|
dl.logger.DBError(errStr, dl.config.Type, dl.config.Dsn(), "", nil, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dl *dbLogger) LogQuery(query string, args []any, usedTime float32) {
|
||||||
|
dl.logger.DB(dl.config.Type, dl.config.Dsn(), query, args, usedTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dl *dbLogger) LogQueryError(errStr string, query string, args []any, usedTime float32) {
|
||||||
|
dl.logger.DBError(errStr, dl.config.Type, dl.config.Dsn(), query, args, usedTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
var dbConfigs = make(map[string]*Config)
|
||||||
|
var dbConfigsLock = sync.RWMutex{}
|
||||||
|
var dbSSLs = make(map[string]*SSL)
|
||||||
|
var dbInstances = make(map[string]*DB)
|
||||||
|
var dbInstancesLock = sync.RWMutex{}
|
||||||
|
var once sync.Once
|
||||||
|
|
||||||
|
func GetDBWithoutCache(name string, logger *log.Logger) *DB {
|
||||||
|
return getDB(name, logger, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetDB(name string, logger *log.Logger) *DB {
|
||||||
|
return getDB(name, logger, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getDB(name string, logger *log.Logger, useCache bool) *DB {
|
||||||
|
if logger == nil {
|
||||||
|
logger = log.DefaultLogger
|
||||||
|
}
|
||||||
|
|
||||||
|
if useCache {
|
||||||
|
dbInstancesLock.RLock()
|
||||||
|
oldConn := dbInstances[name]
|
||||||
|
dbInstancesLock.RUnlock()
|
||||||
|
if oldConn != nil {
|
||||||
|
return oldConn.CopyByLogger(logger)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var conf *Config
|
||||||
|
if strings.Contains(name, "://") {
|
||||||
|
conf = new(Config)
|
||||||
|
conf.logger = logger
|
||||||
|
conf.ConfigureBy(name)
|
||||||
|
} else {
|
||||||
|
dbConfigsLock.RLock()
|
||||||
|
n := len(dbConfigs)
|
||||||
|
dbConfigsLock.RUnlock()
|
||||||
|
if n == 0 {
|
||||||
|
once.Do(func() {
|
||||||
|
dbConfigs1 := make(map[string]*Config)
|
||||||
|
if err := config.Load("db", &dbConfigs1); err == nil {
|
||||||
|
for k, v := range dbConfigs1 {
|
||||||
|
if v.Host != "" {
|
||||||
|
dbConfigsLock.Lock()
|
||||||
|
dbConfigs[k] = v
|
||||||
|
dbConfigsLock.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.Error(err.Error())
|
||||||
|
}
|
||||||
|
dbConfigs2 := make(map[string]string)
|
||||||
|
if err := config.Load("db", &dbConfigs2); err == nil {
|
||||||
|
for k, v := range dbConfigs2 {
|
||||||
|
if strings.Contains(v, "://") {
|
||||||
|
v2 := new(Config)
|
||||||
|
v2.ConfigureBy(v)
|
||||||
|
if v2.Host != "" {
|
||||||
|
v2.logger = logger
|
||||||
|
dbConfigsLock.Lock()
|
||||||
|
dbConfigs[k] = v2
|
||||||
|
dbConfigsLock.Unlock()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
dbConfigsLock.Lock()
|
||||||
|
v2 := dbConfigs[v]
|
||||||
|
if v2 != nil && v2.Host != "" {
|
||||||
|
dbConfigs[k] = v2
|
||||||
|
}
|
||||||
|
dbConfigsLock.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.Error(err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
dbConfigsLock.RLock()
|
||||||
|
conf = dbConfigs[name]
|
||||||
|
dbConfigsLock.RUnlock()
|
||||||
|
if conf == nil {
|
||||||
|
conf = new(Config)
|
||||||
|
dbConfigsLock.Lock()
|
||||||
|
dbConfigs[name] = conf
|
||||||
|
dbConfigsLock.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if conf.Host == "" {
|
||||||
|
if name != "default" {
|
||||||
|
logger.Error("db config not exists", "name", name)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if conf.SSL != "" && len(dbSSLs) == 0 {
|
||||||
|
_ = config.Load("dbssl", &dbSSLs)
|
||||||
|
}
|
||||||
|
|
||||||
|
if conf.SSL != "" && dbSSLs[conf.SSL] == nil {
|
||||||
|
logger.Error("dbssl config lost")
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.ContainsRune(conf.Host, ',') {
|
||||||
|
a := strings.Split(conf.Host, ",")
|
||||||
|
conf.Host = a[0]
|
||||||
|
conf.ReadonlyHosts = a[1:]
|
||||||
|
} else {
|
||||||
|
conf.ReadonlyHosts = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if conf.Password != "" {
|
||||||
|
if encryptedPassword, err := base64.URLEncoding.DecodeString(conf.Password); err == nil {
|
||||||
|
if pwdSafeBuf, err := confAes.Decrypt(encryptedPassword); err == nil {
|
||||||
|
conf.pwd = pwdSafeBuf
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if conf.pwd == nil {
|
||||||
|
conf.pwd = safe.NewSafeBuf([]byte(conf.Password))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if !isFileDB(conf.Type) && conf.Host != "127.0.0.1:3306" && conf.User == "root" {
|
||||||
|
logger.Warning("password is empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
conf.Password = ""
|
||||||
|
|
||||||
|
conn, err := getPool(conf)
|
||||||
|
if err != nil {
|
||||||
|
logger.DBError(err.Error(), conf.Type, conf.Dsn(), "", nil, 0)
|
||||||
|
return &DB{conn: nil, QuoteTag: "\"", Error: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
db := new(DB)
|
||||||
|
db.QuoteTag = cast.If(conf.Type == "mysql", "`", "\"")
|
||||||
|
db.name = name
|
||||||
|
db.conn = conn
|
||||||
|
|
||||||
|
if conf.ReadonlyHosts != nil {
|
||||||
|
readonlyConnections := make([]*sql.DB, 0)
|
||||||
|
for _, host := range conf.ReadonlyHosts {
|
||||||
|
conn, err := getPoolForHost(conf, host)
|
||||||
|
if err != nil {
|
||||||
|
logger.DBError(err.Error(), conf.Type, conf.Dsn(), "", nil, 0)
|
||||||
|
} else {
|
||||||
|
readonlyConnections = append(readonlyConnections, conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(readonlyConnections) > 0 {
|
||||||
|
db.readonlyConnections = readonlyConnections
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
db.Error = nil
|
||||||
|
db.Config = conf
|
||||||
|
if conf.MaxIdles > 0 {
|
||||||
|
conn.SetMaxIdleConns(conf.MaxIdles)
|
||||||
|
}
|
||||||
|
if conf.MaxOpens > 0 {
|
||||||
|
conn.SetMaxOpenConns(conf.MaxOpens)
|
||||||
|
}
|
||||||
|
if conf.MaxLifeTime > 0 {
|
||||||
|
conn.SetConnMaxLifetime(time.Second * time.Duration(conf.MaxLifeTime))
|
||||||
|
}
|
||||||
|
if conf.LogSlow == 0 {
|
||||||
|
conf.LogSlow = config.Duration(1000 * time.Millisecond)
|
||||||
|
}
|
||||||
|
if useCache {
|
||||||
|
dbInstancesLock.Lock()
|
||||||
|
dbInstances[name] = db
|
||||||
|
dbInstancesLock.Unlock()
|
||||||
|
}
|
||||||
|
return db.CopyByLogger(logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getPool(conf *Config) (*sql.DB, error) {
|
||||||
|
return getPoolForHost(conf, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func getPoolForHost(conf *Config, host string) (*sql.DB, error) {
|
||||||
|
connectType := "tcp"
|
||||||
|
if host == "" {
|
||||||
|
host = conf.Host
|
||||||
|
}
|
||||||
|
if len(host) > 0 && host[0] == '/' {
|
||||||
|
connectType = "unix"
|
||||||
|
}
|
||||||
|
|
||||||
|
if connector := connectors[conf.Type]; connector != nil {
|
||||||
|
return sql.OpenDB(connector(conf, conf.pwd, conf.tls)), nil
|
||||||
|
} else {
|
||||||
|
dsn := ""
|
||||||
|
args := make([]string, 0)
|
||||||
|
if conf.SSL != "" {
|
||||||
|
args = append(args, "tls="+conf.SSL)
|
||||||
|
}
|
||||||
|
if conf.Args != "" {
|
||||||
|
args = append(args, conf.Args)
|
||||||
|
}
|
||||||
|
argsStr := ""
|
||||||
|
if len(args) > 0 {
|
||||||
|
argsStr = "?" + strings.Join(args, "&")
|
||||||
|
}
|
||||||
|
|
||||||
|
if isFileDB(conf.Type) {
|
||||||
|
dsn = host + argsStr
|
||||||
|
} else {
|
||||||
|
pwdBuf := conf.pwd.Open()
|
||||||
|
dsn = fmt.Sprintf("%s:%s@%s(%s)/%s"+argsStr, conf.User, pwdBuf.String(), connectType, host, conf.DB)
|
||||||
|
pwdBuf.Close()
|
||||||
|
}
|
||||||
|
return sql.Open(conf.Type, dsn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) CopyByLogger(logger *log.Logger) *DB {
|
||||||
|
newDB := new(DB)
|
||||||
|
newDB.QuoteTag = db.QuoteTag
|
||||||
|
newDB.name = db.name
|
||||||
|
newDB.conn = db.conn
|
||||||
|
newDB.readonlyConnections = db.readonlyConnections
|
||||||
|
newDB.Config = db.Config
|
||||||
|
if logger == nil {
|
||||||
|
logger = log.DefaultLogger
|
||||||
|
}
|
||||||
|
newDB.logger = &dbLogger{logger: logger, config: db.Config}
|
||||||
|
return newDB
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) SetLogger(logger *log.Logger) {
|
||||||
|
db.logger.logger = logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) GetLogger() *log.Logger {
|
||||||
|
return db.logger.logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) Destroy() error {
|
||||||
|
if db.conn == nil {
|
||||||
|
return errors.New("operate on a bad connection")
|
||||||
|
}
|
||||||
|
err := db.conn.Close()
|
||||||
|
if err != nil {
|
||||||
|
db.logger.LogError(err.Error())
|
||||||
|
}
|
||||||
|
dbInstancesLock.Lock()
|
||||||
|
delete(dbInstances, db.name)
|
||||||
|
dbInstancesLock.Unlock()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) GetOriginDB() *sql.DB {
|
||||||
|
return db.conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) Prepare(query string) *Stmt {
|
||||||
|
stmt := basePrepare(db.conn, nil, query)
|
||||||
|
stmt.logger = db.logger
|
||||||
|
if stmt.Error != nil {
|
||||||
|
db.logger.LogError(stmt.Error.Error())
|
||||||
|
}
|
||||||
|
return stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) Quote(text string) string {
|
||||||
|
return quote(db.QuoteTag, text)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) Quotes(texts []string) string {
|
||||||
|
return quotes(db.QuoteTag, texts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) Begin() *Tx {
|
||||||
|
if db.conn == nil {
|
||||||
|
return &Tx{QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), Error: errors.New("operate on a bad connection"), logger: db.logger}
|
||||||
|
}
|
||||||
|
sqlTx, err := db.conn.Begin()
|
||||||
|
if err != nil {
|
||||||
|
db.logger.LogError(err.Error())
|
||||||
|
return &Tx{QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), Error: err, logger: db.logger}
|
||||||
|
}
|
||||||
|
return &Tx{QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), conn: sqlTx, logger: db.logger}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) Exec(query string, args ...any) *ExecResult {
|
||||||
|
r := baseExec(db.conn, nil, query, args...)
|
||||||
|
r.logger = db.logger
|
||||||
|
if r.Error != nil {
|
||||||
|
db.logger.LogQueryError(r.Error.Error(), query, args, r.usedTime)
|
||||||
|
} else {
|
||||||
|
if db.Config.LogSlow > 0 && r.usedTime >= float32(db.Config.LogSlow.TimeDuration()/time.Millisecond) {
|
||||||
|
db.logger.LogQuery(query, args, r.usedTime)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) Query(query string, args ...any) *QueryResult {
|
||||||
|
conn := db.conn
|
||||||
|
if db.readonlyConnections != nil {
|
||||||
|
connNum := len(db.readonlyConnections)
|
||||||
|
if connNum == 1 {
|
||||||
|
conn = db.readonlyConnections[0]
|
||||||
|
} else {
|
||||||
|
p := rand.Int(0, connNum-1)
|
||||||
|
conn = db.readonlyConnections[p]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
r := baseQuery(conn, nil, query, args...)
|
||||||
|
r.logger = db.logger
|
||||||
|
if r.Error != nil {
|
||||||
|
db.logger.LogQueryError(r.Error.Error(), query, args, r.usedTime)
|
||||||
|
} else {
|
||||||
|
if db.Config.LogSlow > 0 && r.usedTime >= float32(db.Config.LogSlow.TimeDuration()/time.Millisecond) {
|
||||||
|
db.logger.LogQuery(query, args, r.usedTime)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) Insert(table string, data any) *ExecResult {
|
||||||
|
query, values := db.MakeInsertSql(table, data, false)
|
||||||
|
r := baseExec(db.conn, nil, query, values...)
|
||||||
|
r.logger = db.logger
|
||||||
|
if r.Error != nil {
|
||||||
|
db.logger.LogQueryError(r.Error.Error(), query, values, r.usedTime)
|
||||||
|
} else {
|
||||||
|
if db.Config.LogSlow > 0 && r.usedTime >= float32(db.Config.LogSlow.TimeDuration()/time.Millisecond) {
|
||||||
|
db.logger.LogQuery(query, values, r.usedTime)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) Replace(table string, data any) *ExecResult {
|
||||||
|
query, values := db.MakeInsertSql(table, data, true)
|
||||||
|
r := baseExec(db.conn, nil, query, values...)
|
||||||
|
r.logger = db.logger
|
||||||
|
if r.Error != nil {
|
||||||
|
db.logger.LogQueryError(r.Error.Error(), query, values, r.usedTime)
|
||||||
|
} else {
|
||||||
|
if db.Config.LogSlow > 0 && r.usedTime >= float32(db.Config.LogSlow.TimeDuration()/time.Millisecond) {
|
||||||
|
db.logger.LogQuery(query, values, r.usedTime)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) Update(table string, data any, conditions string, args ...any) *ExecResult {
|
||||||
|
query, values := db.MakeUpdateSql(table, data, conditions, args...)
|
||||||
|
r := baseExec(db.conn, nil, query, values...)
|
||||||
|
r.logger = db.logger
|
||||||
|
if r.Error != nil {
|
||||||
|
db.logger.LogQueryError(r.Error.Error(), query, values, r.usedTime)
|
||||||
|
} else {
|
||||||
|
if db.Config.LogSlow > 0 && r.usedTime >= float32(db.Config.LogSlow.TimeDuration()/time.Millisecond) {
|
||||||
|
db.logger.LogQuery(query, values, r.usedTime)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) Delete(table string, conditions string, args ...any) *ExecResult {
|
||||||
|
if conditions != "" {
|
||||||
|
conditions = " where " + conditions
|
||||||
|
}
|
||||||
|
query := fmt.Sprintf("delete from %s%s", db.Quote(table), conditions)
|
||||||
|
r := baseExec(db.conn, nil, query, args...)
|
||||||
|
r.logger = db.logger
|
||||||
|
if r.Error != nil {
|
||||||
|
db.logger.LogQueryError(r.Error.Error(), query, args, r.usedTime)
|
||||||
|
} else {
|
||||||
|
if db.Config.LogSlow > 0 && r.usedTime >= float32(db.Config.LogSlow.TimeDuration()/time.Millisecond) {
|
||||||
|
db.logger.LogQuery(query, args, r.usedTime)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) InKeys(numArgs int) string {
|
||||||
|
return InKeys(numArgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func InKeys(numArgs int) string {
|
||||||
|
a := make([]string, numArgs)
|
||||||
|
for i := 0; i < numArgs; i++ {
|
||||||
|
a[i] = "?"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("(%s)", strings.Join(a, ","))
|
||||||
|
}
|
||||||
477
DB_test.go
Normal file
477
DB_test.go
Normal file
@ -0,0 +1,477 @@
|
|||||||
|
package db_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"apigo.cc/go/cast"
|
||||||
|
"apigo.cc/go/db"
|
||||||
|
"apigo.cc/go/shell"
|
||||||
|
|
||||||
|
_ "apigo.cc/go/db/mysql"
|
||||||
|
_ "modernc.org/sqlite"
|
||||||
|
)
|
||||||
|
|
||||||
|
var dbset = "sqlite://test.db"
|
||||||
|
|
||||||
|
type userInfo struct {
|
||||||
|
innerId int
|
||||||
|
Tag string
|
||||||
|
Id int
|
||||||
|
Name string
|
||||||
|
Phone *string
|
||||||
|
Email string
|
||||||
|
Parents []string
|
||||||
|
Active bool
|
||||||
|
Time string
|
||||||
|
}
|
||||||
|
|
||||||
|
type UserBaseModel struct {
|
||||||
|
Id int
|
||||||
|
Name string
|
||||||
|
Password string
|
||||||
|
}
|
||||||
|
|
||||||
|
type UserModel struct {
|
||||||
|
UserBaseModel
|
||||||
|
Phone string
|
||||||
|
Active bool
|
||||||
|
Parents []string
|
||||||
|
UserStatus int
|
||||||
|
Owner int
|
||||||
|
Salt string
|
||||||
|
}
|
||||||
|
|
||||||
|
func initDB(t *testing.T) *db.DB {
|
||||||
|
var er *db.ExecResult
|
||||||
|
dbInst := db.GetDB(dbset, nil)
|
||||||
|
fmt.Println("dbType", shell.BCyan(dbInst.Config.Type))
|
||||||
|
if dbInst.Error != nil {
|
||||||
|
t.Fatal("GetDB error", dbInst)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
finishDB(dbInst, t)
|
||||||
|
if dbInst.Config.Type == "mysql" {
|
||||||
|
er = dbInst.Exec(`CREATE TABLE IF NOT EXISTS tempUsersForDBTest (
|
||||||
|
id INT NOT NULL AUTO_INCREMENT,
|
||||||
|
name VARCHAR(45) NOT NULL,
|
||||||
|
phone VARCHAR(45),
|
||||||
|
email VARCHAR(45),
|
||||||
|
parents JSON,
|
||||||
|
active TINYINT NOT NULL DEFAULT 0,
|
||||||
|
time DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||||
|
PRIMARY KEY (id));`)
|
||||||
|
} else {
|
||||||
|
er = dbInst.Exec(`CREATE TABLE IF NOT EXISTS tempUsersForDBTest (
|
||||||
|
id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||||
|
name VARCHAR(45) NOT NULL,
|
||||||
|
phone VARCHAR(45),
|
||||||
|
email VARCHAR(45),
|
||||||
|
parents JSON,
|
||||||
|
active TINYINT NOT NULL DEFAULT 0,
|
||||||
|
time DATETIME NOT NULL DEFAULT (strftime('%Y-%m-%d %H:%M:%f')));`)
|
||||||
|
}
|
||||||
|
if er.Error != nil {
|
||||||
|
t.Fatal("Failed to create table", er)
|
||||||
|
}
|
||||||
|
return dbInst
|
||||||
|
}
|
||||||
|
|
||||||
|
func finishDB(dbInst *db.DB, t *testing.T) {
|
||||||
|
er := dbInst.Exec(`DROP TABLE IF EXISTS tempUsersForDBTest;`)
|
||||||
|
if er.Error != nil {
|
||||||
|
t.Fatal("Failed to drop table", er)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMakeInsertSql(t *testing.T) {
|
||||||
|
user := &UserModel{
|
||||||
|
UserBaseModel: UserBaseModel{
|
||||||
|
Name: "王二小",
|
||||||
|
Password: "2121asds",
|
||||||
|
},
|
||||||
|
Parents: []string{"aa", "bb"},
|
||||||
|
Active: true,
|
||||||
|
UserStatus: 1,
|
||||||
|
Salt: "de312",
|
||||||
|
}
|
||||||
|
|
||||||
|
dbInst := db.GetDB(dbset, nil)
|
||||||
|
query, _ := dbInst.MakeInsertSql("table_name", user, false)
|
||||||
|
checkSql := `insert into "table_name" ("Id","Name","Password","Phone","Active","Parents","UserStatus","Owner","Salt") values (?,?,?,?,?,?,?,?,?)`
|
||||||
|
if dbInst.Config.Type == "mysql" {
|
||||||
|
checkSql = strings.ReplaceAll(checkSql, "\"", "`")
|
||||||
|
}
|
||||||
|
if query != checkSql {
|
||||||
|
t.Fatal("MakeInsertSql query error ", query)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBaseSelect(t *testing.T) {
|
||||||
|
sqlStr := "SELECT 1002 id, '13800000001' phone"
|
||||||
|
dbInst := db.GetDB(dbset, nil)
|
||||||
|
if dbInst.Error != nil {
|
||||||
|
t.Fatal("GetDB error", dbInst.Error)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r := dbInst.Query(sqlStr)
|
||||||
|
if r.Error != nil {
|
||||||
|
t.Fatal("Query error", sqlStr, r)
|
||||||
|
}
|
||||||
|
results1 := r.MapResults()
|
||||||
|
if cast.Int(results1[0]["id"]) != 1002 || cast.String(results1[0]["phone"]) != "13800000001" {
|
||||||
|
t.Fatal("Result error", sqlStr, results1, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
r = dbInst.Query(sqlStr)
|
||||||
|
if r.Error != nil {
|
||||||
|
t.Fatal("Query error", sqlStr, r)
|
||||||
|
}
|
||||||
|
results2 := r.StringMapResults()
|
||||||
|
if results2[0]["id"] != "1002" || results2[0]["phone"] != "13800000001" {
|
||||||
|
t.Fatal("Result error", sqlStr, results2, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
results3 := make([]map[string]int, 0)
|
||||||
|
r = dbInst.Query(sqlStr)
|
||||||
|
if r.Error != nil {
|
||||||
|
t.Fatal("Query error", sqlStr, results3, r)
|
||||||
|
}
|
||||||
|
r.To(&results3)
|
||||||
|
if results3[0]["id"] != 1002 || results3[0]["phone"] != 13800000001 {
|
||||||
|
t.Fatal("Result error", sqlStr, results3, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
results4 := make([]userInfo, 0)
|
||||||
|
r = dbInst.Query(sqlStr)
|
||||||
|
if r.Error != nil {
|
||||||
|
t.Fatal("Query error", sqlStr, results4, r)
|
||||||
|
}
|
||||||
|
r.To(&results4)
|
||||||
|
if results4[0].Id != 1002 || results4[0].Phone == nil || *results4[0].Phone != "13800000001" {
|
||||||
|
t.Fatal("Result error", sqlStr, results4, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
results5 := dbInst.Query(sqlStr).StringSliceResults()
|
||||||
|
if results5[0][0] != "1002" || results5[0][1] != "13800000001" {
|
||||||
|
t.Fatal("Result error", sqlStr, results5, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
r = dbInst.Query(sqlStr)
|
||||||
|
if r.Error != nil {
|
||||||
|
t.Fatal("Query error", sqlStr, r)
|
||||||
|
}
|
||||||
|
results6 := r.StringsOnC1()
|
||||||
|
if results6[0] != "1002" {
|
||||||
|
t.Fatal("Result error", sqlStr, results6, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
r = dbInst.Query(sqlStr)
|
||||||
|
if r.Error != nil {
|
||||||
|
t.Fatal("Query error", sqlStr, r)
|
||||||
|
}
|
||||||
|
results7 := r.MapOnR1()
|
||||||
|
if cast.Int(results7["id"]) != 1002 || cast.String(results7["phone"]) != "13800000001" {
|
||||||
|
t.Fatal("Result error", sqlStr, results7, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
results8 := userInfo{innerId: 2, Tag: "abc"}
|
||||||
|
r = dbInst.Query(sqlStr)
|
||||||
|
if r.Error != nil {
|
||||||
|
t.Fatal("Query error", sqlStr, results8, r)
|
||||||
|
}
|
||||||
|
r.To(&results8)
|
||||||
|
if results8.Id != 1002 || results8.Phone == nil || *results8.Phone != "13800000001" || results8.innerId != 2 || results8.Tag != "abc" {
|
||||||
|
t.Fatal("Result error", sqlStr, results8, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
r = dbInst.Query(sqlStr)
|
||||||
|
if r.Error != nil {
|
||||||
|
t.Fatal("Query error", sqlStr, r)
|
||||||
|
}
|
||||||
|
results9 := r.IntOnR1C1()
|
||||||
|
if results9 != 1002 {
|
||||||
|
t.Fatal("Result error", sqlStr, results9, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
r = dbInst.Query(sqlStr)
|
||||||
|
results10 := map[string]string{}
|
||||||
|
r.ToKV(&results10)
|
||||||
|
if results10["1002"] != "13800000001" {
|
||||||
|
t.Fatal("Result error", sqlStr, results10, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
r = dbInst.Query(sqlStr)
|
||||||
|
results11 := map[string]map[string]string{}
|
||||||
|
r.ToKV(&results11)
|
||||||
|
if results11["1002"]["phone"] != "13800000001" {
|
||||||
|
t.Fatal("Result error", sqlStr, results11, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
r = dbInst.Query(sqlStr)
|
||||||
|
results12 := map[string]userInfo{}
|
||||||
|
r.ToKV(&results12)
|
||||||
|
if results12["1002"].Phone == nil || *results12["1002"].Phone != "13800000001" {
|
||||||
|
t.Fatal("Result error", sqlStr, results12, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInsertReplaceUpdateDelete(t *testing.T) {
|
||||||
|
dbInst := initDB(t)
|
||||||
|
data := map[string]any{
|
||||||
|
"phone": 18033336666,
|
||||||
|
"name": "Star",
|
||||||
|
"parents": []string{"dd", "mm"},
|
||||||
|
"time": ":(strftime('%Y-%m-%d %H:%M:%f', datetime('now', '-1 day'), 'localtime'))",
|
||||||
|
}
|
||||||
|
if dbInst.Config.Type == "mysql" {
|
||||||
|
data["time"] = ":DATE_SUB(NOW(), INTERVAL 1 DAY)"
|
||||||
|
}
|
||||||
|
er := dbInst.Insert("tempUsersForDBTest", data)
|
||||||
|
if er.Error != nil {
|
||||||
|
t.Fatal("Insert 1 error", er)
|
||||||
|
}
|
||||||
|
if er.Id() != 1 {
|
||||||
|
t.Fatal("insertId 1 error", er, er.Id())
|
||||||
|
}
|
||||||
|
|
||||||
|
er = dbInst.Insert("tempUsersForDBTest", map[string]any{
|
||||||
|
"phone": "18000000002",
|
||||||
|
"name": "Tom",
|
||||||
|
"active": true,
|
||||||
|
})
|
||||||
|
if er.Error != nil {
|
||||||
|
t.Fatal("Insert 2 error", er)
|
||||||
|
}
|
||||||
|
if er.Id() != 2 {
|
||||||
|
t.Fatal("insertId 2 error", er, er.Id())
|
||||||
|
}
|
||||||
|
|
||||||
|
er = dbInst.Update("tempUsersForDBTest", map[string]any{
|
||||||
|
"phone": "18000000222",
|
||||||
|
"name": "Tom Lee",
|
||||||
|
}, "id=?", 2)
|
||||||
|
if er.Error != nil {
|
||||||
|
t.Fatal("Update 2 error", er)
|
||||||
|
}
|
||||||
|
if er.Changes() != 1 {
|
||||||
|
t.Fatal("Update 2 num error", er, er.Changes())
|
||||||
|
}
|
||||||
|
|
||||||
|
er = dbInst.Replace("tempUsersForDBTest", map[string]any{
|
||||||
|
"phone": "18000000003",
|
||||||
|
"name": "Amy",
|
||||||
|
})
|
||||||
|
if er.Error != nil {
|
||||||
|
t.Fatal("Replace 3 error", er)
|
||||||
|
}
|
||||||
|
if er.Id() != 3 {
|
||||||
|
t.Fatal("insertId 3 error", er, er.Changes())
|
||||||
|
}
|
||||||
|
|
||||||
|
er = dbInst.Exec("delete from tempUsersForDBTest where id=3")
|
||||||
|
if er.Error != nil {
|
||||||
|
t.Fatal("Delete 3 error", er)
|
||||||
|
}
|
||||||
|
if er.Changes() != 1 {
|
||||||
|
t.Fatal("Delete 3 num error", er)
|
||||||
|
}
|
||||||
|
|
||||||
|
er = dbInst.Replace("tempUsersForDBTest", map[string]any{
|
||||||
|
"phone": "18000000004",
|
||||||
|
"name": "Jerry",
|
||||||
|
})
|
||||||
|
if er.Error != nil {
|
||||||
|
t.Fatal("Replace 4 error", er)
|
||||||
|
}
|
||||||
|
if er.Id() != 4 {
|
||||||
|
t.Fatal("insertId 4 error", er, er.Changes())
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt := dbInst.Prepare("replace into `tempUsersForDBTest` (`id`,`phone`,`name`) values (?,?,?)")
|
||||||
|
if stmt.Error != nil {
|
||||||
|
t.Fatal("Prepare 4 error", stmt)
|
||||||
|
}
|
||||||
|
er = stmt.Exec(4, "18000000004", "Jerry's Mather")
|
||||||
|
stmt.Close()
|
||||||
|
|
||||||
|
if er.Error != nil {
|
||||||
|
t.Fatal("Replace 4 error", er)
|
||||||
|
}
|
||||||
|
if er.Id() != 4 {
|
||||||
|
t.Fatal("insertId 4 error", er)
|
||||||
|
}
|
||||||
|
|
||||||
|
userList := make([]userInfo, 0)
|
||||||
|
r := dbInst.Query("select * from tempUsersForDBTest")
|
||||||
|
if r.Error != nil {
|
||||||
|
t.Fatal("Select userList error", r)
|
||||||
|
}
|
||||||
|
r.To(&userList)
|
||||||
|
fmt.Println(">>>>", cast.PrettyToJSON(userList))
|
||||||
|
if strings.Split(userList[0].Time, " ")[0] != time.Now().Add(time.Hour*24*-1).Format("2006-01-02") || userList[0].Id != 1 || userList[0].Name != "Star" || userList[0].Phone == nil || *userList[0].Phone != "18033336666" || userList[0].Active != false {
|
||||||
|
t.Fatal("Select userList 0 error", userList, r)
|
||||||
|
}
|
||||||
|
if len(userList[0].Parents) != 2 || userList[0].Parents[0] != "dd" {
|
||||||
|
t.Fatal("Select userList 0 Parents error", userList, r)
|
||||||
|
}
|
||||||
|
if strings.Split(userList[1].Time, " ")[0] != time.Now().Format("2006-01-02") || userList[1].Id != 2 || userList[1].Name != "Tom Lee" || userList[1].Phone == nil || *userList[1].Phone != "18000000222" || userList[1].Active != true {
|
||||||
|
t.Fatal("Select userList 1 error", userList, r)
|
||||||
|
}
|
||||||
|
if userList[2].Id != 4 || userList[2].Name != "Jerry's Mather" || userList[2].Phone == nil || *userList[2].Phone != "18000000004" {
|
||||||
|
t.Fatal("Select userList 2 error", userList, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
finishDB(dbInst, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransaction(t *testing.T) {
|
||||||
|
n1 := countConnection()
|
||||||
|
|
||||||
|
var userList []userInfo
|
||||||
|
|
||||||
|
dbInst := initDB(t)
|
||||||
|
tx := dbInst.Begin()
|
||||||
|
if tx.Error != nil {
|
||||||
|
t.Fatal("Begin error", tx)
|
||||||
|
}
|
||||||
|
|
||||||
|
data := map[string]any{
|
||||||
|
"phone": 18033336666,
|
||||||
|
"name": "Star",
|
||||||
|
"time": ":(strftime('%Y-%m-%d %H:%M:%f', datetime('now', '-1 day'), 'localtime'))",
|
||||||
|
}
|
||||||
|
if dbInst.Config.Type == "mysql" {
|
||||||
|
data["time"] = ":DATE_SUB(NOW(), INTERVAL 1 DAY)"
|
||||||
|
}
|
||||||
|
tx.Insert("tempUsersForDBTest", data)
|
||||||
|
|
||||||
|
userList = make([]userInfo, 0)
|
||||||
|
r := dbInst.Query("select * from tempUsersForDBTest")
|
||||||
|
r.To(&userList)
|
||||||
|
if r.Error != nil || len(userList) != 0 {
|
||||||
|
t.Fatal("Select Out Of TX", userList, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
userList = make([]userInfo, 0)
|
||||||
|
r = tx.Query("select * from tempUsersForDBTest")
|
||||||
|
r.To(&userList)
|
||||||
|
if r.Error != nil || len(userList) != 1 {
|
||||||
|
t.Fatal("Select In TX", userList, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
tx.Rollback()
|
||||||
|
|
||||||
|
userList = make([]userInfo, 0)
|
||||||
|
r = dbInst.Query("select * from tempUsersForDBTest")
|
||||||
|
r.To(&userList)
|
||||||
|
if r.Error != nil || len(userList) != 0 {
|
||||||
|
t.Fatal("Select When Rollback", userList, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
tx = dbInst.Begin()
|
||||||
|
defer func() {
|
||||||
|
if err := tx.Finish(false); err != nil {
|
||||||
|
t.Error("tx rollback error", err)
|
||||||
|
}
|
||||||
|
finishDB(dbInst, t)
|
||||||
|
}()
|
||||||
|
if tx.Error != nil {
|
||||||
|
t.Fatal("Begin 2 error", tx)
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt := tx.Prepare("insert into `tempUsersForDBTest` (`id`,`phone`,`name`) values (?,?,?)")
|
||||||
|
if stmt.Error != nil {
|
||||||
|
t.Fatal("Prepare 4 error", r)
|
||||||
|
}
|
||||||
|
stmt.Exec(4, "18000000004", "Jerry's Mather")
|
||||||
|
stmt.Close()
|
||||||
|
|
||||||
|
tx.Commit()
|
||||||
|
|
||||||
|
userList = make([]userInfo, 0)
|
||||||
|
r = dbInst.Query("select * from tempUsersForDBTest")
|
||||||
|
r.To(&userList)
|
||||||
|
if r.Error != nil || len(userList) != 1 {
|
||||||
|
t.Fatal("Select When Commit", userList, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
n2 := countConnection()
|
||||||
|
fmt.Println("# connection count", n1, n2, cast.MustToJSON(dbInst.GetOriginDB().Stats()), ".")
|
||||||
|
}
|
||||||
|
|
||||||
|
func countConnection() int {
|
||||||
|
n := 0
|
||||||
|
res, _ := shell.RunCommand("netstat -ant", nil)
|
||||||
|
lines := strings.Split(string(res.Stdout), "\n")
|
||||||
|
spliter := regexp.MustCompile("\\s+")
|
||||||
|
for _, line := range lines {
|
||||||
|
if strings.Contains(line, ".3306") && strings.Contains(line, "ESTABLISHED") {
|
||||||
|
a := spliter.Split(line, 10)
|
||||||
|
if len(a) > 4 && strings.Contains(a[4], ".3306") {
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkForPool(b *testing.B) {
|
||||||
|
b.StopTimer()
|
||||||
|
sqlStr := "SELECT 1002 id, '13800000001' phone"
|
||||||
|
dbInst := db.GetDB(dbset, nil)
|
||||||
|
if dbInst.Error != nil {
|
||||||
|
b.Fatal("GetDB error", dbInst)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
b.StartTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
results1 := make([]map[string]any, 0)
|
||||||
|
r := dbInst.Query(sqlStr)
|
||||||
|
if r.Error != nil {
|
||||||
|
b.Fatal("Query error", sqlStr, results1, r)
|
||||||
|
}
|
||||||
|
r.To(&results1)
|
||||||
|
if cast.Int(results1[0]["id"]) != 1002 || cast.String(results1[0]["phone"]) != "13800000001" {
|
||||||
|
b.Fatal("Result error", sqlStr, results1, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b.Log("OpenConnections", dbInst.GetOriginDB().Stats().OpenConnections)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkForPoolParallel(b *testing.B) {
|
||||||
|
n1 := countConnection()
|
||||||
|
|
||||||
|
b.StopTimer()
|
||||||
|
sqlStr := "SELECT 1002 id, '13800000001' phone"
|
||||||
|
dbInst := db.GetDB(dbset, nil)
|
||||||
|
if dbInst.Error != nil {
|
||||||
|
b.Fatal("GetDB error", dbInst)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b.StartTimer()
|
||||||
|
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
results1 := make([]map[string]any, 0)
|
||||||
|
r := dbInst.Query(sqlStr)
|
||||||
|
if r.Error != nil {
|
||||||
|
b.Fatal("Query error", sqlStr, results1, r)
|
||||||
|
}
|
||||||
|
r.To(&results1)
|
||||||
|
if cast.Int(results1[0]["id"]) != 1002 || cast.String(results1[0]["phone"]) != "13800000001" {
|
||||||
|
b.Fatal("Result error", sqlStr, results1, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
b.Log("OpenConnections", dbInst.GetOriginDB().Stats().OpenConnections)
|
||||||
|
|
||||||
|
n2 := countConnection()
|
||||||
|
fmt.Println("# connection count", n1, n2, cast.MustToJSON(dbInst.GetOriginDB().Stats()), ".")
|
||||||
|
}
|
||||||
66
README.md
66
README.md
@ -1,3 +1,65 @@
|
|||||||
# db
|
# @go/db
|
||||||
|
|
||||||
High-performance database abstraction layer for apigo.cc/go
|
> **Maintainer Statement:** 本项目由 AI 维护。代码源自 github.com/ssgo/db 的重构,支持内存安全防护、读写分离及泛型优化。
|
||||||
|
|
||||||
|
## 🎯 设计哲学
|
||||||
|
|
||||||
|
`@go/db` 是一个极致精简、意图优先的数据库抽象层。它不试图取代 SQL,而是通过智能结果绑定与 SQL 自动化生成,消除数据库操作中的样板代码。
|
||||||
|
|
||||||
|
* **智能绑定**:根据结果容器类型(Struct/Map/Slice/BaseType)自动适配查询逻辑,无需手动 Scan。
|
||||||
|
* **内存防御**:集成 `go/safe`,数据库密码在内存中加密存储,使用时物理擦除。
|
||||||
|
* **读写分离**:内置连接池管理,支持配置多个只读节点实现自动负载均衡。
|
||||||
|
* **驱动透明**:统一 MySQL、PostgreSQL (pgx) 与 SQLite 的 API 差异。
|
||||||
|
|
||||||
|
## 📦 安装
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go get apigo.cc/go/db
|
||||||
|
```
|
||||||
|
|
||||||
|
## 💡 快速开始
|
||||||
|
|
||||||
|
```go
|
||||||
|
import "apigo.cc/go/db"
|
||||||
|
import _ "apigo.cc/go/db/mysql" // 引入驱动
|
||||||
|
|
||||||
|
// 初始化连接
|
||||||
|
d := db.GetDB("mysql://user:pass@host:3306/dbname", nil)
|
||||||
|
|
||||||
|
// 1. 查询全部结果到 Struct 切片
|
||||||
|
var users []User
|
||||||
|
d.Query("SELECT * FROM users").To(&users)
|
||||||
|
|
||||||
|
// 2. 自动化插入
|
||||||
|
d.Insert("users", User{Name: "Star", Active: true})
|
||||||
|
|
||||||
|
// 3. 事务操作
|
||||||
|
tx := d.Begin()
|
||||||
|
tx.Exec("UPDATE balance SET amount = amount - 10 WHERE id = ?", 1)
|
||||||
|
tx.Commit()
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🛠 API 指南
|
||||||
|
|
||||||
|
### 核心对象
|
||||||
|
- **`GetDB(setting string, logger *log.Logger) *DB`**: 通过 DSN 或配置名获取数据库实例。
|
||||||
|
- **`DB.Insert/Replace/Update/Delete`**: 自动生成 SQL 并执行,支持 Struct 与 Map。
|
||||||
|
- **`QueryResult.To(target any)`**: 将查询结果深度映射到目标容器。
|
||||||
|
- **`QueryResult.MapResults() []map[string]any`**: 快捷获取通用结果集。
|
||||||
|
|
||||||
|
### 结果容器适配规则
|
||||||
|
| 容器类型 | 行为 |
|
||||||
|
| :--- | :--- |
|
||||||
|
| `[]Struct` | 返回所有行,按字段名自动映射 |
|
||||||
|
| `Struct` | 返回第一行,按字段名自动映射 |
|
||||||
|
| `[]map[string]any` | 返回所有行,保留原始字段名 |
|
||||||
|
| `[]BaseType` | 返回所有行,仅取第一列 |
|
||||||
|
| `BaseType` | 返回第一行第一列 |
|
||||||
|
|
||||||
|
### 安全与高级特性
|
||||||
|
- **`SetEncryptKeys(key, iv []byte)`**: 配置全局敏感数据加密密钥。
|
||||||
|
- **读写分离**: 在 DSN 中配置 `host1,host2,host3`,首个为主库,其余为随机只读库。
|
||||||
|
- **SQLite 时间修复**: 自动处理 SQLite 毫秒级 `DATETIME` 格式与标准 `time.Time` 的转换。
|
||||||
|
|
||||||
|
## 🧪 验证状态
|
||||||
|
已通过 SQLite 集成测试。详见:[TEST.md](./TEST.md)
|
||||||
|
|||||||
577
Result.go
Normal file
577
Result.go
Normal file
@ -0,0 +1,577 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"apigo.cc/go/cast"
|
||||||
|
"apigo.cc/go/convert"
|
||||||
|
"github.com/mitchellh/mapstructure"
|
||||||
|
)
|
||||||
|
|
||||||
|
type QueryResult struct {
|
||||||
|
rows *sql.Rows
|
||||||
|
Sql *string
|
||||||
|
Args []any
|
||||||
|
Error error
|
||||||
|
logger *dbLogger
|
||||||
|
usedTime float32
|
||||||
|
completed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type ExecResult struct {
|
||||||
|
result sql.Result
|
||||||
|
Sql *string
|
||||||
|
Args []any
|
||||||
|
Error error
|
||||||
|
logger *dbLogger
|
||||||
|
usedTime float32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ExecResult) Changes() int64 {
|
||||||
|
if r.result == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
numChanges, err := r.result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return numChanges
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ExecResult) Id() int64 {
|
||||||
|
if r.result == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
insertId, err := r.result.LastInsertId()
|
||||||
|
if err != nil {
|
||||||
|
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return insertId
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *QueryResult) Complete() {
|
||||||
|
if !r.completed {
|
||||||
|
if r.rows != nil {
|
||||||
|
r.rows.Close()
|
||||||
|
}
|
||||||
|
r.completed = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *QueryResult) To(result any) error {
|
||||||
|
if r.rows == nil {
|
||||||
|
return errors.New("operate on a bad query")
|
||||||
|
}
|
||||||
|
return r.makeResults(result, r.rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *QueryResult) MapResults() []map[string]any {
|
||||||
|
result := make([]map[string]any, 0)
|
||||||
|
err := r.makeResults(&result, r.rows)
|
||||||
|
if err != nil {
|
||||||
|
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *QueryResult) SliceResults() [][]any {
|
||||||
|
result := make([][]any, 0)
|
||||||
|
err := r.makeResults(&result, r.rows)
|
||||||
|
if err != nil {
|
||||||
|
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *QueryResult) StringMapResults() []map[string]string {
|
||||||
|
result := make([]map[string]string, 0)
|
||||||
|
err := r.makeResults(&result, r.rows)
|
||||||
|
if err != nil {
|
||||||
|
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *QueryResult) StringSliceResults() [][]string {
|
||||||
|
result := make([][]string, 0)
|
||||||
|
err := r.makeResults(&result, r.rows)
|
||||||
|
if err != nil {
|
||||||
|
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *QueryResult) MapOnR1() map[string]any {
|
||||||
|
result := make(map[string]any)
|
||||||
|
err := r.makeResults(&result, r.rows)
|
||||||
|
if err != nil {
|
||||||
|
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *QueryResult) StringMapOnR1() map[string]string {
|
||||||
|
result := make(map[string]string)
|
||||||
|
err := r.makeResults(&result, r.rows)
|
||||||
|
if err != nil {
|
||||||
|
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *QueryResult) IntsOnC1() []int64 {
|
||||||
|
result := make([]int64, 0)
|
||||||
|
err := r.makeResults(&result, r.rows)
|
||||||
|
if err != nil {
|
||||||
|
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *QueryResult) StringsOnC1() []string {
|
||||||
|
result := make([]string, 0)
|
||||||
|
err := r.makeResults(&result, r.rows)
|
||||||
|
if err != nil {
|
||||||
|
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *QueryResult) IntOnR1C1() int64 {
|
||||||
|
var result int64 = 0
|
||||||
|
err := r.makeResults(&result, r.rows)
|
||||||
|
if err != nil {
|
||||||
|
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *QueryResult) FloatOnR1C1() float64 {
|
||||||
|
var result float64 = 0
|
||||||
|
err := r.makeResults(&result, r.rows)
|
||||||
|
if err != nil {
|
||||||
|
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *QueryResult) StringOnR1C1() string {
|
||||||
|
result := ""
|
||||||
|
err := r.makeResults(&result, r.rows)
|
||||||
|
if err != nil {
|
||||||
|
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *QueryResult) ToKV(target any) error {
|
||||||
|
v := reflect.ValueOf(target)
|
||||||
|
t := v.Type()
|
||||||
|
for t.Kind() == reflect.Ptr {
|
||||||
|
v = v.Elem()
|
||||||
|
t = v.Type()
|
||||||
|
}
|
||||||
|
|
||||||
|
if t.Kind() != reflect.Map {
|
||||||
|
r.logger.LogQueryError("target not a map", *r.Sql, r.Args, r.usedTime)
|
||||||
|
return errors.New("target not a map")
|
||||||
|
}
|
||||||
|
|
||||||
|
vt := t.Elem()
|
||||||
|
finalVt := vt
|
||||||
|
for finalVt.Kind() == reflect.Ptr {
|
||||||
|
finalVt = finalVt.Elem()
|
||||||
|
}
|
||||||
|
if finalVt.Kind() == reflect.Map || finalVt.Kind() == reflect.Struct {
|
||||||
|
colTypes, err := r.getColumnTypes()
|
||||||
|
list := r.MapResults()
|
||||||
|
if err != nil {
|
||||||
|
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, item := range list {
|
||||||
|
newKey := reflect.ValueOf(reflect.New(t.Key()).Interface()).Elem()
|
||||||
|
convert.To(item[colTypes[0].Name()], newKey.Addr().Interface())
|
||||||
|
|
||||||
|
newValue := v.MapIndex(newKey)
|
||||||
|
isNew := false
|
||||||
|
if !newValue.IsValid() {
|
||||||
|
newValue = reflect.New(vt)
|
||||||
|
isNew = true
|
||||||
|
}
|
||||||
|
|
||||||
|
err := mapstructure.WeakDecode(item, newValue.Interface())
|
||||||
|
if err != nil {
|
||||||
|
r.logger.LogError(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
if isNew {
|
||||||
|
v.SetMapIndex(newKey, newValue.Elem())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
list := r.SliceResults()
|
||||||
|
for _, item := range list {
|
||||||
|
if len(item) < 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch vt.Kind() {
|
||||||
|
case reflect.Int:
|
||||||
|
v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(cast.Int(item[1])))
|
||||||
|
case reflect.Int8:
|
||||||
|
v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(int8(cast.Int(item[1]))))
|
||||||
|
case reflect.Int16:
|
||||||
|
v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(int16(cast.Int(item[1]))))
|
||||||
|
case reflect.Int32:
|
||||||
|
v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(int32(cast.Int(item[1]))))
|
||||||
|
case reflect.Int64:
|
||||||
|
v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(cast.Int64(item[1])))
|
||||||
|
case reflect.Uint:
|
||||||
|
v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(uint(cast.Int(item[1]))))
|
||||||
|
case reflect.Uint8:
|
||||||
|
v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(uint8(cast.Int(item[1]))))
|
||||||
|
case reflect.Uint16:
|
||||||
|
v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(uint16(cast.Int(item[1]))))
|
||||||
|
case reflect.Uint32:
|
||||||
|
v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(uint32(cast.Int(item[1]))))
|
||||||
|
case reflect.Uint64:
|
||||||
|
v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(cast.Uint64(item[1])))
|
||||||
|
case reflect.Float32:
|
||||||
|
v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(cast.Float(item[1])))
|
||||||
|
case reflect.Float64:
|
||||||
|
v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(cast.Float64(item[1])))
|
||||||
|
case reflect.Bool:
|
||||||
|
v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(cast.Bool(item[1])))
|
||||||
|
case reflect.String:
|
||||||
|
v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(cast.String(item[1])))
|
||||||
|
case reflect.Interface:
|
||||||
|
v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(item[1]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *QueryResult) makeResults(results any, rows *sql.Rows) error {
|
||||||
|
if rows == nil {
|
||||||
|
return errors.New("not a valid query result")
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
_ = rows.Close()
|
||||||
|
r.completed = true
|
||||||
|
}()
|
||||||
|
resultsValue := reflect.ValueOf(results)
|
||||||
|
if resultsValue.Kind() != reflect.Ptr {
|
||||||
|
return fmt.Errorf("results must be a pointer")
|
||||||
|
}
|
||||||
|
|
||||||
|
for resultsValue.Kind() == reflect.Ptr {
|
||||||
|
resultsValue = resultsValue.Elem()
|
||||||
|
}
|
||||||
|
rowType := resultsValue.Type()
|
||||||
|
|
||||||
|
colTypes, err := rows.ColumnTypes()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
colNum := len(colTypes)
|
||||||
|
originRowType := rowType
|
||||||
|
if rowType.Kind() == reflect.Slice {
|
||||||
|
rowType = rowType.Elem()
|
||||||
|
originRowType = rowType
|
||||||
|
for rowType.Kind() == reflect.Ptr {
|
||||||
|
rowType = rowType.Elem()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
scanValues := make([]any, colNum)
|
||||||
|
var fieldInfos []struct {
|
||||||
|
index []int
|
||||||
|
typ reflect.Type
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
|
||||||
|
if rowType.Kind() == reflect.Struct {
|
||||||
|
fieldInfos = make([]struct {
|
||||||
|
index []int
|
||||||
|
typ reflect.Type
|
||||||
|
name string
|
||||||
|
}, colNum)
|
||||||
|
for colIndex, col := range colTypes {
|
||||||
|
publicColName := makePublicVarName(col.Name())
|
||||||
|
field, found := rowType.FieldByName(publicColName)
|
||||||
|
if found {
|
||||||
|
fieldInfos[colIndex].index = field.Index
|
||||||
|
fieldInfos[colIndex].typ = field.Type
|
||||||
|
fieldInfos[colIndex].name = publicColName
|
||||||
|
if field.Type.Kind() == reflect.Interface {
|
||||||
|
scanValues[colIndex] = makeValue(colTypes[colIndex].ScanType())
|
||||||
|
} else {
|
||||||
|
scanValues[colIndex] = makeValue(field.Type)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fieldInfos[colIndex].index = nil
|
||||||
|
scanValues[colIndex] = makeValue(nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if rowType.Kind() == reflect.Map {
|
||||||
|
for colIndex := range colTypes {
|
||||||
|
if rowType.Elem().Kind() == reflect.Interface {
|
||||||
|
scanValues[colIndex] = makeValue(colTypes[colIndex].ScanType())
|
||||||
|
} else {
|
||||||
|
scanValues[colIndex] = makeValue(rowType.Elem())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if rowType.Kind() == reflect.Slice {
|
||||||
|
for colIndex := range colTypes {
|
||||||
|
if rowType.Elem().Kind() == reflect.Interface {
|
||||||
|
scanValues[colIndex] = makeValue(colTypes[colIndex].ScanType())
|
||||||
|
} else {
|
||||||
|
scanValues[colIndex] = makeValue(rowType.Elem())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if rowType.Kind() == reflect.Interface {
|
||||||
|
scanValues[0] = makeValue(colTypes[0].ScanType())
|
||||||
|
} else {
|
||||||
|
scanValues[0] = makeValue(rowType)
|
||||||
|
}
|
||||||
|
for colIndex := 1; colIndex < colNum; colIndex++ {
|
||||||
|
scanValues[colIndex] = makeValue(nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var data reflect.Value
|
||||||
|
isNew := true
|
||||||
|
for rows.Next() {
|
||||||
|
err = rows.Scan(scanValues...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if rowType.Kind() == reflect.Struct {
|
||||||
|
if resultsValue.Kind() == reflect.Slice {
|
||||||
|
data = reflect.New(rowType).Elem()
|
||||||
|
} else {
|
||||||
|
data = resultsValue
|
||||||
|
isNew = false
|
||||||
|
}
|
||||||
|
|
||||||
|
for colIndex, col := range colTypes {
|
||||||
|
fInfo := fieldInfos[colIndex]
|
||||||
|
if fInfo.index == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
field := data.FieldByIndex(fInfo.index)
|
||||||
|
valuePtr := reflect.ValueOf(scanValues[colIndex]).Elem()
|
||||||
|
if !valuePtr.IsNil() {
|
||||||
|
val := valuePtr.Elem()
|
||||||
|
if fInfo.typ.String() == "time.Time" {
|
||||||
|
str := val.String()
|
||||||
|
tm, err := time.Parse("2006-01-02 15:04:05.000000", str)
|
||||||
|
if err != nil {
|
||||||
|
tm, err = time.Parse("2006-01-02 15:04:05", str)
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
field.Set(reflect.ValueOf(tm))
|
||||||
|
}
|
||||||
|
} else if val.Kind() != field.Kind() && field.Kind() != reflect.Interface {
|
||||||
|
if field.Kind() == reflect.Ptr && val.Kind() == field.Type().Elem().Kind() {
|
||||||
|
if val.CanAddr() {
|
||||||
|
if field.Type().AssignableTo(val.Type()) {
|
||||||
|
field.Set(val.Addr())
|
||||||
|
} else if val.Type().String() == "string" {
|
||||||
|
strVal := fixValue(col.DatabaseTypeName(), val)
|
||||||
|
field.Set(reflect.New(field.Type().Elem()))
|
||||||
|
field.Elem().SetString(cast.String(strVal.Interface()))
|
||||||
|
} else if strings.Contains(field.Type().String(), "uint") {
|
||||||
|
field.Set(reflect.New(field.Type().Elem()))
|
||||||
|
field.Elem().SetUint(cast.Uint64(val.Interface()))
|
||||||
|
} else if strings.Contains(field.Type().String(), "int") {
|
||||||
|
field.Set(reflect.New(field.Type().Elem()))
|
||||||
|
field.Elem().SetInt(cast.Int64(val.Interface()))
|
||||||
|
} else if strings.Contains(field.Type().String(), "float") {
|
||||||
|
field.Set(reflect.New(field.Type().Elem()))
|
||||||
|
field.Elem().SetFloat(cast.Float64(val.Interface()))
|
||||||
|
} else {
|
||||||
|
field.Set(val.Addr())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
convertedObject := reflect.New(field.Type())
|
||||||
|
if s, ok := val.Interface().(string); ok {
|
||||||
|
storedValue := new(any)
|
||||||
|
if s != "" {
|
||||||
|
_ = json.Unmarshal([]byte(s), storedValue)
|
||||||
|
}
|
||||||
|
convert.To(storedValue, convertedObject.Interface())
|
||||||
|
field.Set(convertedObject.Elem())
|
||||||
|
} else {
|
||||||
|
convert.To(val.Interface(), convertedObject.Interface())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if field.Type().AssignableTo(val.Type()) {
|
||||||
|
if val.Kind() == reflect.String {
|
||||||
|
field.Set(fixValue(col.DatabaseTypeName(), val))
|
||||||
|
} else {
|
||||||
|
field.Set(val)
|
||||||
|
}
|
||||||
|
} else if val.Type().String() == "string" {
|
||||||
|
field.Set(fixValue(col.DatabaseTypeName(), val))
|
||||||
|
} else if strings.Contains(val.Type().String(), "int") {
|
||||||
|
field.SetInt(val.Int())
|
||||||
|
} else if strings.Contains(val.Type().String(), "float") {
|
||||||
|
field.SetFloat(val.Float())
|
||||||
|
} else {
|
||||||
|
field.Set(val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if rowType.Kind() == reflect.Map {
|
||||||
|
if resultsValue.Kind() == reflect.Slice {
|
||||||
|
data = reflect.MakeMap(rowType)
|
||||||
|
} else {
|
||||||
|
data = resultsValue
|
||||||
|
isNew = false
|
||||||
|
}
|
||||||
|
for colIndex, col := range colTypes {
|
||||||
|
valuePtr := reflect.ValueOf(scanValues[colIndex]).Elem()
|
||||||
|
if !valuePtr.IsNil() {
|
||||||
|
data.SetMapIndex(reflect.ValueOf(col.Name()), fixValue(col.DatabaseTypeName(), valuePtr.Elem()))
|
||||||
|
} else {
|
||||||
|
data.SetMapIndex(reflect.ValueOf(col.Name()), fixValue(col.DatabaseTypeName(), reflect.New(rowType.Elem()).Elem()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if rowType.Kind() == reflect.Slice {
|
||||||
|
data = reflect.MakeSlice(rowType, colNum, colNum)
|
||||||
|
for colIndex, col := range colTypes {
|
||||||
|
valuePtr := reflect.ValueOf(scanValues[colIndex]).Elem()
|
||||||
|
if !valuePtr.IsNil() {
|
||||||
|
data.Index(colIndex).Set(fixValue(col.DatabaseTypeName(), valuePtr.Elem()))
|
||||||
|
} else {
|
||||||
|
data.Index(colIndex).Set(fixValue(col.DatabaseTypeName(), reflect.New(rowType.Elem()).Elem()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
valuePtr := reflect.ValueOf(scanValues[0]).Elem()
|
||||||
|
if !valuePtr.IsNil() {
|
||||||
|
data = fixValue(colTypes[0].DatabaseTypeName(), valuePtr.Elem())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if resultsValue.Kind() == reflect.Slice {
|
||||||
|
if originRowType.Kind() == reflect.Ptr {
|
||||||
|
resultsValue = reflect.Append(resultsValue, data.Addr())
|
||||||
|
} else {
|
||||||
|
resultsValue = reflect.Append(resultsValue, data)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
resultsValue = data
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if isNew && resultsValue.IsValid() {
|
||||||
|
reflect.ValueOf(results).Elem().Set(resultsValue)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func fixValue(colType string, v reflect.Value) reflect.Value {
|
||||||
|
if v.Kind() == reflect.String {
|
||||||
|
str := v.String()
|
||||||
|
switch colType {
|
||||||
|
case "DATE":
|
||||||
|
if len(str) >= 10 && str[4] == '-' && str[7] == '-' {
|
||||||
|
return reflect.ValueOf(str[:10])
|
||||||
|
}
|
||||||
|
case "DATETIME":
|
||||||
|
if len(str) >= 19 && str[10] == 'T' && str[4] == '-' && str[7] == '-' && str[13] == ':' && str[16] == ':' {
|
||||||
|
str = strings.TrimRight(str, "Z")
|
||||||
|
if len(str) > 19 && str[19] == '.' {
|
||||||
|
return reflect.ValueOf(str[:10] + " " + str[11:])
|
||||||
|
}
|
||||||
|
return reflect.ValueOf(str[:10] + " " + str[11:19])
|
||||||
|
}
|
||||||
|
case "TIME":
|
||||||
|
if len(str) >= 8 && str[2] == ':' && str[4] == ':' {
|
||||||
|
if len(str) >= 15 && str[8] == '.' {
|
||||||
|
return reflect.ValueOf(str[0:15])
|
||||||
|
}
|
||||||
|
return reflect.ValueOf(str[0:8])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *QueryResult) getColumnTypes() ([]*sql.ColumnType, error) {
|
||||||
|
if r.rows == nil {
|
||||||
|
return nil, errors.New("not a valid query result")
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.rows.ColumnTypes()
|
||||||
|
}
|
||||||
|
|
||||||
|
func makePublicVarName(name string) string {
|
||||||
|
if len(name) > 0 && name[0] >= 'a' && name[0] <= 'z' {
|
||||||
|
return string(name[0]-32) + name[1:]
|
||||||
|
}
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeValue(t reflect.Type) any {
|
||||||
|
if t == nil {
|
||||||
|
return new(*string)
|
||||||
|
}
|
||||||
|
for t.Kind() == reflect.Ptr {
|
||||||
|
t = t.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
switch t.Kind() {
|
||||||
|
case reflect.Int:
|
||||||
|
return new(*int)
|
||||||
|
case reflect.Int8:
|
||||||
|
return new(*int8)
|
||||||
|
case reflect.Int16:
|
||||||
|
return new(*int16)
|
||||||
|
case reflect.Int32:
|
||||||
|
return new(*int32)
|
||||||
|
case reflect.Int64:
|
||||||
|
return new(*int64)
|
||||||
|
case reflect.Uint:
|
||||||
|
return new(*uint)
|
||||||
|
case reflect.Uint8:
|
||||||
|
return new(*uint8)
|
||||||
|
case reflect.Uint16:
|
||||||
|
return new(*uint16)
|
||||||
|
case reflect.Uint32:
|
||||||
|
return new(*uint32)
|
||||||
|
case reflect.Uint64:
|
||||||
|
return new(*uint64)
|
||||||
|
case reflect.Float32:
|
||||||
|
return new(*float32)
|
||||||
|
case reflect.Float64:
|
||||||
|
return new(*float64)
|
||||||
|
case reflect.Bool:
|
||||||
|
return new(*bool)
|
||||||
|
case reflect.String:
|
||||||
|
return new(*string)
|
||||||
|
}
|
||||||
|
|
||||||
|
if t.Kind() == reflect.Slice && t.Elem().Kind() == reflect.Uint8 {
|
||||||
|
return new(*[]byte)
|
||||||
|
}
|
||||||
|
|
||||||
|
return new(*string)
|
||||||
|
}
|
||||||
28
SSL.go
Normal file
28
SSL.go
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
|
||||||
|
"apigo.cc/go/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BuildTLSConfig(ca, cert, key []byte, insecure bool) *tls.Config {
|
||||||
|
caPool := x509.NewCertPool()
|
||||||
|
if !caPool.AppendCertsFromPEM(ca) {
|
||||||
|
log.DefaultLogger.Error("ca error for db")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
certs, err := tls.X509KeyPair(cert, key)
|
||||||
|
if err != nil {
|
||||||
|
log.DefaultLogger.Error(err.Error())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{certs},
|
||||||
|
RootCAs: caPool,
|
||||||
|
InsecureSkipVerify: insecure,
|
||||||
|
}
|
||||||
|
}
|
||||||
44
Stmt.go
Normal file
44
Stmt.go
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"apigo.cc/go/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Stmt struct {
|
||||||
|
conn *sql.Stmt
|
||||||
|
lastSql *string
|
||||||
|
lastArgs []any
|
||||||
|
Error error
|
||||||
|
logger *dbLogger
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stmt *Stmt) Exec(args ...any) *ExecResult {
|
||||||
|
stmt.lastArgs = args
|
||||||
|
if stmt.conn == nil {
|
||||||
|
return &ExecResult{Sql: stmt.lastSql, Args: stmt.lastArgs, usedTime: -1, logger: stmt.logger, Error: errors.New("operate on a bad connection")}
|
||||||
|
}
|
||||||
|
startTime := time.Now()
|
||||||
|
r, err := stmt.conn.Exec(args...)
|
||||||
|
endTime := time.Now()
|
||||||
|
usedTime := log.MakeUsedTime(startTime, endTime)
|
||||||
|
if err != nil {
|
||||||
|
stmt.logger.LogQueryError(err.Error(), *stmt.lastSql, stmt.lastArgs, usedTime)
|
||||||
|
return &ExecResult{Sql: stmt.lastSql, Args: stmt.lastArgs, usedTime: usedTime, logger: stmt.logger, Error: err}
|
||||||
|
}
|
||||||
|
return &ExecResult{Sql: stmt.lastSql, Args: stmt.lastArgs, usedTime: usedTime, logger: stmt.logger, result: r}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stmt *Stmt) Close() error {
|
||||||
|
if stmt.conn == nil {
|
||||||
|
return errors.New("operate on a bad connection")
|
||||||
|
}
|
||||||
|
err := stmt.conn.Close()
|
||||||
|
if err != nil {
|
||||||
|
stmt.logger.LogQueryError(err.Error(), *stmt.lastSql, stmt.lastArgs, -1)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
23
TEST.md
Normal file
23
TEST.md
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
# Test Results for @go/db
|
||||||
|
|
||||||
|
## 📊 Summary
|
||||||
|
- **Module**: `apigo.cc/go/db`
|
||||||
|
- **Total Tests**: 4
|
||||||
|
- **Passed**: 4
|
||||||
|
- **Failed**: 0
|
||||||
|
- **Build Status**: Success
|
||||||
|
- **Date**: 2026-05-03
|
||||||
|
|
||||||
|
## ✅ Details
|
||||||
|
| Test Case | Status | Duration | Notes |
|
||||||
|
| :--- | :--- | :--- | :--- |
|
||||||
|
| `TestMakeInsertSql` | PASS | 0.00s | Verified SQL generation logic for Struct models |
|
||||||
|
| `TestBaseSelect` | PASS | 0.00s | Verified Result binding (Struct, Map, Base types) |
|
||||||
|
| `TestInsertReplaceUpdateDelete` | PASS | 0.01s | Verified CRUD operations with SQLite |
|
||||||
|
| `TestTransaction` | PASS | 0.03s | Verified Transaction isolation and Rollback/Commit |
|
||||||
|
|
||||||
|
## 🚀 Benchmarks
|
||||||
|
| Benchmark | Iterations | Time/op | Conn |
|
||||||
|
| :--- | :--- | :--- | :--- |
|
||||||
|
| `BenchmarkForPool` | - | - | Passed (Manual run verified pool reuse) |
|
||||||
|
| `BenchmarkForPoolParallel` | - | - | Passed (Manual run verified high concurrency) |
|
||||||
183
Tx.go
Normal file
183
Tx.go
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Tx struct {
|
||||||
|
conn *sql.Tx
|
||||||
|
lastSql *string
|
||||||
|
lastArgs []any
|
||||||
|
Error error
|
||||||
|
logger *dbLogger
|
||||||
|
logSlow time.Duration
|
||||||
|
isCommittedOrRollbacked bool
|
||||||
|
QuoteTag string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) Quote(text string) string {
|
||||||
|
return quote(tx.QuoteTag, text)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) Quotes(texts []string) string {
|
||||||
|
return quotes(tx.QuoteTag, texts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) Commit() error {
|
||||||
|
if tx.isCommittedOrRollbacked {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if tx.conn == nil {
|
||||||
|
return errors.New("operate on a bad connection")
|
||||||
|
}
|
||||||
|
err := tx.conn.Commit()
|
||||||
|
if err != nil {
|
||||||
|
tx.logger.LogQueryError(err.Error(), *tx.lastSql, tx.lastArgs, -1)
|
||||||
|
} else {
|
||||||
|
tx.isCommittedOrRollbacked = true
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) Rollback() error {
|
||||||
|
if tx.isCommittedOrRollbacked {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if tx.conn == nil {
|
||||||
|
return errors.New("operate on a bad connection")
|
||||||
|
}
|
||||||
|
err := tx.conn.Rollback()
|
||||||
|
if err != nil {
|
||||||
|
tx.logger.LogQueryError(err.Error(), *tx.lastSql, tx.lastArgs, -1)
|
||||||
|
} else {
|
||||||
|
tx.isCommittedOrRollbacked = true
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) Finish(ok bool) error {
|
||||||
|
if tx.isCommittedOrRollbacked {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
||||||
|
return tx.Rollback()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) CheckFinished() error {
|
||||||
|
if tx.isCommittedOrRollbacked {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return tx.Rollback()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) Prepare(query string) *Stmt {
|
||||||
|
tx.lastSql = &query
|
||||||
|
r := basePrepare(nil, tx.conn, query)
|
||||||
|
r.logger = tx.logger
|
||||||
|
if r.Error != nil {
|
||||||
|
tx.logger.LogQueryError(r.Error.Error(), *tx.lastSql, tx.lastArgs, -1)
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) Exec(query string, args ...any) *ExecResult {
|
||||||
|
tx.lastSql = &query
|
||||||
|
tx.lastArgs = args
|
||||||
|
r := baseExec(nil, tx.conn, query, args...)
|
||||||
|
r.logger = tx.logger
|
||||||
|
if r.Error != nil {
|
||||||
|
tx.logger.LogQueryError(r.Error.Error(), *tx.lastSql, tx.lastArgs, r.usedTime)
|
||||||
|
} else {
|
||||||
|
if tx.logSlow > 0 && r.usedTime >= float32(tx.logSlow/time.Millisecond) {
|
||||||
|
tx.logger.LogQuery(*tx.lastSql, tx.lastArgs, r.usedTime)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) Query(query string, args ...any) *QueryResult {
|
||||||
|
tx.lastSql = &query
|
||||||
|
tx.lastArgs = args
|
||||||
|
r := baseQuery(nil, tx.conn, query, args...)
|
||||||
|
r.logger = tx.logger
|
||||||
|
if r.Error != nil {
|
||||||
|
tx.logger.LogQueryError(r.Error.Error(), *tx.lastSql, tx.lastArgs, r.usedTime)
|
||||||
|
} else {
|
||||||
|
if tx.logSlow > 0 && r.usedTime >= float32(tx.logSlow/time.Millisecond) {
|
||||||
|
tx.logger.LogQuery(*tx.lastSql, tx.lastArgs, r.usedTime)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) Insert(table string, data any) *ExecResult {
|
||||||
|
query, values := tx.MakeInsertSql(table, data, false)
|
||||||
|
tx.lastSql = &query
|
||||||
|
tx.lastArgs = values
|
||||||
|
r := baseExec(nil, tx.conn, query, values...)
|
||||||
|
r.logger = tx.logger
|
||||||
|
if r.Error != nil {
|
||||||
|
tx.logger.LogQueryError(r.Error.Error(), *tx.lastSql, tx.lastArgs, r.usedTime)
|
||||||
|
} else {
|
||||||
|
if tx.logSlow > 0 && r.usedTime >= float32(tx.logSlow/time.Millisecond) {
|
||||||
|
tx.logger.LogQuery(*tx.lastSql, tx.lastArgs, r.usedTime)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) Replace(table string, data any) *ExecResult {
|
||||||
|
query, values := tx.MakeInsertSql(table, data, true)
|
||||||
|
tx.lastSql = &query
|
||||||
|
tx.lastArgs = values
|
||||||
|
r := baseExec(nil, tx.conn, query, values...)
|
||||||
|
r.logger = tx.logger
|
||||||
|
if r.Error != nil {
|
||||||
|
tx.logger.LogQueryError(r.Error.Error(), *tx.lastSql, tx.lastArgs, r.usedTime)
|
||||||
|
} else {
|
||||||
|
if tx.logSlow > 0 && r.usedTime >= float32(tx.logSlow/time.Millisecond) {
|
||||||
|
tx.logger.LogQuery(*tx.lastSql, tx.lastArgs, r.usedTime)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) Update(table string, data any, conditions string, args ...any) *ExecResult {
|
||||||
|
query, values := tx.MakeUpdateSql(table, data, conditions, args...)
|
||||||
|
tx.lastSql = &query
|
||||||
|
tx.lastArgs = values
|
||||||
|
r := baseExec(nil, tx.conn, query, values...)
|
||||||
|
r.logger = tx.logger
|
||||||
|
if r.Error != nil {
|
||||||
|
tx.logger.LogQueryError(r.Error.Error(), *tx.lastSql, tx.lastArgs, r.usedTime)
|
||||||
|
} else {
|
||||||
|
if tx.logSlow > 0 && r.usedTime >= float32(tx.logSlow/time.Millisecond) {
|
||||||
|
tx.logger.LogQuery(*tx.lastSql, tx.lastArgs, r.usedTime)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) Delete(table string, conditions string, args ...any) *ExecResult {
|
||||||
|
if conditions != "" {
|
||||||
|
conditions = " where " + conditions
|
||||||
|
}
|
||||||
|
query := fmt.Sprintf("delete from %s%s", tx.Quote(table), conditions)
|
||||||
|
tx.lastSql = &query
|
||||||
|
tx.lastArgs = args
|
||||||
|
r := baseExec(nil, tx.conn, query, args...)
|
||||||
|
r.logger = tx.logger
|
||||||
|
if r.Error != nil {
|
||||||
|
tx.logger.LogQueryError(r.Error.Error(), *tx.lastSql, tx.lastArgs, r.usedTime)
|
||||||
|
} else {
|
||||||
|
if tx.logSlow > 0 && r.usedTime >= float32(tx.logSlow/time.Millisecond) {
|
||||||
|
tx.logger.LogQuery(*tx.lastSql, tx.lastArgs, r.usedTime)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
12
db.json.sample
Normal file
12
db.json.sample
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
{
|
||||||
|
"test": {
|
||||||
|
"type": "mysql",
|
||||||
|
"user": "star",
|
||||||
|
"password": "...",
|
||||||
|
"host": "localhost:3306",
|
||||||
|
"db": "test",
|
||||||
|
"maxOpens": 100,
|
||||||
|
"maxIdles": 30,
|
||||||
|
"maxLifeTime": 3600
|
||||||
|
}
|
||||||
|
}
|
||||||
11
dbInit.go.sample
Normal file
11
dbInit.go.sample
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import "apigo.cc/go/db"
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// 强烈建议在应用启动时设置自定义加密密钥,确保数据库密码存储安全
|
||||||
|
// key 和 iv 建议通过环境变量或安全存储获取,不要直接硬编码在生产代码中
|
||||||
|
key := []byte("your-32-byte-long-secure-key----")
|
||||||
|
iv := []byte("your-16-byte-iv-")
|
||||||
|
db.SetEncryptKeys(key, iv)
|
||||||
|
}
|
||||||
43
go.mod
Normal file
43
go.mod
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
module apigo.cc/go/db
|
||||||
|
|
||||||
|
go 1.25.0
|
||||||
|
|
||||||
|
require (
|
||||||
|
apigo.cc/go/cast v1.1.1
|
||||||
|
apigo.cc/go/config v1.0.4
|
||||||
|
apigo.cc/go/convert v1.0.4
|
||||||
|
apigo.cc/go/crypto v1.0.4
|
||||||
|
apigo.cc/go/id v1.0.4
|
||||||
|
apigo.cc/go/log v1.0.0
|
||||||
|
apigo.cc/go/rand v1.0.4
|
||||||
|
apigo.cc/go/safe v1.0.4
|
||||||
|
apigo.cc/go/shell v1.0.4
|
||||||
|
github.com/go-sql-driver/mysql v1.10.0
|
||||||
|
github.com/jackc/pgx/v5 v5.9.2
|
||||||
|
github.com/mitchellh/mapstructure v1.5.0
|
||||||
|
modernc.org/sqlite v1.50.0
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
apigo.cc/go/encoding v1.0.4 // indirect
|
||||||
|
apigo.cc/go/file v1.0.4 // indirect
|
||||||
|
filippo.io/edwards25519 v1.2.0 // indirect
|
||||||
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
|
github.com/google/uuid v1.6.0 // indirect
|
||||||
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||||
|
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||||
|
github.com/kr/text v0.2.0 // indirect
|
||||||
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
|
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||||
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||||
|
github.com/rogpeppe/go-internal v1.14.1 // indirect
|
||||||
|
golang.org/x/crypto v0.50.0 // indirect
|
||||||
|
golang.org/x/sync v0.20.0 // indirect
|
||||||
|
golang.org/x/sys v0.43.0 // indirect
|
||||||
|
golang.org/x/text v0.36.0 // indirect
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
|
modernc.org/libc v1.72.0 // indirect
|
||||||
|
modernc.org/mathutil v1.7.1 // indirect
|
||||||
|
modernc.org/memory v1.11.0 // indirect
|
||||||
|
)
|
||||||
114
go.sum
Normal file
114
go.sum
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
apigo.cc/go/cast v1.1.1 h1:+5pluN8g1RK2J4byr2xkfOmEdKSmy1PByOqDOHtt/Ns=
|
||||||
|
apigo.cc/go/cast v1.1.1/go.mod h1:vh9ZqISCmTUiyinkNMI/s4f045fRlDK3xC+nPWQYBzI=
|
||||||
|
apigo.cc/go/config v1.0.4 h1:WG9zrQkqfFPkrKIL7RNvvAbbkuUBt1Av11ZP/aIfldM=
|
||||||
|
apigo.cc/go/config v1.0.4/go.mod h1:obryzJiK6j7lQex/58d5eWYOGx5O5IABguqNWxyyXJo=
|
||||||
|
apigo.cc/go/convert v1.0.4 h1:5+qPjC3dlPB59GnWZRlmthxcaXQtKvN+iOuiLdJ1GvQ=
|
||||||
|
apigo.cc/go/convert v1.0.4/go.mod h1:Hp+geeSyhqg/zwIKPOrDoceIREzcwM14t1I5q/dtbfU=
|
||||||
|
apigo.cc/go/crypto v1.0.4 h1:VPUyHCH2N3LLEgdpwUc+DQssNHzLlxVzLNRa0Jm6O4o=
|
||||||
|
apigo.cc/go/crypto v1.0.4/go.mod h1:5sI8BLw6YHZfDReYwCO3TFD2LKm36HMdLg1S5oPv/QU=
|
||||||
|
apigo.cc/go/encoding v1.0.4 h1:aezB0J/qFuHs6iXkbtuJP5JIHUtmjsr5SFb0NNvbObY=
|
||||||
|
apigo.cc/go/encoding v1.0.4/go.mod h1:V5CgT7rBbCxy+uCU20q0ptcNNRSgMtpA8cNOs6r8IeI=
|
||||||
|
apigo.cc/go/file v1.0.4 h1:qCKegV7OYh7r0qc3jZjGA/aKh0vIHgmr1OEbhfEmGX8=
|
||||||
|
apigo.cc/go/file v1.0.4/go.mod h1:C9gNo7386iA21OiBmuWh6CznKWlVBDFkhE4f0H0Susg=
|
||||||
|
apigo.cc/go/id v1.0.4 h1:w+JSdeVit52iefIUolrh1qLEZS9XqHNKr1UygFcgv+s=
|
||||||
|
apigo.cc/go/id v1.0.4/go.mod h1:kg7QuceAKtGNzGWt0+pIIh8Qom1eMSWGb8+0Yhi/QVY=
|
||||||
|
apigo.cc/go/log v1.0.0 h1:lI1NGTSS+Jm12G8BD7ZJO4/hrkfuLTu5O8z36GD8GpU=
|
||||||
|
apigo.cc/go/log v1.0.0/go.mod h1:tvPgFpebY9Wf/DlqMHZ0ZjxDp9AaQTywOQKvtBaNqNo=
|
||||||
|
apigo.cc/go/rand v1.0.4 h1:we070eWSL0dB8NEMaWjXj43+EekXQTm/h0kKpZ/frqw=
|
||||||
|
apigo.cc/go/rand v1.0.4/go.mod h1:mZ/4Soa3bk+XvDaqPWJuUe1bfEi4eThBj1XmEAuYxsk=
|
||||||
|
apigo.cc/go/safe v1.0.4 h1:07pRSdEHprF/2v6SsqAjICYFoeLcqjjvHGEdh6Dzrzg=
|
||||||
|
apigo.cc/go/safe v1.0.4/go.mod h1:o568sHS5rTRSVPmhxWod0tGdc+8l1KjidsNY1/OVZr0=
|
||||||
|
apigo.cc/go/shell v1.0.4 h1:EL9zjI39YBe1h+kRYQeAi/8zVGHe5W198DYYN7cENiY=
|
||||||
|
apigo.cc/go/shell v1.0.4/go.mod h1:N2gDkgK4tJ9TadD60/+gAGuWxyVAWHs5YPBmytw6ELA=
|
||||||
|
filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo=
|
||||||
|
filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc=
|
||||||
|
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||||
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||||
|
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||||
|
github.com/go-sql-driver/mysql v1.10.0 h1:Q+1LV8DkHJvSYAdR83XzuhDaTykuDx0l6fkXxoWCWfw=
|
||||||
|
github.com/go-sql-driver/mysql v1.10.0/go.mod h1:M+cqaI7+xxXGG9swrdeUIoPG3Y3KCkF0pZej+SK+nWk=
|
||||||
|
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||||
|
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||||
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
||||||
|
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
||||||
|
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||||
|
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||||
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||||
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||||
|
github.com/jackc/pgx/v5 v5.9.2 h1:3ZhOzMWnR4yJ+RW1XImIPsD1aNSz4T4fyP7zlQb56hw=
|
||||||
|
github.com/jackc/pgx/v5 v5.9.2/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4=
|
||||||
|
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||||
|
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||||
|
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
|
||||||
|
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
|
||||||
|
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||||
|
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||||
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
||||||
|
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
||||||
|
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||||
|
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||||
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||||
|
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||||
|
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||||
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
|
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
|
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
|
||||||
|
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
|
||||||
|
golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI=
|
||||||
|
golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY=
|
||||||
|
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||||
|
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||||
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
|
||||||
|
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||||
|
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
|
||||||
|
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
|
||||||
|
golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s=
|
||||||
|
golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||||
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
modernc.org/cc/v4 v4.27.3 h1:uNCgn37E5U09mTv1XgskEVUJ8ADKpmFMPxzGJ0TSo+U=
|
||||||
|
modernc.org/cc/v4 v4.27.3/go.mod h1:3YjcbCqhoTTHPycJDRl2WZKKFj0nwcOIPBfEZK0Hdk8=
|
||||||
|
modernc.org/ccgo/v4 v4.32.4 h1:L5OB8rpEX4ZsXEQwGozRfJyJSFHbbNVOoQ59DU9/KuU=
|
||||||
|
modernc.org/ccgo/v4 v4.32.4/go.mod h1:lY7f+fiTDHfcv6YlRgSkxYfhs+UvOEEzj49jAn2TOx0=
|
||||||
|
modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM=
|
||||||
|
modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU=
|
||||||
|
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
|
||||||
|
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
|
||||||
|
modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo=
|
||||||
|
modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
|
||||||
|
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
|
||||||
|
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
|
||||||
|
modernc.org/libc v1.72.0 h1:IEu559v9a0XWjw0DPoVKtXpO2qt5NVLAnFaBbjq+n8c=
|
||||||
|
modernc.org/libc v1.72.0/go.mod h1:tTU8DL8A+XLVkEY3x5E/tO7s2Q/q42EtnNWda/L5QhQ=
|
||||||
|
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||||
|
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||||
|
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||||
|
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
||||||
|
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
||||||
|
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||||
|
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||||
|
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||||
|
modernc.org/sqlite v1.50.0 h1:eMowQSWLK0MeiQTdmz3lqoF5dqclujdlIKeJA11+7oM=
|
||||||
|
modernc.org/sqlite v1.50.0/go.mod h1:m0w8xhwYUVY3H6pSDwc3gkJ/irZT/0YEXwBlhaxQEew=
|
||||||
|
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||||
|
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||||
|
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||||
|
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
|
||||||
48
mysql/connector.go
Normal file
48
mysql/connector.go
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
package mysql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"database/sql/driver"
|
||||||
|
|
||||||
|
"github.com/go-sql-driver/mysql"
|
||||||
|
"apigo.cc/go/db"
|
||||||
|
"apigo.cc/go/safe"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SecureMysqlConnector struct {
|
||||||
|
conf *db.Config
|
||||||
|
pwd *safe.SafeBuf
|
||||||
|
tls *tls.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SecureMysqlConnector) Connect(ctx context.Context) (driver.Conn, error) {
|
||||||
|
cfg, err := mysql.ParseDSN(c.conf.Args)
|
||||||
|
if err != nil {
|
||||||
|
cfg = mysql.NewConfig()
|
||||||
|
}
|
||||||
|
cfg.User = c.conf.User
|
||||||
|
cfg.Net = "tcp"
|
||||||
|
cfg.Addr = c.conf.Host
|
||||||
|
cfg.DBName = c.conf.DB
|
||||||
|
cfg.TLS = c.tls
|
||||||
|
pwdBuf := c.pwd.Open()
|
||||||
|
defer pwdBuf.Close()
|
||||||
|
cfg.Passwd = pwdBuf.String()
|
||||||
|
|
||||||
|
tmpConnector, err := mysql.NewConnector(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return tmpConnector.Connect(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SecureMysqlConnector) Driver() driver.Driver {
|
||||||
|
return &mysql.MySQLDriver{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
db.RegisterConnector("mysql", func(conf *db.Config, pwd *safe.SafeBuf, tls *tls.Config) driver.Connector {
|
||||||
|
return &SecureMysqlConnector{conf: conf, pwd: pwd, tls: tls}
|
||||||
|
})
|
||||||
|
}
|
||||||
48
pgx/connector.go
Normal file
48
pgx/connector.go
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
package pgx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"database/sql/driver"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
|
"github.com/jackc/pgx/v5/stdlib"
|
||||||
|
"apigo.cc/go/db"
|
||||||
|
"apigo.cc/go/safe"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SecurePgxConnector struct {
|
||||||
|
conf *db.Config
|
||||||
|
pwd *safe.SafeBuf
|
||||||
|
tls *tls.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SecurePgxConnector) Connect(ctx context.Context) (driver.Conn, error) {
|
||||||
|
dsn := fmt.Sprintf("postgres://%s@%s/%s?%s", c.conf.User, c.conf.Host, c.conf.DB, c.conf.Args)
|
||||||
|
pgxConfig, err := pgx.ParseConfig(dsn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if c.tls != nil {
|
||||||
|
pgxConfig.TLSConfig = c.tls
|
||||||
|
}
|
||||||
|
pwdBuf := c.pwd.Open()
|
||||||
|
defer pwdBuf.Close()
|
||||||
|
pgxConfig.Password = pwdBuf.String()
|
||||||
|
tmpConnector := stdlib.GetConnector(*pgxConfig)
|
||||||
|
return tmpConnector.Connect(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SecurePgxConnector) Driver() driver.Driver {
|
||||||
|
return stdlib.GetDefaultDriver()
|
||||||
|
}
|
||||||
|
|
||||||
|
func PgxConnector(conf *db.Config, pwd *safe.SafeBuf, tls *tls.Config) driver.Connector {
|
||||||
|
return &SecurePgxConnector{conf: conf, pwd: pwd, tls: tls}
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
db.RegisterConnector("postgres", PgxConnector)
|
||||||
|
db.RegisterConnector("pgx", PgxConnector)
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user