From 69324d566ea9f4c355174c036d6288800999762e Mon Sep 17 00:00:00 2001 From: AI Engineer Date: Sun, 3 May 2026 08:43:23 +0800 Subject: [PATCH] feat: migrate redis module from ssgo to apigo.cc/go standard --- CHANGELOG.md | 16 ++ README.md | 57 ++++++- commands.go | 165 ++++++++++++++++++++ config.go | 139 +++++++++++++++++ go.mod | 24 +++ go.sum | 38 +++++ id.go | 62 ++++++++ redis.go | 390 ++++++++++++++++++++++++++++++++++++++++++++++ redis_test.go | 104 +++++++++++++ result.go | 224 ++++++++++++++++++++++++++ subscribe.go | 147 +++++++++++++++++ subscribe_test.go | 56 +++++++ 12 files changed, 1420 insertions(+), 2 deletions(-) create mode 100644 CHANGELOG.md create mode 100644 commands.go create mode 100644 config.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 id.go create mode 100644 redis.go create mode 100644 redis_test.go create mode 100644 result.go create mode 100644 subscribe.go create mode 100644 subscribe_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..155c599 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,16 @@ +# CHANGELOG - redis + +## v1.0.0 (2026-05-03) +- **Repo Migration**: 从 `@ssgo/redis` 迁移至 `apigo.cc/go/redis`。 +- **Standard Realignment**: + - 依赖全面切换至 `apigo.cc/go/*` 标准库。 + - 适配 Go 1.25.0。 +- **Feature Enhancements**: + - **Generics**: 为 `Result` 引入泛型 `To[T]` 支持,消除类型断言摩擦。 + - **Memory Safety**: 集成 `go/safe` 对 Redis 密码进行实时加解密与内存锁定,防止内存泄漏敏感信息。 + - **Auto-Serialization**: 优化了 `Do` 方法,支持对 Struct/Map/Slice 自动进行 JSON 序列化(优先支持 Marshaler 接口)。 + - **Distributed ID**: 深度集成 `go/id` 核心,提供更高性能的 Redis 序列号预取机制。 +- **Refactoring**: + - 移除了冗余的 `interface{}`,全面改用 `any`。 + - 规范化了 API 命名,统一使用 `GetUpperName` 进行 Struct 字段映射。 + - 增强了连接重试机制,支持对网络波动和服务器重启的自动恢复。 diff --git a/README.md b/README.md index 3cf1b4b..34dab72 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,56 @@ -# redis +# Redis 模块 (redis) -高性能 Redis 客户端,集成分布式 ID 生成与发布订阅 \ No newline at end of file +`redis` 模块提供了一个高性能、内存安全且易于使用的 Redis 客户端,集成了分布式 ID 生成和发布订阅功能。 + +## 设计特性 +- **消除摩擦**: 自动处理连接池、重试、以及复杂类型的 JSON 序列化。 +- **内存安全**: 对 Redis 密码进行内存保护,防止内存 Dump 泄露敏感信息。 +- **泛型支持**: 结果集支持泛型绑定 `To[T]`。 +- **分布式 ID**: 内置基于 Redis 的高性能分布式 ID 生成器。 + +## API 指南 + +### 基础连接 +- `GetRedis(name string, logger *log.Logger) *Redis`: 获取或创建一个 Redis 实例(支持 DSN 或配置文件名)。 +- `NewRedis(conf *Config, logger *log.Logger) *Redis`: 使用指定配置创建 Redis 实例。 + +### 核心操作 +- `Do(cmd string, values ...any) *Result`: 执行原生 Redis 命令。 +- 支持大部分标准命令:`GET`, `SET`, `HGET`, `HSET`, `LPUSH`, `LPOP`, `SADD`, `ZRANGE`, `PUBLISH` 等。 + +### 结果处理 (Result) +- `Int()`, `String()`, `Bool()`, `Float()`: 基础类型转换。 +- `To(target any)`: 将结果反序列化到结构体或 Map。 +- `To[T](rs *Result) T`: **泛型版本**,直接返回目标类型对象。 +- `Ints()`, `Strings()`, `Results()`: 处理数组结果。 +- `ResultMap()`: 处理 Hash 或键值对结果。 + +### 发布订阅 +- `Start()`: 开启订阅监听协程。 +- `Subscribe(channel string, reset func(), received func([]byte))`: 订阅频道。 +- `Unsubscribe(channel string)`: 取消订阅。 +- `Stop()`: 停止所有订阅。 + +### 分布式 ID (IdMaker) +- `NewIdMaker(rd *Redis) *IdMaker`: 创建分布式 ID 生成器。 +- `Get(size int)`: 获取指定长度的唯一 ID。 +- `GetForMysql(size int)`: 获取针对 MySQL 优化的唯一 ID。 + +## 示例 + +```go +import "apigo.cc/go/redis" + +// 获取实例 (自动从 redis.json 加载) +rd := redis.GetRedis("test", nil) + +// 设置结构体 (自动 JSON 序列化) +rd.SET("user:1", User{Name: "Sam", Age: 18}) + +// 获取并自动绑定 (泛型) +user := redis.To[User](rd.GET("user:1")) + +// 分布式 ID +maker := redis.NewIdMaker(rd) +id := maker.Get(10) +``` diff --git a/commands.go b/commands.go new file mode 100644 index 0000000..f49318b --- /dev/null +++ b/commands.go @@ -0,0 +1,165 @@ +package redis + +func stringsToAnys(in []string) []any { + a := make([]any, len(in)) + for i, v := range in { + a[i] = v + } + return a +} + +func (rd *Redis) DEL(keys ...string) int { + return rd.Do("DEL", stringsToAnys(keys)...).Int() +} +func (rd *Redis) EXISTS(key string) bool { + return rd.Do("EXISTS " + key).Bool() +} +func (rd *Redis) EXPIRE(key string, second int) bool { + if second > 315360000 { + return rd.Do("EXPIREAT "+key, second).Bool() + } else { + return rd.Do("EXPIRE "+key, second).Bool() + } +} +func (rd *Redis) KEYS(patten string) []string { + return rd.Do("KEYS " + patten).Strings() +} + +func (rd *Redis) GET(key string) *Result { + return rd.Do("GET " + key) +} +func (rd *Redis) SET(key string, value any) bool { + return rd.Do("SET "+key, value).Bool() +} +func (rd *Redis) SETEX(key string, seconds int, value any) bool { + return rd.Do("SETEX "+key, seconds, value).Bool() +} +func (rd *Redis) SETNX(key string, value any) bool { + return rd.Do("SETNX "+key, value).Bool() +} +func (rd *Redis) GETSET(key string, value any) *Result { + return rd.Do("GETSET "+key, value) +} + +func (rd *Redis) INCR(key string) int64 { + return rd.Do("INCR " + key).Int64() +} +func (rd *Redis) INCRBY(key string, delta int64) int64 { + return rd.Do("INCRBY "+key, delta).Int64() +} +func (rd *Redis) DECR(key string, delta int64) int64 { + return rd.Do("DECR "+key, delta).Int64() +} +func (rd *Redis) DECRBY(key string, delta int64) int64 { + return rd.Do("DECRBY "+key, delta).Int64() +} + +func (rd *Redis) MGET(keys ...string) []Result { + return rd.Do("MGET", stringsToAnys(keys)...).Results() +} +func (rd *Redis) MSET(keyAndValues ...any) bool { + return rd.Do("MSET", keyAndValues...).Bool() +} + +func (rd *Redis) HGET(key, field string) *Result { + return rd.Do("HGET "+key, field) +} +func (rd *Redis) HSET(key, field string, value any) bool { + return rd.Do("HSET "+key, field, value).Error == nil +} +func (rd *Redis) HSETNX(key, field string, value any) bool { + return rd.Do("HSETNX "+key, field, value).Error == nil +} +func (rd *Redis) HMGET(key string, fields ...string) []Result { + return rd.Do("HMGET", append(append([]any{}, key), stringsToAnys(fields)...)...).Results() +} +func (rd *Redis) HGETALL(key string) map[string]*Result { + return rd.Do("HGETALL " + key).ResultMap() +} +func (rd *Redis) HMSET(key string, fieldAndValues ...any) bool { + return rd.Do("HMSET", append(append([]any{}, key), fieldAndValues...)...).Bool() +} +func (rd *Redis) HKEYS(key string) []string { + return rd.Do("HKEYS " + key).Strings() +} +func (rd *Redis) HLEN(key string) int { + return rd.Do("HLEN " + key).Int() +} +func (rd *Redis) HDEL(key string, fields ...string) int { + return rd.Do("HDEL", append(append([]any{}, key), stringsToAnys(fields)...)...).Int() +} +func (rd *Redis) HEXISTS(key, field string) bool { + return rd.Do("HEXISTS "+key, field).Bool() +} +func (rd *Redis) HINCR(key, field string) int64 { + return rd.Do("HINCRBY "+key, field, 1).Int64() +} +func (rd *Redis) HINCRBY(key, field string, delta int64) int64 { + return rd.Do("HINCRBY "+key, field, delta).Int64() +} +func (rd *Redis) HDECR(key, field string) int64 { + return rd.Do("HDECRBY "+key, field, 1).Int64() +} +func (rd *Redis) HDECRBY(key, field string, delta int64) int64 { + return rd.Do("HDECRBY "+key, field, delta).Int64() +} + +func (rd *Redis) LPUSH(key string, values ...string) int { + return rd.Do("LPUSH", append(append([]any{}, key), stringsToAnys(values)...)...).Int() +} +func (rd *Redis) RPUSH(key string, values ...string) int { + return rd.Do("RPUSH", append(append([]any{}, key), stringsToAnys(values)...)...).Int() +} +func (rd *Redis) LPOP(key string) *Result { + return rd.Do("LPOP " + key) +} +func (rd *Redis) RPOP(key string) *Result { + return rd.Do("RPOP " + key) +} +func (rd *Redis) LLEN(key string) int { + return rd.Do("LLEN " + key).Int() +} +func (rd *Redis) LRANGE(key string, start, stop int) []Result { + return rd.Do("LRANGE "+key, start, stop).Results() +} + +func (rd *Redis) SADD(key string, values ...any) int { + return rd.Do("SADD", append([]any{key}, values...)...).Int() +} +func (rd *Redis) SREM(key string, values ...any) int { + return rd.Do("SREM", append([]any{key}, values...)...).Int() +} +func (rd *Redis) SCARD(key string) int { + return rd.Do("SCARD " + key).Int() +} +func (rd *Redis) SMEMBERS(key string) []Result { + return rd.Do("SMEMBERS " + key).Results() +} +func (rd *Redis) SISMEMBER(key string, value any) bool { + return rd.Do("SISMEMBER "+key, value).Bool() +} + +func (rd *Redis) ZADD(key string, score float64, member any) bool { + return rd.Do("ZADD "+key, score, member).Bool() +} +func (rd *Redis) ZREM(key string, members ...any) int { + return rd.Do("ZREM", append([]any{key}, members...)...).Int() +} +func (rd *Redis) ZCARD(key string) int { + return rd.Do("ZCARD " + key).Int() +} +func (rd *Redis) ZRANGE(key string, start, stop int) []Result { + return rd.Do("ZRANGE "+key, start, stop).Results() +} +func (rd *Redis) ZREVRANGE(key string, start, stop int) []Result { + return rd.Do("ZREVRANGE "+key, start, stop).Results() +} +func (rd *Redis) ZRANK(key string, member any) int { + return rd.Do("ZRANK "+key, member).Int() +} +func (rd *Redis) ZREVRANK(key string, member any) int { + return rd.Do("ZREVRANK "+key, member).Int() +} +func (rd *Redis) ZSCORE(key string, member any) float64 { + return rd.Do("ZSCORE "+key, member).Float64() +} diff --git a/config.go b/config.go new file mode 100644 index 0000000..f980819 --- /dev/null +++ b/config.go @@ -0,0 +1,139 @@ +package redis + +import ( + "fmt" + "net/url" + "strconv" + "sync" + "time" + + "apigo.cc/go/cast" + "apigo.cc/go/crypto" + "apigo.cc/go/log" + "apigo.cc/go/safe" +) + +type Config struct { + Host string + Password string + DB int + MaxActive int + MaxIdle int + IdleTimeout time.Duration + ConnectTimeout time.Duration + ReadTimeout time.Duration + WriteTimeout time.Duration + LogSlow time.Duration + logger *log.Logger + pwd *safe.SafeBuf +} + +var redisConfigs = make(map[string]*Config) +var redisConfigsLock = sync.RWMutex{} + +var confAes, _ = crypto.NewAESGCMAndEraseKey([]byte("?GQ$0K0GgLdO=f+~L68PLm$uhKr4'=tV"), []byte("VFs7@sK61cj^f?HZ")) +var keysOnce = sync.Once{} + +func SetEncryptKeys(key, iv []byte) { + keysOnce.Do(func() { + confAes.Close() + confAes, _ = crypto.NewAESGCMAndEraseKey(key, iv) + }) +} + +func (conf *Config) ConfigureBy(setting string) { + redisConfigsLock.Lock() + redisConfigs[setting] = conf + redisConfigsLock.Unlock() + + urlInfo, err := url.Parse(setting) + if err != nil { + conf.logger.Error(err.Error(), "url", setting) + return + } + if urlInfo.Scheme != "redis" { + conf.logger.Error("unsupported scheme", "url", setting) + return + } + + conf.Host = urlInfo.Host + + dbStr := urlInfo.Query().Get("database") + if dbStr == "" && len(urlInfo.Path) > 1 { + dbStr = urlInfo.Path[1:] + } + if len(dbStr) > 0 { + db, err := strconv.Atoi(dbStr) + if err != nil { + conf.logger.Error(err.Error(), "url", setting) + } + if err == nil && db >= 0 && db <= 15 { + conf.DB = db + } + } + + conf.Password, _ = urlInfo.User.Password() + conf.LogSlow = cast.Duration(urlInfo.Query().Get("logSlow")) + conf.MaxIdle = cast.Int(urlInfo.Query().Get("maxIdle")) + conf.MaxActive = cast.Int(urlInfo.Query().Get("maxActive")) + conf.ConnectTimeout = cast.Duration(urlInfo.Query().Get("connectTimeout")) + conf.ReadTimeout = cast.Duration(urlInfo.Query().Get("readTimeout")) + conf.WriteTimeout = cast.Duration(urlInfo.Query().Get("writeTimeout")) + conf.IdleTimeout = cast.Duration(urlInfo.Query().Get("idleTimeout")) +} + +func (conf *Config) Dsn() string { + return fmt.Sprintf("redis://:****@%s/%d?timeout=%s&logSlow=%s", conf.Host, conf.DB, conf.ConnectTimeout, conf.LogSlow) +} + +func parseByName(name string) *Config { + // config name support Host:Port + args := cast.Split(name, ":") + db := 0 + if len(args) > 1 { + arg1, err := strconv.Atoi(args[1]) + if err == nil && arg1 >= 0 && arg1 <= 15 { + name = args[0] + db = arg1 + } + } + + redisConfigsLock.RLock() + conf := redisConfigs[name] + redisConfigsLock.RUnlock() + + if conf == nil { + conf = new(Config) + redisConfigsLock.Lock() + redisConfigs[name] = conf + redisConfigsLock.Unlock() + + if len(args) > 1 { + arg1, err := strconv.Atoi(args[1]) + if err == nil && arg1 >= 0 && arg1 <= 15 { + conf.DB = arg1 + } else { + conf.Host = args[0] + ":" + args[1] + } + } + } + + for i := 2; i < len(args); i++ { + arg2, err := strconv.Atoi(args[i]) + if err == nil { + if arg2 >= 0 && arg2 <= 15 { + conf.DB = arg2 + } else { + conf.Password = args[i] + } + } else { + conf.Password = args[i] + } + } + + if conf.DB == 0 && db > 0 && db <= 15 { + conf.DB = db + } + + return conf +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..eda9ae6 --- /dev/null +++ b/go.mod @@ -0,0 +1,24 @@ +module apigo.cc/go/redis + +go 1.25.0 + +require ( + apigo.cc/go/cast v1.1.1 + apigo.cc/go/config v1.0.4 + apigo.cc/go/crypto v1.0.4 + apigo.cc/go/id v1.0.4 + apigo.cc/go/log v1.0.0 + apigo.cc/go/safe v1.0.4 + github.com/gomodule/redigo v1.9.3 +) + +require ( + apigo.cc/go/convert v1.0.4 // indirect + apigo.cc/go/encoding v1.0.4 // indirect + apigo.cc/go/file v1.0.4 // indirect + apigo.cc/go/rand v1.0.4 // indirect + apigo.cc/go/shell v1.0.4 // indirect + golang.org/x/crypto v0.50.0 // indirect + golang.org/x/sys v0.43.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..302113e --- /dev/null +++ b/go.sum @@ -0,0 +1,38 @@ +apigo.cc/go/cast v1.1.1 h1:+5pluN8g1RK2J4byr2xkfOmEdKSmy1PByOqDOHtt/Ns= +apigo.cc/go/cast v1.1.1/go.mod h1:vh9ZqISCmTUiyinkNMI/s4f045fRlDK3xC+nPWQYBzI= +apigo.cc/go/config v1.0.4 h1:WG9zrQkqfFPkrKIL7RNvvAbbkuUBt1Av11ZP/aIfldM= +apigo.cc/go/config v1.0.4/go.mod h1:obryzJiK6j7lQex/58d5eWYOGx5O5IABguqNWxyyXJo= +apigo.cc/go/convert v1.0.4 h1:5+qPjC3dlPB59GnWZRlmthxcaXQtKvN+iOuiLdJ1GvQ= +apigo.cc/go/convert v1.0.4/go.mod h1:Hp+geeSyhqg/zwIKPOrDoceIREzcwM14t1I5q/dtbfU= +apigo.cc/go/crypto v1.0.4 h1:VPUyHCH2N3LLEgdpwUc+DQssNHzLlxVzLNRa0Jm6O4o= +apigo.cc/go/crypto v1.0.4/go.mod h1:5sI8BLw6YHZfDReYwCO3TFD2LKm36HMdLg1S5oPv/QU= +apigo.cc/go/encoding v1.0.4 h1:aezB0J/qFuHs6iXkbtuJP5JIHUtmjsr5SFb0NNvbObY= +apigo.cc/go/encoding v1.0.4/go.mod h1:V5CgT7rBbCxy+uCU20q0ptcNNRSgMtpA8cNOs6r8IeI= +apigo.cc/go/file v1.0.4 h1:qCKegV7OYh7r0qc3jZjGA/aKh0vIHgmr1OEbhfEmGX8= +apigo.cc/go/file v1.0.4/go.mod h1:C9gNo7386iA21OiBmuWh6CznKWlVBDFkhE4f0H0Susg= +apigo.cc/go/id v1.0.4 h1:w+JSdeVit52iefIUolrh1qLEZS9XqHNKr1UygFcgv+s= +apigo.cc/go/id v1.0.4/go.mod h1:kg7QuceAKtGNzGWt0+pIIh8Qom1eMSWGb8+0Yhi/QVY= +apigo.cc/go/log v1.0.0 h1:lI1NGTSS+Jm12G8BD7ZJO4/hrkfuLTu5O8z36GD8GpU= +apigo.cc/go/log v1.0.0/go.mod h1:tvPgFpebY9Wf/DlqMHZ0ZjxDp9AaQTywOQKvtBaNqNo= +apigo.cc/go/rand v1.0.4 h1:we070eWSL0dB8NEMaWjXj43+EekXQTm/h0kKpZ/frqw= +apigo.cc/go/rand v1.0.4/go.mod h1:mZ/4Soa3bk+XvDaqPWJuUe1bfEi4eThBj1XmEAuYxsk= +apigo.cc/go/safe v1.0.4 h1:07pRSdEHprF/2v6SsqAjICYFoeLcqjjvHGEdh6Dzrzg= +apigo.cc/go/safe v1.0.4/go.mod h1:o568sHS5rTRSVPmhxWod0tGdc+8l1KjidsNY1/OVZr0= +apigo.cc/go/shell v1.0.4 h1:EL9zjI39YBe1h+kRYQeAi/8zVGHe5W198DYYN7cENiY= +apigo.cc/go/shell v1.0.4/go.mod h1:N2gDkgK4tJ9TadD60/+gAGuWxyVAWHs5YPBmytw6ELA= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gomodule/redigo v1.9.3 h1:dNPSXeXv6HCq2jdyWfjgmhBdqnR6PRO3m/G05nvpPC8= +github.com/gomodule/redigo v1.9.3/go.mod h1:KsU3hiK/Ay8U42qpaJk+kuNa3C+spxapWpM+ywhcgtw= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= +golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= +golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= +golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/id.go b/id.go new file mode 100644 index 0000000..425de2b --- /dev/null +++ b/id.go @@ -0,0 +1,62 @@ +package redis + +import ( + "fmt" + "sync" + + "apigo.cc/go/id" +) + +type IdMaker struct { + rd *Redis + secCurrent uint64 + secIndexMax uint64 + secIndexNext uint64 + lock sync.Mutex + maker *id.IDMaker +} + +func NewIdMaker(rd *Redis) *IdMaker { + im := &IdMaker{rd: rd} + im.maker = id.NewIDMaker(im.makeSecIndex) + return im +} + +func (im *IdMaker) makeSecIndex(sec uint64) uint64 { + im.lock.Lock() + defer im.lock.Unlock() + + if im.secCurrent == sec && im.secIndexNext <= im.secIndexMax { + idx := im.secIndexNext + im.secIndexNext++ + return idx + } + + im.secCurrent = sec + key := fmt.Sprintf("_SecIdx_%d", sec) + // 每次从 Redis 预取 100 个序列号 + max := uint64(im.rd.INCRBY(key, 100)) + if max < 100 { + return 0 + } + im.secIndexMax = max + im.secIndexNext = max - 99 + idx := im.secIndexNext + im.secIndexNext++ + if max <= 100 { + im.rd.EXPIRE(key, 10) + } + return idx +} + +func (im *IdMaker) Get(size int) string { + return im.maker.Get(size) +} + +func (im *IdMaker) GetForMysql(size int) string { + return im.maker.GetForMysql(size) +} + +func (im *IdMaker) GetForPostgreSQL(size int) string { + return im.maker.GetForPostgreSQL(size) +} diff --git a/redis.go b/redis.go new file mode 100644 index 0000000..0776d4d --- /dev/null +++ b/redis.go @@ -0,0 +1,390 @@ +package redis + +import ( + "encoding/json" + "errors" + "io" + "net" + "reflect" + "strconv" + "strings" + "sync" + "syscall" + "time" + + "apigo.cc/go/cast" + "apigo.cc/go/config" + "apigo.cc/go/log" + "apigo.cc/go/safe" + "github.com/gomodule/redigo/redis" +) + +type Redis struct { + name string + pool *redis.Pool + ReadTimeout int + Config *Config + logger *log.Logger + Error error + subConn *redis.PubSubConn + subStopChan chan bool + subs map[string]*SubCallbacks + SubRunning bool +} + +type SubCallbacks struct { + received func([]byte) + reset func() +} + +var redisInstances = make(map[string]*Redis) +var redisInstancesLock = sync.RWMutex{} + +func GetRedis(name string, logger *log.Logger) *Redis { + if logger == nil { + logger = log.DefaultLogger + } + + redisInstancesLock.RLock() + oldConn := redisInstances[name] + redisInstancesLock.RUnlock() + if oldConn != nil { + return oldConn.CopyByLogger(logger) + } + + redisConfigsLock.RLock() + configsLen := len(redisConfigs) + redisConfigsLock.RUnlock() + + if configsLen == 0 { + _ = config.Load("redis", &redisConfigs) + } + + fullName := name + + var conf *Config + if strings.HasPrefix(name, "redis://") { + conf = new(Config) + conf.logger = logger + conf.ConfigureBy(name) + } else { + conf = parseByName(name) + } + + if pwd, err := confAes.Decrypt([]byte(conf.Password)); err == nil { + conf.pwd = pwd + } else { + conf.pwd = safe.NewSafeBuf([]byte(conf.Password)) + } + conf.Password = "" + + if conf.Host == "" { + conf.Host = "127.0.0.1:6379" + } + if conf.MaxIdle == 0 { + conf.MaxIdle = 20 + } + if conf.MaxActive == 0 { + conf.MaxActive = 100 + } + if conf.ConnectTimeout == 0 { + conf.ConnectTimeout = 10 * time.Second + } + if conf.ReadTimeout == 0 { + conf.ReadTimeout = 10 * time.Second + } + if conf.WriteTimeout == 0 { + conf.WriteTimeout = 10 * time.Second + } + if conf.LogSlow == 0 { + conf.LogSlow = 100 * time.Millisecond + } + + rd := NewRedis(conf, logger) + rd.name = fullName + redisInstancesLock.Lock() + redisInstances[fullName] = rd + redisInstancesLock.Unlock() + return rd.CopyByLogger(logger) +} + +func NewRedis(conf *Config, logger *log.Logger) *Redis { + if logger == nil { + logger = log.DefaultLogger + } + + conn := &redis.Pool{ + MaxIdle: conf.MaxIdle, + MaxActive: conf.MaxActive, + IdleTimeout: conf.IdleTimeout, + Dial: func() (redis.Conn, error) { + opts := []redis.DialOption{ + redis.DialConnectTimeout(conf.ConnectTimeout), + redis.DialReadTimeout(conf.ReadTimeout), + redis.DialWriteTimeout(conf.WriteTimeout), + redis.DialDatabase(conf.DB), + } + if conf.pwd != nil { + pwdBuf := conf.pwd.Open() + defer pwdBuf.Close() + opts = append(opts, redis.DialPassword(pwdBuf.String())) + } + c, err := redis.Dial("tcp", conf.Host, opts...) + if err != nil { + logger.Error(err.Error(), "type", "redis", "dsn", conf.Dsn()) + return nil, err + } + return c, nil + }, + } + + rd := new(Redis) + rd.ReadTimeout = int(conf.ReadTimeout / time.Millisecond) + rd.pool = conn + rd.Config = conf + rd.logger = logger + + return rd +} + +func (rd *Redis) CopyByLogger(logger *log.Logger) *Redis { + newRedis := new(Redis) + newRedis.name = rd.name + newRedis.ReadTimeout = rd.ReadTimeout + newRedis.pool = rd.pool + newRedis.subConn = rd.subConn + newRedis.subs = rd.subs + newRedis.SubRunning = rd.SubRunning + newRedis.Config = rd.Config + if logger == nil { + newRedis.logger = log.DefaultLogger + } else { + newRedis.logger = logger + } + return newRedis +} + +func (rd *Redis) SetLogger(logger *log.Logger) { + rd.logger = logger +} + +func (rd *Redis) GetLogger() *log.Logger { + return rd.logger +} + +func (rd *Redis) Destroy() error { + if rd.pool == nil { + return errors.New("operate on a bad redis pool") + } + err := rd.pool.Close() + if err != nil { + rd.logger.Error(err.Error(), "type", "redis", "dsn", rd.Config.Dsn()) + } + redisInstancesLock.Lock() + delete(redisInstances, rd.name) + redisInstancesLock.Unlock() + return err +} + +func (rd *Redis) GetPool() *redis.Pool { + return rd.pool +} + +func (rd *Redis) GetNewConnection() (redis.Conn, error) { + if rd.pool == nil { + return nil, errors.New("redis pool is not initialized") + } + c, err := rd.pool.Dial() + if err == nil { + err = c.Err() + } + if err != nil { + if c != nil { + _ = c.Close() + } + return nil, err + } + return c, nil +} + +func (rd *Redis) GetConnection() (redis.Conn, error) { + if rd.pool == nil { + return nil, errors.New("redis pool is not initialized") + } + c := rd.pool.Get() + err := c.Err() + if err != nil { + if c != nil { + _ = c.Close() + } + return nil, err + } + return c, nil +} + +func shouldRetry(err error) bool { + if err == nil { + return false + } + + // 网络错误 + if errors.Is(err, io.EOF) || errors.Is(err, syscall.ECONNRESET) { + return true + } + + // 超时错误 + if opErr, ok := err.(*net.OpError); ok && opErr.Timeout() { + return true + } + + // Redis特定可恢复错误 + if errs, ok := err.(redis.Error); ok { + switch { + case strings.HasPrefix(string(errs), "LOADING"), + strings.HasPrefix(string(errs), "CLUSTERDOWN"): + return true + } + } + + return false +} + +func (rd *Redis) Do(cmd string, values ...any) *Result { + if rd.pool == nil { + err := errors.New("operate on a bad redis pool") + rd.logger.Error(err.Error(), "type", "redis", "dsn", rd.Config.Dsn(), "cmd", cmd) + return &Result{Error: err} + } + startTime := time.Now() + r := rd.do(cmd, values...) + usedTime := time.Since(startTime) + if r.Error == nil { + if rd.Config.LogSlow > 0 && usedTime >= rd.Config.LogSlow { + rd.logger.Debug("redis slow query", "dsn", rd.Config.Dsn(), "cmd", cmd, "args", values, "used", usedTime.String()) + } + } else { + rd.logger.Error(r.Error.Error(), "type", "redis", "dsn", rd.Config.Dsn(), "cmd", cmd, "args", values, "used", usedTime.String()) + } + return r +} + +func (rd *Redis) do(cmd string, values ...any) *Result { + cmdArr := cast.Split(cmd, " ") + if len(cmdArr) > 1 { + cmd = cmdArr[0] + args := make([]any, 0) + for i := 1; i < len(cmdArr); i++ { + args = append(args, cmdArr[i]) + } + if len(values) > 0 { + args = append(args, values...) + } + values = args + } + + // 自动序列化 + for i, v := range values { + if v == nil { + continue + } + rv := reflect.ValueOf(v) + if rv.Kind() == reflect.Ptr { + rv = rv.Elem() + } + kind := rv.Kind() + if kind == reflect.Struct || kind == reflect.Map || (kind == reflect.Slice && rv.Type().Elem().Kind() != reflect.Uint8) { + // 对于这些类型,优先使用 json.Marshal 以支持自定义 Marshaler (如 time.Time) + if encoded, err := json.Marshal(v); err == nil { + values[i] = encoded + } + } + } + + // 从连接池获取 + conn, err := rd.GetConnection() + var replyData any + if err == nil { + replyData, err = conn.Do(cmd, values...) + _ = conn.Close() + } + + if err != nil && shouldRetry(err) { + // 拿全新的连接重试(如果服务器重启可自动恢复) + conn, err = rd.GetNewConnection() + if err == nil { + replyData, err = conn.Do(cmd, values...) + _ = conn.Close() + } + } + + if err != nil { + return &Result{Error: err} + } + + r := &Result{} + switch realValue := replyData.(type) { + case []byte: + r.bytesData = realValue + case string: + r.bytesData = []byte(realValue) + case int64: + r.bytesData = []byte(strconv.FormatInt(realValue, 10)) + case []any: + if cmd == "HMGET" { + r.keys = make([]string, len(values)-1) + for i, v := range values { + if i > 0 { + r.keys[i-1] = cast.String(v) + } + } + } else if cmd == "MGET" { + r.keys = make([]string, len(values)) + for i, v := range values { + r.keys[i] = cast.String(v) + } + } + + if cmd == "HGETALL" { + r.keys = make([]string, len(realValue)/2) + r.bytesDatas = make([][]byte, len(realValue)/2) + i1, i2 := 0, 0 + for i, v := range realValue { + if v != nil { + if i%2 == 0 { + r.keys[i1] = cast.String(v) + i1++ + } else { + switch subRealValue := v.(type) { + case []byte: + r.bytesDatas[i2] = subRealValue + case string: + r.bytesDatas[i2] = []byte(subRealValue) + default: + r.bytesDatas[i2] = []byte(cast.String(v)) + } + i2++ + } + } + } + } else { + r.bytesDatas = make([][]byte, len(realValue)) + for i, v := range realValue { + if v != nil { + switch subRealValue := v.(type) { + case []byte: + r.bytesDatas[i] = subRealValue + case string: + r.bytesDatas[i] = []byte(subRealValue) + default: + r.bytesDatas[i] = []byte(cast.String(v)) + } + } + } + } + case nil: + r.bytesData = []byte{} + default: + r.bytesData = []byte(cast.String(realValue)) + } + return r +} diff --git a/redis_test.go b/redis_test.go new file mode 100644 index 0000000..f7c592f --- /dev/null +++ b/redis_test.go @@ -0,0 +1,104 @@ +package redis_test + +import ( + "os" + "testing" + "time" + + "apigo.cc/go/config" + "apigo.cc/go/redis" +) + +type userInfo struct { + Id int + Name string + Phone string + Time time.Time +} + +func TestBase(t *testing.T) { + os.Setenv("REDIS_TEST", "redis://:@localhost:6379/2?timeout=10ms&logSlow=10us") + _ = config.Load("redis", nil) + + rd := redis.GetRedis("test", nil) + if rd.Error != nil { + t.Fatal("GetRedis error", rd.Error) + } + rd.DEL("redisName", "redisUser", "redisIds") + + r := rd.GET("redisNotExists") + if r.Error != nil && r.String() != "" || r.Int() != 0 { + t.Fatal("GET NotExists", r, r.String(), r.Int()) + } + + exists := rd.EXISTS("redisName") + if exists { + t.Fatal("EXISTS should be false") + } + + rd.SET("redisName", "12345") + + r = rd.GETSET("redisName", 12345) + if r.String() != "12345" { + t.Fatal("GETSET String mismatch", r.String()) + } + if r.Int() != 12345 { + t.Fatal("Int conversion mismatch", r.Int()) + } + + exists = rd.EXISTS("redisName") + if !exists { + t.Fatal("EXISTS should be true") + } + + // Expire test + rd.SET("redisName", "12") + rd.EXPIRE("redisName", 1) + time.Sleep(2 * time.Second) + r = rd.GET("redisName") + if r.Int() > 0 { + t.Fatal("Expired key still exists", r.Int()) + } + + // Struct test + info := userInfo{ + Name: "aaa", + Id: 123, + Time: time.Now().Truncate(time.Second), // Redis JSON might lose precision + } + rd.SET("redisUser", info) + r = rd.GET("redisUser") + var ru userInfo + _ = r.To(&ru) + if ru.Name != info.Name || ru.Id != info.Id || !ru.Time.Equal(info.Time) { + t.Fatalf("Struct mismatch: expected %+v, got %+v", info, ru) + } + + // MSET/MGET test + rd.MSET("redisName", "Sam Lee", "redisIds", []int{1, 2, 3}) + results := rd.MGET("redisName", "redisIds") + if len(results) != 2 || results[0].String() != "Sam Lee" { + t.Fatal("MGET Results mismatch") + } + ria := results[1].Ints() + if len(ria) != 3 || ria[0] != 1 || ria[1] != 2 || ria[2] != 3 { + t.Fatal("MGET Ints mismatch", ria) + } + + num := rd.DEL("redisName", "redisUser", "redisIds") + if num != 3 { + t.Fatal("DEL count mismatch", num) + } +} + +func TestGenerics(t *testing.T) { + rd := redis.GetRedis("test", nil) + rd.SET("gen_test", userInfo{Name: "Generics", Id: 888}) + defer rd.DEL("gen_test") + + r := rd.GET("gen_test") + user := redis.To[userInfo](r) + if user.Name != "Generics" || user.Id != 888 { + t.Fatal("Generics To[T] mismatch", user) + } +} diff --git a/result.go b/result.go new file mode 100644 index 0000000..50535cc --- /dev/null +++ b/result.go @@ -0,0 +1,224 @@ +package redis + +import ( + "encoding/json" + "reflect" + + "apigo.cc/go/cast" +) + +type Result struct { + bytesData []byte + keys []string + bytesDatas [][]byte + Error error +} + +func (rs *Result) Int() int { + return cast.Int(rs.String()) +} +func (rs *Result) Int8() int8 { + return int8(cast.Int64(rs.String())) +} +func (rs *Result) Int16() int16 { + return int16(cast.Int64(rs.String())) +} +func (rs *Result) Int32() int32 { + return int32(cast.Int64(rs.String())) +} +func (rs *Result) Int64() int64 { + return cast.Int64(rs.String()) +} +func (rs *Result) Uint() uint { + return cast.Uint(rs.String()) +} +func (rs *Result) Uint8() uint8 { + return uint8(cast.Uint64(rs.String())) +} +func (rs *Result) Uint16() uint16 { + return uint16(cast.Uint64(rs.String())) +} +func (rs *Result) Uint32() uint32 { + return uint32(cast.Uint64(rs.String())) +} +func (rs *Result) Uint64() uint64 { + return cast.Uint64(rs.String()) +} +func (rs *Result) Float() float32 { + return cast.Float(rs.String()) +} +func (rs *Result) Float64() float64 { + return cast.Float64(rs.String()) +} +func (rs *Result) String() string { + return string(rs.bytes()) +} +func (rs *Result) Bytes() []byte { + return rs.bytes() +} +func (rs *Result) Bool() bool { + return cast.Bool(rs.String()) +} + +func (rs *Result) Ints() []int { + if rs.bytesDatas != nil { + r := make([]int, len(rs.bytesDatas)) + for i, v := range rs.bytesDatas { + r[i] = cast.Int(string(v)) + } + return r + } else if rs.bytesData != nil { + var r []int + _ = rs.To(&r) + return r + } + return []int{} +} + +func (rs *Result) Strings() []string { + if rs.bytesDatas != nil { + r := make([]string, len(rs.bytesDatas)) + for i, v := range rs.bytesDatas { + r[i] = string(v) + } + return r + } else if rs.bytesData != nil { + var r []string + _ = rs.To(&r) + return r + } + return []string{} +} + +func (rs *Result) Results() []Result { + if rs.bytesDatas != nil { + r := make([]Result, len(rs.bytesDatas)) + for i, v := range rs.bytesDatas { + r[i].bytesData = v + } + return r + } else if rs.bytesData != nil { + var m []string + _ = rs.To(&m) + r := make([]Result, len(m)) + for k, v := range m { + r[k] = Result{bytesData: []byte(v)} + } + return r + } + return []Result{} +} + +func (rs *Result) ResultMap() map[string]*Result { + if rs.bytesDatas != nil && rs.keys != nil { + r := make(map[string]*Result) + n := len(rs.bytesDatas) + for i, k := range rs.keys { + if i < n { + r[k] = &Result{bytesData: rs.bytesDatas[i]} + } + } + return r + } else if rs.bytesData != nil { + r := make(map[string]*Result) + var m map[string]string + _ = rs.To(&m) + for k, v := range m { + r[k] = &Result{bytesData: []byte(v)} + } + return r + } + return map[string]*Result{} +} + +func (rs *Result) StringMap() map[string]string { + rm := rs.ResultMap() + m := make(map[string]string) + for k, r := range rm { + m[k] = r.String() + } + return m +} + +func (rs *Result) IntMap() map[string]int { + rm := rs.ResultMap() + m := make(map[string]int) + for k, r := range rm { + m[k] = r.Int() + } + return m +} + +func (rs *Result) bytes() []byte { + if rs.bytesData != nil { + return rs.bytesData + } else if rs.bytesDatas != nil { + return cast.MustJSONBytes(rs.Strings()) + } + return []byte{} +} + +func (rs *Result) ToValue(t reflect.Type) reflect.Value { + v := reflect.New(t).Elem() + switch t.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.SetInt(rs.Int64()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.SetUint(rs.Uint64()) + case reflect.Float32, reflect.Float64: + v.SetFloat(rs.Float64()) + case reflect.Bool: + v.SetBool(rs.Bool()) + case reflect.String: + v.SetString(rs.String()) + case reflect.Map, reflect.Slice, reflect.Struct: + _ = rs.To(v.Addr().Interface()) + } + return v +} + +// To 使用泛型进行反序列化 +func To[T any](rs *Result) T { + var out T + _ = rs.To(&out) + return out +} + +func (rs *Result) To(result any) error { + if rs.bytesData != nil { + if len(rs.bytesData) > 0 { + // 优先使用 json.Unmarshal 以支持自定义 Unmarshaler (如 time.Time) + return json.Unmarshal(rs.bytesData, result) + } + return nil + } + + t := reflect.TypeOf(result) + v := reflect.ValueOf(result) + if t.Kind() == reflect.Ptr { + t = t.Elem() + v = v.Elem() + } + + if (t.Kind() == reflect.Struct || t.Kind() == reflect.Map) && rs.keys != nil && rs.bytesDatas != nil { + rm := rs.ResultMap() + for k, r := range rm { + if t.Kind() == reflect.Struct { + k = cast.GetUpperName(k) + sf, found := t.FieldByName(k) + if found { + v.FieldByName(k).Set(r.ToValue(sf.Type)) + } + } else if t.Kind() == reflect.Map { + v.SetMapIndex(reflect.ValueOf(k), r.ToValue(t.Elem())) + } + } + } else if t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8 && rs.bytesDatas != nil { + results := rs.Results() + for _, r := range results { + v = reflect.Append(v, r.ToValue(t.Elem())) + } + reflect.ValueOf(result).Elem().Set(v) + } + return nil +} diff --git a/subscribe.go b/subscribe.go new file mode 100644 index 0000000..6e73ccb --- /dev/null +++ b/subscribe.go @@ -0,0 +1,147 @@ +package redis + +import ( + "strings" + "time" + + "github.com/gomodule/redigo/redis" +) + +func (rd *Redis) Subscribe(name string, reset func(), received func([]byte)) bool { + if rd.subs == nil { + rd.subs = make(map[string]*SubCallbacks) + } + rd.subs[name] = &SubCallbacks{reset: reset, received: received} + if rd.subConn != nil { + err := rd.subConn.Subscribe(name) + if err != nil { + rd.logger.Error(err.Error(), "type", "redis", "action", "subscribe", "name", name) + } else { + return true + } + } + return false +} + +func (rd *Redis) Unsubscribe(name string) bool { + if rd.subs != nil { + delete(rd.subs, name) + } + if rd.subConn != nil { + err := rd.subConn.Unsubscribe(name) + if err != nil { + rd.logger.Error(err.Error(), "type", "redis", "action", "unsubscribe", "name", name) + } else { + return true + } + } + return false +} + +func (rd *Redis) Start() { + if rd.subs == nil { + rd.subs = make(map[string]*SubCallbacks) + } + rd.SubRunning = true + subStartChan := make(chan bool) + go rd.receiveSub(subStartChan) + <-subStartChan +} + +func (rd *Redis) receiveSub(subStartChan chan bool) { + for { + if !rd.SubRunning { + break + } + + // 开始接收订阅数据 + if rd.subConn == nil { + conn, err := rd.GetConnection() + if err != nil { + time.Sleep(time.Second) + continue + } + rd.subConn = &redis.PubSubConn{Conn: conn} + // 重新订阅 + if len(rd.subs) > 0 { + subs := make([]any, 0) + for k := range rd.subs { + subs = append(subs, k) + } + err = rd.subConn.Subscribe(subs...) + if err != nil { + _ = rd.subConn.Close() + rd.subConn = nil + time.Sleep(time.Second) + continue + } + // 重新连接时调用重置数据的回调 + for _, v := range rd.subs { + if v.reset != nil { + v.reset() + } + } + } + } + + if subStartChan != nil { + subStartChan <- true + subStartChan = nil + } + + for { + isErr := false + receiveObj := rd.subConn.Receive() + switch v := receiveObj.(type) { + case redis.Message: + callback := rd.subs[v.Channel] + if callback != nil && callback.received != nil { + callback.received(v.Data) + } + case redis.Subscription: + case redis.Pong: + case error: + if strings.Contains(v.Error(), "i/o timeout") { + break + } + if !strings.Contains(v.Error(), "connection closed") && !strings.Contains(v.Error(), "use of closed network connection") { + rd.logger.Error(v.Error(), "type", "redis", "action", "receiveSub") + } + if rd.subConn != nil { + _ = rd.subConn.Close() + rd.subConn = nil + } + isErr = true + } + if isErr || !rd.SubRunning { + break + } + } + } + if rd.subStopChan != nil { + rd.subStopChan <- true + } +} + +func (rd *Redis) Stop() { + if rd.SubRunning { + rd.subStopChan = make(chan bool) + rd.SubRunning = false + if rd.subConn != nil { + // 取消订阅 + if len(rd.subs) > 0 { + _ = rd.subConn.Unsubscribe() + } + // 读一次再关闭可以防止Close时阻塞 + _ = rd.subConn.ReceiveWithTimeout(50 * time.Millisecond) + _ = rd.subConn.Close() + rd.subConn = nil + } + <-rd.subStopChan + rd.subStopChan = nil + } +} + +func (rd *Redis) PUBLISH(channel, data string) bool { + return rd.Do("PUBLISH "+channel, data).Bool() +} diff --git a/subscribe_test.go b/subscribe_test.go new file mode 100644 index 0000000..d4e963c --- /dev/null +++ b/subscribe_test.go @@ -0,0 +1,56 @@ +package redis_test + +import ( + "sync/atomic" + "testing" + "time" + + "apigo.cc/go/redis" +) + +func TestSub(t *testing.T) { + rd := redis.GetRedis("test", nil) + rd.Start() + defer rd.Stop() + + var aaa atomic.Value + var bbb atomic.Value + + rd.Subscribe("aaa", nil, func(s []byte) { + aaa.Store(string(s)) + }) + + rd.PUBLISH("aaa", "111") + time.Sleep(100 * time.Millisecond) + + val := aaa.Load() + if val == nil || val.(string) != "111" { + t.Fatal("Subscribe aaa failed", val) + } + + rd.Subscribe("bbb", nil, func(s []byte) { + bbb.Store(string(s)) + }) + + rd.PUBLISH("bbb", "222") + time.Sleep(100 * time.Millisecond) + + val = bbb.Load() + if val == nil || val.(string) != "222" { + t.Fatal("Subscribe bbb failed", val) + } + + rd.Unsubscribe("aaa") + rd.PUBLISH("aaa", "1111") + rd.PUBLISH("bbb", "2222") + time.Sleep(100 * time.Millisecond) + + val = aaa.Load() + if val == nil || val.(string) != "111" { + t.Fatal("Unsubscribe aaa failed: value updated", val) + } + val = bbb.Load() + if val == nil || val.(string) != "2222" { + t.Fatal("Subscribe bbb update failed", val) + } +}