diff --git a/js_export.go b/js_export.go index f5d2d72..f84439c 100644 --- a/js_export.go +++ b/js_export.go @@ -3,6 +3,7 @@ package redis import ( "context" "errors" + "strings" "apigo.cc/go/id" "apigo.cc/go/jsmod" @@ -10,13 +11,60 @@ import ( func init() { jsmod.Register("redis", map[string]any{ - "get": func(ctx context.Context, name string) (*jsRedis, error) { - rd := GetRedis(name, nil) + // 入口:支持别名获取,不传则默认 "default" + "Get": func(ctx context.Context, name *string) (*jsRedis, error) { + target := "default" + if name != nil { + target = *name + } + rd := GetRedis(target, nil) if rd.Error != nil { return nil, rd.Error } return &jsRedis{rd: rd, ctx: ctx}, nil }, + + // 默认快捷调用 (面向 "default" 实例) + "Do": func(ctx context.Context, cmd string, args ...any) (*Result, error) { + jr := &jsRedis{rd: GetRedis("default", nil), ctx: ctx} + if jr.rd.Error != nil { + return nil, jr.rd.Error + } + res := jr.Do(cmd, args...) + return res, res.Error + }, + + // 常用命令平铺 (面向 "default" 实例) + "SET": func(ctx context.Context, key string, val any) (*Result, error) { + jr := &jsRedis{rd: GetRedis("default", nil), ctx: ctx} + res := jr.Do("SET", key, val) + return res, res.Error + }, + "GET": func(ctx context.Context, key string) (*Result, error) { + jr := &jsRedis{rd: GetRedis("default", nil), ctx: ctx} + res := jr.Do("GET", key) + return res, res.Error + }, + "DEL": func(ctx context.Context, key string) (*Result, error) { + jr := &jsRedis{rd: GetRedis("default", nil), ctx: ctx} + res := jr.Do("DEL", key) + return res, res.Error + }, + "EXISTS": func(ctx context.Context, key string) (*Result, error) { + jr := &jsRedis{rd: GetRedis("default", nil), ctx: ctx} + res := jr.Do("EXISTS", key) + return res, res.Error + }, + "EXPIRE": func(ctx context.Context, key string, seconds int) (*Result, error) { + jr := &jsRedis{rd: GetRedis("default", nil), ctx: ctx} + res := jr.Do("EXPIRE", key, seconds) + return res, res.Error + }, + "PUBLISH": func(ctx context.Context, channel, data string) (*Result, error) { + jr := &jsRedis{rd: GetRedis("default", nil), ctx: ctx} + res := jr.Do("PUBLISH", channel, data) + return res, res.Error + }, }) } @@ -28,40 +76,60 @@ type jsRedis struct { var errSafeMode = errors.New("redis operation is restricted in safe mode") -func (jr *jsRedis) checkSafe() error { +// 核心写操作指令集 +var writeCommands = map[string]bool{ + "SET": true, "SETEX": true, "SETNX": true, "MSET": true, "MSETNX": true, + "DEL": true, "EXPIRE": true, "EXPIREAT": true, "PEXPIRE": true, "PEXPIREAT": true, + "HSET": true, "HSETNX": true, "HDEL": true, "HMSET": true, + "LPUSH": true, "RPUSH": true, "LPOP": true, "RPOP": true, "LREM": true, "LTRIM": true, + "SADD": true, "SREM": true, "SPOP": true, "SMOVE": true, + "ZADD": true, "ZREM": true, "ZREMRANGEBYRANK": true, "ZREMRANGEBYSCORE": true, + "PUBLISH": true, "FLUSHDB": true, "FLUSHALL": true, +} + +func (jr *jsRedis) checkSafe(cmd string) error { if jsmod.IsSafeMode(jr.ctx) { - return errSafeMode + cmd = strings.ToUpper(cmd) + if writeCommands[cmd] || !strings.Contains(" GET EXISTS ZRANGE HGET HGETALL SMEMBERS SISMEMBER LINDEX LLEN ", " "+cmd+" ") { + // 严格模式:不在白名单内的或在黑名单内的都禁止 + return errSafeMode + } } return nil } -// Do executes any redis command. In SafeMode, it only allows read-only commands. -// Note: Since we don't have a reliable way to categorize all redis commands as read-only, -// and 'DO' is used for everything, we strictly block 'DO' in SafeMode if it's not a known read-only command. -// For simplicity and maximum safety as requested, we block 'DO' entirely in SafeMode. func (jr *jsRedis) Do(cmd string, args ...any) *Result { - if jr.checkSafe() != nil { - return &Result{Error: errSafeMode} + if err := jr.checkSafe(cmd); err != nil { + return &Result{Error: err} } return jr.rd.Do(cmd, args...) } -// ID Generation Helpers -func (jr *jsRedis) getIDMaker() *id.IDMaker { +// 实例方法 PascalCase 对齐 +func (jr *jsRedis) SET(key string, val any) *Result { return jr.Do("SET", key, val) } +func (jr *jsRedis) GET(key string) *Result { return jr.Do("GET", key) } +func (jr *jsRedis) DEL(key string) *Result { return jr.Do("DEL", key) } +func (jr *jsRedis) EXISTS(key string) *Result { return jr.Do("EXISTS", key) } +func (jr *jsRedis) EXPIRE(key string, s int) *Result { return jr.Do("EXPIRE", key, s) } +func (jr *jsRedis) HSET(key, field string, v any) *Result { return jr.Do("HSET", key, field, v) } +func (jr *jsRedis) HGET(key, field string) *Result { return jr.Do("HGET", key, field) } +func (jr *jsRedis) PUBLISH(ch, data string) *Result { return jr.Do("PUBLISH", ch, data) } + +// ID Generation +func (jr *jsRedis) MakeID(size int, forDB *string) string { if jr.idMaker == nil { jr.idMaker = NewIDMaker(jr.rd) } - return jr.idMaker -} - -func (jr *jsRedis) GetID(size int) string { - return jr.getIDMaker().Get(size) -} - -func (jr *jsRedis) GetForMysql(size int) string { - return jr.getIDMaker().GetForMysql(size) -} - -func (jr *jsRedis) GetForPostgreSQL(size int) string { - return jr.getIDMaker().GetForPostgreSQL(size) + dbType := "" + if forDB != nil { + dbType = strings.ToLower(*forDB) + } + switch dbType { + case "mysql": + return jr.idMaker.GetForMysql(size) + case "postgres", "pg", "pgsql": + return jr.idMaker.GetForPostgreSQL(size) + default: + return jr.idMaker.Get(size) + } }