feat: migrate redis module from ssgo to apigo.cc/go standard

This commit is contained in:
AI Engineer 2026-05-03 08:43:23 +08:00
parent c32edec7b8
commit 69324d566e
12 changed files with 1420 additions and 2 deletions

16
CHANGELOG.md Normal file
View File

@ -0,0 +1,16 @@
# CHANGELOG - redis
## v1.0.0 (2026-05-03)
- **Repo Migration**: 从 `@ssgo/redis` 迁移至 `apigo.cc/go/redis`
- **Standard Realignment**:
- 依赖全面切换至 `apigo.cc/go/*` 标准库。
- 适配 Go 1.25.0。
- **Feature Enhancements**:
- **Generics**: 为 `Result` 引入泛型 `To[T]` 支持,消除类型断言摩擦。
- **Memory Safety**: 集成 `go/safe` 对 Redis 密码进行实时加解密与内存锁定,防止内存泄漏敏感信息。
- **Auto-Serialization**: 优化了 `Do` 方法,支持对 Struct/Map/Slice 自动进行 JSON 序列化(优先支持 Marshaler 接口)。
- **Distributed ID**: 深度集成 `go/id` 核心,提供更高性能的 Redis 序列号预取机制。
- **Refactoring**:
- 移除了冗余的 `interface{}`,全面改用 `any`
- 规范化了 API 命名,统一使用 `GetUpperName` 进行 Struct 字段映射。
- 增强了连接重试机制,支持对网络波动和服务器重启的自动恢复。

View File

@ -1,3 +1,56 @@
# redis
# Redis 模块 (redis)
高性能 Redis 客户端,集成分布式 ID 生成与发布订阅
`redis` 模块提供了一个高性能、内存安全且易于使用的 Redis 客户端,集成了分布式 ID 生成和发布订阅功能。
## 设计特性
- **消除摩擦**: 自动处理连接池、重试、以及复杂类型的 JSON 序列化。
- **内存安全**: 对 Redis 密码进行内存保护,防止内存 Dump 泄露敏感信息。
- **泛型支持**: 结果集支持泛型绑定 `To[T]`
- **分布式 ID**: 内置基于 Redis 的高性能分布式 ID 生成器。
## API 指南
### 基础连接
- `GetRedis(name string, logger *log.Logger) *Redis`: 获取或创建一个 Redis 实例(支持 DSN 或配置文件名)。
- `NewRedis(conf *Config, logger *log.Logger) *Redis`: 使用指定配置创建 Redis 实例。
### 核心操作
- `Do(cmd string, values ...any) *Result`: 执行原生 Redis 命令。
- 支持大部分标准命令:`GET`, `SET`, `HGET`, `HSET`, `LPUSH`, `LPOP`, `SADD`, `ZRANGE`, `PUBLISH` 等。
### 结果处理 (Result)
- `Int()`, `String()`, `Bool()`, `Float()`: 基础类型转换。
- `To(target any)`: 将结果反序列化到结构体或 Map。
- `To[T](rs *Result) T`: **泛型版本**,直接返回目标类型对象。
- `Ints()`, `Strings()`, `Results()`: 处理数组结果。
- `ResultMap()`: 处理 Hash 或键值对结果。
### 发布订阅
- `Start()`: 开启订阅监听协程。
- `Subscribe(channel string, reset func(), received func([]byte))`: 订阅频道。
- `Unsubscribe(channel string)`: 取消订阅。
- `Stop()`: 停止所有订阅。
### 分布式 ID (IdMaker)
- `NewIdMaker(rd *Redis) *IdMaker`: 创建分布式 ID 生成器。
- `Get(size int)`: 获取指定长度的唯一 ID。
- `GetForMysql(size int)`: 获取针对 MySQL 优化的唯一 ID。
## 示例
```go
import "apigo.cc/go/redis"
// 获取实例 (自动从 redis.json 加载)
rd := redis.GetRedis("test", nil)
// 设置结构体 (自动 JSON 序列化)
rd.SET("user:1", User{Name: "Sam", Age: 18})
// 获取并自动绑定 (泛型)
user := redis.To[User](rd.GET("user:1"))
// 分布式 ID
maker := redis.NewIdMaker(rd)
id := maker.Get(10)
```

165
commands.go Normal file
View File

