db/Base.go

317 lines
8.4 KiB
Go
Raw Permalink Normal View History

package db
import (
"database/sql"
"errors"
"fmt"
"reflect"
"strings"
"sync"
"time"
"apigo.cc/go/cast"
"apigo.cc/go/log"
)
var structFieldsCache = sync.Map{}
type structFieldInfo struct {
name string
index []int
}
func getStructFields(typ reflect.Type) []structFieldInfo {
if v, ok := structFieldsCache.Load(typ); ok {
return v.([]structFieldInfo)
}
var fields []structFieldInfo
flattenFields(typ, nil, &fields)
structFieldsCache.Store(typ, fields)
return fields
}
func flattenFields(typ reflect.Type, index []int, fields *[]structFieldInfo) {
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
if typ.Kind() != reflect.Struct {
return
}
for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i)
newIndex := make([]int, len(index)+len(f.Index))
copy(newIndex, index)
copy(newIndex[len(index):], f.Index)
if f.Anonymous && f.Type.Kind() == reflect.Struct {
flattenFields(f.Type, newIndex, fields)
} else {
if f.Name[0] >= 'A' && f.Name[0] <= 'Z' {
*fields = append(*fields, structFieldInfo{name: f.Name, index: newIndex})
}
}
}
}
func basePrepare(db *sql.DB, tx *sql.Tx, query string) *Stmt {
var sqlStmt *sql.Stmt
var err error
if tx != nil {
sqlStmt, err = tx.Prepare(query)
} else if db != nil {
sqlStmt, err = db.Prepare(query)
} else {
return &Stmt{Error: errors.New("operate on a bad connection")}
}
if err != nil {
return &Stmt{Error: err}
}
return &Stmt{conn: sqlStmt, lastSql: &query}
}
func baseExec(db *sql.DB, tx *sql.Tx, query string, args ...any) *ExecResult {
args = flatArgs(args)
var r sql.Result
var err error
startTime := time.Now()
if tx != nil {
r, err = tx.Exec(query, args...)
} else if db != nil {
r, err = db.Exec(query, args...)
} else {
return &ExecResult{Sql: &query, Args: args, usedTime: log.MakeUsedTime(startTime, time.Now()), Error: errors.New("operate on a bad connection")}
}
endTime := time.Now()
usedTime := log.MakeUsedTime(startTime, endTime)
if err != nil {
return &ExecResult{Sql: &query, Args: args, usedTime: usedTime, Error: err}
}
return &ExecResult{Sql: &query, Args: args, usedTime: usedTime, result: r}
}
func flatArgs(args []any) []any {
for i, arg := range args {
if arg == nil {
continue
}
argValue := reflect.ValueOf(arg)
kind := argValue.Kind()
if kind == reflect.Map || kind == reflect.Struct || (kind == reflect.Slice && argValue.Type().Elem().Kind() != reflect.Uint8) {
args[i] = cast.MustToJSON(arg)
}
}
return args
}
func baseQuery(db *sql.DB, tx *sql.Tx, query string, args ...any) *QueryResult {
args = flatArgs(args)
var rows *sql.Rows
var err error
startTime := time.Now()
if tx != nil {
rows, err = tx.Query(query, args...)
} else if db != nil {
rows, err = db.Query(query, args...)
} else {
return &QueryResult{Sql: &query, Args: args, usedTime: log.MakeUsedTime(startTime, time.Now()), Error: errors.New("operate on a bad connection")}
}
endTime := time.Now()
usedTime := log.MakeUsedTime(startTime, endTime)
if err != nil {
return &QueryResult{Sql: &query, Args: args, usedTime: usedTime, Error: err}
}
return &QueryResult{Sql: &query, Args: args, usedTime: usedTime, rows: rows}
}
func quote(quoteTag string, text string) string {
a := strings.Split(text, ".")
for i, v := range a {
a[i] = quoteTag + strings.ReplaceAll(v, quoteTag, "\\"+quoteTag) + quoteTag
}
return strings.Join(a, ".")
}
func quotes(quoteTag string, texts []string) string {
for i, v := range texts {
texts[i] = quote(quoteTag, v)
}
return strings.Join(texts, ",")
}
func makeInsertSql(quoteTag string, table string, data any, useReplace bool, versionField string, nextVer int64, idField string, nextId string) (string, []any) {
keys, vars, values := MakeKeysVarsValues(data)
if versionField != "" {
found := false
for _, k := range keys {
if k == versionField {
found = true
break
}
}
if !found {
keys = append(keys, versionField)
vars = append(vars, "?")
values = append(values, nextVer)
}
}
if idField != "" && nextId != "" {
found := false
for i, k := range keys {
if k == idField {
found = true
if cast.String(values[i]) == "" {
values[i] = nextId
}
break
}
}
if !found {
keys = append(keys, idField)
vars = append(vars, "?")
values = append(values, nextId)
}
}
operation := "insert"
if useReplace {
operation = "replace"
}
query := fmt.Sprintf("%s into %s (%s) values (%s)", operation, quote(quoteTag, table), quotes(quoteTag, keys), strings.Join(vars, ","))
return query, values
}
func makeUpdateSql(quoteTag string, table string, data any, conditions string, versionField string, nextVer int64, args ...any) (string, []any) {
args = flatArgs(args)
keys, vars, values := MakeKeysVarsValues(data)
newKeys := make([]string, 0, len(keys))
newValues := make([]any, 0, len(values))
var oldVersion any
for i, k := range keys {
if k == versionField {
oldVersion = values[i]
continue
}
newKeys = append(newKeys, fmt.Sprintf("%s=%s", quote(quoteTag, k), vars[i]))
newValues = append(newValues, values[i])
}
if versionField != "" {
newKeys = append(newKeys, fmt.Sprintf("%s=?", quote(quoteTag, versionField)))
newValues = append(newValues, nextVer)
}
if oldVersion != nil {
if conditions != "" {
conditions = fmt.Sprintf("(%s) and %s=?", conditions, quote(quoteTag, versionField))
} else {
conditions = fmt.Sprintf("%s=?", quote(quoteTag, versionField))
}
args = append(args, oldVersion)
}
newValues = append(newValues, args...)
if conditions != "" {
conditions = " where " + conditions
}
query := fmt.Sprintf("update %s set %s%s", quote(quoteTag, table), strings.Join(newKeys, ","), conditions)
return query, newValues
}
func (db *DB) MakeInsertSql(table string, data any, useReplace bool) (string, []any) {
ts := db.getTable(table)
nextVer := int64(0)
if ts.VersionField != "" {
nextVer = db.NextVersion(table)
}
nextId := ""
if ts.IdField != "" {
nextId = db.NextID(table)
}
return makeInsertSql(db.QuoteTag, table, data, useReplace, ts.VersionField, nextVer, ts.IdField, nextId)
}
func (db *DB) MakeUpdateSql(table string, data any, conditions string, args ...any) (string, []any) {
ts := db.getTable(table)
nextVer := int64(0)
if ts.VersionField != "" {
nextVer = db.NextVersion(table)
}
return makeUpdateSql(db.QuoteTag, table, data, conditions, ts.VersionField, nextVer, args...)
}
func (tx *Tx) MakeInsertSql(table string, data any, useReplace bool) (string, []any) {
ts := tx.db.getTable(table)
nextVer := int64(0)
if ts.VersionField != "" {
nextVer = tx.db.NextVersion(table)
}
nextId := ""
if ts.IdField != "" {
nextId = tx.db.NextID(table)
}
return makeInsertSql(tx.QuoteTag, table, data, useReplace, ts.VersionField, nextVer, ts.IdField, nextId)
}
func (tx *Tx) MakeUpdateSql(table string, data any, conditions string, args ...any) (string, []any) {
ts := tx.db.getTable(table)
nextVer := int64(0)
if ts.VersionField != "" {
nextVer = tx.db.NextVersion(table)
}
return makeUpdateSql(tx.QuoteTag, table, data, conditions, ts.VersionField, nextVer, args...)
}
func MakeKeysVarsValues(data any) ([]string, []string, []any) {
keys := make([]string, 0)
vars := make([]string, 0)
values := make([]any, 0)
dataType := reflect.TypeOf(data)
dataValue := reflect.ValueOf(data)
for dataType.Kind() == reflect.Ptr {
dataType = dataType.Elem()
dataValue = dataValue.Elem()
}
if dataType.Kind() == reflect.Struct {
fields := getStructFields(dataType)
for _, f := range fields {
v := dataValue.FieldByIndex(f.index)
if v.Kind() == reflect.Interface {
v = v.Elem()
}
keys = append(keys, f.name)
if v.Kind() == reflect.String && v.Len() > 0 && v.String()[0] == ':' {
vars = append(vars, v.String()[1:])
} else {
vars = append(vars, "?")
if !v.IsValid() || !v.CanInterface() {
values = append(values, nil)
} else {
values = append(values, v.Interface())
}
}
}
} else if dataType.Kind() == reflect.Map {
for _, k := range dataValue.MapKeys() {
v := dataValue.MapIndex(k)
if v.Kind() == reflect.Interface {
v = v.Elem()
}
keys = append(keys, cast.String(k.Interface()))
if v.Kind() == reflect.String && v.Len() > 0 && v.String()[0] == ':' {
vars = append(vars, v.String()[1:])
} else {
vars = append(vars, "?")
if !v.IsValid() || !v.CanInterface() {
values = append(values, nil)
} else {
values = append(values, v.Interface())
}
}
}
}
return keys, vars, values
}