Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
038e14f3d8 | ||
|
|
a7c08cdf26 | ||
|
|
c891d37fad | ||
|
|
fbf7e6475c | ||
|
|
582de60053 | ||
|
|
94a4be81ec | ||
|
|
ff34d11c9b | ||
|
|
fe3b420d35 | ||
|
|
44951a9ab6 | ||
|
|
31c243e406 | ||
|
|
a4f5af3338 | ||
|
|
c88139e202 |
@ -878,17 +878,6 @@
|
|||||||
"Precision": 0,
|
"Precision": 0,
|
||||||
"WithoutKey": false,
|
"WithoutKey": false,
|
||||||
"Hide": false
|
"Hide": false
|
||||||
},
|
|
||||||
{
|
|
||||||
"Index": 8,
|
|
||||||
"Name": "CallStacks",
|
|
||||||
"KeyName": "",
|
|
||||||
"AttachBefore": false,
|
|
||||||
"Color": "",
|
|
||||||
"Format": "",
|
|
||||||
"Precision": 0,
|
|
||||||
"WithoutKey": false,
|
|
||||||
"Hide": false
|
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|||||||
53
CHANGELOG.md
53
CHANGELOG.md
@ -1,6 +1,57 @@
|
|||||||
# CHANGELOG - go/service
|
# CHANGELOG - go/service
|
||||||
|
|
||||||
## v1.5.1 (2026-06-04)
|
## v1.5.12 (2026-06-07)
|
||||||
|
- **基础设施对齐: 切换至 starter v1.5.3 编排模式**:
|
||||||
|
- 弃用已废弃的 `starter.Run()`,全面转向 `starter.Start() / starter.Wait()`。
|
||||||
|
|
||||||
|
## v1.5.11 (2026-06-06)
|
||||||
|
- **修复: 路由与静态文件匹配鲁棒性增强**:
|
||||||
|
- **路径参数提取**: 修复了正则匹配路由(如 `{name}`)无法正确提取并注入路径参数到业务函数的问题。
|
||||||
|
- **静态文件匹配**: 引入 `hostStatics` 有序路由表,实现“最长前缀匹配”策略,解决在复杂或重叠的静态目录配置下的匹配歧义问题。
|
||||||
|
- **URL 兼容性**: 针对包含空格、中文字符及特殊符号的复杂 URL,在路由与静态文件匹配阶段统一进行 Robust 处理,彻底解决 404 隐患。
|
||||||
|
- **Host 匹配增强**: 验证并明确了 Host 匹配的灵活性,支持 `hostname:port`, `hostname`, `:port` 的自动降级匹配。
|
||||||
|
|
||||||
|
## v1.5.10 (2026-06-05)
|
||||||
|
- **修复: Static 服务 URL 解码**:
|
||||||
|
- 修复了 \`service.Static()\` 在处理包含空格或特殊字符(如 \`%20\`)的请求路径时,因未解码导致文件匹配失败的问题。
|
||||||
|
- 在路由匹配阶段提前进行 URL Path 解码,提升整体路径匹配的健壮性。
|
||||||
|
|
||||||
|
## v1.5.9 (2026-06-05)
|
||||||
|
- **优化: 低代码环境深度对齐**:
|
||||||
|
- **JS 友好型 Header**: 引入 \`service.Header\` 包装类,提供大小写不敏感的 \`Get/Set/Add/Del\` 方法,提升脚本开发体验。
|
||||||
|
- **Cookie 遮蔽**: 在 \`Request\` 中实现方法遮蔽,确保 JS 侧看到的 Cookie 参数均为简化的 \`Service_Cookie\`,彻底解决穿透问题。
|
||||||
|
- **API 统一**: 将 \`Request.Headers()\` 重命名为 \`Request.Header()\`,与 \`Response.Header()\` 保持命名对齐。
|
||||||
|
- **重构**: 给内部字段(如 \`*http.Request\`, \`ResponseWriter\`)增加 \`js:"-"\` 标签,精准管控对 JS 暴露的 API 边界。
|
||||||
|
- **修复**: 解决了因 Header 包装导致的 Go 内部代码(\`handler.go\`, \`static.go\`)编译错误。
|
||||||
|
- **新特性: EnableWebDev 支持**:
|
||||||
|
- 引入了 `service.EnableWebDev(config watch.Config)`,支持自动刷新页面的开发模式。
|
||||||
|
- **WebSocket 同步**: 自动注册 `/_watch` 服务,与文件监听器协同工作。
|
||||||
|
- **智能 HTML 注入**: 采用 `OutFilter` 在 HTML 响应末尾精准注入 WebSocket 重连脚本,支持静态文件与动态服务。
|
||||||
|
- **性能优化**: 仅在开启开发模式时启用响应缓冲,生产环境无任何性能损失。
|
||||||
|
- **基础设施**: 增加包级 `AddShutdownHook` 支持,提供更优雅的资源回收机制。
|
||||||
|
- **依赖同步**: 升级至 `log v1.5.5`,对齐不带堆栈的 Warning 规范。
|
||||||
|
|
||||||
|
## v1.5.5 (2026-06-05)
|
||||||
|
- **依赖同步**: 全量对齐至 `@go` 基础设施最新版本(`log v1.5.4`, `starter v1.5.2`, `db v1.5.2`)。
|
||||||
|
|
||||||
|
## v1.5.4 (2026-06-05)
|
||||||
|
- **优化: 生命周期日志剥离**:
|
||||||
|
- 彻底移除了 `WebServer` 内部冗余的 `starting / stopping / stopped` 控制台日志输出。
|
||||||
|
- **权责对齐**: 服务的生命周期审计现已全量交由 `apigo.cc/go/starter (v1.5.2)` 接管。
|
||||||
|
- **专注输出**: `service` 模块现在仅负责输出关键的监听信息(如 `starting listener addr:[::]:8001 proto:http`),并自动继承 `starter` 分配的长 TraceID 建立链路关联。
|
||||||
|
|
||||||
|
## v1.5.3 (2026-06-04)
|
||||||
|
- **新特性**:
|
||||||
|
- `Static` 静态文件服务增加默认索引文件识别:当请求目录时,若未配置 `IndexFiles`,会自动尝试匹配 `index.html` 或 `index.htm`。
|
||||||
|
- **优化**:
|
||||||
|
- 依赖升级至 `apigo.cc/go/log v1.5.3`,实现对第三方库原始日志的自动劫持与 `serverId` 注入。
|
||||||
|
- 清理了启动阶段多余的诊断日志,确保“零配置”启动时的控制台纯净度。
|
||||||
|
- **稳定性**:
|
||||||
|
- 修复了 `Start()` 函数的幂等性问题,防止在手动调用与 `starter` 框架调度冲突时产生重复绑定端口的错误。
|
||||||
|
|
||||||
|
## v1.5.2 (2026-06-04)
|
||||||
|
- **架构重构**: 彻底移除全局状态泄露,重构为严格的单例模式。
|
||||||
|
...
|
||||||
- **修复**: 在 `WebServer.Start` 中显式调用 `config.Load(&Config, "service")`,确保启动时自动从 `env.yaml` 加载 `service:` 块。
|
- **修复**: 在 `WebServer.Start` 中显式调用 `config.Load(&Config, "service")`,确保启动时自动从 `env.yaml` 加载 `service:` 块。
|
||||||
- **修复**: 优化 `WebServer.Reload` 的配置加载逻辑,确保与启动加载逻辑保持一致。
|
- **修复**: 优化 `WebServer.Reload` 的配置加载逻辑,确保与启动加载逻辑保持一致。
|
||||||
|
|
||||||
|
|||||||
85
config.go
85
config.go
@ -1,10 +1,5 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
|
||||||
"apigo.cc/go/cast"
|
|
||||||
"path/filepath"
|
|
||||||
)
|
|
||||||
|
|
||||||
// CertSet SSL 证书配置
|
// CertSet SSL 证书配置
|
||||||
type CertSet struct {
|
type CertSet struct {
|
||||||
CertFile string
|
CertFile string
|
||||||
@ -73,83 +68,3 @@ type ServiceConfig struct {
|
|||||||
Rewrites map[string]map[string]any
|
Rewrites map[string]map[string]any
|
||||||
Statics map[string]map[string]string
|
Statics map[string]map[string]string
|
||||||
}
|
}
|
||||||
|
|
||||||
var Config = ServiceConfig{}
|
|
||||||
|
|
||||||
// ApplyConfig 将 ServiceConfig 中的路由策略应用到内部的文件级策略中
|
|
||||||
func ApplyConfig() {
|
|
||||||
hostPoliciesLock.Lock()
|
|
||||||
defer hostPoliciesLock.Unlock()
|
|
||||||
|
|
||||||
// 1. Proxies KV 解析
|
|
||||||
fileProxies = make(map[string][]*proxyType)
|
|
||||||
for host, kv := range Config.Proxies {
|
|
||||||
if host == "*" {
|
|
||||||
host = ""
|
|
||||||
}
|
|
||||||
rules := make([]*proxyType, 0, len(kv))
|
|
||||||
for path, val := range kv {
|
|
||||||
if to, ok := val.(string); ok {
|
|
||||||
// 极简 KV 模式: "/api/*": "user-svc/v1/*"
|
|
||||||
rules = append(rules, parseProxyRule(0, path, "", "", to))
|
|
||||||
} else {
|
|
||||||
// 对象模式: "/api/*": {"To": "...", "Auth": 1}
|
|
||||||
m, _ := cast.ToMap[string, any](val)
|
|
||||||
rules = append(rules, parseProxyRule(
|
|
||||||
cast.Int(m["Auth"]),
|
|
||||||
path,
|
|
||||||
cast.String(m["ToApp"]),
|
|
||||||
cast.String(m["ToPath"]),
|
|
||||||
cast.String(m["To"]),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fileProxies[host] = rules
|
|
||||||
rebuildProxiesUnderLock(host)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. Rewrites KV 解析
|
|
||||||
fileRewrites = make(map[string][]*rewriteType)
|
|
||||||
for host, kv := range Config.Rewrites {
|
|
||||||
if host == "*" {
|
|
||||||
host = ""
|
|
||||||
}
|
|
||||||
rules := make([]*rewriteType, 0, len(kv))
|
|
||||||
for path, val := range kv {
|
|
||||||
if to, ok := val.(string); ok {
|
|
||||||
rules = append(rules, parseRewriteRule(path, "", to))
|
|
||||||
} else {
|
|
||||||
m, _ := cast.ToMap[string, any](val)
|
|
||||||
rules = append(rules, parseRewriteRule(
|
|
||||||
path,
|
|
||||||
cast.String(m["ToPath"]),
|
|
||||||
cast.String(m["To"]),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fileRewrites[host] = rules
|
|
||||||
rebuildRewritesUnderLock(host)
|
|
||||||
}
|
|
||||||
|
|
||||||
staticsByHostLock.Lock()
|
|
||||||
defer staticsByHostLock.Unlock()
|
|
||||||
fileStatics = make(map[string]map[string]*string)
|
|
||||||
|
|
||||||
for host, config := range Config.Statics {
|
|
||||||
if host == "*" {
|
|
||||||
host = ""
|
|
||||||
}
|
|
||||||
newStatics := make(map[string]*string, len(config))
|
|
||||||
for path, rootPath := range config {
|
|
||||||
rp := rootPath
|
|
||||||
if !filepath.IsAbs(rp) {
|
|
||||||
if absPath, err := filepath.Abs(rp); err == nil {
|
|
||||||
rp = absPath
|
|
||||||
}
|
|
||||||
}
|
|
||||||
newStatics[path] = &rp
|
|
||||||
}
|
|
||||||
fileStatics[host] = newStatics
|
|
||||||
rebuildStaticsUnderLock(host)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
38
document.go
38
document.go
@ -21,11 +21,15 @@ type Api struct {
|
|||||||
|
|
||||||
// MakeDocument 生成文档数据
|
// MakeDocument 生成文档数据
|
||||||
func MakeDocument() []Api {
|
func MakeDocument() []Api {
|
||||||
|
return DefaultServer.MakeDocument()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *WebServer) MakeDocument() []Api {
|
||||||
out := make([]Api, 0)
|
out := make([]Api, 0)
|
||||||
|
|
||||||
// 1. Rewrite & Proxy
|
// 1. Rewrite & Proxy
|
||||||
hostPoliciesLock.RLock()
|
ws.hostPoliciesLock.RLock()
|
||||||
for host, rewrites := range hostRewrites {
|
for host, rewrites := range ws.hostRewrites {
|
||||||
for _, a := range rewrites {
|
for _, a := range rewrites {
|
||||||
out = append(out, Api{
|
out = append(out, Api{
|
||||||
Type: "Rewrite",
|
Type: "Rewrite",
|
||||||
@ -34,7 +38,7 @@ func MakeDocument() []Api {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for host, proxies := range hostProxies {
|
for host, proxies := range ws.hostProxies {
|
||||||
for _, a := range proxies {
|
for _, a := range proxies {
|
||||||
out = append(out, Api{
|
out = append(out, Api{
|
||||||
Type: "Proxy",
|
Type: "Proxy",
|
||||||
@ -43,11 +47,11 @@ func MakeDocument() []Api {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
hostPoliciesLock.RUnlock()
|
ws.hostPoliciesLock.RUnlock()
|
||||||
|
|
||||||
// 2. Web Services
|
// 2. Web Services
|
||||||
webServicesLock.RLock()
|
ws.webServicesLock.RLock()
|
||||||
for _, a := range webServicesList {
|
for _, a := range ws.webServicesList {
|
||||||
if a.options.NoDoc {
|
if a.options.NoDoc {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -67,22 +71,22 @@ func MakeDocument() []Api {
|
|||||||
}
|
}
|
||||||
out = append(out, api)
|
out = append(out, api)
|
||||||
}
|
}
|
||||||
webServicesLock.RUnlock()
|
ws.webServicesLock.RUnlock()
|
||||||
|
|
||||||
// 4. WebSocket Services
|
// 4. WebSocket Services
|
||||||
websocketServicesLock.RLock()
|
ws.websocketServicesLock.RLock()
|
||||||
for _, ws := range websocketServicesList {
|
for _, wsc := range ws.websocketServicesList {
|
||||||
api := Api{
|
api := Api{
|
||||||
Type: "WebSocket",
|
Type: "WebSocket",
|
||||||
Path: ws.path,
|
Path: wsc.path,
|
||||||
AuthLevel: ws.authLevel,
|
AuthLevel: wsc.authLevel,
|
||||||
Memo: ws.memo,
|
Memo: wsc.memo,
|
||||||
Host: ws.host,
|
Host: wsc.host,
|
||||||
}
|
}
|
||||||
if ws.funcType != nil && ws.funcType.NumIn() > 0 {
|
if wsc.funcType != nil && wsc.funcType.NumIn() > 0 {
|
||||||
// Find struct in
|
// Find struct in
|
||||||
for i := 0; i < ws.funcType.NumIn(); i++ {
|
for i := 0; i < wsc.funcType.NumIn(); i++ {
|
||||||
t := ws.funcType.In(i)
|
t := wsc.funcType.In(i)
|
||||||
if t.Kind() == reflect.Struct {
|
if t.Kind() == reflect.Struct {
|
||||||
api.In = getType(t)
|
api.In = getType(t)
|
||||||
break
|
break
|
||||||
@ -91,7 +95,7 @@ func MakeDocument() []Api {
|
|||||||
}
|
}
|
||||||
out = append(out, api)
|
out = append(out, api)
|
||||||
}
|
}
|
||||||
websocketServicesLock.RUnlock()
|
ws.websocketServicesLock.RUnlock()
|
||||||
|
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|||||||
8
go.mod
8
go.mod
@ -10,10 +10,10 @@ require (
|
|||||||
apigo.cc/go/http v1.5.0
|
apigo.cc/go/http v1.5.0
|
||||||
apigo.cc/go/id v1.5.0
|
apigo.cc/go/id v1.5.0
|
||||||
apigo.cc/go/jsmod v1.5.0
|
apigo.cc/go/jsmod v1.5.0
|
||||||
apigo.cc/go/log v1.5.0
|
apigo.cc/go/log v1.5.5
|
||||||
apigo.cc/go/redis v1.5.0
|
apigo.cc/go/redis v1.5.0
|
||||||
apigo.cc/go/safe v1.5.0
|
apigo.cc/go/safe v1.5.0
|
||||||
apigo.cc/go/starter v1.5.0
|
apigo.cc/go/starter v1.5.3
|
||||||
apigo.cc/go/timer v1.5.0
|
apigo.cc/go/timer v1.5.0
|
||||||
github.com/gorilla/websocket v1.5.3
|
github.com/gorilla/websocket v1.5.3
|
||||||
golang.org/x/net v0.54.0
|
golang.org/x/net v0.54.0
|
||||||
@ -25,8 +25,8 @@ require (
|
|||||||
apigo.cc/go/rand v1.5.0 // indirect
|
apigo.cc/go/rand v1.5.0 // indirect
|
||||||
apigo.cc/go/shell v1.5.0 // indirect
|
apigo.cc/go/shell v1.5.0 // indirect
|
||||||
github.com/gomodule/redigo v2.0.0+incompatible // indirect
|
github.com/gomodule/redigo v2.0.0+incompatible // indirect
|
||||||
golang.org/x/crypto v0.51.0 // indirect
|
golang.org/x/crypto v0.52.0 // indirect
|
||||||
golang.org/x/sys v0.44.0 // indirect
|
golang.org/x/sys v0.45.0 // indirect
|
||||||
golang.org/x/text v0.37.0 // indirect
|
golang.org/x/text v0.37.0 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
18
go.sum
18
go.sum
@ -1,7 +1,7 @@
|
|||||||
apigo.cc/go/cast v1.5.0 h1:UBGJtFQ8eJPMQXs37cUgqd7YQo1zI9opuSDBDmn2/pE=
|
apigo.cc/go/cast v1.5.0 h1:UBGJtFQ8eJPMQXs37cUgqd7YQo1zI9opuSDBDmn2/pE=
|
||||||
apigo.cc/go/cast v1.5.0/go.mod h1:z2GW5p5WCZGEqVVIJUdhl232vRbLf2Qu4EDlEakX/D8=
|
apigo.cc/go/cast v1.5.0/go.mod h1:z2GW5p5WCZGEqVVIJUdhl232vRbLf2Qu4EDlEakX/D8=
|
||||||
apigo.cc/go/config v1.5.0 h1:Yuz9QEb11XXG4XkhDi/ueT2M1T3Q9PElE5tiakvjehs=
|
apigo.cc/go/config v1.5.1 h1:rpj7oCzlsDV3f2/YK3Pb+CHbfr2DL5Vyyv6VNkobJP4=
|
||||||
apigo.cc/go/config v1.5.0/go.mod h1:jdMiDLPa9gzB8/FFZvm9jOopUqdxb7XSX+0OeWcZZUM=
|
apigo.cc/go/config v1.5.1/go.mod h1:jdMiDLPa9gzB8/FFZvm9jOopUqdxb7XSX+0OeWcZZUM=
|
||||||
apigo.cc/go/crypto v1.5.0 h1:Nxz7a6VKCdvaF258IU0NkjQyureOLxfR308Sy2iftUI=
|
apigo.cc/go/crypto v1.5.0 h1:Nxz7a6VKCdvaF258IU0NkjQyureOLxfR308Sy2iftUI=
|
||||||
apigo.cc/go/crypto v1.5.0/go.mod h1:F9M6nXv+5328r1ZwbTvI6fcr8VdgqHVzALOcsdv6ntE=
|
apigo.cc/go/crypto v1.5.0/go.mod h1:F9M6nXv+5328r1ZwbTvI6fcr8VdgqHVzALOcsdv6ntE=
|
||||||
apigo.cc/go/discover v1.5.0 h1:RGHulidyAHCZdGfpFytFUl3ur4aNVMXKlfJbAMCvgpo=
|
apigo.cc/go/discover v1.5.0 h1:RGHulidyAHCZdGfpFytFUl3ur4aNVMXKlfJbAMCvgpo=
|
||||||
@ -16,8 +16,7 @@ apigo.cc/go/id v1.5.0 h1:MjNWPhBhDsoXaLeJDv/0wfJmVMU9EvOs8pWYfsTQ6e8=
|
|||||||
apigo.cc/go/id v1.5.0/go.mod h1:qhu4a1/KLc/XcBpcsRu+mXZt7U7Wvd9zMcPs4VspuPA=
|
apigo.cc/go/id v1.5.0/go.mod h1:qhu4a1/KLc/XcBpcsRu+mXZt7U7Wvd9zMcPs4VspuPA=
|
||||||
apigo.cc/go/jsmod v1.5.0 h1:JgQtJNiJWy1NOP9AzE8NX5VXJkpO/x3GqLsCCSny5Ec=
|
apigo.cc/go/jsmod v1.5.0 h1:JgQtJNiJWy1NOP9AzE8NX5VXJkpO/x3GqLsCCSny5Ec=
|
||||||
apigo.cc/go/jsmod v1.5.0/go.mod h1:bmyeZtOAP/j5am+YRnaiM89smysK24K7ebk0koFtsSw=
|
apigo.cc/go/jsmod v1.5.0/go.mod h1:bmyeZtOAP/j5am+YRnaiM89smysK24K7ebk0koFtsSw=
|
||||||
apigo.cc/go/log v1.5.0 h1:kQuLLtbt33mEuc/xJVcy8NODXkso/QKSZWNclKrSpsI=
|
apigo.cc/go/log v1.5.5 h1:AFU7d7AQxkpgDHl7SnlEwd6yzGSFAlnrrjbrNDQnQHI=
|
||||||
apigo.cc/go/log v1.5.0/go.mod h1:Djy+I5aLhGB/EjwRz4KHqkVEz584IAD55FAFiIfInuo=
|
|
||||||
apigo.cc/go/rand v1.5.0 h1:1o8hh8fhdBuk1/h02IvugvamuT3dkWbVJrqEJVQKB2E=
|
apigo.cc/go/rand v1.5.0 h1:1o8hh8fhdBuk1/h02IvugvamuT3dkWbVJrqEJVQKB2E=
|
||||||
apigo.cc/go/rand v1.5.0/go.mod h1:Lh98S2dm9UY0X+M+kNQQEKyXHG5pcCKSFPyXN0QCGdk=
|
apigo.cc/go/rand v1.5.0/go.mod h1:Lh98S2dm9UY0X+M+kNQQEKyXHG5pcCKSFPyXN0QCGdk=
|
||||||
apigo.cc/go/redis v1.5.0 h1:VXNDqzKj87BchF7ubDEH+T6lp8NrjeK0izU4ooo7u1A=
|
apigo.cc/go/redis v1.5.0 h1:VXNDqzKj87BchF7ubDEH+T6lp8NrjeK0izU4ooo7u1A=
|
||||||
@ -26,8 +25,7 @@ apigo.cc/go/safe v1.5.0 h1:W1NblmcU8cex1f9Y5z8mNLUJOzZTE1s6fszb3FbhGnk=
|
|||||||
apigo.cc/go/safe v1.5.0/go.mod h1:OfQ5d6COePSGEuPvMeOk6KagX2sezw7nvKh7exj9SeM=
|
apigo.cc/go/safe v1.5.0/go.mod h1:OfQ5d6COePSGEuPvMeOk6KagX2sezw7nvKh7exj9SeM=
|
||||||
apigo.cc/go/shell v1.5.0 h1:WLDMMqUU0INeaBDmQsTPr0h/NfB2RknAtiJ5NL467+Q=
|
apigo.cc/go/shell v1.5.0 h1:WLDMMqUU0INeaBDmQsTPr0h/NfB2RknAtiJ5NL467+Q=
|
||||||
apigo.cc/go/shell v1.5.0/go.mod h1:rYHA77d5hEsQHcJrbAWf1pHy0sxayeJ0gU55LA/JWQk=
|
apigo.cc/go/shell v1.5.0/go.mod h1:rYHA77d5hEsQHcJrbAWf1pHy0sxayeJ0gU55LA/JWQk=
|
||||||
apigo.cc/go/starter v1.5.0 h1:z6wnDrGx/iM6Z+A86FbIW4Y1rNywGzPNY+y2vYQJeMw=
|
apigo.cc/go/starter v1.5.3 h1:kakDapul+l63w3Ah1pnBxD1mup9Fbt821omWCiaGwCE=
|
||||||
apigo.cc/go/starter v1.5.0/go.mod h1:ru2vVCIvBYDWZ9SmPP4JLyEueUh71Y24ww/wDvCT+Vs=
|
|
||||||
apigo.cc/go/timer v1.5.0 h1:iPo/IQn+iuhBRI1/MR1txwZnamef/RBBfOiIlBiqkgk=
|
apigo.cc/go/timer v1.5.0 h1:iPo/IQn+iuhBRI1/MR1txwZnamef/RBBfOiIlBiqkgk=
|
||||||
apigo.cc/go/timer v1.5.0/go.mod h1:kOnqTTX+zA4AH7SfC+LpUm4ZvS+DVyWWMqul/V5QWJs=
|
apigo.cc/go/timer v1.5.0/go.mod h1:kOnqTTX+zA4AH7SfC+LpUm4ZvS+DVyWWMqul/V5QWJs=
|
||||||
github.com/gomodule/redigo v2.0.0+incompatible h1:K/R+8tc58AaqLkqG2Ol3Qk+DR/TlNuhuh457pBFPtt0=
|
github.com/gomodule/redigo v2.0.0+incompatible h1:K/R+8tc58AaqLkqG2Ol3Qk+DR/TlNuhuh457pBFPtt0=
|
||||||
@ -40,12 +38,12 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
|||||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||||
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||||
golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI=
|
golang.org/x/crypto v0.52.0 h1:RMs7fP2rXdep0CftQlK8Uf+kibLm7qkCcradZWYz988=
|
||||||
golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8=
|
golang.org/x/crypto v0.52.0/go.mod h1:1QgfPxDqh0T2M/elOJtp9RvuR95kVjir0e6/BvEmGbc=
|
||||||
golang.org/x/net v0.54.0 h1:2zJIZAxAHV/OHCDTCOHAYehQzLfSXuf/5SoL/Dv6w/w=
|
golang.org/x/net v0.54.0 h1:2zJIZAxAHV/OHCDTCOHAYehQzLfSXuf/5SoL/Dv6w/w=
|
||||||
golang.org/x/net v0.54.0/go.mod h1:Sj4oj8jK6XmHpBZU/zWHw3BV3abl4Kvi+Ut7cQcY+cQ=
|
golang.org/x/net v0.54.0/go.mod h1:Sj4oj8jK6XmHpBZU/zWHw3BV3abl4Kvi+Ut7cQcY+cQ=
|
||||||
golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ=
|
golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY=
|
||||||
golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||||
golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
|
golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
|
||||||
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
|
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
|||||||
152
handler.go
152
handler.go
@ -7,6 +7,7 @@ import (
|
|||||||
"apigo.cc/go/timer"
|
"apigo.cc/go/timer"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"reflect"
|
"reflect"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
"strings"
|
"strings"
|
||||||
@ -15,10 +16,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type RouteHandler struct {
|
type RouteHandler struct {
|
||||||
|
ws *WebServer
|
||||||
webRequestingNum int64
|
webRequestingNum int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ws := rh.ws
|
||||||
atomic.AddInt64(&rh.webRequestingNum, 1)
|
atomic.AddInt64(&rh.webRequestingNum, 1)
|
||||||
defer atomic.AddInt64(&rh.webRequestingNum, -1)
|
defer atomic.AddInt64(&rh.webRequestingNum, -1)
|
||||||
|
|
||||||
@ -31,15 +34,17 @@ func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
request := NewRequest(r)
|
request := NewRequest(r)
|
||||||
request.Id = requestId
|
request.Id = requestId
|
||||||
response := NewResponse(w)
|
response := NewResponse(w, ws)
|
||||||
response.Id = requestId
|
response.Id = requestId
|
||||||
requestLogger := log.New(requestId)
|
requestLogger := log.New(requestId)
|
||||||
|
|
||||||
// 0. 延迟处理日志与状态检查
|
// 0. 延迟处理日志与状态检查
|
||||||
var s *webServiceType
|
var s *webServiceType
|
||||||
|
var wsc *websocketServiceType
|
||||||
var authLevel int
|
var authLevel int
|
||||||
var priority int
|
var priority int
|
||||||
var args = make(map[string]any)
|
var args = make(map[string]any)
|
||||||
|
var result any
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
// 捕捉 Panic
|
// 捕捉 Panic
|
||||||
@ -55,7 +60,7 @@ func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// 记录日志
|
// 记录日志
|
||||||
if (s == nil || !s.options.NoLog200 || response.Code != 200) &&
|
if (s == nil || !s.options.NoLog200 || response.Code != 200) &&
|
||||||
!(Config.NoLogGets && r.Method == http.MethodGet && response.Code == 200) {
|
!(ws.Config.NoLogGets && r.Method == http.MethodGet && response.Code == 200) {
|
||||||
|
|
||||||
scheme := "http"
|
scheme := "http"
|
||||||
if r.TLS != nil {
|
if r.TLS != nil {
|
||||||
@ -65,7 +70,7 @@ func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// 过滤请求头
|
// 过滤请求头
|
||||||
reqHeaders := make(map[string]string)
|
reqHeaders := make(map[string]string)
|
||||||
noLogHeaders := strings.Split(Config.NoLogHeaders, ",")
|
noLogHeaders := strings.Split(ws.Config.NoLogHeaders, ",")
|
||||||
for k, v := range r.Header {
|
for k, v := range r.Header {
|
||||||
skip := false
|
skip := false
|
||||||
for _, nl := range noLogHeaders {
|
for _, nl := range noLogHeaders {
|
||||||
@ -81,7 +86,7 @@ func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// 过滤响应头
|
// 过滤响应头
|
||||||
respHeaders := make(map[string]string)
|
respHeaders := make(map[string]string)
|
||||||
for k, v := range response.Header() {
|
for k, v := range response.Header().H {
|
||||||
respHeaders[k] = strings.Join(v, ", ")
|
respHeaders[k] = strings.Join(v, ", ")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -93,7 +98,7 @@ func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
} else {
|
} else {
|
||||||
respData = string(response.body[:1024]) + "..."
|
respData = string(response.body[:1024]) + "..."
|
||||||
}
|
}
|
||||||
} else if Config.NoLogOutputFields != "" {
|
} else if ws.Config.NoLogOutputFields != "" {
|
||||||
// 简单的字段过滤逻辑 (如果是 JSON 对象)
|
// 简单的字段过滤逻辑 (如果是 JSON 对象)
|
||||||
// 这里可以根据 Config.NoLogOutputFields, LogOutputArrayNum, LogOutputFieldSize 进行更复杂的处理
|
// 这里可以根据 Config.NoLogOutputFields, LogOutputArrayNum, LogOutputFieldSize 进行更复杂的处理
|
||||||
// 暂按字符串截断处理
|
// 暂按字符串截断处理
|
||||||
@ -109,8 +114,8 @@ func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
entry.Scheme = scheme
|
entry.Scheme = scheme
|
||||||
entry.Proto = r.Proto
|
entry.Proto = r.Proto
|
||||||
entry.ClientIp = request.ClientIp()
|
entry.ClientIp = request.ClientIp()
|
||||||
entry.ServerId = serverId
|
entry.ServerId = ws.serverId
|
||||||
entry.App = Config.App
|
entry.App = ws.Config.App
|
||||||
entry.FromApp = r.Header.Get(discover.HeaderFromApp)
|
entry.FromApp = r.Header.Get(discover.HeaderFromApp)
|
||||||
entry.FromNode = r.Header.Get(discover.HeaderFromNode)
|
entry.FromNode = r.Header.Get(discover.HeaderFromNode)
|
||||||
entry.DeviceId = request.DeviceId()
|
entry.DeviceId = request.DeviceId()
|
||||||
@ -131,36 +136,34 @@ func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
// 处理 SessionId 和 DeviceId
|
// 处理 SessionId 和 DeviceId
|
||||||
handleClientKeys(request, response)
|
ws.handleClientKeys(request, response)
|
||||||
|
|
||||||
// 1. 处理重写 (Rewrite)
|
// 1. 处理重写 (Rewrite)
|
||||||
if processRewrite(request, response, requestLogger) {
|
if ws.processRewrite(request, response, requestLogger) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 处理代理 (Proxy)
|
// 2. 处理代理 (Proxy)
|
||||||
if processProxy(request, response, requestLogger) {
|
if ws.processProxy(request, response, requestLogger) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 路由匹配
|
// 3. 路由匹配
|
||||||
path := r.URL.Path
|
path, _ := url.PathUnescape(r.URL.Path)
|
||||||
host := r.Host
|
host := r.Host
|
||||||
|
|
||||||
// 处理静态文件
|
// 处理静态文件
|
||||||
if processStatic(path, request, response, requestLogger) {
|
if ws.processStatic(path, request, response, requestLogger) {
|
||||||
return
|
goto filter
|
||||||
}
|
}
|
||||||
|
|
||||||
var ws *websocketServiceType
|
s, wsc = ws.findService(r.Method, host, path, args)
|
||||||
s, ws = findService(r.Method, host, path)
|
|
||||||
|
|
||||||
// 4. 参数解析 (Form & Body)
|
// 4. 参数解析 (Form & Body)
|
||||||
parseRequestArgs(request, args)
|
parseRequestArgs(request, args)
|
||||||
|
|
||||||
// 5. 前置过滤器
|
// 5. 前置过滤器
|
||||||
var result any
|
for _, filter := range ws.inFilters {
|
||||||
for _, filter := range inFilters {
|
|
||||||
result = filter(&args, request, response, requestLogger)
|
result = filter(&args, request, response, requestLogger)
|
||||||
if result != nil {
|
if result != nil {
|
||||||
break
|
break
|
||||||
@ -174,22 +177,22 @@ func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// 6. 处理业务执行 (WS 或 Web)
|
// 6. 处理业务执行 (WS 或 Web)
|
||||||
if result == nil {
|
if result == nil {
|
||||||
if ws != nil {
|
if wsc != nil {
|
||||||
authLevel = ws.authLevel
|
authLevel = wsc.authLevel
|
||||||
priority = ws.options.Priority
|
priority = wsc.options.Priority
|
||||||
// 鉴权
|
// 鉴权
|
||||||
pass, obj := checkAuth(ws.authLevel, &ws.options, request, response, args, requestLogger)
|
pass, obj := ws.checkAuth(wsc.authLevel, &wsc.options, request, response, args, requestLogger)
|
||||||
if !pass {
|
if !pass {
|
||||||
if !response.changed {
|
if !response.changed {
|
||||||
response.WriteHeader(http.StatusForbidden)
|
response.WriteHeader(http.StatusForbidden)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
doWebsocketService(ws, request, response, requestLogger, obj)
|
ws.doWebsocketService(wsc, request, response, requestLogger, obj)
|
||||||
return
|
return
|
||||||
} else if s != nil {
|
} else if s != nil {
|
||||||
// 鉴权
|
// 鉴权
|
||||||
pass, obj := checkAuth(s.authLevel, &s.options, request, response, args, requestLogger)
|
pass, obj := ws.checkAuth(s.authLevel, &s.options, request, response, args, requestLogger)
|
||||||
if !pass {
|
if !pass {
|
||||||
if !response.changed {
|
if !response.changed {
|
||||||
response.WriteHeader(http.StatusForbidden)
|
response.WriteHeader(http.StatusForbidden)
|
||||||
@ -197,19 +200,24 @@ func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 执行业务
|
// 执行业务
|
||||||
result = doWebService(s, request, response, args, nil, requestLogger, obj)
|
result = ws.doWebService(s, request, response, args, nil, requestLogger, obj)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if s == nil && result == nil && !response.changed {
|
if s == nil && result == nil && !response.changed {
|
||||||
response.WriteHeader(http.StatusNotFound)
|
response.WriteHeader(http.StatusNotFound)
|
||||||
|
result = "404 page not found"
|
||||||
}
|
}
|
||||||
|
|
||||||
// 7. 后置过滤器
|
filter:
|
||||||
for _, filter := range outFilters {
|
// 7. 后置过滤器 (即使 response.changed 也要执行,比如静态文件的 HTML 注入)
|
||||||
|
for _, filter := range ws.outFilters {
|
||||||
newResult, done := filter(args, request, response, result, requestLogger)
|
newResult, done := filter(args, request, response, result, requestLogger)
|
||||||
if newResult != nil {
|
if newResult != nil {
|
||||||
result = newResult
|
result = newResult
|
||||||
|
// 如果 response.changed 为 true,说明已经有内容写出了。
|
||||||
|
// 如果过滤器返回了非 nil 的 result,我们通常认为它想替换或追加内容。
|
||||||
|
// 特别是对于静态文件,如果我们清空了 body 并返回了新内容,result 就不再是 nil。
|
||||||
}
|
}
|
||||||
if done {
|
if done {
|
||||||
break
|
break
|
||||||
@ -217,7 +225,19 @@ func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 8. 输出结果
|
// 8. 输出结果
|
||||||
|
if ws.hasOutFilter {
|
||||||
|
// 过滤器模式:所有内容都应该从 result 或 response.body 中写出
|
||||||
|
if result != nil {
|
||||||
outputResult(response, result)
|
outputResult(response, result)
|
||||||
|
} else if response.changed {
|
||||||
|
response.PhysicalWrite(response.body)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 普通模式:result (业务返回值) 需要写出,而 response.changed (比如静态文件) 已经由 Response.Write 写过了
|
||||||
|
if result != nil {
|
||||||
|
outputResult(response, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func hostOnly(host string) string {
|
func hostOnly(host string) string {
|
||||||
@ -225,9 +245,9 @@ func hostOnly(host string) string {
|
|||||||
return h
|
return h
|
||||||
}
|
}
|
||||||
|
|
||||||
func findService(method, host, path string) (*webServiceType, *websocketServiceType) {
|
func (ws *WebServer) findService(method, host, path string, args map[string]any) (*webServiceType, *websocketServiceType) {
|
||||||
webServicesLock.RLock()
|
ws.webServicesLock.RLock()
|
||||||
defer webServicesLock.RUnlock()
|
defer ws.webServicesLock.RUnlock()
|
||||||
|
|
||||||
// 1. 准备 Host 候选列表: "host:port", "host", ":port", "*"
|
// 1. 准备 Host 候选列表: "host:port", "host", ":port", "*"
|
||||||
hostOnly, port, _ := strings.Cut(host, ":")
|
hostOnly, port, _ := strings.Cut(host, ":")
|
||||||
@ -239,7 +259,7 @@ func findService(method, host, path string) (*webServiceType, *websocketServiceT
|
|||||||
|
|
||||||
// 2. 匹配 Web Service
|
// 2. 匹配 Web Service
|
||||||
for _, h := range hosts {
|
for _, h := range hosts {
|
||||||
if services, exists := webServices[h]; exists {
|
if services, exists := ws.webServices[h]; exists {
|
||||||
if s, ok := services[method+path]; ok {
|
if s, ok := services[method+path]; ok {
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
@ -250,10 +270,10 @@ func findService(method, host, path string) (*webServiceType, *websocketServiceT
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 3. 匹配 WebSocket
|
// 3. 匹配 WebSocket
|
||||||
websocketServicesLock.RLock()
|
ws.websocketServicesLock.RLock()
|
||||||
defer websocketServicesLock.RUnlock()
|
defer ws.websocketServicesLock.RUnlock()
|
||||||
for _, h := range hosts {
|
for _, h := range hosts {
|
||||||
if services, exists := websocketServices[h]; exists {
|
if services, exists := ws.websocketServices[h]; exists {
|
||||||
if ws, ok := services[path]; ok {
|
if ws, ok := services[path]; ok {
|
||||||
return nil, ws
|
return nil, ws
|
||||||
}
|
}
|
||||||
@ -262,13 +282,21 @@ func findService(method, host, path string) (*webServiceType, *websocketServiceT
|
|||||||
|
|
||||||
// 4. 正则匹配
|
// 4. 正则匹配
|
||||||
for _, h := range hosts {
|
for _, h := range hosts {
|
||||||
if services, exists := regexWebServices[h]; exists {
|
if services, exists := ws.regexWebServices[h]; exists {
|
||||||
for i := len(services) - 1; i >= 0; i-- {
|
for i := len(services) - 1; i >= 0; i-- {
|
||||||
s := services[i]
|
s := services[i]
|
||||||
if s.method != "*" && s.method != method {
|
if s.method != "*" && s.method != method {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if s.pathMatcher != nil && s.pathMatcher.MatchString(path) {
|
if s.pathMatcher != nil && s.pathMatcher.MatchString(path) {
|
||||||
|
matches := s.pathMatcher.FindStringSubmatch(path)
|
||||||
|
if len(matches) > 1 {
|
||||||
|
for i, name := range s.pathArgs {
|
||||||
|
if i+1 < len(matches) {
|
||||||
|
args[name] = matches[i+1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -291,7 +319,7 @@ func parseRequestArgs(request *Request, args map[string]any) {
|
|||||||
|
|
||||||
// Form params
|
// Form params
|
||||||
if request.Method == http.MethodPost || request.Method == http.MethodPut {
|
if request.Method == http.MethodPost || request.Method == http.MethodPut {
|
||||||
contentType := request.Header.Get("Content-Type")
|
contentType := request.Header().Get("Content-Type")
|
||||||
if strings.HasPrefix(contentType, "application/json") {
|
if strings.HasPrefix(contentType, "application/json") {
|
||||||
body, _ := io.ReadAll(request.Body)
|
body, _ := io.ReadAll(request.Body)
|
||||||
_ = request.Body.Close()
|
_ = request.Body.Close()
|
||||||
@ -311,10 +339,10 @@ func parseRequestArgs(request *Request, args map[string]any) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkAuth(authLevel int, options *WebServiceOptions, request *Request, response *Response, args map[string]any, logger *log.Logger) (bool, any) {
|
func (ws *WebServer) checkAuth(authLevel int, options *WebServiceOptions, request *Request, response *Response, args map[string]any, logger *log.Logger) (bool, any) {
|
||||||
ac := webAuthCheckers[authLevel]
|
ac := ws.webAuthCheckers[authLevel]
|
||||||
if ac == nil {
|
if ac == nil {
|
||||||
ac = webAuthChecker
|
ac = ws.webAuthChecker
|
||||||
}
|
}
|
||||||
if ac == nil {
|
if ac == nil {
|
||||||
sess := NewSession(request.SessionId(), logger)
|
sess := NewSession(request.SessionId(), logger)
|
||||||
@ -330,7 +358,7 @@ func checkAuth(authLevel int, options *WebServiceOptions, request *Request, resp
|
|||||||
return pass, obj
|
return pass, obj
|
||||||
}
|
}
|
||||||
|
|
||||||
func doWebService(service *webServiceType, request *Request, response *Response, args map[string]any,
|
func (ws *WebServer) doWebService(service *webServiceType, request *Request, response *Response, args map[string]any,
|
||||||
result any, logger *log.Logger, object any) any {
|
result any, logger *log.Logger, object any) any {
|
||||||
if result != nil {
|
if result != nil {
|
||||||
return result
|
return result
|
||||||
@ -365,7 +393,7 @@ func doWebService(service *webServiceType, request *Request, response *Response,
|
|||||||
// 尝试依赖注入
|
// 尝试依赖注入
|
||||||
if object != nil && reflect.TypeOf(object).AssignableTo(t) {
|
if object != nil && reflect.TypeOf(object).AssignableTo(t) {
|
||||||
params[i] = reflect.ValueOf(object)
|
params[i] = reflect.ValueOf(object)
|
||||||
} else if obj := GetInject(t); obj != nil {
|
} else if obj := ws.GetInject(t); obj != nil {
|
||||||
params[i] = reflect.ValueOf(obj)
|
params[i] = reflect.ValueOf(obj)
|
||||||
} else {
|
} else {
|
||||||
params[i] = reflect.New(t).Elem()
|
params[i] = reflect.New(t).Elem()
|
||||||
@ -401,50 +429,54 @@ func outputResult(response *Response, result any) {
|
|||||||
if contentType != "" && response.Header().Get("Content-Type") == "" {
|
if contentType != "" && response.Header().Get("Content-Type") == "" {
|
||||||
response.Header().Set("Content-Type", contentType)
|
response.Header().Set("Content-Type", contentType)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if response.server != nil && response.server.hasOutFilter {
|
||||||
|
response.PhysicalWrite(data)
|
||||||
|
} else {
|
||||||
_, _ = response.Write(data)
|
_, _ = response.Write(data)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
func handleClientKeys(request *Request, response *Response) {
|
func (ws *WebServer) handleClientKeys(request *Request, response *Response) {
|
||||||
// SessionId
|
// SessionId
|
||||||
if usedSessionIdKey != "" {
|
if ws.usedSessionIdKey != "" {
|
||||||
sessionId := request.Header.Get(usedSessionIdKey)
|
sessionId := request.Header().Get(ws.usedSessionIdKey)
|
||||||
if sessionId == "" && !Config.SessionWithoutCookie {
|
if sessionId == "" && !ws.Config.SessionWithoutCookie {
|
||||||
if ck, err := request.Cookie(usedSessionIdKey); err == nil {
|
if ck := request.GetCookie(ws.usedSessionIdKey); ck != nil {
|
||||||
sessionId = ck.Value
|
sessionId = ck.Value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if sessionId == "" {
|
if sessionId == "" {
|
||||||
if sessionIdMaker != nil {
|
if ws.sessionIdMaker != nil {
|
||||||
sessionId = sessionIdMaker()
|
sessionId = ws.sessionIdMaker()
|
||||||
} else {
|
} else {
|
||||||
sessionId = IDMaker.Get11Bytes900MPerSecond()
|
sessionId = IDMaker.Get11Bytes900MPerSecond()
|
||||||
}
|
}
|
||||||
if !Config.SessionWithoutCookie {
|
if !ws.Config.SessionWithoutCookie {
|
||||||
http.SetCookie(response.Writer, &http.Cookie{
|
http.SetCookie(response.Writer, &http.Cookie{
|
||||||
Name: usedSessionIdKey,
|
Name: ws.usedSessionIdKey,
|
||||||
Value: sessionId,
|
Value: sessionId,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
request.Header.Set(discover.HeaderSessionID, sessionId)
|
request.Request.Header.Set(discover.HeaderSessionID, sessionId)
|
||||||
response.Header().Set(usedSessionIdKey, sessionId)
|
response.Header().Set(ws.usedSessionIdKey, sessionId)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeviceId
|
// DeviceId
|
||||||
if usedDeviceIdKey != "" {
|
if ws.usedDeviceIdKey != "" {
|
||||||
deviceId := request.Header.Get(usedDeviceIdKey)
|
deviceId := request.Header().Get(ws.usedDeviceIdKey)
|
||||||
if deviceId == "" && !Config.DeviceWithoutCookie {
|
if deviceId == "" && !ws.Config.DeviceWithoutCookie {
|
||||||
if ck, err := request.Cookie(usedDeviceIdKey); err == nil {
|
if ck := request.GetCookie(ws.usedDeviceIdKey); ck != nil {
|
||||||
deviceId = ck.Value
|
deviceId = ck.Value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if deviceId == "" {
|
if deviceId == "" {
|
||||||
deviceId = IDMaker.Get11Bytes900MPerSecond()
|
deviceId = IDMaker.Get11Bytes900MPerSecond()
|
||||||
if !Config.DeviceWithoutCookie {
|
if !ws.Config.DeviceWithoutCookie {
|
||||||
http.SetCookie(response.Writer, &http.Cookie{
|
http.SetCookie(response.Writer, &http.Cookie{
|
||||||
Name: usedDeviceIdKey,
|
Name: ws.usedDeviceIdKey,
|
||||||
Value: deviceId,
|
Value: deviceId,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Expires: time.Now().AddDate(10, 0, 0),
|
Expires: time.Now().AddDate(10, 0, 0),
|
||||||
@ -452,7 +484,7 @@ func handleClientKeys(request *Request, response *Response) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
request.Header.Set(discover.HeaderDeviceID, deviceId)
|
request.Request.Header.Set(discover.HeaderDeviceID, deviceId)
|
||||||
response.Header().Set(usedDeviceIdKey, deviceId)
|
response.Header().Set(ws.usedDeviceIdKey, deviceId)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -14,7 +14,7 @@ func TestServeHTTP(t *testing.T) {
|
|||||||
}
|
}
|
||||||
Host("*").POST("/hello", handler).Auth(0).Memo("say hello")
|
Host("*").POST("/hello", handler).Auth(0).Memo("say hello")
|
||||||
|
|
||||||
rh := &RouteHandler{}
|
rh := &RouteHandler{ws: DefaultServer}
|
||||||
|
|
||||||
// 模拟请求
|
// 模拟请求
|
||||||
req := httptest.NewRequest("POST", "/hello", strings.NewReader(`{"name":"Star"}`))
|
req := httptest.NewRequest("POST", "/hello", strings.NewReader(`{"name":"Star"}`))
|
||||||
@ -34,7 +34,7 @@ func TestServeHTTP(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestServeHTTP_404(t *testing.T) {
|
func TestServeHTTP_404(t *testing.T) {
|
||||||
rh := &RouteHandler{}
|
rh := &RouteHandler{ws: DefaultServer}
|
||||||
req := httptest.NewRequest("GET", "/notfound", nil)
|
req := httptest.NewRequest("GET", "/notfound", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
@ -54,7 +54,7 @@ func TestServeHTTP_VerifyFailed(t *testing.T) {
|
|||||||
}
|
}
|
||||||
Host("*").POST("/verify", handler).Auth(0).Memo("test verify")
|
Host("*").POST("/verify", handler).Auth(0).Memo("test verify")
|
||||||
|
|
||||||
rh := &RouteHandler{}
|
rh := &RouteHandler{ws: DefaultServer}
|
||||||
req := httptest.NewRequest("POST", "/verify", strings.NewReader(`{"age":10}`))
|
req := httptest.NewRequest("POST", "/verify", strings.NewReader(`{"age":10}`))
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@ -71,7 +71,7 @@ func TestServeHTTP_Panic(t *testing.T) {
|
|||||||
panic("intentional panic")
|
panic("intentional panic")
|
||||||
})
|
})
|
||||||
|
|
||||||
rh := &RouteHandler{}
|
rh := &RouteHandler{ws: DefaultServer}
|
||||||
req := httptest.NewRequest("GET", "/panic", nil)
|
req := httptest.NewRequest("GET", "/panic", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
|||||||
14
js_export.go
14
js_export.go
@ -6,15 +6,15 @@ import (
|
|||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
jsmod.Register("service", map[string]any{
|
jsmod.Register("service", map[string]any{
|
||||||
// 类型占位工厂 (用于 AI 发现类型结构)
|
// 类型占位工厂 (用于 AI 发现类型结构,生成文档时隐藏)
|
||||||
"newRequest": func() *Request { return &Request{} },
|
"__exportRequest": func() *Request { return &Request{} },
|
||||||
"newResponse": func() *Response { return &Response{} },
|
"__exportResponse": func() *Response { return &Response{} },
|
||||||
"newWebSocket": func() *WebSocketConn { return &WebSocketConn{} },
|
"__exportWebSocket": func() *WebSocketConn { return &WebSocketConn{} },
|
||||||
"newSession": func() *Session { return &Session{} },
|
"__exportSession": func() *Session { return &Session{} },
|
||||||
"newFile": func() *jsUploadFile { return &jsUploadFile{} },
|
"__exportFile": func() *jsUploadFile { return &jsUploadFile{} },
|
||||||
|
|
||||||
// 功能函数
|
// 功能函数
|
||||||
"upgrade": Upgrade,
|
"Upgrade": Upgrade,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
77
proxy.go
77
proxy.go
@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@ -77,24 +78,17 @@ func parseProxyRule(authLevel int, path, toApp, toPath string, to string) *proxy
|
|||||||
func (hc *HostContext) Proxy(authLevel int, path string, to string) *HostContext {
|
func (hc *HostContext) Proxy(authLevel int, path string, to string) *HostContext {
|
||||||
p := parseProxyRule(authLevel, path, "", "", to)
|
p := parseProxyRule(authLevel, path, "", "", to)
|
||||||
|
|
||||||
hostPoliciesLock.Lock()
|
hc.ws.hostPoliciesLock.Lock()
|
||||||
defer hostPoliciesLock.Unlock()
|
defer hc.ws.hostPoliciesLock.Unlock()
|
||||||
codeProxies[hc.host] = append(codeProxies[hc.host], p)
|
if hc.ws.codeProxies[hc.host] == nil {
|
||||||
rebuildProxiesUnderLock(hc.host)
|
hc.ws.codeProxies[hc.host] = make([]*proxyType, 0)
|
||||||
|
}
|
||||||
|
hc.ws.codeProxies[hc.host] = append(hc.ws.codeProxies[hc.host], p)
|
||||||
|
hc.ws.rebuildProxiesUnderLock(hc.host)
|
||||||
return hc
|
return hc
|
||||||
}
|
}
|
||||||
|
|
||||||
func rebuildProxiesUnderLock(host string) {
|
func (ws *WebServer) findProxy(request *Request) (int, *string, *string, string) {
|
||||||
var combined []*proxyType
|
|
||||||
combined = append(combined, codeProxies[host]...)
|
|
||||||
combined = append(combined, fileProxies[host]...)
|
|
||||||
combined = append(combined, dynamicProxies[host]...)
|
|
||||||
hostProxies[host] = combined
|
|
||||||
}
|
|
||||||
|
|
||||||
var httpClientPool *gohttp.Client
|
|
||||||
|
|
||||||
func findProxy(request *Request) (int, *string, *string, string) {
|
|
||||||
host := request.Host
|
host := request.Host
|
||||||
hostOnly, port, _ := strings.Cut(host, ":")
|
hostOnly, port, _ := strings.Cut(host, ":")
|
||||||
hosts := []string{host}
|
hosts := []string{host}
|
||||||
@ -110,11 +104,11 @@ func findProxy(request *Request) (int, *string, *string, string) {
|
|||||||
requestPath = requestPath[:pos]
|
requestPath = requestPath[:pos]
|
||||||
}
|
}
|
||||||
|
|
||||||
hostPoliciesLock.RLock()
|
ws.hostPoliciesLock.RLock()
|
||||||
defer hostPoliciesLock.RUnlock()
|
defer ws.hostPoliciesLock.RUnlock()
|
||||||
|
|
||||||
for _, h := range hosts {
|
for _, h := range hosts {
|
||||||
proxies, exists := hostProxies[h]
|
proxies, exists := ws.hostProxies[h]
|
||||||
if !exists {
|
if !exists {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -151,15 +145,15 @@ func findProxy(request *Request) (int, *string, *string, string) {
|
|||||||
return 0, nil, nil, ""
|
return 0, nil, nil, ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func processProxy(request *Request, response *Response, logger *log.Logger) bool {
|
func (ws *WebServer) processProxy(request *Request, response *Response, logger *log.Logger) bool {
|
||||||
authLevel, proxyToApp, proxyToPath, foundHost := findProxy(request)
|
authLevel, proxyToApp, proxyToPath, foundHost := ws.findProxy(request)
|
||||||
|
|
||||||
if proxyToApp == nil || proxyToPath == nil || *proxyToApp == "" || *proxyToPath == "" {
|
if proxyToApp == nil || proxyToPath == nil || *proxyToApp == "" || *proxyToPath == "" {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// 鉴权
|
// 鉴权
|
||||||
pass, obj := checkAuthForProxy(authLevel, request, response, logger)
|
pass, obj := ws.checkAuthForProxy(authLevel, request, response, logger)
|
||||||
if !pass {
|
if !pass {
|
||||||
if !response.changed {
|
if !response.changed {
|
||||||
response.WriteHeader(http.StatusForbidden)
|
response.WriteHeader(http.StatusForbidden)
|
||||||
@ -174,19 +168,16 @@ func processProxy(request *Request, response *Response, logger *log.Logger) bool
|
|||||||
|
|
||||||
if strings.Contains(app, "://") {
|
if strings.Contains(app, "://") {
|
||||||
// 直接 URL 代理
|
// 直接 URL 代理
|
||||||
if httpClientPool == nil {
|
res := ws.getHttpClient().ManualDoByRequest(request.Request, request.Method, app+path, request.Body)
|
||||||
httpClientPool = gohttp.NewClient(time.Duration(Config.RedirectTimeout) * time.Millisecond)
|
|
||||||
}
|
|
||||||
res := httpClientPool.ManualDoByRequest(request.Request, request.Method, app+path, request.Body)
|
|
||||||
copyResponse(res, response, logger)
|
copyResponse(res, response, logger)
|
||||||
} else {
|
} else {
|
||||||
// Discover 代理
|
// Discover 代理
|
||||||
if GlobalDiscoverer == nil {
|
if ws.discoverer == nil {
|
||||||
logger.Error("proxy failed: GlobalDiscoverer is not initialized")
|
logger.Error("proxy failed: Discoverer is not initialized")
|
||||||
response.WriteHeader(http.StatusBadGateway)
|
response.WriteHeader(http.StatusBadGateway)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
caller := GlobalDiscoverer.NewCaller(request.Request, logger)
|
caller := ws.discoverer.NewCaller(request.Request, logger)
|
||||||
caller.NoBody = true
|
caller.NoBody = true
|
||||||
res, _ := caller.ManualDoWithNode(request.Method, app, "", path, request.Body)
|
res, _ := caller.ManualDoWithNode(request.Method, app, "", path, request.Body)
|
||||||
copyResponse(res, response, logger)
|
copyResponse(res, response, logger)
|
||||||
@ -195,10 +186,22 @@ func processProxy(request *Request, response *Response, logger *log.Logger) bool
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkAuthForProxy(authLevel int, request *Request, response *Response, logger *log.Logger) (bool, any) {
|
func (ws *WebServer) getHttpClient() *gohttp.Client {
|
||||||
ac := webAuthCheckers[authLevel]
|
// 尝试从注入对象获取
|
||||||
|
if obj := ws.GetInject(reflect.TypeOf(&gohttp.Client{})); obj != nil {
|
||||||
|
return obj.(*gohttp.Client)
|
||||||
|
}
|
||||||
|
timeout := time.Duration(ws.Config.RedirectTimeout) * time.Millisecond
|
||||||
|
if timeout <= 0 {
|
||||||
|
timeout = 30 * time.Second
|
||||||
|
}
|
||||||
|
return gohttp.NewClient(timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *WebServer) checkAuthForProxy(authLevel int, request *Request, response *Response, logger *log.Logger) (bool, any) {
|
||||||
|
ac := ws.webAuthCheckers[authLevel]
|
||||||
if ac == nil {
|
if ac == nil {
|
||||||
ac = webAuthChecker
|
ac = ws.webAuthChecker
|
||||||
}
|
}
|
||||||
if ac == nil {
|
if ac == nil {
|
||||||
return true, nil
|
return true, nil
|
||||||
@ -239,13 +242,17 @@ type ProxyRule struct {
|
|||||||
|
|
||||||
// ReplaceProxies 使用全量指针替换的方式 (Copy-on-Write) 无缝更新指定 host 的动态代理规则。
|
// ReplaceProxies 使用全量指针替换的方式 (Copy-on-Write) 无缝更新指定 host 的动态代理规则。
|
||||||
func ReplaceProxies(host string, rules []ProxyRule) {
|
func ReplaceProxies(host string, rules []ProxyRule) {
|
||||||
|
DefaultServer.ReplaceProxies(host, rules)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *WebServer) ReplaceProxies(host string, rules []ProxyRule) {
|
||||||
newProxies := make([]*proxyType, 0, len(rules))
|
newProxies := make([]*proxyType, 0, len(rules))
|
||||||
for _, r := range rules {
|
for _, r := range rules {
|
||||||
newProxies = append(newProxies, parseProxyRule(r.AuthLevel, r.Path, r.ToApp, r.ToPath, r.To))
|
newProxies = append(newProxies, parseProxyRule(r.AuthLevel, r.Path, r.ToApp, r.ToPath, r.To))
|
||||||
}
|
}
|
||||||
|
|
||||||
hostPoliciesLock.Lock()
|
ws.hostPoliciesLock.Lock()
|
||||||
defer hostPoliciesLock.Unlock()
|
defer ws.hostPoliciesLock.Unlock()
|
||||||
dynamicProxies[host] = newProxies
|
ws.dynamicProxies[host] = newProxies
|
||||||
rebuildProxiesUnderLock(host)
|
ws.rebuildProxiesUnderLock(host)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -15,7 +15,7 @@ func TestRewrite(t *testing.T) {
|
|||||||
Host("*").ANY("/new", func() string { return "new content" }).Memo("new")
|
Host("*").ANY("/new", func() string { return "new content" }).Memo("new")
|
||||||
Host("*").ANY("/target/123", func() string { return "target content" }).Memo("target")
|
Host("*").ANY("/target/123", func() string { return "target content" }).Memo("target")
|
||||||
|
|
||||||
rh := &RouteHandler{}
|
rh := &RouteHandler{ws: DefaultServer}
|
||||||
|
|
||||||
// 测试精确匹配重写
|
// 测试精确匹配重写
|
||||||
req1 := httptest.NewRequest("GET", "/old", nil)
|
req1 := httptest.NewRequest("GET", "/old", nil)
|
||||||
@ -45,7 +45,7 @@ func TestProxyDirect(t *testing.T) {
|
|||||||
// 注册代理规则
|
// 注册代理规则
|
||||||
Host("*").Proxy(0, "/proxy", backend.URL+"/hello")
|
Host("*").Proxy(0, "/proxy", backend.URL+"/hello")
|
||||||
|
|
||||||
rh := &RouteHandler{}
|
rh := &RouteHandler{ws: DefaultServer}
|
||||||
req := httptest.NewRequest("GET", "/proxy", nil)
|
req := httptest.NewRequest("GET", "/proxy", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
rh.ServeHTTP(w, req)
|
rh.ServeHTTP(w, req)
|
||||||
|
|||||||
30
reload.go
30
reload.go
@ -5,24 +5,30 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
type reloadHook struct {
|
||||||
reloadHooks []func() error
|
hooks []func() error
|
||||||
reloadLock sync.RWMutex
|
lock sync.RWMutex
|
||||||
)
|
}
|
||||||
|
|
||||||
|
var globalReloadHook = &reloadHook{}
|
||||||
|
|
||||||
// OnReload 注册一个在接收到 SIGHUP 信号时触发的重新加载钩子
|
// OnReload 注册一个在接收到 SIGHUP 信号时触发的重新加载钩子
|
||||||
func OnReload(handler func() error) {
|
func OnReload(handler func() error) {
|
||||||
reloadLock.Lock()
|
DefaultServer.OnReload(handler)
|
||||||
defer reloadLock.Unlock()
|
}
|
||||||
reloadHooks = append(reloadHooks, handler)
|
|
||||||
|
func (ws *WebServer) OnReload(handler func() error) {
|
||||||
|
globalReloadHook.lock.Lock()
|
||||||
|
defer globalReloadHook.lock.Unlock()
|
||||||
|
globalReloadHook.hooks = append(globalReloadHook.hooks, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
// triggerReload 触发所有注册的重新加载钩子
|
// triggerReload 触发所有注册的重新加载钩子
|
||||||
func triggerReload() error {
|
func (ws *WebServer) triggerReload() error {
|
||||||
reloadLock.RLock()
|
globalReloadHook.lock.RLock()
|
||||||
hooks := make([]func() error, len(reloadHooks))
|
hooks := make([]func() error, len(globalReloadHook.hooks))
|
||||||
copy(hooks, reloadHooks)
|
copy(hooks, globalReloadHook.hooks)
|
||||||
reloadLock.RUnlock()
|
globalReloadHook.lock.RUnlock()
|
||||||
|
|
||||||
for _, hook := range hooks {
|
for _, hook := range hooks {
|
||||||
if err := hook(); err != nil {
|
if err := hook(); err != nil {
|
||||||
|
|||||||
125
request.go
125
request.go
@ -33,13 +33,120 @@ func (f *UploadFile) Content() ([]byte, error) {
|
|||||||
return io.ReadAll(src)
|
return io.ReadAll(src)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Header 包装 http.Header 以提供 JS 友好的方法
|
||||||
|
type Header struct {
|
||||||
|
H http.Header `js:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Header) Get(key string) string {
|
||||||
|
return h.H.Get(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Header) Set(key, value string) {
|
||||||
|
h.H.Set(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Header) Add(key, value string) {
|
||||||
|
h.H.Add(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Header) Del(key string) {
|
||||||
|
h.H.Del(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Header) Values(key string) []string {
|
||||||
|
return h.H.Values(key)
|
||||||
|
}
|
||||||
|
|
||||||
// Request 封装 http.Request
|
// Request 封装 http.Request
|
||||||
type Request struct {
|
type Request struct {
|
||||||
*http.Request
|
*http.Request `js:"-"`
|
||||||
contextValues map[string]any
|
contextValues map[string]any `js:"-"`
|
||||||
Id string
|
Id string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Cookie 简化的 JS 友好 Cookie 结构
|
||||||
|
type Cookie struct {
|
||||||
|
Name string
|
||||||
|
Value string
|
||||||
|
Path string
|
||||||
|
Domain string
|
||||||
|
MaxAge int
|
||||||
|
Secure bool
|
||||||
|
HttpOnly bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) GetCookie(name string) *Cookie {
|
||||||
|
c, err := r.Request.Cookie(name)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &Cookie{
|
||||||
|
Name: c.Name,
|
||||||
|
Value: c.Value,
|
||||||
|
Path: c.Path,
|
||||||
|
Domain: c.Domain,
|
||||||
|
MaxAge: c.MaxAge,
|
||||||
|
Secure: c.Secure,
|
||||||
|
HttpOnly: c.HttpOnly,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) GetCookies() []*Cookie {
|
||||||
|
res := make([]*Cookie, 0)
|
||||||
|
for _, c := range r.Request.Cookies() {
|
||||||
|
res = append(res, &Cookie{
|
||||||
|
Name: c.Name,
|
||||||
|
Value: c.Value,
|
||||||
|
Path: c.Path,
|
||||||
|
Domain: c.Domain,
|
||||||
|
MaxAge: c.MaxAge,
|
||||||
|
Secure: c.Secure,
|
||||||
|
HttpOnly: c.HttpOnly,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddCookie 遮蔽原生的 AddCookie
|
||||||
|
func (r *Request) AddCookie(c *Cookie) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.Request.AddCookie(&http.Cookie{
|
||||||
|
Name: c.Name,
|
||||||
|
Value: c.Value,
|
||||||
|
Path: c.Path,
|
||||||
|
Domain: c.Domain,
|
||||||
|
MaxAge: c.MaxAge,
|
||||||
|
Secure: c.Secure,
|
||||||
|
HttpOnly: c.HttpOnly,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// CookiesNamed 遮蔽原生的 CookiesNamed
|
||||||
|
func (r *Request) CookiesNamed(name string) []*Cookie {
|
||||||
|
res := make([]*Cookie, 0)
|
||||||
|
for _, c := range r.Request.Cookies() {
|
||||||
|
if c.Name == name {
|
||||||
|
res = append(res, &Cookie{
|
||||||
|
Name: c.Name,
|
||||||
|
Value: c.Value,
|
||||||
|
Path: c.Path,
|
||||||
|
Domain: c.Domain,
|
||||||
|
MaxAge: c.MaxAge,
|
||||||
|
Secure: c.Secure,
|
||||||
|
HttpOnly: c.HttpOnly,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) Header() *Header {
|
||||||
|
return &Header{H: r.Request.Header}
|
||||||
|
}
|
||||||
|
|
||||||
// NewRequest 创建 Request 包装
|
// NewRequest 创建 Request 包装
|
||||||
func NewRequest(httpRequest *http.Request) *Request {
|
func NewRequest(httpRequest *http.Request) *Request {
|
||||||
return &Request{
|
return &Request{
|
||||||
@ -68,11 +175,11 @@ func (r *Request) Get(key string) any {
|
|||||||
|
|
||||||
// MakeUrl 根据当前请求构建完整 URL
|
// MakeUrl 根据当前请求构建完整 URL
|
||||||
func (r *Request) MakeUrl(path string) string {
|
func (r *Request) MakeUrl(path string) string {
|
||||||
scheme := r.Header.Get(discover.HeaderScheme)
|
scheme := r.Header().Get(discover.HeaderScheme)
|
||||||
if scheme == "" {
|
if scheme == "" {
|
||||||
scheme = "http"
|
scheme = "http"
|
||||||
}
|
}
|
||||||
host := r.Header.Get(discover.HeaderHost)
|
host := r.Header().Get(discover.HeaderHost)
|
||||||
if host == "" {
|
if host == "" {
|
||||||
host = r.Host
|
host = r.Host
|
||||||
}
|
}
|
||||||
@ -81,24 +188,24 @@ func (r *Request) MakeUrl(path string) string {
|
|||||||
|
|
||||||
// DeviceId 获取设备 ID
|
// DeviceId 获取设备 ID
|
||||||
func (r *Request) DeviceId() string {
|
func (r *Request) DeviceId() string {
|
||||||
return r.Header.Get(discover.HeaderDeviceID)
|
return r.Header().Get(discover.HeaderDeviceID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SessionId 获取会话 ID
|
// SessionId 获取会话 ID
|
||||||
func (r *Request) SessionId() string {
|
func (r *Request) SessionId() string {
|
||||||
return r.Header.Get(discover.HeaderSessionID)
|
return r.Header().Get(discover.HeaderSessionID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetUserId 设置用户 ID(传递给下游)
|
// SetUserId 设置用户 ID(传递给下游)
|
||||||
func (r *Request) SetUserId(userId string) {
|
func (r *Request) SetUserId(userId string) {
|
||||||
r.Header.Set(discover.HeaderUserID, userId)
|
r.Header().Set(discover.HeaderUserID, userId)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClientIp 获取真实 IP
|
// ClientIp 获取真实 IP
|
||||||
func (r *Request) ClientIp() string {
|
func (r *Request) ClientIp() string {
|
||||||
ip := r.Header.Get(discover.HeaderClientIP)
|
ip := r.Header().Get(discover.HeaderClientIP)
|
||||||
if ip == "" {
|
if ip == "" {
|
||||||
ip = r.Header.Get(discover.HeaderForwardedFor)
|
ip = r.Header().Get(discover.HeaderForwardedFor)
|
||||||
}
|
}
|
||||||
if ip == "" {
|
if ip == "" {
|
||||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||||
|
|||||||
75
response.go
75
response.go
@ -10,31 +10,48 @@ import (
|
|||||||
// Response 封装 http.ResponseWriter
|
// Response 封装 http.ResponseWriter
|
||||||
type Response struct {
|
type Response struct {
|
||||||
Id string
|
Id string
|
||||||
Writer http.ResponseWriter
|
Writer http.ResponseWriter `js:"-"`
|
||||||
Code int
|
Code int
|
||||||
body []byte
|
body []byte `js:"-"`
|
||||||
outLen int
|
outLen int `js:"-"`
|
||||||
changed bool
|
changed bool `js:"-"`
|
||||||
headerWritten bool
|
headerWritten bool `js:"-"`
|
||||||
dontLog200 bool
|
dontLog200 bool `js:"-"`
|
||||||
dontLogArgs []string
|
dontLogArgs []string `js:"-"`
|
||||||
ProxyHeader *http.Header
|
ProxyHeader *http.Header `js:"-"`
|
||||||
|
server *WebServer `js:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Response) SetCookie(cookie *Cookie) {
|
||||||
|
if cookie == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.SetCookie(r.Writer, &http.Cookie{
|
||||||
|
Name: cookie.Name,
|
||||||
|
Value: cookie.Value,
|
||||||
|
Path: cookie.Path,
|
||||||
|
Domain: cookie.Domain,
|
||||||
|
MaxAge: cookie.MaxAge,
|
||||||
|
Secure: cookie.Secure,
|
||||||
|
HttpOnly: cookie.HttpOnly,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewResponse 创建 Response 包装
|
// NewResponse 创建 Response 包装
|
||||||
func NewResponse(writer http.ResponseWriter) *Response {
|
func NewResponse(writer http.ResponseWriter, server *WebServer) *Response {
|
||||||
return &Response{
|
return &Response{
|
||||||
Writer: writer,
|
Writer: writer,
|
||||||
Code: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
|
server: server,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Header 获取响应头部
|
// Header 获取响应头部
|
||||||
func (r *Response) Header() http.Header {
|
func (r *Response) Header() *Header {
|
||||||
if r.ProxyHeader != nil {
|
if r.ProxyHeader != nil {
|
||||||
return *r.ProxyHeader
|
return &Header{H: *r.ProxyHeader}
|
||||||
}
|
}
|
||||||
return r.Writer.Header()
|
return &Header{H: r.Writer.Header()}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write 写入响应内容
|
// Write 写入响应内容
|
||||||
@ -42,9 +59,27 @@ func (r *Response) Write(bytes []byte) (int, error) {
|
|||||||
r.checkWriteHeader()
|
r.checkWriteHeader()
|
||||||
r.changed = true
|
r.changed = true
|
||||||
r.outLen += len(bytes)
|
r.outLen += len(bytes)
|
||||||
if r.Code != http.StatusOK && len(r.body) < 4096 {
|
|
||||||
|
// 如果有输出过滤器,我们必须先缓冲,不能直接写入网线,否则会导致重复输出
|
||||||
|
if r.server != nil && r.server.hasOutFilter {
|
||||||
|
r.body = append(r.body, bytes...)
|
||||||
|
return len(bytes), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 即使没有过滤器,非 200 状态码也进行缓冲以便日志记录
|
||||||
|
if r.Code != http.StatusOK {
|
||||||
r.body = append(r.body, bytes...)
|
r.body = append(r.body, bytes...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if r.ProxyHeader != nil {
|
||||||
|
r.copyProxyHeader()
|
||||||
|
}
|
||||||
|
return r.Writer.Write(bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PhysicalWrite 物理写入网线,绕过过滤器缓冲逻辑
|
||||||
|
func (r *Response) PhysicalWrite(bytes []byte) (int, error) {
|
||||||
|
r.checkWriteHeader()
|
||||||
if r.ProxyHeader != nil {
|
if r.ProxyHeader != nil {
|
||||||
r.copyProxyHeader()
|
r.copyProxyHeader()
|
||||||
}
|
}
|
||||||
@ -100,6 +135,20 @@ func (r *Response) GetStatusCode() int {
|
|||||||
return r.Code
|
return r.Code
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetBody 获取响应内容
|
||||||
|
func (r *Response) GetBody() []byte {
|
||||||
|
return r.body
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearBody 清空响应内容缓冲区 (用于过滤器替换内容)
|
||||||
|
func (r *Response) ClearBody() {
|
||||||
|
r.body = nil
|
||||||
|
r.outLen = 0
|
||||||
|
// 注意:这里我们不重置 headerWritten 和 Code,因为 Header 已经发出去了。
|
||||||
|
// 但是在某些测试环境下(如 httptest.Recorder),我们可以尝试“假装”没写过。
|
||||||
|
// 实际上,生产环境下 Header 发出去就收不回来了,所以注入只能发生在 Body 层面。
|
||||||
|
}
|
||||||
|
|
||||||
// DontLog200 标记不记录 200 状态码的日志
|
// DontLog200 标记不记录 200 状态码的日志
|
||||||
func (r *Response) DontLog200() {
|
func (r *Response) DontLog200() {
|
||||||
r.dontLog200 = true
|
r.dontLog200 = true
|
||||||
|
|||||||
36
rewrite.go
36
rewrite.go
@ -42,22 +42,14 @@ func parseRewriteRule(fromPath, toPath, to string) *rewriteType {
|
|||||||
func (hc *HostContext) Rewrite(path string, to string) *HostContext {
|
func (hc *HostContext) Rewrite(path string, to string) *HostContext {
|
||||||
s := parseRewriteRule(path, "", to)
|
s := parseRewriteRule(path, "", to)
|
||||||
|
|
||||||
hostPoliciesLock.Lock()
|
hc.ws.hostPoliciesLock.Lock()
|
||||||
defer hostPoliciesLock.Unlock()
|
defer hc.ws.hostPoliciesLock.Unlock()
|
||||||
codeRewrites[hc.host] = append(codeRewrites[hc.host], s)
|
hc.ws.codeRewrites[hc.host] = append(hc.ws.codeRewrites[hc.host], s)
|
||||||
rebuildRewritesUnderLock(hc.host)
|
hc.ws.rebuildRewritesUnderLock(hc.host)
|
||||||
return hc
|
return hc
|
||||||
}
|
}
|
||||||
|
|
||||||
func rebuildRewritesUnderLock(host string) {
|
func (ws *WebServer) processRewrite(request *Request, response *Response, logger *log.Logger) bool {
|
||||||
var combined []*rewriteType
|
|
||||||
combined = append(combined, codeRewrites[host]...)
|
|
||||||
combined = append(combined, fileRewrites[host]...)
|
|
||||||
combined = append(combined, dynamicRewrites[host]...)
|
|
||||||
hostRewrites[host] = combined
|
|
||||||
}
|
|
||||||
|
|
||||||
func processRewrite(request *Request, response *Response, logger *log.Logger) bool {
|
|
||||||
host := request.Host
|
host := request.Host
|
||||||
hostOnly, port, _ := strings.Cut(host, ":")
|
hostOnly, port, _ := strings.Cut(host, ":")
|
||||||
hosts := []string{host}
|
hosts := []string{host}
|
||||||
@ -66,8 +58,8 @@ func processRewrite(request *Request, response *Response, logger *log.Logger) bo
|
|||||||
}
|
}
|
||||||
hosts = append(hosts, "*")
|
hosts = append(hosts, "*")
|
||||||
|
|
||||||
hostPoliciesLock.RLock()
|
ws.hostPoliciesLock.RLock()
|
||||||
defer hostPoliciesLock.RUnlock()
|
defer ws.hostPoliciesLock.RUnlock()
|
||||||
|
|
||||||
requestPath := request.RequestURI
|
requestPath := request.RequestURI
|
||||||
queryString := ""
|
queryString := ""
|
||||||
@ -77,7 +69,7 @@ func processRewrite(request *Request, response *Response, logger *log.Logger) bo
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, h := range hosts {
|
for _, h := range hosts {
|
||||||
rewrites, exists := hostRewrites[h]
|
rewrites, exists := ws.hostRewrites[h]
|
||||||
if !exists {
|
if !exists {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -144,13 +136,17 @@ type RewriteRule struct {
|
|||||||
|
|
||||||
// ReplaceRewrites 使用 Copy-on-Write 机制原子地替换指定 host 下的动态重写规则。
|
// ReplaceRewrites 使用 Copy-on-Write 机制原子地替换指定 host 下的动态重写规则。
|
||||||
func ReplaceRewrites(host string, rules []RewriteRule) {
|
func ReplaceRewrites(host string, rules []RewriteRule) {
|
||||||
|
DefaultServer.ReplaceRewrites(host, rules)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *WebServer) ReplaceRewrites(host string, rules []RewriteRule) {
|
||||||
newRewrites := make([]*rewriteType, 0, len(rules))
|
newRewrites := make([]*rewriteType, 0, len(rules))
|
||||||
for _, r := range rules {
|
for _, r := range rules {
|
||||||
newRewrites = append(newRewrites, parseRewriteRule(r.Path, r.ToPath, r.To))
|
newRewrites = append(newRewrites, parseRewriteRule(r.Path, r.ToPath, r.To))
|
||||||
}
|
}
|
||||||
|
|
||||||
hostPoliciesLock.Lock()
|
ws.hostPoliciesLock.Lock()
|
||||||
defer hostPoliciesLock.Unlock()
|
defer ws.hostPoliciesLock.Unlock()
|
||||||
dynamicRewrites[host] = newRewrites
|
ws.dynamicRewrites[host] = newRewrites
|
||||||
rebuildRewritesUnderLock(host)
|
ws.rebuildRewritesUnderLock(host)
|
||||||
}
|
}
|
||||||
|
|||||||
114
robustness_test.go
Normal file
114
robustness_test.go
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"net/url"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStaticRobustness(t *testing.T) {
|
||||||
|
tempDir, _ := os.MkdirTemp("", "robustness_test")
|
||||||
|
defer os.RemoveAll(tempDir)
|
||||||
|
|
||||||
|
// 创建复杂的目录结构
|
||||||
|
subDir := filepath.Join(tempDir, "The NPC Awakens", "The Loop", "scene", "M")
|
||||||
|
_ = os.MkdirAll(subDir, 0755)
|
||||||
|
|
||||||
|
fileName := "画面逐渐亮起,铁匠铺的铁锤在无人操作的情况下,机械地敲击着烧红的铁块。_large.webp"
|
||||||
|
testFile := filepath.Join(subDir, fileName)
|
||||||
|
content := []byte("fake webp content")
|
||||||
|
_ = os.WriteFile(testFile, content, 0644)
|
||||||
|
|
||||||
|
// 注册静态目录
|
||||||
|
ws := NewWebServer()
|
||||||
|
ws.Config.App = "test"
|
||||||
|
ws.Static("/img/", tempDir)
|
||||||
|
|
||||||
|
rh := &RouteHandler{ws: ws}
|
||||||
|
|
||||||
|
// 构造编码后的请求路径
|
||||||
|
encodedPath := "/img/" + url.PathEscape("The NPC Awakens/The Loop/scene/M/画面逐渐亮起,铁匠铺的铁锤在无人操作的情况下,机械地敲击着烧红的铁块。_large.webp")
|
||||||
|
|
||||||
|
// 测试静态文件访问
|
||||||
|
req := httptest.NewRequest("GET", encodedPath+"?v=1780317467305", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
rh.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("Expected 200 for complex static file, got %d. Path: %s", w.Code, encodedPath)
|
||||||
|
} else if string(w.Body.Bytes()) != string(content) {
|
||||||
|
t.Errorf("Content mismatch for complex static file")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDynamicRobustness(t *testing.T) {
|
||||||
|
ws := NewWebServer()
|
||||||
|
ws.Config.App = "test"
|
||||||
|
|
||||||
|
pathPattern := "/api/scene/{name}"
|
||||||
|
ws.Host("*").GET(pathPattern, func(in struct{ Name string }) string {
|
||||||
|
return "Hello " + in.Name
|
||||||
|
})
|
||||||
|
|
||||||
|
rh := &RouteHandler{ws: ws}
|
||||||
|
|
||||||
|
complexName := "画面逐渐亮起,铁匠铺的铁锤"
|
||||||
|
encodedPath := "/api/scene/" + url.PathEscape(complexName)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", encodedPath, nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
rh.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("Expected 200 for complex dynamic path, got %d", w.Code)
|
||||||
|
}
|
||||||
|
expectedBody := "Hello " + complexName
|
||||||
|
if w.Body.String() != expectedBody {
|
||||||
|
t.Errorf("Got body: %s, expected: %s", w.Body.String(), expectedBody)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHostMatching(t *testing.T) {
|
||||||
|
ws := NewWebServer()
|
||||||
|
ws.Config.App = "test"
|
||||||
|
|
||||||
|
// 1. 注册只带端口的 Host
|
||||||
|
ws.Host(":8080").GET("/port", func() string { return "port" })
|
||||||
|
// 2. 注册只带域名的 Host
|
||||||
|
ws.Host("localhost").GET("/host", func() string { return "host" })
|
||||||
|
// 3. 注册完整 Host
|
||||||
|
ws.Host("example.com:9000").GET("/full", func() string { return "full" })
|
||||||
|
|
||||||
|
rh := &RouteHandler{ws: ws}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
requestHost string
|
||||||
|
path string
|
||||||
|
expected string
|
||||||
|
code int
|
||||||
|
}{
|
||||||
|
{"localhost:8080", "/port", "port", http.StatusOK},
|
||||||
|
{"otherhost:8080", "/port", "port", http.StatusOK},
|
||||||
|
{"localhost:9999", "/host", "host", http.StatusOK},
|
||||||
|
{"example.com:9000", "/full", "full", http.StatusOK},
|
||||||
|
{"example.com:8080", "/port", "port", http.StatusOK},
|
||||||
|
{"localhost:8080", "/host", "host", http.StatusOK},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
req := httptest.NewRequest("GET", tt.path, nil)
|
||||||
|
req.Host = tt.requestHost
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
rh.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != tt.code {
|
||||||
|
t.Errorf("Host [%s] Path [%s] expected code %d, got %d", tt.requestHost, tt.path, tt.code, w.Code)
|
||||||
|
}
|
||||||
|
if tt.code == http.StatusOK && w.Body.String() != tt.expected {
|
||||||
|
t.Errorf("Host [%s] Path [%s] expected body %s, got %s", tt.requestHost, tt.path, tt.expected, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
373
server.go
373
server.go
@ -7,6 +7,7 @@ import (
|
|||||||
"apigo.cc/go/redis"
|
"apigo.cc/go/redis"
|
||||||
"apigo.cc/go/safe"
|
"apigo.cc/go/safe"
|
||||||
"apigo.cc/go/starter"
|
"apigo.cc/go/starter"
|
||||||
|
"apigo.cc/go/watch"
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
@ -14,26 +15,274 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"reflect"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GlobalDiscoverer 供服务框架内部使用的发现实例
|
type staticType struct {
|
||||||
var GlobalDiscoverer *discover.Discoverer
|
path string
|
||||||
|
rootPath *string
|
||||||
|
}
|
||||||
|
|
||||||
// WebServer 实现了 starter.Service 和 starter.Reloader 接口
|
|
||||||
type WebServer struct {
|
type WebServer struct {
|
||||||
|
Config ServiceConfig
|
||||||
server *http.Server
|
server *http.Server
|
||||||
listener net.Listener
|
listener net.Listener
|
||||||
Addr string
|
Addr string
|
||||||
useDiscover bool
|
useDiscover bool
|
||||||
discoverer *discover.Discoverer
|
discoverer *discover.Discoverer
|
||||||
logger *log.Logger
|
logger *log.Logger
|
||||||
|
|
||||||
|
// 运行时状态
|
||||||
|
serverId string
|
||||||
|
serverAddr string
|
||||||
|
running bool
|
||||||
|
|
||||||
|
// Web 服务注册 (按 Host 隔离)
|
||||||
|
webServices map[string]map[string]*webServiceType
|
||||||
|
regexWebServices map[string][]*webServiceType
|
||||||
|
webServicesLock sync.RWMutex
|
||||||
|
webServicesList []*webServiceType
|
||||||
|
|
||||||
|
websocketServices map[string]map[string]*websocketServiceType
|
||||||
|
websocketServicesLock sync.RWMutex
|
||||||
|
websocketServicesList []*websocketServiceType
|
||||||
|
|
||||||
|
// 路由策略 (按 Host 隔离)
|
||||||
|
hostRewrites map[string][]*rewriteType
|
||||||
|
hostProxies map[string][]*proxyType
|
||||||
|
|
||||||
|
codeProxies map[string][]*proxyType
|
||||||
|
fileProxies map[string][]*proxyType
|
||||||
|
dynamicProxies map[string][]*proxyType
|
||||||
|
|
||||||
|
codeRewrites map[string][]*rewriteType
|
||||||
|
fileRewrites map[string][]*rewriteType
|
||||||
|
dynamicRewrites map[string][]*rewriteType
|
||||||
|
|
||||||
|
hostPoliciesLock sync.RWMutex
|
||||||
|
|
||||||
|
// 静态文件服务
|
||||||
|
statics map[string]*string
|
||||||
|
staticsByHost map[string]map[string]*string
|
||||||
|
codeStatics map[string]map[string]*string
|
||||||
|
fileStatics map[string]map[string]*string
|
||||||
|
dynamicStatics map[string]map[string]*string
|
||||||
|
hostStatics map[string][]*staticType
|
||||||
|
staticsByHostLock sync.RWMutex
|
||||||
|
|
||||||
|
// 过滤器与拦截器
|
||||||
|
inFilters []func(*map[string]any, *Request, *Response, *log.Logger) any
|
||||||
|
outFilters []func(map[string]any, *Request, *Response, any, *log.Logger) (any, bool)
|
||||||
|
errorHandle func(any, *Request, *Response) any
|
||||||
|
webAuthChecker func(int, *log.Logger, *string, map[string]any, *Request, *Response, *WebServiceOptions) (pass bool, object any)
|
||||||
|
webAuthCheckers map[int]func(int, *log.Logger, *string, map[string]any, *Request, *Response, *WebServiceOptions) (pass bool, object any)
|
||||||
|
|
||||||
|
// 注入点
|
||||||
|
injectObjects map[reflect.Type]any
|
||||||
|
injectFunctions map[reflect.Type]func() any
|
||||||
|
|
||||||
|
// 客户端标识
|
||||||
|
usedDeviceIdKey string
|
||||||
|
usedClientAppKey string
|
||||||
|
usedSessionIdKey string
|
||||||
|
sessionIdMaker func() string
|
||||||
|
|
||||||
|
// 停机钩子
|
||||||
|
shutdownHooks []func()
|
||||||
|
shutdownHooksLock sync.Mutex
|
||||||
|
|
||||||
|
// 性能优化:标记是否有输出过滤器
|
||||||
|
hasOutFilter bool
|
||||||
|
|
||||||
|
// Web 开发模式配置
|
||||||
|
webDevEnabled bool
|
||||||
|
webDevConfig watch.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewWebServer 创建并返回一个新的 WebServer 实例
|
// DefaultServer 全局单例服务实例
|
||||||
|
var DefaultServer = NewWebServer()
|
||||||
|
|
||||||
|
// Config 全局配置对象 (指向 DefaultServer.Config)
|
||||||
|
var Config = &DefaultServer.Config
|
||||||
|
|
||||||
func NewWebServer() *WebServer {
|
func NewWebServer() *WebServer {
|
||||||
return &WebServer{}
|
ws := &WebServer{
|
||||||
|
webServices: make(map[string]map[string]*webServiceType),
|
||||||
|
regexWebServices: make(map[string][]*webServiceType),
|
||||||
|
webServicesList: make([]*webServiceType, 0),
|
||||||
|
websocketServices: make(map[string]map[string]*websocketServiceType),
|
||||||
|
websocketServicesList: make([]*websocketServiceType, 0),
|
||||||
|
hostRewrites: make(map[string][]*rewriteType),
|
||||||
|
hostProxies: make(map[string][]*proxyType),
|
||||||
|
codeProxies: make(map[string][]*proxyType),
|
||||||
|
fileProxies: make(map[string][]*proxyType),
|
||||||
|
dynamicProxies: make(map[string][]*proxyType),
|
||||||
|
codeRewrites: make(map[string][]*rewriteType),
|
||||||
|
fileRewrites: make(map[string][]*rewriteType),
|
||||||
|
dynamicRewrites: make(map[string][]*rewriteType),
|
||||||
|
statics: make(map[string]*string),
|
||||||
|
staticsByHost: make(map[string]map[string]*string),
|
||||||
|
codeStatics: make(map[string]map[string]*string),
|
||||||
|
fileStatics: make(map[string]map[string]*string),
|
||||||
|
dynamicStatics: make(map[string]map[string]*string),
|
||||||
|
hostStatics: make(map[string][]*staticType),
|
||||||
|
webAuthCheckers: make(map[int]func(int, *log.Logger, *string, map[string]any, *Request, *Response, *WebServiceOptions) (pass bool, object any)),
|
||||||
|
injectObjects: make(map[reflect.Type]any),
|
||||||
|
injectFunctions: make(map[reflect.Type]func() any),
|
||||||
|
}
|
||||||
|
return ws
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDiscovererForTest 提供给测试用例使用的后门方法,用于模拟断开或重置服务发现
|
||||||
|
func SetDiscovererForTest(d *discover.Discoverer) {
|
||||||
|
DefaultServer.discoverer = d
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyConfig 将 ServiceConfig 中的路由策略应用到内部的文件级策略中
|
||||||
|
func (ws *WebServer) ApplyConfig() {
|
||||||
|
ws.hostPoliciesLock.Lock()
|
||||||
|
defer ws.hostPoliciesLock.Unlock()
|
||||||
|
|
||||||
|
// 1. Proxies KV 解析
|
||||||
|
ws.fileProxies = make(map[string][]*proxyType)
|
||||||
|
for host, kv := range ws.Config.Proxies {
|
||||||
|
h := host
|
||||||
|
if h == "*" {
|
||||||
|
h = ""
|
||||||
|
}
|
||||||
|
rules := make([]*proxyType, 0, len(kv))
|
||||||
|
for path, val := range kv {
|
||||||
|
if to, ok := val.(string); ok {
|
||||||
|
rules = append(rules, parseProxyRule(0, path, "", "", to))
|
||||||
|
} else {
|
||||||
|
// 对象模式
|
||||||
|
m := make(map[string]any)
|
||||||
|
if tm, ok := val.(map[string]any); ok {
|
||||||
|
m = tm
|
||||||
|
}
|
||||||
|
rules = append(rules, parseProxyRule(
|
||||||
|
int(reflect.ValueOf(m["Auth"]).Int()), // Simplified
|
||||||
|
path,
|
||||||
|
fmt.Sprint(m["ToApp"]),
|
||||||
|
fmt.Sprint(m["ToPath"]),
|
||||||
|
fmt.Sprint(m["To"]),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ws.fileProxies[h] = rules
|
||||||
|
ws.rebuildProxiesUnderLock(h)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Rewrites KV 解析
|
||||||
|
ws.fileRewrites = make(map[string][]*rewriteType)
|
||||||
|
for host, kv := range ws.Config.Rewrites {
|
||||||
|
h := host
|
||||||
|
if h == "*" {
|
||||||
|
h = ""
|
||||||
|
}
|
||||||
|
rules := make([]*rewriteType, 0, len(kv))
|
||||||
|
for path, val := range kv {
|
||||||
|
if to, ok := val.(string); ok {
|
||||||
|
rules = append(rules, parseRewriteRule(path, "", to))
|
||||||
|
} else {
|
||||||
|
m := make(map[string]any)
|
||||||
|
if tm, ok := val.(map[string]any); ok {
|
||||||
|
m = tm
|
||||||
|
}
|
||||||
|
rules = append(rules, parseRewriteRule(
|
||||||
|
path,
|
||||||
|
fmt.Sprint(m["ToPath"]),
|
||||||
|
fmt.Sprint(m["To"]),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ws.fileRewrites[h] = rules
|
||||||
|
ws.rebuildRewritesUnderLock(h)
|
||||||
|
}
|
||||||
|
|
||||||
|
ws.staticsByHostLock.Lock()
|
||||||
|
defer ws.staticsByHostLock.Unlock()
|
||||||
|
ws.fileStatics = make(map[string]map[string]*string)
|
||||||
|
|
||||||
|
for host, config := range ws.Config.Statics {
|
||||||
|
h := host
|
||||||
|
if h == "*" {
|
||||||
|
h = ""
|
||||||
|
}
|
||||||
|
newStatics := make(map[string]*string, len(config))
|
||||||
|
for path, rootPath := range config {
|
||||||
|
rp := rootPath
|
||||||
|
if !filepath.IsAbs(rp) {
|
||||||
|
if absPath, err := filepath.Abs(rp); err == nil {
|
||||||
|
rp = absPath
|
||||||
|
}
|
||||||
|
}
|
||||||
|
newStatics[path] = &rp
|
||||||
|
}
|
||||||
|
ws.fileStatics[h] = newStatics
|
||||||
|
ws.rebuildStaticsUnderLock(h)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 始终重新构建默认 Host 的静态路由,以合并代码定义的路由
|
||||||
|
ws.rebuildStaticsUnderLock("")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *WebServer) rebuildProxiesUnderLock(host string) {
|
||||||
|
combined := make([]*proxyType, 0)
|
||||||
|
combined = append(combined, ws.codeProxies[host]...)
|
||||||
|
combined = append(combined, ws.fileProxies[host]...)
|
||||||
|
combined = append(combined, ws.dynamicProxies[host]...)
|
||||||
|
|
||||||
|
sort.Slice(combined, func(i, j int) bool {
|
||||||
|
return len(combined[i].fromPath) > len(combined[j].fromPath)
|
||||||
|
})
|
||||||
|
ws.hostProxies[host] = combined
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *WebServer) rebuildRewritesUnderLock(host string) {
|
||||||
|
combined := make([]*rewriteType, 0)
|
||||||
|
combined = append(combined, ws.codeRewrites[host]...)
|
||||||
|
combined = append(combined, ws.fileRewrites[host]...)
|
||||||
|
combined = append(combined, ws.dynamicRewrites[host]...)
|
||||||
|
|
||||||
|
sort.Slice(combined, func(i, j int) bool {
|
||||||
|
return len(combined[i].fromPath) > len(combined[j].fromPath)
|
||||||
|
})
|
||||||
|
ws.hostRewrites[host] = combined
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *WebServer) rebuildStaticsUnderLock(host string) {
|
||||||
|
combined := make(map[string]*string)
|
||||||
|
for k, v := range ws.codeStatics[host] {
|
||||||
|
combined[k] = v
|
||||||
|
}
|
||||||
|
for k, v := range ws.fileStatics[host] {
|
||||||
|
combined[k] = v
|
||||||
|
}
|
||||||
|
for k, v := range ws.dynamicStatics[host] {
|
||||||
|
combined[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
if host == "" {
|
||||||
|
ws.statics = combined
|
||||||
|
} else {
|
||||||
|
ws.staticsByHost[host] = combined
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构造有序的静态路由列表 (按路径长度降序排列,实现最长匹配)
|
||||||
|
sorted := make([]*staticType, 0, len(combined))
|
||||||
|
for k, v := range combined {
|
||||||
|
sorted = append(sorted, &staticType{path: k, rootPath: v})
|
||||||
|
}
|
||||||
|
sort.Slice(sorted, func(i, j int) bool {
|
||||||
|
return len(sorted[i].path) > len(sorted[j].path)
|
||||||
|
})
|
||||||
|
ws.hostStatics[host] = sorted
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start 启动服务,实现 starter.Service 接口
|
// Start 启动服务,实现 starter.Service 接口
|
||||||
@ -44,12 +293,12 @@ func (ws *WebServer) Start(ctx context.Context, logger *log.Logger) error {
|
|||||||
ws.logger = logger
|
ws.logger = logger
|
||||||
|
|
||||||
// 初始加载配置
|
// 初始加载配置
|
||||||
if err := config.Load(&Config, "service"); err != nil {
|
if err := config.Load(&ws.Config, "service"); err != nil {
|
||||||
logger.Error("failed to load config during start", "error", err.Error())
|
logger.Error("failed to load config during start", "error", err.Error())
|
||||||
}
|
}
|
||||||
ApplyConfig()
|
ws.ApplyConfig()
|
||||||
|
|
||||||
listenStr := Config.Listen
|
listenStr := ws.Config.Listen
|
||||||
ws.useDiscover = false
|
ws.useDiscover = false
|
||||||
|
|
||||||
if listenStr == "" {
|
if listenStr == "" {
|
||||||
@ -57,7 +306,6 @@ func (ws *WebServer) Start(ctx context.Context, logger *log.Logger) error {
|
|||||||
ws.useDiscover = true
|
ws.useDiscover = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// 解析第一个监听配置
|
|
||||||
part := strings.Split(listenStr, "|")[0]
|
part := strings.Split(listenStr, "|")[0]
|
||||||
addr, opts, _ := strings.Cut(part, ",")
|
addr, opts, _ := strings.Cut(part, ",")
|
||||||
|
|
||||||
@ -70,29 +318,30 @@ func (ws *WebServer) Start(ctx context.Context, logger *log.Logger) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if protocol == "" {
|
if protocol == "" {
|
||||||
protocol = "http" // Default to http
|
protocol = "http"
|
||||||
}
|
}
|
||||||
|
|
||||||
if !strings.Contains(addr, ":") {
|
if !strings.Contains(addr, ":") {
|
||||||
addr = ":" + addr
|
addr = ":" + addr
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查是否需要启动服务发现
|
if ws.webDevEnabled {
|
||||||
appName := Config.App
|
ws.initWebDev(logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
appName := ws.Config.App
|
||||||
if appName == "" {
|
if appName == "" {
|
||||||
appName = GetDefaultName()
|
appName = GetDefaultName()
|
||||||
Config.App = appName
|
ws.Config.App = appName
|
||||||
}
|
}
|
||||||
if appName != "" || Config.Register != "" {
|
if appName != "" || ws.Config.Register != "" {
|
||||||
ws.useDiscover = true
|
ws.useDiscover = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// 初始化服务器唯一标识 (8位,物理上限 3,844/s)
|
ws.serverId = IDMaker.Get8Bytes4KPerSecond()
|
||||||
serverId = IDMaker.Get8Bytes4KPerSecond()
|
|
||||||
|
|
||||||
// 初始化分布式 ID 生成器
|
if ws.Config.IdServer != "" {
|
||||||
if Config.IdServer != "" {
|
rd := redis.GetRedis(ws.Config.IdServer, log.New(ws.serverId))
|
||||||
rd := redis.GetRedis(Config.IdServer, log.New(serverId))
|
|
||||||
if rd.Error == nil {
|
if rd.Error == nil {
|
||||||
IDMaker = redis.NewIDMaker(rd)
|
IDMaker = redis.NewIDMaker(rd)
|
||||||
}
|
}
|
||||||
@ -105,38 +354,35 @@ func (ws *WebServer) Start(ctx context.Context, logger *log.Logger) error {
|
|||||||
|
|
||||||
ws.listener = listener
|
ws.listener = listener
|
||||||
ws.Addr = listener.Addr().String()
|
ws.Addr = listener.Addr().String()
|
||||||
serverAddr = ws.Addr
|
ws.serverAddr = ws.Addr
|
||||||
|
|
||||||
// 如果使用了随机端口且没有明确指定不需要服务发现,则开启
|
|
||||||
if addr == ":0" || strings.HasSuffix(addr, ":0") {
|
if addr == ":0" || strings.HasSuffix(addr, ":0") {
|
||||||
ws.useDiscover = true
|
ws.useDiscover = true
|
||||||
}
|
}
|
||||||
|
|
||||||
h2s := &http2.Server{}
|
h2s := &http2.Server{}
|
||||||
var handler http.Handler = &RouteHandler{}
|
var handler http.Handler = &RouteHandler{ws: ws}
|
||||||
if protocol == "h2c" {
|
if protocol == "h2c" {
|
||||||
handler = h2c.NewHandler(handler, h2s)
|
handler = h2c.NewHandler(handler, h2s)
|
||||||
}
|
}
|
||||||
|
|
||||||
ws.server = &http.Server{
|
ws.server = &http.Server{
|
||||||
Handler: handler,
|
Handler: handler,
|
||||||
ReadTimeout: time.Duration(Config.ReadTimeout) * time.Millisecond,
|
ReadTimeout: time.Duration(ws.Config.ReadTimeout) * time.Millisecond,
|
||||||
ReadHeaderTimeout: time.Duration(Config.ReadHeaderTimeout) * time.Millisecond,
|
ReadHeaderTimeout: time.Duration(ws.Config.ReadHeaderTimeout) * time.Millisecond,
|
||||||
WriteTimeout: time.Duration(Config.WriteTimeout) * time.Millisecond,
|
WriteTimeout: time.Duration(ws.Config.WriteTimeout) * time.Millisecond,
|
||||||
IdleTimeout: time.Duration(Config.IdleTimeout) * time.Millisecond,
|
IdleTimeout: time.Duration(ws.Config.IdleTimeout) * time.Millisecond,
|
||||||
MaxHeaderBytes: Config.MaxHeaderBytes,
|
MaxHeaderBytes: ws.Config.MaxHeaderBytes,
|
||||||
}
|
}
|
||||||
|
|
||||||
// 启动服务发现
|
|
||||||
if ws.useDiscover {
|
if ws.useDiscover {
|
||||||
_, port, _ := net.SplitHostPort(ws.Addr)
|
_, port, _ := net.SplitHostPort(ws.Addr)
|
||||||
ip := GetServerIp()
|
ip := GetServerIp()
|
||||||
discoverAddr := fmt.Sprintf("%s:%s", ip, port)
|
discoverAddr := fmt.Sprintf("%s:%s", ip, port)
|
||||||
|
|
||||||
// 转换配置
|
|
||||||
discConf := discover.Config{
|
discConf := discover.Config{
|
||||||
Weight: Config.Weight,
|
Weight: ws.Config.Weight,
|
||||||
CallRetryTimes: 10, // Default
|
CallRetryTimes: 10,
|
||||||
Calls: make(map[string]discover.CallConfig),
|
Calls: make(map[string]discover.CallConfig),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -144,15 +390,15 @@ func (ws *WebServer) Start(ctx context.Context, logger *log.Logger) error {
|
|||||||
discConf.Weight = 100
|
discConf.Weight = 100
|
||||||
}
|
}
|
||||||
|
|
||||||
for name, call := range Config.Calls {
|
for name, call := range ws.Config.Calls {
|
||||||
dc := discover.CallConfig{
|
dc := discover.CallConfig{
|
||||||
Http2: call.Http2,
|
Http2: call.Http2,
|
||||||
SSL: call.SSL,
|
SSL: call.SSL,
|
||||||
}
|
}
|
||||||
if call.Timeout > 0 {
|
if call.Timeout > 0 {
|
||||||
dc.Timeout = time.Duration(call.Timeout) * time.Millisecond
|
dc.Timeout = time.Duration(call.Timeout) * time.Millisecond
|
||||||
} else if Config.RedirectTimeout > 0 {
|
} else if ws.Config.RedirectTimeout > 0 {
|
||||||
dc.Timeout = time.Duration(Config.RedirectTimeout) * time.Millisecond
|
dc.Timeout = time.Duration(ws.Config.RedirectTimeout) * time.Millisecond
|
||||||
}
|
}
|
||||||
if call.Token != "" {
|
if call.Token != "" {
|
||||||
dc.Token = safe.NewSafeBuf([]byte(call.Token))
|
dc.Token = safe.NewSafeBuf([]byte(call.Token))
|
||||||
@ -160,33 +406,30 @@ func (ws *WebServer) Start(ctx context.Context, logger *log.Logger) error {
|
|||||||
discConf.Calls[name] = dc
|
discConf.Calls[name] = dc
|
||||||
}
|
}
|
||||||
|
|
||||||
// 解析必需的 Register,支持环境变量 fallback
|
registry := ws.Config.Register
|
||||||
registry := Config.Register
|
|
||||||
if registry == "" {
|
if registry == "" {
|
||||||
registry = os.Getenv("DISCOVER_REGISTRY")
|
registry = os.Getenv("DISCOVER_REGISTRY")
|
||||||
}
|
}
|
||||||
if registry == "" {
|
|
||||||
registry = "127.0.0.1:6379::15" // Default fallback
|
|
||||||
}
|
|
||||||
|
|
||||||
|
if registry != "" {
|
||||||
ws.discoverer = discover.Start(registry, appName, discoverAddr, logger, discConf)
|
ws.discoverer = discover.Start(registry, appName, discoverAddr, logger, discConf)
|
||||||
GlobalDiscoverer = ws.discoverer
|
|
||||||
if ws.discoverer != nil {
|
if ws.discoverer != nil {
|
||||||
logger.Info("discover registered", "app", appName, "addr", discoverAddr)
|
logger.Info("discover registered", "app", appName, "addr", discoverAddr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
errChan := make(chan error, 1)
|
errChan := make(chan error, 1)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
logger.Info("service starting", "addr", ws.Addr, "proto", protocol)
|
logger.Info("starting listener", "addr", ws.Addr, "proto", protocol)
|
||||||
if err := ws.server.Serve(listener); err != nil && err != http.ErrServerClosed {
|
ws.running = true
|
||||||
|
if err := ws.server.Serve(ws.listener); err != nil && err != http.ErrServerClosed {
|
||||||
errChan <- err
|
errChan <- err
|
||||||
}
|
}
|
||||||
close(errChan)
|
close(errChan)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// 短暂等待验证是否闪退
|
|
||||||
select {
|
select {
|
||||||
case err := <-errChan:
|
case err := <-errChan:
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -200,11 +443,16 @@ func (ws *WebServer) Start(ctx context.Context, logger *log.Logger) error {
|
|||||||
|
|
||||||
// Stop 停止服务,实现 starter.Service 接口
|
// Stop 停止服务,实现 starter.Service 接口
|
||||||
func (ws *WebServer) Stop(ctx context.Context) error {
|
func (ws *WebServer) Stop(ctx context.Context) error {
|
||||||
logger := ws.logger
|
ws.running = false
|
||||||
if logger == nil {
|
|
||||||
logger = log.DefaultLogger
|
// 执行停机钩子 (反序)
|
||||||
|
ws.shutdownHooksLock.Lock()
|
||||||
|
for i := len(ws.shutdownHooks) - 1; i >= 0; i-- {
|
||||||
|
ws.shutdownHooks[i]()
|
||||||
}
|
}
|
||||||
logger.Info("service stopping")
|
ws.shutdownHooks = nil
|
||||||
|
ws.shutdownHooksLock.Unlock()
|
||||||
|
|
||||||
if ws.discoverer != nil {
|
if ws.discoverer != nil {
|
||||||
ws.discoverer.Stop()
|
ws.discoverer.Stop()
|
||||||
}
|
}
|
||||||
@ -213,13 +461,12 @@ func (ws *WebServer) Stop(ctx context.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
logger.Info("service stopped")
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Status 检查服务健康状态,实现 starter.Service 接口
|
// Status 检查服务健康状态,实现 starter.Service 接口
|
||||||
func (ws *WebServer) Status() (string, error) {
|
func (ws *WebServer) Status() (string, error) {
|
||||||
if ws.server == nil {
|
if ws.server == nil || !ws.running {
|
||||||
return "", fmt.Errorf("server is not running")
|
return "", fmt.Errorf("server is not running")
|
||||||
}
|
}
|
||||||
return ws.Addr, nil
|
return ws.Addr, nil
|
||||||
@ -233,14 +480,12 @@ func (ws *WebServer) Reload() error {
|
|||||||
}
|
}
|
||||||
logger.Info("reloading configurations...")
|
logger.Info("reloading configurations...")
|
||||||
|
|
||||||
// 重新加载配置文件中的策略
|
if err := config.Load(&ws.Config, "service"); err != nil {
|
||||||
if err := config.Load(&Config, "service"); err != nil {
|
|
||||||
logger.Error("failed to load config during reload", "error", err.Error())
|
logger.Error("failed to load config during reload", "error", err.Error())
|
||||||
}
|
}
|
||||||
ApplyConfig()
|
ws.ApplyConfig()
|
||||||
|
|
||||||
// 触发业务挂载的 Hook
|
return ws.triggerReload()
|
||||||
return triggerReload()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AsyncServer 兼容旧版异步服务实例
|
// AsyncServer 兼容旧版异步服务实例
|
||||||
@ -250,7 +495,7 @@ type AsyncServer struct {
|
|||||||
|
|
||||||
// Stop 兼容旧版的无参数停止方法
|
// Stop 兼容旧版的无参数停止方法
|
||||||
func (as *AsyncServer) Stop() {
|
func (as *AsyncServer) Stop() {
|
||||||
stopTimeout := time.Duration(Config.StopTimeout) * time.Millisecond
|
stopTimeout := time.Duration(as.Config.StopTimeout) * time.Millisecond
|
||||||
if stopTimeout <= 0 {
|
if stopTimeout <= 0 {
|
||||||
stopTimeout = 5 * time.Second
|
stopTimeout = 5 * time.Second
|
||||||
}
|
}
|
||||||
@ -261,9 +506,8 @@ func (as *AsyncServer) Stop() {
|
|||||||
|
|
||||||
// AsyncStart 兼容旧版的异步启动方法
|
// AsyncStart 兼容旧版的异步启动方法
|
||||||
func AsyncStart() *AsyncServer {
|
func AsyncStart() *AsyncServer {
|
||||||
ws := NewWebServer()
|
_ = DefaultServer.Start(context.Background(), log.DefaultLogger)
|
||||||
_ = ws.Start(context.Background(), log.DefaultLogger)
|
return &AsyncServer{WebServer: DefaultServer}
|
||||||
return &AsyncServer{WebServer: ws}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait 等待服务结束 (兼容旧版,直接阻塞)
|
// Wait 等待服务结束 (兼容旧版,直接阻塞)
|
||||||
@ -271,12 +515,21 @@ func (as *AsyncServer) Wait() {
|
|||||||
select {}
|
select {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var startOnce sync.Once
|
||||||
|
|
||||||
// Start 兼容旧版的同步启动方法 (通过内部注册 starter 实现)
|
// Start 兼容旧版的同步启动方法 (通过内部注册 starter 实现)
|
||||||
func Start() {
|
func Start() {
|
||||||
|
if DefaultServer.running {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
startOnce.Do(func() {
|
||||||
stopTimeout := time.Duration(Config.StopTimeout) * time.Millisecond
|
stopTimeout := time.Duration(Config.StopTimeout) * time.Millisecond
|
||||||
if stopTimeout <= 0 {
|
if stopTimeout <= 0 {
|
||||||
stopTimeout = 5 * time.Second
|
stopTimeout = 5 * time.Second
|
||||||
}
|
}
|
||||||
starter.Register("web-server", NewWebServer(), 100, 5*time.Second, stopTimeout)
|
starter.Register("web-server", DefaultServer, 100, 5*time.Second, stopTimeout)
|
||||||
starter.Run()
|
if err := starter.Start(); err == nil {
|
||||||
|
starter.Wait()
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
320
service.go
320
service.go
@ -2,7 +2,9 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"apigo.cc/go/log"
|
"apigo.cc/go/log"
|
||||||
|
"apigo.cc/go/watch"
|
||||||
"errors"
|
"errors"
|
||||||
|
"math"
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
@ -54,121 +56,130 @@ type websocketServiceType struct {
|
|||||||
options WebServiceOptions
|
options WebServiceOptions
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
serverId string
|
|
||||||
serverAddr string
|
|
||||||
serverProto = "http"
|
|
||||||
serverProtoName = "http"
|
|
||||||
running = false
|
|
||||||
|
|
||||||
// webServices 按 Host 隔离: map[host]map[method+path]*webServiceType
|
|
||||||
webServices = make(map[string]map[string]*webServiceType)
|
|
||||||
// regexWebServices 按 Host 隔离: map[host][]*webServiceType
|
|
||||||
regexWebServices = make(map[string][]*webServiceType)
|
|
||||||
webServicesLock = sync.RWMutex{}
|
|
||||||
webServicesList = make([]*webServiceType, 0)
|
|
||||||
|
|
||||||
websocketServices = make(map[string]map[string]*websocketServiceType)
|
|
||||||
websocketServicesLock = sync.RWMutex{}
|
|
||||||
websocketServicesList = make([]*websocketServiceType, 0)
|
|
||||||
|
|
||||||
// Rewrite 与 Proxy 按 Host 隔离 (编译后的最终路由)
|
|
||||||
hostRewrites = make(map[string][]*rewriteType)
|
|
||||||
hostProxies = make(map[string][]*proxyType)
|
|
||||||
|
|
||||||
// 按来源隔离的策略,避免互相覆盖
|
|
||||||
codeProxies = make(map[string][]*proxyType)
|
|
||||||
fileProxies = make(map[string][]*proxyType)
|
|
||||||
dynamicProxies = make(map[string][]*proxyType)
|
|
||||||
|
|
||||||
codeRewrites = make(map[string][]*rewriteType)
|
|
||||||
fileRewrites = make(map[string][]*rewriteType)
|
|
||||||
dynamicRewrites = make(map[string][]*rewriteType)
|
|
||||||
|
|
||||||
hostPoliciesLock = sync.RWMutex{}
|
|
||||||
|
|
||||||
// 过滤器与拦截器
|
|
||||||
inFilters = make([]func(*map[string]any, *Request, *Response, *log.Logger) any, 0)
|
|
||||||
outFilters = make([]func(map[string]any, *Request, *Response, any, *log.Logger) (any, bool), 0)
|
|
||||||
errorHandle func(any, *Request, *Response) any
|
|
||||||
webAuthChecker func(int, *log.Logger, *string, map[string]any, *Request, *Response, *WebServiceOptions) (pass bool, object any)
|
|
||||||
webAuthCheckers = make(map[int]func(int, *log.Logger, *string, map[string]any, *Request, *Response, *WebServiceOptions) (pass bool, object any))
|
|
||||||
|
|
||||||
// 注入点
|
|
||||||
injectObjects = make(map[reflect.Type]any)
|
|
||||||
injectFunctions = make(map[reflect.Type]func() any)
|
|
||||||
|
|
||||||
usedDeviceIdKey string
|
|
||||||
usedClientAppKey string
|
|
||||||
usedSessionIdKey string
|
|
||||||
sessionIdMaker func() string
|
|
||||||
)
|
|
||||||
|
|
||||||
// SetClientKeys 设置客户端标识相关的 Key 映射
|
// SetClientKeys 设置客户端标识相关的 Key 映射
|
||||||
func SetClientKeys(deviceIdKey, clientAppKey, sessionIdKey string) {
|
func SetClientKeys(deviceIdKey, clientAppKey, sessionIdKey string) {
|
||||||
usedDeviceIdKey = deviceIdKey
|
DefaultServer.SetClientKeys(deviceIdKey, clientAppKey, sessionIdKey)
|
||||||
usedClientAppKey = clientAppKey
|
}
|
||||||
usedSessionIdKey = sessionIdKey
|
|
||||||
|
func (ws *WebServer) SetClientKeys(deviceIdKey, clientAppKey, sessionIdKey string) {
|
||||||
|
ws.usedDeviceIdKey = deviceIdKey
|
||||||
|
ws.usedClientAppKey = clientAppKey
|
||||||
|
ws.usedSessionIdKey = sessionIdKey
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSessionIdMaker 设置自定义会话 ID 生成器
|
// SetSessionIdMaker 设置自定义会话 ID 生成器
|
||||||
func SetSessionIdMaker(maker func() string) {
|
func SetSessionIdMaker(maker func() string) {
|
||||||
sessionIdMaker = maker
|
DefaultServer.SetSessionIdMaker(maker)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *WebServer) SetSessionIdMaker(maker func() string) {
|
||||||
|
ws.sessionIdMaker = maker
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetAuthChecker 设置全局鉴权器
|
// SetAuthChecker 设置全局鉴权器
|
||||||
func SetAuthChecker(authChecker func(authLevel int, logger *log.Logger, url *string, in map[string]any, request *Request, response *Response, options *WebServiceOptions) (pass bool, object any)) {
|
func SetAuthChecker(authChecker func(authLevel int, logger *log.Logger, url *string, in map[string]any, request *Request, response *Response, options *WebServiceOptions) (pass bool, object any)) {
|
||||||
webAuthChecker = authChecker
|
DefaultServer.SetAuthChecker(authChecker)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *WebServer) SetAuthChecker(authChecker func(authLevel int, logger *log.Logger, url *string, in map[string]any, request *Request, response *Response, options *WebServiceOptions) (pass bool, object any)) {
|
||||||
|
ws.webAuthChecker = authChecker
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddAuthChecker 为指定级别添加鉴权器
|
// AddAuthChecker 为指定级别添加鉴权器
|
||||||
func AddAuthChecker(authLevels []int, authChecker func(authLevel int, logger *log.Logger, url *string, in map[string]any, request *Request, response *Response, options *WebServiceOptions) (pass bool, object any)) {
|
func AddAuthChecker(authLevels []int, authChecker func(authLevel int, logger *log.Logger, url *string, in map[string]any, request *Request, response *Response, options *WebServiceOptions) (pass bool, object any)) {
|
||||||
|
DefaultServer.AddAuthChecker(authLevels, authChecker)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *WebServer) AddAuthChecker(authLevels []int, authChecker func(authLevel int, logger *log.Logger, url *string, in map[string]any, request *Request, response *Response, options *WebServiceOptions) (pass bool, object any)) {
|
||||||
for _, al := range authLevels {
|
for _, al := range authLevels {
|
||||||
webAuthCheckers[al] = authChecker
|
ws.webAuthCheckers[al] = authChecker
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetInFilter 设置前置过滤器
|
// SetInFilter 设置前置过滤器
|
||||||
func SetInFilter(filter func(in *map[string]any, request *Request, response *Response, logger *log.Logger) (out any)) {
|
func SetInFilter(filter func(in *map[string]any, request *Request, response *Response, logger *log.Logger) (out any)) {
|
||||||
inFilters = append(inFilters, filter)
|
DefaultServer.SetInFilter(filter)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *WebServer) SetInFilter(filter func(in *map[string]any, request *Request, response *Response, logger *log.Logger) (out any)) {
|
||||||
|
ws.inFilters = append(ws.inFilters, filter)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddShutdownHook 增加停机钩子
|
||||||
|
func AddShutdownHook(hook func()) {
|
||||||
|
DefaultServer.AddShutdownHook(hook)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *WebServer) AddShutdownHook(hook func()) {
|
||||||
|
ws.shutdownHooksLock.Lock()
|
||||||
|
defer ws.shutdownHooksLock.Unlock()
|
||||||
|
ws.shutdownHooks = append(ws.shutdownHooks, hook)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetOutFilter 设置后置过滤器
|
// SetOutFilter 设置后置过滤器
|
||||||
func SetOutFilter(filter func(in map[string]any, request *Request, response *Response, out any, logger *log.Logger) (newOut any, isOver bool)) {
|
func SetOutFilter(filter func(in map[string]any, request *Request, response *Response, out any, logger *log.Logger) (newOut any, isOver bool)) {
|
||||||
outFilters = append(outFilters, filter)
|
DefaultServer.SetOutFilter(filter)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *WebServer) SetOutFilter(filter func(in map[string]any, request *Request, response *Response, out any, logger *log.Logger) (newOut any, isOver bool)) {
|
||||||
|
ws.webServicesLock.Lock()
|
||||||
|
defer ws.webServicesLock.Unlock()
|
||||||
|
ws.outFilters = append(ws.outFilters, filter)
|
||||||
|
ws.hasOutFilter = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// HostContext 提供流式服务注册能力
|
// HostContext 提供流式服务注册能力
|
||||||
type HostContext struct {
|
type HostContext struct {
|
||||||
|
ws *WebServer
|
||||||
host string
|
host string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Host 指定服务运行的 Host (支持 "example.com", ":8080", "example.com:8080", "*")
|
// Host 指定服务运行的 Host (支持 "example.com", ":8080", "example.com:8080", "*")
|
||||||
func Host(host string) *HostContext {
|
func Host(host string) *HostContext {
|
||||||
|
return DefaultServer.Host(host)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *WebServer) Host(host string) *HostContext {
|
||||||
if host == "" {
|
if host == "" {
|
||||||
host = "*"
|
host = "*"
|
||||||
}
|
}
|
||||||
return &HostContext{host: host}
|
return &HostContext{ws: ws, host: host}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register 注册一个 Web 服务 (使用默认 Host "*")
|
// Register 注册一个 Web 服务 (使用默认 Host "*")
|
||||||
func Register(method, path string, serviceFunc any) *webServiceType {
|
func Register(method, path string, serviceFunc any) *webServiceType {
|
||||||
return Host("*").Register(method, path, serviceFunc)
|
return DefaultServer.Register(method, path, serviceFunc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *WebServer) Register(method, path string, serviceFunc any) *webServiceType {
|
||||||
|
return ws.Host("*").Register(method, path, serviceFunc)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterWebsocket 注册一个 WebSocket 服务 (使用默认 Host "*")
|
// RegisterWebsocket 注册一个 WebSocket 服务 (使用默认 Host "*")
|
||||||
func RegisterWebsocket(path string, serviceFunc any) *websocketServiceType {
|
func RegisterWebsocket(path string, serviceFunc any) *websocketServiceType {
|
||||||
return Host("*").WebSocket(path, serviceFunc)
|
return DefaultServer.RegisterWebsocket(path, serviceFunc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *WebServer) RegisterWebsocket(path string, serviceFunc any) *websocketServiceType {
|
||||||
|
return ws.Host("*").WebSocket(path, serviceFunc)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Proxy 注册一个代理转发 (使用默认 Host "*")
|
// Proxy 注册一个代理转发 (使用默认 Host "*")
|
||||||
func Proxy(authLevel int, path string, to string) {
|
func Proxy(authLevel int, path string, to string) {
|
||||||
Host("*").Proxy(authLevel, path, to)
|
DefaultServer.Proxy(authLevel, path, to)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *WebServer) Proxy(authLevel int, path string, to string) {
|
||||||
|
ws.Host("*").Proxy(authLevel, path, to)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Restful 注册一个符合 RESTful 规范的服务结构体 (使用默认 Host "*")
|
// Restful 注册一个符合 RESTful 规范的服务结构体 (使用默认 Host "*")
|
||||||
func Restful(authLevel int, path string, serviceStruct any) {
|
func Restful(authLevel int, path string, serviceStruct any) {
|
||||||
Host("*").Restful(authLevel, path, serviceStruct)
|
DefaultServer.Restful(authLevel, path, serviceStruct)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *WebServer) Restful(authLevel int, path string, serviceStruct any) {
|
||||||
|
ws.Host("*").Restful(authLevel, path, serviceStruct)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hc *HostContext) Register(method, path string, serviceFunc any) *webServiceType {
|
func (hc *HostContext) Register(method, path string, serviceFunc any) *webServiceType {
|
||||||
@ -188,25 +199,27 @@ func (hc *HostContext) Register(method, path string, serviceFunc any) *webServic
|
|||||||
finds := finder.FindAllStringSubmatch(path, 20)
|
finds := finder.FindAllStringSubmatch(path, 20)
|
||||||
for _, found := range finds {
|
for _, found := range finds {
|
||||||
keyName = strings.Replace(keyName, regexp.QuoteMeta(found[0]), "(.*?)", 1)
|
keyName = strings.Replace(keyName, regexp.QuoteMeta(found[0]), "(.*?)", 1)
|
||||||
|
hc.ws.webServicesLock.Lock() // Fixed: use lock from ws
|
||||||
s.pathArgs = append(s.pathArgs, found[1])
|
s.pathArgs = append(s.pathArgs, found[1])
|
||||||
|
hc.ws.webServicesLock.Unlock()
|
||||||
}
|
}
|
||||||
if len(s.pathArgs) > 0 {
|
if len(s.pathArgs) > 0 {
|
||||||
s.pathMatcher, _ = regexp.Compile("^" + keyName + "$")
|
s.pathMatcher, _ = regexp.Compile("^" + keyName + "$")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
webServicesLock.Lock()
|
hc.ws.webServicesLock.Lock()
|
||||||
defer webServicesLock.Unlock()
|
defer hc.ws.webServicesLock.Unlock()
|
||||||
|
|
||||||
if s.pathMatcher == nil {
|
if s.pathMatcher == nil {
|
||||||
if webServices[s.host] == nil {
|
if hc.ws.webServices[s.host] == nil {
|
||||||
webServices[s.host] = make(map[string]*webServiceType)
|
hc.ws.webServices[s.host] = make(map[string]*webServiceType)
|
||||||
}
|
}
|
||||||
webServices[s.host][s.method+s.path] = s
|
hc.ws.webServices[s.host][s.method+s.path] = s
|
||||||
} else {
|
} else {
|
||||||
regexWebServices[s.host] = append(regexWebServices[s.host], s)
|
hc.ws.regexWebServices[s.host] = append(hc.ws.regexWebServices[s.host], s)
|
||||||
}
|
}
|
||||||
webServicesList = append(webServicesList, s)
|
hc.ws.webServicesList = append(hc.ws.webServicesList, s)
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -301,13 +314,13 @@ func (hc *HostContext) WebSocket(path string, serviceFunc any) *websocketService
|
|||||||
funcValue: reflect.ValueOf(serviceFunc),
|
funcValue: reflect.ValueOf(serviceFunc),
|
||||||
}
|
}
|
||||||
|
|
||||||
websocketServicesLock.Lock()
|
hc.ws.websocketServicesLock.Lock()
|
||||||
defer websocketServicesLock.Unlock()
|
defer hc.ws.websocketServicesLock.Unlock()
|
||||||
if websocketServices[hc.host] == nil {
|
if hc.ws.websocketServices[hc.host] == nil {
|
||||||
websocketServices[hc.host] = make(map[string]*websocketServiceType)
|
hc.ws.websocketServices[hc.host] = make(map[string]*websocketServiceType)
|
||||||
}
|
}
|
||||||
websocketServices[hc.host][path] = ws
|
hc.ws.websocketServices[hc.host][path] = ws
|
||||||
websocketServicesList = append(websocketServicesList, ws)
|
hc.ws.websocketServicesList = append(hc.ws.websocketServicesList, ws)
|
||||||
return ws
|
return ws
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -452,10 +465,14 @@ func makeCachedService(matchedService any) (*webServiceType, error) {
|
|||||||
|
|
||||||
// GetInject 获取注入对象
|
// GetInject 获取注入对象
|
||||||
func GetInject(dataType reflect.Type) any {
|
func GetInject(dataType reflect.Type) any {
|
||||||
if obj, exists := injectObjects[dataType]; exists {
|
return DefaultServer.GetInject(dataType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *WebServer) GetInject(dataType reflect.Type) any {
|
||||||
|
if obj, exists := ws.injectObjects[dataType]; exists {
|
||||||
return obj
|
return obj
|
||||||
}
|
}
|
||||||
if factory, exists := injectFunctions[dataType]; exists {
|
if factory, exists := ws.injectFunctions[dataType]; exists {
|
||||||
return factory()
|
return factory()
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -465,9 +482,154 @@ func GetInject(dataType reflect.Type) any {
|
|||||||
func GetInjectT[T any]() T {
|
func GetInjectT[T any]() T {
|
||||||
var zero T
|
var zero T
|
||||||
t := reflect.TypeOf((*T)(nil)).Elem()
|
t := reflect.TypeOf((*T)(nil)).Elem()
|
||||||
obj := GetInject(t)
|
obj := DefaultServer.GetInject(t)
|
||||||
if obj == nil {
|
if obj == nil {
|
||||||
return zero
|
return zero
|
||||||
}
|
}
|
||||||
return obj.(T)
|
return obj.(T)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var webDevOnce sync.Once
|
||||||
|
|
||||||
|
// EnableWebDev 开启 Web 开发模式,支持自动刷新
|
||||||
|
func EnableWebDev(config watch.Config) {
|
||||||
|
DefaultServer.webDevEnabled = true
|
||||||
|
DefaultServer.webDevConfig = config
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *WebServer) initWebDev(logger *log.Logger) {
|
||||||
|
webDevOnce.Do(func() {
|
||||||
|
logger.Warning("Web Development Mode Enabled. This should NOT be used in production environment.")
|
||||||
|
onWatchConn := map[string]*WebSocketConn{}
|
||||||
|
onWatchLock := sync.Mutex{}
|
||||||
|
|
||||||
|
// 1. 注册 WebSocket 服务
|
||||||
|
ws.RegisterWebsocket("/_watch", func(request *Request, conn *WebSocketConn, logger *log.Logger) {
|
||||||
|
onWatchLock.Lock()
|
||||||
|
onWatchConn[request.Id] = conn
|
||||||
|
onWatchLock.Unlock()
|
||||||
|
|
||||||
|
// 保持连接,处理消息 (如 ping)
|
||||||
|
for {
|
||||||
|
if _, err := conn.ReadString(); err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
onWatchLock.Lock()
|
||||||
|
delete(onWatchConn, request.Id)
|
||||||
|
onWatchLock.Unlock()
|
||||||
|
})
|
||||||
|
|
||||||
|
// 2. 启动文件监听
|
||||||
|
watcher, err := watch.Start(ws.webDevConfig, func(e *watch.Event) {
|
||||||
|
onWatchLock.Lock()
|
||||||
|
defer onWatchLock.Unlock()
|
||||||
|
for _, conn := range onWatchConn {
|
||||||
|
_ = conn.Send("reload")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to start watch for EnableWebDev", "error", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 注册停机钩子
|
||||||
|
ws.AddShutdownHook(func() {
|
||||||
|
watcher.Stop()
|
||||||
|
onWatchLock.Lock()
|
||||||
|
for _, conn := range onWatchConn {
|
||||||
|
_ = conn.Close()
|
||||||
|
}
|
||||||
|
onWatchLock.Unlock()
|
||||||
|
})
|
||||||
|
|
||||||
|
// 4. 注册输出过滤器进行注入
|
||||||
|
ws.SetOutFilter(func(in map[string]any, request *Request, response *Response, out any, logger *log.Logger) (newOut any, isOver bool) {
|
||||||
|
contentType := response.Header().Get("Content-Type")
|
||||||
|
var outStr string
|
||||||
|
|
||||||
|
if out != nil {
|
||||||
|
switch v := out.(type) {
|
||||||
|
case string:
|
||||||
|
outStr = v
|
||||||
|
case []byte:
|
||||||
|
outStr = string(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if outStr == "" && response.changed {
|
||||||
|
outStr = string(response.GetBody())
|
||||||
|
}
|
||||||
|
|
||||||
|
if outStr == "" {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
isHtml := strings.HasPrefix(contentType, "text/html")
|
||||||
|
if !isHtml && (contentType == "" || strings.HasPrefix(contentType, "text/plain")) {
|
||||||
|
// 检测内容前 100 字节是否包含 <html
|
||||||
|
checkLen := int(math.Min(float64(len(outStr)), 100))
|
||||||
|
if strings.Contains(strings.ToLower(outStr[0:checkLen]), "<html") {
|
||||||
|
isHtml = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if isHtml {
|
||||||
|
if strings.Contains(outStr, "let _watchWS = null") {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
// 注入自动刷新的代码
|
||||||
|
injectCode := `<script>
|
||||||
|
let _watchWS = null
|
||||||
|
let _watchWSConnection = false
|
||||||
|
let _watchWSIsFirst = true
|
||||||
|
function connect() {
|
||||||
|
_watchWSConnection = true
|
||||||
|
let ws = new WebSocket(location.protocol.replace('http', 'ws') + '//' + location.host + '/_watch')
|
||||||
|
ws.onopen = () => {
|
||||||
|
_watchWS = ws
|
||||||
|
_watchWSConnection = false
|
||||||
|
if( !_watchWSIsFirst ) location.reload()
|
||||||
|
_watchWSIsFirst = false
|
||||||
|
}
|
||||||
|
ws.onmessage = () => {
|
||||||
|
location.reload()
|
||||||
|
}
|
||||||
|
ws.onclose = () => {
|
||||||
|
_watchWS = null
|
||||||
|
_watchWSConnection = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
setInterval(()=>{
|
||||||
|
if(_watchWS!= null){
|
||||||
|
try{
|
||||||
|
_watchWS.send("ping")
|
||||||
|
}catch(err){
|
||||||
|
_watchWS = null
|
||||||
|
_watchWSConnection = false
|
||||||
|
}
|
||||||
|
} else if(!_watchWSConnection){
|
||||||
|
connect()
|
||||||
|
}
|
||||||
|
}, 1000)
|
||||||
|
connect()
|
||||||
|
</script>`
|
||||||
|
// 仅替换最后一个 </html> 避免多个标签时的重复注入
|
||||||
|
lastIndex := strings.LastIndex(outStr, "</html>")
|
||||||
|
if lastIndex != -1 {
|
||||||
|
outStr = outStr[:lastIndex] + injectCode + outStr[lastIndex:]
|
||||||
|
} else {
|
||||||
|
outStr = outStr + injectCode
|
||||||
|
}
|
||||||
|
|
||||||
|
// 无论如何,只要我们提供了新的输出,就清空原始 Body,防止 handler 重复写入
|
||||||
|
response.ClearBody()
|
||||||
|
return []byte(outStr), false
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
|||||||
@ -12,9 +12,9 @@ func TestServiceRegister(t *testing.T) {
|
|||||||
|
|
||||||
Host("*").Register("*", "/test", handler).Auth(0).Memo("test service")
|
Host("*").Register("*", "/test", handler).Auth(0).Memo("test service")
|
||||||
|
|
||||||
webServicesLock.RLock()
|
DefaultServer.webServicesLock.RLock()
|
||||||
s := webServices["*"]["*/test"]
|
s := DefaultServer.webServices["*"]["*/test"]
|
||||||
webServicesLock.RUnlock()
|
DefaultServer.webServicesLock.RUnlock()
|
||||||
|
|
||||||
if s == nil {
|
if s == nil {
|
||||||
t.Fatal("Service not registered")
|
t.Fatal("Service not registered")
|
||||||
@ -35,9 +35,9 @@ func TestRegexServiceRegister(t *testing.T) {
|
|||||||
|
|
||||||
Host("*").Register("*", "/user/{id}", handler).Auth(0).Memo("get user")
|
Host("*").Register("*", "/user/{id}", handler).Auth(0).Memo("get user")
|
||||||
|
|
||||||
webServicesLock.RLock()
|
DefaultServer.webServicesLock.RLock()
|
||||||
found := false
|
found := false
|
||||||
for _, services := range regexWebServices {
|
for _, services := range DefaultServer.regexWebServices {
|
||||||
for _, s := range services {
|
for _, s := range services {
|
||||||
if s.path == "/user/{id}" {
|
if s.path == "/user/{id}" {
|
||||||
found = true
|
found = true
|
||||||
@ -51,7 +51,7 @@ func TestRegexServiceRegister(t *testing.T) {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
webServicesLock.RUnlock()
|
DefaultServer.webServicesLock.RUnlock()
|
||||||
|
|
||||||
if !found {
|
if !found {
|
||||||
t.Fatal("Regex service not registered")
|
t.Fatal("Regex service not registered")
|
||||||
|
|||||||
@ -72,7 +72,7 @@ func TestSessionInjection(t *testing.T) {
|
|||||||
}
|
}
|
||||||
Host("*").GET("/test-session", handler)
|
Host("*").GET("/test-session", handler)
|
||||||
|
|
||||||
rh := &RouteHandler{}
|
rh := &RouteHandler{ws: DefaultServer}
|
||||||
req := httptest.NewRequest("GET", "/test-session", nil)
|
req := httptest.NewRequest("GET", "/test-session", nil)
|
||||||
req.Header.Set("sessid", "sess_123")
|
req.Header.Set("sessid", "sess_123")
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@ -110,7 +110,7 @@ func TestCustomAuthInjection(t *testing.T) {
|
|||||||
}
|
}
|
||||||
Host("*").GET("/test-auth", handler).Auth(10)
|
Host("*").GET("/test-auth", handler).Auth(10)
|
||||||
|
|
||||||
rh := &RouteHandler{}
|
rh := &RouteHandler{ws: DefaultServer}
|
||||||
req := httptest.NewRequest("GET", "/test-auth", nil)
|
req := httptest.NewRequest("GET", "/test-auth", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
@ -132,7 +132,7 @@ func TestAutomaticAuthLevelCheck(t *testing.T) {
|
|||||||
}
|
}
|
||||||
Host("*").GET("/test-auto-auth", handler).Auth(1)
|
Host("*").GET("/test-auto-auth", handler).Auth(1)
|
||||||
|
|
||||||
rh := &RouteHandler{}
|
rh := &RouteHandler{ws: DefaultServer}
|
||||||
|
|
||||||
// 1. 无 Session 或 AuthLevel=0 时应失败
|
// 1. 无 Session 或 AuthLevel=0 时应失败
|
||||||
req1 := httptest.NewRequest("GET", "/test-auto-auth", nil)
|
req1 := httptest.NewRequest("GET", "/test-auto-auth", nil)
|
||||||
|
|||||||
104
static.go
104
static.go
@ -5,58 +5,59 @@ import (
|
|||||||
"apigo.cc/go/log"
|
"apigo.cc/go/log"
|
||||||
"mime"
|
"mime"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
statics = make(map[string]*string)
|
|
||||||
staticsByHost = make(map[string]map[string]*string)
|
|
||||||
|
|
||||||
codeStatics = make(map[string]map[string]*string)
|
|
||||||
fileStatics = make(map[string]map[string]*string)
|
|
||||||
dynamicStatics = make(map[string]map[string]*string)
|
|
||||||
|
|
||||||
staticsByHostLock = sync.RWMutex{}
|
|
||||||
)
|
|
||||||
|
|
||||||
// Static 注册静态文件目录
|
// Static 注册静态文件目录
|
||||||
func (hc *HostContext) Static(path, rootPath string) *HostContext {
|
func (hc *HostContext) Static(path, rootPath string) *HostContext {
|
||||||
host := hc.host
|
host := hc.host
|
||||||
if host == "*" {
|
if host == "*" {
|
||||||
host = ""
|
host = ""
|
||||||
}
|
}
|
||||||
StaticByHost(path, rootPath, host)
|
hc.ws.StaticByHost(path, rootPath, host)
|
||||||
return hc
|
return hc
|
||||||
}
|
}
|
||||||
|
|
||||||
// Static 注册静态文件目录 (使用默认 Host "*")
|
// Static 注册静态文件目录 (使用默认 Host "*")
|
||||||
func Static(path, rootPath string) {
|
func Static(path, rootPath string) {
|
||||||
Host("*").Static(path, rootPath)
|
DefaultServer.Static(path, rootPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *WebServer) Static(path, rootPath string) {
|
||||||
|
ws.Host("*").Static(path, rootPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
// StaticByHost 为指定域名注册静态文件目录
|
// StaticByHost 为指定域名注册静态文件目录
|
||||||
func StaticByHost(path, rootPath, host string) {
|
func StaticByHost(path, rootPath, host string) {
|
||||||
|
DefaultServer.StaticByHost(path, rootPath, host)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *WebServer) StaticByHost(path, rootPath, host string) {
|
||||||
if !filepath.IsAbs(rootPath) {
|
if !filepath.IsAbs(rootPath) {
|
||||||
if absPath, err := filepath.Abs(rootPath); err == nil {
|
if absPath, err := filepath.Abs(rootPath); err == nil {
|
||||||
rootPath = absPath
|
rootPath = absPath
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
staticsByHostLock.Lock()
|
ws.staticsByHostLock.Lock()
|
||||||
defer staticsByHostLock.Unlock()
|
defer ws.staticsByHostLock.Unlock()
|
||||||
|
|
||||||
if codeStatics[host] == nil {
|
if ws.codeStatics[host] == nil {
|
||||||
codeStatics[host] = make(map[string]*string)
|
ws.codeStatics[host] = make(map[string]*string)
|
||||||
}
|
}
|
||||||
codeStatics[host][path] = &rootPath
|
ws.codeStatics[host][path] = &rootPath
|
||||||
rebuildStaticsUnderLock(host)
|
ws.rebuildStaticsUnderLock(host)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReplaceStatics 使用 Copy-on-Write 机制原子地替换指定 host 下的动态静态目录规则
|
// ReplaceStatics 使用 Copy-on-Write 机制原子地替换指定 host 下的动态静态目录规则
|
||||||
func ReplaceStatics(host string, config map[string]string) {
|
func ReplaceStatics(host string, config map[string]string) {
|
||||||
|
DefaultServer.ReplaceStatics(host, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *WebServer) ReplaceStatics(host string, config map[string]string) {
|
||||||
newStatics := make(map[string]*string, len(config))
|
newStatics := make(map[string]*string, len(config))
|
||||||
for path, rootPath := range config {
|
for path, rootPath := range config {
|
||||||
rp := rootPath
|
rp := rootPath
|
||||||
@ -68,60 +69,38 @@ func ReplaceStatics(host string, config map[string]string) {
|
|||||||
newStatics[path] = &rp
|
newStatics[path] = &rp
|
||||||
}
|
}
|
||||||
|
|
||||||
staticsByHostLock.Lock()
|
ws.staticsByHostLock.Lock()
|
||||||
defer staticsByHostLock.Unlock()
|
defer ws.staticsByHostLock.Unlock()
|
||||||
|
|
||||||
dynamicStatics[host] = newStatics
|
ws.dynamicStatics[host] = newStatics
|
||||||
rebuildStaticsUnderLock(host)
|
ws.rebuildStaticsUnderLock(host)
|
||||||
}
|
}
|
||||||
|
|
||||||
func rebuildStaticsUnderLock(host string) {
|
func (ws *WebServer) getStaticFilePath(requestPath, host string) string {
|
||||||
combined := make(map[string]*string)
|
requestPath, _ = url.PathUnescape(requestPath)
|
||||||
|
ws.staticsByHostLock.RLock()
|
||||||
// 合并三种来源的静态路由
|
defer ws.staticsByHostLock.RUnlock()
|
||||||
for k, v := range codeStatics[host] {
|
|
||||||
combined[k] = v
|
|
||||||
}
|
|
||||||
for k, v := range fileStatics[host] {
|
|
||||||
combined[k] = v
|
|
||||||
}
|
|
||||||
for k, v := range dynamicStatics[host] {
|
|
||||||
combined[k] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
if host == "" {
|
|
||||||
statics = combined
|
|
||||||
} else {
|
|
||||||
staticsByHost[host] = combined
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getStaticFilePath(requestPath, host string) string {
|
|
||||||
staticsByHostLock.RLock()
|
|
||||||
defer staticsByHostLock.RUnlock()
|
|
||||||
|
|
||||||
// 优先匹配指定域名的配置
|
// 优先匹配指定域名的配置
|
||||||
if hostConfig, exists := staticsByHost[host]; exists {
|
if filePath := ws.findMatchedPathSorted(ws.hostStatics[host], requestPath); filePath != "" {
|
||||||
if filePath := findMatchedPath(hostConfig, requestPath); filePath != "" {
|
|
||||||
return filePath
|
return filePath
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// 匹配全局配置
|
// 匹配全局配置
|
||||||
return findMatchedPath(statics, requestPath)
|
return ws.findMatchedPathSorted(ws.hostStatics[""], requestPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
func findMatchedPath(config map[string]*string, requestPath string) string {
|
func (ws *WebServer) findMatchedPathSorted(config []*staticType, requestPath string) string {
|
||||||
for urlPath, rootPath := range config {
|
for _, rule := range config {
|
||||||
if strings.HasPrefix(requestPath, urlPath) {
|
if strings.HasPrefix(requestPath, rule.path) {
|
||||||
return filepath.Join(*rootPath, requestPath[len(urlPath):])
|
return filepath.Join(*rule.rootPath, requestPath[len(rule.path):])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func processStatic(requestPath string, request *Request, response *Response, logger *log.Logger) bool {
|
func (ws *WebServer) processStatic(requestPath string, request *Request, response *Response, logger *log.Logger) bool {
|
||||||
filePath := getStaticFilePath(requestPath, request.Host)
|
filePath := ws.getStaticFilePath(requestPath, request.Host)
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -133,7 +112,12 @@ func processStatic(requestPath string, request *Request, response *Response, log
|
|||||||
|
|
||||||
if info.IsDir {
|
if info.IsDir {
|
||||||
// 自动查找索引文件
|
// 自动查找索引文件
|
||||||
for _, indexFile := range Config.IndexFiles {
|
indexFiles := ws.Config.IndexFiles
|
||||||
|
if len(indexFiles) == 0 {
|
||||||
|
indexFiles = []string{"index.html", "index.htm"}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, indexFile := range indexFiles {
|
||||||
f := filepath.Join(filePath, indexFile)
|
f := filepath.Join(filePath, indexFile)
|
||||||
if i := file.GetFileInfo(f); i != nil && !i.IsDir {
|
if i := file.GetFileInfo(f); i != nil && !i.IsDir {
|
||||||
filePath = f
|
filePath = f
|
||||||
@ -148,7 +132,7 @@ func processStatic(requestPath string, request *Request, response *Response, log
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 检查 304
|
// 检查 304
|
||||||
if ifModifiedSince := request.Header.Get("If-Modified-Since"); ifModifiedSince != "" {
|
if ifModifiedSince := request.Header().Get("If-Modified-Since"); ifModifiedSince != "" {
|
||||||
if t, err := time.Parse(http.TimeFormat, ifModifiedSince); err == nil {
|
if t, err := time.Parse(http.TimeFormat, ifModifiedSince); err == nil {
|
||||||
if time.Unix(info.ModTime, 0).Truncate(time.Second).Before(t.Truncate(time.Second)) ||
|
if time.Unix(info.ModTime, 0).Truncate(time.Second).Before(t.Truncate(time.Second)) ||
|
||||||
time.Unix(info.ModTime, 0).Truncate(time.Second).Equal(t.Truncate(time.Second)) {
|
time.Unix(info.ModTime, 0).Truncate(time.Second).Equal(t.Truncate(time.Second)) {
|
||||||
|
|||||||
@ -19,7 +19,7 @@ func TestStaticService(t *testing.T) {
|
|||||||
// 注册静态目录
|
// 注册静态目录
|
||||||
Static("/ui", tempDir)
|
Static("/ui", tempDir)
|
||||||
|
|
||||||
rh := &RouteHandler{}
|
rh := &RouteHandler{ws: DefaultServer}
|
||||||
|
|
||||||
// 测试成功访问
|
// 测试成功访问
|
||||||
req := httptest.NewRequest("GET", "/ui/index.html", nil)
|
req := httptest.NewRequest("GET", "/ui/index.html", nil)
|
||||||
@ -52,7 +52,7 @@ func TestHostStaticService(t *testing.T) {
|
|||||||
// 注册域名特定静态文件服务
|
// 注册域名特定静态文件服务
|
||||||
Host("example.com").Static("/host-ui", tempDir)
|
Host("example.com").Static("/host-ui", tempDir)
|
||||||
|
|
||||||
rh := &RouteHandler{}
|
rh := &RouteHandler{ws: DefaultServer}
|
||||||
|
|
||||||
// 1. 匹配域名访问
|
// 1. 匹配域名访问
|
||||||
req1 := httptest.NewRequest("GET", "/host-ui/index.html", nil)
|
req1 := httptest.NewRequest("GET", "/host-ui/index.html", nil)
|
||||||
@ -77,4 +77,3 @@ func TestHostStaticService(t *testing.T) {
|
|||||||
t.Errorf("Expected 404 for mismatched host, got %d", w2.Code)
|
t.Errorf("Expected 404 for mismatched host, got %d", w2.Code)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
10
websocket.go
10
websocket.go
@ -59,7 +59,7 @@ func Upgrade(response *Response, request *Request) (*WebSocketConn, error) {
|
|||||||
return &WebSocketConn{Conn: conn}, nil
|
return &WebSocketConn{Conn: conn}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func doWebsocketService(ws *websocketServiceType, request *Request, response *Response, logger *log.Logger, object any) {
|
func (ws *WebServer) doWebsocketService(wsc *websocketServiceType, request *Request, response *Response, logger *log.Logger, object any) {
|
||||||
wsConn, err := Upgrade(response, request)
|
wsConn, err := Upgrade(response, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("websocket upgrade failed", "error", err.Error())
|
logger.Error("websocket upgrade failed", "error", err.Error())
|
||||||
@ -68,9 +68,9 @@ func doWebsocketService(ws *websocketServiceType, request *Request, response *Re
|
|||||||
defer wsConn.Close()
|
defer wsConn.Close()
|
||||||
|
|
||||||
// 调用业务处理函数,注入依赖
|
// 调用业务处理函数,注入依赖
|
||||||
params := make([]reflect.Value, ws.funcType.NumIn())
|
params := make([]reflect.Value, wsc.funcType.NumIn())
|
||||||
for i := 0; i < len(params); i++ {
|
for i := 0; i < len(params); i++ {
|
||||||
t := ws.funcType.In(i)
|
t := wsc.funcType.In(i)
|
||||||
if t == reflect.TypeOf(request) {
|
if t == reflect.TypeOf(request) {
|
||||||
params[i] = reflect.ValueOf(request)
|
params[i] = reflect.ValueOf(request)
|
||||||
} else if t == reflect.TypeOf(logger) {
|
} else if t == reflect.TypeOf(logger) {
|
||||||
@ -81,11 +81,11 @@ func doWebsocketService(ws *websocketServiceType, request *Request, response *Re
|
|||||||
params[i] = reflect.ValueOf(wsConn.Conn)
|
params[i] = reflect.ValueOf(wsConn.Conn)
|
||||||
} else if object != nil && reflect.TypeOf(object).AssignableTo(t) {
|
} else if object != nil && reflect.TypeOf(object).AssignableTo(t) {
|
||||||
params[i] = reflect.ValueOf(object)
|
params[i] = reflect.ValueOf(object)
|
||||||
} else if obj := GetInject(t); obj != nil {
|
} else if obj := ws.GetInject(t); obj != nil {
|
||||||
params[i] = reflect.ValueOf(obj)
|
params[i] = reflect.ValueOf(obj)
|
||||||
} else {
|
} else {
|
||||||
params[i] = reflect.New(t).Elem()
|
params[i] = reflect.New(t).Elem()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ws.funcValue.Call(params)
|
wsc.funcValue.Call(params)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,8 +1,12 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"apigo.cc/go/log"
|
||||||
|
"apigo.cc/go/watch"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
@ -20,7 +24,7 @@ func TestWebSocketService(t *testing.T) {
|
|||||||
}).Auth(0).Memo("test websocket")
|
}).Auth(0).Memo("test websocket")
|
||||||
|
|
||||||
// 启动测试服务器
|
// 启动测试服务器
|
||||||
server := httptest.NewServer(&RouteHandler{})
|
server := httptest.NewServer(&RouteHandler{ws: DefaultServer})
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
// 建立连接
|
// 建立连接
|
||||||
@ -47,3 +51,63 @@ func TestWebSocketService(t *testing.T) {
|
|||||||
t.Errorf("Reply mismatch: %v", reply)
|
t.Errorf("Reply mismatch: %v", reply)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEnableWebDev(t *testing.T) {
|
||||||
|
// 1. 初始化 EnableWebDev
|
||||||
|
EnableWebDev(watch.Config{
|
||||||
|
Paths: []string{"."},
|
||||||
|
})
|
||||||
|
|
||||||
|
// 必须手动调用 initWebDev 或触发 Start,因为现在的逻辑是延迟初始化的
|
||||||
|
DefaultServer.initWebDev(log.DefaultLogger)
|
||||||
|
|
||||||
|
// 2. 准备一个真实的静态 HTML 文件
|
||||||
|
staticDir := "test_static"
|
||||||
|
_ = os.MkdirAll(staticDir, 0755)
|
||||||
|
htmlFile := filepath.Join(staticDir, "index.html")
|
||||||
|
_ = os.WriteFile(htmlFile, []byte("<html><head></head><body>Static Content</body></html>"), 0644)
|
||||||
|
defer os.RemoveAll(staticDir)
|
||||||
|
|
||||||
|
// 注册静态服务
|
||||||
|
Static("/static/", staticDir)
|
||||||
|
|
||||||
|
handler := &RouteHandler{ws: DefaultServer}
|
||||||
|
|
||||||
|
// 3. 测试静态文件注入
|
||||||
|
req := httptest.NewRequest("GET", "/static/index.html", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
if !strings.Contains(body, "let _watchWS = null") {
|
||||||
|
t.Errorf("Static HTML injection failed, code not found in body: %s", body)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. 测试普通服务注入
|
||||||
|
Register("GET", "/test-dev", func() string {
|
||||||
|
return "<html><head></head><body>Hello</body></html>"
|
||||||
|
})
|
||||||
|
|
||||||
|
req2 := httptest.NewRequest("GET", "/test-dev", nil)
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w2, req2)
|
||||||
|
|
||||||
|
body2 := w2.Body.String()
|
||||||
|
if !strings.Contains(body2, "let _watchWS = null") {
|
||||||
|
t.Errorf("Dynamic HTML injection failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. 验证非 HTML 不注入
|
||||||
|
Register("GET", "/test-json", func() map[string]string {
|
||||||
|
return map[string]string{"foo": "bar"}
|
||||||
|
})
|
||||||
|
req3 := httptest.NewRequest("GET", "/test-json", nil)
|
||||||
|
w3 := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w3, req3)
|
||||||
|
|
||||||
|
body3 := w3.Body.String()
|
||||||
|
if strings.Contains(body3, "let _watchWS = null") {
|
||||||
|
t.Errorf("JSON should not be injected")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user