@ -0,0 +1,165 @@
package redis
func stringsToAnys(in []string) []any {
a := make([]any, len(in))
for i, v := range in {
a[i] = v
}
return a
}
func (rd *Redis) DEL(keys ...string) int {
return rd.Do("DEL", stringsToAnys(keys)...).Int()
}
func (rd *Redis) EXISTS(key string) bool {
return rd.Do("EXISTS " + key).Bool()
}
func (rd *Redis) EXPIRE(key string, second int) bool {
if second > 315360000 {
return rd.Do("EXPIREAT "+key, second).Bool()
} else {
return rd.Do("EXPIRE "+key, second).Bool()
}
}
func (rd *Redis) KEYS(patten string) []string {
return rd.Do("KEYS " + patten).Strings()
}
func (rd *Redis) GET(key string) *Result {
return rd.Do("GET " + key)
}
func (rd *Redis) SET(key string, value any) bool {
return rd.Do("SET "+key, value).Bool()
}
func (rd *Redis) SETEX(key string, seconds int, value any) bool {
return rd.Do("SETEX "+key, seconds, value).Bool()
}
func (rd *Redis) SETNX(key string, value any) bool {
return rd.Do("SETNX "+key, value).Bool()
}
func (rd *Redis) GETSET(key string, value any) *Result {
return rd.Do("GETSET "+key, value)
}
func (rd *Redis) INCR(key string) int64 {
return rd.Do("INCR " + key).Int64()
}
func (rd *Redis) INCRBY(key string, delta int64) int64 {
return rd.Do("INCRBY "+key, delta).Int64()
}
func (rd *Redis) DECR(key string, delta int64) int64 {
return rd.Do("DECR "+key, delta).Int64()
}
func (rd *Redis) DECRBY(key string, delta int64) int64 {
return rd.Do("DECRBY "+key, delta).Int64()
}
func (rd *Redis) MGET(keys ...string) []Result {
return rd.Do("MGET", stringsToAnys(keys)...).Results()
}
func (rd *Redis) MSET(keyAndValues ...any) bool {
return rd.Do("MSET", keyAndValues...).Bool()
}
func (rd *Redis) HGET(key, field string) *Result {
return rd.Do("HGET "+key, field)
}
func (rd *Redis) HSET(key, field string, value any) bool {
return rd.Do("HSET "+key, field, value).Error == nil
}
func (rd *Redis) HSETNX(key, field string, value any) bool {
return rd.Do("HSETNX "+key, field, value).Error == nil
}
func (rd *Redis) HMGET(key string, fields ...string) []Result {
return rd.Do("HMGET", append(append([]any{}, key), stringsToAnys(fields)...)...).Results()
}
func (rd *Redis) HGETALL(key string) map[string]*Result {
return rd.Do("HGETALL " + key).ResultMap()
}
func (rd *Redis) HMSET(key string, fieldAndValues ...any) bool {
return rd.Do("HMSET", append(append([]any{}, key), fieldAndValues...)...).Bool()
}
func (rd *Redis) HKEYS(key string) []string {
return rd.Do("HKEYS " + key).Strings()
}
func (rd *Redis) HLEN(key string) int {
return rd.Do("HLEN " + key).Int()
}
func (rd *Redis) HDEL(key string, fields ...string) int {
return rd.Do("HDEL", append(append([]any{}, key), stringsToAnys(fields)...)...).Int()
}
func (rd *Redis) HEXISTS(key, field string) bool {
return rd.Do("HEXISTS "+key, field).Bool()
}
func (rd *Redis) HINCR(key, field string) int64 {
return rd.Do("HINCRBY "+key, field, 1).Int64()
}
func (rd *Redis) HINCRBY(key, field string, delta int64) int64 {
return rd.Do("HINCRBY "+key, field, delta).Int64()
}
func (rd *Redis) HDECR(key, field string) int64 {
return rd.Do("HDECRBY "+key, field, 1).Int64()
}
func (rd *Redis) HDECRBY(key, field string, delta int64) int64 {
return rd.Do("HDECRBY "+key, field, delta).Int64()
}
func (rd *Redis) LPUSH(key string, values ...string) int {
return rd.Do("LPUSH", append(append([]any{}, key), stringsToAnys(values)...)...).Int()
}
func (rd *Redis) RPUSH(key string, values ...string) int {
return rd.Do("RPUSH", append(append([]any{}, key), stringsToAnys(values)...)...).Int()
}
func (rd *Redis) LPOP(key string) *Result {
return rd.Do("LPOP " + key)
}
func (rd *Redis) RPOP(key string) *Result {
return rd.Do("RPOP " + key)
}
func (rd *Redis) LLEN(key string) int {
return rd.Do("LLEN " + key).Int()
}
func (rd *Redis) LRANGE(key string, start, stop int) []Result {
return rd.Do("LRANGE "+key, start, stop).Results()
}
func (rd *Redis) SADD(key string, values ...any) int {
return rd.Do("SADD", append([]any{key}, values...)...).Int()
}
func (rd *Redis) SREM(key string, values ...any) int {
return rd.Do("SREM", append([]any{key}, values...)...).Int()
}
func (rd *Redis) SCARD(key string) int {
return rd.Do("SCARD " + key).Int()
}
func (rd *Redis) SMEMBERS(key string) []Result {
return rd.Do("SMEMBERS " + key).Results()
}
func (rd *Redis) SISMEMBER(key string, value any) bool {
return rd.Do("SISMEMBER "+key, value).Bool()
}
func (rd *Redis) ZADD(key string, score float64, member any) bool {
return rd.Do("ZADD "+key, score, member).Bool()
}
func (rd *Redis) ZREM(key string, members ...any) int {
return rd.Do("ZREM", append([]any{key}, members...)...).Int()
}
func (rd *Redis) ZCARD(key string) int {
return rd.Do("ZCARD " + key).Int()
}
func (rd *Redis) ZRANGE(key string, start, stop int) []Result {
return rd.Do("ZRANGE "+key, start, stop).Results()
}
func (rd *Redis) ZREVRANGE(key string, start, stop int) []Result {
return rd.Do("ZREVRANGE "+key, start, stop).Results()
}
func (rd *Redis) ZRANK(key string, member any) int {
return rd.Do("ZRANK "+key, member).Int()
}
func (rd *Redis) ZREVRANK(key string, member any) int {
return rd.Do("ZREVRANK "+key, member).Int()
}
func (rd *Redis) ZSCORE(key string, member any) float64 {
return rd.Do("ZSCORE "+key, member).Float64()
}

139
config.go Normal file
View File

