package db import ( "database/sql" "errors" "fmt" "time" ) type Tx struct { conn *sql.Tx lastSql *string lastArgs []any Error error logger *dbLogger logSlow time.Duration isCommittedOrRollbacked bool QuoteTag string } func (tx *Tx) Quote(text string) string { return quote(tx.QuoteTag, text) } func (tx *Tx) Quotes(texts []string) string { return quotes(tx.QuoteTag, texts) } func (tx *Tx) Commit() error { if tx.isCommittedOrRollbacked { return nil } if tx.conn == nil { return errors.New("operate on a bad connection") } err := tx.conn.Commit() if err != nil { tx.logger.LogQueryError(err.Error(), *tx.lastSql, tx.lastArgs, -1) } else { tx.isCommittedOrRollbacked = true } return err } func (tx *Tx) Rollback() error { if tx.isCommittedOrRollbacked { return nil } if tx.conn == nil { return errors.New("operate on a bad connection") } err := tx.conn.Rollback() if err != nil { tx.logger.LogQueryError(err.Error(), *tx.lastSql, tx.lastArgs, -1) } else { tx.isCommittedOrRollbacked = true } return err } func (tx *Tx) Finish(ok bool) error { if tx.isCommittedOrRollbacked { return nil } if ok { return tx.Commit() } return tx.Rollback() } func (tx *Tx) CheckFinished() error { if tx.isCommittedOrRollbacked { return nil } return tx.Rollback() } func (tx *Tx) Prepare(query string) *Stmt { tx.lastSql = &query r := basePrepare(nil, tx.conn, query) r.logger = tx.logger if r.Error != nil { tx.logger.LogQueryError(r.Error.Error(), *tx.lastSql, tx.lastArgs, -1) } return r } func (tx *Tx) Exec(query string, args ...any) *ExecResult { tx.lastSql = &query tx.lastArgs = args r := baseExec(nil, tx.conn, query, args...) r.logger = tx.logger if r.Error != nil { tx.logger.LogQueryError(r.Error.Error(), *tx.lastSql, tx.lastArgs, r.usedTime) } else { if tx.logSlow > 0 && r.usedTime >= float32(tx.logSlow/time.Millisecond) { tx.logger.LogQuery(*tx.lastSql, tx.lastArgs, r.usedTime) } } return r } func (tx *Tx) Query(query string, args ...any) *QueryResult { tx.lastSql = &query tx.lastArgs = args r := baseQuery(nil, tx.conn, query, args...) r.logger = tx.logger if r.Error != nil { tx.logger.LogQueryError(r.Error.Error(), *tx.lastSql, tx.lastArgs, r.usedTime) } else { if tx.logSlow > 0 && r.usedTime >= float32(tx.logSlow/time.Millisecond) { tx.logger.LogQuery(*tx.lastSql, tx.lastArgs, r.usedTime) } } return r } func (tx *Tx) Insert(table string, data any) *ExecResult { query, values := tx.MakeInsertSql(table, data, false) tx.lastSql = &query tx.lastArgs = values r := baseExec(nil, tx.conn, query, values...) r.logger = tx.logger if r.Error != nil { tx.logger.LogQueryError(r.Error.Error(), *tx.lastSql, tx.lastArgs, r.usedTime) } else { if tx.logSlow > 0 && r.usedTime >= float32(tx.logSlow/time.Millisecond) { tx.logger.LogQuery(*tx.lastSql, tx.lastArgs, r.usedTime) } } return r } func (tx *Tx) Replace(table string, data any) *ExecResult { query, values := tx.MakeInsertSql(table, data, true) tx.lastSql = &query tx.lastArgs = values r := baseExec(nil, tx.conn, query, values...) r.logger = tx.logger if r.Error != nil { tx.logger.LogQueryError(r.Error.Error(), *tx.lastSql, tx.lastArgs, r.usedTime) } else { if tx.logSlow > 0 && r.usedTime >= float32(tx.logSlow/time.Millisecond) { tx.logger.LogQuery(*tx.lastSql, tx.lastArgs, r.usedTime) } } return r } func (tx *Tx) Update(table string, data any, conditions string, args ...any) *ExecResult { query, values := tx.MakeUpdateSql(table, data, conditions, args...) tx.lastSql = &query tx.lastArgs = values r := baseExec(nil, tx.conn, query, values...) r.logger = tx.logger if r.Error != nil { tx.logger.LogQueryError(r.Error.Error(), *tx.lastSql, tx.lastArgs, r.usedTime) } else { if tx.logSlow > 0 && r.usedTime >= float32(tx.logSlow/time.Millisecond) { tx.logger.LogQuery(*tx.lastSql, tx.lastArgs, r.usedTime) } } return r } func (tx *Tx) Delete(table string, conditions string, args ...any) *ExecResult { if conditions != "" { conditions = " where " + conditions } query := fmt.Sprintf("delete from %s%s", tx.Quote(table), conditions) tx.lastSql = &query tx.lastArgs = args r := baseExec(nil, tx.conn, query, args...) r.logger = tx.logger if r.Error != nil { tx.logger.LogQueryError(r.Error.Error(), *tx.lastSql, tx.lastArgs, r.usedTime) } else { if tx.logSlow > 0 && r.usedTime >= float32(tx.logSlow/time.Millisecond) { tx.logger.LogQuery(*tx.lastSql, tx.lastArgs, r.usedTime) } } return r }