package redis import ( "encoding/json" "errors" "io" "net" "reflect" "strconv" "strings" "sync" "syscall" "time" "apigo.cc/go/cast" "apigo.cc/go/config" "apigo.cc/go/log" "apigo.cc/go/safe" "github.com/gomodule/redigo/redis" ) type Redis struct { name string pool *redis.Pool Config *Config logger *log.Logger Error error subConn *redis.PubSubConn subLock sync.RWMutex subStopChan chan bool subs map[string]*SubCallbacks SubRunning bool } type SubCallbacks struct { received func([]byte) reset func() } var redisInstances = make(map[string]*Redis) var redisInstancesLock = sync.RWMutex{} func GetRedis(name string, logger *log.Logger) *Redis { if logger == nil { logger = log.DefaultLogger } redisInstancesLock.RLock() oldConn := redisInstances[name] redisInstancesLock.RUnlock() if oldConn != nil { return oldConn.CopyByLogger(logger) } redisConfigsLock.RLock() configsLen := len(redisConfigs) redisConfigsLock.RUnlock() if configsLen == 0 { _ = config.Load("redis", &redisConfigs) } fullName := name var conf *Config if strings.HasPrefix(name, "redis://") { conf = new(Config) conf.logger = logger conf.ConfigureBy(name) } else { conf = parseByName(name) } if pwd, err := confAes.Decrypt([]byte(conf.Password)); err == nil { conf.pwd = pwd } else { conf.pwd = safe.NewSafeBuf([]byte(conf.Password)) } conf.Password = "" if conf.Host == "" { conf.Host = "127.0.0.1:6379" } if conf.MaxIdle == 0 { conf.MaxIdle = 20 } if conf.MaxActive == 0 { conf.MaxActive = 100 } if conf.ConnectTimeout == 0 { conf.ConnectTimeout = 10 * time.Second } if conf.ReadTimeout == 0 { conf.ReadTimeout = 10 * time.Second } if conf.WriteTimeout == 0 { conf.WriteTimeout = 10 * time.Second } if conf.LogSlow == 0 { conf.LogSlow = 100 * time.Millisecond } rd := NewRedis(conf, logger) rd.name = fullName redisInstancesLock.Lock() redisInstances[fullName] = rd redisInstancesLock.Unlock() return rd.CopyByLogger(logger) } func NewRedis(conf *Config, logger *log.Logger) *Redis { if logger == nil { logger = log.DefaultLogger } conn := &redis.Pool{ MaxIdle: conf.MaxIdle, MaxActive: conf.MaxActive, IdleTimeout: conf.IdleTimeout, Dial: func() (redis.Conn, error) { opts := []redis.DialOption{ redis.DialConnectTimeout(conf.ConnectTimeout), redis.DialReadTimeout(conf.ReadTimeout), redis.DialWriteTimeout(conf.WriteTimeout), redis.DialDatabase(conf.DB), } if conf.pwd != nil { pwdBuf := conf.pwd.Open() defer pwdBuf.Close() opts = append(opts, redis.DialPassword(pwdBuf.String())) } c, err := redis.Dial("tcp", conf.Host, opts...) if err != nil { logger.Error(err.Error(), "type", "redis", "dsn", conf.Dsn()) return nil, err } return c, nil }, } rd := new(Redis) rd.pool = conn rd.Config = conf rd.logger = logger return rd } func (rd *Redis) CopyByLogger(logger *log.Logger) *Redis { newRedis := new(Redis) newRedis.name = rd.name newRedis.pool = rd.pool newRedis.subConn = rd.subConn newRedis.subs = rd.subs newRedis.SubRunning = rd.SubRunning newRedis.Config = rd.Config if logger == nil { newRedis.logger = log.DefaultLogger } else { newRedis.logger = logger } return newRedis } func (rd *Redis) SetLogger(logger *log.Logger) { rd.logger = logger } func (rd *Redis) GetLogger() *log.Logger { return rd.logger } func (rd *Redis) Destroy() error { if rd.pool == nil { return errors.New("operate on a bad redis pool") } err := rd.pool.Close() if err != nil { rd.logger.Error(err.Error(), "type", "redis", "dsn", rd.Config.Dsn()) } redisInstancesLock.Lock() delete(redisInstances, rd.name) redisInstancesLock.Unlock() return err } func (rd *Redis) GetPool() *redis.Pool { return rd.pool } func (rd *Redis) GetNewConnection() (redis.Conn, error) { if rd.pool == nil { return nil, errors.New("redis pool is not initialized") } c, err := rd.pool.Dial() if err == nil { err = c.Err() } if err != nil { if c != nil { _ = c.Close() } return nil, err } return c, nil } func (rd *Redis) GetConnection() (redis.Conn, error) { if rd.pool == nil { return nil, errors.New("redis pool is not initialized") } c := rd.pool.Get() err := c.Err() if err != nil { if c != nil { _ = c.Close() } return nil, err } return c, nil } func shouldRetry(err error) bool { if err == nil { return false } // 网络错误 if errors.Is(err, io.EOF) || errors.Is(err, syscall.ECONNRESET) { return true } // 超时错误 var opErr *net.OpError if errors.As(err, &opErr) && opErr.Timeout() { return true } // Redis特定可恢复错误 var redisErr redis.Error if errors.As(err, &redisErr) { errMsg := string(redisErr) if strings.HasPrefix(errMsg, "LOADING") || strings.HasPrefix(errMsg, "CLUSTERDOWN") { return true } } return false } func (rd *Redis) Do(cmd string, values ...any) *Result { if rd.pool == nil { err := errors.New("operate on a bad redis pool") rd.logger.Error(err.Error(), "type", "redis", "dsn", rd.Config.Dsn(), "cmd", cmd) return &Result{Error: err} } startTime := time.Now() r := rd.do(cmd, values...) usedTime := time.Since(startTime) if r.Error == nil { if rd.Config.LogSlow > 0 && usedTime >= rd.Config.LogSlow { rd.logger.Debug("redis slow query", "dsn", rd.Config.Dsn(), "cmd", cmd, "args", values, "used", usedTime.String()) } } else { rd.logger.Error(r.Error.Error(), "type", "redis", "dsn", rd.Config.Dsn(), "cmd", cmd, "args", values, "used", usedTime.String()) } return r } func (rd *Redis) do(cmd string, values ...any) *Result { if strings.Contains(cmd, " ") { cmdArr := cast.Split(cmd, " ") if len(cmdArr) > 1 { cmd = cmdArr[0] args := make([]any, 0, len(cmdArr)-1+len(values)) for i := 1; i < len(cmdArr); i++ { args = append(args, cmdArr[i]) } if len(values) > 0 { args = append(args, values...) } values = args } } // 自动序列化 for i, v := range values { if v == nil { continue } rv := reflect.ValueOf(v) if rv.Kind() == reflect.Ptr { if rv.IsNil() { continue } rv = rv.Elem() } kind := rv.Kind() if kind == reflect.Struct || kind == reflect.Map || (kind == reflect.Slice && rv.Type().Elem().Kind() != reflect.Uint8) { // 对于这些类型,优先使用 json.Marshal 以支持自定义 Marshaler (如 time.Time) if encoded, err := json.Marshal(v); err == nil { values[i] = encoded } } } // 从连接池获取 conn, err := rd.GetConnection() var replyData any if err == nil { replyData, err = conn.Do(cmd, values...) _ = conn.Close() } if err != nil && shouldRetry(err) { // 拿全新的连接重试(如果服务器重启可自动恢复) conn, err = rd.GetNewConnection() if err == nil { replyData, err = conn.Do(cmd, values...) _ = conn.Close() } } if err != nil { return &Result{Error: err} } r := &Result{} switch realValue := replyData.(type) { case []byte: r.bytesData = realValue case string: r.bytesData = []byte(realValue) case int64: r.bytesData = []byte(strconv.FormatInt(realValue, 10)) case []any: if cmd == "HMGET" { r.keys = make([]string, len(values)-1) for i, v := range values { if i > 0 { r.keys[i-1] = cast.String(v) } } } else if cmd == "MGET" { r.keys = make([]string, len(values)) for i, v := range values { r.keys[i] = cast.String(v) } } if cmd == "HGETALL" { r.keys = make([]string, len(realValue)/2) r.bytesDatas = make([][]byte, len(realValue)/2) i1, i2 := 0, 0 for i, v := range realValue { if v != nil { if i%2 == 0 { r.keys[i1] = cast.String(v) i1++ } else { switch subRealValue := v.(type) { case []byte: r.bytesDatas[i2] = subRealValue case string: r.bytesDatas[i2] = []byte(subRealValue) default: r.bytesDatas[i2] = []byte(cast.String(v)) } i2++ } } } } else { r.bytesDatas = make([][]byte, len(realValue)) for i, v := range realValue { if v != nil { switch subRealValue := v.(type) { case []byte: r.bytesDatas[i] = subRealValue case string: r.bytesDatas[i] = []byte(subRealValue) default: r.bytesDatas[i] = []byte(cast.String(v)) } } } } case nil: r.bytesData = []byte{} default: r.bytesData = []byte(cast.String(realValue)) } return r }