@ -0,0 +1,139 @@
package redis
import (
"fmt"
"net/url"
"strconv"
"sync"
"time"
"apigo.cc/go/cast"
"apigo.cc/go/crypto"
"apigo.cc/go/log"
"apigo.cc/go/safe"
)
type Config struct {
Host string
Password string
DB int
MaxActive int
MaxIdle int
IdleTimeout time.Duration
ConnectTimeout time.Duration
ReadTimeout time.Duration
WriteTimeout time.Duration
LogSlow time.Duration
logger *log.Logger
pwd *safe.SafeBuf
}
var redisConfigs = make(map[string]*Config)
var redisConfigsLock = sync.RWMutex{}
var confAes, _ = crypto.NewAESGCMAndEraseKey([]byte("?GQ$0K0GgLdO=f+~L68PLm$uhKr4'=tV"), []byte("VFs7@sK61cj^f?HZ"))
var keysOnce = sync.Once{}
func SetEncryptKeys(key, iv []byte) {
keysOnce.Do(func() {
confAes.Close()
confAes, _ = crypto.NewAESGCMAndEraseKey(key, iv)
})
}
func (conf *Config) ConfigureBy(setting string) {
redisConfigsLock.Lock()
redisConfigs[setting] = conf
redisConfigsLock.Unlock()
urlInfo, err := url.Parse(setting)
if err != nil {
conf.logger.Error(err.Error(), "url", setting)
return
}
if urlInfo.Scheme != "redis" {
conf.logger.Error("unsupported scheme", "url", setting)
return
}
conf.Host = urlInfo.Host
dbStr := urlInfo.Query().Get("database")
if dbStr == "" && len(urlInfo.Path) > 1 {
dbStr = urlInfo.Path[1:]
}
if len(dbStr) > 0 {
db, err := strconv.Atoi(dbStr)
if err != nil {
conf.logger.Error(err.Error(), "url", setting)
}
if err == nil && db >= 0 && db <= 15 {
conf.DB = db
}
}
conf.Password, _ = urlInfo.User.Password()
conf.LogSlow = cast.Duration(urlInfo.Query().Get("logSlow"))
conf.MaxIdle = cast.Int(urlInfo.Query().Get("maxIdle"))
conf.MaxActive = cast.Int(urlInfo.Query().Get("maxActive"))
conf.ConnectTimeout = cast.Duration(urlInfo.Query().Get("connectTimeout"))
conf.ReadTimeout = cast.Duration(urlInfo.Query().Get("readTimeout"))
conf.WriteTimeout = cast.Duration(urlInfo.Query().Get("writeTimeout"))
conf.IdleTimeout = cast.Duration(urlInfo.Query().Get("idleTimeout"))
}
func (conf *Config) Dsn() string {
return fmt.Sprintf("redis://:****@%s/%d?timeout=%s&logSlow=%s", conf.Host, conf.DB, conf.ConnectTimeout, conf.LogSlow)
}
func parseByName(name string) *Config {
// config name support Host:Port
args := cast.Split(name, ":")
db := 0
if len(args) > 1 {
arg1, err := strconv.Atoi(args[1])
if err == nil && arg1 >= 0 && arg1 <= 15 {
name = args[0]
db = arg1
}
}
redisConfigsLock.RLock()
conf := redisConfigs[name]
redisConfigsLock.RUnlock()
if conf == nil {
conf = new(Config)
redisConfigsLock.Lock()
redisConfigs[name] = conf
redisConfigsLock.Unlock()
if len(args) > 1 {
arg1, err := strconv.Atoi(args[1])
if err == nil && arg1 >= 0 && arg1 <= 15 {
conf.DB = arg1
} else {
conf.Host = args[0] + ":" + args[1]
}
}
}
for i := 2; i < len(args); i++ {
arg2, err := strconv.Atoi(args[i])
if err == nil {
if arg2 >= 0 && arg2 <= 15 {
conf.DB = arg2
} else {
conf.Password = args[i]
}
} else {
conf.Password = args[i]
}
}
if conf.DB == 0 && db > 0 && db <= 15 {
conf.DB = db
}
return conf
}

24
go.mod Normal file
View File

