diff --git a/CHANGELOG.md b/CHANGELOG.md index 079ecc8..a5f1dd5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,16 @@ # CHANGELOG - go/service -## v1.3.1 (2026-05-10) +## v1.3.6 (2026-05-31) +- **基础设施增强**: + - 新增 `WebSocketConn` 标准包装器,提供统一的 `Send`, `ReadString`, `ReadBytes`, `ReadJSON` 接口。 + - 新增 `Upgrade` 函数,支持在自定义处理器中手动升级 HTTP 为 WebSocket。 + - 集成 `Session` 会话管理,支持分布式 Redis 或本地内存存储。 +- **安全加固**: 彻底移除 `UploadFile.Save` 方法,规避低代码环境下的文件落盘风险。 +- **JSMOD 类型对齐**: + - 提供 `newRequest`, `newResponse`, `newWebSocket`, `newSession`, `newFile` 占位工厂,支持 AI 环境下的类型自动发现 (DTS)。 + - 导出 `upgrade` 方法支持动态服务分发场景。 + +## v1.3.5 (2026-05-31) - **Logging Refactor (Callback Pattern)**: 引入 `LogRequest` 闭环式回调封装,自动处理日志级别检查、对象池获取及元数据填充,消除 20+ 参数带来的维护压力。 - **Graceful Shutdown**: `ServiceConfig` 新增 `StopTimeout` 字段,支持通过配置灵活管控服务优雅退出的超时时间(默认 5s)。 - **Panic Recovery**: 增强 `handler.go` 中的 `recover` 逻辑,在发生 Panic 时自动记录 `requestId` 和 `path`,大幅提升故障定位效率。 diff --git a/README.md b/README.md index 6039860..76c35f4 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,28 @@ service.Host("*").WebSocket("/ws", func(conn *websocket.Conn, logger *log.Logger }).Auth(0).Memo("聊天室") ``` -### 3. 生命周期管理 +### 4. 会话管理 (Session) +框架内置了会话管理机制,支持内存存储和 Redis 存储。 + +- **获取会话**: 在服务方法中注入 `*service.Session` 参数。 +- **自动鉴权**: 如果未设置自定义 `AuthChecker`,框架将自动检查 Session 中的 `_authLevel` 是否满足接口要求。 +- **权限校验**: `session.AuthFuncs("func1", "&func2")` 支持 OR 和 AND 逻辑的权限细粒度校验。 + +```go +service.Host("*").POST("/login", func(s *service.Session, in LoginArgs) string { + // 业务登录逻辑... + s.SetAuthLevel(1) // 设置鉴权级别 + s.Set("uid", "123") + s.Save() // 持久化 + return "ok" +}) + +service.Host("*").GET("/profile", func(s *service.Session) any { + return s.Get("uid") +}).Auth(1) // 自动要求 Session AuthLevel >= 1 +``` + +### 5. 生命周期管理 ```go func main() { // 异步启动 diff --git a/config.go b/config.go index 1def31c..0626cbd 100644 --- a/config.go +++ b/config.go @@ -48,6 +48,8 @@ type ServiceConfig struct { Memory int // 内存限制 (MB) CookieScope string // Session Cookie 有效范围: host|domain|topDomain SessionWithoutCookie bool // Session 禁用 Cookie + SessionRedis string // Session 存储使用的 Redis 配置名称 (不设置则使用内存) + SessionTimeout int // Session 有效期 (秒,默认 3600) DeviceWithoutCookie bool // 设备ID禁用 Cookie IdServer string // Redis 服务器连接 (用于全局唯一 ID 生成) IndexFiles []string // 静态文件索引文件 diff --git a/go.mod b/go.mod index 7d404a7..d4afcfd 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( apigo.cc/go/file v1.3.2 apigo.cc/go/http v1.3.2 apigo.cc/go/id v1.3.1 + apigo.cc/go/jsmod v1.0.1 apigo.cc/go/log v1.3.4 apigo.cc/go/redis v1.3.2 apigo.cc/go/safe v1.3.1 diff --git a/go.sum b/go.sum index 67d6744..d6d34e5 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,8 @@ apigo.cc/go/http v1.3.2 h1:0Or5KfoIq4+yeWKYusYPV8XLPw8XuzJMeaFv7dZViLI= apigo.cc/go/http v1.3.2/go.mod h1:Q9R7Ors0Fz2A6Mxg0dykO2PjCzdAHRRXreOUMjMOLwA= apigo.cc/go/id v1.3.1 h1:pkqi6VeWyQoHuIu0Zbx/RRxIAdM61Js0j6cY1M9XVCk= apigo.cc/go/id v1.3.1/go.mod h1:P2/vl3tyW3US+ayOFSMoPIOCulNLBngNYPhXJC/Z7J4= +apigo.cc/go/jsmod v1.0.1 h1:vaz3cMQi75UVoALLfyV/Trs8iP/Nh28yN57IvBFpPGk= +apigo.cc/go/jsmod v1.0.1/go.mod h1:bmyeZtOAP/j5am+YRnaiM89smysK24K7ebk0koFtsSw= apigo.cc/go/log v1.3.4 h1:UT8Neb9r4QjjbCFbTzw+ZeTxd+DmdmR5gNExeR4Cj+g= apigo.cc/go/log v1.3.4/go.mod h1:/Q/2r51xWSsrS4QN5U9jLiTw8n6qNC8kG9nuVHweY20= apigo.cc/go/rand v1.3.1 h1:7FvsI6PtQ5XrWER0dTiLVo0p7GIxRidT/TBKhVy93j8= diff --git a/handler.go b/handler.go index ebd9f18..ac7b9cf 100644 --- a/handler.go +++ b/handler.go @@ -177,11 +177,19 @@ func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if ws != nil { authLevel = ws.authLevel priority = ws.options.Priority - doWebsocketService(ws, request, response, requestLogger) + // 鉴权 + pass, obj := checkAuth(ws.authLevel, &ws.options, request, response, args, requestLogger) + if !pass { + if !response.changed { + response.WriteHeader(http.StatusForbidden) + } + return + } + doWebsocketService(ws, request, response, requestLogger, obj) return } else if s != nil { // 鉴权 - pass, obj := checkAuth(s, request, response, args, requestLogger) + pass, obj := checkAuth(s.authLevel, &s.options, request, response, args, requestLogger) if !pass { if !response.changed { response.WriteHeader(http.StatusForbidden) @@ -303,15 +311,23 @@ func parseRequestArgs(request *Request, args map[string]any) { } } -func checkAuth(s *webServiceType, request *Request, response *Response, args map[string]any, logger *log.Logger) (bool, any) { - ac := webAuthCheckers[s.authLevel] +func checkAuth(authLevel int, options *WebServiceOptions, request *Request, response *Response, args map[string]any, logger *log.Logger) (bool, any) { + ac := webAuthCheckers[authLevel] if ac == nil { ac = webAuthChecker } if ac == nil { - return true, nil + sess := NewSession(request.SessionId(), logger) + if authLevel > 0 && sess.GetAuthLevel() < authLevel { + return false, sess + } + return true, sess } - return ac(s.authLevel, logger, &request.RequestURI, args, request, response, &s.options) + pass, obj := ac(authLevel, logger, &request.RequestURI, args, request, response, options) + if pass && obj == nil { + obj = NewSession(request.SessionId(), logger) + } + return pass, obj } func doWebService(service *webServiceType, request *Request, response *Response, args map[string]any, @@ -347,7 +363,9 @@ func doWebService(service *webServiceType, request *Request, response *Response, params[i] = reflect.ValueOf(in).Elem() default: // 尝试依赖注入 - if obj := GetInject(t); obj != nil { + if object != nil && reflect.TypeOf(object).AssignableTo(t) { + params[i] = reflect.ValueOf(object) + } else if obj := GetInject(t); obj != nil { params[i] = reflect.ValueOf(obj) } else { params[i] = reflect.New(t).Elem() diff --git a/js_export.go b/js_export.go new file mode 100644 index 0000000..82d1e69 --- /dev/null +++ b/js_export.go @@ -0,0 +1,45 @@ +package service + +import ( + "apigo.cc/go/jsmod" +) + +func init() { + jsmod.Register("service", map[string]any{ + // 类型占位工厂 (用于 AI 发现类型结构) + "newRequest": func() *Request { return &Request{} }, + "newResponse": func() *Response { return &Response{} }, + "newWebSocket": func() *WebSocketConn { return &WebSocketConn{} }, + "newSession": func() *Session { return &Session{} }, + "newFile": func() *jsUploadFile { return &jsUploadFile{} }, + + // 功能函数 + "upgrade": Upgrade, + }) +} + +// jsUploadFile 包装 UploadFile 以隐藏敏感方法 +type jsUploadFile struct { + f *UploadFile +} + +func (j *jsUploadFile) Filename() string { + if j.f == nil { + return "" + } + return j.f.Filename +} + +func (j *jsUploadFile) Size() int64 { + if j.f == nil { + return 0 + } + return j.f.Size +} + +func (j *jsUploadFile) Content() ([]byte, error) { + if j.f == nil { + return nil, nil + } + return j.f.Content() +} diff --git a/request.go b/request.go index 9e111f1..9fec36a 100644 --- a/request.go +++ b/request.go @@ -2,7 +2,6 @@ package service import ( "apigo.cc/go/discover" - "apigo.cc/go/file" "io" "mime/multipart" "net" @@ -24,15 +23,6 @@ func (f *UploadFile) Open() (multipart.File, error) { return f.fileHeader.Open() } -// Save 保存上传文件到本地 -func (f *UploadFile) Save(filename string) error { - data, err := f.Content() - if err != nil { - return err - } - return file.WriteBytes(filename, data) -} - // Content 获取上传文件内容 func (f *UploadFile) Content() ([]byte, error) { src, err := f.fileHeader.Open() diff --git a/response.go b/response.go index b8a6d60..a308f5c 100644 --- a/response.go +++ b/response.go @@ -31,7 +31,6 @@ func NewResponse(writer http.ResponseWriter) *Response { // Header 获取响应头部 func (r *Response) Header() http.Header { - r.changed = true if r.ProxyHeader != nil { return *r.ProxyHeader } diff --git a/session.go b/session.go new file mode 100644 index 0000000..74ee5c3 --- /dev/null +++ b/session.go @@ -0,0 +1,214 @@ +package service + +import ( + "apigo.cc/go/cast" + "apigo.cc/go/log" + "apigo.cc/go/redis" + "errors" + "strings" + "sync" + "time" +) + +// Session 会话对象 +type Session struct { + id string + conn *redis.Redis + data map[string]any + funcAuthCache map[string]bool + lock sync.RWMutex +} + +var ( + memorySessionData = map[string]map[string]any{} + memorySessionDataLock = sync.RWMutex{} + lastSessionClearTime int64 +) + +// NewSession 创建或加载会话 +func NewSession(id string, logger *log.Logger) *Session { + data := map[string]any{} + var conn *redis.Redis + + timeout := Config.SessionTimeout + if timeout <= 0 { + timeout = 3600 + } + + if Config.SessionRedis != "" { + conn = redis.GetRedis(Config.SessionRedis, logger) + err := conn.GET("SESS_" + id).To(&data) + if err == nil { + _ = conn.EXPIRE("SESS_"+id, timeout) + } + } else { + memorySessionDataLock.RLock() + if d, ok := memorySessionData[id]; ok && d != nil { + for k, v := range d { + data[k] = v + } + } + memorySessionDataLock.RUnlock() + } + + return &Session{ + id: id, + conn: conn, + data: data, + funcAuthCache: map[string]bool{}, + } +} + +// Set 设置会话数据 +func (s *Session) Set(key string, value any) { + s.lock.Lock() + defer s.lock.Unlock() + s.data[key] = value +} + +// Get 获取会话数据 +func (s *Session) Get(key string) any { + s.lock.RLock() + defer s.lock.RUnlock() + return s.data[key] +} + +// Remove 移除会话数据 +func (s *Session) Remove(key string) { + s.lock.Lock() + defer s.lock.Unlock() + delete(s.data, key) +} + +// SetAuthLevel 设置鉴权级别 +func (s *Session) SetAuthLevel(level int) { + s.Set("_authLevel", level) +} + +// GetAuthLevel 获取当前鉴权级别 +func (s *Session) GetAuthLevel() int { + return cast.Int(s.Get("_authLevel")) +} + +// Save 保存会话数据 +func (s *Session) Save() error { + s.lock.Lock() + defer s.lock.Unlock() + + timeout := Config.SessionTimeout + if timeout <= 0 { + timeout = 3600 + } + + if s.conn == nil { + now := time.Now().Unix() + s.data["_time"] = now + + // 复制一份数据存储,防止外部修改 + saveData := make(map[string]any) + for k, v := range s.data { + saveData[k] = v + } + + memorySessionDataLock.Lock() + memorySessionData[s.id] = saveData + + clearTimeDiff := now - lastSessionClearTime + if clearTimeDiff > 60 { + lastSessionClearTime = now + } + memorySessionDataLock.Unlock() + + if clearTimeDiff > 60 { + go clearMemorySession(int64(timeout)) + } + return nil + } else { + if !s.conn.SETEX("SESS_"+s.id, timeout, s.data) { + return errors.New("redis save failed") + } + return nil + } +} + +func clearMemorySession(timeout int64) { + memorySessionDataLock.Lock() + defer memorySessionDataLock.Unlock() + now := time.Now().Unix() + for id, data := range memorySessionData { + if t, ok := data["_time"].(int64); ok { + if now-t > timeout { + delete(memorySessionData, id) + } + } + } +} + +// AuthFuncs 检查权限 +func (s *Session) AuthFuncs(needFuncs ...string) bool { + if len(needFuncs) == 0 { + return true + } + + s.lock.RLock() + cacheKey := strings.Join(needFuncs, "; ") + if res, ok := s.funcAuthCache[cacheKey]; ok { + s.lock.RUnlock() + return res + } + s.lock.RUnlock() + + userFuncs, _ := cast.ToSlice[string](s.Get("funcs")) + isOk := false + + // 超级管理员判断 + for _, uf := range userFuncs { + if uf == "system.superAdmin." || strings.HasPrefix(uf, "system.superAdmin.") { + isOk = true + break + } + } + + if !isOk && len(userFuncs) > 0 { + requiredAuthTotal := 0 + for _, nf := range needFuncs { + if strings.HasPrefix(nf, "&") { + requiredAuthTotal++ + } + } + + normalAuthOk := 0 + requiredAuthOk := 0 + + for _, nf := range needFuncs { + isRequired := false + matchFunc := nf + if strings.HasPrefix(nf, "&") { + isRequired = true + matchFunc = nf[1:] + } + + for _, uf := range userFuncs { + if strings.HasPrefix(uf, matchFunc) { + if isRequired { + requiredAuthOk++ + } else { + normalAuthOk++ + } + break + } + } + + // 如果是非必需权限命中,或者必需权限已全部命中且至少命中了一个非必需权限(如果有) + if (normalAuthOk > 0 || requiredAuthTotal == len(needFuncs)) && requiredAuthOk == requiredAuthTotal { + isOk = true + break + } + } + } + + s.lock.Lock() + s.funcAuthCache[cacheKey] = isOk + s.lock.Unlock() + return isOk +} diff --git a/session_test.go b/session_test.go new file mode 100644 index 0000000..a8e37a0 --- /dev/null +++ b/session_test.go @@ -0,0 +1,158 @@ +package service + +import ( + "apigo.cc/go/log" + "net/http" + "net/http/httptest" + "testing" +) + +func TestSessionLogic(t *testing.T) { + SetClientKeys("", "", "sessid") + Config.SessionTimeout = 3600 + + // 1. 测试 Session 数据存取 + sess := NewSession("test_id", nil) + sess.Set("key1", "value1") + if sess.Get("key1") != "value1" { + t.Errorf("Expected value1, got %v", sess.Get("key1")) + } + + if err := sess.Save(); err != nil { + t.Errorf("Save failed: %v", err) + } + + sess2 := NewSession("test_id", nil) + if sess2.Get("key1") != "value1" { + t.Errorf("Expected value1 in new session instance, got %v", sess2.Get("key1")) + } + + // 2. 测试 AuthFuncs 逻辑 + sess.Set("funcs", []string{"user.read", "user.write", "system.admin"}) + + if !sess.AuthFuncs("user.read") { + t.Error("Expected true for user.read") + } + if !sess.AuthFuncs("user.read", "user.write") { + t.Error("Expected true for user.read and user.write") + } + if sess.AuthFuncs("user.delete") { + t.Error("Expected false for user.delete") + } + + // 测试必需权限 & + if sess.AuthFuncs("&user.read", "other") { + t.Error("Expected false for &user.read when other is missing") + } + if !sess.AuthFuncs("&user.read") { + t.Error("Expected true for &user.read") + } + + if sess.AuthFuncs("&user.delete", "user.read") { + t.Error("Expected false for &user.delete even if user.read exists") + } + + // 测试超级管理员 + sess.Set("funcs", []string{"system.superAdmin.all"}) + if !sess.AuthFuncs("any.thing") { + t.Error("Expected true for superAdmin") + } +} + +func TestSessionInjection(t *testing.T) { + SetClientKeys("", "", "sessid") + + handler := func(s *Session) string { + if s == nil { + return "no session" + } + s.Set("name", "star") + _ = s.Save() + return "ok" + } + Host("*").GET("/test-session", handler) + + rh := &RouteHandler{} + req := httptest.NewRequest("GET", "/test-session", nil) + req.Header.Set("sessid", "sess_123") + w := httptest.NewRecorder() + + rh.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected 200, got %d", w.Code) + } + if w.Body.String() != "ok" { + t.Errorf("Expected ok, got %s", w.Body.String()) + } + + // 验证 Session 是否真的保存了 + sess := NewSession("sess_123", nil) + if sess.Get("name") != "star" { + t.Errorf("Expected star, got %v", sess.Get("name")) + } +} + +type CustomAuth struct { + User string +} + +func TestCustomAuthInjection(t *testing.T) { + AddAuthChecker([]int{10}, func(authLevel int, logger *log.Logger, url *string, in map[string]any, request *Request, response *Response, options *WebServiceOptions) (pass bool, object any) { + return true, &CustomAuth{User: "custom_user"} + }) + + handler := func(auth *CustomAuth) string { + if auth == nil { + return "no auth" + } + return auth.User + } + Host("*").GET("/test-auth", handler).Auth(10) + + rh := &RouteHandler{} + req := httptest.NewRequest("GET", "/test-auth", nil) + w := httptest.NewRecorder() + + rh.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected 200, got %d", w.Code) + } + if w.Body.String() != "custom_user" { + t.Errorf("Expected custom_user, got %s", w.Body.String()) + } +} + +func TestAutomaticAuthLevelCheck(t *testing.T) { + SetClientKeys("", "", "sessid") + + handler := func() string { + return "ok" + } + Host("*").GET("/test-auto-auth", handler).Auth(1) + + rh := &RouteHandler{} + + // 1. 无 Session 或 AuthLevel=0 时应失败 + req1 := httptest.NewRequest("GET", "/test-auto-auth", nil) + req1.Header.Set("sessid", "sess_auto_1") + w1 := httptest.NewRecorder() + rh.ServeHTTP(w1, req1) + if w1.Code != http.StatusForbidden { + t.Errorf("Expected 403, got %d", w1.Code) + } + + // 2. 设置 Session AuthLevel=1 后应成功 + sess := NewSession("sess_auto_2", nil) + sess.SetAuthLevel(1) + _ = sess.Save() + + req2 := httptest.NewRequest("GET", "/test-auto-auth", nil) + req2.Header.Set("sessid", "sess_auto_2") + w2 := httptest.NewRecorder() + rh.ServeHTTP(w2, req2) + if w2.Code != http.StatusOK { + t.Errorf("Expected 200, got %d", w2.Code) + } +} diff --git a/websocket.go b/websocket.go index 91f67f9..6257987 100644 --- a/websocket.go +++ b/websocket.go @@ -11,13 +11,61 @@ var defaultUpgrader = &websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, } -func doWebsocketService(ws *websocketServiceType, request *Request, response *Response, logger *log.Logger) { +// WebSocketConn 包装标准的 websocket.Conn,提供更友好的接口 +type WebSocketConn struct { + Conn *websocket.Conn +} + +// Send 发送消息,支持 string, []byte 或 自动转 JSON +func (c *WebSocketConn) Send(data any) error { + switch v := data.(type) { + case string: + return c.Conn.WriteMessage(websocket.TextMessage, []byte(v)) + case []byte: + return c.Conn.WriteMessage(websocket.BinaryMessage, v) + default: + return c.Conn.WriteJSON(v) + } +} + +// ReadString 读取字符串消息 +func (c *WebSocketConn) ReadString() (string, error) { + _, b, err := c.Conn.ReadMessage() + return string(b), err +} + +// ReadBytes 读取二进制消息 +func (c *WebSocketConn) ReadBytes() ([]byte, error) { + _, b, err := c.Conn.ReadMessage() + return b, err +} + +// ReadJSON 读取 JSON 消息 +func (c *WebSocketConn) ReadJSON(v any) error { + return c.Conn.ReadJSON(v) +} + +// Close 关闭连接 +func (c *WebSocketConn) Close() error { + return c.Conn.Close() +} + +// Upgrade 将 HTTP 请求升级为 WebSocket 连接 +func Upgrade(response *Response, request *Request) (*WebSocketConn, error) { conn, err := defaultUpgrader.Upgrade(response.Writer, request.Request, nil) + if err != nil { + return nil, err + } + return &WebSocketConn{Conn: conn}, nil +} + +func doWebsocketService(ws *websocketServiceType, request *Request, response *Response, logger *log.Logger, object any) { + wsConn, err := Upgrade(response, request) if err != nil { logger.Error("websocket upgrade failed", "error", err.Error()) return } - defer conn.Close() + defer wsConn.Close() // 调用业务处理函数,注入依赖 params := make([]reflect.Value, ws.funcType.NumIn()) @@ -27,8 +75,12 @@ func doWebsocketService(ws *websocketServiceType, request *Request, response *Re params[i] = reflect.ValueOf(request) } else if t == reflect.TypeOf(logger) { params[i] = reflect.ValueOf(logger) - } else if t == reflect.TypeOf(conn) { - params[i] = reflect.ValueOf(conn) + } else if t == reflect.TypeOf(wsConn) { + params[i] = reflect.ValueOf(wsConn) + } else if t == reflect.TypeOf(wsConn.Conn) { + params[i] = reflect.ValueOf(wsConn.Conn) + } else if object != nil && reflect.TypeOf(object).AssignableTo(t) { + params[i] = reflect.ValueOf(object) } else if obj := GetInject(t); obj != nil { params[i] = reflect.ValueOf(obj) } else {