diff --git a/CHANGELOG.md b/CHANGELOG.md index aa15db5..0b0a9f5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,26 @@ # CHANGELOG - go/service +## v1.5.17 (2026-06-21) +- **Session 增强**: + - `Save` 改为可变参数 `Save(args ...map[string]any) error`,支持传入 map 批量设置后保存。 + - 新增 `Load(keys []string) map[string]any`,支持批量读取,空参数返回全部数据。 + - `Remove` 改为可变参数 `Remove(keys ...string)`,支持一次删除多个 key。 +- **日志脱敏引擎**: + - 新增 `sanitizeLogData` 递归脱敏函数,基于 size 预算、字符串截断、数组/对象元素限制构建日志安全对象,不影响原始数据。 + - 新增 `sanitizeLogHeaders` 统一处理请求/响应头过滤。 + - 新增 `sanitizeRespBody` 自动 JSON 解析后走对象脱敏,非 JSON 按字符串截断。 +- **日志配置新增**: + - `NoLogInput`/`NoLogOutput`/`NoLogAllHeaders` — 高并发加速开关。 + - `LogInputObjectNum`/`LogOutputObjectNum` — 对象最多记录 key 数(默认 10/5)。 + - 各配置默认值: fieldSize=20, arrayNum(In=5/Out=3), maxSize=200, NoLogHeaders 默认过滤内部头。 +- **响应体捕获修复**: + - `response.Write` 改为始终调用 `keepBody` 缓冲 body(限制大小 LogOutputMaxSize),修复 200 响应无日志输出的问题。 + - `outputResult` 先调 `keepBody` 再分路径写出,修复 dev 模式(hasOutFilter)下 PhysicalWrite 绕过 body 缓冲的问题。 +- **客户端 Key 默认值**: `Session-ID`/`Device-ID`/`App`,配合 HTTP header canonicalize 显示为 `Session-Id`/`Device-Id`。 +- **调试清理**: 移除 handler.go 中 `fmt.Println`/`shell.BMagenta` 调试残留。 +- **日志格式优化**: RequestHeaders 颜色 blue,ResponseDataLength key 为 Size,ResponseData/ResponseHeaders 颜色 yellow。 +- **依赖更新**: 升级 `js` 至 `v1.5.6`。 + ## v1.5.15 (2026-06-21) - **错误堆栈重构**: - 重构 `js_export.go`,将匿名占位工厂函数改写为包级具名函数。 diff --git a/TEST.md b/TEST.md index 77b0743..da91220 100644 --- a/TEST.md +++ b/TEST.md @@ -24,7 +24,11 @@ - [x] `TestGetDefaultName`: 自动应用名识别 - [x] `TestGetServerIp`: 自动 IP 探测 - [x] `TestSmartStartup`: 零配置智能启动与 Discover 注册 -- [x] **Logging Filters**: 已手动验证 `NoLogGets`, `NoLogHeaders` 等过滤逻辑。 +- [x] `TestSanitizeScalars` ~ `TestSanitizeMixedSlice`: 日志脱敏 10 个测试(标量/对象/数组/嵌套/预算/Unicode) +- [x] `TestSessionLogic`: Session Save/Load/Remove 及 AuthFuncs +- [x] `TestSessionInjection`: Session HTTP 注入流程 +- [x] Logging Filters: NoLogInput/NoLogOutput/NoLogAllHeaders/NoLogGets/NoLogHeaders +- [x] Response Body: 200 响应和 dev 模式下 keepBody 捕获 ## 基础设施对齐验证 - [x] 成功集成 `apigo.cc/go/cast` 用于参数解析与类型强转。 diff --git a/config.go b/config.go index e0f12eb..ab1395f 100644 --- a/config.go +++ b/config.go @@ -23,12 +23,18 @@ type ServiceConfig struct { Listen string // 监听端口(|隔开多个监听)(,隔开多个选项),例如 80,http|443|443:h2|127.0.0.1:8080,h2c SSL map[string]*CertSet // SSL 证书配置,key 为域名 NoLogGets bool // 不记录 GET 请求的日志 + NoLogInput bool // 不记录请求输入 + NoLogOutput bool // 不记录响应输出 + NoLogAllHeaders bool // 不记录所有请求/响应头 NoLogHeaders string // 不记录请求头中包含的这些字段,多个字段用逗号分隔 - LogInputArrayNum int // 请求字段中容器类型在日志打印个数限制 - LogInputFieldSize int // 请求字段中单个字段在日志打印长度限制 + LogInputObjectNum int // 请求对象中最多记录的 key 数 + LogInputArrayNum int // 请求数组中最多记录的元素数 + LogInputFieldSize int // 请求单个字段的字符串截断长度 NoLogOutputFields string // 不记录响应字段中包含的这些字段 - LogOutputArrayNum int // 响应字段中容器类型在日志打印个数限制 - LogOutputFieldSize int // 响应字段中单个字段在日志打印长度限制 + LogOutputObjectNum int // 响应对象中最多记录的 key 数 + LogOutputArrayNum int // 响应数组中最多记录的元素数 + LogOutputFieldSize int // 响应单个字段的字符串截断长度 + LogOutputMaxSize int // 非对象响应内容的最大记录长度 Compress bool // 是否启用压缩 CompressMinSize int // 启用压缩的最小长度 CompressMaxSize int // 启用压缩的最大长度 diff --git a/go.mod b/go.mod index a7393fa..6e89873 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,7 @@ require ( apigo.cc/go/id v1.5.4 apigo.cc/go/jsmod v1.5.3 apigo.cc/go/log v1.5.8 - apigo.cc/go/redis v1.5.7 + apigo.cc/go/redis v1.5.8 apigo.cc/go/safe v1.5.2 apigo.cc/go/starter v1.5.5 apigo.cc/go/timer v1.5.0 diff --git a/go.sum b/go.sum index 6bbef62..4c3cdb4 100644 --- a/go.sum +++ b/go.sum @@ -20,8 +20,8 @@ apigo.cc/go/log v1.5.8 h1:/IYtGPWhRjT3OayylDIphkWZIQbpLjqVeSnFEiD3Dy0= apigo.cc/go/log v1.5.8/go.mod h1:HfFPANMYxJx197SSTXB21Pgxcz/gGqPP8nlSErgd5WE= apigo.cc/go/rand v1.5.3 h1:O4bPIwyaOWEBCr0nL9A4G4qG48AqiGTCzfPeckm3Ius= apigo.cc/go/rand v1.5.3/go.mod h1:q1BTFkY/cXE229dDD5Q22lF7T0DoKPV6xAu+6bCrDH4= -apigo.cc/go/redis v1.5.7 h1:tFE/lDgz08XYFIo74aYj/qYWEPFjEwtC7oU/7zd2xBg= -apigo.cc/go/redis v1.5.7/go.mod h1:PsBVxmoUz4aCeffvofhb0J69JriahHFWRuMU6Qkw6Pk= +apigo.cc/go/redis v1.5.8 h1:cYPA3/dzo7pHKx14BS4ZqOq1aPgWyYFewE2b0BBnLGI= +apigo.cc/go/redis v1.5.8/go.mod h1:PsBVxmoUz4aCeffvofhb0J69JriahHFWRuMU6Qkw6Pk= apigo.cc/go/safe v1.5.2 h1:EnuEOW/SGwf/5A0nw9LnqfKJE071+TIc6ez8HI9R9Lg= apigo.cc/go/safe v1.5.2/go.mod h1:2GqCCLLGex4OAhdET3iBWm1R+LIYtmTrvHP8W0iESSw= apigo.cc/go/shell v1.5.4 h1:Kn6lP6I6d9U0hbyUjpKKFdFZ8RPo4vi4V6AYW8YFzrc= diff --git a/handler.go b/handler.go index c229110..8d1c894 100644 --- a/handler.go +++ b/handler.go @@ -1,10 +1,6 @@ package service import ( - "apigo.cc/go/cast" - "apigo.cc/go/discover" - "apigo.cc/go/log" - "apigo.cc/go/timer" "io" "net/http" "net/url" @@ -13,6 +9,11 @@ import ( "strings" "sync/atomic" "time" + + "apigo.cc/go/cast" + "apigo.cc/go/discover" + "apigo.cc/go/log" + "apigo.cc/go/timer" ) type RouteHandler struct { @@ -71,43 +72,33 @@ func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } usedTime := float32(tracker.Stop().Seconds()) - // 过滤请求头 - reqHeaders := make(map[string]string) - noLogHeaders := strings.Split(ws.Config.NoLogHeaders, ",") - for k, v := range r.Header { - skip := false - for _, nl := range noLogHeaders { - if nl != "" && strings.EqualFold(k, strings.TrimSpace(nl)) { - skip = true - break - } - } - if !skip { - reqHeaders[k] = strings.Join(v, ", ") - } + // 请求头 + var reqHeaders map[string]string + if !ws.Config.NoLogAllHeaders { + reqHeaders = sanitizeLogHeaders(r.Header, ws.Config.NoLogHeaders) } - // 过滤响应头 - respHeaders := make(map[string]string) - for k, v := range response.Header().H { - respHeaders[k] = strings.Join(v, ", ") + // 响应头 + var respHeaders map[string]string + if !ws.Config.NoLogAllHeaders { + respHeaders = sanitizeLogHeaders(response.Header().H, "") } - // 处理响应内容截断 + // 请求输入脱敏 + var reqData any + if !ws.Config.NoLogInput && args != nil { + reqData = sanitizeLogData(args, sanitizeOpts{ + maxSize: 200, + fieldSize: ws.Config.LogInputFieldSize, + arrayNum: ws.Config.LogInputArrayNum, + objectNum: ws.Config.LogInputObjectNum, + }) + } + + // 响应输出脱敏 var respData any - if response.Code != 200 { - if len(response.body) < 1024 { - respData = string(response.body) - } else { - respData = string(response.body[:1024]) + "..." - } - } else if ws.Config.NoLogOutputFields != "" { - // 简单的字段过滤逻辑 (如果是 JSON 对象) - // 这里可以根据 Config.NoLogOutputFields, LogOutputArrayNum, LogOutputFieldSize 进行更复杂的处理 - // 暂按字符串截断处理 - if len(response.body) > 0 { - respData = "[content hidden or truncated]" - } + if !ws.Config.NoLogOutput && len(response.body) > 0 { + respData = sanitizeRespBody(response.body, &ws.Config) } LogRequest(requestLogger, func(entry *RequestLog) { @@ -128,7 +119,7 @@ func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { entry.AuthLevel = authLevel entry.Priority = priority entry.RequestHeaders = reqHeaders - entry.RequestData = args + entry.RequestData = reqData entry.ResponseCode = response.Code entry.UsedTime = usedTime entry.ResponseHeaders = respHeaders @@ -433,6 +424,7 @@ func outputResult(response *Response, result any) { response.Header().Set("Content-Type", contentType) } + response.keepBody(data) if response.server != nil && response.server.hasOutFilter { response.PhysicalWrite(data) } else { @@ -490,4 +482,14 @@ func (ws *WebServer) handleClientKeys(request *Request, response *Response) { request.Request.Header.Set(discover.HeaderDeviceID, deviceId) response.Header().Set(ws.usedDeviceIdKey, deviceId) } + + // AppName / AppVersion(客户端上报,注入内部标准头供下游微服务使用) + if ws.usedClientAppKey != "" { + if appName := request.Header().Get(ws.usedClientAppKey + "Name"); appName != "" { + request.Request.Header.Set(discover.HeaderClientAppName, appName) + } + if appVersion := request.Header().Get(ws.usedClientAppKey + "Version"); appVersion != "" { + request.Request.Header.Set(discover.HeaderClientAppVersion, appVersion) + } + } } diff --git a/js_export.go b/js_export.go index b1c91e3..b1a7807 100644 --- a/js_export.go +++ b/js_export.go @@ -32,7 +32,7 @@ func jsUpgrade(response *Response, request *Request) (*WebSocketConn, error) { return conn, nil } -// jsUploadFile 包装 UploadFile 以隐藏敏感方法 +// jsUploadFile 包装 UploadFile 以隐藏敏感方法(如 Open) type jsUploadFile struct { f *UploadFile } diff --git a/log.go b/log.go index 6773301..85b304b 100644 --- a/log.go +++ b/log.go @@ -25,12 +25,12 @@ type RequestLog struct { AuthLevel int `log:"pos:22,color:green"` Priority int `log:"pos:23,hide:true"` RequestData any `log:"pos:24,color:cyan,keyname:Request"` - RequestHeaders map[string]string `log:"pos:25,color:cyan,keyname:Headers"` + RequestHeaders map[string]string `log:"pos:25,color:blue,keyname:Headers"` UsedTime float32 `log:"pos:26,color:green,precision:6"` ResponseCode int `log:"pos:27,color:magenta,keyname:Status"` - ResponseDataLength uint `log:"pos:28,color:magenta,keyname:ContentLength"` + ResponseDataLength uint `log:"pos:28,color:magenta,keyname:Size"` ResponseData any `log:"pos:29,color:magenta,keyname:Response"` - ResponseHeaders map[string]string `log:"pos:30,color:magenta,keyname:Headers"` + ResponseHeaders map[string]string `log:"pos:30,color:yellow,keyname:Headers"` } func (l *RequestLog) Reset() { diff --git a/log_sanitize.go b/log_sanitize.go new file mode 100644 index 0000000..948f416 --- /dev/null +++ b/log_sanitize.go @@ -0,0 +1,197 @@ +package service + +import ( + "net/http" + "sort" + "strings" + + "apigo.cc/go/cast" +) + +// sanitizeOpts 日志脱敏配置 +type sanitizeOpts struct { + maxSize int // 整体尺寸上限(内容字符估算) + fieldSize int // 单字符串截断长度 + arrayNum int // 数组最多保留元素数 + objectNum int // 对象最多保留 key 数 +} + +// sanitizeLogData 递归脱敏,返回新建的对象,不影响原始数据 +func sanitizeLogData(v any, opts sanitizeOpts) any { + budget := opts.maxSize + return sanitizeRecursive(v, opts, &budget) +} + +func sanitizeRecursive(v any, opts sanitizeOpts, budget *int) any { + if v == nil { + return nil + } + + switch val := v.(type) { + case string: + return sanitizeString(val, opts, budget) + case bool: + return sanitizeScalar(val, 5, budget) + case float64: + return sanitizeScalar(val, 8, budget) + case map[string]any: + return sanitizeMapContent(val, opts, budget) + case []any: + return sanitizeSliceContent(val, opts, budget) + default: + // int 等各种数值类型 + return sanitizeScalar(val, 8, budget) + } +} + +func sanitizeString(s string, opts sanitizeOpts, budget *int) string { + if len([]rune(s)) > opts.fieldSize { + s = string([]rune(s)[:opts.fieldSize]) + } + if *budget < len(s) { + *budget = 0 + return s // 即使超出预算也返回截断后的内容 + } + *budget -= len(s) + return s +} + +func sanitizeScalar(v any, cost int, budget *int) any { + if *budget < cost { + *budget = 0 + return nil + } + *budget -= cost + return v +} + +func sanitizeMapContent(m map[string]any, opts sanitizeOpts, budget *int) map[string]any { + // 空 map 占 2 预算 + if *budget < 2 { + *budget = 0 + return map[string]any{} + } + *budget -= 2 + + result := make(map[string]any) + // 排序 key 保证遍历顺序确定性 + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + sort.Strings(keys) + + count := 0 + for _, k := range keys { + if count >= opts.objectNum || *budget <= 0 { + break + } + + // 扣 key 长度 + keyCost := len(k) + if *budget < keyCost { + break + } + budgetBefore := *budget + *budget -= keyCost + + // 递归处理 value + processed := sanitizeRecursive(m[k], opts, budget) + + // 检查是否因预算不足返回了 nil(仅当原值非 nil 时) + if processed == nil && m[k] != nil { + *budget = budgetBefore + break + } + + result[k] = processed + count++ + } + return result +} + +func sanitizeSliceContent(s []any, opts sanitizeOpts, budget *int) []any { + // 空 slice 占 2 预算 + if *budget < 2 { + *budget = 0 + return []any{} + } + *budget -= 2 + + result := make([]any, 0, min(len(s), opts.arrayNum)) + count := 0 + for _, v := range s { + if count >= opts.arrayNum || *budget <= 0 { + break + } + + // 值预算由递归处理 + processed := sanitizeRecursive(v, opts, budget) + if processed == nil && v != nil { + break + } + + result = append(result, processed) + count++ + } + return result +} + +// sanitizeLogHeaders 过滤请求/响应头,排除 NoLogHeaders 中指定的字段 +func sanitizeLogHeaders(h http.Header, noLogHeaders string) map[string]string { + result := make(map[string]string) + excludes := strings.Split(noLogHeaders, ",") + for k, v := range h { + skip := false + for _, ex := range excludes { + if ex != "" && strings.EqualFold(k, strings.TrimSpace(ex)) { + skip = true + break + } + } + if !skip { + result[k] = strings.Join(v, ", ") + } + } + return result +} + +// sanitizeRespBody 对响应体进行脱敏:尝试 JSON 解析后走对象脱敏,失败则按字符串截断 +func sanitizeRespBody(body []byte, cfg *ServiceConfig) any { + // 尝试解析为 JSON 对象 + var parsed any + if err := cast.UnmarshalJSON(body, &parsed); err == nil && parsed != nil { + // 排除敏感字段 + if cfg.NoLogOutputFields != "" { + parsed = stripFields(parsed, cfg.NoLogOutputFields) + } + return sanitizeLogData(parsed, sanitizeOpts{ + maxSize: 200, + fieldSize: cfg.LogOutputFieldSize, + arrayNum: cfg.LogOutputArrayNum, + objectNum: cfg.LogOutputObjectNum, + }) + } + + // 非 JSON 内容,按字符串截断 + if len(body) > cfg.LogOutputMaxSize { + return string(body[:cfg.LogOutputMaxSize]) + "..." + } + return string(body) +} + +// stripFields 从对象中删除指定字段(仅处理顶层 map) +func stripFields(v any, fields string) any { + m, ok := v.(map[string]any) + if !ok { + return v + } + excludes := strings.Split(fields, ",") + for _, f := range excludes { + f = strings.TrimSpace(f) + if f != "" { + delete(m, f) + } + } + return m +} diff --git a/log_sanitize_test.go b/log_sanitize_test.go new file mode 100644 index 0000000..682ca72 --- /dev/null +++ b/log_sanitize_test.go @@ -0,0 +1,250 @@ +package service + +import ( + "reflect" + "testing" +) + +func TestSanitizeScalars(t *testing.T) { + opts := sanitizeOpts{maxSize: 200, fieldSize: 20, arrayNum: 3, objectNum: 10} + + tests := []struct { + name string + input any + expected any + }{ + {"nil", nil, nil}, + {"int", 42, 42}, + {"float", 3.14, 3.14}, + {"bool_true", true, true}, + {"bool_false", false, false}, + {"string_short", "hello", "hello"}, + {"string_exact", "12345678901234567890", "12345678901234567890"}, + {"string_over", "123456789012345678901234567890", "12345678901234567890"}, + {"string_unicode", "你好世界你好世界你好世界你好世界你好世界你好世界你好世界", "你好世界你好世界你好世界你好世界你好世界"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := sanitizeLogData(tt.input, opts) + if !reflect.DeepEqual(got, tt.expected) { + t.Errorf("got %v, want %v", got, tt.expected) + } + }) + } +} + +func TestSanitizeMapBasic(t *testing.T) { + opts := sanitizeOpts{maxSize: 200, fieldSize: 20, arrayNum: 3, objectNum: 10} + + input := map[string]any{ + "name": "John", + "bio": "A very long biography that should be truncated at twenty", + "status": "active", + } + expected := map[string]any{ + "name": "John", + "bio": "A very long biograph", + "status": "active", + } + + got := sanitizeLogData(input, opts) + if !reflect.DeepEqual(got, expected) { + t.Errorf("got %v, want %v", got, expected) + } +} + +func TestSanitizeMapObjectNum(t *testing.T) { + opts := sanitizeOpts{maxSize: 200, fieldSize: 20, arrayNum: 3, objectNum: 3} + + input := map[string]any{ + "k1": "v1", + "k2": "v2", + "k3": "v3", + "k4": "v4", + "k5": "v5", + } + expected := map[string]any{ + "k1": "v1", + "k2": "v2", + "k3": "v3", + } + + got := sanitizeLogData(input, opts) + if !reflect.DeepEqual(got, expected) { + t.Errorf("got %v, want %v", got, expected) + } +} + +func TestSanitizeSlice(t *testing.T) { + opts := sanitizeOpts{maxSize: 200, fieldSize: 20, arrayNum: 3, objectNum: 10} + + input := []any{1, 2, 3, 4, 5, 6, 7} + expected := []any{1, 2, 3} + + got := sanitizeLogData(input, opts) + if !reflect.DeepEqual(got, expected) { + t.Errorf("got %v, want %v", got, expected) + } +} + +func TestSanitizeSliceWithStrings(t *testing.T) { + opts := sanitizeOpts{maxSize: 200, fieldSize: 20, arrayNum: 3, objectNum: 10} + + input := []any{ + "short", + "this string is definitely way too long for twenty characters", + "ok", + "this would be fourth but arrayNum is 3", + } + expected := []any{ + "short", + "this string is defin", + "ok", + } + + got := sanitizeLogData(input, opts) + if !reflect.DeepEqual(got, expected) { + t.Errorf("got %v, want %v", got, expected) + } +} + +func TestSanitizeNested(t *testing.T) { + opts := sanitizeOpts{maxSize: 200, fieldSize: 20, arrayNum: 3, objectNum: 10} + + input := map[string]any{ + "user": map[string]any{ + "name": "John Doe", + "bio": "A very long description that should be truncated at twenty chars", + "tags": []any{"go", "javascript-long-name", "rust", "python", "java", "c++"}, + }, + "score": 100, + } + expected := map[string]any{ + "user": map[string]any{ + "name": "John Doe", + "bio": "A very long descript", + "tags": []any{"go", "javascript-long-name", "rust"}, + }, + "score": 100, + } + + got := sanitizeLogData(input, opts) + if !reflect.DeepEqual(got, expected) { + t.Errorf("got %v, want %v", got, expected) + } +} + +func TestSanitizeTopLevelArray(t *testing.T) { + opts := sanitizeOpts{maxSize: 200, fieldSize: 20, arrayNum: 3, objectNum: 10} + + input := []any{ + map[string]any{"id": 1, "name": "Alice"}, + map[string]any{"id": 2, "name": "Bob"}, + map[string]any{"id": 3, "name": "Charlie"}, + map[string]any{"id": 4, "name": "Diana"}, + map[string]any{"id": 5, "name": "Eve"}, + } + expected := []any{ + map[string]any{"id": 1, "name": "Alice"}, + map[string]any{"id": 2, "name": "Bob"}, + map[string]any{"id": 3, "name": "Charlie"}, + } + + got := sanitizeLogData(input, opts) + if !reflect.DeepEqual(got, expected) { + t.Errorf("got %v, want %v", got, expected) + } +} + +func TestSanitizeBudgetExhausted(t *testing.T) { + opts := sanitizeOpts{maxSize: 22, fieldSize: 20, arrayNum: 10, objectNum: 10} + + input := map[string]any{ + "a": "1234567890", + "b": "abcdefghij", + "c": "should-be-dropped", + } + expected := map[string]any{ + "a": "1234567890", + "b": "abcdefghij", + } + + got := sanitizeLogData(input, opts) + if !reflect.DeepEqual(got, expected) { + t.Errorf("got %v, want %v", got, expected) + } +} + +func TestSanitizeNestedBudgetExhausted(t *testing.T) { + opts := sanitizeOpts{maxSize: 30, fieldSize: 20, arrayNum: 10, objectNum: 10} + + input := map[string]any{ + "small": "hi", + "nested": map[string]any{ + "field1": "1234567890", + "field2": "abcdefghij", + }, + "after": "no", + } + // 按字母序: after(7) → nested(22) → small(被跳过) + expected := map[string]any{ + "after": "no", + "nested": map[string]any{ + "field1": "1234567890", + }, + } + + got := sanitizeLogData(input, opts) + if !reflect.DeepEqual(got, expected) { + t.Errorf("got %v, want %v", got, expected) + } +} + +func TestSanitizeEmptyContainers(t *testing.T) { + opts := sanitizeOpts{maxSize: 200, fieldSize: 20, arrayNum: 3, objectNum: 10} + + tests := []struct { + name string + input any + expected any + }{ + {"empty_map", map[string]any{}, map[string]any{}}, + {"empty_slice", []any{}, []any{}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := sanitizeLogData(tt.input, opts) + if !reflect.DeepEqual(got, tt.expected) { + t.Errorf("got %v, want %v", got, tt.expected) + } + }) + } +} + +func TestSanitizeMixedSlice(t *testing.T) { + opts := sanitizeOpts{maxSize: 200, fieldSize: 20, arrayNum: 10, objectNum: 10} + + input := []any{ + "ok", + 42, + true, + nil, + map[string]any{"nested": "value"}, + "this one is too long and will be trimmed to twenty chars", + nil, + } + expected := []any{ + "ok", + 42, + true, + nil, + map[string]any{"nested": "value"}, + "this one is too long", + nil, + } + + got := sanitizeLogData(input, opts) + if !reflect.DeepEqual(got, expected) { + t.Errorf("got %v, want %v", got, expected) + } +} diff --git a/response.go b/response.go index faf758f..9b61c8e 100644 --- a/response.go +++ b/response.go @@ -67,10 +67,8 @@ func (r *Response) Write(bytes []byte) (int, error) { return len(bytes), nil } - // 即使没有过滤器,非 200 状态码也进行缓冲以便日志记录 - if r.Code != http.StatusOK { - r.body = append(r.body, bytes...) - } + // 缓冲 body 用于日志记录 + r.keepBody(bytes) if r.ProxyHeader != nil { r.copyProxyHeader() @@ -82,6 +80,22 @@ func (r *Response) Write(bytes []byte) (int, error) { return n, nil } +// keepBody 缓冲数据用于日志记录,限制大小防止内存问题 +func (r *Response) keepBody(bytes []byte) { + maxBuf := 200 + if r.server != nil && r.server.Config.LogOutputMaxSize > 0 { + maxBuf = r.server.Config.LogOutputMaxSize + } + if len(r.body) < maxBuf { + space := maxBuf - len(r.body) + if len(bytes) <= space { + r.body = append(r.body, bytes...) + } else { + r.body = append(r.body, bytes[:space]...) + } + } +} + // PhysicalWrite 物理写入网线,绕过过滤器缓冲逻辑 func (r *Response) PhysicalWrite(bytes []byte) (int, error) { r.checkWriteHeader() diff --git a/server.go b/server.go index 5638261..6ebb1ff 100644 --- a/server.go +++ b/server.go @@ -110,6 +110,17 @@ var DefaultServer = NewWebServer() // Config 全局配置对象 (指向 DefaultServer.Config) var Config = &DefaultServer.Config +func init() { + Config.NoLogHeaders = "X-Request-Id,X-Device-Id,X-Session-Id,Cookie,Device-Id,Session-Id" + Config.LogInputObjectNum = 10 + Config.LogInputArrayNum = 5 + Config.LogInputFieldSize = 20 + Config.LogOutputObjectNum = 5 + Config.LogOutputArrayNum = 3 + Config.LogOutputFieldSize = 20 + Config.LogOutputMaxSize = 200 +} + func NewWebServer() *WebServer { ws := &WebServer{ webServices: make(map[string]map[string]*webServiceType), @@ -134,6 +145,10 @@ func NewWebServer() *WebServer { webAuthCheckers: make(map[int]func(int, *log.Logger, *string, map[string]any, *Request, *Response, *WebServiceOptions) (pass bool, object any)), injectObjects: make(map[reflect.Type]any), injectFunctions: make(map[reflect.Type]func() any), + + usedSessionIdKey: "Session-ID", + usedDeviceIdKey: "Device-ID", + usedClientAppKey: "App", } return ws } diff --git a/session.go b/session.go index 939b6dc..65588b2 100644 --- a/session.go +++ b/session.go @@ -1,14 +1,15 @@ package service import ( - "apigo.cc/go/cast" - "apigo.cc/go/jsmod" - "apigo.cc/go/log" - "apigo.cc/go/redis" "errors" "strings" "sync" "time" + + "apigo.cc/go/cast" + "apigo.cc/go/jsmod" + "apigo.cc/go/log" + "apigo.cc/go/redis" ) // Session 会话对象 @@ -74,11 +75,35 @@ func (s *Session) Get(key string) any { return s.data[key] } -// Remove 移除会话数据 -func (s *Session) Remove(key string) { +// Load 批量读取会话数据,keys 为空时返回全部数据 +func (s *Session) Load(keys []string) map[string]any { + s.lock.RLock() + defer s.lock.RUnlock() + + if len(keys) == 0 { + result := make(map[string]any, len(s.data)) + for k, v := range s.data { + result[k] = v + } + return result + } + + result := make(map[string]any, len(keys)) + for _, key := range keys { + if v, ok := s.data[key]; ok { + result[key] = v + } + } + return result +} + +// Remove 移除会话数据,支持传入多个 key +func (s *Session) Remove(keys ...string) { s.lock.Lock() defer s.lock.Unlock() - delete(s.data, key) + for _, key := range keys { + delete(s.data, key) + } } // SetAuthLevel 设置鉴权级别 @@ -91,11 +116,17 @@ func (s *Session) GetAuthLevel() int { return cast.Int(s.Get("_authLevel")) } -// Save 保存会话数据 -func (s *Session) Save() error { +// Save 保存会话数据,可选传入 map 用于批量设置后保存 +func (s *Session) Save(args ...map[string]any) error { s.lock.Lock() defer s.lock.Unlock() + if len(args) > 0 && args[0] != nil { + for k, v := range args[0] { + s.data[k] = v + } + } + timeout := Config.SessionTimeout if timeout <= 0 { timeout = 3600 diff --git a/session_test.go b/session_test.go index 36adbfc..6584d2c 100644 --- a/session_test.go +++ b/session_test.go @@ -27,6 +27,89 @@ func TestSessionLogic(t *testing.T) { t.Errorf("Expected value1 in new session instance, got %v", sess2.Get("key1")) } + // 1.1 测试 Save 批量设置 + sess3 := NewSession("test_batch", nil) + m := map[string]any{"a": 1, "b": "hello", "c": true} + if err := sess3.Save(m); err != nil { + t.Errorf("Save with map failed: %v", err) + } + sess4 := NewSession("test_batch", nil) + if sess4.Get("a") != 1 { + t.Errorf("Expected a=1, got %v", sess4.Get("a")) + } + if sess4.Get("b") != "hello" { + t.Errorf("Expected b=hello, got %v", sess4.Get("b")) + } + if sess4.Get("c") != true { + t.Errorf("Expected c=true, got %v", sess4.Get("c")) + } + + // 1.2 测试 Save 无参数仍然正常工作 + sess5 := NewSession("test_noarg", nil) + sess5.Set("x", "y") + if err := sess5.Save(); err != nil { + t.Errorf("Save without args failed: %v", err) + } + sess6 := NewSession("test_noarg", nil) + if sess6.Get("x") != "y" { + t.Errorf("Expected x=y, got %v", sess6.Get("x")) + } + + // 1.3 测试 Load 批量读取 + sess7 := NewSession("test_load", nil) + sess7.Set("k1", "v1") + sess7.Set("k2", "v2") + sess7.Set("k3", "v3") + _ = sess7.Save() + + sess8 := NewSession("test_load", nil) + result := sess8.Load([]string{"k1", "k3"}) + if len(result) != 2 { + t.Errorf("Expected 2 keys, got %d", len(result)) + } + if result["k1"] != "v1" { + t.Errorf("Expected k1=v1, got %v", result["k1"]) + } + if result["k3"] != "v3" { + t.Errorf("Expected k3=v3, got %v", result["k3"]) + } + if _, ok := result["k2"]; ok { + t.Error("k2 should not be in partial Load result") + } + + // 1.4 测试 Load 空参数返回全部数据 + allData := sess8.Load(nil) + if len(allData) < 3 { + t.Errorf("Expected at least 3 keys in full Load, got %d", len(allData)) + } + + // 1.5 测试 Remove 多 key + sess9 := NewSession("test_remove", nil) + sess9.Set("a", 1) + sess9.Set("b", 2) + sess9.Set("c", 3) + sess9.Remove("a", "c") + _ = sess9.Save() + + sess10 := NewSession("test_remove", nil) + if sess10.Get("a") != nil { + t.Errorf("Expected a removed, got %v", sess10.Get("a")) + } + if sess10.Get("b") != 2 { + t.Errorf("Expected b=2, got %v", sess10.Get("b")) + } + if sess10.Get("c") != nil { + t.Errorf("Expected c removed, got %v", sess10.Get("c")) + } + + // 1.6 测试 Remove 无参数(安全无操作) + sess9.Remove() + _ = sess9.Save() + sess11 := NewSession("test_remove", nil) + if sess11.Get("b") != 2 { + t.Errorf("Expected b=2 after no-arg Remove, got %v", sess11.Get("b")) + } + // 2. 测试 AuthFuncs 逻辑 sess.Set("funcs", []string{"user.read", "user.write", "system.admin"})