publish v1.5.2

This commit is contained in:
AI Engineer 2026-06-04 18:16:46 +08:00
parent 26e6fc4b9d
commit c88139e202
22 changed files with 664 additions and 532 deletions

View File

@ -1,5 +1,12 @@
# CHANGELOG - go/service # CHANGELOG - go/service
## v1.5.2 (2026-06-04)
- **架构重构**: 彻底移除全局状态泄露,重构为严格的单例模式。
- 将所有路由 (`webServices`)、代理配置 (`hostProxies`)、静态资源 (`statics`) 及鉴权器等 30 多个包级全局变量移入私有的 `webServer` 结构体中。
- `NewWebServer()` 改为直接返回全局安全的 `DefaultServer` 单例。
- 所有的包级公开方法(如 `Register`, `Proxy`, `Static` 等)现作为 `DefaultServer` 实例方法的安全代理。
- **依赖更新**: 升级 `apigo.cc/go/log``v1.5.1`,修复了在使用 `starter` 时异步控制台日志被静默丢弃的关键 Bug。
## v1.5.1 (2026-06-04) ## v1.5.1 (2026-06-04)
- **修复**: 在 `WebServer.Start` 中显式调用 `config.Load(&Config, "service")`,确保启动时自动从 `env.yaml` 加载 `service:` 块。 - **修复**: 在 `WebServer.Start` 中显式调用 `config.Load(&Config, "service")`,确保启动时自动从 `env.yaml` 加载 `service:` 块。
- **修复**: 优化 `WebServer.Reload` 的配置加载逻辑,确保与启动加载逻辑保持一致。 - **修复**: 优化 `WebServer.Reload` 的配置加载逻辑,确保与启动加载逻辑保持一致。

175
config.go
View File

