578 lines
15 KiB
Go
578 lines
15 KiB
Go
package db
|
|
|
|
import (
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
"strings"
|
|
"time"
|
|
|
|
"apigo.cc/go/cast"
|
|
"apigo.cc/go/convert"
|
|
"github.com/mitchellh/mapstructure"
|
|
)
|
|
|
|
type QueryResult struct {
|
|
rows *sql.Rows
|
|
Sql *string
|
|
Args []any
|
|
Error error
|
|
logger *dbLogger
|
|
usedTime float32
|
|
completed bool
|
|
}
|
|
|
|
type ExecResult struct {
|
|
result sql.Result
|
|
Sql *string
|
|
Args []any
|
|
Error error
|
|
logger *dbLogger
|
|
usedTime float32
|
|
}
|
|
|
|
func (r *ExecResult) Changes() int64 {
|
|
if r.result == nil {
|
|
return 0
|
|
}
|
|
numChanges, err := r.result.RowsAffected()
|
|
if err != nil {
|
|
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
|
|
return 0
|
|
}
|
|
return numChanges
|
|
}
|
|
|
|
func (r *ExecResult) Id() int64 {
|
|
if r.result == nil {
|
|
return 0
|
|
}
|
|
insertId, err := r.result.LastInsertId()
|
|
if err != nil {
|
|
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
|
|
return 0
|
|
}
|
|
return insertId
|
|
}
|
|
|
|
func (r *QueryResult) Complete() {
|
|
if !r.completed {
|
|
if r.rows != nil {
|
|
r.rows.Close()
|
|
}
|
|
r.completed = true
|
|
}
|
|
}
|
|
|
|
func (r *QueryResult) To(result any) error {
|
|
if r.rows == nil {
|
|
return errors.New("operate on a bad query")
|
|
}
|
|
return r.makeResults(result, r.rows)
|
|
}
|
|
|
|
func (r *QueryResult) MapResults() []map[string]any {
|
|
result := make([]map[string]any, 0)
|
|
err := r.makeResults(&result, r.rows)
|
|
if err != nil {
|
|
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (r *QueryResult) SliceResults() [][]any {
|
|
result := make([][]any, 0)
|
|
err := r.makeResults(&result, r.rows)
|
|
if err != nil {
|
|
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (r *QueryResult) StringMapResults() []map[string]string {
|
|
result := make([]map[string]string, 0)
|
|
err := r.makeResults(&result, r.rows)
|
|
if err != nil {
|
|
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (r *QueryResult) StringSliceResults() [][]string {
|
|
result := make([][]string, 0)
|
|
err := r.makeResults(&result, r.rows)
|
|
if err != nil {
|
|
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (r *QueryResult) MapOnR1() map[string]any {
|
|
result := make(map[string]any)
|
|
err := r.makeResults(&result, r.rows)
|
|
if err != nil {
|
|
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (r *QueryResult) StringMapOnR1() map[string]string {
|
|
result := make(map[string]string)
|
|
err := r.makeResults(&result, r.rows)
|
|
if err != nil {
|
|
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (r *QueryResult) IntsOnC1() []int64 {
|
|
result := make([]int64, 0)
|
|
err := r.makeResults(&result, r.rows)
|
|
if err != nil {
|
|
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (r *QueryResult) StringsOnC1() []string {
|
|
result := make([]string, 0)
|
|
err := r.makeResults(&result, r.rows)
|
|
if err != nil {
|
|
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (r *QueryResult) IntOnR1C1() int64 {
|
|
var result int64 = 0
|
|
err := r.makeResults(&result, r.rows)
|
|
if err != nil {
|
|
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (r *QueryResult) FloatOnR1C1() float64 {
|
|
var result float64 = 0
|
|
err := r.makeResults(&result, r.rows)
|
|
if err != nil {
|
|
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (r *QueryResult) StringOnR1C1() string {
|
|
result := ""
|
|
err := r.makeResults(&result, r.rows)
|
|
if err != nil {
|
|
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (r *QueryResult) ToKV(target any) error {
|
|
v := reflect.ValueOf(target)
|
|
t := v.Type()
|
|
for t.Kind() == reflect.Ptr {
|
|
v = v.Elem()
|
|
t = v.Type()
|
|
}
|
|
|
|
if t.Kind() != reflect.Map {
|
|
r.logger.LogQueryError("target not a map", *r.Sql, r.Args, r.usedTime)
|
|
return errors.New("target not a map")
|
|
}
|
|
|
|
vt := t.Elem()
|
|
finalVt := vt
|
|
for finalVt.Kind() == reflect.Ptr {
|
|
finalVt = finalVt.Elem()
|
|
}
|
|
if finalVt.Kind() == reflect.Map || finalVt.Kind() == reflect.Struct {
|
|
colTypes, err := r.getColumnTypes()
|
|
list := r.MapResults()
|
|
if err != nil {
|
|
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
|
|
return err
|
|
}
|
|
for _, item := range list {
|
|
newKey := reflect.ValueOf(reflect.New(t.Key()).Interface()).Elem()
|
|
convert.To(item[colTypes[0].Name()], newKey.Addr().Interface())
|
|
|
|
newValue := v.MapIndex(newKey)
|
|
isNew := false
|
|
if !newValue.IsValid() {
|
|
newValue = reflect.New(vt)
|
|
isNew = true
|
|
}
|
|
|
|
err := mapstructure.WeakDecode(item, newValue.Interface())
|
|
if err != nil {
|
|
r.logger.LogError(err.Error())
|
|
}
|
|
|
|
if isNew {
|
|
v.SetMapIndex(newKey, newValue.Elem())
|
|
}
|
|
}
|
|
} else {
|
|
list := r.SliceResults()
|
|
for _, item := range list {
|
|
if len(item) < 2 {
|
|
continue
|
|
}
|
|
switch vt.Kind() {
|
|
case reflect.Int:
|
|
v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(cast.Int(item[1])))
|
|
case reflect.Int8:
|
|
v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(int8(cast.Int(item[1]))))
|
|
case reflect.Int16:
|
|
v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(int16(cast.Int(item[1]))))
|
|
case reflect.Int32:
|
|
v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(int32(cast.Int(item[1]))))
|
|
case reflect.Int64:
|
|
v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(cast.Int64(item[1])))
|
|
case reflect.Uint:
|
|
v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(uint(cast.Int(item[1]))))
|
|
case reflect.Uint8:
|
|
v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(uint8(cast.Int(item[1]))))
|
|
case reflect.Uint16:
|
|
v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(uint16(cast.Int(item[1]))))
|
|
case reflect.Uint32:
|
|
v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(uint32(cast.Int(item[1]))))
|
|
case reflect.Uint64:
|
|
v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(cast.Uint64(item[1])))
|
|
case reflect.Float32:
|
|
v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(cast.Float(item[1])))
|
|
case reflect.Float64:
|
|
v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(cast.Float64(item[1])))
|
|
case reflect.Bool:
|
|
v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(cast.Bool(item[1])))
|
|
case reflect.String:
|
|
v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(cast.String(item[1])))
|
|
case reflect.Interface:
|
|
v.SetMapIndex(reflect.ValueOf(cast.String(item[0])), reflect.ValueOf(item[1]))
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *QueryResult) makeResults(results any, rows *sql.Rows) error {
|
|
if rows == nil {
|
|
return errors.New("not a valid query result")
|
|
}
|
|
|
|
defer func() {
|
|
_ = rows.Close()
|
|
r.completed = true
|
|
}()
|
|
resultsValue := reflect.ValueOf(results)
|
|
if resultsValue.Kind() != reflect.Ptr {
|
|
return fmt.Errorf("results must be a pointer")
|
|
}
|
|
|
|
for resultsValue.Kind() == reflect.Ptr {
|
|
resultsValue = resultsValue.Elem()
|
|
}
|
|
rowType := resultsValue.Type()
|
|
|
|
colTypes, err := rows.ColumnTypes()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
colNum := len(colTypes)
|
|
originRowType := rowType
|
|
if rowType.Kind() == reflect.Slice {
|
|
rowType = rowType.Elem()
|
|
originRowType = rowType
|
|
for rowType.Kind() == reflect.Ptr {
|
|
rowType = rowType.Elem()
|
|
}
|
|
}
|
|
|
|
scanValues := make([]any, colNum)
|
|
var fieldInfos []struct {
|
|
index []int
|
|
typ reflect.Type
|
|
name string
|
|
}
|
|
|
|
if rowType.Kind() == reflect.Struct {
|
|
fieldInfos = make([]struct {
|
|
index []int
|
|
typ reflect.Type
|
|
name string
|
|
}, colNum)
|
|
for colIndex, col := range colTypes {
|
|
publicColName := makePublicVarName(col.Name())
|
|
field, found := rowType.FieldByName(publicColName)
|
|
if found {
|
|
fieldInfos[colIndex].index = field.Index
|
|
fieldInfos[colIndex].typ = field.Type
|
|
fieldInfos[colIndex].name = publicColName
|
|
if field.Type.Kind() == reflect.Interface {
|
|
scanValues[colIndex] = makeValue(colTypes[colIndex].ScanType())
|
|
} else {
|
|
scanValues[colIndex] = makeValue(field.Type)
|
|
}
|
|
} else {
|
|
fieldInfos[colIndex].index = nil
|
|
scanValues[colIndex] = makeValue(nil)
|
|
}
|
|
}
|
|
} else if rowType.Kind() == reflect.Map {
|
|
for colIndex := range colTypes {
|
|
if rowType.Elem().Kind() == reflect.Interface {
|
|
scanValues[colIndex] = makeValue(colTypes[colIndex].ScanType())
|
|
} else {
|
|
scanValues[colIndex] = makeValue(rowType.Elem())
|
|
}
|
|
}
|
|
} else if rowType.Kind() == reflect.Slice {
|
|
for colIndex := range colTypes {
|
|
if rowType.Elem().Kind() == reflect.Interface {
|
|
scanValues[colIndex] = makeValue(colTypes[colIndex].ScanType())
|
|
} else {
|
|
scanValues[colIndex] = makeValue(rowType.Elem())
|
|
}
|
|
}
|
|
} else {
|
|
if rowType.Kind() == reflect.Interface {
|
|
scanValues[0] = makeValue(colTypes[0].ScanType())
|
|
} else {
|
|
scanValues[0] = makeValue(rowType)
|
|
}
|
|
for colIndex := 1; colIndex < colNum; colIndex++ {
|
|
scanValues[colIndex] = makeValue(nil)
|
|
}
|
|
}
|
|
|
|
var data reflect.Value
|
|
isNew := true
|
|
for rows.Next() {
|
|
err = rows.Scan(scanValues...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if rowType.Kind() == reflect.Struct {
|
|
if resultsValue.Kind() == reflect.Slice {
|
|
data = reflect.New(rowType).Elem()
|
|
} else {
|
|
data = resultsValue
|
|
isNew = false
|
|
}
|
|
|
|
for colIndex, col := range colTypes {
|
|
fInfo := fieldInfos[colIndex]
|
|
if fInfo.index == nil {
|
|
continue
|
|
}
|
|
field := data.FieldByIndex(fInfo.index)
|
|
valuePtr := reflect.ValueOf(scanValues[colIndex]).Elem()
|
|
if !valuePtr.IsNil() {
|
|
val := valuePtr.Elem()
|
|
if fInfo.typ.String() == "time.Time" {
|
|
str := val.String()
|
|
tm, err := time.Parse("2006-01-02 15:04:05.000000", str)
|
|
if err != nil {
|
|
tm, err = time.Parse("2006-01-02 15:04:05", str)
|
|
}
|
|
if err == nil {
|
|
field.Set(reflect.ValueOf(tm))
|
|
}
|
|
} else if val.Kind() != field.Kind() && field.Kind() != reflect.Interface {
|
|
if field.Kind() == reflect.Ptr && val.Kind() == field.Type().Elem().Kind() {
|
|
if val.CanAddr() {
|
|
if field.Type().AssignableTo(val.Type()) {
|
|
field.Set(val.Addr())
|
|
} else if val.Type().String() == "string" {
|
|
strVal := fixValue(col.DatabaseTypeName(), val)
|
|
field.Set(reflect.New(field.Type().Elem()))
|
|
field.Elem().SetString(cast.String(strVal.Interface()))
|
|
} else if strings.Contains(field.Type().String(), "uint") {
|
|
field.Set(reflect.New(field.Type().Elem()))
|
|
field.Elem().SetUint(cast.Uint64(val.Interface()))
|
|
} else if strings.Contains(field.Type().String(), "int") {
|
|
field.Set(reflect.New(field.Type().Elem()))
|
|
field.Elem().SetInt(cast.Int64(val.Interface()))
|
|
} else if strings.Contains(field.Type().String(), "float") {
|
|
field.Set(reflect.New(field.Type().Elem()))
|
|
field.Elem().SetFloat(cast.Float64(val.Interface()))
|
|
} else {
|
|
field.Set(val.Addr())
|
|
}
|
|
}
|
|
} else {
|
|
convertedObject := reflect.New(field.Type())
|
|
if s, ok := val.Interface().(string); ok {
|
|
storedValue := new(any)
|
|
if s != "" {
|
|
_ = json.Unmarshal([]byte(s), storedValue)
|
|
}
|
|
convert.To(storedValue, convertedObject.Interface())
|
|
field.Set(convertedObject.Elem())
|
|
} else {
|
|
convert.To(val.Interface(), convertedObject.Interface())
|
|
}
|
|
}
|
|
} else if field.Type().AssignableTo(val.Type()) {
|
|
if val.Kind() == reflect.String {
|
|
field.Set(fixValue(col.DatabaseTypeName(), val))
|
|
} else {
|
|
field.Set(val)
|
|
}
|
|
} else if val.Type().String() == "string" {
|
|
field.Set(fixValue(col.DatabaseTypeName(), val))
|
|
} else if strings.Contains(val.Type().String(), "int") {
|
|
field.SetInt(val.Int())
|
|
} else if strings.Contains(val.Type().String(), "float") {
|
|
field.SetFloat(val.Float())
|
|
} else {
|
|
field.Set(val)
|
|
}
|
|
}
|
|
}
|
|
} else if rowType.Kind() == reflect.Map {
|
|
if resultsValue.Kind() == reflect.Slice {
|
|
data = reflect.MakeMap(rowType)
|
|
} else {
|
|
data = resultsValue
|
|
isNew = false
|
|
}
|
|
for colIndex, col := range colTypes {
|
|
valuePtr := reflect.ValueOf(scanValues[colIndex]).Elem()
|
|
if !valuePtr.IsNil() {
|
|
data.SetMapIndex(reflect.ValueOf(col.Name()), fixValue(col.DatabaseTypeName(), valuePtr.Elem()))
|
|
} else {
|
|
data.SetMapIndex(reflect.ValueOf(col.Name()), fixValue(col.DatabaseTypeName(), reflect.New(rowType.Elem()).Elem()))
|
|
}
|
|
}
|
|
} else if rowType.Kind() == reflect.Slice {
|
|
data = reflect.MakeSlice(rowType, colNum, colNum)
|
|
for colIndex, col := range colTypes {
|
|
valuePtr := reflect.ValueOf(scanValues[colIndex]).Elem()
|
|
if !valuePtr.IsNil() {
|
|
data.Index(colIndex).Set(fixValue(col.DatabaseTypeName(), valuePtr.Elem()))
|
|
} else {
|
|
data.Index(colIndex).Set(fixValue(col.DatabaseTypeName(), reflect.New(rowType.Elem()).Elem()))
|
|
}
|
|
}
|
|
} else {
|
|
valuePtr := reflect.ValueOf(scanValues[0]).Elem()
|
|
if !valuePtr.IsNil() {
|
|
data = fixValue(colTypes[0].DatabaseTypeName(), valuePtr.Elem())
|
|
}
|
|
}
|
|
|
|
if resultsValue.Kind() == reflect.Slice {
|
|
if originRowType.Kind() == reflect.Ptr {
|
|
resultsValue = reflect.Append(resultsValue, data.Addr())
|
|
} else {
|
|
resultsValue = reflect.Append(resultsValue, data)
|
|
}
|
|
} else {
|
|
resultsValue = data
|
|
break
|
|
}
|
|
}
|
|
|
|
if isNew && resultsValue.IsValid() {
|
|
reflect.ValueOf(results).Elem().Set(resultsValue)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func fixValue(colType string, v reflect.Value) reflect.Value {
|
|
if v.Kind() == reflect.String {
|
|
str := v.String()
|
|
switch colType {
|
|
case "DATE":
|
|
if len(str) >= 10 && str[4] == '-' && str[7] == '-' {
|
|
return reflect.ValueOf(str[:10])
|
|
}
|
|
case "DATETIME":
|
|
if len(str) >= 19 && str[10] == 'T' && str[4] == '-' && str[7] == '-' && str[13] == ':' && str[16] == ':' {
|
|
str = strings.TrimRight(str, "Z")
|
|
if len(str) > 19 && str[19] == '.' {
|
|
return reflect.ValueOf(str[:10] + " " + str[11:])
|
|
}
|
|
return reflect.ValueOf(str[:10] + " " + str[11:19])
|
|
}
|
|
case "TIME":
|
|
if len(str) >= 8 && str[2] == ':' && str[4] == ':' {
|
|
if len(str) >= 15 && str[8] == '.' {
|
|
return reflect.ValueOf(str[0:15])
|
|
}
|
|
return reflect.ValueOf(str[0:8])
|
|
}
|
|
}
|
|
}
|
|
return v
|
|
}
|
|
|
|
func (r *QueryResult) getColumnTypes() ([]*sql.ColumnType, error) {
|
|
if r.rows == nil {
|
|
return nil, errors.New("not a valid query result")
|
|
}
|
|
|
|
return r.rows.ColumnTypes()
|
|
}
|
|
|
|
func makePublicVarName(name string) string {
|
|
if len(name) > 0 && name[0] >= 'a' && name[0] <= 'z' {
|
|
return string(name[0]-32) + name[1:]
|
|
}
|
|
return name
|
|
}
|
|
|
|
func makeValue(t reflect.Type) any {
|
|
if t == nil {
|
|
return new(*string)
|
|
}
|
|
for t.Kind() == reflect.Ptr {
|
|
t = t.Elem()
|
|
}
|
|
|
|
switch t.Kind() {
|
|
case reflect.Int:
|
|
return new(*int)
|
|
case reflect.Int8:
|
|
return new(*int8)
|
|
case reflect.Int16:
|
|
return new(*int16)
|
|
case reflect.Int32:
|
|
return new(*int32)
|
|
case reflect.Int64:
|
|
return new(*int64)
|
|
case reflect.Uint:
|
|
return new(*uint)
|
|
case reflect.Uint8:
|
|
return new(*uint8)
|
|
case reflect.Uint16:
|
|
return new(*uint16)
|
|
case reflect.Uint32:
|
|
return new(*uint32)
|
|
case reflect.Uint64:
|
|
return new(*uint64)
|
|
case reflect.Float32:
|
|
return new(*float32)
|
|
case reflect.Float64:
|
|
return new(*float64)
|
|
case reflect.Bool:
|
|
return new(*bool)
|
|
case reflect.String:
|
|
return new(*string)
|
|
}
|
|
|
|
if t.Kind() == reflect.Slice && t.Elem().Kind() == reflect.Uint8 {
|
|
return new(*[]byte)
|
|
}
|
|
|
|
return new(*string)
|
|
}
|