@ -0,0 +1,24 @@
module apigo.cc/go/redis
go 1.25.0
require (
apigo.cc/go/cast v1.1.1
apigo.cc/go/config v1.0.4
apigo.cc/go/crypto v1.0.4
apigo.cc/go/id v1.0.4
apigo.cc/go/log v1.0.0
apigo.cc/go/safe v1.0.4
github.com/gomodule/redigo v1.9.3
)
require (
apigo.cc/go/convert v1.0.4 // indirect
apigo.cc/go/encoding v1.0.4 // indirect
apigo.cc/go/file v1.0.4 // indirect
apigo.cc/go/rand v1.0.4 // indirect
apigo.cc/go/shell v1.0.4 // indirect
golang.org/x/crypto v0.50.0 // indirect
golang.org/x/sys v0.43.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

38
go.sum Normal file
View File

@ -0,0 +1,38 @@
apigo.cc/go/cast v1.1.1 h1:+5pluN8g1RK2J4byr2xkfOmEdKSmy1PByOqDOHtt/Ns=
apigo.cc/go/cast v1.1.1/go.mod h1:vh9ZqISCmTUiyinkNMI/s4f045fRlDK3xC+nPWQYBzI=
apigo.cc/go/config v1.0.4 h1:WG9zrQkqfFPkrKIL7RNvvAbbkuUBt1Av11ZP/aIfldM=
apigo.cc/go/config v1.0.4/go.mod h1:obryzJiK6j7lQex/58d5eWYOGx5O5IABguqNWxyyXJo=
apigo.cc/go/convert v1.0.4 h1:5+qPjC3dlPB59GnWZRlmthxcaXQtKvN+iOuiLdJ1GvQ=
apigo.cc/go/convert v1.0.4/go.mod h1:Hp+geeSyhqg/zwIKPOrDoceIREzcwM14t1I5q/dtbfU=
apigo.cc/go/crypto v1.0.4 h1:VPUyHCH2N3LLEgdpwUc+DQssNHzLlxVzLNRa0Jm6O4o=
apigo.cc/go/crypto v1.0.4/go.mod h1:5sI8BLw6YHZfDReYwCO3TFD2LKm36HMdLg1S5oPv/QU=
apigo.cc/go/encoding v1.0.4 h1:aezB0J/qFuHs6iXkbtuJP5JIHUtmjsr5SFb0NNvbObY=
apigo.cc/go/encoding v1.0.4/go.mod h1:V5CgT7rBbCxy+uCU20q0ptcNNRSgMtpA8cNOs6r8IeI=
apigo.cc/go/file v1.0.4 h1:qCKegV7OYh7r0qc3jZjGA/aKh0vIHgmr1OEbhfEmGX8=
apigo.cc/go/file v1.0.4/go.mod h1:C9gNo7386iA21OiBmuWh6CznKWlVBDFkhE4f0H0Susg=
apigo.cc/go/id v1.0.4 h1:w+JSdeVit52iefIUolrh1qLEZS9XqHNKr1UygFcgv+s=
apigo.cc/go/id v1.0.4/go.mod h1:kg7QuceAKtGNzGWt0+pIIh8Qom1eMSWGb8+0Yhi/QVY=
apigo.cc/go/log v1.0.0 h1:lI1NGTSS+Jm12G8BD7ZJO4/hrkfuLTu5O8z36GD8GpU=
apigo.cc/go/log v1.0.0/go.mod h1:tvPgFpebY9Wf/DlqMHZ0ZjxDp9AaQTywOQKvtBaNqNo=
apigo.cc/go/rand v1.0.4 h1:we070eWSL0dB8NEMaWjXj43+EekXQTm/h0kKpZ/frqw=
apigo.cc/go/rand v1.0.4/go.mod h1:mZ/4Soa3bk+XvDaqPWJuUe1bfEi4eThBj1XmEAuYxsk=
apigo.cc/go/safe v1.0.4 h1:07pRSdEHprF/2v6SsqAjICYFoeLcqjjvHGEdh6Dzrzg=
apigo.cc/go/safe v1.0.4/go.mod h1:o568sHS5rTRSVPmhxWod0tGdc+8l1KjidsNY1/OVZr0=
apigo.cc/go/shell v1.0.4 h1:EL9zjI39YBe1h+kRYQeAi/8zVGHe5W198DYYN7cENiY=
apigo.cc/go/shell v1.0.4/go.mod h1:N2gDkgK4tJ9TadD60/+gAGuWxyVAWHs5YPBmytw6ELA=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gomodule/redigo v1.9.3 h1:dNPSXeXv6HCq2jdyWfjgmhBdqnR6PRO3m/G05nvpPC8=
github.com/gomodule/redigo v1.9.3/go.mod h1:KsU3hiK/Ay8U42qpaJk+kuNa3C+spxapWpM+ywhcgtw=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

62
id.go Normal file
View File

@ -0,0 +1,62 @@
package redis
import (
"fmt"
"sync"
"apigo.cc/go/id"
)
type IdMaker struct {
rd *Redis
secCurrent uint64
secIndexMax uint64
secIndexNext uint64
lock sync.Mutex
maker *id.IDMaker
}
func NewIdMaker(rd *Redis) *IdMaker {
im := &IdMaker{rd: rd}
im.maker = id.NewIDMaker(im.makeSecIndex)
return im
}
func (im *IdMaker) makeSecIndex(sec uint64) uint64 {
im.lock.Lock()
defer im.lock.Unlock()
if im.secCurrent == sec && im.secIndexNext <= im.secIndexMax {
idx := im.secIndexNext
im.secIndexNext++
return idx
}
im.secCurrent = sec
key := fmt.Sprintf("_SecIdx_%d", sec)
// 每次从 Redis 预取 100 个序列号
max := uint64(im.rd.INCRBY(key, 100))
if max < 100 {
return 0
}
im.secIndexMax = max
im.secIndexNext = max - 99
idx := im.secIndexNext
im.secIndexNext++
if max <= 100 {
im.rd.EXPIRE(key, 10)
}
return idx
}
func (im *IdMaker) Get(size int) string {
return im.maker.Get(size)
}
func (im *IdMaker) GetForMysql(size int) string {
return im.maker.GetForMysql(size)
}
func (im *IdMaker) GetForPostgreSQL(size int) string {
return im.maker.GetForPostgreSQL(size)
}

390
redis.go Normal file
View File

@ -0,0 +1,390 @@
package redis
import (
"encoding/json"
"errors"
"io"
"net"
"reflect"
"strconv"
"strings"
"sync"
"syscall"
"time"
"apigo.cc/go/cast"
"apigo.cc/go/config"
"apigo.cc/go/log"
"apigo.cc/go/safe"
"github.com/gomodule/redigo/redis"
)
type Redis struct {
name string
pool *redis.Pool
ReadTimeout int
Config *Config
logger *log.Logger
Error error
subConn *redis.PubSubConn
subStopChan chan bool
subs map[string]*SubCallbacks
SubRunning bool
}
type SubCallbacks struct {
received func([]byte)
reset func()
}
var redisInstances = make(map[string]*Redis)
var redisInstancesLock = sync.RWMutex{}
func GetRedis(name string, logger *log.Logger) *Redis {
if logger == nil {
logger = log.DefaultLogger
}
redisInstancesLock.RLock()
oldConn := redisInstances[name]
redisInstancesLock.RUnlock()
if oldConn != nil {
return oldConn.CopyByLogger(logger)
}
redisConfigsLock.RLock()
configsLen := len(redisConfigs)
redisConfigsLock.RUnlock()
if configsLen == 0 {
_ = config.Load("redis", &redisConfigs)
}
fullName := name
var conf *Config
if strings.HasPrefix(name, "redis://") {
conf = new(Config)
conf.logger = logger
conf.ConfigureBy(name)
} else {
conf = parseByName(name)
}
if pwd, err := confAes.Decrypt([]byte(conf.Password)); err == nil {
conf.pwd = pwd
} else {
conf.pwd = safe.NewSafeBuf([]byte(conf.Password))
}
conf.Password = ""
if conf.Host == "" {
conf.Host = "127.0.0.1:6379"
}
if conf.MaxIdle == 0 {
conf.MaxIdle = 20
}
if conf.MaxActive == 0 {
conf.MaxActive = 100
}
if conf.ConnectTimeout == 0 {
conf.ConnectTimeout = 10 * time.Second
}
if conf.ReadTimeout == 0 {
conf.ReadTimeout = 10 * time.Second
}
if conf.WriteTimeout == 0 {
conf.WriteTimeout = 10 * time.Second
}
if conf.LogSlow == 0 {
conf.LogSlow = 100 * time.Millisecond
}
rd := NewRedis(conf, logger)
rd.name = fullName
redisInstancesLock.Lock()
redisInstances[fullName] = rd
redisInstancesLock.Unlock()
return rd.CopyByLogger(logger)
}
func NewRedis(conf *Config, logger *log.Logger) *Redis {
if logger == nil {
logger = log.DefaultLogger
}
conn := &redis.Pool{
MaxIdle: conf.MaxIdle,
MaxActive: conf.MaxActive,
IdleTimeout: conf.IdleTimeout,
Dial: func() (redis.Conn, error) {
opts := []redis.DialOption{
redis.DialConnectTimeout(conf.ConnectTimeout),
redis.DialReadTimeout(conf.ReadTimeout),
redis.DialWriteTimeout(conf.WriteTimeout),
redis.DialDatabase(conf.DB),
}
if conf.pwd != nil {
pwdBuf := conf.pwd.Open()
defer pwdBuf.Close()
opts = append(opts, redis.DialPassword(pwdBuf.String()))
}
c, err := redis.Dial("tcp", conf.Host, opts...)
if err != nil {
logger.Error(err.Error(), "type", "redis", "dsn", conf.Dsn())
return nil, err
}
return c, nil
},
}
rd := new(Redis)
rd.ReadTimeout = int(conf.ReadTimeout / time.Millisecond)
rd.pool = conn
rd.Config = conf
rd.logger = logger
return rd
}
func (rd *Redis) CopyByLogger(logger *log.Logger) *Redis {
newRedis := new(Redis)
newRedis.name = rd.name
newRedis.ReadTimeout = rd.ReadTimeout
newRedis.pool = rd.pool
newRedis.subConn = rd.subConn
newRedis.subs = rd.subs
newRedis.SubRunning = rd.SubRunning
newRedis.Config = rd.Config
if logger == nil {
newRedis.logger = log.DefaultLogger
} else {
newRedis.logger = logger
}
return newRedis
}
func (rd *Redis) SetLogger(logger *log.Logger) {
rd.logger = logger
}
func (rd *Redis) GetLogger() *log.Logger {
return rd.logger
}
func (rd *Redis) Destroy() error {
if rd.pool == nil {
return errors.New("operate on a bad redis pool")
}
err := rd.pool.Close()
if err != nil {
rd.logger.Error(err.Error(), "type", "redis", "dsn", rd.Config.Dsn())
}
redisInstancesLock.Lock()
delete(redisInstances, rd.name)
redisInstancesLock.Unlock()
return err
}
func (rd *Redis) GetPool() *redis.Pool {
return rd.pool
}
func (rd *Redis) GetNewConnection() (redis.Conn, error) {
if rd.pool == nil {
return nil, errors.New("redis pool is not initialized")
}
c, err := rd.pool.Dial()
if err == nil {
err = c.Err()
}
if err != nil {
if c != nil {
_ = c.Close()
}
return nil, err
}
return c, nil
}
func (rd *Redis) GetConnection() (redis.Conn, error) {
if rd.pool == nil {
return nil, errors.New("redis pool is not initialized")
}
c := rd.pool.Get()
err := c.Err()
if err != nil {
if c != nil {
_ = c.Close()
}
return nil, err
}
return c, nil
}
func shouldRetry(err error) bool {
if err == nil {
return false
}
// 网络错误
if errors.Is(err, io.EOF) || errors.Is(err, syscall.ECONNRESET) {
return true
}
// 超时错误
if opErr, ok := err.(*net.OpError); ok && opErr.Timeout() {
return true
}
// Redis特定可恢复错误
if errs, ok := err.(redis.Error); ok {
switch {
case strings.HasPrefix(string(errs), "LOADING"),
strings.HasPrefix(string(errs), "CLUSTERDOWN"):
return true
}
}
return false
}
func (rd *Redis) Do(cmd string, values ...any) *Result {
if rd.pool == nil {
err := errors.New("operate on a bad redis pool")
rd.logger.Error(err.Error(), "type", "redis", "dsn", rd.Config.Dsn(), "cmd", cmd)
return &Result{Error: err}
}
startTime := time.Now()
r := rd.do(cmd, values...)
usedTime := time.Since(startTime)
if r.Error == nil {
if rd.Config.LogSlow > 0 && usedTime >= rd.Config.LogSlow {
rd.logger.Debug("redis slow query", "dsn", rd.Config.Dsn(), "cmd", cmd, "args", values, "used", usedTime.String())
}
} else {
rd.logger.Error(r.Error.Error(), "type", "redis", "dsn", rd.Config.Dsn(), "cmd", cmd, "args", values, "used", usedTime.String())
}
return r
}
func (rd *Redis) do(cmd string, values ...any) *Result {
cmdArr := cast.Split(cmd, " ")
if len(cmdArr) > 1 {
cmd = cmdArr[0]
args := make([]any, 0)
for i := 1; i < len(cmdArr); i++ {
args = append(args, cmdArr[i])
}
if len(values) > 0 {
args = append(args, values...)
}
values = args
}
// 自动序列化
for i, v := range values {
if v == nil {
continue
}
rv := reflect.ValueOf(v)
if rv.Kind() == reflect.Ptr {
rv = rv.Elem()
}
kind := rv.Kind()
if kind == reflect.Struct || kind == reflect.Map || (kind == reflect.Slice && rv.Type().Elem().Kind() != reflect.Uint8) {
// 对于这些类型,优先使用 json.Marshal 以支持自定义 Marshaler (如 time.Time)
if encoded, err := json.Marshal(v); err == nil {
values[i] = encoded
}
}
}
// 从连接池获取
conn, err := rd.GetConnection()
var replyData any
if err == nil {
replyData, err = conn.Do(cmd, values...)
_ = conn.Close()
}
if err != nil && shouldRetry(err) {
// 拿全新的连接重试(如果服务器重启可自动恢复)
conn, err = rd.GetNewConnection()
if err == nil {
replyData, err = conn.Do(cmd, values...)
_ = conn.Close()
}
}
if err != nil {
return &Result{Error: err}
}
r := &Result{}
switch realValue := replyData.(type) {
case []byte:
r.bytesData = realValue
case string:
r.bytesData = []byte(realValue)
case int64:
r.bytesData = []byte(strconv.FormatInt(realValue, 10))
case []any:
if cmd == "HMGET" {
r.keys = make([]string, len(values)-1)
for i, v := range values {
if i > 0 {
r.keys[i-1] = cast.String(v)
}
}
} else if cmd == "MGET" {
r.keys = make([]string, len(values))
for i, v := range values {
r.keys[i] = cast.String(v)
}
}
if cmd == "HGETALL" {
r.keys = make([]string, len(realValue)/2)
r.bytesDatas = make([][]byte, len(realValue)/2)
i1, i2 := 0, 0
for i, v := range realValue {
if v != nil {
if i%2 == 0 {
r.keys[i1] = cast.String(v)
i1++
} else {
switch subRealValue := v.(type) {
case []byte:
r.bytesDatas[i2] = subRealValue
case string:
r.bytesDatas[i2] = []byte(subRealValue)
default:
r.bytesDatas[i2] = []byte(cast.String(v))
}
i2++
}
}
}
} else {
r.bytesDatas = make([][]byte, len(realValue))
for i, v := range realValue {
if v != nil {
switch subRealValue := v.(type) {
case []byte:
r.bytesDatas[i] = subRealValue
case string:
r.bytesDatas[i] = []byte(subRealValue)
default:
r.bytesDatas[i] = []byte(cast.String(v))
}
}
}
}
case nil:
r.bytesData = []byte{}
default:
r.bytesData = []byte(cast.String(realValue))
}
return r
}

104
redis_test.go Normal file
View File

@ -0,0 +1,104 @@
package redis_test
import (
"os"
"testing"
"time"
"apigo.cc/go/config"
"apigo.cc/go/redis"
)
type userInfo struct {
Id int
Name string
Phone string
Time time.Time
}
func TestBase(t *testing.T) {
os.Setenv("REDIS_TEST", "redis://:@localhost:6379/2?timeout=10ms&logSlow=10us")
_ = config.Load("redis", nil)
rd := redis.GetRedis("test", nil)
if rd.Error != nil {
t.Fatal("GetRedis error", rd.Error)
}
rd.DEL("redisName", "redisUser", "redisIds")
r := rd.GET("redisNotExists")
if r.Error != nil && r.String() != "" || r.Int() != 0 {
t.Fatal("GET NotExists", r, r.String(), r.Int())
}
exists := rd.EXISTS("redisName")
if exists {
t.Fatal("EXISTS should be false")
}
rd.SET("redisName", "12345")
r = rd.GETSET("redisName", 12345)
if r.String() != "12345" {
t.Fatal("GETSET String mismatch", r.String())
}
if r.Int() != 12345 {
t.Fatal("Int conversion mismatch", r.Int())
}
exists = rd.EXISTS("redisName")
if !exists {
t.Fatal("EXISTS should be true")
}
// Expire test
rd.SET("redisName", "12")
rd.EXPIRE("redisName", 1)
time.Sleep(2 * time.Second)
r = rd.GET("redisName")
if r.Int() > 0 {
t.Fatal("Expired key still exists", r.Int())
}
// Struct test
info := userInfo{
Name: "aaa",
Id: 123,
Time: time.Now().Truncate(time.Second), // Redis JSON might lose precision
}
rd.SET("redisUser", info)
r = rd.GET("redisUser")
var ru userInfo
_ = r.To(&ru)
if ru.Name != info.Name || ru.Id != info.Id || !ru.Time.Equal(info.Time) {
t.Fatalf("Struct mismatch: expected %+v, got %+v", info, ru)
}
// MSET/MGET test
rd.MSET("redisName", "Sam Lee", "redisIds", []int{1, 2, 3})
results := rd.MGET("redisName", "redisIds")
if len(results) != 2 || results[0].String() != "Sam Lee" {
t.Fatal("MGET Results mismatch")
}
ria := results[1].Ints()
if len(ria) != 3 || ria[0] != 1 || ria[1] != 2 || ria[2] != 3 {
t.Fatal("MGET Ints mismatch", ria)
}
num := rd.DEL("redisName", "redisUser", "redisIds")
if num != 3 {
t.Fatal("DEL count mismatch", num)
}
}
func TestGenerics(t *testing.T) {
rd := redis.GetRedis("test", nil)
rd.SET("gen_test", userInfo{Name: "Generics", Id: 888})
defer rd.DEL("gen_test")
r := rd.GET("gen_test")
user := redis.To[userInfo](r)
if user.Name != "Generics" || user.Id != 888 {
t.Fatal("Generics To[T] mismatch", user)
}
}

224
result.go Normal file
View File

@ -0,0 +1,224 @@
package redis
import (
"encoding/json"
"reflect"
"apigo.cc/go/cast"
)
type Result struct {
bytesData []byte
keys []string
bytesDatas [][]byte
Error error
}
func (rs *Result) Int() int {
return cast.Int(rs.String())
}
func (rs *Result) Int8() int8 {
return int8(cast.Int64(rs.String()))
}
func (rs *Result) Int16() int16 {
return int16(cast.Int64(rs.String()))
}
func (rs *Result) Int32() int32 {
return int32(cast.Int64(rs.String()))
}
func (rs *Result) Int64() int64 {
return cast.Int64(rs.String())
}
func (rs *Result) Uint() uint {
return cast.Uint(rs.String())
}
func (rs *Result) Uint8() uint8 {
return uint8(cast.Uint64(rs.String()))
}
func (rs *Result) Uint16() uint16 {
return uint16(cast.Uint64(rs.String()))
}
func (rs *Result) Uint32() uint32 {
return uint32(cast.Uint64(rs.String()))
}
func (rs *Result) Uint64() uint64 {
return cast.Uint64(rs.String())
}
func (rs *Result) Float() float32 {
return cast.Float(rs.String())
}
func (rs *Result) Float64() float64 {
return cast.Float64(rs.String())
}
func (rs *Result) String() string {
return string(rs.bytes())
}
func (rs *Result) Bytes() []byte {
return rs.bytes()
}
func (rs *Result) Bool() bool {
return cast.Bool(rs.String())
}
func (rs *Result) Ints() []int {
if rs.bytesDatas != nil {
r := make([]int, len(rs.bytesDatas))
for i, v := range rs.bytesDatas {
r[i] = cast.Int(string(v))
}
return r
} else if rs.bytesData != nil {
var r []int
_ = rs.To(&r)
return r
}
return []int{}
}
func (rs *Result) Strings() []string {
if rs.bytesDatas != nil {
r := make([]string, len(rs.bytesDatas))
for i, v := range rs.bytesDatas {
r[i] = string(v)
}
return r
} else if rs.bytesData != nil {
var r []string
_ = rs.To(&r)
return r
}
return []string{}
}
func (rs *Result) Results() []Result {
if rs.bytesDatas != nil {
r := make([]Result, len(rs.bytesDatas))
for i, v := range rs.bytesDatas {
r[i].bytesData = v
}
return r
} else if rs.bytesData != nil {
var m []string
_ = rs.To(&m)
r := make([]Result, len(m))
for k, v := range m {
r[k] = Result{bytesData: []byte(v)}
}
return r
}
return []Result{}
}
func (rs *Result) ResultMap() map[string]*Result {
if rs.bytesDatas != nil && rs.keys != nil {
r := make(map[string]*Result)
n := len(rs.bytesDatas)
for i, k := range rs.keys {
if i < n {
r[k] = &Result{bytesData: rs.bytesDatas[i]}
}
}
return r
} else if rs.bytesData != nil {
r := make(map[string]*Result)
var m map[string]string
_ = rs.To(&m)
for k, v := range m {
r[k] = &Result{bytesData: []byte(v)}
}
return r
}
return map[string]*Result{}
}
func (rs *Result) StringMap() map[string]string {
rm := rs.ResultMap()
m := make(map[string]string)
for k, r := range rm {
m[k] = r.String()
}
return m
}
func (rs *Result) IntMap() map[string]int {
rm := rs.ResultMap()
m := make(map[string]int)
for k, r := range rm {
m[k] = r.Int()
}
return m
}
func (rs *Result) bytes() []byte {
if rs.bytesData != nil {
return rs.bytesData
} else if rs.bytesDatas != nil {
return cast.MustJSONBytes(rs.Strings())
}
return []byte{}
}
func (rs *Result) ToValue(t reflect.Type) reflect.Value {
v := reflect.New(t).Elem()
switch t.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
v.SetInt(rs.Int64())
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
v.SetUint(rs.Uint64())
case reflect.Float32, reflect.Float64:
v.SetFloat(rs.Float64())
case reflect.Bool:
v.SetBool(rs.Bool())
case reflect.String:
v.SetString(rs.String())
case reflect.Map, reflect.Slice, reflect.Struct:
_ = rs.To(v.Addr().Interface())
}
return v
}
// To 使用泛型进行反序列化
func To[T any](rs *Result) T {
var out T
_ = rs.To(&out)
return out
}
func (rs *Result) To(result any) error {
if rs.bytesData != nil {
if len(rs.bytesData) > 0 {
// 优先使用 json.Unmarshal 以支持自定义 Unmarshaler (如 time.Time)
return json.Unmarshal(rs.bytesData, result)
}
return nil
}
t := reflect.TypeOf(result)
v := reflect.ValueOf(result)
if t.Kind() == reflect.Ptr {
t = t.Elem()
v = v.Elem()
}
if (t.Kind() == reflect.Struct || t.Kind() == reflect.Map) && rs.keys != nil && rs.bytesDatas != nil {
rm := rs.ResultMap()
for k, r := range rm {
if t.Kind() == reflect.Struct {
k = cast.GetUpperName(k)
sf, found := t.FieldByName(k)
if found {
v.FieldByName(k).Set(r.ToValue(sf.Type))
}
} else if t.Kind() == reflect.Map {
v.SetMapIndex(reflect.ValueOf(k), r.ToValue(t.Elem()))
}
}
} else if t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8 && rs.bytesDatas != nil {
results := rs.Results()
for _, r := range results {
v = reflect.Append(v, r.ToValue(t.Elem()))
}
reflect.ValueOf(result).Elem().Set(v)
}
return nil
}