@ -1,10 +1,5 @@
package service package service
import (
"apigo.cc/go/cast"
"path/filepath"
)
// CertSet SSL 证书配置 // CertSet SSL 证书配置
type CertSet struct { type CertSet struct {
CertFile string CertFile string
@ -21,135 +16,55 @@ type CallConfig struct {
// ServiceConfig 核心服务配置 // ServiceConfig 核心服务配置
type ServiceConfig struct { type ServiceConfig struct {
App string // 应用名称。优先从环境变量 DISCOVER_APP 获取,若为空则自动通过代码检测。 App string // 应用名称。优先从环境变量 DISCOVER_APP 获取,若为空则自动通过代码检测。
Register string // 发现服务注册中心地址。支持 Redis URL 或 Redis 配置名称。 Register string // 发现服务注册中心地址。支持 Redis URL 或 Redis 配置名称。
Weight int // 当前节点在发现服务中的权重 (默认 100) Weight int // 当前节点在发现服务中的权重 (默认 100)
Calls map[string]CallConfig // 依赖的下游服务调用配置 Calls map[string]CallConfig // 依赖的下游服务调用配置
Listen string // 监听端口(|隔开多个监听)(,隔开多个选项),例如 80,http|443|443:h2|127.0.0.1:8080,h2c Listen string // 监听端口(|隔开多个监听)(,隔开多个选项),例如 80,http|443|443:h2|127.0.0.1:8080,h2c
SSL map[string]*CertSet // SSL 证书配置key 为域名 SSL map[string]*CertSet // SSL 证书配置key 为域名
NoLogGets bool // 不记录 GET 请求的日志 NoLogGets bool // 不记录 GET 请求的日志
NoLogHeaders string // 不记录请求头中包含的这些字段,多个字段用逗号分隔 NoLogHeaders string // 不记录请求头中包含的这些字段,多个字段用逗号分隔
LogInputArrayNum int // 请求字段中容器类型在日志打印个数限制 LogInputArrayNum int // 请求字段中容器类型在日志打印个数限制
LogInputFieldSize int // 请求字段中单个字段在日志打印长度限制 LogInputFieldSize int // 请求字段中单个字段在日志打印长度限制
NoLogOutputFields string // 不记录响应字段中包含的这些字段 NoLogOutputFields string // 不记录响应字段中包含的这些字段
LogOutputArrayNum int // 响应字段中容器类型在日志打印个数限制 LogOutputArrayNum int // 响应字段中容器类型在日志打印个数限制
LogOutputFieldSize int // 响应字段中单个字段在日志打印长度限制 LogOutputFieldSize int // 响应字段中单个字段在日志打印长度限制
Compress bool // 是否启用压缩 Compress bool // 是否启用压缩
CompressMinSize int // 启用压缩的最小长度 CompressMinSize int // 启用压缩的最小长度
CompressMaxSize int // 启用压缩的最大长度 CompressMaxSize int // 启用压缩的最大长度
CheckDomain string // 心跳检测时使用域名 CheckDomain string // 心跳检测时使用域名
AccessTokens map[string]*int // 指定 Access-Token 验证及其对应的 auth-level AccessTokens map[string]*int // 指定 Access-Token 验证及其对应的 auth-level
RedirectTimeout int // Proxy 和 Discover 发起请求时的超时时间 (ms) RedirectTimeout int // Proxy 和 Discover 发起请求时的超时时间 (ms)
AcceptXRealIpWithoutRequestId bool // 是否允许头部没有携带请求ID的 X-Real-IP 信息 AcceptXRealIpWithoutRequestId bool // 是否允许头部没有携带请求ID的 X-Real-IP 信息
StatisticTime bool // 是否开启请求时间统计 StatisticTime bool // 是否开启请求时间统计
StatisticTimeInterval int // 统计时间间隔 (ms) StatisticTimeInterval int // 统计时间间隔 (ms)
MaxUploadSize int64 // 最大上传文件大小 (Bytes) MaxUploadSize int64 // 最大上传文件大小 (Bytes)
Cpu int // CPU 占用的核数限制 Cpu int // CPU 占用的核数限制
Memory int // 内存限制 (MB) Memory int // 内存限制 (MB)
CookieScope string // Session Cookie 有效范围: host|domain|topDomain CookieScope string // Session Cookie 有效范围: host|domain|topDomain
SessionWithoutCookie bool // Session 禁用 Cookie SessionWithoutCookie bool // Session 禁用 Cookie
SessionRedis string // Session 存储使用的 Redis 配置名称 (不设置则使用内存) SessionRedis string // Session 存储使用的 Redis 配置名称 (不设置则使用内存)
SessionTimeout int // Session 有效期 (秒,默认 3600) SessionTimeout int // Session 有效期 (秒,默认 3600)
DeviceWithoutCookie bool // 设备ID禁用 Cookie DeviceWithoutCookie bool // 设备ID禁用 Cookie
IdServer string // Redis 服务器连接 (用于全局唯一 ID 生成) IdServer string // Redis 服务器连接 (用于全局唯一 ID 生成)
IndexFiles []string // 静态文件索引文件 IndexFiles []string // 静态文件索引文件
IndexDir bool // 访问目录时显示文件列表 IndexDir bool // 访问目录时显示文件列表
ReadTimeout int // 读取请求的超时时间 (ms) ReadTimeout int // 读取请求的超时时间 (ms)
ReadHeaderTimeout int // 读取请求头的超时时间 (ms) ReadHeaderTimeout int // 读取请求头的超时时间 (ms)
WriteTimeout int // 响应写入的超时时间 (ms) WriteTimeout int // 响应写入的超时时间 (ms)
IdleTimeout int // 连接空闲超时时间 (ms) IdleTimeout int // 连接空闲超时时间 (ms)
MaxHeaderBytes int // 请求头的最大字节数 MaxHeaderBytes int // 请求头的最大字节数
MaxHandlers int // 每个连接的最大处理程序数量 MaxHandlers int // 每个连接的最大处理程序数量
MaxConcurrentStreams uint32 // 每个连接的最大并发流数量 MaxConcurrentStreams uint32 // 每个连接的最大并发流数量
MaxDecoderHeaderTableSize uint32 // 解码器头表的最大大小 MaxDecoderHeaderTableSize uint32 // 解码器头表的最大大小
MaxEncoderHeaderTableSize uint32 // 编码器头表的最大大小 MaxEncoderHeaderTableSize uint32 // 编码器头表的最大大小
MaxReadFrameSize uint32 // 单个帧的最大读取大小 MaxReadFrameSize uint32 // 单个帧的最大读取大小
MaxUploadBufferPerConnection int32 // 每个连接的最大上传缓冲区大小 MaxUploadBufferPerConnection int32 // 每个连接的最大上传缓冲区大小
MaxUploadBufferPerStream int32 // 每个流的最大上传缓冲区大小 MaxUploadBufferPerStream int32 // 每个流的最大上传缓冲区大小
StopTimeout int // 停止服务的超时时间 (ms) StopTimeout int // 停止服务的超时时间 (ms)
// 从配置文件中加载的静态路由策略 (按 Host 分组,全局配置用 "" 或 "*") // 从配置文件中加载的静态路由策略 (按 Host 分组,全局配置用 "" 或 "*")
Proxies map[string]map[string]any Proxies map[string]map[string]any
Rewrites map[string]map[string]any Rewrites map[string]map[string]any
Statics map[string]map[string]string Statics map[string]map[string]string
} }
var Config = ServiceConfig{}
// ApplyConfig 将 ServiceConfig 中的路由策略应用到内部的文件级策略中
func ApplyConfig() {
hostPoliciesLock.Lock()
defer hostPoliciesLock.Unlock()
// 1. Proxies KV 解析
fileProxies = make(map[string][]*proxyType)
for host, kv := range Config.Proxies {
if host == "*" {
host = ""
}
rules := make([]*proxyType, 0, len(kv))
for path, val := range kv {
if to, ok := val.(string); ok {
// 极简 KV 模式: "/api/*": "user-svc/v1/*"
rules = append(rules, parseProxyRule(0, path, "", "", to))
} else {
// 对象模式: "/api/*": {"To": "...", "Auth": 1}
m, _ := cast.ToMap[string, any](val)
rules = append(rules, parseProxyRule(
cast.Int(m["Auth"]),
path,
cast.String(m["ToApp"]),
cast.String(m["ToPath"]),
cast.String(m["To"]),
))
}
}
fileProxies[host] = rules
rebuildProxiesUnderLock(host)
}
// 2. Rewrites KV 解析
fileRewrites = make(map[string][]*rewriteType)
for host, kv := range Config.Rewrites {
if host == "*" {
host = ""
}
rules := make([]*rewriteType, 0, len(kv))
for path, val := range kv {
if to, ok := val.(string); ok {
rules = append(rules, parseRewriteRule(path, "", to))
} else {
m, _ := cast.ToMap[string, any](val)
rules = append(rules, parseRewriteRule(
path,
cast.String(m["ToPath"]),
cast.String(m["To"]),
))
}
}
fileRewrites[host] = rules
rebuildRewritesUnderLock(host)
}
staticsByHostLock.Lock()
defer staticsByHostLock.Unlock()
fileStatics = make(map[string]map[string]*string)
for host, config := range Config.Statics {
if host == "*" {
host = ""
}
newStatics := make(map[string]*string, len(config))
for path, rootPath := range config {
rp := rootPath
if !filepath.IsAbs(rp) {
if absPath, err := filepath.Abs(rp); err == nil {
rp = absPath
}
}
newStatics[path] = &rp
}
fileStatics[host] = newStatics
rebuildStaticsUnderLock(host)
}
}

View File

@ -21,11 +21,15 @@ type Api struct {
// MakeDocument 生成文档数据 // MakeDocument 生成文档数据
func MakeDocument() []Api { func MakeDocument() []Api {
return DefaultServer.MakeDocument()
}
func (ws *webServer) MakeDocument() []Api {
out := make([]Api, 0) out := make([]Api, 0)
// 1. Rewrite & Proxy // 1. Rewrite & Proxy
hostPoliciesLock.RLock() ws.hostPoliciesLock.RLock()
for host, rewrites := range hostRewrites { for host, rewrites := range ws.hostRewrites {
for _, a := range rewrites { for _, a := range rewrites {
out = append(out, Api{ out = append(out, Api{
Type: "Rewrite", Type: "Rewrite",
@ -34,7 +38,7 @@ func MakeDocument() []Api {
}) })
} }
} }
for host, proxies := range hostProxies { for host, proxies := range ws.hostProxies {
for _, a := range proxies { for _, a := range proxies {
out = append(out, Api{ out = append(out, Api{
Type: "Proxy", Type: "Proxy",
@ -43,11 +47,11 @@ func MakeDocument() []Api {
}) })
} }
} }
hostPoliciesLock.RUnlock() ws.hostPoliciesLock.RUnlock()
// 2. Web Services // 2. Web Services
webServicesLock.RLock() ws.webServicesLock.RLock()
for _, a := range webServicesList { for _, a := range ws.webServicesList {
if a.options.NoDoc { if a.options.NoDoc {
continue continue
} }
@ -67,22 +71,22 @@ func MakeDocument() []Api {
} }
out = append(out, api) out = append(out, api)
} }
webServicesLock.RUnlock() ws.webServicesLock.RUnlock()
// 4. WebSocket Services // 4. WebSocket Services
websocketServicesLock.RLock() ws.websocketServicesLock.RLock()
for _, ws := range websocketServicesList { for _, wsc := range ws.websocketServicesList {
api := Api{ api := Api{
Type: "WebSocket", Type: "WebSocket",
Path: ws.path, Path: wsc.path,
AuthLevel: ws.authLevel, AuthLevel: wsc.authLevel,
Memo: ws.memo, Memo: wsc.memo,
Host: ws.host, Host: wsc.host,
} }
if ws.funcType != nil && ws.funcType.NumIn() > 0 { if wsc.funcType != nil && wsc.funcType.NumIn() > 0 {
// Find struct in // Find struct in
for i := 0; i < ws.funcType.NumIn(); i++ { for i := 0; i < wsc.funcType.NumIn(); i++ {
t := ws.funcType.In(i) t := wsc.funcType.In(i)
if t.Kind() == reflect.Struct { if t.Kind() == reflect.Struct {
api.In = getType(t) api.In = getType(t)
break break
@ -91,7 +95,7 @@ func MakeDocument() []Api {
} }
out = append(out, api) out = append(out, api)
} }
websocketServicesLock.RUnlock() ws.websocketServicesLock.RUnlock()
return out return out
} }

6
go.mod
View File

@ -10,7 +10,7 @@ require (
apigo.cc/go/http v1.5.0 apigo.cc/go/http v1.5.0
apigo.cc/go/id v1.5.0 apigo.cc/go/id v1.5.0
apigo.cc/go/jsmod v1.5.0 apigo.cc/go/jsmod v1.5.0
apigo.cc/go/log v1.5.0 apigo.cc/go/log v1.5.2
apigo.cc/go/redis v1.5.0 apigo.cc/go/redis v1.5.0
apigo.cc/go/safe v1.5.0 apigo.cc/go/safe v1.5.0
apigo.cc/go/starter v1.5.0 apigo.cc/go/starter v1.5.0
@ -25,8 +25,8 @@ require (
apigo.cc/go/rand v1.5.0 // indirect apigo.cc/go/rand v1.5.0 // indirect
apigo.cc/go/shell v1.5.0 // indirect apigo.cc/go/shell v1.5.0 // indirect
github.com/gomodule/redigo v2.0.0+incompatible // indirect github.com/gomodule/redigo v2.0.0+incompatible // indirect
golang.org/x/crypto v0.51.0 // indirect golang.org/x/crypto v0.52.0 // indirect
golang.org/x/sys v0.44.0 // indirect golang.org/x/sys v0.45.0 // indirect
golang.org/x/text v0.37.0 // indirect golang.org/x/text v0.37.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )

16
go.sum
View File

@ -1,7 +1,7 @@
apigo.cc/go/cast v1.5.0 h1:UBGJtFQ8eJPMQXs37cUgqd7YQo1zI9opuSDBDmn2/pE= apigo.cc/go/cast v1.5.0 h1:UBGJtFQ8eJPMQXs37cUgqd7YQo1zI9opuSDBDmn2/pE=
apigo.cc/go/cast v1.5.0/go.mod h1:z2GW5p5WCZGEqVVIJUdhl232vRbLf2Qu4EDlEakX/D8= apigo.cc/go/cast v1.5.0/go.mod h1:z2GW5p5WCZGEqVVIJUdhl232vRbLf2Qu4EDlEakX/D8=
apigo.cc/go/config v1.5.0 h1:Yuz9QEb11XXG4XkhDi/ueT2M1T3Q9PElE5tiakvjehs= apigo.cc/go/config v1.5.1 h1:rpj7oCzlsDV3f2/YK3Pb+CHbfr2DL5Vyyv6VNkobJP4=
apigo.cc/go/config v1.5.0/go.mod h1:jdMiDLPa9gzB8/FFZvm9jOopUqdxb7XSX+0OeWcZZUM= apigo.cc/go/config v1.5.1/go.mod h1:jdMiDLPa9gzB8/FFZvm9jOopUqdxb7XSX+0OeWcZZUM=
apigo.cc/go/crypto v1.5.0 h1:Nxz7a6VKCdvaF258IU0NkjQyureOLxfR308Sy2iftUI= apigo.cc/go/crypto v1.5.0 h1:Nxz7a6VKCdvaF258IU0NkjQyureOLxfR308Sy2iftUI=
apigo.cc/go/crypto v1.5.0/go.mod h1:F9M6nXv+5328r1ZwbTvI6fcr8VdgqHVzALOcsdv6ntE= apigo.cc/go/crypto v1.5.0/go.mod h1:F9M6nXv+5328r1ZwbTvI6fcr8VdgqHVzALOcsdv6ntE=
apigo.cc/go/discover v1.5.0 h1:RGHulidyAHCZdGfpFytFUl3ur4aNVMXKlfJbAMCvgpo= apigo.cc/go/discover v1.5.0 h1:RGHulidyAHCZdGfpFytFUl3ur4aNVMXKlfJbAMCvgpo=
@ -16,8 +16,8 @@ apigo.cc/go/id v1.5.0 h1:MjNWPhBhDsoXaLeJDv/0wfJmVMU9EvOs8pWYfsTQ6e8=
apigo.cc/go/id v1.5.0/go.mod h1:qhu4a1/KLc/XcBpcsRu+mXZt7U7Wvd9zMcPs4VspuPA= apigo.cc/go/id v1.5.0/go.mod h1:qhu4a1/KLc/XcBpcsRu+mXZt7U7Wvd9zMcPs4VspuPA=
apigo.cc/go/jsmod v1.5.0 h1:JgQtJNiJWy1NOP9AzE8NX5VXJkpO/x3GqLsCCSny5Ec= apigo.cc/go/jsmod v1.5.0 h1:JgQtJNiJWy1NOP9AzE8NX5VXJkpO/x3GqLsCCSny5Ec=
apigo.cc/go/jsmod v1.5.0/go.mod h1:bmyeZtOAP/j5am+YRnaiM89smysK24K7ebk0koFtsSw= apigo.cc/go/jsmod v1.5.0/go.mod h1:bmyeZtOAP/j5am+YRnaiM89smysK24K7ebk0koFtsSw=
apigo.cc/go/log v1.5.0 h1:kQuLLtbt33mEuc/xJVcy8NODXkso/QKSZWNclKrSpsI= apigo.cc/go/log v1.5.2 h1:ORcrDh6a4ghxIrm+TNLtm8HxjctwndGL2jCLctEIags=
apigo.cc/go/log v1.5.0/go.mod h1:Djy+I5aLhGB/EjwRz4KHqkVEz584IAD55FAFiIfInuo= apigo.cc/go/log v1.5.2/go.mod h1:Djy+I5aLhGB/EjwRz4KHqkVEz584IAD55FAFiIfInuo=
apigo.cc/go/rand v1.5.0 h1:1o8hh8fhdBuk1/h02IvugvamuT3dkWbVJrqEJVQKB2E= apigo.cc/go/rand v1.5.0 h1:1o8hh8fhdBuk1/h02IvugvamuT3dkWbVJrqEJVQKB2E=
apigo.cc/go/rand v1.5.0/go.mod h1:Lh98S2dm9UY0X+M+kNQQEKyXHG5pcCKSFPyXN0QCGdk= apigo.cc/go/rand v1.5.0/go.mod h1:Lh98S2dm9UY0X+M+kNQQEKyXHG5pcCKSFPyXN0QCGdk=
apigo.cc/go/redis v1.5.0 h1:VXNDqzKj87BchF7ubDEH+T6lp8NrjeK0izU4ooo7u1A= apigo.cc/go/redis v1.5.0 h1:VXNDqzKj87BchF7ubDEH+T6lp8NrjeK0izU4ooo7u1A=
@ -40,12 +40,12 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI= golang.org/x/crypto v0.52.0 h1:RMs7fP2rXdep0CftQlK8Uf+kibLm7qkCcradZWYz988=
golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8= golang.org/x/crypto v0.52.0/go.mod h1:1QgfPxDqh0T2M/elOJtp9RvuR95kVjir0e6/BvEmGbc=
golang.org/x/net v0.54.0 h1:2zJIZAxAHV/OHCDTCOHAYehQzLfSXuf/5SoL/Dv6w/w= golang.org/x/net v0.54.0 h1:2zJIZAxAHV/OHCDTCOHAYehQzLfSXuf/5SoL/Dv6w/w=
golang.org/x/net v0.54.0/go.mod h1:Sj4oj8jK6XmHpBZU/zWHw3BV3abl4Kvi+Ut7cQcY+cQ= golang.org/x/net v0.54.0/go.mod h1:Sj4oj8jK6XmHpBZU/zWHw3BV3abl4Kvi+Ut7cQcY+cQ=
golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ= golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY=
golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc= golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38= golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@ -15,10 +15,12 @@ import (
) )
type RouteHandler struct { type RouteHandler struct {
ws *webServer
webRequestingNum int64 webRequestingNum int64
} }
func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ws := rh.ws
atomic.AddInt64(&rh.webRequestingNum, 1) atomic.AddInt64(&rh.webRequestingNum, 1)
defer atomic.AddInt64(&rh.webRequestingNum, -1) defer atomic.AddInt64(&rh.webRequestingNum, -1)
@ -55,7 +57,7 @@ func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// 记录日志 // 记录日志
if (s == nil || !s.options.NoLog200 || response.Code != 200) && if (s == nil || !s.options.NoLog200 || response.Code != 200) &&
!(Config.NoLogGets && r.Method == http.MethodGet && response.Code == 200) { !(ws.Config.NoLogGets && r.Method == http.MethodGet && response.Code == 200) {
scheme := "http" scheme := "http"
if r.TLS != nil { if r.TLS != nil {
@ -65,7 +67,7 @@ func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// 过滤请求头 // 过滤请求头
reqHeaders := make(map[string]string) reqHeaders := make(map[string]string)
noLogHeaders := strings.Split(Config.NoLogHeaders, ",") noLogHeaders := strings.Split(ws.Config.NoLogHeaders, ",")
for k, v := range r.Header { for k, v := range r.Header {
skip := false skip := false
for _, nl := range noLogHeaders { for _, nl := range noLogHeaders {
@ -93,7 +95,7 @@ func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} else { } else {
respData = string(response.body[:1024]) + "..." respData = string(response.body[:1024]) + "..."
} }
} else if Config.NoLogOutputFields != "" { } else if ws.Config.NoLogOutputFields != "" {
// 简单的字段过滤逻辑 (如果是 JSON 对象) // 简单的字段过滤逻辑 (如果是 JSON 对象)
// 这里可以根据 Config.NoLogOutputFields, LogOutputArrayNum, LogOutputFieldSize 进行更复杂的处理 // 这里可以根据 Config.NoLogOutputFields, LogOutputArrayNum, LogOutputFieldSize 进行更复杂的处理
// 暂按字符串截断处理 // 暂按字符串截断处理
@ -109,8 +111,8 @@ func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
entry.Scheme = scheme entry.Scheme = scheme
entry.Proto = r.Proto entry.Proto = r.Proto
entry.ClientIp = request.ClientIp() entry.ClientIp = request.ClientIp()
entry.ServerId = serverId entry.ServerId = ws.serverId
entry.App = Config.App entry.App = ws.Config.App
entry.FromApp = r.Header.Get(discover.HeaderFromApp) entry.FromApp = r.Header.Get(discover.HeaderFromApp)
entry.FromNode = r.Header.Get(discover.HeaderFromNode) entry.FromNode = r.Header.Get(discover.HeaderFromNode)
entry.DeviceId = request.DeviceId() entry.DeviceId = request.DeviceId()
@ -131,15 +133,15 @@ func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}() }()
// 处理 SessionId 和 DeviceId // 处理 SessionId 和 DeviceId
handleClientKeys(request, response) ws.handleClientKeys(request, response)
// 1. 处理重写 (Rewrite) // 1. 处理重写 (Rewrite)
if processRewrite(request, response, requestLogger) { if ws.processRewrite(request, response, requestLogger) {
return return
} }
// 2. 处理代理 (Proxy) // 2. 处理代理 (Proxy)
if processProxy(request, response, requestLogger) { if ws.processProxy(request, response, requestLogger) {
return return
} }
@ -148,19 +150,19 @@ func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
host := r.Host host := r.Host
// 处理静态文件 // 处理静态文件
if processStatic(path, request, response, requestLogger) { if ws.processStatic(path, request, response, requestLogger) {
return return
} }
var ws *websocketServiceType var wsc *websocketServiceType
s, ws = findService(r.Method, host, path) s, wsc = ws.findService(r.Method, host, path)
// 4. 参数解析 (Form & Body) // 4. 参数解析 (Form & Body)
parseRequestArgs(request, args) parseRequestArgs(request, args)
// 5. 前置过滤器 // 5. 前置过滤器
var result any var result any
for _, filter := range inFilters { for _, filter := range ws.inFilters {
result = filter(&args, request, response, requestLogger) result = filter(&args, request, response, requestLogger)
if result != nil { if result != nil {
break break
@ -174,22 +176,22 @@ func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// 6. 处理业务执行 (WS 或 Web) // 6. 处理业务执行 (WS 或 Web)
if result == nil { if result == nil {
if ws != nil { if wsc != nil {
authLevel = ws.authLevel authLevel = wsc.authLevel
priority = ws.options.Priority priority = wsc.options.Priority
// 鉴权 // 鉴权
pass, obj := checkAuth(ws.authLevel, &ws.options, request, response, args, requestLogger) pass, obj := ws.checkAuth(wsc.authLevel, &wsc.options, request, response, args, requestLogger)
if !pass { if !pass {
if !response.changed { if !response.changed {
response.WriteHeader(http.StatusForbidden) response.WriteHeader(http.StatusForbidden)
} }
return return
} }
doWebsocketService(ws, request, response, requestLogger, obj) ws.doWebsocketService(wsc, request, response, requestLogger, obj)
return return
} else if s != nil { } else if s != nil {
// 鉴权 // 鉴权
pass, obj := checkAuth(s.authLevel, &s.options, request, response, args, requestLogger) pass, obj := ws.checkAuth(s.authLevel, &s.options, request, response, args, requestLogger)
if !pass { if !pass {
if !response.changed { if !response.changed {
response.WriteHeader(http.StatusForbidden) response.WriteHeader(http.StatusForbidden)
@ -197,7 +199,7 @@ func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
// 执行业务 // 执行业务
result = doWebService(s, request, response, args, nil, requestLogger, obj) result = ws.doWebService(s, request, response, args, nil, requestLogger, obj)
} }
} }
@ -206,7 +208,7 @@ func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
// 7. 后置过滤器 // 7. 后置过滤器
for _, filter := range outFilters { for _, filter := range ws.outFilters {
newResult, done := filter(args, request, response, result, requestLogger) newResult, done := filter(args, request, response, result, requestLogger)
if newResult != nil { if newResult != nil {
result = newResult result = newResult
@ -225,9 +227,9 @@ func hostOnly(host string) string {
return h return h
} }
func findService(method, host, path string) (*webServiceType, *websocketServiceType) { func (ws *webServer) findService(method, host, path string) (*webServiceType, *websocketServiceType) {
webServicesLock.RLock() ws.webServicesLock.RLock()
defer webServicesLock.RUnlock() defer ws.webServicesLock.RUnlock()
// 1. 准备 Host 候选列表: "host:port", "host", ":port", "*" // 1. 准备 Host 候选列表: "host:port", "host", ":port", "*"
hostOnly, port, _ := strings.Cut(host, ":") hostOnly, port, _ := strings.Cut(host, ":")
@ -239,7 +241,7 @@ func findService(method, host, path string) (*webServiceType, *websocketServiceT
// 2. 匹配 Web Service // 2. 匹配 Web Service
for _, h := range hosts { for _, h := range hosts {
if services, exists := webServices[h]; exists { if services, exists := ws.webServices[h]; exists {
if s, ok := services[method+path]; ok { if s, ok := services[method+path]; ok {
return s, nil return s, nil
} }
@ -250,10 +252,10 @@ func findService(method, host, path string) (*webServiceType, *websocketServiceT
} }
// 3. 匹配 WebSocket // 3. 匹配 WebSocket
websocketServicesLock.RLock() ws.websocketServicesLock.RLock()
defer websocketServicesLock.RUnlock() defer ws.websocketServicesLock.RUnlock()
for _, h := range hosts { for _, h := range hosts {
if services, exists := websocketServices[h]; exists { if services, exists := ws.websocketServices[h]; exists {
if ws, ok := services[path]; ok { if ws, ok := services[path]; ok {
return nil, ws return nil, ws
} }
@ -262,7 +264,7 @@ func findService(method, host, path string) (*webServiceType, *websocketServiceT
// 4. 正则匹配 // 4. 正则匹配
for _, h := range hosts { for _, h := range hosts {
if services, exists := regexWebServices[h]; exists { if services, exists := ws.regexWebServices[h]; exists {
for i := len(services) - 1; i >= 0; i-- { for i := len(services) - 1; i >= 0; i-- {
s := services[i] s := services[i]
if s.method != "*" && s.method != method { if s.method != "*" && s.method != method {
@ -311,10 +313,10 @@ func parseRequestArgs(request *Request, args map[string]any) {
} }
} }
func checkAuth(authLevel int, options *WebServiceOptions, request *Request, response *Response, args map[string]any, logger *log.Logger) (bool, any) { func (ws *webServer) checkAuth(authLevel int, options *WebServiceOptions, request *Request, response *Response, args map[string]any, logger *log.Logger) (bool, any) {
ac := webAuthCheckers[authLevel] ac := ws.webAuthCheckers[authLevel]
if ac == nil { if ac == nil {
ac = webAuthChecker ac = ws.webAuthChecker
} }
if ac == nil { if ac == nil {
sess := NewSession(request.SessionId(), logger) sess := NewSession(request.SessionId(), logger)
@ -330,7 +332,7 @@ func checkAuth(authLevel int, options *WebServiceOptions, request *Request, resp
return pass, obj return pass, obj
} }
func doWebService(service *webServiceType, request *Request, response *Response, args map[string]any, func (ws *webServer) doWebService(service *webServiceType, request *Request, response *Response, args map[string]any,
result any, logger *log.Logger, object any) any { result any, logger *log.Logger, object any) any {
if result != nil { if result != nil {
return result return result
@ -365,7 +367,7 @@ func doWebService(service *webServiceType, request *Request, response *Response,
// 尝试依赖注入 // 尝试依赖注入
if object != nil && reflect.TypeOf(object).AssignableTo(t) { if object != nil && reflect.TypeOf(object).AssignableTo(t) {
params[i] = reflect.ValueOf(object) params[i] = reflect.ValueOf(object)
} else if obj := GetInject(t); obj != nil { } else if obj := ws.GetInject(t); obj != nil {
params[i] = reflect.ValueOf(obj) params[i] = reflect.ValueOf(obj)
} else { } else {
params[i] = reflect.New(t).Elem() params[i] = reflect.New(t).Elem()
@ -404,24 +406,24 @@ func outputResult(response *Response, result any) {
_, _ = response.Write(data) _, _ = response.Write(data)
} }
func handleClientKeys(request *Request, response *Response) { func (ws *webServer) handleClientKeys(request *Request, response *Response) {
// SessionId // SessionId
if usedSessionIdKey != "" { if ws.usedSessionIdKey != "" {
sessionId := request.Header.Get(usedSessionIdKey) sessionId := request.Header.Get(ws.usedSessionIdKey)
if sessionId == "" && !Config.SessionWithoutCookie { if sessionId == "" && !ws.Config.SessionWithoutCookie {
if ck, err := request.Cookie(usedSessionIdKey); err == nil { if ck, err := request.Cookie(ws.usedSessionIdKey); err == nil {
sessionId = ck.Value sessionId = ck.Value
} }
} }
if sessionId == "" { if sessionId == "" {
if sessionIdMaker != nil { if ws.sessionIdMaker != nil {
sessionId = sessionIdMaker() sessionId = ws.sessionIdMaker()
} else { } else {
sessionId = IDMaker.Get11Bytes900MPerSecond() sessionId = IDMaker.Get11Bytes900MPerSecond()
} }
if !Config.SessionWithoutCookie { if !ws.Config.SessionWithoutCookie {
http.SetCookie(response.Writer, &http.Cookie{ http.SetCookie(response.Writer, &http.Cookie{
Name: usedSessionIdKey, Name: ws.usedSessionIdKey,
Value: sessionId, Value: sessionId,
Path: "/", Path: "/",
HttpOnly: true, HttpOnly: true,
@ -429,22 +431,22 @@ func handleClientKeys(request *Request, response *Response) {
} }
} }
request.Header.Set(discover.HeaderSessionID, sessionId) request.Header.Set(discover.HeaderSessionID, sessionId)
response.Header().Set(usedSessionIdKey, sessionId) response.Header().Set(ws.usedSessionIdKey, sessionId)
} }
// DeviceId // DeviceId
if usedDeviceIdKey != "" { if ws.usedDeviceIdKey != "" {
deviceId := request.Header.Get(usedDeviceIdKey) deviceId := request.Header.Get(ws.usedDeviceIdKey)
if deviceId == "" && !Config.DeviceWithoutCookie { if deviceId == "" && !ws.Config.DeviceWithoutCookie {
if ck, err := request.Cookie(usedDeviceIdKey); err == nil { if ck, err := request.Cookie(ws.usedDeviceIdKey); err == nil {
deviceId = ck.Value deviceId = ck.Value
} }
} }
if deviceId == "" { if deviceId == "" {
deviceId = IDMaker.Get11Bytes900MPerSecond() deviceId = IDMaker.Get11Bytes900MPerSecond()
if !Config.DeviceWithoutCookie { if !ws.Config.DeviceWithoutCookie {
http.SetCookie(response.Writer, &http.Cookie{ http.SetCookie(response.Writer, &http.Cookie{
Name: usedDeviceIdKey, Name: ws.usedDeviceIdKey,
Value: deviceId, Value: deviceId,
Path: "/", Path: "/",
Expires: time.Now().AddDate(10, 0, 0), Expires: time.Now().AddDate(10, 0, 0),
@ -453,6 +455,6 @@ func handleClientKeys(request *Request, response *Response) {
} }
} }
request.Header.Set(discover.HeaderDeviceID, deviceId) request.Header.Set(discover.HeaderDeviceID, deviceId)
response.Header().Set(usedDeviceIdKey, deviceId) response.Header().Set(ws.usedDeviceIdKey, deviceId)
} }
} }

View File

@ -14,7 +14,7 @@ func TestServeHTTP(t *testing.T) {
} }
Host("*").POST("/hello", handler).Auth(0).Memo("say hello") Host("*").POST("/hello", handler).Auth(0).Memo("say hello")
rh := &RouteHandler{} rh := &RouteHandler{ws: DefaultServer}
// 模拟请求 // 模拟请求
req := httptest.NewRequest("POST", "/hello", strings.NewReader(`{"name":"Star"}`)) req := httptest.NewRequest("POST", "/hello", strings.NewReader(`{"name":"Star"}`))
@ -34,7 +34,7 @@ func TestServeHTTP(t *testing.T) {
} }
func TestServeHTTP_404(t *testing.T) { func TestServeHTTP_404(t *testing.T) {
rh := &RouteHandler{} rh := &RouteHandler{ws: DefaultServer}
req := httptest.NewRequest("GET", "/notfound", nil) req := httptest.NewRequest("GET", "/notfound", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -54,7 +54,7 @@ func TestServeHTTP_VerifyFailed(t *testing.T) {
} }
Host("*").POST("/verify", handler).Auth(0).Memo("test verify") Host("*").POST("/verify", handler).Auth(0).Memo("test verify")
rh := &RouteHandler{} rh := &RouteHandler{ws: DefaultServer}
req := httptest.NewRequest("POST", "/verify", strings.NewReader(`{"age":10}`)) req := httptest.NewRequest("POST", "/verify", strings.NewReader(`{"age":10}`))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -71,7 +71,7 @@ func TestServeHTTP_Panic(t *testing.T) {
panic("intentional panic") panic("intentional panic")
}) })
rh := &RouteHandler{} rh := &RouteHandler{ws: DefaultServer}
req := httptest.NewRequest("GET", "/panic", nil) req := httptest.NewRequest("GET", "/panic", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()

View File

@ -7,11 +7,11 @@ import (
func init() { func init() {
jsmod.Register("service", map[string]any{ jsmod.Register("service", map[string]any{
// 类型占位工厂 (用于 AI 发现类型结构) // 类型占位工厂 (用于 AI 发现类型结构)
"newRequest": func() *Request { return &Request{} }, "newRequest": func() *Request { return &Request{} },
"newResponse": func() *Response { return &Response{} }, "newResponse": func() *Response { return &Response{} },
"newWebSocket": func() *WebSocketConn { return &WebSocketConn{} }, "newWebSocket": func() *WebSocketConn { return &WebSocketConn{} },
"newSession": func() *Session { return &Session{} }, "newSession": func() *Session { return &Session{} },
"newFile": func() *jsUploadFile { return &jsUploadFile{} }, "newFile": func() *jsUploadFile { return &jsUploadFile{} },
// 功能函数 // 功能函数
"upgrade": Upgrade, "upgrade": Upgrade,

View File

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"reflect"
"regexp" "regexp"
"strings" "strings"
"time" "time"
@ -77,24 +78,17 @@ func parseProxyRule(authLevel int, path, toApp, toPath string, to string) *proxy
func (hc *HostContext) Proxy(authLevel int, path string, to string) *HostContext { func (hc *HostContext) Proxy(authLevel int, path string, to string) *HostContext {
p := parseProxyRule(authLevel, path, "", "", to) p := parseProxyRule(authLevel, path, "", "", to)
hostPoliciesLock.Lock() hc.ws.hostPoliciesLock.Lock()
defer hostPoliciesLock.Unlock() defer hc.ws.hostPoliciesLock.Unlock()
codeProxies[hc.host] = append(codeProxies[hc.host], p) if hc.ws.codeProxies[hc.host] == nil {
rebuildProxiesUnderLock(hc.host) hc.ws.codeProxies[hc.host] = make([]*proxyType, 0)
}
hc.ws.codeProxies[hc.host] = append(hc.ws.codeProxies[hc.host], p)
hc.ws.rebuildProxiesUnderLock(hc.host)
return hc return hc
} }
func rebuildProxiesUnderLock(host string) { func (ws *webServer) findProxy(request *Request) (int, *string, *string, string) {
var combined []*proxyType
combined = append(combined, codeProxies[host]...)
combined = append(combined, fileProxies[host]...)
combined = append(combined, dynamicProxies[host]...)
hostProxies[host] = combined
}
var httpClientPool *gohttp.Client
func findProxy(request *Request) (int, *string, *string, string) {
host := request.Host host := request.Host
hostOnly, port, _ := strings.Cut(host, ":") hostOnly, port, _ := strings.Cut(host, ":")
hosts := []string{host} hosts := []string{host}
@ -110,11 +104,11 @@ func findProxy(request *Request) (int, *string, *string, string) {
requestPath = requestPath[:pos] requestPath = requestPath[:pos]
} }
hostPoliciesLock.RLock() ws.hostPoliciesLock.RLock()
defer hostPoliciesLock.RUnlock() defer ws.hostPoliciesLock.RUnlock()
for _, h := range hosts { for _, h := range hosts {
proxies, exists := hostProxies[h] proxies, exists := ws.hostProxies[h]
if !exists { if !exists {
continue continue
} }
@ -151,15 +145,15 @@ func findProxy(request *Request) (int, *string, *string, string) {
return 0, nil, nil, "" return 0, nil, nil, ""
} }
func processProxy(request *Request, response *Response, logger *log.Logger) bool { func (ws *webServer) processProxy(request *Request, response *Response, logger *log.Logger) bool {
authLevel, proxyToApp, proxyToPath, foundHost := findProxy(request) authLevel, proxyToApp, proxyToPath, foundHost := ws.findProxy(request)
if proxyToApp == nil || proxyToPath == nil || *proxyToApp == "" || *proxyToPath == "" { if proxyToApp == nil || proxyToPath == nil || *proxyToApp == "" || *proxyToPath == "" {
return false return false
} }
// 鉴权 // 鉴权
pass, obj := checkAuthForProxy(authLevel, request, response, logger) pass, obj := ws.checkAuthForProxy(authLevel, request, response, logger)
if !pass { if !pass {
if !response.changed { if !response.changed {
response.WriteHeader(http.StatusForbidden) response.WriteHeader(http.StatusForbidden)
@ -174,19 +168,16 @@ func processProxy(request *Request, response *Response, logger *log.Logger) bool
if strings.Contains(app, "://") { if strings.Contains(app, "://") {
// 直接 URL 代理 // 直接 URL 代理
if httpClientPool == nil { res := ws.getHttpClient().ManualDoByRequest(request.Request, request.Method, app+path, request.Body)
httpClientPool = gohttp.NewClient(time.Duration(Config.RedirectTimeout) * time.Millisecond)
}
res := httpClientPool.ManualDoByRequest(request.Request, request.Method, app+path, request.Body)
copyResponse(res, response, logger) copyResponse(res, response, logger)
} else { } else {
// Discover 代理 // Discover 代理
if GlobalDiscoverer == nil { if ws.discoverer == nil {
logger.Error("proxy failed: GlobalDiscoverer is not initialized") logger.Error("proxy failed: Discoverer is not initialized")
response.WriteHeader(http.StatusBadGateway) response.WriteHeader(http.StatusBadGateway)
return true return true
} }
caller := GlobalDiscoverer.NewCaller(request.Request, logger) caller := ws.discoverer.NewCaller(request.Request, logger)
caller.NoBody = true caller.NoBody = true
res, _ := caller.ManualDoWithNode(request.Method, app, "", path, request.Body) res, _ := caller.ManualDoWithNode(request.Method, app, "", path, request.Body)
copyResponse(res, response, logger) copyResponse(res, response, logger)
@ -195,10 +186,22 @@ func processProxy(request *Request, response *Response, logger *log.Logger) bool
return true return true
} }
func checkAuthForProxy(authLevel int, request *Request, response *Response, logger *log.Logger) (bool, any) { func (ws *webServer) getHttpClient() *gohttp.Client {
ac := webAuthCheckers[authLevel] // 尝试从注入对象获取
if obj := ws.GetInject(reflect.TypeOf(&gohttp.Client{})); obj != nil {
return obj.(*gohttp.Client)
}
timeout := time.Duration(ws.Config.RedirectTimeout) * time.Millisecond
if timeout <= 0 {
timeout = 30 * time.Second
}
return gohttp.NewClient(timeout)
}
func (ws *webServer) checkAuthForProxy(authLevel int, request *Request, response *Response, logger *log.Logger) (bool, any) {
ac := ws.webAuthCheckers[authLevel]
if ac == nil { if ac == nil {
ac = webAuthChecker ac = ws.webAuthChecker
} }
if ac == nil { if ac == nil {
return true, nil return true, nil
@ -239,13 +242,17 @@ type ProxyRule struct {
// ReplaceProxies 使用全量指针替换的方式 (Copy-on-Write) 无缝更新指定 host 的动态代理规则。 // ReplaceProxies 使用全量指针替换的方式 (Copy-on-Write) 无缝更新指定 host 的动态代理规则。
func ReplaceProxies(host string, rules []ProxyRule) { func ReplaceProxies(host string, rules []ProxyRule) {
DefaultServer.ReplaceProxies(host, rules)
}
func (ws *webServer) ReplaceProxies(host string, rules []ProxyRule) {
newProxies := make([]*proxyType, 0, len(rules)) newProxies := make([]*proxyType, 0, len(rules))
for _, r := range rules { for _, r := range rules {
newProxies = append(newProxies, parseProxyRule(r.AuthLevel, r.Path, r.ToApp, r.ToPath, r.To)) newProxies = append(newProxies, parseProxyRule(r.AuthLevel, r.Path, r.ToApp, r.ToPath, r.To))
} }
hostPoliciesLock.Lock() ws.hostPoliciesLock.Lock()
defer hostPoliciesLock.Unlock() defer ws.hostPoliciesLock.Unlock()
dynamicProxies[host] = newProxies ws.dynamicProxies[host] = newProxies
rebuildProxiesUnderLock(host) ws.rebuildProxiesUnderLock(host)
} }

View File

@ -15,7 +15,7 @@ func TestRewrite(t *testing.T) {
Host("*").ANY("/new", func() string { return "new content" }).Memo("new") Host("*").ANY("/new", func() string { return "new content" }).Memo("new")
Host("*").ANY("/target/123", func() string { return "target content" }).Memo("target") Host("*").ANY("/target/123", func() string { return "target content" }).Memo("target")
rh := &RouteHandler{} rh := &RouteHandler{ws: DefaultServer}
// 测试精确匹配重写 // 测试精确匹配重写
req1 := httptest.NewRequest("GET", "/old", nil) req1 := httptest.NewRequest("GET", "/old", nil)
@ -45,7 +45,7 @@ func TestProxyDirect(t *testing.T) {
// 注册代理规则 // 注册代理规则
Host("*").Proxy(0, "/proxy", backend.URL+"/hello") Host("*").Proxy(0, "/proxy", backend.URL+"/hello")
rh := &RouteHandler{} rh := &RouteHandler{ws: DefaultServer}
req := httptest.NewRequest("GET", "/proxy", nil) req := httptest.NewRequest("GET", "/proxy", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
rh.ServeHTTP(w, req) rh.ServeHTTP(w, req)

View File

@ -5,24 +5,30 @@ import (
"sync" "sync"
) )
var ( type reloadHook struct {
reloadHooks []func() error hooks []func() error
reloadLock sync.RWMutex lock sync.RWMutex
) }
var globalReloadHook = &reloadHook{}
// OnReload 注册一个在接收到 SIGHUP 信号时触发的重新加载钩子 // OnReload 注册一个在接收到 SIGHUP 信号时触发的重新加载钩子
func OnReload(handler func() error) { func OnReload(handler func() error) {
reloadLock.Lock() DefaultServer.OnReload(handler)
defer reloadLock.Unlock() }
reloadHooks = append(reloadHooks, handler)
func (ws *webServer) OnReload(handler func() error) {
globalReloadHook.lock.Lock()
defer globalReloadHook.lock.Unlock()
globalReloadHook.hooks = append(globalReloadHook.hooks, handler)
} }
// triggerReload 触发所有注册的重新加载钩子 // triggerReload 触发所有注册的重新加载钩子
func triggerReload() error { func (ws *webServer) triggerReload() error {
reloadLock.RLock() globalReloadHook.lock.RLock()
hooks := make([]func() error, len(reloadHooks)) hooks := make([]func() error, len(globalReloadHook.hooks))
copy(hooks, reloadHooks) copy(hooks, globalReloadHook.hooks)
reloadLock.RUnlock() globalReloadHook.lock.RUnlock()
for _, hook := range hooks { for _, hook := range hooks {
if err := hook(); err != nil { if err := hook(); err != nil {

View File

@ -42,22 +42,14 @@ func parseRewriteRule(fromPath, toPath, to string) *rewriteType {
func (hc *HostContext) Rewrite(path string, to string) *HostContext { func (hc *HostContext) Rewrite(path string, to string) *HostContext {
s := parseRewriteRule(path, "", to) s := parseRewriteRule(path, "", to)
hostPoliciesLock.Lock() hc.ws.hostPoliciesLock.Lock()
defer hostPoliciesLock.Unlock() defer hc.ws.hostPoliciesLock.Unlock()
codeRewrites[hc.host] = append(codeRewrites[hc.host], s) hc.ws.codeRewrites[hc.host] = append(hc.ws.codeRewrites[hc.host], s)
rebuildRewritesUnderLock(hc.host) hc.ws.rebuildRewritesUnderLock(hc.host)
return hc return hc
} }
func rebuildRewritesUnderLock(host string) { func (ws *webServer) processRewrite(request *Request, response *Response, logger *log.Logger) bool {
var combined []*rewriteType
combined = append(combined, codeRewrites[host]...)
combined = append(combined, fileRewrites[host]...)
combined = append(combined, dynamicRewrites[host]...)
hostRewrites[host] = combined
}
func processRewrite(request *Request, response *Response, logger *log.Logger) bool {
host := request.Host host := request.Host
hostOnly, port, _ := strings.Cut(host, ":") hostOnly, port, _ := strings.Cut(host, ":")
hosts := []string{host} hosts := []string{host}
@ -66,8 +58,8 @@ func processRewrite(request *Request, response *Response, logger *log.Logger) bo
} }
hosts = append(hosts, "*") hosts = append(hosts, "*")
hostPoliciesLock.RLock() ws.hostPoliciesLock.RLock()
defer hostPoliciesLock.RUnlock() defer ws.hostPoliciesLock.RUnlock()
requestPath := request.RequestURI requestPath := request.RequestURI
queryString := "" queryString := ""
@ -77,7 +69,7 @@ func processRewrite(request *Request, response *Response, logger *log.Logger) bo
} }
for _, h := range hosts { for _, h := range hosts {
rewrites, exists := hostRewrites[h] rewrites, exists := ws.hostRewrites[h]
if !exists { if !exists {
continue continue
} }
@ -144,13 +136,17 @@ type RewriteRule struct {
// ReplaceRewrites 使用 Copy-on-Write 机制原子地替换指定 host 下的动态重写规则。 // ReplaceRewrites 使用 Copy-on-Write 机制原子地替换指定 host 下的动态重写规则。
func ReplaceRewrites(host string, rules []RewriteRule) { func ReplaceRewrites(host string, rules []RewriteRule) {
DefaultServer.ReplaceRewrites(host, rules)
}
func (ws *webServer) ReplaceRewrites(host string, rules []RewriteRule) {
newRewrites := make([]*rewriteType, 0, len(rules)) newRewrites := make([]*rewriteType, 0, len(rules))
for _, r := range rules { for _, r := range rules {
newRewrites = append(newRewrites, parseRewriteRule(r.Path, r.ToPath, r.To)) newRewrites = append(newRewrites, parseRewriteRule(r.Path, r.ToPath, r.To))
} }
hostPoliciesLock.Lock() ws.hostPoliciesLock.Lock()
defer hostPoliciesLock.Unlock() defer ws.hostPoliciesLock.Unlock()
dynamicRewrites[host] = newRewrites ws.dynamicRewrites[host] = newRewrites
rebuildRewritesUnderLock(host) ws.rebuildRewritesUnderLock(host)
} }

352
server.go
View File

@ -14,42 +14,262 @@ import (
"net" "net"
"net/http" "net/http"
"os" "os"
"path/filepath"
"reflect"
"sort"
"strings" "strings"
"sync"
"time" "time"
) )
// GlobalDiscoverer 供服务框架内部使用的发现实例 type webServer struct {
var GlobalDiscoverer *discover.Discoverer Config ServiceConfig
// WebServer 实现了 starter.Service 和 starter.Reloader 接口
type WebServer struct {
server *http.Server server *http.Server
listener net.Listener listener net.Listener
Addr string Addr string
useDiscover bool useDiscover bool
discoverer *discover.Discoverer discoverer *discover.Discoverer
logger *log.Logger logger *log.Logger
// 运行时状态
serverId string
serverAddr string
running bool
// Web 服务注册 (按 Host 隔离)
webServices map[string]map[string]*webServiceType
regexWebServices map[string][]*webServiceType
webServicesLock sync.RWMutex
webServicesList []*webServiceType
websocketServices map[string]map[string]*websocketServiceType
websocketServicesLock sync.RWMutex
websocketServicesList []*websocketServiceType
// 路由策略 (按 Host 隔离)
hostRewrites map[string][]*rewriteType
hostProxies map[string][]*proxyType
codeProxies map[string][]*proxyType
fileProxies map[string][]*proxyType
dynamicProxies map[string][]*proxyType
codeRewrites map[string][]*rewriteType
fileRewrites map[string][]*rewriteType
dynamicRewrites map[string][]*rewriteType
hostPoliciesLock sync.RWMutex
// 静态文件服务
statics map[string]*string
staticsByHost map[string]map[string]*string
codeStatics map[string]map[string]*string
fileStatics map[string]map[string]*string
dynamicStatics map[string]map[string]*string
staticsByHostLock sync.RWMutex
// 过滤器与拦截器
inFilters []func(*map[string]any, *Request, *Response, *log.Logger) any
outFilters []func(map[string]any, *Request, *Response, any, *log.Logger) (any, bool)
errorHandle func(any, *Request, *Response) any
webAuthChecker func(int, *log.Logger, *string, map[string]any, *Request, *Response, *WebServiceOptions) (pass bool, object any)
webAuthCheckers map[int]func(int, *log.Logger, *string, map[string]any, *Request, *Response, *WebServiceOptions) (pass bool, object any)
// 注入点
injectObjects map[reflect.Type]any
injectFunctions map[reflect.Type]func() any
// 客户端标识
usedDeviceIdKey string
usedClientAppKey string
usedSessionIdKey string
sessionIdMaker func() string
} }
// NewWebServer 创建并返回一个新的 WebServer 实例 // DefaultServer 全局单例服务实例
func NewWebServer() *WebServer { var DefaultServer = newWebServer()
return &WebServer{}
// Config 全局配置对象 (指向 DefaultServer.Config)
var Config = &DefaultServer.Config
func newWebServer() *webServer {
ws := &webServer{
webServices: make(map[string]map[string]*webServiceType),
regexWebServices: make(map[string][]*webServiceType),
webServicesList: make([]*webServiceType, 0),
websocketServices: make(map[string]map[string]*websocketServiceType),
websocketServicesList: make([]*websocketServiceType, 0),
hostRewrites: make(map[string][]*rewriteType),
hostProxies: make(map[string][]*proxyType),
codeProxies: make(map[string][]*proxyType),
fileProxies: make(map[string][]*proxyType),
dynamicProxies: make(map[string][]*proxyType),
codeRewrites: make(map[string][]*rewriteType),
fileRewrites: make(map[string][]*rewriteType),
dynamicRewrites: make(map[string][]*rewriteType),
statics: make(map[string]*string),
staticsByHost: make(map[string]map[string]*string),
codeStatics: make(map[string]map[string]*string),
fileStatics: make(map[string]map[string]*string),
dynamicStatics: make(map[string]map[string]*string),
webAuthCheckers: make(map[int]func(int, *log.Logger, *string, map[string]any, *Request, *Response, *WebServiceOptions) (pass bool, object any)),
injectObjects: make(map[reflect.Type]any),
injectFunctions: make(map[reflect.Type]func() any),
}
return ws
}
// SetDiscovererForTest 提供给测试用例使用的后门方法,用于模拟断开或重置服务发现
func SetDiscovererForTest(d *discover.Discoverer) {
DefaultServer.discoverer = d
}
// ApplyConfig 将 ServiceConfig 中的路由策略应用到内部的文件级策略中
func (ws *webServer) ApplyConfig() {
ws.hostPoliciesLock.Lock()
defer ws.hostPoliciesLock.Unlock()
// 1. Proxies KV 解析
ws.fileProxies = make(map[string][]*proxyType)
for host, kv := range ws.Config.Proxies {
h := host
if h == "*" {
h = ""
}
rules := make([]*proxyType, 0, len(kv))
for path, val := range kv {
if to, ok := val.(string); ok {
rules = append(rules, parseProxyRule(0, path, "", "", to))
} else {
// 对象模式
m := make(map[string]any)
if tm, ok := val.(map[string]any); ok {
m = tm
}
rules = append(rules, parseProxyRule(
int(reflect.ValueOf(m["Auth"]).Int()), // Simplified
path,
fmt.Sprint(m["ToApp"]),
fmt.Sprint(m["ToPath"]),
fmt.Sprint(m["To"]),
))
}
}
ws.fileProxies[h] = rules
ws.rebuildProxiesUnderLock(h)
}
// 2. Rewrites KV 解析
ws.fileRewrites = make(map[string][]*rewriteType)
for host, kv := range ws.Config.Rewrites {
h := host
if h == "*" {
h = ""
}
rules := make([]*rewriteType, 0, len(kv))
for path, val := range kv {
if to, ok := val.(string); ok {
rules = append(rules, parseRewriteRule(path, "", to))
} else {
m := make(map[string]any)
if tm, ok := val.(map[string]any); ok {
m = tm
}
rules = append(rules, parseRewriteRule(
path,
fmt.Sprint(m["ToPath"]),
fmt.Sprint(m["To"]),
))
}
}
ws.fileRewrites[h] = rules
ws.rebuildRewritesUnderLock(h)
}
ws.staticsByHostLock.Lock()
defer ws.staticsByHostLock.Unlock()
ws.fileStatics = make(map[string]map[string]*string)
for host, config := range ws.Config.Statics {
h := host
if h == "*" {
h = ""
}
newStatics := make(map[string]*string, len(config))
for path, rootPath := range config {
rp := rootPath
if !filepath.IsAbs(rp) {
if absPath, err := filepath.Abs(rp); err == nil {
rp = absPath
}
}
newStatics[path] = &rp
}
ws.fileStatics[h] = newStatics
ws.rebuildStaticsUnderLock(h)
}
// 始终重新构建默认 Host 的静态路由,以合并代码定义的路由
ws.rebuildStaticsUnderLock("")
}
func (ws *webServer) rebuildProxiesUnderLock(host string) {
combined := make([]*proxyType, 0)
combined = append(combined, ws.codeProxies[host]...)
combined = append(combined, ws.fileProxies[host]...)
combined = append(combined, ws.dynamicProxies[host]...)
sort.Slice(combined, func(i, j int) bool {
return len(combined[i].fromPath) > len(combined[j].fromPath)
})
ws.hostProxies[host] = combined
}
func (ws *webServer) rebuildRewritesUnderLock(host string) {
combined := make([]*rewriteType, 0)
combined = append(combined, ws.codeRewrites[host]...)
combined = append(combined, ws.fileRewrites[host]...)
combined = append(combined, ws.dynamicRewrites[host]...)
sort.Slice(combined, func(i, j int) bool {
return len(combined[i].fromPath) > len(combined[j].fromPath)
})
ws.hostRewrites[host] = combined
}
func (ws *webServer) rebuildStaticsUnderLock(host string) {
combined := make(map[string]*string)
for k, v := range ws.codeStatics[host] {
combined[k] = v
}
for k, v := range ws.fileStatics[host] {
combined[k] = v
}
for k, v := range ws.dynamicStatics[host] {
combined[k] = v
}
if host == "" {
ws.statics = combined
} else {
ws.staticsByHost[host] = combined
}
} }
// Start 启动服务,实现 starter.Service 接口 // Start 启动服务,实现 starter.Service 接口
func (ws *WebServer) Start(ctx context.Context, logger *log.Logger) error { func (ws *webServer) Start(ctx context.Context, logger *log.Logger) error {
if logger == nil { if logger == nil {
logger = log.DefaultLogger logger = log.DefaultLogger
} }
ws.logger = logger ws.logger = logger
// 初始加载配置 // 初始加载配置
if err := config.Load(&Config, "service"); err != nil { if err := config.Load(&ws.Config, "service"); err != nil {
logger.Error("failed to load config during start", "error", err.Error()) logger.Error("failed to load config during start", "error", err.Error())
} }
ApplyConfig() ws.ApplyConfig()
listenStr := Config.Listen listenStr := ws.Config.Listen
ws.useDiscover = false ws.useDiscover = false
if listenStr == "" { if listenStr == "" {
@ -57,7 +277,6 @@ func (ws *WebServer) Start(ctx context.Context, logger *log.Logger) error {
ws.useDiscover = true ws.useDiscover = true
} }
// 解析第一个监听配置
part := strings.Split(listenStr, "|")[0] part := strings.Split(listenStr, "|")[0]
addr, opts, _ := strings.Cut(part, ",") addr, opts, _ := strings.Cut(part, ",")
@ -70,29 +289,26 @@ func (ws *WebServer) Start(ctx context.Context, logger *log.Logger) error {
} }
if protocol == "" { if protocol == "" {
protocol = "http" // Default to http protocol = "http"
} }
if !strings.Contains(addr, ":") { if !strings.Contains(addr, ":") {
addr = ":" + addr addr = ":" + addr
} }
// 检查是否需要启动服务发现 appName := ws.Config.App
appName := Config.App
if appName == "" { if appName == "" {
appName = GetDefaultName() appName = GetDefaultName()
Config.App = appName ws.Config.App = appName
} }
if appName != "" || Config.Register != "" { if appName != "" || ws.Config.Register != "" {
ws.useDiscover = true ws.useDiscover = true
} }
// 初始化服务器唯一标识 (8位物理上限 3,844/s) ws.serverId = IDMaker.Get8Bytes4KPerSecond()
serverId = IDMaker.Get8Bytes4KPerSecond()
// 初始化分布式 ID 生成器 if ws.Config.IdServer != "" {
if Config.IdServer != "" { rd := redis.GetRedis(ws.Config.IdServer, log.New(ws.serverId))
rd := redis.GetRedis(Config.IdServer, log.New(serverId))
if rd.Error == nil { if rd.Error == nil {
IDMaker = redis.NewIDMaker(rd) IDMaker = redis.NewIDMaker(rd)
} }
@ -105,38 +321,35 @@ func (ws *WebServer) Start(ctx context.Context, logger *log.Logger) error {
ws.listener = listener ws.listener = listener
ws.Addr = listener.Addr().String() ws.Addr = listener.Addr().String()
serverAddr = ws.Addr ws.serverAddr = ws.Addr
// 如果使用了随机端口且没有明确指定不需要服务发现,则开启
if addr == ":0" || strings.HasSuffix(addr, ":0") { if addr == ":0" || strings.HasSuffix(addr, ":0") {
ws.useDiscover = true ws.useDiscover = true
} }
h2s := &http2.Server{} h2s := &http2.Server{}
var handler http.Handler = &RouteHandler{} var handler http.Handler = &RouteHandler{ws: ws}
if protocol == "h2c" { if protocol == "h2c" {
handler = h2c.NewHandler(handler, h2s) handler = h2c.NewHandler(handler, h2s)
} }
ws.server = &http.Server{ ws.server = &http.Server{
Handler: handler, Handler: handler,
ReadTimeout: time.Duration(Config.ReadTimeout) * time.Millisecond, ReadTimeout: time.Duration(ws.Config.ReadTimeout) * time.Millisecond,
ReadHeaderTimeout: time.Duration(Config.ReadHeaderTimeout) * time.Millisecond, ReadHeaderTimeout: time.Duration(ws.Config.ReadHeaderTimeout) * time.Millisecond,
WriteTimeout: time.Duration(Config.WriteTimeout) * time.Millisecond, WriteTimeout: time.Duration(ws.Config.WriteTimeout) * time.Millisecond,
IdleTimeout: time.Duration(Config.IdleTimeout) * time.Millisecond, IdleTimeout: time.Duration(ws.Config.IdleTimeout) * time.Millisecond,
MaxHeaderBytes: Config.MaxHeaderBytes, MaxHeaderBytes: ws.Config.MaxHeaderBytes,
} }
// 启动服务发现
if ws.useDiscover { if ws.useDiscover {
_, port, _ := net.SplitHostPort(ws.Addr) _, port, _ := net.SplitHostPort(ws.Addr)
ip := GetServerIp() ip := GetServerIp()
discoverAddr := fmt.Sprintf("%s:%s", ip, port) discoverAddr := fmt.Sprintf("%s:%s", ip, port)
// 转换配置
discConf := discover.Config{ discConf := discover.Config{
Weight: Config.Weight, Weight: ws.Config.Weight,
CallRetryTimes: 10, // Default CallRetryTimes: 10,
Calls: make(map[string]discover.CallConfig), Calls: make(map[string]discover.CallConfig),
} }
@ -144,15 +357,15 @@ func (ws *WebServer) Start(ctx context.Context, logger *log.Logger) error {
discConf.Weight = 100 discConf.Weight = 100
} }
for name, call := range Config.Calls { for name, call := range ws.Config.Calls {
dc := discover.CallConfig{ dc := discover.CallConfig{
Http2: call.Http2, Http2: call.Http2,
SSL: call.SSL, SSL: call.SSL,
} }
if call.Timeout > 0 { if call.Timeout > 0 {
dc.Timeout = time.Duration(call.Timeout) * time.Millisecond dc.Timeout = time.Duration(call.Timeout) * time.Millisecond
} else if Config.RedirectTimeout > 0 { } else if ws.Config.RedirectTimeout > 0 {
dc.Timeout = time.Duration(Config.RedirectTimeout) * time.Millisecond dc.Timeout = time.Duration(ws.Config.RedirectTimeout) * time.Millisecond
} }
if call.Token != "" { if call.Token != "" {
dc.Token = safe.NewSafeBuf([]byte(call.Token)) dc.Token = safe.NewSafeBuf([]byte(call.Token))
@ -160,19 +373,16 @@ func (ws *WebServer) Start(ctx context.Context, logger *log.Logger) error {
discConf.Calls[name] = dc discConf.Calls[name] = dc
} }
// 解析必需的 Register支持环境变量 fallback registry := ws.Config.Register
registry := Config.Register
if registry == "" { if registry == "" {
registry = os.Getenv("DISCOVER_REGISTRY") registry = os.Getenv("DISCOVER_REGISTRY")
} }
if registry == "" {
registry = "127.0.0.1:6379::15" // Default fallback
}
ws.discoverer = discover.Start(registry, appName, discoverAddr, logger, discConf) if registry != "" {
GlobalDiscoverer = ws.discoverer ws.discoverer = discover.Start(registry, appName, discoverAddr, logger, discConf)
if ws.discoverer != nil { if ws.discoverer != nil {
logger.Info("discover registered", "app", appName, "addr", discoverAddr) logger.Info("discover registered", "app", appName, "addr", discoverAddr)
}
} }
} }
@ -180,13 +390,13 @@ func (ws *WebServer) Start(ctx context.Context, logger *log.Logger) error {
go func() { go func() {
logger.Info("service starting", "addr", ws.Addr, "proto", protocol) logger.Info("service starting", "addr", ws.Addr, "proto", protocol)
if err := ws.server.Serve(listener); err != nil && err != http.ErrServerClosed { ws.running = true
if err := ws.server.Serve(ws.listener); err != nil && err != http.ErrServerClosed {
errChan <- err errChan <- err
} }
close(errChan) close(errChan)
}() }()
// 短暂等待验证是否闪退
select { select {
case err := <-errChan: case err := <-errChan:
if err != nil { if err != nil {
@ -199,12 +409,13 @@ func (ws *WebServer) Start(ctx context.Context, logger *log.Logger) error {
} }
// Stop 停止服务,实现 starter.Service 接口 // Stop 停止服务,实现 starter.Service 接口
func (ws *WebServer) Stop(ctx context.Context) error { func (ws *webServer) Stop(ctx context.Context) error {
logger := ws.logger logger := ws.logger
if logger == nil { if logger == nil {
logger = log.DefaultLogger logger = log.DefaultLogger
} }
logger.Info("service stopping") logger.Info("service stopping")
ws.running = false
if ws.discoverer != nil { if ws.discoverer != nil {
ws.discoverer.Stop() ws.discoverer.Stop()
} }
@ -218,52 +429,49 @@ func (ws *WebServer) Stop(ctx context.Context) error {
} }
// Status 检查服务健康状态,实现 starter.Service 接口 // Status 检查服务健康状态,实现 starter.Service 接口
func (ws *WebServer) Status() (string, error) { func (ws *webServer) Status() (string, error) {
if ws.server == nil { if ws.server == nil || !ws.running {
return "", fmt.Errorf("server is not running") return "", fmt.Errorf("server is not running")
} }
return ws.Addr, nil return ws.Addr, nil
} }
// Reload 实现配置重新加载,实现 starter.Reloader 接口 // Reload 实现配置重新加载,实现 starter.Reloader 接口
func (ws *WebServer) Reload() error { func (ws *webServer) Reload() error {
logger := ws.logger logger := ws.logger
if logger == nil { if logger == nil {
logger = log.DefaultLogger logger = log.DefaultLogger
} }
logger.Info("reloading configurations...") logger.Info("reloading configurations...")
// 重新加载配置文件中的策略 if err := config.Load(&ws.Config, "service"); err != nil {
if err := config.Load(&Config, "service"); err != nil {
logger.Error("failed to load config during reload", "error", err.Error()) logger.Error("failed to load config during reload", "error", err.Error())
} }
ApplyConfig() ws.ApplyConfig()
// 触发业务挂载的 Hook return ws.triggerReload()
return triggerReload()
} }
// AsyncServer 兼容旧版异步服务实例 // AsyncServer 兼容旧版异步服务实例
type AsyncServer struct { type AsyncServer struct {
*WebServer *webServer
} }
// Stop 兼容旧版的无参数停止方法 // Stop 兼容旧版的无参数停止方法
func (as *AsyncServer) Stop() { func (as *AsyncServer) Stop() {
stopTimeout := time.Duration(Config.StopTimeout) * time.Millisecond stopTimeout := time.Duration(as.Config.StopTimeout) * time.Millisecond
if stopTimeout <= 0 { if stopTimeout <= 0 {
stopTimeout = 5 * time.Second stopTimeout = 5 * time.Second
} }
ctx, cancel := context.WithTimeout(context.Background(), stopTimeout) ctx, cancel := context.WithTimeout(context.Background(), stopTimeout)
defer cancel() defer cancel()
_ = as.WebServer.Stop(ctx) _ = as.webServer.Stop(ctx)
} }
// AsyncStart 兼容旧版的异步启动方法 // AsyncStart 兼容旧版的异步启动方法
func AsyncStart() *AsyncServer { func AsyncStart() *AsyncServer {
ws := NewWebServer() _ = DefaultServer.Start(context.Background(), log.DefaultLogger)
_ = ws.Start(context.Background(), log.DefaultLogger) return &AsyncServer{webServer: DefaultServer}
return &AsyncServer{WebServer: ws}
} }
// Wait 等待服务结束 (兼容旧版,直接阻塞) // Wait 等待服务结束 (兼容旧版,直接阻塞)
@ -271,12 +479,16 @@ func (as *AsyncServer) Wait() {
select {} select {}
} }
var startOnce sync.Once
// Start 兼容旧版的同步启动方法 (通过内部注册 starter 实现) // Start 兼容旧版的同步启动方法 (通过内部注册 starter 实现)
func Start() { func Start() {
stopTimeout := time.Duration(Config.StopTimeout) * time.Millisecond startOnce.Do(func() {
if stopTimeout <= 0 { stopTimeout := time.Duration(Config.StopTimeout) * time.Millisecond
stopTimeout = 5 * time.Second if stopTimeout <= 0 {
} stopTimeout = 5 * time.Second
starter.Register("web-server", NewWebServer(), 100, 5*time.Second, stopTimeout) }
starter.Run() starter.Register("web-server", DefaultServer, 100, 5*time.Second, stopTimeout)
starter.Run()
})
} }

View File

@ -6,7 +6,6 @@ import (
"reflect" "reflect"
"regexp" "regexp"
"strings" "strings"
"sync"
) )
// webServiceType 内部存储的服务元数据 // webServiceType 内部存储的服务元数据
@ -17,7 +16,7 @@ type webServiceType struct {
path string path string
pathMatcher *regexp.Regexp pathMatcher *regexp.Regexp
pathArgs []string pathArgs []string
paramsNum int paramsNum int
inType reflect.Type inType reflect.Type
inIndex int inIndex int
headersType reflect.Type headersType reflect.Type
@ -54,121 +53,116 @@ type websocketServiceType struct {
options WebServiceOptions options WebServiceOptions
} }
var (
serverId string
serverAddr string
serverProto = "http"
serverProtoName = "http"
running = false
// webServices 按 Host 隔离: map[host]map[method+path]*webServiceType
webServices = make(map[string]map[string]*webServiceType)
// regexWebServices 按 Host 隔离: map[host][]*webServiceType
regexWebServices = make(map[string][]*webServiceType)
webServicesLock = sync.RWMutex{}
webServicesList = make([]*webServiceType, 0)
websocketServices = make(map[string]map[string]*websocketServiceType)
websocketServicesLock = sync.RWMutex{}
websocketServicesList = make([]*websocketServiceType, 0)
// Rewrite 与 Proxy 按 Host 隔离 (编译后的最终路由)
hostRewrites = make(map[string][]*rewriteType)
hostProxies = make(map[string][]*proxyType)
// 按来源隔离的策略,避免互相覆盖
codeProxies = make(map[string][]*proxyType)
fileProxies = make(map[string][]*proxyType)
dynamicProxies = make(map[string][]*proxyType)
codeRewrites = make(map[string][]*rewriteType)
fileRewrites = make(map[string][]*rewriteType)
dynamicRewrites = make(map[string][]*rewriteType)
hostPoliciesLock = sync.RWMutex{}
// 过滤器与拦截器
inFilters = make([]func(*map[string]any, *Request, *Response, *log.Logger) any, 0)
outFilters = make([]func(map[string]any, *Request, *Response, any, *log.Logger) (any, bool), 0)
errorHandle func(any, *Request, *Response) any
webAuthChecker func(int, *log.Logger, *string, map[string]any, *Request, *Response, *WebServiceOptions) (pass bool, object any)
webAuthCheckers = make(map[int]func(int, *log.Logger, *string, map[string]any, *Request, *Response, *WebServiceOptions) (pass bool, object any))
// 注入点
injectObjects = make(map[reflect.Type]any)
injectFunctions = make(map[reflect.Type]func() any)
usedDeviceIdKey string
usedClientAppKey string
usedSessionIdKey string
sessionIdMaker func() string
)
// SetClientKeys 设置客户端标识相关的 Key 映射 // SetClientKeys 设置客户端标识相关的 Key 映射
func SetClientKeys(deviceIdKey, clientAppKey, sessionIdKey string) { func SetClientKeys(deviceIdKey, clientAppKey, sessionIdKey string) {
usedDeviceIdKey = deviceIdKey DefaultServer.SetClientKeys(deviceIdKey, clientAppKey, sessionIdKey)
usedClientAppKey = clientAppKey }
usedSessionIdKey = sessionIdKey
func (ws *webServer) SetClientKeys(deviceIdKey, clientAppKey, sessionIdKey string) {
ws.usedDeviceIdKey = deviceIdKey
ws.usedClientAppKey = clientAppKey
ws.usedSessionIdKey = sessionIdKey
} }
// SetSessionIdMaker 设置自定义会话 ID 生成器 // SetSessionIdMaker 设置自定义会话 ID 生成器
func SetSessionIdMaker(maker func() string) { func SetSessionIdMaker(maker func() string) {
sessionIdMaker = maker DefaultServer.SetSessionIdMaker(maker)
}
func (ws *webServer) SetSessionIdMaker(maker func() string) {
ws.sessionIdMaker = maker
} }
// SetAuthChecker 设置全局鉴权器 // SetAuthChecker 设置全局鉴权器
func SetAuthChecker(authChecker func(authLevel int, logger *log.Logger, url *string, in map[string]any, request *Request, response *Response, options *WebServiceOptions) (pass bool, object any)) { func SetAuthChecker(authChecker func(authLevel int, logger *log.Logger, url *string, in map[string]any, request *Request, response *Response, options *WebServiceOptions) (pass bool, object any)) {
webAuthChecker = authChecker DefaultServer.SetAuthChecker(authChecker)
}
func (ws *webServer) SetAuthChecker(authChecker func(authLevel int, logger *log.Logger, url *string, in map[string]any, request *Request, response *Response, options *WebServiceOptions) (pass bool, object any)) {
ws.webAuthChecker = authChecker
} }
// AddAuthChecker 为指定级别添加鉴权器 // AddAuthChecker 为指定级别添加鉴权器
func AddAuthChecker(authLevels []int, authChecker func(authLevel int, logger *log.Logger, url *string, in map[string]any, request *Request, response *Response, options *WebServiceOptions) (pass bool, object any)) { func AddAuthChecker(authLevels []int, authChecker func(authLevel int, logger *log.Logger, url *string, in map[string]any, request *Request, response *Response, options *WebServiceOptions) (pass bool, object any)) {
DefaultServer.AddAuthChecker(authLevels, authChecker)
}
func (ws *webServer) AddAuthChecker(authLevels []int, authChecker func(authLevel int, logger *log.Logger, url *string, in map[string]any, request *Request, response *Response, options *WebServiceOptions) (pass bool, object any)) {
for _, al := range authLevels { for _, al := range authLevels {
webAuthCheckers[al] = authChecker ws.webAuthCheckers[al] = authChecker
} }
} }
// SetInFilter 设置前置过滤器 // SetInFilter 设置前置过滤器
func SetInFilter(filter func(in *map[string]any, request *Request, response *Response, logger *log.Logger) (out any)) { func SetInFilter(filter func(in *map[string]any, request *Request, response *Response, logger *log.Logger) (out any)) {
inFilters = append(inFilters, filter) DefaultServer.SetInFilter(filter)
}
func (ws *webServer) SetInFilter(filter func(in *map[string]any, request *Request, response *Response, logger *log.Logger) (out any)) {
ws.inFilters = append(ws.inFilters, filter)
} }
// SetOutFilter 设置后置过滤器 // SetOutFilter 设置后置过滤器
func SetOutFilter(filter func(in map[string]any, request *Request, response *Response, out any, logger *log.Logger) (newOut any, isOver bool)) { func SetOutFilter(filter func(in map[string]any, request *Request, response *Response, out any, logger *log.Logger) (newOut any, isOver bool)) {
outFilters = append(outFilters, filter) DefaultServer.SetOutFilter(filter)
}
func (ws *webServer) SetOutFilter(filter func(in map[string]any, request *Request, response *Response, out any, logger *log.Logger) (newOut any, isOver bool)) {
ws.outFilters = append(ws.outFilters, filter)
} }
// HostContext 提供流式服务注册能力 // HostContext 提供流式服务注册能力
type HostContext struct { type HostContext struct {
ws *webServer
host string host string
} }
// Host 指定服务运行的 Host (支持 "example.com", ":8080", "example.com:8080", "*") // Host 指定服务运行的 Host (支持 "example.com", ":8080", "example.com:8080", "*")
func Host(host string) *HostContext { func Host(host string) *HostContext {
return DefaultServer.Host(host)
}
func (ws *webServer) Host(host string) *HostContext {
if host == "" { if host == "" {
host = "*" host = "*"
} }
return &HostContext{host: host} return &HostContext{ws: ws, host: host}
} }
// Register 注册一个 Web 服务 (使用默认 Host "*") // Register 注册一个 Web 服务 (使用默认 Host "*")
func Register(method, path string, serviceFunc any) *webServiceType { func Register(method, path string, serviceFunc any) *webServiceType {
return Host("*").Register(method, path, serviceFunc) return DefaultServer.Register(method, path, serviceFunc)
}
func (ws *webServer) Register(method, path string, serviceFunc any) *webServiceType {
return ws.Host("*").Register(method, path, serviceFunc)
} }
// RegisterWebsocket 注册一个 WebSocket 服务 (使用默认 Host "*") // RegisterWebsocket 注册一个 WebSocket 服务 (使用默认 Host "*")
func RegisterWebsocket(path string, serviceFunc any) *websocketServiceType { func RegisterWebsocket(path string, serviceFunc any) *websocketServiceType {
return Host("*").WebSocket(path, serviceFunc) return DefaultServer.RegisterWebsocket(path, serviceFunc)
}
func (ws *webServer) RegisterWebsocket(path string, serviceFunc any) *websocketServiceType {
return ws.Host("*").WebSocket(path, serviceFunc)
} }
// Proxy 注册一个代理转发 (使用默认 Host "*") // Proxy 注册一个代理转发 (使用默认 Host "*")
func Proxy(authLevel int, path string, to string) { func Proxy(authLevel int, path string, to string) {
Host("*").Proxy(authLevel, path, to) DefaultServer.Proxy(authLevel, path, to)
}
func (ws *webServer) Proxy(authLevel int, path string, to string) {
ws.Host("*").Proxy(authLevel, path, to)
} }
// Restful 注册一个符合 RESTful 规范的服务结构体 (使用默认 Host "*") // Restful 注册一个符合 RESTful 规范的服务结构体 (使用默认 Host "*")
func Restful(authLevel int, path string, serviceStruct any) { func Restful(authLevel int, path string, serviceStruct any) {
Host("*").Restful(authLevel, path, serviceStruct) DefaultServer.Restful(authLevel, path, serviceStruct)
}
func (ws *webServer) Restful(authLevel int, path string, serviceStruct any) {
ws.Host("*").Restful(authLevel, path, serviceStruct)
} }
func (hc *HostContext) Register(method, path string, serviceFunc any) *webServiceType { func (hc *HostContext) Register(method, path string, serviceFunc any) *webServiceType {
@ -188,25 +182,27 @@ func (hc *HostContext) Register(method, path string, serviceFunc any) *webServic
finds := finder.FindAllStringSubmatch(path, 20) finds := finder.FindAllStringSubmatch(path, 20)
for _, found := range finds { for _, found := range finds {
keyName = strings.Replace(keyName, regexp.QuoteMeta(found[0]), "(.*?)", 1) keyName = strings.Replace(keyName, regexp.QuoteMeta(found[0]), "(.*?)", 1)
hc.ws.webServicesLock.Lock() // Fixed: use lock from ws
s.pathArgs = append(s.pathArgs, found[1]) s.pathArgs = append(s.pathArgs, found[1])
hc.ws.webServicesLock.Unlock()
} }
if len(s.pathArgs) > 0 { if len(s.pathArgs) > 0 {
s.pathMatcher, _ = regexp.Compile("^" + keyName + "$") s.pathMatcher, _ = regexp.Compile("^" + keyName + "$")
} }
} }
webServicesLock.Lock() hc.ws.webServicesLock.Lock()
defer webServicesLock.Unlock() defer hc.ws.webServicesLock.Unlock()
if s.pathMatcher == nil { if s.pathMatcher == nil {
if webServices[s.host] == nil { if hc.ws.webServices[s.host] == nil {
webServices[s.host] = make(map[string]*webServiceType) hc.ws.webServices[s.host] = make(map[string]*webServiceType)
} }
webServices[s.host][s.method+s.path] = s hc.ws.webServices[s.host][s.method+s.path] = s
} else { } else {
regexWebServices[s.host] = append(regexWebServices[s.host], s) hc.ws.regexWebServices[s.host] = append(hc.ws.regexWebServices[s.host], s)
} }
webServicesList = append(webServicesList, s) hc.ws.webServicesList = append(hc.ws.webServicesList, s)
return s return s
} }
@ -301,13 +297,13 @@ func (hc *HostContext) WebSocket(path string, serviceFunc any) *websocketService
funcValue: reflect.ValueOf(serviceFunc), funcValue: reflect.ValueOf(serviceFunc),
} }
websocketServicesLock.Lock() hc.ws.websocketServicesLock.Lock()
defer websocketServicesLock.Unlock() defer hc.ws.websocketServicesLock.Unlock()
if websocketServices[hc.host] == nil { if hc.ws.websocketServices[hc.host] == nil {
websocketServices[hc.host] = make(map[string]*websocketServiceType) hc.ws.websocketServices[hc.host] = make(map[string]*websocketServiceType)
} }
websocketServices[hc.host][path] = ws hc.ws.websocketServices[hc.host][path] = ws
websocketServicesList = append(websocketServicesList, ws) hc.ws.websocketServicesList = append(hc.ws.websocketServicesList, ws)
return ws return ws
} }
@ -407,7 +403,7 @@ func makeCachedService(matchedService any) (*webServiceType, error) {
} }
targetService := &webServiceType{ targetService := &webServiceType{
paramsNum: funcType.NumIn(), paramsNum: funcType.NumIn(),
inIndex: -1, inIndex: -1,
headersIndex: -1, headersIndex: -1,
requestIndex: -1, requestIndex: -1,
@ -452,10 +448,14 @@ func makeCachedService(matchedService any) (*webServiceType, error) {
// GetInject 获取注入对象 // GetInject 获取注入对象
func GetInject(dataType reflect.Type) any { func GetInject(dataType reflect.Type) any {
if obj, exists := injectObjects[dataType]; exists { return DefaultServer.GetInject(dataType)
}
func (ws *webServer) GetInject(dataType reflect.Type) any {
if obj, exists := ws.injectObjects[dataType]; exists {
return obj return obj
} }
if factory, exists := injectFunctions[dataType]; exists { if factory, exists := ws.injectFunctions[dataType]; exists {
return factory() return factory()
} }
return nil return nil
@ -465,7 +465,7 @@ func GetInject(dataType reflect.Type) any {
func GetInjectT[T any]() T { func GetInjectT[T any]() T {
var zero T var zero T
t := reflect.TypeOf((*T)(nil)).Elem() t := reflect.TypeOf((*T)(nil)).Elem()
obj := GetInject(t) obj := DefaultServer.GetInject(t)
if obj == nil { if obj == nil {
return zero return zero
} }

View File

@ -12,9 +12,9 @@ func TestServiceRegister(t *testing.T) {
Host("*").Register("*", "/test", handler).Auth(0).Memo("test service") Host("*").Register("*", "/test", handler).Auth(0).Memo("test service")
webServicesLock.RLock() DefaultServer.webServicesLock.RLock()
s := webServices["*"]["*/test"] s := DefaultServer.webServices["*"]["*/test"]
webServicesLock.RUnlock() DefaultServer.webServicesLock.RUnlock()
if s == nil { if s == nil {
t.Fatal("Service not registered") t.Fatal("Service not registered")
@ -35,9 +35,9 @@ func TestRegexServiceRegister(t *testing.T) {
Host("*").Register("*", "/user/{id}", handler).Auth(0).Memo("get user") Host("*").Register("*", "/user/{id}", handler).Auth(0).Memo("get user")
webServicesLock.RLock() DefaultServer.webServicesLock.RLock()
found := false found := false
for _, services := range regexWebServices { for _, services := range DefaultServer.regexWebServices {
for _, s := range services { for _, s := range services {
if s.path == "/user/{id}" { if s.path == "/user/{id}" {
found = true found = true
@ -51,7 +51,7 @@ func TestRegexServiceRegister(t *testing.T) {
break break
} }
} }
webServicesLock.RUnlock() DefaultServer.webServicesLock.RUnlock()
if !found { if !found {
t.Fatal("Regex service not registered") t.Fatal("Regex service not registered")

View File

@ -72,7 +72,7 @@ func TestSessionInjection(t *testing.T) {
} }
Host("*").GET("/test-session", handler) Host("*").GET("/test-session", handler)
rh := &RouteHandler{} rh := &RouteHandler{ws: DefaultServer}
req := httptest.NewRequest("GET", "/test-session", nil) req := httptest.NewRequest("GET", "/test-session", nil)
req.Header.Set("sessid", "sess_123") req.Header.Set("sessid", "sess_123")
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -110,7 +110,7 @@ func TestCustomAuthInjection(t *testing.T) {
} }
Host("*").GET("/test-auth", handler).Auth(10) Host("*").GET("/test-auth", handler).Auth(10)
rh := &RouteHandler{} rh := &RouteHandler{ws: DefaultServer}
req := httptest.NewRequest("GET", "/test-auth", nil) req := httptest.NewRequest("GET", "/test-auth", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -132,7 +132,7 @@ func TestAutomaticAuthLevelCheck(t *testing.T) {
} }
Host("*").GET("/test-auto-auth", handler).Auth(1) Host("*").GET("/test-auto-auth", handler).Auth(1)
rh := &RouteHandler{} rh := &RouteHandler{ws: DefaultServer}
// 1. 无 Session 或 AuthLevel=0 时应失败 // 1. 无 Session 或 AuthLevel=0 时应失败
req1 := httptest.NewRequest("GET", "/test-auto-auth", nil) req1 := httptest.NewRequest("GET", "/test-auto-auth", nil)

View File

@ -7,56 +7,56 @@ import (
"net/http" "net/http"
"path/filepath" "path/filepath"
"strings" "strings"
"sync"
"time" "time"
) )
var (
statics = make(map[string]*string)
staticsByHost = make(map[string]map[string]*string)
codeStatics = make(map[string]map[string]*string)
fileStatics = make(map[string]map[string]*string)
dynamicStatics = make(map[string]map[string]*string)
staticsByHostLock = sync.RWMutex{}
)
// Static 注册静态文件目录 // Static 注册静态文件目录
func (hc *HostContext) Static(path, rootPath string) *HostContext { func (hc *HostContext) Static(path, rootPath string) *HostContext {
host := hc.host host := hc.host
if host == "*" { if host == "*" {
host = "" host = ""
} }
StaticByHost(path, rootPath, host) hc.ws.StaticByHost(path, rootPath, host)
return hc return hc
} }
// Static 注册静态文件目录 (使用默认 Host "*") // Static 注册静态文件目录 (使用默认 Host "*")
func Static(path, rootPath string) { func Static(path, rootPath string) {
Host("*").Static(path, rootPath) DefaultServer.Static(path, rootPath)
}
func (ws *webServer) Static(path, rootPath string) {
ws.Host("*").Static(path, rootPath)
} }
// StaticByHost 为指定域名注册静态文件目录 // StaticByHost 为指定域名注册静态文件目录
func StaticByHost(path, rootPath, host string) { func StaticByHost(path, rootPath, host string) {
DefaultServer.StaticByHost(path, rootPath, host)
}
func (ws *webServer) StaticByHost(path, rootPath, host string) {
if !filepath.IsAbs(rootPath) { if !filepath.IsAbs(rootPath) {
if absPath, err := filepath.Abs(rootPath); err == nil { if absPath, err := filepath.Abs(rootPath); err == nil {
rootPath = absPath rootPath = absPath
} }
} }
staticsByHostLock.Lock() ws.staticsByHostLock.Lock()
defer staticsByHostLock.Unlock() defer ws.staticsByHostLock.Unlock()
if codeStatics[host] == nil { if ws.codeStatics[host] == nil {
codeStatics[host] = make(map[string]*string) ws.codeStatics[host] = make(map[string]*string)
} }
codeStatics[host][path] = &rootPath ws.codeStatics[host][path] = &rootPath
rebuildStaticsUnderLock(host) ws.rebuildStaticsUnderLock(host)
} }
// ReplaceStatics 使用 Copy-on-Write 机制原子地替换指定 host 下的动态静态目录规则 // ReplaceStatics 使用 Copy-on-Write 机制原子地替换指定 host 下的动态静态目录规则
func ReplaceStatics(host string, config map[string]string) { func ReplaceStatics(host string, config map[string]string) {
DefaultServer.ReplaceStatics(host, config)
}
func (ws *webServer) ReplaceStatics(host string, config map[string]string) {
newStatics := make(map[string]*string, len(config)) newStatics := make(map[string]*string, len(config))
for path, rootPath := range config { for path, rootPath := range config {
rp := rootPath rp := rootPath
@ -68,50 +68,29 @@ func ReplaceStatics(host string, config map[string]string) {
newStatics[path] = &rp newStatics[path] = &rp
} }
staticsByHostLock.Lock() ws.staticsByHostLock.Lock()
defer staticsByHostLock.Unlock() defer ws.staticsByHostLock.Unlock()
dynamicStatics[host] = newStatics ws.dynamicStatics[host] = newStatics
rebuildStaticsUnderLock(host) ws.rebuildStaticsUnderLock(host)
} }
func rebuildStaticsUnderLock(host string) { func (ws *webServer) getStaticFilePath(requestPath, host string) string {
combined := make(map[string]*string) ws.staticsByHostLock.RLock()
defer ws.staticsByHostLock.RUnlock()
// 合并三种来源的静态路由
for k, v := range codeStatics[host] {
combined[k] = v
}
for k, v := range fileStatics[host] {
combined[k] = v
}
for k, v := range dynamicStatics[host] {
combined[k] = v
}
if host == "" {
statics = combined
} else {
staticsByHost[host] = combined
}
}
func getStaticFilePath(requestPath, host string) string {
staticsByHostLock.RLock()
defer staticsByHostLock.RUnlock()
// 优先匹配指定域名的配置 // 优先匹配指定域名的配置
if hostConfig, exists := staticsByHost[host]; exists { if hostConfig, exists := ws.staticsByHost[host]; exists {
if filePath := findMatchedPath(hostConfig, requestPath); filePath != "" { if filePath := ws.findMatchedPath(hostConfig, requestPath); filePath != "" {
return filePath return filePath
} }
} }
// 匹配全局配置 // 匹配全局配置
return findMatchedPath(statics, requestPath) return ws.findMatchedPath(ws.statics, requestPath)
} }
func findMatchedPath(config map[string]*string, requestPath string) string { func (ws *webServer) findMatchedPath(config map[string]*string, requestPath string) string {
for urlPath, rootPath := range config { for urlPath, rootPath := range config {
if strings.HasPrefix(requestPath, urlPath) { if strings.HasPrefix(requestPath, urlPath) {
return filepath.Join(*rootPath, requestPath[len(urlPath):]) return filepath.Join(*rootPath, requestPath[len(urlPath):])
@ -120,8 +99,8 @@ func findMatchedPath(config map[string]*string, requestPath string) string {
return "" return ""
} }
func processStatic(requestPath string, request *Request, response *Response, logger *log.Logger) bool { func (ws *webServer) processStatic(requestPath string, request *Request, response *Response, logger *log.Logger) bool {
filePath := getStaticFilePath(requestPath, request.Host) filePath := ws.getStaticFilePath(requestPath, request.Host)
if filePath == "" { if filePath == "" {
return false return false
} }
@ -133,7 +112,12 @@ func processStatic(requestPath string, request *Request, response *Response, log
if info.IsDir { if info.IsDir {
// 自动查找索引文件 // 自动查找索引文件
for _, indexFile := range Config.IndexFiles { indexFiles := ws.Config.IndexFiles
if len(indexFiles) == 0 {
indexFiles = []string{"index.html", "index.htm"}
}
for _, indexFile := range indexFiles {
f := filepath.Join(filePath, indexFile) f := filepath.Join(filePath, indexFile)
if i := file.GetFileInfo(f); i != nil && !i.IsDir { if i := file.GetFileInfo(f); i != nil && !i.IsDir {
filePath = f filePath = f

View File

@ -19,7 +19,7 @@ func TestStaticService(t *testing.T) {
// 注册静态目录 // 注册静态目录
Static("/ui", tempDir) Static("/ui", tempDir)
rh := &RouteHandler{} rh := &RouteHandler{ws: DefaultServer}
// 测试成功访问 // 测试成功访问
req := httptest.NewRequest("GET", "/ui/index.html", nil) req := httptest.NewRequest("GET", "/ui/index.html", nil)
@ -52,7 +52,7 @@ func TestHostStaticService(t *testing.T) {
// 注册域名特定静态文件服务 // 注册域名特定静态文件服务
Host("example.com").Static("/host-ui", tempDir) Host("example.com").Static("/host-ui", tempDir)
rh := &RouteHandler{} rh := &RouteHandler{ws: DefaultServer}
// 1. 匹配域名访问 // 1. 匹配域名访问
req1 := httptest.NewRequest("GET", "/host-ui/index.html", nil) req1 := httptest.NewRequest("GET", "/host-ui/index.html", nil)
@ -77,4 +77,3 @@ func TestHostStaticService(t *testing.T) {
t.Errorf("Expected 404 for mismatched host, got %d", w2.Code) t.Errorf("Expected 404 for mismatched host, got %d", w2.Code)
} }
} }

View File

@ -59,7 +59,7 @@ func Upgrade(response *Response, request *Request) (*WebSocketConn, error) {
return &WebSocketConn{Conn: conn}, nil return &WebSocketConn{Conn: conn}, nil
} }
func doWebsocketService(ws *websocketServiceType, request *Request, response *Response, logger *log.Logger, object any) { func (ws *webServer) doWebsocketService(wsc *websocketServiceType, request *Request, response *Response, logger *log.Logger, object any) {
wsConn, err := Upgrade(response, request) wsConn, err := Upgrade(response, request)
if err != nil { if err != nil {
logger.Error("websocket upgrade failed", "error", err.Error()) logger.Error("websocket upgrade failed", "error", err.Error())
@ -68,9 +68,9 @@ func doWebsocketService(ws *websocketServiceType, request *Request, response *Re
defer wsConn.Close() defer wsConn.Close()
// 调用业务处理函数,注入依赖 // 调用业务处理函数,注入依赖
params := make([]reflect.Value, ws.funcType.NumIn()) params := make([]reflect.Value, wsc.funcType.NumIn())
for i := 0; i < len(params); i++ { for i := 0; i < len(params); i++ {
t := ws.funcType.In(i) t := wsc.funcType.In(i)
if t == reflect.TypeOf(request) { if t == reflect.TypeOf(request) {
params[i] = reflect.ValueOf(request) params[i] = reflect.ValueOf(request)
} else if t == reflect.TypeOf(logger) { } else if t == reflect.TypeOf(logger) {
@ -81,11 +81,11 @@ func doWebsocketService(ws *websocketServiceType, request *Request, response *Re
params[i] = reflect.ValueOf(wsConn.Conn) params[i] = reflect.ValueOf(wsConn.Conn)
} else if object != nil && reflect.TypeOf(object).AssignableTo(t) { } else if object != nil && reflect.TypeOf(object).AssignableTo(t) {
params[i] = reflect.ValueOf(object) params[i] = reflect.ValueOf(object)
} else if obj := GetInject(t); obj != nil { } else if obj := ws.GetInject(t); obj != nil {
params[i] = reflect.ValueOf(obj) params[i] = reflect.ValueOf(obj)
} else { } else {
params[i] = reflect.New(t).Elem() params[i] = reflect.New(t).Elem()
} }
} }
ws.funcValue.Call(params) wsc.funcValue.Call(params)
} }

View File

@ -20,7 +20,7 @@ func TestWebSocketService(t *testing.T) {
}).Auth(0).Memo("test websocket") }).Auth(0).Memo("test websocket")
// 启动测试服务器 // 启动测试服务器
server := httptest.NewServer(&RouteHandler{}) server := httptest.NewServer(&RouteHandler{ws: DefaultServer})
defer server.Close() defer server.Close()
// 建立连接 // 建立连接