db/Tx.go

209 lines
5.3 KiB
Go
Raw Permalink Normal View History

package db
import (
"database/sql"
"errors"
"fmt"
"strings"
"time"
)
type Tx struct {
conn *sql.Tx
db *DB
lastSql *string
lastArgs []any
Error error
logger *dbLogger
logSlow time.Duration
isCommittedOrRollbacked bool
QuoteTag string
}
func (tx *Tx) Quote(text string) string {
return quote(tx.QuoteTag, text)
}
func (tx *Tx) Quotes(texts []string) string {
return quotes(tx.QuoteTag, texts)
}
func (tx *Tx) Commit() error {
if tx.isCommittedOrRollbacked {
return nil
}
if tx.conn == nil {
return errors.New("operate on a bad connection")
}
err := tx.conn.Commit()
if err != nil {
tx.logger.LogQueryError(err.Error(), *tx.lastSql, tx.lastArgs, -1)
} else {
tx.isCommittedOrRollbacked = true
}
return err
}
func (tx *Tx) Rollback() error {
if tx.isCommittedOrRollbacked {
return nil
}
if tx.conn == nil {
return errors.New("operate on a bad connection")
}
err := tx.conn.Rollback()
if err != nil {
tx.logger.LogQueryError(err.Error(), *tx.lastSql, tx.lastArgs, -1)
} else {
tx.isCommittedOrRollbacked = true
}
return err
}
func (tx *Tx) Finish(ok bool) error {
if tx.isCommittedOrRollbacked {
return nil
}
if ok {
return tx.Commit()
}
return tx.Rollback()
}
func (tx *Tx) CheckFinished() error {
if tx.isCommittedOrRollbacked {
return nil
}
return tx.Rollback()
}
func (tx *Tx) Prepare(query string) *Stmt {
tx.lastSql = &query
r := basePrepare(nil, tx.conn, query)
r.logger = tx.logger
if r.Error != nil {
tx.logger.LogQueryError(r.Error.Error(), *tx.lastSql, tx.lastArgs, -1)
}
return r
}
func (tx *Tx) Exec(query string, args ...any) *ExecResult {
tx.lastSql = &query
tx.lastArgs = args
r := baseExec(nil, tx.conn, query, args...)
r.logger = tx.logger
if r.Error != nil {
tx.logger.LogQueryError(r.Error.Error(), *tx.lastSql, tx.lastArgs, r.usedTime)
} else {
if tx.logSlow > 0 && r.usedTime >= float32(tx.logSlow/time.Millisecond) {
tx.logger.LogQuery(*tx.lastSql, tx.lastArgs, r.usedTime)
}
}
return r
}
func (tx *Tx) Query(query string, args ...any) *QueryResult {
tx.lastSql = &query
tx.lastArgs = args
r := baseQuery(nil, tx.conn, query, args...)
r.logger = tx.logger
if r.Error != nil {
tx.logger.LogQueryError(r.Error.Error(), *tx.lastSql, tx.lastArgs, r.usedTime)
} else {
if tx.logSlow > 0 && r.usedTime >= float32(tx.logSlow/time.Millisecond) {
tx.logger.LogQuery(*tx.lastSql, tx.lastArgs, r.usedTime)
}
}
return r
}
func (tx *Tx) Insert(table string, data any) *ExecResult {
query, values := tx.MakeInsertSql(table, data, false)
tx.lastSql = &query
tx.lastArgs = values
r := baseExec(nil, tx.conn, query, values...)
r.logger = tx.logger
if r.Error != nil {
tx.logger.LogQueryError(r.Error.Error(), *tx.lastSql, tx.lastArgs, r.usedTime)
} else {
if tx.logSlow > 0 && r.usedTime >= float32(tx.logSlow/time.Millisecond) {
tx.logger.LogQuery(*tx.lastSql, tx.lastArgs, r.usedTime)
}
}
return r
}
func (tx *Tx) Replace(table string, data any) *ExecResult {
query, values := tx.MakeInsertSql(table, data, true)
tx.lastSql = &query
tx.lastArgs = values
r := baseExec(nil, tx.conn, query, values...)
r.logger = tx.logger
if r.Error != nil {
tx.logger.LogQueryError(r.Error.Error(), *tx.lastSql, tx.lastArgs, r.usedTime)
} else {
if tx.logSlow > 0 && r.usedTime >= float32(tx.logSlow/time.Millisecond) {
tx.logger.LogQuery(*tx.lastSql, tx.lastArgs, r.usedTime)
}
}
return r
}
func (tx *Tx) Update(table string, data any, conditions string, args ...any) *ExecResult {
query, values := tx.MakeUpdateSql(table, data, conditions, args...)
tx.lastSql = &query
tx.lastArgs = values
r := baseExec(nil, tx.conn, query, values...)
r.logger = tx.logger
if r.Error != nil {
tx.logger.LogQueryError(r.Error.Error(), *tx.lastSql, tx.lastArgs, r.usedTime)
} else {
if tx.logSlow > 0 && r.usedTime >= float32(tx.logSlow/time.Millisecond) {
tx.logger.LogQuery(*tx.lastSql, tx.lastArgs, r.usedTime)
}
}
return r
}
func (tx *Tx) Delete(table string, conditions string, args ...any) *ExecResult {
ts := tx.db.getTable(table)
where := ""
if conditions != "" {
where = " where " + conditions
}
if ts.HasShadowTable {
// Move to shadow table
colList := ""
if len(ts.Columns) > 0 {
quotedCols := make([]string, len(ts.Columns))
for i, c := range ts.Columns {
quotedCols[i] = tx.Quote(c)
}
colList = fmt.Sprintf(" (%s) select %s", strings.Join(quotedCols, ","), strings.Join(quotedCols, ","))
} else {
colList = " select *"
}
moveQuery := fmt.Sprintf("insert into %s%s from %s%s", tx.Quote(table+"_deleted"), colList, tx.Quote(table), where)
r := baseExec(nil, tx.conn, moveQuery, args...)
if r.Error != nil {
tx.logger.LogQueryError(r.Error.Error(), moveQuery, args, r.usedTime)
return r
}
}
query := fmt.Sprintf("delete from %s%s", tx.Quote(table), where)
tx.lastSql = &query
tx.lastArgs = args
r := baseExec(nil, tx.conn, query, args...)
r.logger = tx.logger
if r.Error != nil {
tx.logger.LogQueryError(r.Error.Error(), *tx.lastSql, tx.lastArgs, r.usedTime)
} else {
if tx.logSlow > 0 && r.usedTime >= float32(tx.logSlow/time.Millisecond) {
tx.logger.LogQuery(*tx.lastSql, tx.lastArgs, r.usedTime)
}
}
return r
}