diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..155c599 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,16 @@ +# CHANGELOG - redis + +## v1.0.0 (2026-05-03) +- **Repo Migration**: 从 `@ssgo/redis` 迁移至 `apigo.cc/go/redis`。 +- **Standard Realignment**: + - 依赖全面切换至 `apigo.cc/go/*` 标准库。 + - 适配 Go 1.25.0。 +- **Feature Enhancements**: + - **Generics**: 为 `Result` 引入泛型 `To[T]` 支持,消除类型断言摩擦。 + - **Memory Safety**: 集成 `go/safe` 对 Redis 密码进行实时加解密与内存锁定,防止内存泄漏敏感信息。 + - **Auto-Serialization**: 优化了 `Do` 方法,支持对 Struct/Map/Slice 自动进行 JSON 序列化(优先支持 Marshaler 接口)。 + - **Distributed ID**: 深度集成 `go/id` 核心,提供更高性能的 Redis 序列号预取机制。 +- **Refactoring**: + - 移除了冗余的 `interface{}`,全面改用 `any`。 + - 规范化了 API 命名,统一使用 `GetUpperName` 进行 Struct 字段映射。 + - 增强了连接重试机制,支持对网络波动和服务器重启的自动恢复。 diff --git a/README.md b/README.md index 3cf1b4b..34dab72 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,56 @@ -# redis +# Redis 模块 (redis) -高性能 Redis 客户端,集成分布式 ID 生成与发布订阅 \ No newline at end of file +`redis` 模块提供了一个高性能、内存安全且易于使用的 Redis 客户端,集成了分布式 ID 生成和发布订阅功能。 + +## 设计特性 +- **消除摩擦**: 自动处理连接池、重试、以及复杂类型的 JSON 序列化。 +- **内存安全**: 对 Redis 密码进行内存保护,防止内存 Dump 泄露敏感信息。 +- **泛型支持**: 结果集支持泛型绑定 `To[T]`。 +- **分布式 ID**: 内置基于 Redis 的高性能分布式 ID 生成器。 + +## API 指南 + +### 基础连接 +- `GetRedis(name string, logger *log.Logger) *Redis`: 获取或创建一个 Redis 实例(支持 DSN 或配置文件名)。 +- `NewRedis(conf *Config, logger *log.Logger) *Redis`: 使用指定配置创建 Redis 实例。 + +### 核心操作 +- `Do(cmd string, values ...any) *Result`: 执行原生 Redis 命令。 +- 支持大部分标准命令:`GET`, `SET`, `HGET`, `HSET`, `LPUSH`, `LPOP`, `SADD`, `ZRANGE`, `PUBLISH` 等。 + +### 结果处理 (Result) +- `Int()`, `String()`, `Bool()`, `Float()`: 基础类型转换。 +- `To(target any)`: 将结果反序列化到结构体或 Map。 +- `To[T](rs *Result) T`: **泛型版本**,直接返回目标类型对象。 +- `Ints()`, `Strings()`, `Results()`: 处理数组结果。 +- `ResultMap()`: 处理 Hash 或键值对结果。 + +### 发布订阅 +- `Start()`: 开启订阅监听协程。 +- `Subscribe(channel string, reset func(), received func([]byte))`: 订阅频道。 +- `Unsubscribe(channel string)`: 取消订阅。 +- `Stop()`: 停止所有订阅。 + +### 分布式 ID (IdMaker) +- `NewIdMaker(rd *Redis) *IdMaker`: 创建分布式 ID 生成器。 +- `Get(size int)`: 获取指定长度的唯一 ID。 +- `GetForMysql(size int)`: 获取针对 MySQL 优化的唯一 ID。 + +## 示例 + +```go +import "apigo.cc/go/redis" + +// 获取实例 (自动从 redis.json 加载) +rd := redis.GetRedis("test", nil) + +// 设置结构体 (自动 JSON 序列化) +rd.SET("user:1", User{Name: "Sam", Age: 18}) + +// 获取并自动绑定 (泛型) +user := redis.To[User](rd.GET("user:1")) + +// 分布式 ID +maker := redis.NewIdMaker(rd) +id := maker.Get(10) +``` diff --git a/commands.go b/commands.go new file mode 100644 index 0000000..f49318b --- /dev/null +++ b/commands.go @@ -0,0 +1,165 @@ +package redis + +func stringsToAnys(in []string) []any { + a := make([]any, len(in)) + for i, v := range in { + a[i] = v + } + return a +} + +func (rd *Redis) DEL(keys ...string) int { + return rd.Do("DEL", stringsToAnys(keys)...).Int() +} +func (rd *Redis) EXISTS(key string) bool { + return rd.Do("EXISTS " + key).Bool() +} +func (rd *Redis) EXPIRE(key string, second int) bool { + if second > 315360000 { + return rd.Do("EXPIREAT "+key, second).Bool() + } else { + return rd.Do("EXPIRE "+key, second).Bool() + } +} +func (rd *Redis) KEYS(patten string) []string { + return rd.Do("KEYS " + patten).Strings() +} + +func (rd *Redis) GET(key string) *Result { + return rd.Do("GET " + key) +} +func (rd *Redis) SET(key string, value any) bool { + return rd.Do("SET "+key, value).Bool() +} +func (rd *Redis) SETEX(key string, seconds int, value any) bool { + return rd.Do("SETEX "+key, seconds, value).Bool() +} +func (rd *Redis) SETNX(key string, value any) bool { + return rd.Do("SETNX "+key, value).Bool() +} +func (rd *Redis) GETSET(key string, value any) *Result { + return rd.Do("GETSET "+key, value) +} + +func (rd *Redis) INCR(key string) int64 { + return rd.Do("INCR " + key).Int64() +} +func (rd *Redis) INCRBY(key string, delta int64) int64 { + return rd.Do("INCRBY "+key, delta).Int64() +} +func (rd *Redis) DECR(key string, delta int64) int64 { + return rd.Do("DECR "+key, delta).Int64() +} +func (rd *Redis) DECRBY(key string, delta int64) int64 { + return rd.Do("DECRBY "+key, delta).Int64() +} + +func (rd *Redis) MGET(keys ...string) []Result { + return rd.Do("MGET", stringsToAnys(keys)...).Results() +} +func (rd *Redis) MSET(keyAndValues ...any) bool { + return rd.Do("MSET", keyAndValues...).Bool() +} + +func (rd *Redis) HGET(key, field string) *Result { + return rd.Do("HGET "+key, field) +} +func (rd *Redis) HSET(key, field string, value any) bool { + return rd.Do("HSET "+key, field, value).Error == nil +} +func (rd *Redis) HSETNX(key, field string, value any) bool { + return rd.Do("HSETNX "+key, field, value).Error == nil +} +func (rd *Redis) HMGET(key string, fields ...string) []Result { + return rd.Do("HMGET", append(append([]any{}, key), stringsToAnys(fields)...)...).Results() +} +func (rd *Redis) HGETALL(key string) map[string]*Result { + return rd.Do("HGETALL " + key).ResultMap() +} +func (rd *Redis) HMSET(key string, fieldAndValues ...any) bool { + return rd.Do("HMSET", append(append([]any{}, key), fieldAndValues...)...).Bool() +} +func (rd *Redis) HKEYS(key string) []string { + return rd.Do("HKEYS " + key).Strings() +} +func (rd *Redis) HLEN(key string) int { + return rd.Do("HLEN " + key).Int() +} +func (rd *Redis) HDEL(key string, fields ...string) int { + return rd.Do("HDEL", append(append([]any{}, key), stringsToAnys(fields)...)...).Int() +} +func (rd *Redis) HEXISTS(key, field string) bool { + return rd.Do("HEXISTS "+key, field).Bool() +} +func (rd *Redis) HINCR(key, field string) int64 { + return rd.Do("HINCRBY "+key, field, 1).Int64() +} +func (rd *Redis) HINCRBY(key, field string, delta int64) int64 { + return rd.Do("HINCRBY "+key, field, delta).Int64() +} +func (rd *Redis) HDECR(key, field string) int64 { + return rd.Do("HDECRBY "+key, field, 1).Int64() +} +func (rd *Redis) HDECRBY(key, field string, delta int64) int64 { + return rd.Do("HDECRBY "+key, field, delta).Int64() +} + +func (rd *Redis) LPUSH(key string, values ...string) int { + return rd.Do("LPUSH", append(append([]any{}, key), stringsToAnys(values)...)...).Int() +} +func (rd *Redis) RPUSH(key string, values ...string) int { + return rd.Do("RPUSH", append(append([]any{}, key), stringsToAnys(values)...)...).Int() +} +func (rd *Redis) LPOP(key string) *Result { + return rd.Do("LPOP " + key) +} +func (rd *Redis) RPOP(key string) *Result { + return rd.Do("RPOP " + key) +} +func (rd *Redis) LLEN(key string) int { + return rd.Do("LLEN " + key).Int() +} +func (rd *Redis) LRANGE(key string, start, stop int) []Result { + return rd.Do("LRANGE "+key, start, stop).Results() +} + +func (rd *Redis) SADD(key string, values ...any) int { + return rd.Do("SADD", append([]any{key}, values...)...).Int() +} +func (rd *Redis) SREM(key string, values ...any) int { + return rd.Do("SREM", append([]any{key}, values...)...).Int() +} +func (rd *Redis) SCARD(key string) int { + return rd.Do("SCARD " + key).Int() +} +func (rd *Redis) SMEMBERS(key string) []Result { + return rd.Do("SMEMBERS " + key).Results() +} +func (rd *Redis) SISMEMBER(key string, value any) bool { + return rd.Do("SISMEMBER "+key, value).Bool() +} + +func (rd *Redis) ZADD(key string, score float64, member any) bool { + return rd.Do("ZADD "+key, score, member).Bool() +} +func (rd *Redis) ZREM(key string, members ...any) int { + return rd.Do("ZREM", append([]any{key}, members...)...).Int() +} +func (rd *Redis) ZCARD(key string) int { + return rd.Do("ZCARD " + key).Int() +} +func (rd *Redis) ZRANGE(key string, start, stop int) []Result { + return rd.Do("ZRANGE "+key, start, stop).Results() +} +func (rd *Redis) ZREVRANGE(key string, start, stop int) []Result { + return rd.Do("ZREVRANGE "+key, start, stop).Results() +} +func (rd *Redis) ZRANK(key string, member any) int { + return rd.Do("ZRANK "+key, member).Int() +} +func (rd *Redis) ZREVRANK(key string, member any) int { + return rd.Do("ZREVRANK "+key, member).Int() +} +func (rd *Redis) ZSCORE(key string, member any) float64 { + return rd.Do("ZSCORE "+key, member).Float64() +} diff --git a/config.go b/config.go new file mode 100644 index 0000000..f980819 --- /dev/null +++ b/config.go @@ -0,0 +1,139 @@ +package redis + +import ( + "fmt" + "net/url" + "strconv" + "sync" + "time" + + "apigo.cc/go/cast" + "apigo.cc/go/crypto" + "apigo.cc/go/log" + "apigo.cc/go/safe" +) + +type Config struct { + Host string + Password string + DB int + MaxActive int + MaxIdle int + IdleTimeout time.Duration + ConnectTimeout time.Duration + ReadTimeout time.Duration + WriteTimeout time.Duration + LogSlow time.Duration + logger *log.Logger + pwd *safe.SafeBuf +} + +var redisConfigs = make(map[string]*Config) +var redisConfigsLock = sync.RWMutex{} + +var confAes, _ = crypto.NewAESGCMAndEraseKey([]byte("?GQ$0K0GgLdO=f+~L68PLm$uhKr4'=tV"), []byte("VFs7@sK61cj^f?HZ")) +var keysOnce = sync.Once{} + +func SetEncryptKeys(key, iv []byte) { + keysOnce.Do(func() { + confAes.Close() + confAes, _ = crypto.NewAESGCMAndEraseKey(key, iv) + }) +} + +func (conf *Config) ConfigureBy(setting string) { + redisConfigsLock.Lock() + redisConfigs[setting] = conf + redisConfigsLock.Unlock() + + urlInfo, err := url.Parse(setting) + if err != nil { + conf.logger.Error(err.Error(), "url", setting) + return + } + if urlInfo.Scheme != "redis" { + conf.logger.Error("unsupported scheme", "url", setting) + return + } + + conf.Host = urlInfo.Host + + dbStr := urlInfo.Query().Get("database") + if dbStr == "" && len(urlInfo.Path) > 1 { + dbStr = urlInfo.Path[1:] + } + if len(dbStr) > 0 { + db, err := strconv.Atoi(dbStr) + if err != nil { + conf.logger.Error(err.Error(), "url", setting) + } + if err == nil && db >= 0 && db <= 15 { + conf.DB = db + } + } + + conf.Password, _ = urlInfo.User.Password() + conf.LogSlow = cast.Duration(urlInfo.Query().Get("logSlow")) + conf.MaxIdle = cast.Int(urlInfo.Query().Get("maxIdle")) + conf.MaxActive = cast.Int(urlInfo.Query().Get("maxActive")) + conf.ConnectTimeout = cast.Duration(urlInfo.Query().Get("connectTimeout")) + conf.ReadTimeout = cast.Duration(urlInfo.Query().Get("readTimeout")) + conf.WriteTimeout = cast.Duration(urlInfo.Query().Get("writeTimeout")) + conf.IdleTimeout = cast.Duration(urlInfo.Query().Get("idleTimeout")) +} + +func (conf *Config) Dsn() string { + return fmt.Sprintf("redis://:****@%s/%d?timeout=%s&logSlow=%s", conf.Host, conf.DB, conf.ConnectTimeout, conf.LogSlow) +} + +func parseByName(name string) *Config { + // config name support Host:Port + args := cast.Split(name, ":") + db := 0 + if len(args) > 1 { + arg1, err := strconv.Atoi(args[1]) + if err == nil && arg1 >= 0 && arg1 <= 15 { + name = args[0] + db = arg1 + } + } + + redisConfigsLock.RLock() + conf := redisConfigs[name] + redisConfigsLock.RUnlock() + + if conf == nil { + conf = new(Config) + redisConfigsLock.Lock() + redisConfigs[name] = conf + redisConfigsLock.Unlock() + + if len(args) > 1 { + arg1, err := strconv.Atoi(args[1]) + if err == nil && arg1 >= 0 && arg1 <= 15 { + conf.DB = arg1 + } else { + conf.Host = args[0] + ":" + args[1] + } + } + } + + for i := 2; i < len(args); i++ { + arg2, err := strconv.Atoi(args[i]) + if err == nil { + if arg2 >= 0 && arg2 <= 15 { + conf.DB = arg2 + } else { + conf.Password = args[i] + } + } else { + conf.Password = args[i] + } + } + + if conf.DB == 0 && db > 0 && db <= 15 { + conf.DB = db + } + + return conf +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..eda9ae6 --- /dev/null +++ b/go.mod @@ -0,0 +1,24 @@ +module apigo.cc/go/redis + +go 1.25.0 + +require ( + apigo.cc/go/cast v1.1.1 + apigo.cc/go/config v1.0.4 + apigo.cc/go/crypto v1.0.4 + apigo.cc/go/id v1.0.4 + apigo.cc/go/log v1.0.0 + apigo.cc/go/safe v1.0.4 + github.com/gomodule/redigo v1.9.3 +) + +require ( + apigo.cc/go/convert v1.0.4 // indirect + apigo.cc/go/encoding v1.0.4 // indirect + apigo.cc/go/file v1.0.4 // indirect + apigo.cc/go/rand v1.0.4 // indirect + apigo.cc/go/shell v1.0.4 // indirect + golang.org/x/crypto v0.50.0 // indirect + golang.org/x/sys v0.43.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..302113e --- /dev/null +++ b/go.sum @@ -0,0 +1,38 @@ +apigo.cc/go/cast v1.1.1 h1:+5pluN8g1RK2J4byr2xkfOmEdKSmy1PByOqDOHtt/Ns= +apigo.cc/go/cast v1.1.1/go.mod h1:vh9ZqISCmTUiyinkNMI/s4f045fRlDK3xC+nPWQYBzI= +apigo.cc/go/config v1.0.4 h1:WG9zrQkqfFPkrKIL7RNvvAbbkuUBt1Av11ZP/aIfldM= +apigo.cc/go/config v1.0.4/go.mod h1:obryzJiK6j7lQex/58d5eWYOGx5O5IABguqNWxyyXJo= +apigo.cc/go/convert v1.0.4 h1:5+qPjC3dlPB59GnWZRlmthxcaXQtKvN+iOuiLdJ1GvQ= +apigo.cc/go/convert v1.0.4/go.mod h1:Hp+geeSyhqg/zwIKPOrDoceIREzcwM14t1I5q/dtbfU= +apigo.cc/go/crypto v1.0.4 h1:VPUyHCH2N3LLEgdpwUc+DQssNHzLlxVzLNRa0Jm6O4o= +apigo.cc/go/crypto v1.0.4/go.mod h1:5sI8BLw6YHZfDReYwCO3TFD2LKm36HMdLg1S5oPv/QU= +apigo.cc/go/encoding v1.0.4 h1:aezB0J/qFuHs6iXkbtuJP5JIHUtmjsr5SFb0NNvbObY= +apigo.cc/go/encoding v1.0.4/go.mod h1:V5CgT7rBbCxy+uCU20q0ptcNNRSgMtpA8cNOs6r8IeI= +apigo.cc/go/file v1.0.4 h1:qCKegV7OYh7r0qc3jZjGA/aKh0vIHgmr1OEbhfEmGX8= +apigo.cc/go/file v1.0.4/go.mod h1:C9gNo7386iA21OiBmuWh6CznKWlVBDFkhE4f0H0Susg= +apigo.cc/go/id v1.0.4 h1:w+JSdeVit52iefIUolrh1qLEZS9XqHNKr1UygFcgv+s= +apigo.cc/go/id v1.0.4/go.mod h1:kg7QuceAKtGNzGWt0+pIIh8Qom1eMSWGb8+0Yhi/QVY= +apigo.cc/go/log v1.0.0 h1:lI1NGTSS+Jm12G8BD7ZJO4/hrkfuLTu5O8z36GD8GpU= +apigo.cc/go/log v1.0.0/go.mod h1:tvPgFpebY9Wf/DlqMHZ0ZjxDp9AaQTywOQKvtBaNqNo= +apigo.cc/go/rand v1.0.4 h1:we070eWSL0dB8NEMaWjXj43+EekXQTm/h0kKpZ/frqw= +apigo.cc/go/rand v1.0.4/go.mod h1:mZ/4Soa3bk+XvDaqPWJuUe1bfEi4eThBj1XmEAuYxsk= +apigo.cc/go/safe v1.0.4 h1:07pRSdEHprF/2v6SsqAjICYFoeLcqjjvHGEdh6Dzrzg= +apigo.cc/go/safe v1.0.4/go.mod h1:o568sHS5rTRSVPmhxWod0tGdc+8l1KjidsNY1/OVZr0= +apigo.cc/go/shell v1.0.4 h1:EL9zjI39YBe1h+kRYQeAi/8zVGHe5W198DYYN7cENiY= +apigo.cc/go/shell v1.0.4/go.mod h1:N2gDkgK4tJ9TadD60/+gAGuWxyVAWHs5YPBmytw6ELA= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gomodule/redigo v1.9.3 h1:dNPSXeXv6HCq2jdyWfjgmhBdqnR6PRO3m/G05nvpPC8= +github.com/gomodule/redigo v1.9.3/go.mod h1:KsU3hiK/Ay8U42qpaJk+kuNa3C+spxapWpM+ywhcgtw= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= +golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= +golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= +golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/id.go b/id.go new file mode 100644 index 0000000..425de2b --- /dev/null +++ b/id.go @@ -0,0 +1,62 @@ +package redis + +import ( + "fmt" + "sync" + + "apigo.cc/go/id" +) + +type IdMaker struct { + rd *Redis + secCurrent uint64 + secIndexMax uint64 + secIndexNext uint64 + lock sync.Mutex + maker *id.IDMaker +} + +func NewIdMaker(rd *Redis) *IdMaker { + im := &IdMaker{rd: rd} + im.maker = id.NewIDMaker(im.makeSecIndex) + return im +} + +func (im *IdMaker) makeSecIndex(sec uint64) uint64 { + im.lock.Lock() + defer im.lock.Unlock() + + if im.secCurrent == sec && im.secIndexNext <= im.secIndexMax { + idx := im.secIndexNext + im.secIndexNext++ + return idx + } + + im.secCurrent = sec + key := fmt.Sprintf("_SecIdx_%d", sec) + // 每次从 Redis 预取 100 个序列号 + max := uint64(im.rd.INCRBY(key, 100)) + if max < 100 { + return 0 + } + im.secIndexMax = max + im.secIndexNext = max - 99 + idx := im.secIndexNext + im.secIndexNext++ + if max <= 100 { + im.rd.EXPIRE(key, 10) + } + return idx +} + +func (im *IdMaker) Get(size int) string { + return im.maker.Get(size) +} + +func (im *IdMaker) GetForMysql(size int) string { + return im.maker.GetForMysql(size) +} + +func (im *IdMaker) GetForPostgreSQL(size int) string { + return im.maker.GetForPostgreSQL(size) +} diff --git a/redis.go b/redis.go new file mode 100644 index 0000000..0776d4d --- /dev/null +++ b/redis.go @@ -0,0 +1,390 @@ +package redis + +import ( + "encoding/json" + "errors" + "io" + "net" + "reflect" + "strconv" + "strings" + "sync" + "syscall" + "time" + + "apigo.cc/go/cast" + "apigo.cc/go/config" + "apigo.cc/go/log" + "apigo.cc/go/safe" + "github.com/gomodule/redigo/redis" +) + +type Redis struct { + name string + pool *redis.Pool + ReadTimeout int + Config *Config + logger *log.Logger + Error error + subConn *redis.PubSubConn + subStopChan chan bool + subs map[string]*SubCallbacks + SubRunning bool +} + +type SubCallbacks struct { + received func([]byte) + reset func() +} + +var redisInstances = make(map[string]*Redis) +var redisInstancesLock = sync.RWMutex{} + +func GetRedis(name string, logger *log.Logger) *Redis { + if logger == nil { + logger = log.DefaultLogger + } + + redisInstancesLock.RLock() + oldConn := redisInstances[name] + redisInstancesLock.RUnlock() + if oldConn != nil { + return oldConn.CopyByLogger(logger) + } + + redisConfigsLock.RLock() + configsLen := len(redisConfigs) + redisConfigsLock.RUnlock() + + if configsLen == 0 { + _ = config.Load("redis", &redisConfigs) + } + + fullName := name + + var conf *Config + if strings.HasPrefix(name, "redis://") { + conf = new(Config) + conf.logger = logger + conf.ConfigureBy(name) + } else { + conf = parseByName(name) + } + + if pwd, err := confAes.Decrypt([]byte(conf.Password)); err == nil { + conf.pwd = pwd + } else { + conf.pwd = safe.NewSafeBuf([]byte(conf.Password)) + } + conf.Password = "" + + if conf.Host == "" { + conf.Host = "127.0.0.1:6379" + } + if conf.MaxIdle == 0 { + conf.MaxIdle = 20 + } + if conf.MaxActive == 0 { + conf.MaxActive = 100 + } + if conf.ConnectTimeout == 0 { + conf.ConnectTimeout = 10 * time.Second + } + if conf.ReadTimeout == 0 { + conf.ReadTimeout = 10 * time.Second + } + if conf.WriteTimeout == 0 { + conf.WriteTimeout = 10 * time.Second + } + if conf.LogSlow == 0 { + conf.LogSlow = 100 * time.Millisecond + } + + rd := NewRedis(conf, logger) + rd.name = fullName + redisInstancesLock.Lock() + redisInstances[fullName] = rd + redisInstancesLock.Unlock() + return rd.CopyByLogger(logger) +} + +func NewRedis(conf *Config, logger *log.Logger) *Redis { + if logger == nil { + logger = log.DefaultLogger + } + + conn := &redis.Pool{ + MaxIdle: conf.MaxIdle, + MaxActive: conf.MaxActive, + IdleTimeout: conf.IdleTimeout, + Dial: func() (redis.Conn, error) { + opts := []redis.DialOption{ + redis.DialConnectTimeout(conf.ConnectTimeout), + redis.DialReadTimeout(conf.ReadTimeout), + redis.DialWriteTimeout(conf.WriteTimeout), + redis.DialDatabase(conf.DB), + } + if conf.pwd != nil { + pwdBuf := conf.pwd.Open() + defer pwdBuf.Close() + opts = append(opts, redis.DialPassword(pwdBuf.String())) + } + c, err := redis.Dial("tcp", conf.Host, opts...) + if err != nil { + logger.Error(err.Error(), "type", "redis", "dsn", conf.Dsn()) + return nil, err + } + return c, nil + }, + } + + rd := new(Redis) + rd.ReadTimeout = int(conf.ReadTimeout / time.Millisecond) + rd.pool = conn + rd.Config = conf + rd.logger = logger + + return rd +} + +func (rd *Redis) CopyByLogger(logger *log.Logger) *Redis { + newRedis := new(Redis) + newRedis.name = rd.name + newRedis.ReadTimeout = rd.ReadTimeout + newRedis.pool = rd.pool + newRedis.subConn = rd.subConn + newRedis.subs = rd.subs + newRedis.SubRunning = rd.SubRunning + newRedis.Config = rd.Config + if logger == nil { + newRedis.logger = log.DefaultLogger + } else { + newRedis.logger = logger + } + return newRedis +} + +func (rd *Redis) SetLogger(logger *log.Logger) { + rd.logger = logger +} + +func (rd *Redis) GetLogger() *log.Logger { + return rd.logger +} + +func (rd *Redis) Destroy() error { + if rd.pool == nil { + return errors.New("operate on a bad redis pool") + } + err := rd.pool.Close() + if err != nil { + rd.logger.Error(err.Error(), "type", "redis", "dsn", rd.Config.Dsn()) + } + redisInstancesLock.Lock() + delete(redisInstances, rd.name) + redisInstancesLock.Unlock() + return err +} + +func (rd *Redis) GetPool() *redis.Pool { + return rd.pool +} + +func (rd *Redis) GetNewConnection() (redis.Conn, error) { + if rd.pool == nil { + return nil, errors.New("redis pool is not initialized") + } + c, err := rd.pool.Dial() + if err == nil { + err = c.Err() + } + if err != nil { + if c != nil { + _ = c.Close() + } + return nil, err + } + return c, nil +} + +func (rd *Redis) GetConnection() (redis.Conn, error) { + if rd.pool == nil { + return nil, errors.New("redis pool is not initialized") + } + c := rd.pool.Get() + err := c.Err() + if err != nil { + if c != nil { + _ = c.Close() + } + return nil, err + } + return c, nil +} + +func shouldRetry(err error) bool { + if err == nil { + return false + } + + // 网络错误 + if errors.Is(err, io.EOF) || errors.Is(err, syscall.ECONNRESET) { + return true + } + + // 超时错误 + if opErr, ok := err.(*net.OpError); ok && opErr.Timeout() { + return true + } + + // Redis特定可恢复错误 + if errs, ok := err.(redis.Error); ok { + switch { + case strings.HasPrefix(string(errs), "LOADING"), + strings.HasPrefix(string(errs), "CLUSTERDOWN"): + return true + } + } + + return false +} + +func (rd *Redis) Do(cmd string, values ...any) *Result { + if rd.pool == nil { + err := errors.New("operate on a bad redis pool") + rd.logger.Error(err.Error(), "type", "redis", "dsn", rd.Config.Dsn(), "cmd", cmd) + return &Result{Error: err} + } + startTime := time.Now() + r := rd.do(cmd, values...) + usedTime := time.Since(startTime) + if r.Error == nil { + if rd.Config.LogSlow > 0 && usedTime >= rd.Config.LogSlow { + rd.logger.Debug("redis slow query", "dsn", rd.Config.Dsn(), "cmd", cmd, "args", values, "used", usedTime.String()) + } + } else { + rd.logger.Error(r.Error.Error(), "type", "redis", "dsn", rd.Config.Dsn(), "cmd", cmd, "args", values, "used", usedTime.String()) + } + return r +} + +func (rd *Redis) do(cmd string, values ...any) *Result { + cmdArr := cast.Split(cmd, " ") + if len(cmdArr) > 1 { + cmd = cmdArr[0] + args := make([]any, 0) + for i := 1; i < len(cmdArr); i++ { + args = append(args, cmdArr[i]) + } + if len(values) > 0 { + args = append(args, values...) + } + values = args + } + + // 自动序列化 + for i, v := range values { + if v == nil { + continue + } + rv := reflect.ValueOf(v) + if rv.Kind() == reflect.Ptr { + rv = rv.Elem() + } + kind := rv.Kind() + if kind == reflect.Struct || kind == reflect.Map || (kind == reflect.Slice && rv.Type().Elem().Kind() != reflect.Uint8) { + // 对于这些类型,优先使用 json.Marshal 以支持自定义 Marshaler (如 time.Time) + if encoded, err := json.Marshal(v); err == nil { + values[i] = encoded + } + } + } + + // 从连接池获取 + conn, err := rd.GetConnection() + var replyData any + if err == nil { + replyData, err = conn.Do(cmd, values...) + _ = conn.Close() + } + + if err != nil && shouldRetry(err) { + // 拿全新的连接重试(如果服务器重启可自动恢复) + conn, err = rd.GetNewConnection() + if err == nil { + replyData, err = conn.Do(cmd, values...) + _ = conn.Close() + } + } + + if err != nil { + return &Result{Error: err} + } + + r := &Result{} + switch realValue := replyData.(type) { + case []byte: + r.bytesData = realValue + case string: + r.bytesData = []byte(realValue) + case int64: + r.bytesData = []byte(strconv.FormatInt(realValue, 10)) + case []any: + if cmd == "HMGET" { + r.keys = make([]string, len(values)-1) + for i, v := range values { + if i > 0 { + r.keys[i-1] = cast.String(v) + } + } + } else if cmd == "MGET" { + r.keys = make([]string, len(values)) + for i, v := range values { + r.keys[i] = cast.String(v) + } + } + + if cmd == "HGETALL" { + r.keys = make([]string, len(realValue)/2) + r.bytesDatas = make([][]byte, len(realValue)/2) + i1, i2 := 0, 0 + for i, v := range realValue { + if v != nil { + if i%2 == 0 { + r.keys[i1] = cast.String(v) + i1++ + } else { + switch subRealValue := v.(type) { + case []byte: + r.bytesDatas[i2] = subRealValue + case string: + r.bytesDatas[i2] = []byte(subRealValue) + default: + r.bytesDatas[i2] = []byte(cast.String(v)) + } + i2++ + } + } + } + } else { + r.bytesDatas = make([][]byte, len(realValue)) + for i, v := range realValue { + if v != nil { + switch subRealValue := v.(type) { + case []byte: + r.bytesDatas[i] = subRealValue + case string: + r.bytesDatas[i] = []byte(subRealValue) + default: + r.bytesDatas[i] = []byte(cast.String(v)) + } + } + } + } + case nil: + r.bytesData = []byte{} + default: + r.bytesData = []byte(cast.String(realValue)) + } + return r +} diff --git a/redis_test.go b/redis_test.go new file mode 100644 index 0000000..f7c592f --- /dev/null +++ b/redis_test.go @@ -0,0 +1,104 @@ +package redis_test + +import ( + "os" + "testing" + "time" + + "apigo.cc/go/config" + "apigo.cc/go/redis" +) + +type userInfo struct { + Id int + Name string + Phone string + Time time.Time +} + +func TestBase(t *testing.T) { + os.Setenv("REDIS_TEST", "redis://:@localhost:6379/2?timeout=10ms&logSlow=10us") + _ = config.Load("redis", nil) + + rd := redis.GetRedis("test", nil) + if rd.Error != nil { + t.Fatal("GetRedis error", rd.Error) + } + rd.DEL("redisName", "redisUser", "redisIds") + + r := rd.GET("redisNotExists") + if r.Error != nil && r.String() != "" || r.Int() != 0 { + t.Fatal("GET NotExists", r, r.String(), r.Int()) + } + + exists := rd.EXISTS("redisName") + if exists { + t.Fatal("EXISTS should be false") + } + + rd.SET("redisName", "12345") + + r = rd.GETSET("redisName", 12345) + if r.String() != "12345" { + t.Fatal("GETSET String mismatch", r.String()) + } + if r.Int() != 12345 { + t.Fatal("Int conversion mismatch", r.Int()) + } + + exists = rd.EXISTS("redisName") + if !exists { + t.Fatal("EXISTS should be true") + } + + // Expire test + rd.SET("redisName", "12") + rd.EXPIRE("redisName", 1) + time.Sleep(2 * time.Second) + r = rd.GET("redisName") + if r.Int() > 0 { + t.Fatal("Expired key still exists", r.Int()) + } + + // Struct test + info := userInfo{ + Name: "aaa", + Id: 123, + Time: time.Now().Truncate(time.Second), // Redis JSON might lose precision + } + rd.SET("redisUser", info) + r = rd.GET("redisUser") + var ru userInfo + _ = r.To(&ru) + if ru.Name != info.Name || ru.Id != info.Id || !ru.Time.Equal(info.Time) { + t.Fatalf("Struct mismatch: expected %+v, got %+v", info, ru) + } + + // MSET/MGET test + rd.MSET("redisName", "Sam Lee", "redisIds", []int{1, 2, 3}) + results := rd.MGET("redisName", "redisIds") + if len(results) != 2 || results[0].String() != "Sam Lee" { + t.Fatal("MGET Results mismatch") + } + ria := results[1].Ints() + if len(ria) != 3 || ria[0] != 1 || ria[1] != 2 || ria[2] != 3 { + t.Fatal("MGET Ints mismatch", ria) + } + + num := rd.DEL("redisName", "redisUser", "redisIds") + if num != 3 { + t.Fatal("DEL count mismatch", num) + } +} + +func TestGenerics(t *testing.T) { + rd := redis.GetRedis("test", nil) + rd.SET("gen_test", userInfo{Name: "Generics", Id: 888}) + defer rd.DEL("gen_test") + + r := rd.GET("gen_test") + user := redis.To[userInfo](r) + if user.Name != "Generics" || user.Id != 888 { + t.Fatal("Generics To[T] mismatch", user) + } +} diff --git a/result.go b/result.go new file mode 100644 index 0000000..50535cc --- /dev/null +++ b/result.go @@ -0,0 +1,224 @@ +package redis + +import ( + "encoding/json" + "reflect" + + "apigo.cc/go/cast" +) + +type Result struct { + bytesData []byte + keys []string + bytesDatas [][]byte + Error error +} + +func (rs *Result) Int() int { + return cast.Int(rs.String()) +} +func (rs *Result) Int8() int8 { + return int8(cast.Int64(rs.String())) +} +func (rs *Result) Int16() int16 { + return int16(cast.Int64(rs.String())) +} +func (rs *Result) Int32() int32 { + return int32(cast.Int64(rs.String())) +} +func (rs *Result) Int64() int64 { + return cast.Int64(rs.String()) +} +func (rs *Result) Uint() uint { + return cast.Uint(rs.String()) +} +func (rs *Result) Uint8() uint8 { + return uint8(cast.Uint64(rs.String())) +} +func (rs *Result) Uint16() uint16 { + return uint16(cast.Uint64(rs.String())) +} +func (rs *Result) Uint32() uint32 { + return uint32(cast.Uint64(rs.String())) +} +func (rs *Result) Uint64() uint64 { + return cast.Uint64(rs.String()) +} +func (rs *Result) Float() float32 { + return cast.Float(rs.String()) +} +func (rs *Result) Float64() float64 { + return cast.Float64(rs.String()) +} +func (rs *Result) String() string { + return string(rs.bytes()) +} +func (rs *Result) Bytes() []byte { + return rs.bytes() +} +func (rs *Result) Bool() bool { + return cast.Bool(rs.String()) +} + +func (rs *Result) Ints() []int { + if rs.bytesDatas != nil { + r := make([]int, len(rs.bytesDatas)) + for i, v := range rs.bytesDatas { + r[i] = cast.Int(string(v)) + } + return r + } else if rs.bytesData != nil { + var r []int + _ = rs.To(&r) + return r + } + return []int{} +} + +func (rs *Result) Strings() []string { + if rs.bytesDatas != nil { + r := make([]string, len(rs.bytesDatas)) + for i, v := range rs.bytesDatas { + r[i] = string(v) + } + return r + } else if rs.bytesData != nil { + var r []string + _ = rs.To(&r) + return r + } + return []string{} +} + +func (rs *Result) Results() []Result { + if rs.bytesDatas != nil { + r := make([]Result, len(rs.bytesDatas)) + for i, v := range rs.bytesDatas { + r[i].bytesData = v + } + return r + } else if rs.bytesData != nil { + var m []string + _ = rs.To(&m) + r := make([]Result, len(m)) + for k, v := range m { + r[k] = Result{bytesData: []byte(v)} + } + return r + } + return []Result{} +} + +func (rs *Result) ResultMap() map[string]*Result { + if rs.bytesDatas != nil && rs.keys != nil { + r := make(map[string]*Result) + n := len(rs.bytesDatas) + for i, k := range rs.keys { + if i < n { + r[k] = &Result{bytesData: rs.bytesDatas[i]} + } + } + return r + } else if rs.bytesData != nil { + r := make(map[string]*Result) + var m map[string]string + _ = rs.To(&m) + for k, v := range m { + r[k] = &Result{bytesData: []byte(v)} + } + return r + } + return map[string]*Result{} +} + +func (rs *Result) StringMap() map[string]string { + rm := rs.ResultMap() + m := make(map[string]string) + for k, r := range rm { + m[k] = r.String() + } + return m +} + +func (rs *Result) IntMap() map[string]int { + rm := rs.ResultMap() + m := make(map[string]int) + for k, r := range rm { + m[k] = r.Int() + } + return m +} + +func (rs *Result) bytes() []byte { + if rs.bytesData != nil { + return rs.bytesData + } else if rs.bytesDatas != nil { + return cast.MustJSONBytes(rs.Strings()) + } + return []byte{} +} + +func (rs *Result) ToValue(t reflect.Type) reflect.Value { + v := reflect.New(t).Elem() + switch t.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.SetInt(rs.Int64()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.SetUint(rs.Uint64()) + case reflect.Float32, reflect.Float64: + v.SetFloat(rs.Float64()) + case reflect.Bool: + v.SetBool(rs.Bool()) + case reflect.String: + v.SetString(rs.String()) + case reflect.Map, reflect.Slice, reflect.Struct: + _ = rs.To(v.Addr().Interface()) + } + return v +} + +// To 使用泛型进行反序列化 +func To[T any](rs *Result) T { + var out T + _ = rs.To(&out) + return out +} + +func (rs *Result) To(result any) error { + if rs.bytesData != nil { + if len(rs.bytesData) > 0 { + // 优先使用 json.Unmarshal 以支持自定义 Unmarshaler (如 time.Time) + return json.Unmarshal(rs.bytesData, result) + } + return nil + } + + t := reflect.TypeOf(result) + v := reflect.ValueOf(result) + if t.Kind() == reflect.Ptr { + t = t.Elem() + v = v.Elem() + } + + if (t.Kind() == reflect.Struct || t.Kind() == reflect.Map) && rs.keys != nil && rs.bytesDatas != nil { + rm := rs.ResultMap() + for k, r := range rm { + if t.Kind() == reflect.Struct { + k = cast.GetUpperName(k) + sf, found := t.FieldByName(k) + if found { + v.FieldByName(k).Set(r.ToValue(sf.Type)) + } + } else if t.Kind() == reflect.Map { + v.SetMapIndex(reflect.ValueOf(k), r.ToValue(t.Elem())) + } + } + } else if t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8 && rs.bytesDatas != nil { + results := rs.Results() + for _, r := range results { + v = reflect.Append(v, r.ToValue(t.Elem())) + } + reflect.ValueOf(result).Elem().Set(v) + } + return nil +} diff --git a/subscribe.go b/subscribe.go new file mode 100644 index 0000000..6e73ccb --- /dev/null +++ b/subscribe.go @@ -0,0 +1,147 @@ +package redis + +import ( + "strings" + "time" + + "github.com/gomodule/redigo/redis" +) + +func (rd *Redis) Subscribe(name string, reset func(), received func([]byte)) bool { + if rd.subs == nil { + rd.subs = make(map[string]*SubCallbacks) + } + rd.subs[name] = &SubCallbacks{reset: reset, received: received} + if rd.subConn != nil { + err := rd.subConn.Subscribe(name) + if err != nil { + rd.logger.Error(err.Error(), "type", "redis", "action", "subscribe", "name", name) + } else { + return true + } + } + return false +} + +func (rd *Redis) Unsubscribe(name string) bool { + if rd.subs != nil { + delete(rd.subs, name) + } + if rd.subConn != nil { + err := rd.subConn.Unsubscribe(name) + if err != nil { + rd.logger.Error(err.Error(), "type", "redis", "action", "unsubscribe", "name", name) + } else { + return true + } + } + return false +} + +func (rd *Redis) Start() { + if rd.subs == nil { + rd.subs = make(map[string]*SubCallbacks) + } + rd.SubRunning = true + subStartChan := make(chan bool) + go rd.receiveSub(subStartChan) + <-subStartChan +} + +func (rd *Redis) receiveSub(subStartChan chan bool) { + for { + if !rd.SubRunning { + break + } + + // 开始接收订阅数据 + if rd.subConn == nil { + conn, err := rd.GetConnection() + if err != nil { + time.Sleep(time.Second) + continue + } + rd.subConn = &redis.PubSubConn{Conn: conn} + // 重新订阅 + if len(rd.subs) > 0 { + subs := make([]any, 0) + for k := range rd.subs { + subs = append(subs, k) + } + err = rd.subConn.Subscribe(subs...) + if err != nil { + _ = rd.subConn.Close() + rd.subConn = nil + time.Sleep(time.Second) + continue + } + // 重新连接时调用重置数据的回调 + for _, v := range rd.subs { + if v.reset != nil { + v.reset() + } + } + } + } + + if subStartChan != nil { + subStartChan <- true + subStartChan = nil + } + + for { + isErr := false + receiveObj := rd.subConn.Receive() + switch v := receiveObj.(type) { + case redis.Message: + callback := rd.subs[v.Channel] + if callback != nil && callback.received != nil { + callback.received(v.Data) + } + case redis.Subscription: + case redis.Pong: + case error: + if strings.Contains(v.Error(), "i/o timeout") { + break + } + if !strings.Contains(v.Error(), "connection closed") && !strings.Contains(v.Error(), "use of closed network connection") { + rd.logger.Error(v.Error(), "type", "redis", "action", "receiveSub") + } + if rd.subConn != nil { + _ = rd.subConn.Close() + rd.subConn = nil + } + isErr = true + } + if isErr || !rd.SubRunning { + break + } + } + } + if rd.subStopChan != nil { + rd.subStopChan <- true + } +} + +func (rd *Redis) Stop() { + if rd.SubRunning { + rd.subStopChan = make(chan bool) + rd.SubRunning = false + if rd.subConn != nil { + // 取消订阅 + if len(rd.subs) > 0 { + _ = rd.subConn.Unsubscribe() + } + // 读一次再关闭可以防止Close时阻塞 + _ = rd.subConn.ReceiveWithTimeout(50 * time.Millisecond) + _ = rd.subConn.Close() + rd.subConn = nil + } + <-rd.subStopChan + rd.subStopChan = nil + } +} + +func (rd *Redis) PUBLISH(channel, data string) bool { + return rd.Do("PUBLISH "+channel, data).Bool() +} diff --git a/subscribe_test.go b/subscribe_test.go new file mode 100644 index 0000000..d4e963c --- /dev/null +++ b/subscribe_test.go @@ -0,0 +1,56 @@ +package redis_test + +import ( + "sync/atomic" + "testing" + "time" + + "apigo.cc/go/redis" +) + +func TestSub(t *testing.T) { + rd := redis.GetRedis("test", nil) + rd.Start() + defer rd.Stop() + + var aaa atomic.Value + var bbb atomic.Value + + rd.Subscribe("aaa", nil, func(s []byte) { + aaa.Store(string(s)) + }) + + rd.PUBLISH("aaa", "111") + time.Sleep(100 * time.Millisecond) + + val := aaa.Load() + if val == nil || val.(string) != "111" { + t.Fatal("Subscribe aaa failed", val) + } + + rd.Subscribe("bbb", nil, func(s []byte) { + bbb.Store(string(s)) + }) + + rd.PUBLISH("bbb", "222") + time.Sleep(100 * time.Millisecond) + + val = bbb.Load() + if val == nil || val.(string) != "222" { + t.Fatal("Subscribe bbb failed", val) + } + + rd.Unsubscribe("aaa") + rd.PUBLISH("aaa", "1111") + rd.PUBLISH("bbb", "2222") + time.Sleep(100 * time.Millisecond) + + val = aaa.Load() + if val == nil || val.(string) != "111" { + t.Fatal("Unsubscribe aaa failed: value updated", val) + } + val = bbb.Load() + if val == nil || val.(string) != "2222" { + t.Fatal("Subscribe bbb update failed", val) + } +}