391 lines
8.2 KiB
Go
391 lines
8.2 KiB
Go
|
|
package redis
|
||
|
|
|
||
|
|
import (
|
||
|
|
"encoding/json"
|
||
|
|
"errors"
|
||
|
|
"io"
|
||
|
|
"net"
|
||
|
|
"reflect"
|
||
|
|
"strconv"
|
||
|
|
"strings"
|
||
|
|
"sync"
|
||
|
|
"syscall"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"apigo.cc/go/cast"
|
||
|
|
"apigo.cc/go/config"
|
||
|
|
"apigo.cc/go/log"
|
||
|
|
"apigo.cc/go/safe"
|
||
|
|
"github.com/gomodule/redigo/redis"
|
||
|
|
)
|
||
|
|
|
||
|
|
type Redis struct {
|
||
|
|
name string
|
||
|
|
pool *redis.Pool
|
||
|
|
ReadTimeout int
|
||
|
|
Config *Config
|
||
|
|
logger *log.Logger
|
||
|
|
Error error
|
||
|
|
subConn *redis.PubSubConn
|
||
|
|
subStopChan chan bool
|
||
|
|
subs map[string]*SubCallbacks
|
||
|
|
SubRunning bool
|
||
|
|
}
|
||
|
|
|
||
|
|
type SubCallbacks struct {
|
||
|
|
received func([]byte)
|
||
|
|
reset func()
|
||
|
|
}
|
||
|
|
|
||
|
|
var redisInstances = make(map[string]*Redis)
|
||
|
|
var redisInstancesLock = sync.RWMutex{}
|
||
|
|
|
||
|
|
func GetRedis(name string, logger *log.Logger) *Redis {
|
||
|
|
if logger == nil {
|
||
|
|
logger = log.DefaultLogger
|
||
|
|
}
|
||
|
|
|
||
|
|
redisInstancesLock.RLock()
|
||
|
|
oldConn := redisInstances[name]
|
||
|
|
redisInstancesLock.RUnlock()
|
||
|
|
if oldConn != nil {
|
||
|
|
return oldConn.CopyByLogger(logger)
|
||
|
|
}
|
||
|
|
|
||
|
|
redisConfigsLock.RLock()
|
||
|
|
configsLen := len(redisConfigs)
|
||
|
|
redisConfigsLock.RUnlock()
|
||
|
|
|
||
|
|
if configsLen == 0 {
|
||
|
|
_ = config.Load("redis", &redisConfigs)
|
||
|
|
}
|
||
|
|
|
||
|
|
fullName := name
|
||
|
|
|
||
|
|
var conf *Config
|
||
|
|
if strings.HasPrefix(name, "redis://") {
|
||
|
|
conf = new(Config)
|
||
|
|
conf.logger = logger
|
||
|
|
conf.ConfigureBy(name)
|
||
|
|
} else {
|
||
|
|
conf = parseByName(name)
|
||
|
|
}
|
||
|
|
|
||
|
|
if pwd, err := confAes.Decrypt([]byte(conf.Password)); err == nil {
|
||
|
|
conf.pwd = pwd
|
||
|
|
} else {
|
||
|
|
conf.pwd = safe.NewSafeBuf([]byte(conf.Password))
|
||
|
|
}
|
||
|
|
conf.Password = ""
|
||
|
|
|
||
|
|
if conf.Host == "" {
|
||
|
|
conf.Host = "127.0.0.1:6379"
|
||
|
|
}
|
||
|
|
if conf.MaxIdle == 0 {
|
||
|
|
conf.MaxIdle = 20
|
||
|
|
}
|
||
|
|
if conf.MaxActive == 0 {
|
||
|
|
conf.MaxActive = 100
|
||
|
|
}
|
||
|
|
if conf.ConnectTimeout == 0 {
|
||
|
|
conf.ConnectTimeout = 10 * time.Second
|
||
|
|
}
|
||
|
|
if conf.ReadTimeout == 0 {
|
||
|
|
conf.ReadTimeout = 10 * time.Second
|
||
|
|
}
|
||
|
|
if conf.WriteTimeout == 0 {
|
||
|
|
conf.WriteTimeout = 10 * time.Second
|
||
|
|
}
|
||
|
|
if conf.LogSlow == 0 {
|
||
|
|
conf.LogSlow = 100 * time.Millisecond
|
||
|
|
}
|
||
|
|
|
||
|
|
rd := NewRedis(conf, logger)
|
||
|
|
rd.name = fullName
|
||
|
|
redisInstancesLock.Lock()
|
||
|
|
redisInstances[fullName] = rd
|
||
|
|
redisInstancesLock.Unlock()
|
||
|
|
return rd.CopyByLogger(logger)
|
||
|
|
}
|
||
|
|
|
||
|
|
func NewRedis(conf *Config, logger *log.Logger) *Redis {
|
||
|
|
if logger == nil {
|
||
|
|
logger = log.DefaultLogger
|
||
|
|
}
|
||
|
|
|
||
|
|
conn := &redis.Pool{
|
||
|
|
MaxIdle: conf.MaxIdle,
|
||
|
|
MaxActive: conf.MaxActive,
|
||
|
|
IdleTimeout: conf.IdleTimeout,
|
||
|
|
Dial: func() (redis.Conn, error) {
|
||
|
|
opts := []redis.DialOption{
|
||
|
|
redis.DialConnectTimeout(conf.ConnectTimeout),
|
||
|
|
redis.DialReadTimeout(conf.ReadTimeout),
|
||
|
|
redis.DialWriteTimeout(conf.WriteTimeout),
|
||
|
|
redis.DialDatabase(conf.DB),
|
||
|
|
}
|
||
|
|
if conf.pwd != nil {
|
||
|
|
pwdBuf := conf.pwd.Open()
|
||
|
|
defer pwdBuf.Close()
|
||
|
|
opts = append(opts, redis.DialPassword(pwdBuf.String()))
|
||
|
|
}
|
||
|
|
c, err := redis.Dial("tcp", conf.Host, opts...)
|
||
|
|
if err != nil {
|
||
|
|
logger.Error(err.Error(), "type", "redis", "dsn", conf.Dsn())
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
return c, nil
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
rd := new(Redis)
|
||
|
|
rd.ReadTimeout = int(conf.ReadTimeout / time.Millisecond)
|
||
|
|
rd.pool = conn
|
||
|
|
rd.Config = conf
|
||
|
|
rd.logger = logger
|
||
|
|
|
||
|
|
return rd
|
||
|
|
}
|
||
|
|
|
||
|
|
func (rd *Redis) CopyByLogger(logger *log.Logger) *Redis {
|
||
|
|
newRedis := new(Redis)
|
||
|
|
newRedis.name = rd.name
|
||
|
|
newRedis.ReadTimeout = rd.ReadTimeout
|
||
|
|
newRedis.pool = rd.pool
|
||
|
|
newRedis.subConn = rd.subConn
|
||
|
|
newRedis.subs = rd.subs
|
||
|
|
newRedis.SubRunning = rd.SubRunning
|
||
|
|
newRedis.Config = rd.Config
|
||
|
|
if logger == nil {
|
||
|
|
newRedis.logger = log.DefaultLogger
|
||
|
|
} else {
|
||
|
|
newRedis.logger = logger
|
||
|
|
}
|
||
|
|
return newRedis
|
||
|
|
}
|
||
|
|
|
||
|
|
func (rd *Redis) SetLogger(logger *log.Logger) {
|
||
|
|
rd.logger = logger
|
||
|
|
}
|
||
|
|
|
||
|
|
func (rd *Redis) GetLogger() *log.Logger {
|
||
|
|
return rd.logger
|
||
|
|
}
|
||
|
|
|
||
|
|
func (rd *Redis) Destroy() error {
|
||
|
|
if rd.pool == nil {
|
||
|
|
return errors.New("operate on a bad redis pool")
|
||
|
|
}
|
||
|
|
err := rd.pool.Close()
|
||
|
|
if err != nil {
|
||
|
|
rd.logger.Error(err.Error(), "type", "redis", "dsn", rd.Config.Dsn())
|
||
|
|
}
|
||
|
|
redisInstancesLock.Lock()
|
||
|
|
delete(redisInstances, rd.name)
|
||
|
|
redisInstancesLock.Unlock()
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
func (rd *Redis) GetPool() *redis.Pool {
|
||
|
|
return rd.pool
|
||
|
|
}
|
||
|
|
|
||
|
|
func (rd *Redis) GetNewConnection() (redis.Conn, error) {
|
||
|
|
if rd.pool == nil {
|
||
|
|
return nil, errors.New("redis pool is not initialized")
|
||
|
|
}
|
||
|
|
c, err := rd.pool.Dial()
|
||
|
|
if err == nil {
|
||
|
|
err = c.Err()
|
||
|
|
}
|
||
|
|
if err != nil {
|
||
|
|
if c != nil {
|
||
|
|
_ = c.Close()
|
||
|
|
}
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
return c, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (rd *Redis) GetConnection() (redis.Conn, error) {
|
||
|
|
if rd.pool == nil {
|
||
|
|
return nil, errors.New("redis pool is not initialized")
|
||
|
|
}
|
||
|
|
c := rd.pool.Get()
|
||
|
|
err := c.Err()
|
||
|
|
if err != nil {
|
||
|
|
if c != nil {
|
||
|
|
_ = c.Close()
|
||
|
|
}
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
return c, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func shouldRetry(err error) bool {
|
||
|
|
if err == nil {
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
|
||
|
|
// 网络错误
|
||
|
|
if errors.Is(err, io.EOF) || errors.Is(err, syscall.ECONNRESET) {
|
||
|
|
return true
|
||
|
|
}
|
||
|
|
|
||
|
|
// 超时错误
|
||
|
|
if opErr, ok := err.(*net.OpError); ok && opErr.Timeout() {
|
||
|
|
return true
|
||
|
|
}
|
||
|
|
|
||
|
|
// Redis特定可恢复错误
|
||
|
|
if errs, ok := err.(redis.Error); ok {
|
||
|
|
switch {
|
||
|
|
case strings.HasPrefix(string(errs), "LOADING"),
|
||
|
|
strings.HasPrefix(string(errs), "CLUSTERDOWN"):
|
||
|
|
return true
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
|
||
|
|
func (rd *Redis) Do(cmd string, values ...any) *Result {
|
||
|
|
if rd.pool == nil {
|
||
|
|
err := errors.New("operate on a bad redis pool")
|
||
|
|
rd.logger.Error(err.Error(), "type", "redis", "dsn", rd.Config.Dsn(), "cmd", cmd)
|
||
|
|
return &Result{Error: err}
|
||
|
|
}
|
||
|
|
startTime := time.Now()
|
||
|
|
r := rd.do(cmd, values...)
|
||
|
|
usedTime := time.Since(startTime)
|
||
|
|
if r.Error == nil {
|
||
|
|
if rd.Config.LogSlow > 0 && usedTime >= rd.Config.LogSlow {
|
||
|
|
rd.logger.Debug("redis slow query", "dsn", rd.Config.Dsn(), "cmd", cmd, "args", values, "used", usedTime.String())
|
||
|
|
}
|
||
|
|
} else {
|
||
|
|
rd.logger.Error(r.Error.Error(), "type", "redis", "dsn", rd.Config.Dsn(), "cmd", cmd, "args", values, "used", usedTime.String())
|
||
|
|
}
|
||
|
|
return r
|
||
|
|
}
|
||
|
|
|
||
|
|
func (rd *Redis) do(cmd string, values ...any) *Result {
|
||
|
|
cmdArr := cast.Split(cmd, " ")
|
||
|
|
if len(cmdArr) > 1 {
|
||
|
|
cmd = cmdArr[0]
|
||
|
|
args := make([]any, 0)
|
||
|
|
for i := 1; i < len(cmdArr); i++ {
|
||
|
|
args = append(args, cmdArr[i])
|
||
|
|
}
|
||
|
|
if len(values) > 0 {
|
||
|
|
args = append(args, values...)
|
||
|
|
}
|
||
|
|
values = args
|
||
|
|
}
|
||
|
|
|
||
|
|
// 自动序列化
|
||
|
|
for i, v := range values {
|
||
|
|
if v == nil {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
rv := reflect.ValueOf(v)
|
||
|
|
if rv.Kind() == reflect.Ptr {
|
||
|
|
rv = rv.Elem()
|
||
|
|
}
|
||
|
|
kind := rv.Kind()
|
||
|
|
if kind == reflect.Struct || kind == reflect.Map || (kind == reflect.Slice && rv.Type().Elem().Kind() != reflect.Uint8) {
|
||
|
|
// 对于这些类型,优先使用 json.Marshal 以支持自定义 Marshaler (如 time.Time)
|
||
|
|
if encoded, err := json.Marshal(v); err == nil {
|
||
|
|
values[i] = encoded
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// 从连接池获取
|
||
|
|
conn, err := rd.GetConnection()
|
||
|
|
var replyData any
|
||
|
|
if err == nil {
|
||
|
|
replyData, err = conn.Do(cmd, values...)
|
||
|
|
_ = conn.Close()
|
||
|
|
}
|
||
|
|
|
||
|
|
if err != nil && shouldRetry(err) {
|
||
|
|
// 拿全新的连接重试(如果服务器重启可自动恢复)
|
||
|
|
conn, err = rd.GetNewConnection()
|
||
|
|
if err == nil {
|
||
|
|
replyData, err = conn.Do(cmd, values...)
|
||
|
|
_ = conn.Close()
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
if err != nil {
|
||
|
|
return &Result{Error: err}
|
||
|
|
}
|
||
|
|
|
||
|
|
r := &Result{}
|
||
|
|
switch realValue := replyData.(type) {
|
||
|
|
case []byte:
|
||
|
|
r.bytesData = realValue
|
||
|
|
case string:
|
||
|
|
r.bytesData = []byte(realValue)
|
||
|
|
case int64:
|
||
|
|
r.bytesData = []byte(strconv.FormatInt(realValue, 10))
|
||
|
|
case []any:
|
||
|
|
if cmd == "HMGET" {
|
||
|
|
r.keys = make([]string, len(values)-1)
|
||
|
|
for i, v := range values {
|
||
|
|
if i > 0 {
|
||
|
|
r.keys[i-1] = cast.String(v)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
} else if cmd == "MGET" {
|
||
|
|
r.keys = make([]string, len(values))
|
||
|
|
for i, v := range values {
|
||
|
|
r.keys[i] = cast.String(v)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
if cmd == "HGETALL" {
|
||
|
|
r.keys = make([]string, len(realValue)/2)
|
||
|
|
r.bytesDatas = make([][]byte, len(realValue)/2)
|
||
|
|
i1, i2 := 0, 0
|
||
|
|
for i, v := range realValue {
|
||
|
|
if v != nil {
|
||
|
|
if i%2 == 0 {
|
||
|
|
r.keys[i1] = cast.String(v)
|
||
|
|
i1++
|
||
|
|
} else {
|
||
|
|
switch subRealValue := v.(type) {
|
||
|
|
case []byte:
|
||
|
|
r.bytesDatas[i2] = subRealValue
|
||
|
|
case string:
|
||
|
|
r.bytesDatas[i2] = []byte(subRealValue)
|
||
|
|
default:
|
||
|
|
r.bytesDatas[i2] = []byte(cast.String(v))
|
||
|
|
}
|
||
|
|
i2++
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
} else {
|
||
|
|
r.bytesDatas = make([][]byte, len(realValue))
|
||
|
|
for i, v := range realValue {
|
||
|
|
if v != nil {
|
||
|
|
switch subRealValue := v.(type) {
|
||
|
|
case []byte:
|
||
|
|
r.bytesDatas[i] = subRealValue
|
||
|
|
case string:
|
||
|
|
r.bytesDatas[i] = []byte(subRealValue)
|
||
|
|
default:
|
||
|
|
r.bytesDatas[i] = []byte(cast.String(v))
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
case nil:
|
||
|
|
r.bytesData = []byte{}
|
||
|
|
default:
|
||
|
|
r.bytesData = []byte(cast.String(realValue))
|
||
|
|
}
|
||
|
|
return r
|
||
|
|
}
|