147
subscribe.go Normal file
View File

@ -0,0 +1,147 @@
package redis
import (
"strings"
"time"
"github.com/gomodule/redigo/redis"
)
func (rd *Redis) Subscribe(name string, reset func(), received func([]byte)) bool {
if rd.subs == nil {
rd.subs = make(map[string]*SubCallbacks)
}
rd.subs[name] = &SubCallbacks{reset: reset, received: received}
if rd.subConn != nil {
err := rd.subConn.Subscribe(name)
if err != nil {
rd.logger.Error(err.Error(), "type", "redis", "action", "subscribe", "name", name)
} else {
return true
}
}
return false
}
func (rd *Redis) Unsubscribe(name string) bool {
if rd.subs != nil {
delete(rd.subs, name)
}
if rd.subConn != nil {
err := rd.subConn.Unsubscribe(name)
if err != nil {
rd.logger.Error(err.Error(), "type", "redis", "action", "unsubscribe", "name", name)
} else {
return true
}
}
return false
}
func (rd *Redis) Start() {
if rd.subs == nil {
rd.subs = make(map[string]*SubCallbacks)
}
rd.SubRunning = true
subStartChan := make(chan bool)
go rd.receiveSub(subStartChan)
<-subStartChan
}
func (rd *Redis) receiveSub(subStartChan chan bool) {
for {
if !rd.SubRunning {
break
}
// 开始接收订阅数据
if rd.subConn == nil {
conn, err := rd.GetConnection()
if err != nil {
time.Sleep(time.Second)
continue
}
rd.subConn = &redis.PubSubConn{Conn: conn}
// 重新订阅
if len(rd.subs) > 0 {
subs := make([]any, 0)
for k := range rd.subs {
subs = append(subs, k)
}
err = rd.subConn.Subscribe(subs...)
if err != nil {
_ = rd.subConn.Close()
rd.subConn = nil
time.Sleep(time.Second)
continue
}
// 重新连接时调用重置数据的回调
for _, v := range rd.subs {
if v.reset != nil {
v.reset()
}
}
}
}
if subStartChan != nil {
subStartChan <- true
subStartChan = nil
}
for {
isErr := false
receiveObj := rd.subConn.Receive()
switch v := receiveObj.(type) {
case redis.Message:
callback := rd.subs[v.Channel]
if callback != nil && callback.received != nil {
callback.received(v.Data)
}
case redis.Subscription:
case redis.Pong:
case error:
if strings.Contains(v.Error(), "i/o timeout") {
break
}
if !strings.Contains(v.Error(), "connection closed") && !strings.Contains(v.Error(), "use of closed network connection") {
rd.logger.Error(v.Error(), "type", "redis", "action", "receiveSub")
}
if rd.subConn != nil {
_ = rd.subConn.Close()
rd.subConn = nil
}
isErr = true
}
if isErr || !rd.SubRunning {
break
}
}
}
if rd.subStopChan != nil {
rd.subStopChan <- true
}
}
func (rd *Redis) Stop() {
if rd.SubRunning {
rd.subStopChan = make(chan bool)
rd.SubRunning = false
if rd.subConn != nil {
// 取消订阅
if len(rd.subs) > 0 {
_ = rd.subConn.Unsubscribe()
}
// 读一次再关闭可以防止Close时阻塞
_ = rd.subConn.ReceiveWithTimeout(50 * time.Millisecond)
_ = rd.subConn.Close()
rd.subConn = nil
}
<-rd.subStopChan
rd.subStopChan = nil
}
}
func (rd *Redis) PUBLISH(channel, data string) bool {
return rd.Do("PUBLISH "+channel, data).Bool()
}

