redis/redis.go

389 lines
8.2 KiB
Go
Raw Permalink Normal View History

package redis
import (
"encoding/json"
"errors"
"io"
"net"
"reflect"
"strconv"
"strings"
"sync"
"syscall"
"time"
"apigo.cc/go/cast"
"apigo.cc/go/config"
"apigo.cc/go/log"
"apigo.cc/go/safe"
"github.com/gomodule/redigo/redis"
)
type Redis struct {
name string
pool *redis.Pool
Config *Config
logger *log.Logger
Error error
subConn *redis.PubSubConn
subLock sync.RWMutex
subStopChan chan bool
subs map[string]*SubCallbacks
SubRunning bool
}
type SubCallbacks struct {
received func([]byte)
reset func()
}
var redisInstances = make(map[string]*Redis)
var redisInstancesLock = sync.RWMutex{}
func GetRedis(name string, logger *log.Logger) *Redis {
if logger == nil {
logger = log.DefaultLogger
}
redisInstancesLock.RLock()
oldConn := redisInstances[name]
redisInstancesLock.RUnlock()
if oldConn != nil {
return oldConn.CopyByLogger(logger)
}
redisConfigsLock.RLock()
configsLen := len(redisConfigs)
redisConfigsLock.RUnlock()
if configsLen == 0 {
_ = config.Load("redis", &redisConfigs)
}
fullName := name
var conf *Config
if strings.HasPrefix(name, "redis://") {
conf = new(Config)
conf.logger = logger
conf.ConfigureBy(name)
} else {
conf = parseByName(name)
}
if pwd, err := confAes.Decrypt([]byte(conf.Password)); err == nil {
conf.pwd = pwd
} else {
conf.pwd = safe.NewSafeBuf([]byte(conf.Password))
}
conf.Password = ""
if conf.Host == "" {
conf.Host = "127.0.0.1:6379"
}
if conf.MaxIdle == 0 {
conf.MaxIdle = 20
}
if conf.MaxActive == 0 {
conf.MaxActive = 100
}
if conf.ConnectTimeout == 0 {
conf.ConnectTimeout = 10 * time.Second
}
if conf.ReadTimeout == 0 {
conf.ReadTimeout = 10 * time.Second
}
if conf.WriteTimeout == 0 {
conf.WriteTimeout = 10 * time.Second
}
if conf.LogSlow == 0 {
conf.LogSlow = 100 * time.Millisecond
}
rd := NewRedis(conf, logger)
rd.name = fullName
redisInstancesLock.Lock()
redisInstances[fullName] = rd
redisInstancesLock.Unlock()
return rd.CopyByLogger(logger)
}
func NewRedis(conf *Config, logger *log.Logger) *Redis {
if logger == nil {
logger = log.DefaultLogger
}
conn := &redis.Pool{
MaxIdle: conf.MaxIdle,
MaxActive: conf.MaxActive,
IdleTimeout: conf.IdleTimeout,
Dial: func() (redis.Conn, error) {
opts := []redis.DialOption{
redis.DialConnectTimeout(conf.ConnectTimeout),
redis.DialReadTimeout(conf.ReadTimeout),
redis.DialWriteTimeout(conf.WriteTimeout),
redis.DialDatabase(conf.DB),
}
if conf.pwd != nil {
pwdBuf := conf.pwd.Open()
defer pwdBuf.Close()
opts = append(opts, redis.DialPassword(pwdBuf.String()))
}
c, err := redis.Dial("tcp", conf.Host, opts...)
if err != nil {
logger.Error(err.Error(), "type", "redis", "dsn", conf.Dsn())
return nil, err
}
return c, nil
},
}
rd := new(Redis)
rd.pool = conn
rd.Config = conf
rd.logger = logger
return rd
}
func (rd *Redis) CopyByLogger(logger *log.Logger) *Redis {
newRedis := new(Redis)
newRedis.name = rd.name
newRedis.pool = rd.pool
newRedis.subConn = rd.subConn
newRedis.subs = rd.subs
newRedis.SubRunning = rd.SubRunning
newRedis.Config = rd.Config
if logger == nil {
newRedis.logger = log.DefaultLogger
} else {
newRedis.logger = logger
}
return newRedis
}
func (rd *Redis) SetLogger(logger *log.Logger) {
rd.logger = logger
}
func (rd *Redis) GetLogger() *log.Logger {
return rd.logger
}
func (rd *Redis) Destroy() error {
if rd.pool == nil {
return errors.New("operate on a bad redis pool")
}
err := rd.pool.Close()
if err != nil {
rd.logger.Error(err.Error(), "type", "redis", "dsn", rd.Config.Dsn())
}
redisInstancesLock.Lock()
delete(redisInstances, rd.name)
redisInstancesLock.Unlock()
return err
}
func (rd *Redis) GetPool() *redis.Pool {
return rd.pool
}
func (rd *Redis) GetNewConnection() (redis.Conn, error) {
if rd.pool == nil {
return nil, errors.New("redis pool is not initialized")
}
c, err := rd.pool.Dial()
if err == nil {
err = c.Err()
}
if err != nil {
if c != nil {
_ = c.Close()
}
return nil, err
}
return c, nil
}
func (rd *Redis) GetConnection() (redis.Conn, error) {
if rd.pool == nil {
return nil, errors.New("redis pool is not initialized")
}
c := rd.pool.Get()
err := c.Err()
if err != nil {
if c != nil {
_ = c.Close()
}
return nil, err
}
return c, nil
}
func shouldRetry(err error) bool {
if err == nil {
return false
}
// 网络错误
if errors.Is(err, io.EOF) || errors.Is(err, syscall.ECONNRESET) {
return true
}
// 超时错误
if opErr, ok := err.(*net.OpError); ok && opErr.Timeout() {
return true
}
// Redis特定可恢复错误
if errs, ok := err.(redis.Error); ok {
switch {
case strings.HasPrefix(string(errs), "LOADING"),
strings.HasPrefix(string(errs), "CLUSTERDOWN"):
return true
}
}
return false
}
func (rd *Redis) Do(cmd string, values ...any) *Result {
if rd.pool == nil {
err := errors.New("operate on a bad redis pool")
rd.logger.Error(err.Error(), "type", "redis", "dsn", rd.Config.Dsn(), "cmd", cmd)
return &Result{Error: err}
}
startTime := time.Now()
r := rd.do(cmd, values...)
usedTime := time.Since(startTime)
if r.Error == nil {
if rd.Config.LogSlow > 0 && usedTime >= rd.Config.LogSlow {
rd.logger.Debug("redis slow query", "dsn", rd.Config.Dsn(), "cmd", cmd, "args", values, "used", usedTime.String())
}
} else {
rd.logger.Error(r.Error.Error(), "type", "redis", "dsn", rd.Config.Dsn(), "cmd", cmd, "args", values, "used", usedTime.String())
}
return r
}
func (rd *Redis) do(cmd string, values ...any) *Result {
cmdArr := cast.Split(cmd, " ")
if len(cmdArr) > 1 {
cmd = cmdArr[0]
args := make([]any, 0)
for i := 1; i < len(cmdArr); i++ {
args = append(args, cmdArr[i])
}
if len(values) > 0 {
args = append(args, values...)
}
values = args
}
// 自动序列化
for i, v := range values {
if v == nil {
continue
}
rv := reflect.ValueOf(v)
if rv.Kind() == reflect.Ptr {
rv = rv.Elem()
}
kind := rv.Kind()
if kind == reflect.Struct || kind == reflect.Map || (kind == reflect.Slice && rv.Type().Elem().Kind() != reflect.Uint8) {
// 对于这些类型,优先使用 json.Marshal 以支持自定义 Marshaler (如 time.Time)
if encoded, err := json.Marshal(v); err == nil {
values[i] = encoded
}
}
}
// 从连接池获取
conn, err := rd.GetConnection()
var replyData any
if err == nil {
replyData, err = conn.Do(cmd, values...)
_ = conn.Close()
}
if err != nil && shouldRetry(err) {
// 拿全新的连接重试(如果服务器重启可自动恢复)
conn, err = rd.GetNewConnection()
if err == nil {
replyData, err = conn.Do(cmd, values...)
_ = conn.Close()
}
}
if err != nil {
return &Result{Error: err}
}
r := &Result{}
switch realValue := replyData.(type) {
case []byte:
r.bytesData = realValue
case string:
r.bytesData = []byte(realValue)
case int64:
r.bytesData = []byte(strconv.FormatInt(realValue, 10))
case []any:
if cmd == "HMGET" {
r.keys = make([]string, len(values)-1)
for i, v := range values {
if i > 0 {
r.keys[i-1] = cast.String(v)
}
}
} else if cmd == "MGET" {
r.keys = make([]string, len(values))
for i, v := range values {
r.keys[i] = cast.String(v)
}
}
if cmd == "HGETALL" {
r.keys = make([]string, len(realValue)/2)
r.bytesDatas = make([][]byte, len(realValue)/2)
i1, i2 := 0, 0
for i, v := range realValue {
if v != nil {
if i%2 == 0 {
r.keys[i1] = cast.String(v)
i1++
} else {
switch subRealValue := v.(type) {
case []byte:
r.bytesDatas[i2] = subRealValue
case string:
r.bytesDatas[i2] = []byte(subRealValue)
default:
r.bytesDatas[i2] = []byte(cast.String(v))
}
i2++
}
}
}
} else {
r.bytesDatas = make([][]byte, len(realValue))
for i, v := range realValue {
if v != nil {
switch subRealValue := v.(type) {
case []byte:
r.bytesDatas[i] = subRealValue
case string:
r.bytesDatas[i] = []byte(subRealValue)
default:
r.bytesDatas[i] = []byte(cast.String(v))
}
}
}
}
case nil:
r.bytesData = []byte{}
default:
r.bytesData = []byte(cast.String(realValue))
}
return r
}