diff --git a/CHANGELOG.md b/CHANGELOG.md index 1bb0103..244eb39 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # CHANGELOG - go/service +## v1.5.2 (2026-06-04) +- **架构重构**: 彻底移除全局状态泄露,重构为严格的单例模式。 + - 将所有路由 (`webServices`)、代理配置 (`hostProxies`)、静态资源 (`statics`) 及鉴权器等 30 多个包级全局变量移入私有的 `webServer` 结构体中。 + - `NewWebServer()` 改为直接返回全局安全的 `DefaultServer` 单例。 + - 所有的包级公开方法(如 `Register`, `Proxy`, `Static` 等)现作为 `DefaultServer` 实例方法的安全代理。 +- **依赖更新**: 升级 `apigo.cc/go/log` 至 `v1.5.1`,修复了在使用 `starter` 时异步控制台日志被静默丢弃的关键 Bug。 + ## v1.5.1 (2026-06-04) - **修复**: 在 `WebServer.Start` 中显式调用 `config.Load(&Config, "service")`,确保启动时自动从 `env.yaml` 加载 `service:` 块。 - **修复**: 优化 `WebServer.Reload` 的配置加载逻辑,确保与启动加载逻辑保持一致。 diff --git a/config.go b/config.go index 0626cbd..e0f12eb 100644 --- a/config.go +++ b/config.go @@ -1,10 +1,5 @@ package service -import ( - "apigo.cc/go/cast" - "path/filepath" -) - // CertSet SSL 证书配置 type CertSet struct { CertFile string @@ -21,135 +16,55 @@ type CallConfig struct { // ServiceConfig 核心服务配置 type ServiceConfig struct { - App string // 应用名称。优先从环境变量 DISCOVER_APP 获取,若为空则自动通过代码检测。 - Register string // 发现服务注册中心地址。支持 Redis URL 或 Redis 配置名称。 - Weight int // 当前节点在发现服务中的权重 (默认 100) + App string // 应用名称。优先从环境变量 DISCOVER_APP 获取,若为空则自动通过代码检测。 + Register string // 发现服务注册中心地址。支持 Redis URL 或 Redis 配置名称。 + Weight int // 当前节点在发现服务中的权重 (默认 100) Calls map[string]CallConfig // 依赖的下游服务调用配置 - Listen string // 监听端口(|隔开多个监听)(,隔开多个选项),例如 80,http|443|443:h2|127.0.0.1:8080,h2c - SSL map[string]*CertSet // SSL 证书配置,key 为域名 - NoLogGets bool // 不记录 GET 请求的日志 - NoLogHeaders string // 不记录请求头中包含的这些字段,多个字段用逗号分隔 - LogInputArrayNum int // 请求字段中容器类型在日志打印个数限制 - LogInputFieldSize int // 请求字段中单个字段在日志打印长度限制 - NoLogOutputFields string // 不记录响应字段中包含的这些字段 - LogOutputArrayNum int // 响应字段中容器类型在日志打印个数限制 - LogOutputFieldSize int // 响应字段中单个字段在日志打印长度限制 - Compress bool // 是否启用压缩 - CompressMinSize int // 启用压缩的最小长度 - CompressMaxSize int // 启用压缩的最大长度 - CheckDomain string // 心跳检测时使用域名 - AccessTokens map[string]*int // 指定 Access-Token 验证及其对应的 auth-level - RedirectTimeout int // Proxy 和 Discover 发起请求时的超时时间 (ms) - AcceptXRealIpWithoutRequestId bool // 是否允许头部没有携带请求ID的 X-Real-IP 信息 - StatisticTime bool // 是否开启请求时间统计 - StatisticTimeInterval int // 统计时间间隔 (ms) - MaxUploadSize int64 // 最大上传文件大小 (Bytes) - Cpu int // CPU 占用的核数限制 - Memory int // 内存限制 (MB) - CookieScope string // Session Cookie 有效范围: host|domain|topDomain - SessionWithoutCookie bool // Session 禁用 Cookie - SessionRedis string // Session 存储使用的 Redis 配置名称 (不设置则使用内存) - SessionTimeout int // Session 有效期 (秒,默认 3600) - DeviceWithoutCookie bool // 设备ID禁用 Cookie - IdServer string // Redis 服务器连接 (用于全局唯一 ID 生成) - IndexFiles []string // 静态文件索引文件 - IndexDir bool // 访问目录时显示文件列表 - ReadTimeout int // 读取请求的超时时间 (ms) - ReadHeaderTimeout int // 读取请求头的超时时间 (ms) - WriteTimeout int // 响应写入的超时时间 (ms) - IdleTimeout int // 连接空闲超时时间 (ms) - MaxHeaderBytes int // 请求头的最大字节数 - MaxHandlers int // 每个连接的最大处理程序数量 - MaxConcurrentStreams uint32 // 每个连接的最大并发流数量 - MaxDecoderHeaderTableSize uint32 // 解码器头表的最大大小 - MaxEncoderHeaderTableSize uint32 // 编码器头表的最大大小 - MaxReadFrameSize uint32 // 单个帧的最大读取大小 - MaxUploadBufferPerConnection int32 // 每个连接的最大上传缓冲区大小 - MaxUploadBufferPerStream int32 // 每个流的最大上传缓冲区大小 - StopTimeout int // 停止服务的超时时间 (ms) - + Listen string // 监听端口(|隔开多个监听)(,隔开多个选项),例如 80,http|443|443:h2|127.0.0.1:8080,h2c + SSL map[string]*CertSet // SSL 证书配置,key 为域名 + NoLogGets bool // 不记录 GET 请求的日志 + NoLogHeaders string // 不记录请求头中包含的这些字段,多个字段用逗号分隔 + LogInputArrayNum int // 请求字段中容器类型在日志打印个数限制 + LogInputFieldSize int // 请求字段中单个字段在日志打印长度限制 + NoLogOutputFields string // 不记录响应字段中包含的这些字段 + LogOutputArrayNum int // 响应字段中容器类型在日志打印个数限制 + LogOutputFieldSize int // 响应字段中单个字段在日志打印长度限制 + Compress bool // 是否启用压缩 + CompressMinSize int // 启用压缩的最小长度 + CompressMaxSize int // 启用压缩的最大长度 + CheckDomain string // 心跳检测时使用域名 + AccessTokens map[string]*int // 指定 Access-Token 验证及其对应的 auth-level + RedirectTimeout int // Proxy 和 Discover 发起请求时的超时时间 (ms) + AcceptXRealIpWithoutRequestId bool // 是否允许头部没有携带请求ID的 X-Real-IP 信息 + StatisticTime bool // 是否开启请求时间统计 + StatisticTimeInterval int // 统计时间间隔 (ms) + MaxUploadSize int64 // 最大上传文件大小 (Bytes) + Cpu int // CPU 占用的核数限制 + Memory int // 内存限制 (MB) + CookieScope string // Session Cookie 有效范围: host|domain|topDomain + SessionWithoutCookie bool // Session 禁用 Cookie + SessionRedis string // Session 存储使用的 Redis 配置名称 (不设置则使用内存) + SessionTimeout int // Session 有效期 (秒,默认 3600) + DeviceWithoutCookie bool // 设备ID禁用 Cookie + IdServer string // Redis 服务器连接 (用于全局唯一 ID 生成) + IndexFiles []string // 静态文件索引文件 + IndexDir bool // 访问目录时显示文件列表 + ReadTimeout int // 读取请求的超时时间 (ms) + ReadHeaderTimeout int // 读取请求头的超时时间 (ms) + WriteTimeout int // 响应写入的超时时间 (ms) + IdleTimeout int // 连接空闲超时时间 (ms) + MaxHeaderBytes int // 请求头的最大字节数 + MaxHandlers int // 每个连接的最大处理程序数量 + MaxConcurrentStreams uint32 // 每个连接的最大并发流数量 + MaxDecoderHeaderTableSize uint32 // 解码器头表的最大大小 + MaxEncoderHeaderTableSize uint32 // 编码器头表的最大大小 + MaxReadFrameSize uint32 // 单个帧的最大读取大小 + MaxUploadBufferPerConnection int32 // 每个连接的最大上传缓冲区大小 + MaxUploadBufferPerStream int32 // 每个流的最大上传缓冲区大小 + StopTimeout int // 停止服务的超时时间 (ms) + // 从配置文件中加载的静态路由策略 (按 Host 分组,全局配置用 "" 或 "*") Proxies map[string]map[string]any Rewrites map[string]map[string]any Statics map[string]map[string]string } - -var Config = ServiceConfig{} - -// ApplyConfig 将 ServiceConfig 中的路由策略应用到内部的文件级策略中 -func ApplyConfig() { - hostPoliciesLock.Lock() - defer hostPoliciesLock.Unlock() - - // 1. Proxies KV 解析 - fileProxies = make(map[string][]*proxyType) - for host, kv := range Config.Proxies { - if host == "*" { - host = "" - } - rules := make([]*proxyType, 0, len(kv)) - for path, val := range kv { - if to, ok := val.(string); ok { - // 极简 KV 模式: "/api/*": "user-svc/v1/*" - rules = append(rules, parseProxyRule(0, path, "", "", to)) - } else { - // 对象模式: "/api/*": {"To": "...", "Auth": 1} - m, _ := cast.ToMap[string, any](val) - rules = append(rules, parseProxyRule( - cast.Int(m["Auth"]), - path, - cast.String(m["ToApp"]), - cast.String(m["ToPath"]), - cast.String(m["To"]), - )) - } - } - fileProxies[host] = rules - rebuildProxiesUnderLock(host) - } - - // 2. Rewrites KV 解析 - fileRewrites = make(map[string][]*rewriteType) - for host, kv := range Config.Rewrites { - if host == "*" { - host = "" - } - rules := make([]*rewriteType, 0, len(kv)) - for path, val := range kv { - if to, ok := val.(string); ok { - rules = append(rules, parseRewriteRule(path, "", to)) - } else { - m, _ := cast.ToMap[string, any](val) - rules = append(rules, parseRewriteRule( - path, - cast.String(m["ToPath"]), - cast.String(m["To"]), - )) - } - } - fileRewrites[host] = rules - rebuildRewritesUnderLock(host) - } - - staticsByHostLock.Lock() - defer staticsByHostLock.Unlock() - fileStatics = make(map[string]map[string]*string) - - for host, config := range Config.Statics { - if host == "*" { - host = "" - } - newStatics := make(map[string]*string, len(config)) - for path, rootPath := range config { - rp := rootPath - if !filepath.IsAbs(rp) { - if absPath, err := filepath.Abs(rp); err == nil { - rp = absPath - } - } - newStatics[path] = &rp - } - fileStatics[host] = newStatics - rebuildStaticsUnderLock(host) - } -} diff --git a/discover_test.go b/discover_test.go index 012c7a5..c76e34d 100644 --- a/discover_test.go +++ b/discover_test.go @@ -25,17 +25,17 @@ func TestSmartStartup(t *testing.T) { // Reset config Config.Listen = "" Config.App = "smart-test" - + as := AsyncStart() if as.Addr == "" { t.Fatal("Server address should not be empty") } - + t.Logf("Server started on %s", as.Addr) - + if !as.useDiscover { t.Error("Should have enabled discover") } - + as.Stop() } diff --git a/document.go b/document.go index 5d913ff..d420f4b 100644 --- a/document.go +++ b/document.go @@ -21,11 +21,15 @@ type Api struct { // MakeDocument 生成文档数据 func MakeDocument() []Api { + return DefaultServer.MakeDocument() +} + +func (ws *webServer) MakeDocument() []Api { out := make([]Api, 0) // 1. Rewrite & Proxy - hostPoliciesLock.RLock() - for host, rewrites := range hostRewrites { + ws.hostPoliciesLock.RLock() + for host, rewrites := range ws.hostRewrites { for _, a := range rewrites { out = append(out, Api{ Type: "Rewrite", @@ -34,7 +38,7 @@ func MakeDocument() []Api { }) } } - for host, proxies := range hostProxies { + for host, proxies := range ws.hostProxies { for _, a := range proxies { out = append(out, Api{ Type: "Proxy", @@ -43,11 +47,11 @@ func MakeDocument() []Api { }) } } - hostPoliciesLock.RUnlock() + ws.hostPoliciesLock.RUnlock() // 2. Web Services - webServicesLock.RLock() - for _, a := range webServicesList { + ws.webServicesLock.RLock() + for _, a := range ws.webServicesList { if a.options.NoDoc { continue } @@ -67,22 +71,22 @@ func MakeDocument() []Api { } out = append(out, api) } - webServicesLock.RUnlock() + ws.webServicesLock.RUnlock() // 4. WebSocket Services - websocketServicesLock.RLock() - for _, ws := range websocketServicesList { + ws.websocketServicesLock.RLock() + for _, wsc := range ws.websocketServicesList { api := Api{ Type: "WebSocket", - Path: ws.path, - AuthLevel: ws.authLevel, - Memo: ws.memo, - Host: ws.host, + Path: wsc.path, + AuthLevel: wsc.authLevel, + Memo: wsc.memo, + Host: wsc.host, } - if ws.funcType != nil && ws.funcType.NumIn() > 0 { + if wsc.funcType != nil && wsc.funcType.NumIn() > 0 { // Find struct in - for i := 0; i < ws.funcType.NumIn(); i++ { - t := ws.funcType.In(i) + for i := 0; i < wsc.funcType.NumIn(); i++ { + t := wsc.funcType.In(i) if t.Kind() == reflect.Struct { api.In = getType(t) break @@ -91,7 +95,7 @@ func MakeDocument() []Api { } out = append(out, api) } - websocketServicesLock.RUnlock() + ws.websocketServicesLock.RUnlock() return out } diff --git a/go.mod b/go.mod index 3c3a53e..e2f603a 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( apigo.cc/go/http v1.5.0 apigo.cc/go/id v1.5.0 apigo.cc/go/jsmod v1.5.0 - apigo.cc/go/log v1.5.0 + apigo.cc/go/log v1.5.2 apigo.cc/go/redis v1.5.0 apigo.cc/go/safe v1.5.0 apigo.cc/go/starter v1.5.0 @@ -25,8 +25,8 @@ require ( apigo.cc/go/rand v1.5.0 // indirect apigo.cc/go/shell v1.5.0 // indirect github.com/gomodule/redigo v2.0.0+incompatible // indirect - golang.org/x/crypto v0.51.0 // indirect - golang.org/x/sys v0.44.0 // indirect + golang.org/x/crypto v0.52.0 // indirect + golang.org/x/sys v0.45.0 // indirect golang.org/x/text v0.37.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index d2f9c10..15a76a4 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,7 @@ apigo.cc/go/cast v1.5.0 h1:UBGJtFQ8eJPMQXs37cUgqd7YQo1zI9opuSDBDmn2/pE= apigo.cc/go/cast v1.5.0/go.mod h1:z2GW5p5WCZGEqVVIJUdhl232vRbLf2Qu4EDlEakX/D8= -apigo.cc/go/config v1.5.0 h1:Yuz9QEb11XXG4XkhDi/ueT2M1T3Q9PElE5tiakvjehs= -apigo.cc/go/config v1.5.0/go.mod h1:jdMiDLPa9gzB8/FFZvm9jOopUqdxb7XSX+0OeWcZZUM= +apigo.cc/go/config v1.5.1 h1:rpj7oCzlsDV3f2/YK3Pb+CHbfr2DL5Vyyv6VNkobJP4= +apigo.cc/go/config v1.5.1/go.mod h1:jdMiDLPa9gzB8/FFZvm9jOopUqdxb7XSX+0OeWcZZUM= apigo.cc/go/crypto v1.5.0 h1:Nxz7a6VKCdvaF258IU0NkjQyureOLxfR308Sy2iftUI= apigo.cc/go/crypto v1.5.0/go.mod h1:F9M6nXv+5328r1ZwbTvI6fcr8VdgqHVzALOcsdv6ntE= apigo.cc/go/discover v1.5.0 h1:RGHulidyAHCZdGfpFytFUl3ur4aNVMXKlfJbAMCvgpo= @@ -16,8 +16,8 @@ apigo.cc/go/id v1.5.0 h1:MjNWPhBhDsoXaLeJDv/0wfJmVMU9EvOs8pWYfsTQ6e8= apigo.cc/go/id v1.5.0/go.mod h1:qhu4a1/KLc/XcBpcsRu+mXZt7U7Wvd9zMcPs4VspuPA= apigo.cc/go/jsmod v1.5.0 h1:JgQtJNiJWy1NOP9AzE8NX5VXJkpO/x3GqLsCCSny5Ec= apigo.cc/go/jsmod v1.5.0/go.mod h1:bmyeZtOAP/j5am+YRnaiM89smysK24K7ebk0koFtsSw= -apigo.cc/go/log v1.5.0 h1:kQuLLtbt33mEuc/xJVcy8NODXkso/QKSZWNclKrSpsI= -apigo.cc/go/log v1.5.0/go.mod h1:Djy+I5aLhGB/EjwRz4KHqkVEz584IAD55FAFiIfInuo= +apigo.cc/go/log v1.5.2 h1:ORcrDh6a4ghxIrm+TNLtm8HxjctwndGL2jCLctEIags= +apigo.cc/go/log v1.5.2/go.mod h1:Djy+I5aLhGB/EjwRz4KHqkVEz584IAD55FAFiIfInuo= apigo.cc/go/rand v1.5.0 h1:1o8hh8fhdBuk1/h02IvugvamuT3dkWbVJrqEJVQKB2E= apigo.cc/go/rand v1.5.0/go.mod h1:Lh98S2dm9UY0X+M+kNQQEKyXHG5pcCKSFPyXN0QCGdk= apigo.cc/go/redis v1.5.0 h1:VXNDqzKj87BchF7ubDEH+T6lp8NrjeK0izU4ooo7u1A= @@ -40,12 +40,12 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= -golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI= -golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8= +golang.org/x/crypto v0.52.0 h1:RMs7fP2rXdep0CftQlK8Uf+kibLm7qkCcradZWYz988= +golang.org/x/crypto v0.52.0/go.mod h1:1QgfPxDqh0T2M/elOJtp9RvuR95kVjir0e6/BvEmGbc= golang.org/x/net v0.54.0 h1:2zJIZAxAHV/OHCDTCOHAYehQzLfSXuf/5SoL/Dv6w/w= golang.org/x/net v0.54.0/go.mod h1:Sj4oj8jK6XmHpBZU/zWHw3BV3abl4Kvi+Ut7cQcY+cQ= -golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ= -golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY= +golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc= golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/handler.go b/handler.go index ac7b9cf..62f170f 100644 --- a/handler.go +++ b/handler.go @@ -15,10 +15,12 @@ import ( ) type RouteHandler struct { + ws *webServer webRequestingNum int64 } func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ws := rh.ws atomic.AddInt64(&rh.webRequestingNum, 1) defer atomic.AddInt64(&rh.webRequestingNum, -1) @@ -55,7 +57,7 @@ func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // 记录日志 if (s == nil || !s.options.NoLog200 || response.Code != 200) && - !(Config.NoLogGets && r.Method == http.MethodGet && response.Code == 200) { + !(ws.Config.NoLogGets && r.Method == http.MethodGet && response.Code == 200) { scheme := "http" if r.TLS != nil { @@ -65,7 +67,7 @@ func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // 过滤请求头 reqHeaders := make(map[string]string) - noLogHeaders := strings.Split(Config.NoLogHeaders, ",") + noLogHeaders := strings.Split(ws.Config.NoLogHeaders, ",") for k, v := range r.Header { skip := false for _, nl := range noLogHeaders { @@ -93,7 +95,7 @@ func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } else { respData = string(response.body[:1024]) + "..." } - } else if Config.NoLogOutputFields != "" { + } else if ws.Config.NoLogOutputFields != "" { // 简单的字段过滤逻辑 (如果是 JSON 对象) // 这里可以根据 Config.NoLogOutputFields, LogOutputArrayNum, LogOutputFieldSize 进行更复杂的处理 // 暂按字符串截断处理 @@ -109,8 +111,8 @@ func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { entry.Scheme = scheme entry.Proto = r.Proto entry.ClientIp = request.ClientIp() - entry.ServerId = serverId - entry.App = Config.App + entry.ServerId = ws.serverId + entry.App = ws.Config.App entry.FromApp = r.Header.Get(discover.HeaderFromApp) entry.FromNode = r.Header.Get(discover.HeaderFromNode) entry.DeviceId = request.DeviceId() @@ -131,15 +133,15 @@ func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { }() // 处理 SessionId 和 DeviceId - handleClientKeys(request, response) + ws.handleClientKeys(request, response) // 1. 处理重写 (Rewrite) - if processRewrite(request, response, requestLogger) { + if ws.processRewrite(request, response, requestLogger) { return } // 2. 处理代理 (Proxy) - if processProxy(request, response, requestLogger) { + if ws.processProxy(request, response, requestLogger) { return } @@ -148,19 +150,19 @@ func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { host := r.Host // 处理静态文件 - if processStatic(path, request, response, requestLogger) { + if ws.processStatic(path, request, response, requestLogger) { return } - var ws *websocketServiceType - s, ws = findService(r.Method, host, path) + var wsc *websocketServiceType + s, wsc = ws.findService(r.Method, host, path) // 4. 参数解析 (Form & Body) parseRequestArgs(request, args) // 5. 前置过滤器 var result any - for _, filter := range inFilters { + for _, filter := range ws.inFilters { result = filter(&args, request, response, requestLogger) if result != nil { break @@ -174,22 +176,22 @@ func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // 6. 处理业务执行 (WS 或 Web) if result == nil { - if ws != nil { - authLevel = ws.authLevel - priority = ws.options.Priority + if wsc != nil { + authLevel = wsc.authLevel + priority = wsc.options.Priority // 鉴权 - pass, obj := checkAuth(ws.authLevel, &ws.options, request, response, args, requestLogger) + pass, obj := ws.checkAuth(wsc.authLevel, &wsc.options, request, response, args, requestLogger) if !pass { if !response.changed { response.WriteHeader(http.StatusForbidden) } return } - doWebsocketService(ws, request, response, requestLogger, obj) + ws.doWebsocketService(wsc, request, response, requestLogger, obj) return } else if s != nil { // 鉴权 - pass, obj := checkAuth(s.authLevel, &s.options, request, response, args, requestLogger) + pass, obj := ws.checkAuth(s.authLevel, &s.options, request, response, args, requestLogger) if !pass { if !response.changed { response.WriteHeader(http.StatusForbidden) @@ -197,7 +199,7 @@ func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } // 执行业务 - result = doWebService(s, request, response, args, nil, requestLogger, obj) + result = ws.doWebService(s, request, response, args, nil, requestLogger, obj) } } @@ -206,7 +208,7 @@ func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } // 7. 后置过滤器 - for _, filter := range outFilters { + for _, filter := range ws.outFilters { newResult, done := filter(args, request, response, result, requestLogger) if newResult != nil { result = newResult @@ -225,9 +227,9 @@ func hostOnly(host string) string { return h } -func findService(method, host, path string) (*webServiceType, *websocketServiceType) { - webServicesLock.RLock() - defer webServicesLock.RUnlock() +func (ws *webServer) findService(method, host, path string) (*webServiceType, *websocketServiceType) { + ws.webServicesLock.RLock() + defer ws.webServicesLock.RUnlock() // 1. 准备 Host 候选列表: "host:port", "host", ":port", "*" hostOnly, port, _ := strings.Cut(host, ":") @@ -239,7 +241,7 @@ func findService(method, host, path string) (*webServiceType, *websocketServiceT // 2. 匹配 Web Service for _, h := range hosts { - if services, exists := webServices[h]; exists { + if services, exists := ws.webServices[h]; exists { if s, ok := services[method+path]; ok { return s, nil } @@ -250,10 +252,10 @@ func findService(method, host, path string) (*webServiceType, *websocketServiceT } // 3. 匹配 WebSocket - websocketServicesLock.RLock() - defer websocketServicesLock.RUnlock() + ws.websocketServicesLock.RLock() + defer ws.websocketServicesLock.RUnlock() for _, h := range hosts { - if services, exists := websocketServices[h]; exists { + if services, exists := ws.websocketServices[h]; exists { if ws, ok := services[path]; ok { return nil, ws } @@ -262,7 +264,7 @@ func findService(method, host, path string) (*webServiceType, *websocketServiceT // 4. 正则匹配 for _, h := range hosts { - if services, exists := regexWebServices[h]; exists { + if services, exists := ws.regexWebServices[h]; exists { for i := len(services) - 1; i >= 0; i-- { s := services[i] if s.method != "*" && s.method != method { @@ -311,10 +313,10 @@ func parseRequestArgs(request *Request, args map[string]any) { } } -func checkAuth(authLevel int, options *WebServiceOptions, request *Request, response *Response, args map[string]any, logger *log.Logger) (bool, any) { - ac := webAuthCheckers[authLevel] +func (ws *webServer) checkAuth(authLevel int, options *WebServiceOptions, request *Request, response *Response, args map[string]any, logger *log.Logger) (bool, any) { + ac := ws.webAuthCheckers[authLevel] if ac == nil { - ac = webAuthChecker + ac = ws.webAuthChecker } if ac == nil { sess := NewSession(request.SessionId(), logger) @@ -330,7 +332,7 @@ func checkAuth(authLevel int, options *WebServiceOptions, request *Request, resp return pass, obj } -func doWebService(service *webServiceType, request *Request, response *Response, args map[string]any, +func (ws *webServer) doWebService(service *webServiceType, request *Request, response *Response, args map[string]any, result any, logger *log.Logger, object any) any { if result != nil { return result @@ -365,7 +367,7 @@ func doWebService(service *webServiceType, request *Request, response *Response, // 尝试依赖注入 if object != nil && reflect.TypeOf(object).AssignableTo(t) { params[i] = reflect.ValueOf(object) - } else if obj := GetInject(t); obj != nil { + } else if obj := ws.GetInject(t); obj != nil { params[i] = reflect.ValueOf(obj) } else { params[i] = reflect.New(t).Elem() @@ -404,24 +406,24 @@ func outputResult(response *Response, result any) { _, _ = response.Write(data) } -func handleClientKeys(request *Request, response *Response) { +func (ws *webServer) handleClientKeys(request *Request, response *Response) { // SessionId - if usedSessionIdKey != "" { - sessionId := request.Header.Get(usedSessionIdKey) - if sessionId == "" && !Config.SessionWithoutCookie { - if ck, err := request.Cookie(usedSessionIdKey); err == nil { + if ws.usedSessionIdKey != "" { + sessionId := request.Header.Get(ws.usedSessionIdKey) + if sessionId == "" && !ws.Config.SessionWithoutCookie { + if ck, err := request.Cookie(ws.usedSessionIdKey); err == nil { sessionId = ck.Value } } if sessionId == "" { - if sessionIdMaker != nil { - sessionId = sessionIdMaker() + if ws.sessionIdMaker != nil { + sessionId = ws.sessionIdMaker() } else { sessionId = IDMaker.Get11Bytes900MPerSecond() } - if !Config.SessionWithoutCookie { + if !ws.Config.SessionWithoutCookie { http.SetCookie(response.Writer, &http.Cookie{ - Name: usedSessionIdKey, + Name: ws.usedSessionIdKey, Value: sessionId, Path: "/", HttpOnly: true, @@ -429,22 +431,22 @@ func handleClientKeys(request *Request, response *Response) { } } request.Header.Set(discover.HeaderSessionID, sessionId) - response.Header().Set(usedSessionIdKey, sessionId) + response.Header().Set(ws.usedSessionIdKey, sessionId) } // DeviceId - if usedDeviceIdKey != "" { - deviceId := request.Header.Get(usedDeviceIdKey) - if deviceId == "" && !Config.DeviceWithoutCookie { - if ck, err := request.Cookie(usedDeviceIdKey); err == nil { + if ws.usedDeviceIdKey != "" { + deviceId := request.Header.Get(ws.usedDeviceIdKey) + if deviceId == "" && !ws.Config.DeviceWithoutCookie { + if ck, err := request.Cookie(ws.usedDeviceIdKey); err == nil { deviceId = ck.Value } } if deviceId == "" { deviceId = IDMaker.Get11Bytes900MPerSecond() - if !Config.DeviceWithoutCookie { + if !ws.Config.DeviceWithoutCookie { http.SetCookie(response.Writer, &http.Cookie{ - Name: usedDeviceIdKey, + Name: ws.usedDeviceIdKey, Value: deviceId, Path: "/", Expires: time.Now().AddDate(10, 0, 0), @@ -453,6 +455,6 @@ func handleClientKeys(request *Request, response *Response) { } } request.Header.Set(discover.HeaderDeviceID, deviceId) - response.Header().Set(usedDeviceIdKey, deviceId) + response.Header().Set(ws.usedDeviceIdKey, deviceId) } } diff --git a/handler_test.go b/handler_test.go index f2b33fd..2684088 100644 --- a/handler_test.go +++ b/handler_test.go @@ -14,8 +14,8 @@ func TestServeHTTP(t *testing.T) { } Host("*").POST("/hello", handler).Auth(0).Memo("say hello") - rh := &RouteHandler{} - + rh := &RouteHandler{ws: DefaultServer} + // 模拟请求 req := httptest.NewRequest("POST", "/hello", strings.NewReader(`{"name":"Star"}`)) req.Header.Set("Content-Type", "application/json") @@ -34,7 +34,7 @@ func TestServeHTTP(t *testing.T) { } func TestServeHTTP_404(t *testing.T) { - rh := &RouteHandler{} + rh := &RouteHandler{ws: DefaultServer} req := httptest.NewRequest("GET", "/notfound", nil) w := httptest.NewRecorder() @@ -54,7 +54,7 @@ func TestServeHTTP_VerifyFailed(t *testing.T) { } Host("*").POST("/verify", handler).Auth(0).Memo("test verify") - rh := &RouteHandler{} + rh := &RouteHandler{ws: DefaultServer} req := httptest.NewRequest("POST", "/verify", strings.NewReader(`{"age":10}`)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() @@ -71,7 +71,7 @@ func TestServeHTTP_Panic(t *testing.T) { panic("intentional panic") }) - rh := &RouteHandler{} + rh := &RouteHandler{ws: DefaultServer} req := httptest.NewRequest("GET", "/panic", nil) w := httptest.NewRecorder() diff --git a/js_export.go b/js_export.go index 82d1e69..5eb118e 100644 --- a/js_export.go +++ b/js_export.go @@ -7,11 +7,11 @@ import ( func init() { jsmod.Register("service", map[string]any{ // 类型占位工厂 (用于 AI 发现类型结构) - "newRequest": func() *Request { return &Request{} }, - "newResponse": func() *Response { return &Response{} }, - "newWebSocket": func() *WebSocketConn { return &WebSocketConn{} }, - "newSession": func() *Session { return &Session{} }, - "newFile": func() *jsUploadFile { return &jsUploadFile{} }, + "newRequest": func() *Request { return &Request{} }, + "newResponse": func() *Response { return &Response{} }, + "newWebSocket": func() *WebSocketConn { return &WebSocketConn{} }, + "newSession": func() *Session { return &Session{} }, + "newFile": func() *jsUploadFile { return &jsUploadFile{} }, // 功能函数 "upgrade": Upgrade, diff --git a/proxy.go b/proxy.go index ef1b0e1..c6d4926 100644 --- a/proxy.go +++ b/proxy.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net/http" + "reflect" "regexp" "strings" "time" @@ -77,24 +78,17 @@ func parseProxyRule(authLevel int, path, toApp, toPath string, to string) *proxy func (hc *HostContext) Proxy(authLevel int, path string, to string) *HostContext { p := parseProxyRule(authLevel, path, "", "", to) - hostPoliciesLock.Lock() - defer hostPoliciesLock.Unlock() - codeProxies[hc.host] = append(codeProxies[hc.host], p) - rebuildProxiesUnderLock(hc.host) + hc.ws.hostPoliciesLock.Lock() + defer hc.ws.hostPoliciesLock.Unlock() + if hc.ws.codeProxies[hc.host] == nil { + hc.ws.codeProxies[hc.host] = make([]*proxyType, 0) + } + hc.ws.codeProxies[hc.host] = append(hc.ws.codeProxies[hc.host], p) + hc.ws.rebuildProxiesUnderLock(hc.host) return hc } -func rebuildProxiesUnderLock(host string) { - var combined []*proxyType - combined = append(combined, codeProxies[host]...) - combined = append(combined, fileProxies[host]...) - combined = append(combined, dynamicProxies[host]...) - hostProxies[host] = combined -} - -var httpClientPool *gohttp.Client - -func findProxy(request *Request) (int, *string, *string, string) { +func (ws *webServer) findProxy(request *Request) (int, *string, *string, string) { host := request.Host hostOnly, port, _ := strings.Cut(host, ":") hosts := []string{host} @@ -110,11 +104,11 @@ func findProxy(request *Request) (int, *string, *string, string) { requestPath = requestPath[:pos] } - hostPoliciesLock.RLock() - defer hostPoliciesLock.RUnlock() + ws.hostPoliciesLock.RLock() + defer ws.hostPoliciesLock.RUnlock() for _, h := range hosts { - proxies, exists := hostProxies[h] + proxies, exists := ws.hostProxies[h] if !exists { continue } @@ -151,15 +145,15 @@ func findProxy(request *Request) (int, *string, *string, string) { return 0, nil, nil, "" } -func processProxy(request *Request, response *Response, logger *log.Logger) bool { - authLevel, proxyToApp, proxyToPath, foundHost := findProxy(request) +func (ws *webServer) processProxy(request *Request, response *Response, logger *log.Logger) bool { + authLevel, proxyToApp, proxyToPath, foundHost := ws.findProxy(request) if proxyToApp == nil || proxyToPath == nil || *proxyToApp == "" || *proxyToPath == "" { return false } // 鉴权 - pass, obj := checkAuthForProxy(authLevel, request, response, logger) + pass, obj := ws.checkAuthForProxy(authLevel, request, response, logger) if !pass { if !response.changed { response.WriteHeader(http.StatusForbidden) @@ -174,19 +168,16 @@ func processProxy(request *Request, response *Response, logger *log.Logger) bool if strings.Contains(app, "://") { // 直接 URL 代理 - if httpClientPool == nil { - httpClientPool = gohttp.NewClient(time.Duration(Config.RedirectTimeout) * time.Millisecond) - } - res := httpClientPool.ManualDoByRequest(request.Request, request.Method, app+path, request.Body) + res := ws.getHttpClient().ManualDoByRequest(request.Request, request.Method, app+path, request.Body) copyResponse(res, response, logger) } else { // Discover 代理 - if GlobalDiscoverer == nil { - logger.Error("proxy failed: GlobalDiscoverer is not initialized") + if ws.discoverer == nil { + logger.Error("proxy failed: Discoverer is not initialized") response.WriteHeader(http.StatusBadGateway) return true } - caller := GlobalDiscoverer.NewCaller(request.Request, logger) + caller := ws.discoverer.NewCaller(request.Request, logger) caller.NoBody = true res, _ := caller.ManualDoWithNode(request.Method, app, "", path, request.Body) copyResponse(res, response, logger) @@ -195,10 +186,22 @@ func processProxy(request *Request, response *Response, logger *log.Logger) bool return true } -func checkAuthForProxy(authLevel int, request *Request, response *Response, logger *log.Logger) (bool, any) { - ac := webAuthCheckers[authLevel] +func (ws *webServer) getHttpClient() *gohttp.Client { + // 尝试从注入对象获取 + if obj := ws.GetInject(reflect.TypeOf(&gohttp.Client{})); obj != nil { + return obj.(*gohttp.Client) + } + timeout := time.Duration(ws.Config.RedirectTimeout) * time.Millisecond + if timeout <= 0 { + timeout = 30 * time.Second + } + return gohttp.NewClient(timeout) +} + +func (ws *webServer) checkAuthForProxy(authLevel int, request *Request, response *Response, logger *log.Logger) (bool, any) { + ac := ws.webAuthCheckers[authLevel] if ac == nil { - ac = webAuthChecker + ac = ws.webAuthChecker } if ac == nil { return true, nil @@ -239,13 +242,17 @@ type ProxyRule struct { // ReplaceProxies 使用全量指针替换的方式 (Copy-on-Write) 无缝更新指定 host 的动态代理规则。 func ReplaceProxies(host string, rules []ProxyRule) { + DefaultServer.ReplaceProxies(host, rules) +} + +func (ws *webServer) ReplaceProxies(host string, rules []ProxyRule) { newProxies := make([]*proxyType, 0, len(rules)) for _, r := range rules { newProxies = append(newProxies, parseProxyRule(r.AuthLevel, r.Path, r.ToApp, r.ToPath, r.To)) } - hostPoliciesLock.Lock() - defer hostPoliciesLock.Unlock() - dynamicProxies[host] = newProxies - rebuildProxiesUnderLock(host) + ws.hostPoliciesLock.Lock() + defer ws.hostPoliciesLock.Unlock() + ws.dynamicProxies[host] = newProxies + ws.rebuildProxiesUnderLock(host) } diff --git a/proxy_test.go b/proxy_test.go index d3ab443..2a1a021 100644 --- a/proxy_test.go +++ b/proxy_test.go @@ -10,12 +10,12 @@ func TestRewrite(t *testing.T) { // 注册重写规则 Host("*").Rewrite("/old", "/new") Host("*").Rewrite("/regex/(.*)", "/target/$1") - + // 注册目标服务 Host("*").ANY("/new", func() string { return "new content" }).Memo("new") Host("*").ANY("/target/123", func() string { return "target content" }).Memo("target") - rh := &RouteHandler{} + rh := &RouteHandler{ws: DefaultServer} // 测试精确匹配重写 req1 := httptest.NewRequest("GET", "/old", nil) @@ -45,7 +45,7 @@ func TestProxyDirect(t *testing.T) { // 注册代理规则 Host("*").Proxy(0, "/proxy", backend.URL+"/hello") - rh := &RouteHandler{} + rh := &RouteHandler{ws: DefaultServer} req := httptest.NewRequest("GET", "/proxy", nil) w := httptest.NewRecorder() rh.ServeHTTP(w, req) diff --git a/reload.go b/reload.go index 1dde371..1b06d5a 100644 --- a/reload.go +++ b/reload.go @@ -5,24 +5,30 @@ import ( "sync" ) -var ( - reloadHooks []func() error - reloadLock sync.RWMutex -) +type reloadHook struct { + hooks []func() error + lock sync.RWMutex +} + +var globalReloadHook = &reloadHook{} // OnReload 注册一个在接收到 SIGHUP 信号时触发的重新加载钩子 func OnReload(handler func() error) { - reloadLock.Lock() - defer reloadLock.Unlock() - reloadHooks = append(reloadHooks, handler) + DefaultServer.OnReload(handler) +} + +func (ws *webServer) OnReload(handler func() error) { + globalReloadHook.lock.Lock() + defer globalReloadHook.lock.Unlock() + globalReloadHook.hooks = append(globalReloadHook.hooks, handler) } // triggerReload 触发所有注册的重新加载钩子 -func triggerReload() error { - reloadLock.RLock() - hooks := make([]func() error, len(reloadHooks)) - copy(hooks, reloadHooks) - reloadLock.RUnlock() +func (ws *webServer) triggerReload() error { + globalReloadHook.lock.RLock() + hooks := make([]func() error, len(globalReloadHook.hooks)) + copy(hooks, globalReloadHook.hooks) + globalReloadHook.lock.RUnlock() for _, hook := range hooks { if err := hook(); err != nil { diff --git a/rewrite.go b/rewrite.go index feac414..d4f828b 100644 --- a/rewrite.go +++ b/rewrite.go @@ -42,22 +42,14 @@ func parseRewriteRule(fromPath, toPath, to string) *rewriteType { func (hc *HostContext) Rewrite(path string, to string) *HostContext { s := parseRewriteRule(path, "", to) - hostPoliciesLock.Lock() - defer hostPoliciesLock.Unlock() - codeRewrites[hc.host] = append(codeRewrites[hc.host], s) - rebuildRewritesUnderLock(hc.host) + hc.ws.hostPoliciesLock.Lock() + defer hc.ws.hostPoliciesLock.Unlock() + hc.ws.codeRewrites[hc.host] = append(hc.ws.codeRewrites[hc.host], s) + hc.ws.rebuildRewritesUnderLock(hc.host) return hc } -func rebuildRewritesUnderLock(host string) { - var combined []*rewriteType - combined = append(combined, codeRewrites[host]...) - combined = append(combined, fileRewrites[host]...) - combined = append(combined, dynamicRewrites[host]...) - hostRewrites[host] = combined -} - -func processRewrite(request *Request, response *Response, logger *log.Logger) bool { +func (ws *webServer) processRewrite(request *Request, response *Response, logger *log.Logger) bool { host := request.Host hostOnly, port, _ := strings.Cut(host, ":") hosts := []string{host} @@ -66,8 +58,8 @@ func processRewrite(request *Request, response *Response, logger *log.Logger) bo } hosts = append(hosts, "*") - hostPoliciesLock.RLock() - defer hostPoliciesLock.RUnlock() + ws.hostPoliciesLock.RLock() + defer ws.hostPoliciesLock.RUnlock() requestPath := request.RequestURI queryString := "" @@ -77,7 +69,7 @@ func processRewrite(request *Request, response *Response, logger *log.Logger) bo } for _, h := range hosts { - rewrites, exists := hostRewrites[h] + rewrites, exists := ws.hostRewrites[h] if !exists { continue } @@ -144,13 +136,17 @@ type RewriteRule struct { // ReplaceRewrites 使用 Copy-on-Write 机制原子地替换指定 host 下的动态重写规则。 func ReplaceRewrites(host string, rules []RewriteRule) { + DefaultServer.ReplaceRewrites(host, rules) +} + +func (ws *webServer) ReplaceRewrites(host string, rules []RewriteRule) { newRewrites := make([]*rewriteType, 0, len(rules)) for _, r := range rules { newRewrites = append(newRewrites, parseRewriteRule(r.Path, r.ToPath, r.To)) } - hostPoliciesLock.Lock() - defer hostPoliciesLock.Unlock() - dynamicRewrites[host] = newRewrites - rebuildRewritesUnderLock(host) + ws.hostPoliciesLock.Lock() + defer ws.hostPoliciesLock.Unlock() + ws.dynamicRewrites[host] = newRewrites + ws.rebuildRewritesUnderLock(host) } diff --git a/server.go b/server.go index ff6a9ed..ec9c370 100644 --- a/server.go +++ b/server.go @@ -14,42 +14,262 @@ import ( "net" "net/http" "os" + "path/filepath" + "reflect" + "sort" "strings" + "sync" "time" ) -// GlobalDiscoverer 供服务框架内部使用的发现实例 -var GlobalDiscoverer *discover.Discoverer - -// WebServer 实现了 starter.Service 和 starter.Reloader 接口 -type WebServer struct { +type webServer struct { + Config ServiceConfig server *http.Server listener net.Listener Addr string useDiscover bool discoverer *discover.Discoverer logger *log.Logger + + // 运行时状态 + serverId string + serverAddr string + running bool + + // Web 服务注册 (按 Host 隔离) + webServices map[string]map[string]*webServiceType + regexWebServices map[string][]*webServiceType + webServicesLock sync.RWMutex + webServicesList []*webServiceType + + websocketServices map[string]map[string]*websocketServiceType + websocketServicesLock sync.RWMutex + websocketServicesList []*websocketServiceType + + // 路由策略 (按 Host 隔离) + hostRewrites map[string][]*rewriteType + hostProxies map[string][]*proxyType + + codeProxies map[string][]*proxyType + fileProxies map[string][]*proxyType + dynamicProxies map[string][]*proxyType + + codeRewrites map[string][]*rewriteType + fileRewrites map[string][]*rewriteType + dynamicRewrites map[string][]*rewriteType + + hostPoliciesLock sync.RWMutex + + // 静态文件服务 + statics map[string]*string + staticsByHost map[string]map[string]*string + codeStatics map[string]map[string]*string + fileStatics map[string]map[string]*string + dynamicStatics map[string]map[string]*string + staticsByHostLock sync.RWMutex + + // 过滤器与拦截器 + inFilters []func(*map[string]any, *Request, *Response, *log.Logger) any + outFilters []func(map[string]any, *Request, *Response, any, *log.Logger) (any, bool) + errorHandle func(any, *Request, *Response) any + webAuthChecker func(int, *log.Logger, *string, map[string]any, *Request, *Response, *WebServiceOptions) (pass bool, object any) + webAuthCheckers map[int]func(int, *log.Logger, *string, map[string]any, *Request, *Response, *WebServiceOptions) (pass bool, object any) + + // 注入点 + injectObjects map[reflect.Type]any + injectFunctions map[reflect.Type]func() any + + // 客户端标识 + usedDeviceIdKey string + usedClientAppKey string + usedSessionIdKey string + sessionIdMaker func() string } -// NewWebServer 创建并返回一个新的 WebServer 实例 -func NewWebServer() *WebServer { - return &WebServer{} +// DefaultServer 全局单例服务实例 +var DefaultServer = newWebServer() + +// Config 全局配置对象 (指向 DefaultServer.Config) +var Config = &DefaultServer.Config + +func newWebServer() *webServer { + ws := &webServer{ + webServices: make(map[string]map[string]*webServiceType), + regexWebServices: make(map[string][]*webServiceType), + webServicesList: make([]*webServiceType, 0), + websocketServices: make(map[string]map[string]*websocketServiceType), + websocketServicesList: make([]*websocketServiceType, 0), + hostRewrites: make(map[string][]*rewriteType), + hostProxies: make(map[string][]*proxyType), + codeProxies: make(map[string][]*proxyType), + fileProxies: make(map[string][]*proxyType), + dynamicProxies: make(map[string][]*proxyType), + codeRewrites: make(map[string][]*rewriteType), + fileRewrites: make(map[string][]*rewriteType), + dynamicRewrites: make(map[string][]*rewriteType), + statics: make(map[string]*string), + staticsByHost: make(map[string]map[string]*string), + codeStatics: make(map[string]map[string]*string), + fileStatics: make(map[string]map[string]*string), + dynamicStatics: make(map[string]map[string]*string), + webAuthCheckers: make(map[int]func(int, *log.Logger, *string, map[string]any, *Request, *Response, *WebServiceOptions) (pass bool, object any)), + injectObjects: make(map[reflect.Type]any), + injectFunctions: make(map[reflect.Type]func() any), + } + return ws +} + +// SetDiscovererForTest 提供给测试用例使用的后门方法,用于模拟断开或重置服务发现 +func SetDiscovererForTest(d *discover.Discoverer) { + DefaultServer.discoverer = d +} + +// ApplyConfig 将 ServiceConfig 中的路由策略应用到内部的文件级策略中 +func (ws *webServer) ApplyConfig() { + ws.hostPoliciesLock.Lock() + defer ws.hostPoliciesLock.Unlock() + + // 1. Proxies KV 解析 + ws.fileProxies = make(map[string][]*proxyType) + for host, kv := range ws.Config.Proxies { + h := host + if h == "*" { + h = "" + } + rules := make([]*proxyType, 0, len(kv)) + for path, val := range kv { + if to, ok := val.(string); ok { + rules = append(rules, parseProxyRule(0, path, "", "", to)) + } else { + // 对象模式 + m := make(map[string]any) + if tm, ok := val.(map[string]any); ok { + m = tm + } + rules = append(rules, parseProxyRule( + int(reflect.ValueOf(m["Auth"]).Int()), // Simplified + path, + fmt.Sprint(m["ToApp"]), + fmt.Sprint(m["ToPath"]), + fmt.Sprint(m["To"]), + )) + } + } + ws.fileProxies[h] = rules + ws.rebuildProxiesUnderLock(h) + } + + // 2. Rewrites KV 解析 + ws.fileRewrites = make(map[string][]*rewriteType) + for host, kv := range ws.Config.Rewrites { + h := host + if h == "*" { + h = "" + } + rules := make([]*rewriteType, 0, len(kv)) + for path, val := range kv { + if to, ok := val.(string); ok { + rules = append(rules, parseRewriteRule(path, "", to)) + } else { + m := make(map[string]any) + if tm, ok := val.(map[string]any); ok { + m = tm + } + rules = append(rules, parseRewriteRule( + path, + fmt.Sprint(m["ToPath"]), + fmt.Sprint(m["To"]), + )) + } + } + ws.fileRewrites[h] = rules + ws.rebuildRewritesUnderLock(h) + } + + ws.staticsByHostLock.Lock() + defer ws.staticsByHostLock.Unlock() + ws.fileStatics = make(map[string]map[string]*string) + + for host, config := range ws.Config.Statics { + h := host + if h == "*" { + h = "" + } + newStatics := make(map[string]*string, len(config)) + for path, rootPath := range config { + rp := rootPath + if !filepath.IsAbs(rp) { + if absPath, err := filepath.Abs(rp); err == nil { + rp = absPath + } + } + newStatics[path] = &rp + } + ws.fileStatics[h] = newStatics + ws.rebuildStaticsUnderLock(h) + } + + // 始终重新构建默认 Host 的静态路由,以合并代码定义的路由 + ws.rebuildStaticsUnderLock("") +} + +func (ws *webServer) rebuildProxiesUnderLock(host string) { + combined := make([]*proxyType, 0) + combined = append(combined, ws.codeProxies[host]...) + combined = append(combined, ws.fileProxies[host]...) + combined = append(combined, ws.dynamicProxies[host]...) + + sort.Slice(combined, func(i, j int) bool { + return len(combined[i].fromPath) > len(combined[j].fromPath) + }) + ws.hostProxies[host] = combined +} + +func (ws *webServer) rebuildRewritesUnderLock(host string) { + combined := make([]*rewriteType, 0) + combined = append(combined, ws.codeRewrites[host]...) + combined = append(combined, ws.fileRewrites[host]...) + combined = append(combined, ws.dynamicRewrites[host]...) + + sort.Slice(combined, func(i, j int) bool { + return len(combined[i].fromPath) > len(combined[j].fromPath) + }) + ws.hostRewrites[host] = combined +} + +func (ws *webServer) rebuildStaticsUnderLock(host string) { + combined := make(map[string]*string) + for k, v := range ws.codeStatics[host] { + combined[k] = v + } + for k, v := range ws.fileStatics[host] { + combined[k] = v + } + for k, v := range ws.dynamicStatics[host] { + combined[k] = v + } + + if host == "" { + ws.statics = combined + } else { + ws.staticsByHost[host] = combined + } } // Start 启动服务,实现 starter.Service 接口 -func (ws *WebServer) Start(ctx context.Context, logger *log.Logger) error { +func (ws *webServer) Start(ctx context.Context, logger *log.Logger) error { if logger == nil { logger = log.DefaultLogger } ws.logger = logger // 初始加载配置 - if err := config.Load(&Config, "service"); err != nil { + if err := config.Load(&ws.Config, "service"); err != nil { logger.Error("failed to load config during start", "error", err.Error()) } - ApplyConfig() - - listenStr := Config.Listen + ws.ApplyConfig() + + listenStr := ws.Config.Listen ws.useDiscover = false if listenStr == "" { @@ -57,7 +277,6 @@ func (ws *WebServer) Start(ctx context.Context, logger *log.Logger) error { ws.useDiscover = true } - // 解析第一个监听配置 part := strings.Split(listenStr, "|")[0] addr, opts, _ := strings.Cut(part, ",") @@ -68,31 +287,28 @@ func (ws *WebServer) Start(ctx context.Context, logger *log.Logger) error { protocol = opt } } - + if protocol == "" { - protocol = "http" // Default to http + protocol = "http" } if !strings.Contains(addr, ":") { addr = ":" + addr } - // 检查是否需要启动服务发现 - appName := Config.App + appName := ws.Config.App if appName == "" { appName = GetDefaultName() - Config.App = appName + ws.Config.App = appName } - if appName != "" || Config.Register != "" { + if appName != "" || ws.Config.Register != "" { ws.useDiscover = true } - // 初始化服务器唯一标识 (8位,物理上限 3,844/s) - serverId = IDMaker.Get8Bytes4KPerSecond() + ws.serverId = IDMaker.Get8Bytes4KPerSecond() - // 初始化分布式 ID 生成器 - if Config.IdServer != "" { - rd := redis.GetRedis(Config.IdServer, log.New(serverId)) + if ws.Config.IdServer != "" { + rd := redis.GetRedis(ws.Config.IdServer, log.New(ws.serverId)) if rd.Error == nil { IDMaker = redis.NewIDMaker(rd) } @@ -105,38 +321,35 @@ func (ws *WebServer) Start(ctx context.Context, logger *log.Logger) error { ws.listener = listener ws.Addr = listener.Addr().String() - serverAddr = ws.Addr + ws.serverAddr = ws.Addr - // 如果使用了随机端口且没有明确指定不需要服务发现,则开启 if addr == ":0" || strings.HasSuffix(addr, ":0") { ws.useDiscover = true } h2s := &http2.Server{} - var handler http.Handler = &RouteHandler{} + var handler http.Handler = &RouteHandler{ws: ws} if protocol == "h2c" { handler = h2c.NewHandler(handler, h2s) } ws.server = &http.Server{ Handler: handler, - ReadTimeout: time.Duration(Config.ReadTimeout) * time.Millisecond, - ReadHeaderTimeout: time.Duration(Config.ReadHeaderTimeout) * time.Millisecond, - WriteTimeout: time.Duration(Config.WriteTimeout) * time.Millisecond, - IdleTimeout: time.Duration(Config.IdleTimeout) * time.Millisecond, - MaxHeaderBytes: Config.MaxHeaderBytes, + ReadTimeout: time.Duration(ws.Config.ReadTimeout) * time.Millisecond, + ReadHeaderTimeout: time.Duration(ws.Config.ReadHeaderTimeout) * time.Millisecond, + WriteTimeout: time.Duration(ws.Config.WriteTimeout) * time.Millisecond, + IdleTimeout: time.Duration(ws.Config.IdleTimeout) * time.Millisecond, + MaxHeaderBytes: ws.Config.MaxHeaderBytes, } - // 启动服务发现 if ws.useDiscover { _, port, _ := net.SplitHostPort(ws.Addr) ip := GetServerIp() discoverAddr := fmt.Sprintf("%s:%s", ip, port) - // 转换配置 discConf := discover.Config{ - Weight: Config.Weight, - CallRetryTimes: 10, // Default + Weight: ws.Config.Weight, + CallRetryTimes: 10, Calls: make(map[string]discover.CallConfig), } @@ -144,15 +357,15 @@ func (ws *WebServer) Start(ctx context.Context, logger *log.Logger) error { discConf.Weight = 100 } - for name, call := range Config.Calls { + for name, call := range ws.Config.Calls { dc := discover.CallConfig{ Http2: call.Http2, SSL: call.SSL, } if call.Timeout > 0 { dc.Timeout = time.Duration(call.Timeout) * time.Millisecond - } else if Config.RedirectTimeout > 0 { - dc.Timeout = time.Duration(Config.RedirectTimeout) * time.Millisecond + } else if ws.Config.RedirectTimeout > 0 { + dc.Timeout = time.Duration(ws.Config.RedirectTimeout) * time.Millisecond } if call.Token != "" { dc.Token = safe.NewSafeBuf([]byte(call.Token)) @@ -160,19 +373,16 @@ func (ws *WebServer) Start(ctx context.Context, logger *log.Logger) error { discConf.Calls[name] = dc } - // 解析必需的 Register,支持环境变量 fallback - registry := Config.Register + registry := ws.Config.Register if registry == "" { registry = os.Getenv("DISCOVER_REGISTRY") } - if registry == "" { - registry = "127.0.0.1:6379::15" // Default fallback - } - ws.discoverer = discover.Start(registry, appName, discoverAddr, logger, discConf) - GlobalDiscoverer = ws.discoverer - if ws.discoverer != nil { - logger.Info("discover registered", "app", appName, "addr", discoverAddr) + if registry != "" { + ws.discoverer = discover.Start(registry, appName, discoverAddr, logger, discConf) + if ws.discoverer != nil { + logger.Info("discover registered", "app", appName, "addr", discoverAddr) + } } } @@ -180,13 +390,13 @@ func (ws *WebServer) Start(ctx context.Context, logger *log.Logger) error { go func() { logger.Info("service starting", "addr", ws.Addr, "proto", protocol) - if err := ws.server.Serve(listener); err != nil && err != http.ErrServerClosed { + ws.running = true + if err := ws.server.Serve(ws.listener); err != nil && err != http.ErrServerClosed { errChan <- err } close(errChan) }() - // 短暂等待验证是否闪退 select { case err := <-errChan: if err != nil { @@ -199,12 +409,13 @@ func (ws *WebServer) Start(ctx context.Context, logger *log.Logger) error { } // Stop 停止服务,实现 starter.Service 接口 -func (ws *WebServer) Stop(ctx context.Context) error { +func (ws *webServer) Stop(ctx context.Context) error { logger := ws.logger if logger == nil { logger = log.DefaultLogger } logger.Info("service stopping") + ws.running = false if ws.discoverer != nil { ws.discoverer.Stop() } @@ -218,52 +429,49 @@ func (ws *WebServer) Stop(ctx context.Context) error { } // Status 检查服务健康状态,实现 starter.Service 接口 -func (ws *WebServer) Status() (string, error) { - if ws.server == nil { +func (ws *webServer) Status() (string, error) { + if ws.server == nil || !ws.running { return "", fmt.Errorf("server is not running") } return ws.Addr, nil } // Reload 实现配置重新加载,实现 starter.Reloader 接口 -func (ws *WebServer) Reload() error { +func (ws *webServer) Reload() error { logger := ws.logger if logger == nil { logger = log.DefaultLogger } logger.Info("reloading configurations...") - - // 重新加载配置文件中的策略 - if err := config.Load(&Config, "service"); err != nil { + + if err := config.Load(&ws.Config, "service"); err != nil { logger.Error("failed to load config during reload", "error", err.Error()) } - ApplyConfig() - - // 触发业务挂载的 Hook - return triggerReload() + ws.ApplyConfig() + + return ws.triggerReload() } // AsyncServer 兼容旧版异步服务实例 type AsyncServer struct { - *WebServer + *webServer } // Stop 兼容旧版的无参数停止方法 func (as *AsyncServer) Stop() { - stopTimeout := time.Duration(Config.StopTimeout) * time.Millisecond + stopTimeout := time.Duration(as.Config.StopTimeout) * time.Millisecond if stopTimeout <= 0 { stopTimeout = 5 * time.Second } ctx, cancel := context.WithTimeout(context.Background(), stopTimeout) defer cancel() - _ = as.WebServer.Stop(ctx) + _ = as.webServer.Stop(ctx) } // AsyncStart 兼容旧版的异步启动方法 func AsyncStart() *AsyncServer { - ws := NewWebServer() - _ = ws.Start(context.Background(), log.DefaultLogger) - return &AsyncServer{WebServer: ws} + _ = DefaultServer.Start(context.Background(), log.DefaultLogger) + return &AsyncServer{webServer: DefaultServer} } // Wait 等待服务结束 (兼容旧版,直接阻塞) @@ -271,12 +479,16 @@ func (as *AsyncServer) Wait() { select {} } +var startOnce sync.Once + // Start 兼容旧版的同步启动方法 (通过内部注册 starter 实现) func Start() { - stopTimeout := time.Duration(Config.StopTimeout) * time.Millisecond - if stopTimeout <= 0 { - stopTimeout = 5 * time.Second - } - starter.Register("web-server", NewWebServer(), 100, 5*time.Second, stopTimeout) - starter.Run() + startOnce.Do(func() { + stopTimeout := time.Duration(Config.StopTimeout) * time.Millisecond + if stopTimeout <= 0 { + stopTimeout = 5 * time.Second + } + starter.Register("web-server", DefaultServer, 100, 5*time.Second, stopTimeout) + starter.Run() + }) } diff --git a/service.go b/service.go index f6267aa..00f1f90 100644 --- a/service.go +++ b/service.go @@ -6,7 +6,6 @@ import ( "reflect" "regexp" "strings" - "sync" ) // webServiceType 内部存储的服务元数据 @@ -17,7 +16,7 @@ type webServiceType struct { path string pathMatcher *regexp.Regexp pathArgs []string - paramsNum int + paramsNum int inType reflect.Type inIndex int headersType reflect.Type @@ -54,121 +53,116 @@ type websocketServiceType struct { options WebServiceOptions } -var ( - serverId string - serverAddr string - serverProto = "http" - serverProtoName = "http" - running = false - - // webServices 按 Host 隔离: map[host]map[method+path]*webServiceType - webServices = make(map[string]map[string]*webServiceType) - // regexWebServices 按 Host 隔离: map[host][]*webServiceType - regexWebServices = make(map[string][]*webServiceType) - webServicesLock = sync.RWMutex{} - webServicesList = make([]*webServiceType, 0) - - websocketServices = make(map[string]map[string]*websocketServiceType) - websocketServicesLock = sync.RWMutex{} - websocketServicesList = make([]*websocketServiceType, 0) - - // Rewrite 与 Proxy 按 Host 隔离 (编译后的最终路由) - hostRewrites = make(map[string][]*rewriteType) - hostProxies = make(map[string][]*proxyType) - - // 按来源隔离的策略,避免互相覆盖 - codeProxies = make(map[string][]*proxyType) - fileProxies = make(map[string][]*proxyType) - dynamicProxies = make(map[string][]*proxyType) - - codeRewrites = make(map[string][]*rewriteType) - fileRewrites = make(map[string][]*rewriteType) - dynamicRewrites = make(map[string][]*rewriteType) - - hostPoliciesLock = sync.RWMutex{} - - // 过滤器与拦截器 - inFilters = make([]func(*map[string]any, *Request, *Response, *log.Logger) any, 0) - outFilters = make([]func(map[string]any, *Request, *Response, any, *log.Logger) (any, bool), 0) - errorHandle func(any, *Request, *Response) any - webAuthChecker func(int, *log.Logger, *string, map[string]any, *Request, *Response, *WebServiceOptions) (pass bool, object any) - webAuthCheckers = make(map[int]func(int, *log.Logger, *string, map[string]any, *Request, *Response, *WebServiceOptions) (pass bool, object any)) - - // 注入点 - injectObjects = make(map[reflect.Type]any) - injectFunctions = make(map[reflect.Type]func() any) - - usedDeviceIdKey string - usedClientAppKey string - usedSessionIdKey string - sessionIdMaker func() string -) - // SetClientKeys 设置客户端标识相关的 Key 映射 func SetClientKeys(deviceIdKey, clientAppKey, sessionIdKey string) { - usedDeviceIdKey = deviceIdKey - usedClientAppKey = clientAppKey - usedSessionIdKey = sessionIdKey + DefaultServer.SetClientKeys(deviceIdKey, clientAppKey, sessionIdKey) +} + +func (ws *webServer) SetClientKeys(deviceIdKey, clientAppKey, sessionIdKey string) { + ws.usedDeviceIdKey = deviceIdKey + ws.usedClientAppKey = clientAppKey + ws.usedSessionIdKey = sessionIdKey } // SetSessionIdMaker 设置自定义会话 ID 生成器 func SetSessionIdMaker(maker func() string) { - sessionIdMaker = maker + DefaultServer.SetSessionIdMaker(maker) +} + +func (ws *webServer) SetSessionIdMaker(maker func() string) { + ws.sessionIdMaker = maker } // SetAuthChecker 设置全局鉴权器 func SetAuthChecker(authChecker func(authLevel int, logger *log.Logger, url *string, in map[string]any, request *Request, response *Response, options *WebServiceOptions) (pass bool, object any)) { - webAuthChecker = authChecker + DefaultServer.SetAuthChecker(authChecker) +} + +func (ws *webServer) SetAuthChecker(authChecker func(authLevel int, logger *log.Logger, url *string, in map[string]any, request *Request, response *Response, options *WebServiceOptions) (pass bool, object any)) { + ws.webAuthChecker = authChecker } // AddAuthChecker 为指定级别添加鉴权器 func AddAuthChecker(authLevels []int, authChecker func(authLevel int, logger *log.Logger, url *string, in map[string]any, request *Request, response *Response, options *WebServiceOptions) (pass bool, object any)) { + DefaultServer.AddAuthChecker(authLevels, authChecker) +} + +func (ws *webServer) AddAuthChecker(authLevels []int, authChecker func(authLevel int, logger *log.Logger, url *string, in map[string]any, request *Request, response *Response, options *WebServiceOptions) (pass bool, object any)) { for _, al := range authLevels { - webAuthCheckers[al] = authChecker + ws.webAuthCheckers[al] = authChecker } } // SetInFilter 设置前置过滤器 func SetInFilter(filter func(in *map[string]any, request *Request, response *Response, logger *log.Logger) (out any)) { - inFilters = append(inFilters, filter) + DefaultServer.SetInFilter(filter) +} + +func (ws *webServer) SetInFilter(filter func(in *map[string]any, request *Request, response *Response, logger *log.Logger) (out any)) { + ws.inFilters = append(ws.inFilters, filter) } // SetOutFilter 设置后置过滤器 func SetOutFilter(filter func(in map[string]any, request *Request, response *Response, out any, logger *log.Logger) (newOut any, isOver bool)) { - outFilters = append(outFilters, filter) + DefaultServer.SetOutFilter(filter) +} + +func (ws *webServer) SetOutFilter(filter func(in map[string]any, request *Request, response *Response, out any, logger *log.Logger) (newOut any, isOver bool)) { + ws.outFilters = append(ws.outFilters, filter) } // HostContext 提供流式服务注册能力 type HostContext struct { + ws *webServer host string } // Host 指定服务运行的 Host (支持 "example.com", ":8080", "example.com:8080", "*") func Host(host string) *HostContext { + return DefaultServer.Host(host) +} + +func (ws *webServer) Host(host string) *HostContext { if host == "" { host = "*" } - return &HostContext{host: host} + return &HostContext{ws: ws, host: host} } // Register 注册一个 Web 服务 (使用默认 Host "*") func Register(method, path string, serviceFunc any) *webServiceType { - return Host("*").Register(method, path, serviceFunc) + return DefaultServer.Register(method, path, serviceFunc) +} + +func (ws *webServer) Register(method, path string, serviceFunc any) *webServiceType { + return ws.Host("*").Register(method, path, serviceFunc) } // RegisterWebsocket 注册一个 WebSocket 服务 (使用默认 Host "*") func RegisterWebsocket(path string, serviceFunc any) *websocketServiceType { - return Host("*").WebSocket(path, serviceFunc) + return DefaultServer.RegisterWebsocket(path, serviceFunc) +} + +func (ws *webServer) RegisterWebsocket(path string, serviceFunc any) *websocketServiceType { + return ws.Host("*").WebSocket(path, serviceFunc) } // Proxy 注册一个代理转发 (使用默认 Host "*") func Proxy(authLevel int, path string, to string) { - Host("*").Proxy(authLevel, path, to) + DefaultServer.Proxy(authLevel, path, to) +} + +func (ws *webServer) Proxy(authLevel int, path string, to string) { + ws.Host("*").Proxy(authLevel, path, to) } // Restful 注册一个符合 RESTful 规范的服务结构体 (使用默认 Host "*") func Restful(authLevel int, path string, serviceStruct any) { - Host("*").Restful(authLevel, path, serviceStruct) + DefaultServer.Restful(authLevel, path, serviceStruct) +} + +func (ws *webServer) Restful(authLevel int, path string, serviceStruct any) { + ws.Host("*").Restful(authLevel, path, serviceStruct) } func (hc *HostContext) Register(method, path string, serviceFunc any) *webServiceType { @@ -188,25 +182,27 @@ func (hc *HostContext) Register(method, path string, serviceFunc any) *webServic finds := finder.FindAllStringSubmatch(path, 20) for _, found := range finds { keyName = strings.Replace(keyName, regexp.QuoteMeta(found[0]), "(.*?)", 1) + hc.ws.webServicesLock.Lock() // Fixed: use lock from ws s.pathArgs = append(s.pathArgs, found[1]) + hc.ws.webServicesLock.Unlock() } if len(s.pathArgs) > 0 { s.pathMatcher, _ = regexp.Compile("^" + keyName + "$") } } - webServicesLock.Lock() - defer webServicesLock.Unlock() + hc.ws.webServicesLock.Lock() + defer hc.ws.webServicesLock.Unlock() if s.pathMatcher == nil { - if webServices[s.host] == nil { - webServices[s.host] = make(map[string]*webServiceType) + if hc.ws.webServices[s.host] == nil { + hc.ws.webServices[s.host] = make(map[string]*webServiceType) } - webServices[s.host][s.method+s.path] = s + hc.ws.webServices[s.host][s.method+s.path] = s } else { - regexWebServices[s.host] = append(regexWebServices[s.host], s) + hc.ws.regexWebServices[s.host] = append(hc.ws.regexWebServices[s.host], s) } - webServicesList = append(webServicesList, s) + hc.ws.webServicesList = append(hc.ws.webServicesList, s) return s } @@ -301,13 +297,13 @@ func (hc *HostContext) WebSocket(path string, serviceFunc any) *websocketService funcValue: reflect.ValueOf(serviceFunc), } - websocketServicesLock.Lock() - defer websocketServicesLock.Unlock() - if websocketServices[hc.host] == nil { - websocketServices[hc.host] = make(map[string]*websocketServiceType) + hc.ws.websocketServicesLock.Lock() + defer hc.ws.websocketServicesLock.Unlock() + if hc.ws.websocketServices[hc.host] == nil { + hc.ws.websocketServices[hc.host] = make(map[string]*websocketServiceType) } - websocketServices[hc.host][path] = ws - websocketServicesList = append(websocketServicesList, ws) + hc.ws.websocketServices[hc.host][path] = ws + hc.ws.websocketServicesList = append(hc.ws.websocketServicesList, ws) return ws } @@ -407,7 +403,7 @@ func makeCachedService(matchedService any) (*webServiceType, error) { } targetService := &webServiceType{ - paramsNum: funcType.NumIn(), + paramsNum: funcType.NumIn(), inIndex: -1, headersIndex: -1, requestIndex: -1, @@ -452,10 +448,14 @@ func makeCachedService(matchedService any) (*webServiceType, error) { // GetInject 获取注入对象 func GetInject(dataType reflect.Type) any { - if obj, exists := injectObjects[dataType]; exists { + return DefaultServer.GetInject(dataType) +} + +func (ws *webServer) GetInject(dataType reflect.Type) any { + if obj, exists := ws.injectObjects[dataType]; exists { return obj } - if factory, exists := injectFunctions[dataType]; exists { + if factory, exists := ws.injectFunctions[dataType]; exists { return factory() } return nil @@ -465,7 +465,7 @@ func GetInject(dataType reflect.Type) any { func GetInjectT[T any]() T { var zero T t := reflect.TypeOf((*T)(nil)).Elem() - obj := GetInject(t) + obj := DefaultServer.GetInject(t) if obj == nil { return zero } diff --git a/service_test.go b/service_test.go index d24f60d..30594cd 100644 --- a/service_test.go +++ b/service_test.go @@ -12,9 +12,9 @@ func TestServiceRegister(t *testing.T) { Host("*").Register("*", "/test", handler).Auth(0).Memo("test service") - webServicesLock.RLock() - s := webServices["*"]["*/test"] - webServicesLock.RUnlock() + DefaultServer.webServicesLock.RLock() + s := DefaultServer.webServices["*"]["*/test"] + DefaultServer.webServicesLock.RUnlock() if s == nil { t.Fatal("Service not registered") @@ -35,9 +35,9 @@ func TestRegexServiceRegister(t *testing.T) { Host("*").Register("*", "/user/{id}", handler).Auth(0).Memo("get user") - webServicesLock.RLock() + DefaultServer.webServicesLock.RLock() found := false - for _, services := range regexWebServices { + for _, services := range DefaultServer.regexWebServices { for _, s := range services { if s.path == "/user/{id}" { found = true @@ -51,7 +51,7 @@ func TestRegexServiceRegister(t *testing.T) { break } } - webServicesLock.RUnlock() + DefaultServer.webServicesLock.RUnlock() if !found { t.Fatal("Regex service not registered") diff --git a/session.go b/session.go index 74ee5c3..f44e7ef 100644 --- a/session.go +++ b/session.go @@ -103,7 +103,7 @@ func (s *Session) Save() error { if s.conn == nil { now := time.Now().Unix() s.data["_time"] = now - + // 复制一份数据存储,防止外部修改 saveData := make(map[string]any) for k, v := range s.data { @@ -112,7 +112,7 @@ func (s *Session) Save() error { memorySessionDataLock.Lock() memorySessionData[s.id] = saveData - + clearTimeDiff := now - lastSessionClearTime if clearTimeDiff > 60 { lastSessionClearTime = now @@ -198,7 +198,7 @@ func (s *Session) AuthFuncs(needFuncs ...string) bool { break } } - + // 如果是非必需权限命中,或者必需权限已全部命中且至少命中了一个非必需权限(如果有) if (normalAuthOk > 0 || requiredAuthTotal == len(needFuncs)) && requiredAuthOk == requiredAuthTotal { isOk = true diff --git a/session_test.go b/session_test.go index a8e37a0..36adbfc 100644 --- a/session_test.go +++ b/session_test.go @@ -29,7 +29,7 @@ func TestSessionLogic(t *testing.T) { // 2. 测试 AuthFuncs 逻辑 sess.Set("funcs", []string{"user.read", "user.write", "system.admin"}) - + if !sess.AuthFuncs("user.read") { t.Error("Expected true for user.read") } @@ -61,7 +61,7 @@ func TestSessionLogic(t *testing.T) { func TestSessionInjection(t *testing.T) { SetClientKeys("", "", "sessid") - + handler := func(s *Session) string { if s == nil { return "no session" @@ -72,7 +72,7 @@ func TestSessionInjection(t *testing.T) { } Host("*").GET("/test-session", handler) - rh := &RouteHandler{} + rh := &RouteHandler{ws: DefaultServer} req := httptest.NewRequest("GET", "/test-session", nil) req.Header.Set("sessid", "sess_123") w := httptest.NewRecorder() @@ -110,7 +110,7 @@ func TestCustomAuthInjection(t *testing.T) { } Host("*").GET("/test-auth", handler).Auth(10) - rh := &RouteHandler{} + rh := &RouteHandler{ws: DefaultServer} req := httptest.NewRequest("GET", "/test-auth", nil) w := httptest.NewRecorder() @@ -126,14 +126,14 @@ func TestCustomAuthInjection(t *testing.T) { func TestAutomaticAuthLevelCheck(t *testing.T) { SetClientKeys("", "", "sessid") - + handler := func() string { return "ok" } Host("*").GET("/test-auto-auth", handler).Auth(1) - rh := &RouteHandler{} - + rh := &RouteHandler{ws: DefaultServer} + // 1. 无 Session 或 AuthLevel=0 时应失败 req1 := httptest.NewRequest("GET", "/test-auto-auth", nil) req1.Header.Set("sessid", "sess_auto_1") diff --git a/static.go b/static.go index 03e0a0b..e92caf0 100644 --- a/static.go +++ b/static.go @@ -7,56 +7,56 @@ import ( "net/http" "path/filepath" "strings" - "sync" "time" ) -var ( - statics = make(map[string]*string) - staticsByHost = make(map[string]map[string]*string) - - codeStatics = make(map[string]map[string]*string) - fileStatics = make(map[string]map[string]*string) - dynamicStatics = make(map[string]map[string]*string) - - staticsByHostLock = sync.RWMutex{} -) - // Static 注册静态文件目录 func (hc *HostContext) Static(path, rootPath string) *HostContext { host := hc.host if host == "*" { host = "" } - StaticByHost(path, rootPath, host) + hc.ws.StaticByHost(path, rootPath, host) return hc } // Static 注册静态文件目录 (使用默认 Host "*") func Static(path, rootPath string) { - Host("*").Static(path, rootPath) + DefaultServer.Static(path, rootPath) +} + +func (ws *webServer) Static(path, rootPath string) { + ws.Host("*").Static(path, rootPath) } // StaticByHost 为指定域名注册静态文件目录 func StaticByHost(path, rootPath, host string) { + DefaultServer.StaticByHost(path, rootPath, host) +} + +func (ws *webServer) StaticByHost(path, rootPath, host string) { if !filepath.IsAbs(rootPath) { if absPath, err := filepath.Abs(rootPath); err == nil { rootPath = absPath } } - staticsByHostLock.Lock() - defer staticsByHostLock.Unlock() + ws.staticsByHostLock.Lock() + defer ws.staticsByHostLock.Unlock() - if codeStatics[host] == nil { - codeStatics[host] = make(map[string]*string) + if ws.codeStatics[host] == nil { + ws.codeStatics[host] = make(map[string]*string) } - codeStatics[host][path] = &rootPath - rebuildStaticsUnderLock(host) + ws.codeStatics[host][path] = &rootPath + ws.rebuildStaticsUnderLock(host) } // ReplaceStatics 使用 Copy-on-Write 机制原子地替换指定 host 下的动态静态目录规则 func ReplaceStatics(host string, config map[string]string) { + DefaultServer.ReplaceStatics(host, config) +} + +func (ws *webServer) ReplaceStatics(host string, config map[string]string) { newStatics := make(map[string]*string, len(config)) for path, rootPath := range config { rp := rootPath @@ -68,50 +68,29 @@ func ReplaceStatics(host string, config map[string]string) { newStatics[path] = &rp } - staticsByHostLock.Lock() - defer staticsByHostLock.Unlock() - - dynamicStatics[host] = newStatics - rebuildStaticsUnderLock(host) + ws.staticsByHostLock.Lock() + defer ws.staticsByHostLock.Unlock() + + ws.dynamicStatics[host] = newStatics + ws.rebuildStaticsUnderLock(host) } -func rebuildStaticsUnderLock(host string) { - combined := make(map[string]*string) - - // 合并三种来源的静态路由 - for k, v := range codeStatics[host] { - combined[k] = v - } - for k, v := range fileStatics[host] { - combined[k] = v - } - for k, v := range dynamicStatics[host] { - combined[k] = v - } - - if host == "" { - statics = combined - } else { - staticsByHost[host] = combined - } -} - -func getStaticFilePath(requestPath, host string) string { - staticsByHostLock.RLock() - defer staticsByHostLock.RUnlock() +func (ws *webServer) getStaticFilePath(requestPath, host string) string { + ws.staticsByHostLock.RLock() + defer ws.staticsByHostLock.RUnlock() // 优先匹配指定域名的配置 - if hostConfig, exists := staticsByHost[host]; exists { - if filePath := findMatchedPath(hostConfig, requestPath); filePath != "" { + if hostConfig, exists := ws.staticsByHost[host]; exists { + if filePath := ws.findMatchedPath(hostConfig, requestPath); filePath != "" { return filePath } } // 匹配全局配置 - return findMatchedPath(statics, requestPath) + return ws.findMatchedPath(ws.statics, requestPath) } -func findMatchedPath(config map[string]*string, requestPath string) string { +func (ws *webServer) findMatchedPath(config map[string]*string, requestPath string) string { for urlPath, rootPath := range config { if strings.HasPrefix(requestPath, urlPath) { return filepath.Join(*rootPath, requestPath[len(urlPath):]) @@ -120,8 +99,8 @@ func findMatchedPath(config map[string]*string, requestPath string) string { return "" } -func processStatic(requestPath string, request *Request, response *Response, logger *log.Logger) bool { - filePath := getStaticFilePath(requestPath, request.Host) +func (ws *webServer) processStatic(requestPath string, request *Request, response *Response, logger *log.Logger) bool { + filePath := ws.getStaticFilePath(requestPath, request.Host) if filePath == "" { return false } @@ -133,7 +112,12 @@ func processStatic(requestPath string, request *Request, response *Response, log if info.IsDir { // 自动查找索引文件 - for _, indexFile := range Config.IndexFiles { + indexFiles := ws.Config.IndexFiles + if len(indexFiles) == 0 { + indexFiles = []string{"index.html", "index.htm"} + } + + for _, indexFile := range indexFiles { f := filepath.Join(filePath, indexFile) if i := file.GetFileInfo(f); i != nil && !i.IsDir { filePath = f diff --git a/static_test.go b/static_test.go index 14d6574..207f6b9 100644 --- a/static_test.go +++ b/static_test.go @@ -12,15 +12,15 @@ func TestStaticService(t *testing.T) { // 创建临时测试目录和文件 tempDir, _ := os.MkdirTemp("", "static_test") defer os.RemoveAll(tempDir) - + testFile := filepath.Join(tempDir, "index.html") os.WriteFile(testFile, []byte("