56
subscribe_test.go Normal file
View File

@ -0,0 +1,56 @@
package redis_test
import (
"sync/atomic"
"testing"
"time"
"apigo.cc/go/redis"
)
func TestSub(t *testing.T) {
rd := redis.GetRedis("test", nil)
rd.Start()
defer rd.Stop()
var aaa atomic.Value
var bbb atomic.Value
rd.Subscribe("aaa", nil, func(s []byte) {
aaa.Store(string(s))
})
rd.PUBLISH("aaa", "111")
time.Sleep(100 * time.Millisecond)
val := aaa.Load()
if val == nil || val.(string) != "111" {
t.Fatal("Subscribe aaa failed", val)
}
rd.Subscribe("bbb", nil, func(s []byte) {
bbb.Store(string(s))
})
rd.PUBLISH("bbb", "222")
time.Sleep(100 * time.Millisecond)
val = bbb.Load()
if val == nil || val.(string) != "222" {
t.Fatal("Subscribe bbb failed", val)
}
rd.Unsubscribe("aaa")
rd.PUBLISH("aaa", "1111")
rd.PUBLISH("bbb", "2222")
time.Sleep(100 * time.Millisecond)
val = aaa.Load()
if val == nil || val.(string) != "111" {
t.Fatal("Unsubscribe aaa failed: value updated", val)
}
val = bbb.Load()
if val == nil || val.(string) != "2222" {
t.Fatal("Subscribe bbb update failed", val)
}
}