package redis import ( "context" "errors" "strings" "apigo.cc/go/id" "apigo.cc/go/jsmod" "apigo.cc/go/log" ) func init() { jsmod.Register("redis", map[string]any{ // 入口:支持别名获取,不传则默认 "default" "Get": jsGet, // 默认快捷调用 (面向 "default" 实例) "Do": jsDo, "MakeID": jsMakeID, }) } func jsGet(ctx context.Context, name *string) (*jsRedis, error) { target := "default" if name != nil { target = *name } rd := GetRedis(target, nil) if rd.Error != nil { return nil, jsmod.MakeError(rd.Error) } return &jsRedis{rd: rd, ctx: ctx}, nil } func jsDo(ctx context.Context, cmd string, args ...any) (*Result, error) { jr := getDefaultRedisForJS(ctx) if jr.rd.Error != nil { return nil, jsmod.MakeError(jr.rd.Error) } res := jr.Do(cmd, args...) if res != nil && res.Error != nil { return res, jsmod.MakeError(res.Error) } return res, nil } func jsMakeID(ctx context.Context, size int, forDB *string) string { jr := getDefaultRedisForJS(ctx) if jr.rd.Error != nil { return id.MakeID(size) } return jr.MakeID(size, forDB) } type jsRedis struct { rd *Redis ctx context.Context idMaker *id.IDMaker } var defaultRedisForJS *jsRedis func getDefaultRedisForJS(ctx context.Context) *jsRedis { if defaultRedisForJS == nil { defaultRedisForJS = &jsRedis{rd: GetRedis("default", ctx.Value("Logger").(*log.Logger)), ctx: ctx} } return defaultRedisForJS } var errSafeMode = errors.New("redis operation is restricted in safe mode") // 核心写操作指令集 var writeCommands = map[string]bool{ "SET": true, "SETEX": true, "SETNX": true, "MSET": true, "MSETNX": true, "DEL": true, "EXPIRE": true, "EXPIREAT": true, "PEXPIRE": true, "PEXPIREAT": true, "HSET": true, "HSETNX": true, "HDEL": true, "HMSET": true, "LPUSH": true, "RPUSH": true, "LPOP": true, "RPOP": true, "LREM": true, "LTRIM": true, "SADD": true, "SREM": true, "SPOP": true, "SMOVE": true, "ZADD": true, "ZREM": true, "ZREMRANGEBYRANK": true, "ZREMRANGEBYSCORE": true, "PUBLISH": true, "FLUSHDB": true, "FLUSHALL": true, } func (jr *jsRedis) checkSafe(cmd string) error { if jsmod.IsSafeMode(jr.ctx) { cmd = strings.ToUpper(cmd) if writeCommands[cmd] || !strings.Contains(" GET EXISTS ZRANGE HGET HGETALL SMEMBERS SISMEMBER LINDEX LLEN ", " "+cmd+" ") { // 严格模式:不在白名单内的或在黑名单内的都禁止 return errSafeMode } } return nil } func (jr *jsRedis) Do(cmd string, args ...any) *Result { if err := jr.checkSafe(cmd); err != nil { return &Result{Error: jsmod.MakeError(err)} } res := jr.rd.Do(cmd, args...) if res != nil && res.Error != nil { resCopy := *res resCopy.Error = jsmod.MakeError(res.Error) return &resCopy } return res } // ID Generation func (jr *jsRedis) MakeID(size int, forDB *string) string { if jr.idMaker == nil { jr.idMaker = NewIDMaker(jr.rd) } dbType := "" if forDB != nil { dbType = strings.ToLower(*forDB) } switch dbType { case "mysql": return jr.idMaker.GetForMysql(size) case "postgres", "pg", "pgsql": return jr.idMaker.GetForPostgreSQL(size) default: return jr.idMaker.Get(size) } }