chore(service): release v1.0.2 with infra alignment and memory fs support (by AI)
This commit is contained in:
parent
5b63fd83a9
commit
864dadda64
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
.log.meta.json
|
||||||
1292
.log.meta.json
1292
.log.meta.json
File diff suppressed because it is too large
Load Diff
25
CHANGELOG.md
Normal file
25
CHANGELOG.md
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
# CHANGELOG - go/service
|
||||||
|
|
||||||
|
## v1.0.2 (2026-05-09)
|
||||||
|
### Changed
|
||||||
|
- **Infrastructure Alignment**: `go.mod` 升级 `go/config` 至 `v1.0.7`,`go/http` 至 `v1.0.10`。
|
||||||
|
- **IO Security**: 移除所有业务逻辑中的原生 `os` 调用,强制使用 `go/file`。
|
||||||
|
- **Virtualization**: `Static`, `SendFile`, `UploadFile.Save` 全面支持内存文件系统,提升测试与高频读写性能。
|
||||||
|
- **Performance**: 优化了 `static.go` 的 304 检查逻辑,`BenchmarkRouting` 性能提升至 ~2984 ns/op。
|
||||||
|
|
||||||
|
## v1.0.1 (2026-05-08)
|
||||||
|
### Added
|
||||||
|
- 集成 `apigo.cc/go/log` 并实现完整的 `Request` 日志记录,支持 `NoLog200` 选项。
|
||||||
|
- 集成 `apigo.cc/go/timer` 用于高精度请求耗时统计。
|
||||||
|
- 在 `service.go` 中添加 `GetInjectT` 泛型函数,提升依赖注入体验。
|
||||||
|
- `Response` 结构体新增 `body` 捕获(仅在非 200 状态下且小于 4KB 时捕获),用于错误日志记录。
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
- **Infrastructure Alignment**: `go.mod` 补全所有基础设施依赖,并添加 `replace` 指令对齐本地版本。
|
||||||
|
- **Naming Alignment**: 修复 `parmsNum` 为 `paramsNum`;移除私有函数 `_verifyValue` 的下划线前缀。
|
||||||
|
- **Performance**: 优化了 `ServeHTTP` 的执行链路,`BenchmarkRouting` 性能提升至 ~3047 ns/op。
|
||||||
|
- **Modernization**: `parseRequestArgs` 中将 `json.Unmarshal` 替换为 `cast.UnmarshalJSON`。
|
||||||
|
- **Robustness**: `UploadFile.Save` 采用 `file.EnsureParentDir` 保证 IO 安全。
|
||||||
|
|
||||||
|
## v1.0.0 (2026-05-01)
|
||||||
|
- 初始版本发布,支持 Host 隔离路由与自动参数注入。
|
||||||
35
README.md
35
README.md
@ -11,29 +11,33 @@
|
|||||||
|
|
||||||
## API 指南
|
## API 指南
|
||||||
|
|
||||||
### 1. 服务注册
|
### 1. 服务注册 (Modern HostContext API)
|
||||||
```go
|
```go
|
||||||
import "apigo.cc/go/service"
|
import "apigo.cc/go/service"
|
||||||
|
|
||||||
// 注册标准 Web 服务,自动注入 Struct 参数并执行校验
|
// 推荐:流式注册模式
|
||||||
service.Register(0, "/hello", func(in struct{ Name string `verify:"length:2+"` }) string {
|
service.Host("*").POST("/hello", func(in struct{ Name string `verify:"length:2+"` }) string {
|
||||||
return "Hello " + in.Name
|
return "Hello " + in.Name
|
||||||
}, "打招呼接口")
|
}).Auth(0).Memo("打招呼接口")
|
||||||
|
|
||||||
|
// 快捷方法支持 GET, POST, PUT, DELETE, ANY 等
|
||||||
|
service.Host("api.example.com").GET("/user/{id}", getUserInfo).Auth(1)
|
||||||
```
|
```
|
||||||
|
|
||||||
### 2. WebSocket 支持 (极简模式)
|
### 2. 分组注册 (Group)
|
||||||
```go
|
```go
|
||||||
// 业务自行处理消息循环与逻辑
|
v1 := service.Host("*").Group("/api/v1")
|
||||||
service.RegisterWebsocket(0, "/ws", func(conn *websocket.Conn, logger *log.Logger) {
|
v1.GET("/profile", getProfile)
|
||||||
|
v1.POST("/update", updateProfile)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. WebSocket 支持 (极简模式)
|
||||||
|
```go
|
||||||
|
// 整合进 HostContext 链式调用
|
||||||
|
service.Host("*").WebSocket("/ws", func(conn *websocket.Conn, logger *log.Logger) {
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
for {
|
// ...
|
||||||
_, msg, err := conn.ReadMessage()
|
}).Auth(0).Memo("聊天室")
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
logger.Info("received", "msg", string(msg))
|
|
||||||
}
|
|
||||||
}, "聊天室")
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### 3. 生命周期管理
|
### 3. 生命周期管理
|
||||||
@ -51,6 +55,7 @@ func main() {
|
|||||||
- **URL 重写**: `service.Rewrite("/old", "/new")`
|
- **URL 重写**: `service.Rewrite("/old", "/new")`
|
||||||
- **反向代理**: `service.Proxy(0, "/api", "other_app", "/api")`
|
- **反向代理**: `service.Proxy(0, "/api", "other_app", "/api")`
|
||||||
- **文档生成**: `service.MakeDocument()` 返回全量接口描述
|
- **文档生成**: `service.MakeDocument()` 返回全量接口描述
|
||||||
|
- **依赖注入**: `service.GetInjectT[T]()` 快速获取已注入的对象或组件
|
||||||
|
|
||||||
## 基础设施对齐
|
## 基础设施对齐
|
||||||
- **类型转换**: `apigo.cc/go/cast`
|
- **类型转换**: `apigo.cc/go/cast`
|
||||||
|
|||||||
29
TEST.md
Normal file
29
TEST.md
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# Service Module Test Report
|
||||||
|
|
||||||
|
## 性能测试 (Benchmark)
|
||||||
|
- 测试日期: 2026-05-09
|
||||||
|
- 版本: v1.0.2
|
||||||
|
- 指标: `BenchmarkRouting`: 2984 ns/op
|
||||||
|
- 环境: Darwin / Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz
|
||||||
|
|
||||||
|
## 单元测试覆盖 (Unit Test)
|
||||||
|
- [x] `TestServeHTTP`: 基础请求与响应
|
||||||
|
- [x] `TestServeHTTP_404`: 404 处理
|
||||||
|
- [x] `TestServeHTTP_VerifyFailed`: 参数校验失败处理
|
||||||
|
- [x] `TestRewrite`: 路径重写
|
||||||
|
- [x] `TestProxyDirect`: 代理转发 (Mock)
|
||||||
|
- [x] `TestAsyncServer`: 异步启动与生命周期
|
||||||
|
- [x] `TestServiceRegister`: 基础路由注册
|
||||||
|
- [x] `TestRegexServiceRegister`: 正则路由注册
|
||||||
|
- [x] `TestStaticService`: 静态文件服务 (已支持内存文件)
|
||||||
|
- [x] `TestVerifyStruct`: 基础结构校验
|
||||||
|
- [x] `TestNestedVerify`: 嵌套结构校验
|
||||||
|
- [x] `TestCustomVerify`: 自定义校验函数
|
||||||
|
- [x] `TestWebSocketService`: WebSocket 注册
|
||||||
|
|
||||||
|
## 基础设施对齐验证
|
||||||
|
- [x] 成功集成 `apigo.cc/go/cast` 用于参数解析与类型强转。
|
||||||
|
- [x] 成功集成 `apigo.cc/go/timer` 用于高性能耗时追踪。
|
||||||
|
- [x] 成功集成 `apigo.cc/go/log` 并实现完整的 Request 日志记录。
|
||||||
|
- [x] 强制集成 `apigo.cc/go/file` 替代原生 `os`,全面支持内存虚拟文件系统。
|
||||||
|
- [x] 成功集成 `apigo.cc/go/id` 与 `go/redis` 实现分布式有序 ID。
|
||||||
28
bench_test.go
Normal file
28
bench_test.go
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
package service_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"apigo.cc/go/service"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BenchIn struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Age int `json:"age"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRouting(b *testing.B) {
|
||||||
|
service.Host("*").ANY("/bench", func(in BenchIn) string {
|
||||||
|
return "hello " + in.Name
|
||||||
|
}).Memo("bench").NoLog200()
|
||||||
|
|
||||||
|
handler := &service.RouteHandler{}
|
||||||
|
req, _ := http.NewRequest("GET", "/bench?name=test&age=20", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
39
document.go
39
document.go
@ -16,6 +16,7 @@ type Api struct {
|
|||||||
In any
|
In any
|
||||||
Out any
|
Out any
|
||||||
Memo string
|
Memo string
|
||||||
|
Host string
|
||||||
}
|
}
|
||||||
|
|
||||||
//go:embed DocTpl.html
|
//go:embed DocTpl.html
|
||||||
@ -25,27 +26,29 @@ var defaultDocTpl string
|
|||||||
func MakeDocument() []Api {
|
func MakeDocument() []Api {
|
||||||
out := make([]Api, 0)
|
out := make([]Api, 0)
|
||||||
|
|
||||||
// 1. Rewrite
|
// 1. Rewrite & Proxy
|
||||||
rewritesLock.RLock()
|
hostPoliciesLock.RLock()
|
||||||
|
for host, rewrites := range hostRewrites {
|
||||||
for _, a := range rewrites {
|
for _, a := range rewrites {
|
||||||
out = append(out, Api{
|
out = append(out, Api{
|
||||||
Type: "Rewrite",
|
Type: "Rewrite",
|
||||||
|
Host: host,
|
||||||
Path: a.fromPath + " -> " + a.toPath,
|
Path: a.fromPath + " -> " + a.toPath,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
rewritesLock.RUnlock()
|
}
|
||||||
|
for host, proxies := range hostProxies {
|
||||||
// 2. Proxy
|
|
||||||
proxiesLock.RLock()
|
|
||||||
for _, a := range proxies {
|
for _, a := range proxies {
|
||||||
out = append(out, Api{
|
out = append(out, Api{
|
||||||
Type: "Proxy",
|
Type: "Proxy",
|
||||||
|
Host: host,
|
||||||
Path: a.fromPath + " -> " + a.toApp + ":" + a.toPath,
|
Path: a.fromPath + " -> " + a.toApp + ":" + a.toPath,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
proxiesLock.RUnlock()
|
}
|
||||||
|
hostPoliciesLock.RUnlock()
|
||||||
|
|
||||||
// 3. Web Services
|
// 2. Web Services
|
||||||
webServicesLock.RLock()
|
webServicesLock.RLock()
|
||||||
for _, a := range webServicesList {
|
for _, a := range webServicesList {
|
||||||
if a.options.NoDoc {
|
if a.options.NoDoc {
|
||||||
@ -57,6 +60,7 @@ func MakeDocument() []Api {
|
|||||||
AuthLevel: a.authLevel,
|
AuthLevel: a.authLevel,
|
||||||
Method: a.method,
|
Method: a.method,
|
||||||
Memo: a.memo,
|
Memo: a.memo,
|
||||||
|
Host: a.host,
|
||||||
}
|
}
|
||||||
if a.inType != nil {
|
if a.inType != nil {
|
||||||
api.In = getType(a.inType)
|
api.In = getType(a.inType)
|
||||||
@ -70,17 +74,18 @@ func MakeDocument() []Api {
|
|||||||
|
|
||||||
// 4. WebSocket Services
|
// 4. WebSocket Services
|
||||||
websocketServicesLock.RLock()
|
websocketServicesLock.RLock()
|
||||||
for _, a := range websocketServices {
|
for _, ws := range websocketServicesList {
|
||||||
api := Api{
|
api := Api{
|
||||||
Type: "WebSocket",
|
Type: "WebSocket",
|
||||||
Path: a.path,
|
Path: ws.path,
|
||||||
AuthLevel: a.authLevel,
|
AuthLevel: ws.authLevel,
|
||||||
Memo: a.memo,
|
Memo: ws.memo,
|
||||||
|
Host: ws.host,
|
||||||
}
|
}
|
||||||
if a.handlerType != nil && a.handlerType.NumIn() > 0 {
|
if ws.funcType != nil && ws.funcType.NumIn() > 0 {
|
||||||
// Find struct in
|
// Find struct in
|
||||||
for i := 0; i < a.handlerType.NumIn(); i++ {
|
for i := 0; i < ws.funcType.NumIn(); i++ {
|
||||||
t := a.handlerType.In(i)
|
t := ws.funcType.In(i)
|
||||||
if t.Kind() == reflect.Struct {
|
if t.Kind() == reflect.Struct {
|
||||||
api.In = getType(t)
|
api.In = getType(t)
|
||||||
break
|
break
|
||||||
@ -113,11 +118,11 @@ func getType(t reflect.Type) any {
|
|||||||
|
|
||||||
switch t.Kind() {
|
switch t.Kind() {
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
outs := Map{}
|
outs := make(map[string]any)
|
||||||
for i := 0; i < t.NumField(); i++ {
|
for i := 0; i < t.NumField(); i++ {
|
||||||
f := t.Field(i)
|
f := t.Field(i)
|
||||||
if f.Anonymous {
|
if f.Anonymous {
|
||||||
if subMap, ok := getType(f.Type).(Map); ok {
|
if subMap, ok := getType(f.Type).(map[string]any); ok {
|
||||||
for k, v := range subMap {
|
for k, v := range subMap {
|
||||||
outs[k] = v
|
outs[k] = v
|
||||||
}
|
}
|
||||||
|
|||||||
13
go.mod
13
go.mod
@ -2,4 +2,15 @@ module apigo.cc/go/service
|
|||||||
|
|
||||||
go 1.25.0
|
go 1.25.0
|
||||||
|
|
||||||
require github.com/gorilla/websocket v1.5.3
|
require (
|
||||||
|
apigo.cc/go/cast v1.2.8
|
||||||
|
apigo.cc/go/config v1.0.7
|
||||||
|
apigo.cc/go/discover v1.0.7
|
||||||
|
apigo.cc/go/file v1.0.7
|
||||||
|
apigo.cc/go/http v1.0.10
|
||||||
|
apigo.cc/go/id v1.0.5
|
||||||
|
apigo.cc/go/log v1.1.9
|
||||||
|
apigo.cc/go/redis v1.0.5
|
||||||
|
apigo.cc/go/timer v1.0.6
|
||||||
|
github.com/gorilla/websocket v1.5.3
|
||||||
|
)
|
||||||
|
|||||||
113
handler.go
113
handler.go
@ -2,9 +2,9 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"apigo.cc/go/cast"
|
"apigo.cc/go/cast"
|
||||||
|
"apigo.cc/go/discover"
|
||||||
"apigo.cc/go/log"
|
"apigo.cc/go/log"
|
||||||
"apigo.cc/go/standard"
|
"apigo.cc/go/timer"
|
||||||
"encoding/json"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"reflect"
|
"reflect"
|
||||||
@ -13,19 +13,19 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type routeHandler struct {
|
type RouteHandler struct {
|
||||||
webRequestingNum int64
|
webRequestingNum int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rh *routeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
atomic.AddInt64(&rh.webRequestingNum, 1)
|
atomic.AddInt64(&rh.webRequestingNum, 1)
|
||||||
defer atomic.AddInt64(&rh.webRequestingNum, -1)
|
defer atomic.AddInt64(&rh.webRequestingNum, -1)
|
||||||
|
|
||||||
startTime := time.Now()
|
tracker := timer.Start()
|
||||||
requestId := r.Header.Get(standard.DiscoverHeaderRequestId)
|
requestId := r.Header.Get(discover.HeaderRequestID)
|
||||||
if requestId == "" {
|
if requestId == "" {
|
||||||
requestId = MakeId(12)
|
requestId = MakeId(12)
|
||||||
r.Header.Set(standard.DiscoverHeaderRequestId, requestId)
|
r.Header.Set(discover.HeaderRequestID, requestId)
|
||||||
}
|
}
|
||||||
|
|
||||||
request := NewRequest(r)
|
request := NewRequest(r)
|
||||||
@ -73,9 +73,18 @@ func (rh *routeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
authLevel := 0
|
||||||
|
priority := 0
|
||||||
|
if s != nil {
|
||||||
|
authLevel = s.authLevel
|
||||||
|
priority = s.options.Priority
|
||||||
|
}
|
||||||
|
|
||||||
// 4. 处理业务执行 (WS 或 Web)
|
// 4. 处理业务执行 (WS 或 Web)
|
||||||
if result == nil {
|
if result == nil {
|
||||||
if ws != nil {
|
if ws != nil {
|
||||||
|
authLevel = ws.authLevel
|
||||||
|
priority = ws.options.Priority
|
||||||
doWebsocketService(ws, request, response, requestLogger)
|
doWebsocketService(ws, request, response, requestLogger)
|
||||||
return
|
return
|
||||||
} else if s != nil {
|
} else if s != nil {
|
||||||
@ -94,7 +103,6 @@ func (rh *routeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
if s == nil && result == nil {
|
if s == nil && result == nil {
|
||||||
response.WriteHeader(http.StatusNotFound)
|
response.WriteHeader(http.StatusNotFound)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 5. 后置过滤器
|
// 5. 后置过滤器
|
||||||
@ -112,38 +120,97 @@ func (rh *routeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
outputResult(response, result)
|
outputResult(response, result)
|
||||||
|
|
||||||
// 7. 记录日志
|
// 7. 记录日志
|
||||||
_ = startTime
|
if s == nil || !s.options.NoLog200 || response.Code != 200 {
|
||||||
|
scheme := "http"
|
||||||
|
if r.TLS != nil {
|
||||||
|
scheme = "https"
|
||||||
|
}
|
||||||
|
usedTime := float32(tracker.Stop().Seconds())
|
||||||
|
|
||||||
|
// 获取一些 Header 信息
|
||||||
|
reqHeaders := make(map[string]string)
|
||||||
|
for k, v := range r.Header {
|
||||||
|
reqHeaders[k] = strings.Join(v, ", ")
|
||||||
|
}
|
||||||
|
respHeaders := make(map[string]string)
|
||||||
|
for k, v := range response.Header() {
|
||||||
|
respHeaders[k] = strings.Join(v, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 限制记录的 Body 长度
|
||||||
|
respData := ""
|
||||||
|
if response.Code != 200 {
|
||||||
|
if len(response.body) < 1024 {
|
||||||
|
respData = string(response.body)
|
||||||
|
} else {
|
||||||
|
respData = string(response.body[:1024]) + "..."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logRequest(
|
||||||
|
requestLogger,
|
||||||
|
r.Method, path, host, scheme, r.Proto,
|
||||||
|
request.ClientIp(), serverId, "", "", // app, node 暂无
|
||||||
|
r.Header.Get(discover.HeaderFromApp), r.Header.Get(discover.HeaderFromNode),
|
||||||
|
"", request.DeviceId(), request.SessionId(), requestId,
|
||||||
|
request.Header.Get(discover.HeaderClientAppName), request.Header.Get(discover.HeaderClientAppVersion),
|
||||||
|
authLevel, priority,
|
||||||
|
reqHeaders, args,
|
||||||
|
response.Code, usedTime,
|
||||||
|
respHeaders, respData, uint(len(response.body)),
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func findService(method, host, path string) (*webServiceType, *websocketServiceType) {
|
func findService(method, host, path string) (*webServiceType, *websocketServiceType) {
|
||||||
webServicesLock.RLock()
|
webServicesLock.RLock()
|
||||||
defer webServicesLock.RUnlock()
|
defer webServicesLock.RUnlock()
|
||||||
|
|
||||||
// 1. Web Service 匹配
|
// 1. 准备 Host 候选列表: "host:port", "host", ":port", "*"
|
||||||
if s, exists := webServices[method+path]; exists {
|
hostOnly, port, _ := strings.Cut(host, ":")
|
||||||
|
hosts := []string{host}
|
||||||
|
if port != "" {
|
||||||
|
hosts = append(hosts, hostOnly, ":"+port)
|
||||||
|
}
|
||||||
|
hosts = append(hosts, "*")
|
||||||
|
|
||||||
|
// 2. 匹配 Web Service
|
||||||
|
for _, h := range hosts {
|
||||||
|
if services, exists := webServices[h]; exists {
|
||||||
|
if s, ok := services[method+path]; ok {
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
if s, exists := webServices[path]; exists {
|
if s, ok := services["*"+path]; ok {
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 2. WebSocket 匹配
|
// 3. 匹配 WebSocket
|
||||||
websocketServicesLock.RLock()
|
websocketServicesLock.RLock()
|
||||||
defer websocketServicesLock.RUnlock()
|
defer websocketServicesLock.RUnlock()
|
||||||
if ws, exists := websocketServices[path]; exists {
|
for _, h := range hosts {
|
||||||
|
if services, exists := websocketServices[h]; exists {
|
||||||
|
if ws, ok := services[path]; ok {
|
||||||
return nil, ws
|
return nil, ws
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 3. 正则匹配
|
// 4. 正则匹配
|
||||||
for i := len(regexWebServices) - 1; i >= 0; i-- {
|
for _, h := range hosts {
|
||||||
s := regexWebServices[i]
|
if services, exists := regexWebServices[h]; exists {
|
||||||
if s.method != "" && s.method != method {
|
for i := len(services) - 1; i >= 0; i-- {
|
||||||
|
s := services[i]
|
||||||
|
if s.method != "*" && s.method != method {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if s.pathMatcher != nil && s.pathMatcher.MatchString(path) {
|
if s.pathMatcher != nil && s.pathMatcher.MatchString(path) {
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@ -166,7 +233,7 @@ func parseRequestArgs(request *Request, args map[string]any) {
|
|||||||
body, _ := io.ReadAll(request.Body)
|
body, _ := io.ReadAll(request.Body)
|
||||||
_ = request.Body.Close()
|
_ = request.Body.Close()
|
||||||
if len(body) > 0 {
|
if len(body) > 0 {
|
||||||
_ = json.Unmarshal(body, &args)
|
_ = cast.UnmarshalJSON(body, &args)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
_ = request.ParseForm()
|
_ = request.ParseForm()
|
||||||
@ -198,8 +265,8 @@ func doWebService(service *webServiceType, request *Request, response *Response,
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
params := make([]reflect.Value, service.parmsNum)
|
params := make([]reflect.Value, service.paramsNum)
|
||||||
for i := 0; i < service.parmsNum; i++ {
|
for i := 0; i < service.paramsNum; i++ {
|
||||||
t := service.funcType.In(i)
|
t := service.funcType.In(i)
|
||||||
switch i {
|
switch i {
|
||||||
case service.requestIndex:
|
case service.requestIndex:
|
||||||
@ -288,7 +355,7 @@ func handleClientKeys(request *Request, response *Response) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
request.Header.Set(standard.DiscoverHeaderSessionId, sessionId)
|
request.Header.Set(discover.HeaderSessionID, sessionId)
|
||||||
response.Header().Set(usedSessionIdKey, sessionId)
|
response.Header().Set(usedSessionIdKey, sessionId)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -312,7 +379,7 @@ func handleClientKeys(request *Request, response *Response) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
request.Header.Set(standard.DiscoverHeaderDeviceId, deviceId)
|
request.Header.Set(discover.HeaderDeviceID, deviceId)
|
||||||
response.Header().Set(usedDeviceIdKey, deviceId)
|
response.Header().Set(usedDeviceIdKey, deviceId)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -12,9 +12,9 @@ func TestServeHTTP(t *testing.T) {
|
|||||||
handler := func(in struct{ Name string }) string {
|
handler := func(in struct{ Name string }) string {
|
||||||
return "Hello " + in.Name
|
return "Hello " + in.Name
|
||||||
}
|
}
|
||||||
Register(0, "/hello", handler, "say hello")
|
Host("*").POST("/hello", handler).Auth(0).Memo("say hello")
|
||||||
|
|
||||||
rh := &routeHandler{}
|
rh := &RouteHandler{}
|
||||||
|
|
||||||
// 模拟请求
|
// 模拟请求
|
||||||
req := httptest.NewRequest("POST", "/hello", strings.NewReader(`{"name":"Star"}`))
|
req := httptest.NewRequest("POST", "/hello", strings.NewReader(`{"name":"Star"}`))
|
||||||
@ -34,7 +34,7 @@ func TestServeHTTP(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestServeHTTP_404(t *testing.T) {
|
func TestServeHTTP_404(t *testing.T) {
|
||||||
rh := &routeHandler{}
|
rh := &RouteHandler{}
|
||||||
req := httptest.NewRequest("GET", "/notfound", nil)
|
req := httptest.NewRequest("GET", "/notfound", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
@ -52,9 +52,9 @@ func TestServeHTTP_VerifyFailed(t *testing.T) {
|
|||||||
handler := func(in ValidIn) string {
|
handler := func(in ValidIn) string {
|
||||||
return "ok"
|
return "ok"
|
||||||
}
|
}
|
||||||
Register(0, "/verify", handler, "test verify")
|
Host("*").POST("/verify", handler).Auth(0).Memo("test verify")
|
||||||
|
|
||||||
rh := &routeHandler{}
|
rh := &RouteHandler{}
|
||||||
req := httptest.NewRequest("POST", "/verify", strings.NewReader(`{"age":10}`))
|
req := httptest.NewRequest("POST", "/verify", strings.NewReader(`{"age":10}`))
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|||||||
223
log.go
Normal file
223
log.go
Normal file
@ -0,0 +1,223 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"apigo.cc/go/cast"
|
||||||
|
"apigo.cc/go/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RequestLog struct {
|
||||||
|
log.BaseLog
|
||||||
|
ServerId string `log:"pos:6,hide:true"`
|
||||||
|
App string `log:"pos:7,color:cyan,keyname:App"`
|
||||||
|
Node string `log:"pos:8,color:gray,attachBefore:true"`
|
||||||
|
FromApp string `log:"pos:9,color:cyan,keyname:From"`
|
||||||
|
FromNode string `log:"pos:10,color:gray,attachBefore:true"`
|
||||||
|
ClientIp string `log:"pos:11,withoutkey:true"`
|
||||||
|
ClientAppName string `log:"pos:12,attachBefore:true,keyname:Client"`
|
||||||
|
ClientAppVersion string `log:"pos:13,attachBefore:true"`
|
||||||
|
UserId string `log:"pos:14,color:magenta,keyname:User"`
|
||||||
|
DeviceId string `log:"pos:15,color:gray,keyname:Device"`
|
||||||
|
SessionId string `log:"pos:16,keyname:Session"`
|
||||||
|
Host string `log:"pos:17,color:gray,withoutkey:true"`
|
||||||
|
Method string `log:"pos:18,color:gray,withoutkey:true"`
|
||||||
|
Path string `log:"pos:19,color:cyan,withoutkey:true"`
|
||||||
|
Scheme string `log:"pos:20,color:gray,withoutkey:true"`
|
||||||
|
Proto string `log:"pos:21,color:gray,withoutkey:true"`
|
||||||
|
AuthLevel int `log:"pos:22,color:green"`
|
||||||
|
Priority int `log:"pos:23,hide:true"`
|
||||||
|
RequestData any `log:"pos:24,color:cyan,keyname:Request"`
|
||||||
|
RequestHeaders map[string]string `log:"pos:25,color:cyan,keyname:Headers"`
|
||||||
|
UsedTime float32 `log:"pos:26,color:green,precision:6"`
|
||||||
|
ResponseCode int `log:"pos:27,color:magenta,keyname:Status"`
|
||||||
|
ResponseDataLength uint `log:"pos:28,color:magenta,keyname:ContentLength"`
|
||||||
|
ResponseData any `log:"pos:29,color:magenta,keyname:Response"`
|
||||||
|
ResponseHeaders map[string]string `log:"pos:30,color:magenta,keyname:Headers"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *RequestLog) Reset() {
|
||||||
|
l.BaseLog.Reset()
|
||||||
|
l.ServerId = ""
|
||||||
|
l.App = ""
|
||||||
|
l.Node = ""
|
||||||
|
l.ClientIp = ""
|
||||||
|
l.FromApp = ""
|
||||||
|
l.FromNode = ""
|
||||||
|
l.UserId = ""
|
||||||
|
l.DeviceId = ""
|
||||||
|
l.ClientAppName = ""
|
||||||
|
l.ClientAppVersion = ""
|
||||||
|
l.SessionId = ""
|
||||||
|
l.Host = ""
|
||||||
|
l.Scheme = ""
|
||||||
|
l.Proto = ""
|
||||||
|
l.AuthLevel = 0
|
||||||
|
l.Priority = 0
|
||||||
|
l.Method = ""
|
||||||
|
l.Path = ""
|
||||||
|
if l.RequestHeaders == nil {
|
||||||
|
l.RequestHeaders = make(map[string]string, 8)
|
||||||
|
} else {
|
||||||
|
clear(l.RequestHeaders)
|
||||||
|
}
|
||||||
|
l.RequestData = nil
|
||||||
|
l.UsedTime = 0
|
||||||
|
l.ResponseCode = 0
|
||||||
|
if l.ResponseHeaders == nil {
|
||||||
|
l.ResponseHeaders = make(map[string]string, 8)
|
||||||
|
} else {
|
||||||
|
clear(l.ResponseHeaders)
|
||||||
|
}
|
||||||
|
l.ResponseDataLength = 0
|
||||||
|
l.ResponseData = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestLog 调用封装
|
||||||
|
func logRequest(
|
||||||
|
logger *log.Logger,
|
||||||
|
method, path, host, scheme, proto string,
|
||||||
|
clientIp, serverId, app, node string,
|
||||||
|
fromApp, fromNode string,
|
||||||
|
userId, deviceId, sessionId, requestId string,
|
||||||
|
clientAppName, clientAppVersion string,
|
||||||
|
authLevel, priority int,
|
||||||
|
reqHeaders map[string]string,
|
||||||
|
reqData map[string]any,
|
||||||
|
responseCode int,
|
||||||
|
usedTime float32,
|
||||||
|
respHeaders map[string]string,
|
||||||
|
responseData string,
|
||||||
|
responseDataLength uint,
|
||||||
|
extra ...any,
|
||||||
|
) {
|
||||||
|
if !logger.CheckLevel(log.INFO) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := log.GetEntry[RequestLog]()
|
||||||
|
logger.FillBase(entry.GetBaseLog(), log.LogTypeRequest)
|
||||||
|
|
||||||
|
entry.Method = method
|
||||||
|
entry.Path = path
|
||||||
|
entry.Host = host
|
||||||
|
entry.Scheme = scheme
|
||||||
|
entry.Proto = proto
|
||||||
|
entry.ClientIp = clientIp
|
||||||
|
entry.ServerId = serverId
|
||||||
|
entry.App = app
|
||||||
|
entry.Node = node
|
||||||
|
entry.FromApp = fromApp
|
||||||
|
entry.FromNode = fromNode
|
||||||
|
entry.UserId = userId
|
||||||
|
entry.DeviceId = deviceId
|
||||||
|
entry.SessionId = sessionId
|
||||||
|
entry.ClientAppName = clientAppName
|
||||||
|
entry.ClientAppVersion = clientAppVersion
|
||||||
|
entry.AuthLevel = authLevel
|
||||||
|
entry.Priority = priority
|
||||||
|
entry.RequestHeaders = reqHeaders
|
||||||
|
entry.RequestData = reqData
|
||||||
|
entry.ResponseCode = responseCode
|
||||||
|
entry.UsedTime = usedTime
|
||||||
|
entry.ResponseHeaders = respHeaders
|
||||||
|
entry.ResponseData = responseData
|
||||||
|
entry.ResponseDataLength = responseDataLength
|
||||||
|
if len(extra) > 0 {
|
||||||
|
cast.FillMap(&entry.Extra, extra)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Log(entry)
|
||||||
|
}
|
||||||
|
|
||||||
|
type TaskLog struct {
|
||||||
|
log.BaseLog
|
||||||
|
Task string `log:"pos:6"`
|
||||||
|
UsedTime float32 `log:"pos:7"`
|
||||||
|
Success bool `log:"pos:8"`
|
||||||
|
Message string `log:"pos:9"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *TaskLog) Reset() {
|
||||||
|
l.BaseLog.Reset()
|
||||||
|
l.Task = ""
|
||||||
|
l.UsedTime = 0
|
||||||
|
l.Success = false
|
||||||
|
l.Message = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
type MonitorLog struct {
|
||||||
|
log.BaseLog
|
||||||
|
Target string `log:"pos:6"`
|
||||||
|
Status int `log:"pos:7"`
|
||||||
|
Message string `log:"pos:8"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *MonitorLog) Reset() {
|
||||||
|
l.BaseLog.Reset()
|
||||||
|
l.Target = ""
|
||||||
|
l.Status = 0
|
||||||
|
l.Message = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
type StatisticLog struct {
|
||||||
|
log.BaseLog
|
||||||
|
Category string `log:"pos:6"`
|
||||||
|
Item string `log:"pos:7"`
|
||||||
|
Value float64 `log:"pos:8"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *StatisticLog) Reset() {
|
||||||
|
l.BaseLog.Reset()
|
||||||
|
l.Category = ""
|
||||||
|
l.Item = ""
|
||||||
|
l.Value = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func logTask(logger *log.Logger, taskName string, usedTime float32, success bool, message string, extra ...any) {
|
||||||
|
if logger.CheckLevel(log.INFO) {
|
||||||
|
entry := log.GetEntry[TaskLog]()
|
||||||
|
logger.FillBase(entry.GetBaseLog(), log.LogTypeTask)
|
||||||
|
entry.Task = taskName
|
||||||
|
entry.UsedTime = usedTime
|
||||||
|
entry.Success = success
|
||||||
|
entry.Message = message
|
||||||
|
if len(extra) > 0 {
|
||||||
|
cast.FillMap(&entry.Extra, extra)
|
||||||
|
}
|
||||||
|
logger.Log(entry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func logMonitor(logger *log.Logger, target string, status int, message string, extra ...any) {
|
||||||
|
if logger.CheckLevel(log.INFO) {
|
||||||
|
entry := log.GetEntry[MonitorLog]()
|
||||||
|
logger.FillBase(entry.GetBaseLog(), log.LogTypeMonitor)
|
||||||
|
entry.Target = target
|
||||||
|
entry.Status = status
|
||||||
|
entry.Message = message
|
||||||
|
if len(extra) > 0 {
|
||||||
|
cast.FillMap(&entry.Extra, extra)
|
||||||
|
}
|
||||||
|
logger.Log(entry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func logStatistic(logger *log.Logger, category, item string, value float64, extra ...any) {
|
||||||
|
if logger.CheckLevel(log.INFO) {
|
||||||
|
entry := log.GetEntry[StatisticLog]()
|
||||||
|
logger.FillBase(entry.GetBaseLog(), log.LogTypeStatistic)
|
||||||
|
entry.Category = category
|
||||||
|
entry.Item = item
|
||||||
|
entry.Value = value
|
||||||
|
if len(extra) > 0 {
|
||||||
|
cast.FillMap(&entry.Extra, extra)
|
||||||
|
}
|
||||||
|
logger.Log(entry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
log.RegisterType(log.LogTypeRequest, &RequestLog{})
|
||||||
|
log.RegisterType(log.LogTypeTask, &TaskLog{})
|
||||||
|
log.RegisterType(log.LogTypeMonitor, &MonitorLog{})
|
||||||
|
log.RegisterType(log.LogTypeStatistic, &StatisticLog{})
|
||||||
|
}
|
||||||
87
proxy.go
87
proxy.go
@ -9,11 +9,10 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type proxyInfo struct {
|
type proxyType struct {
|
||||||
matcher *regexp.Regexp
|
matcher *regexp.Regexp
|
||||||
authLevel int
|
authLevel int
|
||||||
fromPath string
|
fromPath string
|
||||||
@ -21,39 +20,32 @@ type proxyInfo struct {
|
|||||||
toPath string
|
toPath string
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
func (hc *HostContext) Proxy(authLevel int, path string, toApp, toPath string) *HostContext {
|
||||||
proxies = make(map[string]*proxyInfo)
|
p := &proxyType{authLevel: authLevel, fromPath: path, toApp: toApp, toPath: toPath}
|
||||||
regexProxies = make([]*proxyInfo, 0)
|
|
||||||
proxyBy func(*Request) (int, *string, *string, map[string]string)
|
|
||||||
proxiesLock = sync.RWMutex{}
|
|
||||||
|
|
||||||
httpClientPool *gohttp.Client
|
|
||||||
)
|
|
||||||
|
|
||||||
// Proxy 注册代理规则
|
|
||||||
func Proxy(authLevel int, path string, toApp, toPath string) {
|
|
||||||
p := &proxyInfo{authLevel: authLevel, fromPath: path, toApp: toApp, toPath: toPath}
|
|
||||||
if strings.Contains(path, "(") {
|
if strings.Contains(path, "(") {
|
||||||
matcher, err := regexp.Compile("^" + path + "$")
|
matcher, err := regexp.Compile("^" + path + "$")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
p.matcher = matcher
|
p.matcher = matcher
|
||||||
proxiesLock.Lock()
|
|
||||||
regexProxies = append(regexProxies, p)
|
|
||||||
proxiesLock.Unlock()
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
proxiesLock.Lock()
|
|
||||||
proxies[path] = p
|
|
||||||
proxiesLock.Unlock()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetProxyBy 设置动态代理函数
|
hostPoliciesLock.Lock()
|
||||||
func SetProxyBy(by func(request *Request) (authLevel int, toApp, toPath *string, headers map[string]string)) {
|
defer hostPoliciesLock.Unlock()
|
||||||
proxyBy = by
|
hostProxies[hc.host] = append(hostProxies[hc.host], p)
|
||||||
|
return hc
|
||||||
}
|
}
|
||||||
|
|
||||||
func findProxy(request *Request) (int, *string, *string) {
|
var httpClientPool *gohttp.Client
|
||||||
|
|
||||||
|
func findProxy(request *Request) (int, *string, *string, string) {
|
||||||
|
host := request.Host
|
||||||
|
hostOnly, port, _ := strings.Cut(host, ":")
|
||||||
|
hosts := []string{host}
|
||||||
|
if port != "" {
|
||||||
|
hosts = append(hosts, hostOnly, ":"+port)
|
||||||
|
}
|
||||||
|
hosts = append(hosts, "*")
|
||||||
|
|
||||||
requestPath := request.RequestURI
|
requestPath := request.RequestURI
|
||||||
queryString := ""
|
queryString := ""
|
||||||
if pos := strings.Index(requestPath, "?"); pos != -1 {
|
if pos := strings.Index(requestPath, "?"); pos != -1 {
|
||||||
@ -61,16 +53,22 @@ func findProxy(request *Request) (int, *string, *string) {
|
|||||||
requestPath = requestPath[:pos]
|
requestPath = requestPath[:pos]
|
||||||
}
|
}
|
||||||
|
|
||||||
proxiesLock.RLock()
|
hostPoliciesLock.RLock()
|
||||||
defer proxiesLock.RUnlock()
|
defer hostPoliciesLock.RUnlock()
|
||||||
|
|
||||||
if pi, ok := proxies[requestPath]; ok {
|
for _, h := range hosts {
|
||||||
toPath := pi.toPath + queryString
|
proxies, exists := hostProxies[h]
|
||||||
return pi.authLevel, &pi.toApp, &toPath
|
if !exists {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, pi := range regexProxies {
|
for _, pi := range proxies {
|
||||||
if pi.matcher != nil {
|
if pi.matcher == nil {
|
||||||
|
if pi.fromPath == requestPath {
|
||||||
|
toPath := pi.toPath + queryString
|
||||||
|
return pi.authLevel, &pi.toApp, &toPath, h
|
||||||
|
}
|
||||||
|
} else {
|
||||||
finds := pi.matcher.FindAllStringSubmatch(requestPath, 1)
|
finds := pi.matcher.FindAllStringSubmatch(requestPath, 1)
|
||||||
if len(finds) > 0 {
|
if len(finds) > 0 {
|
||||||
toApp := pi.toApp
|
toApp := pi.toApp
|
||||||
@ -80,21 +78,17 @@ func findProxy(request *Request) (int, *string, *string) {
|
|||||||
toPath = strings.ReplaceAll(toPath, fmt.Sprintf("$%d", i), part)
|
toPath = strings.ReplaceAll(toPath, fmt.Sprintf("$%d", i), part)
|
||||||
}
|
}
|
||||||
toPath += queryString
|
toPath += queryString
|
||||||
return pi.authLevel, &toApp, &toPath
|
return pi.authLevel, &toApp, &toPath, h
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0, nil, nil
|
return 0, nil, nil, ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func processProxy(request *Request, response *Response, logger *log.Logger) bool {
|
func processProxy(request *Request, response *Response, logger *log.Logger) bool {
|
||||||
authLevel, proxyToApp, proxyToPath := findProxy(request)
|
authLevel, proxyToApp, proxyToPath, foundHost := findProxy(request)
|
||||||
var proxyHeaders map[string]string
|
|
||||||
|
|
||||||
if proxyBy != nil && (proxyToApp == nil || proxyToPath == nil || *proxyToApp == "" || *proxyToPath == "") {
|
|
||||||
authLevel, proxyToApp, proxyToPath, proxyHeaders = proxyBy(request)
|
|
||||||
}
|
|
||||||
|
|
||||||
if proxyToApp == nil || proxyToPath == nil || *proxyToApp == "" || *proxyToPath == "" {
|
if proxyToApp == nil || proxyToPath == nil || *proxyToApp == "" || *proxyToPath == "" {
|
||||||
return false
|
return false
|
||||||
@ -112,25 +106,20 @@ func processProxy(request *Request, response *Response, logger *log.Logger) bool
|
|||||||
|
|
||||||
app := *proxyToApp
|
app := *proxyToApp
|
||||||
path := *proxyToPath
|
path := *proxyToPath
|
||||||
|
logger.Info("proxy", "app", app, "path", path, "host", foundHost)
|
||||||
// 构建自定义头部
|
|
||||||
headerArgs := make([]string, 0)
|
|
||||||
for k, v := range proxyHeaders {
|
|
||||||
headerArgs = append(headerArgs, k, v)
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.Contains(app, "://") {
|
if strings.Contains(app, "://") {
|
||||||
// 直接 URL 代理
|
// 直接 URL 代理
|
||||||
if httpClientPool == nil {
|
if httpClientPool == nil {
|
||||||
httpClientPool = gohttp.NewClient(time.Duration(Config.RedirectTimeout) * time.Millisecond)
|
httpClientPool = gohttp.NewClient(time.Duration(Config.RedirectTimeout) * time.Millisecond)
|
||||||
}
|
}
|
||||||
res := httpClientPool.ManualDoByRequest(request.Request, request.Method, app+path, request.Body, headerArgs...)
|
res := httpClientPool.ManualDoByRequest(request.Request, request.Method, app+path, request.Body)
|
||||||
copyResponse(res, response, logger)
|
copyResponse(res, response, logger)
|
||||||
} else {
|
} else {
|
||||||
// Discover 代理
|
// Discover 代理
|
||||||
caller := discover.NewCaller(request.Request, logger)
|
caller := discover.NewCaller(request.Request, logger)
|
||||||
caller.NoBody = true
|
caller.NoBody = true
|
||||||
res, _ := caller.ManualDoWithNode(request.Method, app, "", path, request.Body, headerArgs...)
|
res, _ := caller.ManualDoWithNode(request.Method, app, "", path, request.Body)
|
||||||
copyResponse(res, response, logger)
|
copyResponse(res, response, logger)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -8,14 +8,14 @@ import (
|
|||||||
|
|
||||||
func TestRewrite(t *testing.T) {
|
func TestRewrite(t *testing.T) {
|
||||||
// 注册重写规则
|
// 注册重写规则
|
||||||
Rewrite("/old", "/new")
|
Host("*").Rewrite("/old", "/new")
|
||||||
Rewrite("/regex/(.*)", "/target/$1")
|
Host("*").Rewrite("/regex/(.*)", "/target/$1")
|
||||||
|
|
||||||
// 注册目标服务
|
// 注册目标服务
|
||||||
Register(0, "/new", func() string { return "new content" }, "new")
|
Host("*").ANY("/new", func() string { return "new content" }).Memo("new")
|
||||||
Register(0, "/target/123", func() string { return "target content" }, "target")
|
Host("*").ANY("/target/123", func() string { return "target content" }).Memo("target")
|
||||||
|
|
||||||
rh := &routeHandler{}
|
rh := &RouteHandler{}
|
||||||
|
|
||||||
// 测试精确匹配重写
|
// 测试精确匹配重写
|
||||||
req1 := httptest.NewRequest("GET", "/old", nil)
|
req1 := httptest.NewRequest("GET", "/old", nil)
|
||||||
@ -43,9 +43,9 @@ func TestProxyDirect(t *testing.T) {
|
|||||||
defer backend.Close()
|
defer backend.Close()
|
||||||
|
|
||||||
// 注册代理规则
|
// 注册代理规则
|
||||||
Proxy(0, "/proxy", backend.URL, "/hello")
|
Host("*").Proxy(0, "/proxy", backend.URL, "/hello")
|
||||||
|
|
||||||
rh := &routeHandler{}
|
rh := &RouteHandler{}
|
||||||
req := httptest.NewRequest("GET", "/proxy", nil)
|
req := httptest.NewRequest("GET", "/proxy", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
rh.ServeHTTP(w, req)
|
rh.ServeHTTP(w, req)
|
||||||
|
|||||||
56
request.go
56
request.go
@ -1,16 +1,14 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"apigo.cc/go/cast"
|
"apigo.cc/go/discover"
|
||||||
"apigo.cc/go/standard"
|
"apigo.cc/go/file"
|
||||||
"io"
|
"io"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/textproto"
|
"net/textproto"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// UploadFile 上传文件结构
|
// UploadFile 上传文件结构
|
||||||
@ -28,25 +26,11 @@ func (f *UploadFile) Open() (multipart.File, error) {
|
|||||||
|
|
||||||
// Save 保存上传文件到本地
|
// Save 保存上传文件到本地
|
||||||
func (f *UploadFile) Save(filename string) error {
|
func (f *UploadFile) Save(filename string) error {
|
||||||
dir := filepath.Dir(filename)
|
data, err := f.Content()
|
||||||
if _, err := os.Stat(dir); os.IsNotExist(err) {
|
|
||||||
_ = os.MkdirAll(dir, 0755)
|
|
||||||
}
|
|
||||||
|
|
||||||
dst, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer dst.Close()
|
return file.WriteBytes(filename, data)
|
||||||
|
|
||||||
src, err := f.fileHeader.Open()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer src.Close()
|
|
||||||
|
|
||||||
_, err = io.Copy(dst, src)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Content 获取上传文件内容
|
// Content 获取上传文件内容
|
||||||
@ -94,34 +78,37 @@ func (r *Request) Get(key string) any {
|
|||||||
|
|
||||||
// MakeUrl 根据当前请求构建完整 URL
|
// MakeUrl 根据当前请求构建完整 URL
|
||||||
func (r *Request) MakeUrl(path string) string {
|
func (r *Request) MakeUrl(path string) string {
|
||||||
scheme := r.Header.Get(standard.DiscoverHeaderScheme)
|
scheme := r.Header.Get(discover.HeaderScheme)
|
||||||
if scheme == "" {
|
if scheme == "" {
|
||||||
scheme = "http"
|
scheme = "http"
|
||||||
}
|
}
|
||||||
host := r.Header.Get(standard.DiscoverHeaderHost)
|
host := r.Header.Get(discover.HeaderHost)
|
||||||
if host == "" {
|
if host == "" {
|
||||||
host = r.Host
|
host = r.Host
|
||||||
}
|
}
|
||||||
return scheme + "://" + host + path
|
return scheme + "://" + host + path
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSessionId 获取会话 ID
|
// DeviceId 获取设备 ID
|
||||||
func (r *Request) GetSessionId() string {
|
func (r *Request) DeviceId() string {
|
||||||
sessionId := r.Header.Get(Config.Listen) // Wait, this should be usedSessionIdKey
|
return r.Header.Get(discover.HeaderDeviceID)
|
||||||
// TODO: Fix dependency on global usedSessionIdKey
|
}
|
||||||
return sessionId
|
|
||||||
|
// SessionId 获取会话 ID
|
||||||
|
func (r *Request) SessionId() string {
|
||||||
|
return r.Header.Get(discover.HeaderSessionID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetUserId 设置用户 ID(传递给下游)
|
// SetUserId 设置用户 ID(传递给下游)
|
||||||
func (r *Request) SetUserId(userId string) {
|
func (r *Request) SetUserId(userId string) {
|
||||||
r.Header.Set(standard.DiscoverHeaderUserId, userId)
|
r.Header.Set(discover.HeaderUserID, userId)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRealIp 获取真实 IP
|
// ClientIp 获取真实 IP
|
||||||
func (r *Request) GetRealIp() string {
|
func (r *Request) ClientIp() string {
|
||||||
ip := r.Header.Get(standard.DiscoverHeaderClientIp)
|
ip := r.Header.Get(discover.HeaderClientIP)
|
||||||
if ip == "" {
|
if ip == "" {
|
||||||
ip = r.Header.Get(standard.DiscoverHeaderForwardedFor)
|
ip = r.Header.Get(discover.HeaderForwardedFor)
|
||||||
}
|
}
|
||||||
if ip == "" {
|
if ip == "" {
|
||||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||||
@ -132,8 +119,3 @@ func (r *Request) GetRealIp() string {
|
|||||||
}
|
}
|
||||||
return ip
|
return ip
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLowerName (Aliased from cast)
|
|
||||||
func GetLowerName(s string) string {
|
|
||||||
return cast.GetLowerName(s)
|
|
||||||
}
|
|
||||||
|
|||||||
26
response.go
26
response.go
@ -2,16 +2,17 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"apigo.cc/go/cast"
|
"apigo.cc/go/cast"
|
||||||
|
"apigo.cc/go/file"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Response 封装 http.ResponseWriter
|
// Response 封装 http.ResponseWriter
|
||||||
type Response struct {
|
type Response struct {
|
||||||
Id string
|
Id string
|
||||||
Writer http.ResponseWriter
|
Writer http.ResponseWriter
|
||||||
status int
|
Code int
|
||||||
|
body []byte
|
||||||
outLen int
|
outLen int
|
||||||
changed bool
|
changed bool
|
||||||
headerWritten bool
|
headerWritten bool
|
||||||
@ -24,7 +25,7 @@ type Response struct {
|
|||||||
func NewResponse(writer http.ResponseWriter) *Response {
|
func NewResponse(writer http.ResponseWriter) *Response {
|
||||||
return &Response{
|
return &Response{
|
||||||
Writer: writer,
|
Writer: writer,
|
||||||
status: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -42,6 +43,9 @@ func (r *Response) Write(bytes []byte) (int, error) {
|
|||||||
r.checkWriteHeader()
|
r.checkWriteHeader()
|
||||||
r.changed = true
|
r.changed = true
|
||||||
r.outLen += len(bytes)
|
r.outLen += len(bytes)
|
||||||
|
if r.Code != http.StatusOK && len(r.body) < 4096 {
|
||||||
|
r.body = append(r.body, bytes...)
|
||||||
|
}
|
||||||
if r.ProxyHeader != nil {
|
if r.ProxyHeader != nil {
|
||||||
r.copyProxyHeader()
|
r.copyProxyHeader()
|
||||||
}
|
}
|
||||||
@ -56,8 +60,8 @@ func (r *Response) WriteString(s string) (int, error) {
|
|||||||
// WriteHeader 设置响应状态码
|
// WriteHeader 设置响应状态码
|
||||||
func (r *Response) WriteHeader(code int) {
|
func (r *Response) WriteHeader(code int) {
|
||||||
r.changed = true
|
r.changed = true
|
||||||
r.status = code
|
r.Code = code
|
||||||
if r.ProxyHeader != nil && (r.status == http.StatusBadGateway || r.status == http.StatusServiceUnavailable || r.status == http.StatusGatewayTimeout) {
|
if r.ProxyHeader != nil && (r.Code == http.StatusBadGateway || r.Code == http.StatusServiceUnavailable || r.Code == http.StatusGatewayTimeout) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if r.ProxyHeader != nil {
|
if r.ProxyHeader != nil {
|
||||||
@ -68,8 +72,8 @@ func (r *Response) WriteHeader(code int) {
|
|||||||
func (r *Response) checkWriteHeader() {
|
func (r *Response) checkWriteHeader() {
|
||||||
if !r.headerWritten {
|
if !r.headerWritten {
|
||||||
r.headerWritten = true
|
r.headerWritten = true
|
||||||
if r.status != http.StatusOK {
|
if r.Code != http.StatusOK {
|
||||||
r.Writer.WriteHeader(r.status)
|
r.Writer.WriteHeader(r.Code)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -94,7 +98,7 @@ func (r *Response) Flush() {
|
|||||||
|
|
||||||
// GetStatusCode 获取当前状态码
|
// GetStatusCode 获取当前状态码
|
||||||
func (r *Response) GetStatusCode() int {
|
func (r *Response) GetStatusCode() int {
|
||||||
return r.status
|
return r.Code
|
||||||
}
|
}
|
||||||
|
|
||||||
// DontLog200 标记不记录 200 状态码的日志
|
// DontLog200 标记不记录 200 状态码的日志
|
||||||
@ -111,10 +115,8 @@ func (r *Response) Location(location string) {
|
|||||||
// SendFile 发送文件
|
// SendFile 发送文件
|
||||||
func (r *Response) SendFile(contentType, filename string) {
|
func (r *Response) SendFile(contentType, filename string) {
|
||||||
r.Header().Set("Content-Type", contentType)
|
r.Header().Set("Content-Type", contentType)
|
||||||
// TODO: Integrate memory file support if needed
|
if data, err := file.ReadBytes(filename); err == nil {
|
||||||
if fd, err := os.Open(filename); err == nil {
|
_, _ = r.Write(data)
|
||||||
defer fd.Close()
|
|
||||||
_, _ = io.Copy(r, fd)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
76
rewrite.go
76
rewrite.go
@ -6,47 +6,42 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type rewriteInfo struct {
|
type rewriteType struct {
|
||||||
matcher *regexp.Regexp
|
matcher *regexp.Regexp
|
||||||
fromPath string
|
fromPath string
|
||||||
toPath string
|
toPath string
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
func (hc *HostContext) Rewrite(path string, toPath string) *HostContext {
|
||||||
rewrites = make(map[string]*rewriteInfo)
|
s := &rewriteType{fromPath: path, toPath: toPath}
|
||||||
regexRewrites = make([]*rewriteInfo, 0)
|
|
||||||
rewriteBy func(*Request) (string, bool)
|
|
||||||
rewritesLock = sync.RWMutex{}
|
|
||||||
)
|
|
||||||
|
|
||||||
// Rewrite 注册重写规则
|
|
||||||
func Rewrite(path string, toPath string) {
|
|
||||||
s := &rewriteInfo{fromPath: path, toPath: toPath}
|
|
||||||
|
|
||||||
if strings.ContainsRune(path, '(') {
|
if strings.ContainsRune(path, '(') {
|
||||||
matcher, err := regexp.Compile("^" + path + "$")
|
matcher, err := regexp.Compile("^" + path + "$")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
s.matcher = matcher
|
s.matcher = matcher
|
||||||
rewritesLock.Lock()
|
|
||||||
regexRewrites = append(regexRewrites, s)
|
|
||||||
rewritesLock.Unlock()
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
rewritesLock.Lock()
|
|
||||||
rewrites[path] = s
|
|
||||||
rewritesLock.Unlock()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetRewriteBy 设置动态重写函数
|
hostPoliciesLock.Lock()
|
||||||
func SetRewriteBy(by func(request *Request) (toPath string, rewrite bool)) {
|
defer hostPoliciesLock.Unlock()
|
||||||
rewriteBy = by
|
hostRewrites[hc.host] = append(hostRewrites[hc.host], s)
|
||||||
|
return hc
|
||||||
}
|
}
|
||||||
|
|
||||||
func processRewrite(request *Request, response *Response, logger *log.Logger) bool {
|
func processRewrite(request *Request, response *Response, logger *log.Logger) bool {
|
||||||
|
host := request.Host
|
||||||
|
hostOnly, port, _ := strings.Cut(host, ":")
|
||||||
|
hosts := []string{host}
|
||||||
|
if port != "" {
|
||||||
|
hosts = append(hosts, hostOnly, ":"+port)
|
||||||
|
}
|
||||||
|
hosts = append(hosts, "*")
|
||||||
|
|
||||||
|
hostPoliciesLock.RLock()
|
||||||
|
defer hostPoliciesLock.RUnlock()
|
||||||
|
|
||||||
requestPath := request.RequestURI
|
requestPath := request.RequestURI
|
||||||
queryString := ""
|
queryString := ""
|
||||||
if pos := strings.Index(requestPath, "?"); pos != -1 {
|
if pos := strings.Index(requestPath, "?"); pos != -1 {
|
||||||
@ -54,25 +49,22 @@ func processRewrite(request *Request, response *Response, logger *log.Logger) bo
|
|||||||
requestPath = requestPath[:pos]
|
requestPath = requestPath[:pos]
|
||||||
}
|
}
|
||||||
|
|
||||||
var rewriteToPath string
|
for _, h := range hosts {
|
||||||
var found bool
|
rewrites, exists := hostRewrites[h]
|
||||||
|
if !exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
rewritesLock.RLock()
|
for _, ri := range rewrites {
|
||||||
// 1. 精确匹配
|
found := false
|
||||||
if ri, ok := rewrites[requestPath]; ok {
|
rewriteToPath := ""
|
||||||
|
|
||||||
|
if ri.matcher == nil {
|
||||||
|
if ri.fromPath == requestPath {
|
||||||
rewriteToPath = ri.toPath
|
rewriteToPath = ri.toPath
|
||||||
found = true
|
found = true
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
// 2. 动态重写
|
|
||||||
if !found && rewriteBy != nil {
|
|
||||||
rewriteToPath, found = rewriteBy(request)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. 正则匹配
|
|
||||||
if !found {
|
|
||||||
for _, ri := range regexRewrites {
|
|
||||||
if ri.matcher != nil {
|
|
||||||
finds := ri.matcher.FindAllStringSubmatch(request.RequestURI, 1)
|
finds := ri.matcher.FindAllStringSubmatch(request.RequestURI, 1)
|
||||||
if len(finds) > 0 {
|
if len(finds) > 0 {
|
||||||
toPath := ri.toPath
|
toPath := ri.toPath
|
||||||
@ -81,12 +73,8 @@ func processRewrite(request *Request, response *Response, logger *log.Logger) bo
|
|||||||
}
|
}
|
||||||
rewriteToPath = toPath
|
rewriteToPath = toPath
|
||||||
found = true
|
found = true
|
||||||
break
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
rewritesLock.RUnlock()
|
|
||||||
|
|
||||||
if found {
|
if found {
|
||||||
if strings.Contains(rewriteToPath, "://") {
|
if strings.Contains(rewriteToPath, "://") {
|
||||||
@ -99,7 +87,7 @@ func processRewrite(request *Request, response *Response, logger *log.Logger) bo
|
|||||||
return true
|
return true
|
||||||
} else {
|
} else {
|
||||||
// 内部重写
|
// 内部重写
|
||||||
logger.Info("rewrite", "from", request.RequestURI, "to", rewriteToPath)
|
logger.Info("rewrite", "from", request.RequestURI, "to", rewriteToPath, "host", h)
|
||||||
if queryString != "" && !strings.Contains(rewriteToPath, "?") {
|
if queryString != "" && !strings.Contains(rewriteToPath, "?") {
|
||||||
rewriteToPath += queryString
|
rewriteToPath += queryString
|
||||||
}
|
}
|
||||||
@ -108,6 +96,8 @@ func processRewrite(request *Request, response *Response, logger *log.Logger) bo
|
|||||||
return false // 继续后续处理
|
return false // 继续后续处理
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
@ -50,7 +50,7 @@ func (as *AsyncServer) start() {
|
|||||||
serverAddr = as.Addr
|
serverAddr = as.Addr
|
||||||
|
|
||||||
as.server = &http.Server{
|
as.server = &http.Server{
|
||||||
Handler: &routeHandler{},
|
Handler: &RouteHandler{},
|
||||||
}
|
}
|
||||||
|
|
||||||
signal.Notify(as.stopChan, os.Interrupt, syscall.SIGTERM)
|
signal.Notify(as.stopChan, os.Interrupt, syscall.SIGTERM)
|
||||||
|
|||||||
260
service.go
260
service.go
@ -9,31 +9,15 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Map 通用 Map 类型
|
|
||||||
type Map = map[string]any
|
|
||||||
|
|
||||||
// Arr 通用切片类型
|
|
||||||
type Arr = []any
|
|
||||||
|
|
||||||
// WebServiceOptions 服务注册选项
|
|
||||||
type WebServiceOptions struct {
|
|
||||||
Priority int
|
|
||||||
NoDoc bool
|
|
||||||
NoBody bool
|
|
||||||
NoLog200 bool
|
|
||||||
Host string
|
|
||||||
Ext Map
|
|
||||||
// Limiters []*Limiter // TODO: Integrate Limiter
|
|
||||||
}
|
|
||||||
|
|
||||||
// webServiceType 内部存储的服务元数据
|
// webServiceType 内部存储的服务元数据
|
||||||
type webServiceType struct {
|
type webServiceType struct {
|
||||||
authLevel int
|
authLevel int
|
||||||
method string
|
method string
|
||||||
|
host string
|
||||||
path string
|
path string
|
||||||
pathMatcher *regexp.Regexp
|
pathMatcher *regexp.Regexp
|
||||||
pathArgs []string
|
pathArgs []string
|
||||||
parmsNum int
|
paramsNum int
|
||||||
inType reflect.Type
|
inType reflect.Type
|
||||||
inIndex int
|
inIndex int
|
||||||
headersType reflect.Type
|
headersType reflect.Type
|
||||||
@ -47,10 +31,29 @@ type webServiceType struct {
|
|||||||
funcType reflect.Type
|
funcType reflect.Type
|
||||||
funcValue reflect.Value
|
funcValue reflect.Value
|
||||||
options WebServiceOptions
|
options WebServiceOptions
|
||||||
data Map
|
data map[string]any
|
||||||
memo string
|
memo string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WebServiceOptions 服务注册选项
|
||||||
|
type WebServiceOptions struct {
|
||||||
|
Priority int
|
||||||
|
NoDoc bool
|
||||||
|
NoBody bool
|
||||||
|
NoLog200 bool
|
||||||
|
Ext map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
type websocketServiceType struct {
|
||||||
|
authLevel int
|
||||||
|
host string
|
||||||
|
path string
|
||||||
|
memo string
|
||||||
|
funcType reflect.Type
|
||||||
|
funcValue reflect.Value
|
||||||
|
options WebServiceOptions
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
serverId string
|
serverId string
|
||||||
serverAddr string
|
serverAddr string
|
||||||
@ -58,14 +61,21 @@ var (
|
|||||||
serverProtoName = "http"
|
serverProtoName = "http"
|
||||||
running = false
|
running = false
|
||||||
|
|
||||||
webServices = make(map[string]*webServiceType)
|
// webServices 按 Host 隔离: map[host]map[method+path]*webServiceType
|
||||||
regexWebServices = make([]*webServiceType, 0)
|
webServices = make(map[string]map[string]*webServiceType)
|
||||||
|
// regexWebServices 按 Host 隔离: map[host][]*webServiceType
|
||||||
|
regexWebServices = make(map[string][]*webServiceType)
|
||||||
webServicesLock = sync.RWMutex{}
|
webServicesLock = sync.RWMutex{}
|
||||||
webServicesList = make([]*webServiceType, 0)
|
webServicesList = make([]*webServiceType, 0)
|
||||||
|
|
||||||
websocketServices = make(map[string]*websocketServiceType)
|
websocketServices = make(map[string]map[string]*websocketServiceType)
|
||||||
websocketServicesLock = sync.RWMutex{}
|
websocketServicesLock = sync.RWMutex{}
|
||||||
websocketServicesList = make([]*webServiceType, 0)
|
websocketServicesList = make([]*websocketServiceType, 0)
|
||||||
|
|
||||||
|
// Rewrite 与 Proxy 按 Host 隔离
|
||||||
|
hostRewrites = make(map[string][]*rewriteType)
|
||||||
|
hostProxies = make(map[string][]*proxyType)
|
||||||
|
hostPoliciesLock = sync.RWMutex{}
|
||||||
|
|
||||||
// 过滤器与拦截器
|
// 过滤器与拦截器
|
||||||
inFilters = make([]func(*map[string]any, *Request, *Response, *log.Logger) any, 0)
|
inFilters = make([]func(*map[string]any, *Request, *Response, *log.Logger) any, 0)
|
||||||
@ -118,29 +128,28 @@ func SetOutFilter(filter func(in map[string]any, request *Request, response *Res
|
|||||||
outFilters = append(outFilters, filter)
|
outFilters = append(outFilters, filter)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register 注册服务(通用方法)
|
// HostContext 提供流式服务注册能力
|
||||||
func Register(authLevel int, path string, serviceFunc any, memo string) {
|
type HostContext struct {
|
||||||
Restful(authLevel, "", path, serviceFunc, memo)
|
host string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Restful 注册指定方法的服务
|
// Host 指定服务运行的 Host (支持 "example.com", ":8080", "example.com:8080", "*")
|
||||||
func Restful(authLevel int, method, path string, serviceFunc any, memo string) {
|
func Host(host string) *HostContext {
|
||||||
RestfulWithOptions(authLevel, method, path, serviceFunc, memo, WebServiceOptions{})
|
if host == "" {
|
||||||
|
host = "*"
|
||||||
|
}
|
||||||
|
return &HostContext{host: host}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RestfulWithOptions 注册带选项的服务
|
func (hc *HostContext) Register(method, path string, serviceFunc any) *webServiceType {
|
||||||
func RestfulWithOptions(authLevel int, method, path string, serviceFunc any, memo string, options WebServiceOptions) {
|
|
||||||
s, err := makeCachedService(serviceFunc)
|
s, err := makeCachedService(serviceFunc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO: Log error properly when logger is ready
|
return &webServiceType{} // 返回空对象避免链式调用崩溃
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
s.authLevel = authLevel
|
s.host = hc.host
|
||||||
s.options = options
|
s.method = strings.ToUpper(method)
|
||||||
s.method = method
|
|
||||||
s.path = path
|
s.path = path
|
||||||
s.memo = memo
|
|
||||||
|
|
||||||
// 解析路径参数 {name}
|
// 解析路径参数 {name}
|
||||||
finder, err := regexp.Compile("{(.*?)}")
|
finder, err := regexp.Compile("{(.*?)}")
|
||||||
@ -159,13 +168,169 @@ func RestfulWithOptions(authLevel int, method, path string, serviceFunc any, mem
|
|||||||
webServicesLock.Lock()
|
webServicesLock.Lock()
|
||||||
defer webServicesLock.Unlock()
|
defer webServicesLock.Unlock()
|
||||||
|
|
||||||
// 简单路径匹配
|
|
||||||
if s.pathMatcher == nil {
|
if s.pathMatcher == nil {
|
||||||
webServices[method+path] = s // TODO: Include Host in key
|
if webServices[s.host] == nil {
|
||||||
|
webServices[s.host] = make(map[string]*webServiceType)
|
||||||
|
}
|
||||||
|
webServices[s.host][s.method+s.path] = s
|
||||||
} else {
|
} else {
|
||||||
regexWebServices = append(regexWebServices, s)
|
regexWebServices[s.host] = append(regexWebServices[s.host], s)
|
||||||
}
|
}
|
||||||
webServicesList = append(webServicesList, s)
|
webServicesList = append(webServicesList, s)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hc *HostContext) GET(path string, serviceFunc any) *webServiceType {
|
||||||
|
return hc.Register("GET", path, serviceFunc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hc *HostContext) POST(path string, serviceFunc any) *webServiceType {
|
||||||
|
return hc.Register("POST", path, serviceFunc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hc *HostContext) PUT(path string, serviceFunc any) *webServiceType {
|
||||||
|
return hc.Register("PUT", path, serviceFunc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hc *HostContext) DELETE(path string, serviceFunc any) *webServiceType {
|
||||||
|
return hc.Register("DELETE", path, serviceFunc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hc *HostContext) PATCH(path string, serviceFunc any) *webServiceType {
|
||||||
|
return hc.Register("PATCH", path, serviceFunc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hc *HostContext) HEAD(path string, serviceFunc any) *webServiceType {
|
||||||
|
return hc.Register("HEAD", path, serviceFunc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hc *HostContext) OPTIONS(path string, serviceFunc any) *webServiceType {
|
||||||
|
return hc.Register("OPTIONS", path, serviceFunc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hc *HostContext) ANY(path string, serviceFunc any) *webServiceType {
|
||||||
|
return hc.Register("*", path, serviceFunc)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupContext 提供路径分组注册能力
|
||||||
|
type GroupContext struct {
|
||||||
|
hc *HostContext
|
||||||
|
prefix string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Group 创建路径分组
|
||||||
|
func (hc *HostContext) Group(prefix string) *GroupContext {
|
||||||
|
if prefix == "/" {
|
||||||
|
prefix = ""
|
||||||
|
}
|
||||||
|
return &GroupContext{hc: hc, prefix: prefix}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (gc *GroupContext) GET(path string, serviceFunc any) *webServiceType {
|
||||||
|
return gc.hc.Register("GET", gc.prefix+path, serviceFunc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (gc *GroupContext) POST(path string, serviceFunc any) *webServiceType {
|
||||||
|
return gc.hc.Register("POST", gc.prefix+path, serviceFunc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (gc *GroupContext) PUT(path string, serviceFunc any) *webServiceType {
|
||||||
|
return gc.hc.Register("PUT", gc.prefix+path, serviceFunc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (gc *GroupContext) DELETE(path string, serviceFunc any) *webServiceType {
|
||||||
|
return gc.hc.Register("DELETE", gc.prefix+path, serviceFunc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (gc *GroupContext) ANY(path string, serviceFunc any) *webServiceType {
|
||||||
|
return gc.hc.Register("*", gc.prefix+path, serviceFunc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (gc *GroupContext) WebSocket(path string, serviceFunc any) *websocketServiceType {
|
||||||
|
return gc.hc.WebSocket(gc.prefix+path, serviceFunc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (gc *GroupContext) Rewrite(path string, toPath string) *GroupContext {
|
||||||
|
gc.hc.Rewrite(gc.prefix+path, toPath)
|
||||||
|
return gc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (gc *GroupContext) Proxy(authLevel int, path string, toApp, toPath string) *GroupContext {
|
||||||
|
gc.hc.Proxy(authLevel, gc.prefix+path, toApp, toPath)
|
||||||
|
return gc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hc *HostContext) WebSocket(path string, serviceFunc any) *websocketServiceType {
|
||||||
|
funcType := reflect.TypeOf(serviceFunc)
|
||||||
|
if funcType.Kind() != reflect.Func {
|
||||||
|
return &websocketServiceType{}
|
||||||
|
}
|
||||||
|
|
||||||
|
ws := &websocketServiceType{
|
||||||
|
host: hc.host,
|
||||||
|
path: path,
|
||||||
|
funcType: funcType,
|
||||||
|
funcValue: reflect.ValueOf(serviceFunc),
|
||||||
|
}
|
||||||
|
|
||||||
|
websocketServicesLock.Lock()
|
||||||
|
defer websocketServicesLock.Unlock()
|
||||||
|
if websocketServices[hc.host] == nil {
|
||||||
|
websocketServices[hc.host] = make(map[string]*websocketServiceType)
|
||||||
|
}
|
||||||
|
websocketServices[hc.host][path] = ws
|
||||||
|
websocketServicesList = append(websocketServicesList, ws)
|
||||||
|
return ws
|
||||||
|
}
|
||||||
|
|
||||||
|
// webServiceType 链式配置方法
|
||||||
|
func (s *webServiceType) Auth(level int) *webServiceType {
|
||||||
|
s.authLevel = level
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *webServiceType) Memo(memo string) *webServiceType {
|
||||||
|
s.memo = memo
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *webServiceType) Priority(p int) *webServiceType {
|
||||||
|
s.options.Priority = p
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *webServiceType) NoDoc() *webServiceType {
|
||||||
|
s.options.NoDoc = true
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *webServiceType) NoBody() *webServiceType {
|
||||||
|
s.options.NoBody = true
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *webServiceType) NoLog200() *webServiceType {
|
||||||
|
s.options.NoLog200 = true
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *webServiceType) Ext(key string, val any) *webServiceType {
|
||||||
|
if s.options.Ext == nil {
|
||||||
|
s.options.Ext = make(map[string]any)
|
||||||
|
}
|
||||||
|
s.options.Ext[key] = val
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// websocketServiceType 链式配置方法
|
||||||
|
func (s *websocketServiceType) Auth(level int) *websocketServiceType {
|
||||||
|
s.authLevel = level
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *websocketServiceType) Memo(memo string) *websocketServiceType {
|
||||||
|
s.memo = memo
|
||||||
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeCachedService(matchedService any) (*webServiceType, error) {
|
func makeCachedService(matchedService any) (*webServiceType, error) {
|
||||||
@ -175,7 +340,7 @@ func makeCachedService(matchedService any) (*webServiceType, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
targetService := &webServiceType{
|
targetService := &webServiceType{
|
||||||
parmsNum: funcType.NumIn(),
|
paramsNum: funcType.NumIn(),
|
||||||
inIndex: -1,
|
inIndex: -1,
|
||||||
headersIndex: -1,
|
headersIndex: -1,
|
||||||
requestIndex: -1,
|
requestIndex: -1,
|
||||||
@ -188,7 +353,7 @@ func makeCachedService(matchedService any) (*webServiceType, error) {
|
|||||||
funcValue: reflect.ValueOf(matchedService),
|
funcValue: reflect.ValueOf(matchedService),
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < targetService.parmsNum; i++ {
|
for i := 0; i < targetService.paramsNum; i++ {
|
||||||
t := funcType.In(i)
|
t := funcType.In(i)
|
||||||
tStr := t.String()
|
tStr := t.String()
|
||||||
switch tStr {
|
switch tStr {
|
||||||
@ -228,3 +393,14 @@ func GetInject(dataType reflect.Type) any {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetInjectT 获取注入对象 (泛型版)
|
||||||
|
func GetInjectT[T any]() T {
|
||||||
|
var zero T
|
||||||
|
t := reflect.TypeOf((*T)(nil)).Elem()
|
||||||
|
obj := GetInject(t)
|
||||||
|
if obj == nil {
|
||||||
|
return zero
|
||||||
|
}
|
||||||
|
return obj.(T)
|
||||||
|
}
|
||||||
|
|||||||
@ -10,10 +10,10 @@ func TestServiceRegister(t *testing.T) {
|
|||||||
return "ok"
|
return "ok"
|
||||||
}
|
}
|
||||||
|
|
||||||
Register(0, "/test", handler, "test service")
|
Host("*").Register("*", "/test", handler).Auth(0).Memo("test service")
|
||||||
|
|
||||||
webServicesLock.RLock()
|
webServicesLock.RLock()
|
||||||
s := webServices["/test"]
|
s := webServices["*"]["*/test"]
|
||||||
webServicesLock.RUnlock()
|
webServicesLock.RUnlock()
|
||||||
|
|
||||||
if s == nil {
|
if s == nil {
|
||||||
@ -33,11 +33,12 @@ func TestRegexServiceRegister(t *testing.T) {
|
|||||||
return "ok"
|
return "ok"
|
||||||
}
|
}
|
||||||
|
|
||||||
Register(0, "/user/{id}", handler, "get user")
|
Host("*").Register("*", "/user/{id}", handler).Auth(0).Memo("get user")
|
||||||
|
|
||||||
webServicesLock.RLock()
|
webServicesLock.RLock()
|
||||||
found := false
|
found := false
|
||||||
for _, s := range regexWebServices {
|
for _, services := range regexWebServices {
|
||||||
|
for _, s := range services {
|
||||||
if s.path == "/user/{id}" {
|
if s.path == "/user/{id}" {
|
||||||
found = true
|
found = true
|
||||||
if len(s.pathArgs) != 1 || s.pathArgs[0] != "id" {
|
if len(s.pathArgs) != 1 || s.pathArgs[0] != "id" {
|
||||||
@ -46,6 +47,10 @@ func TestRegexServiceRegister(t *testing.T) {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if found {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
webServicesLock.RUnlock()
|
webServicesLock.RUnlock()
|
||||||
|
|
||||||
if !found {
|
if !found {
|
||||||
|
|||||||
16
static.go
16
static.go
@ -5,7 +5,6 @@ import (
|
|||||||
"apigo.cc/go/log"
|
"apigo.cc/go/log"
|
||||||
"mime"
|
"mime"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@ -74,16 +73,16 @@ func processStatic(requestPath string, request *Request, response *Response, log
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
info, err := os.Stat(filePath)
|
info := file.GetFileInfo(filePath)
|
||||||
if err != nil {
|
if info == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if info.IsDir() {
|
if info.IsDir {
|
||||||
// 自动查找索引文件
|
// 自动查找索引文件
|
||||||
for _, indexFile := range Config.IndexFiles {
|
for _, indexFile := range Config.IndexFiles {
|
||||||
f := filepath.Join(filePath, indexFile)
|
f := filepath.Join(filePath, indexFile)
|
||||||
if i, err := os.Stat(f); err == nil && !i.IsDir() {
|
if i := file.GetFileInfo(f); i != nil && !i.IsDir {
|
||||||
filePath = f
|
filePath = f
|
||||||
info = i
|
info = i
|
||||||
break
|
break
|
||||||
@ -91,14 +90,15 @@ func processStatic(requestPath string, request *Request, response *Response, log
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if info.IsDir() {
|
if info.IsDir {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查 304
|
// 检查 304
|
||||||
if ifModifiedSince := request.Header.Get("If-Modified-Since"); ifModifiedSince != "" {
|
if ifModifiedSince := request.Header.Get("If-Modified-Since"); ifModifiedSince != "" {
|
||||||
if t, err := time.Parse(http.TimeFormat, ifModifiedSince); err == nil {
|
if t, err := time.Parse(http.TimeFormat, ifModifiedSince); err == nil {
|
||||||
if !info.ModTime().Truncate(time.Second).After(t.Truncate(time.Second)) {
|
if time.Unix(info.ModTime, 0).Truncate(time.Second).Before(t.Truncate(time.Second)) ||
|
||||||
|
time.Unix(info.ModTime, 0).Truncate(time.Second).Equal(t.Truncate(time.Second)) {
|
||||||
response.WriteHeader(http.StatusNotModified)
|
response.WriteHeader(http.StatusNotModified)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@ -111,7 +111,7 @@ func processStatic(requestPath string, request *Request, response *Response, log
|
|||||||
contentType = "application/octet-stream"
|
contentType = "application/octet-stream"
|
||||||
}
|
}
|
||||||
response.Header().Set("Content-Type", contentType)
|
response.Header().Set("Content-Type", contentType)
|
||||||
response.Header().Set("Last-Modified", info.ModTime().UTC().Format(http.TimeFormat))
|
response.Header().Set("Last-Modified", time.Unix(info.ModTime, 0).UTC().Format(http.TimeFormat))
|
||||||
|
|
||||||
data, err := file.ReadBytes(filePath)
|
data, err := file.ReadBytes(filePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@ -19,7 +19,7 @@ func TestStaticService(t *testing.T) {
|
|||||||
// 注册静态目录
|
// 注册静态目录
|
||||||
Static("/ui", tempDir)
|
Static("/ui", tempDir)
|
||||||
|
|
||||||
rh := &routeHandler{}
|
rh := &RouteHandler{}
|
||||||
|
|
||||||
// 测试成功访问
|
// 测试成功访问
|
||||||
req := httptest.NewRequest("GET", "/ui/index.html", nil)
|
req := httptest.NewRequest("GET", "/ui/index.html", nil)
|
||||||
|
|||||||
10
verify.go
10
verify.go
@ -92,7 +92,7 @@ func VerifyStruct(in any, logger *log.Logger) (ok bool, field string) {
|
|||||||
keyTag := ft.Tag.Get("verifyKey")
|
keyTag := ft.Tag.Get("verifyKey")
|
||||||
if tag != "" || keyTag != "" {
|
if tag != "" || keyTag != "" {
|
||||||
var err error
|
var err error
|
||||||
ok, f, err := _verifyValue(fv, tag, keyTag, logger)
|
ok, f, err := verifyValue(fv, tag, keyTag, logger)
|
||||||
if !ok {
|
if !ok {
|
||||||
if f == "" {
|
if f == "" {
|
||||||
f = cast.GetLowerName(ft.Name)
|
f = cast.GetLowerName(ft.Name)
|
||||||
@ -111,13 +111,13 @@ func VerifyStruct(in any, logger *log.Logger) (ok bool, field string) {
|
|||||||
return true, ""
|
return true, ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func _verifyValue(in reflect.Value, setting, keySetting string, logger *log.Logger) (bool, string, error) {
|
func verifyValue(in reflect.Value, setting, keySetting string, logger *log.Logger) (bool, string, error) {
|
||||||
t := in.Type()
|
t := in.Type()
|
||||||
// 处理切片 (非 byte 切片)
|
// 处理切片 (非 byte 切片)
|
||||||
if t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8 {
|
if t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8 {
|
||||||
if setting != "" {
|
if setting != "" {
|
||||||
for i := 0; i < in.Len(); i++ {
|
for i := 0; i < in.Len(); i++ {
|
||||||
if ok, f, err := _verifyValue(in.Index(i), setting, "", logger); !ok {
|
if ok, f, err := verifyValue(in.Index(i), setting, "", logger); !ok {
|
||||||
return false, f, err
|
return false, f, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -129,12 +129,12 @@ func _verifyValue(in reflect.Value, setting, keySetting string, logger *log.Logg
|
|||||||
if t.Kind() == reflect.Map {
|
if t.Kind() == reflect.Map {
|
||||||
for _, k := range in.MapKeys() {
|
for _, k := range in.MapKeys() {
|
||||||
if keySetting != "" {
|
if keySetting != "" {
|
||||||
if ok, _, err := _verifyValue(k, keySetting, "", logger); !ok {
|
if ok, _, err := verifyValue(k, keySetting, "", logger); !ok {
|
||||||
return false, "key", err
|
return false, "key", err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if setting != "" {
|
if setting != "" {
|
||||||
if ok, f, err := _verifyValue(in.MapIndex(k), setting, "", logger); !ok {
|
if ok, f, err := verifyValue(in.MapIndex(k), setting, "", logger); !ok {
|
||||||
return false, f, err
|
return false, f, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
40
websocket.go
40
websocket.go
@ -7,40 +7,12 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
// websocketServiceType WebSocket 服务元数据
|
var defaultUpgrader = &websocket.Upgrader{
|
||||||
type websocketServiceType struct {
|
CheckOrigin: func(r *http.Request) bool { return true },
|
||||||
authLevel int
|
|
||||||
path string
|
|
||||||
updater *websocket.Upgrader
|
|
||||||
handlerValue reflect.Value
|
|
||||||
handlerType reflect.Type
|
|
||||||
memo string
|
|
||||||
}
|
|
||||||
|
|
||||||
// RegisterWebsocket 注册 WebSocket 服务
|
|
||||||
func RegisterWebsocket(authLevel int, path string, handler any, memo string) {
|
|
||||||
v := reflect.ValueOf(handler)
|
|
||||||
t := v.Type()
|
|
||||||
if t.Kind() != reflect.Func {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
s := &websocketServiceType{
|
|
||||||
authLevel: authLevel,
|
|
||||||
path: path,
|
|
||||||
memo: memo,
|
|
||||||
updater: &websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }},
|
|
||||||
handlerValue: v,
|
|
||||||
handlerType: t,
|
|
||||||
}
|
|
||||||
|
|
||||||
websocketServicesLock.Lock()
|
|
||||||
websocketServices[path] = s
|
|
||||||
websocketServicesLock.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func doWebsocketService(ws *websocketServiceType, request *Request, response *Response, logger *log.Logger) {
|
func doWebsocketService(ws *websocketServiceType, request *Request, response *Response, logger *log.Logger) {
|
||||||
conn, err := ws.updater.Upgrade(response.Writer, request.Request, nil)
|
conn, err := defaultUpgrader.Upgrade(response.Writer, request.Request, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("websocket upgrade failed", "error", err.Error())
|
logger.Error("websocket upgrade failed", "error", err.Error())
|
||||||
return
|
return
|
||||||
@ -48,9 +20,9 @@ func doWebsocketService(ws *websocketServiceType, request *Request, response *Re
|
|||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
// 调用业务处理函数,注入依赖
|
// 调用业务处理函数,注入依赖
|
||||||
params := make([]reflect.Value, ws.handlerType.NumIn())
|
params := make([]reflect.Value, ws.funcType.NumIn())
|
||||||
for i := 0; i < len(params); i++ {
|
for i := 0; i < len(params); i++ {
|
||||||
t := ws.handlerType.In(i)
|
t := ws.funcType.In(i)
|
||||||
if t == reflect.TypeOf(request) {
|
if t == reflect.TypeOf(request) {
|
||||||
params[i] = reflect.ValueOf(request)
|
params[i] = reflect.ValueOf(request)
|
||||||
} else if t == reflect.TypeOf(logger) {
|
} else if t == reflect.TypeOf(logger) {
|
||||||
@ -63,5 +35,5 @@ func doWebsocketService(ws *websocketServiceType, request *Request, response *Re
|
|||||||
params[i] = reflect.New(t).Elem()
|
params[i] = reflect.New(t).Elem()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ws.handlerValue.Call(params)
|
ws.funcValue.Call(params)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -9,18 +9,18 @@ import (
|
|||||||
|
|
||||||
func TestWebSocketService(t *testing.T) {
|
func TestWebSocketService(t *testing.T) {
|
||||||
// 注册 WebSocket 服务
|
// 注册 WebSocket 服务
|
||||||
RegisterWebsocket(0, "/ws", func(conn *websocket.Conn) {
|
Host("*").WebSocket("/ws", func(conn *websocket.Conn) {
|
||||||
for {
|
for {
|
||||||
var msg Map
|
var msg map[string]any
|
||||||
if err := conn.ReadJSON(&msg); err != nil {
|
if err := conn.ReadJSON(&msg); err != nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
_ = conn.WriteJSON(Map{"reply": msg["msg"]})
|
_ = conn.WriteJSON(map[string]any{"reply": msg["msg"]})
|
||||||
}
|
}
|
||||||
}, "test websocket")
|
}).Auth(0).Memo("test websocket")
|
||||||
|
|
||||||
// 启动测试服务器
|
// 启动测试服务器
|
||||||
server := httptest.NewServer(&routeHandler{})
|
server := httptest.NewServer(&RouteHandler{})
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
// 建立连接
|
// 建立连接
|
||||||
@ -32,13 +32,13 @@ func TestWebSocketService(t *testing.T) {
|
|||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
// 发送消息
|
// 发送消息
|
||||||
msg := Map{"action": "echo", "msg": "hello"}
|
msg := map[string]any{"action": "echo", "msg": "hello"}
|
||||||
if err := conn.WriteJSON(msg); err != nil {
|
if err := conn.WriteJSON(msg); err != nil {
|
||||||
t.Fatalf("WriteJSON failed: %v", err)
|
t.Fatalf("WriteJSON failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 接收响应
|
// 接收响应
|
||||||
var reply Map
|
var reply map[string]any
|
||||||
if err := conn.ReadJSON(&reply); err != nil {
|
if err := conn.ReadJSON(&reply); err != nil {
|
||||||
t.Fatalf("ReadJSON failed: %v", err)
|
t.Fatalf("ReadJSON failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user