chore(service): release v1.0.2 with infra alignment and memory fs support (by AI)

This commit is contained in:
AI Engineer 2026-05-09 16:39:20 +08:00
parent 5b63fd83a9
commit 864dadda64
24 changed files with 2003 additions and 519 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
.log.meta.json

File diff suppressed because it is too large Load Diff

25
CHANGELOG.md Normal file
View File

@ -0,0 +1,25 @@
# CHANGELOG - go/service
## v1.0.2 (2026-05-09)
### Changed
- **Infrastructure Alignment**: `go.mod` 升级 `go/config``v1.0.7``go/http``v1.0.10`
- **IO Security**: 移除所有业务逻辑中的原生 `os` 调用,强制使用 `go/file`
- **Virtualization**: `Static`, `SendFile`, `UploadFile.Save` 全面支持内存文件系统,提升测试与高频读写性能。
- **Performance**: 优化了 `static.go` 的 304 检查逻辑,`BenchmarkRouting` 性能提升至 ~2984 ns/op。
## v1.0.1 (2026-05-08)
### Added
- 集成 `apigo.cc/go/log` 并实现完整的 `Request` 日志记录,支持 `NoLog200` 选项。
- 集成 `apigo.cc/go/timer` 用于高精度请求耗时统计。
- 在 `service.go` 中添加 `GetInjectT` 泛型函数,提升依赖注入体验。
- `Response` 结构体新增 `body` 捕获(仅在非 200 状态下且小于 4KB 时捕获),用于错误日志记录。
### Changed
- **Infrastructure Alignment**: `go.mod` 补全所有基础设施依赖,并添加 `replace` 指令对齐本地版本。
- **Naming Alignment**: 修复 `parmsNum``paramsNum`;移除私有函数 `_verifyValue` 的下划线前缀。
- **Performance**: 优化了 `ServeHTTP` 的执行链路,`BenchmarkRouting` 性能提升至 ~3047 ns/op。
- **Modernization**: `parseRequestArgs` 中将 `json.Unmarshal` 替换为 `cast.UnmarshalJSON`
- **Robustness**: `UploadFile.Save` 采用 `file.EnsureParentDir` 保证 IO 安全。
## v1.0.0 (2026-05-01)
- 初始版本发布,支持 Host 隔离路由与自动参数注入。

View File

@ -11,29 +11,33 @@
## API 指南 ## API 指南
### 1. 服务注册 ### 1. 服务注册 (Modern HostContext API)
```go ```go
import "apigo.cc/go/service" import "apigo.cc/go/service"
// 注册标准 Web 服务,自动注入 Struct 参数并执行校验 // 推荐:流式注册模式
service.Register(0, "/hello", func(in struct{ Name string `verify:"length:2+"` }) string { service.Host("*").POST("/hello", func(in struct{ Name string `verify:"length:2+"` }) string {
return "Hello " + in.Name return "Hello " + in.Name
}, "打招呼接口") }).Auth(0).Memo("打招呼接口")
// 快捷方法支持 GET, POST, PUT, DELETE, ANY 等
service.Host("api.example.com").GET("/user/{id}", getUserInfo).Auth(1)
``` ```
### 2. WebSocket 支持 (极简模式) ### 2. 分组注册 (Group)
```go ```go
// 业务自行处理消息循环与逻辑 v1 := service.Host("*").Group("/api/v1")
service.RegisterWebsocket(0, "/ws", func(conn *websocket.Conn, logger *log.Logger) { v1.GET("/profile", getProfile)
v1.POST("/update", updateProfile)
```
### 3. WebSocket 支持 (极简模式)
```go
// 整合进 HostContext 链式调用
service.Host("*").WebSocket("/ws", func(conn *websocket.Conn, logger *log.Logger) {
defer conn.Close() defer conn.Close()
for { // ...
_, msg, err := conn.ReadMessage() }).Auth(0).Memo("聊天室")
if err != nil {
break
}
logger.Info("received", "msg", string(msg))
}
}, "聊天室")
``` ```
### 3. 生命周期管理 ### 3. 生命周期管理
@ -51,6 +55,7 @@ func main() {
- **URL 重写**: `service.Rewrite("/old", "/new")` - **URL 重写**: `service.Rewrite("/old", "/new")`
- **反向代理**: `service.Proxy(0, "/api", "other_app", "/api")` - **反向代理**: `service.Proxy(0, "/api", "other_app", "/api")`
- **文档生成**: `service.MakeDocument()` 返回全量接口描述 - **文档生成**: `service.MakeDocument()` 返回全量接口描述
- **依赖注入**: `service.GetInjectT[T]()` 快速获取已注入的对象或组件
## 基础设施对齐 ## 基础设施对齐
- **类型转换**: `apigo.cc/go/cast` - **类型转换**: `apigo.cc/go/cast`

29
TEST.md Normal file
View File

@ -0,0 +1,29 @@
# Service Module Test Report
## 性能测试 (Benchmark)
- 测试日期: 2026-05-09
- 版本: v1.0.2
- 指标: `BenchmarkRouting`: 2984 ns/op
- 环境: Darwin / Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz
## 单元测试覆盖 (Unit Test)
- [x] `TestServeHTTP`: 基础请求与响应
- [x] `TestServeHTTP_404`: 404 处理
- [x] `TestServeHTTP_VerifyFailed`: 参数校验失败处理
- [x] `TestRewrite`: 路径重写
- [x] `TestProxyDirect`: 代理转发 (Mock)
- [x] `TestAsyncServer`: 异步启动与生命周期
- [x] `TestServiceRegister`: 基础路由注册
- [x] `TestRegexServiceRegister`: 正则路由注册
- [x] `TestStaticService`: 静态文件服务 (已支持内存文件)
- [x] `TestVerifyStruct`: 基础结构校验
- [x] `TestNestedVerify`: 嵌套结构校验
- [x] `TestCustomVerify`: 自定义校验函数
- [x] `TestWebSocketService`: WebSocket 注册
## 基础设施对齐验证
- [x] 成功集成 `apigo.cc/go/cast` 用于参数解析与类型强转。
- [x] 成功集成 `apigo.cc/go/timer` 用于高性能耗时追踪。
- [x] 成功集成 `apigo.cc/go/log` 并实现完整的 Request 日志记录。
- [x] 强制集成 `apigo.cc/go/file` 替代原生 `os`,全面支持内存虚拟文件系统。
- [x] 成功集成 `apigo.cc/go/id``go/redis` 实现分布式有序 ID。

28
bench_test.go Normal file
View File

@ -0,0 +1,28 @@
package service_test
import (
"apigo.cc/go/service"
"net/http"
"net/http/httptest"
"testing"
)
type BenchIn struct {
Name string `json:"name"`
Age int `json:"age"`
}
func BenchmarkRouting(b *testing.B) {
service.Host("*").ANY("/bench", func(in BenchIn) string {
return "hello " + in.Name
}).Memo("bench").NoLog200()
handler := &service.RouteHandler{}
req, _ := http.NewRequest("GET", "/bench?name=test&age=20", nil)
w := httptest.NewRecorder()
b.ResetTimer()
for i := 0; i < b.N; i++ {
handler.ServeHTTP(w, req)
}
}

View File

@ -16,6 +16,7 @@ type Api struct {
In any In any
Out any Out any
Memo string Memo string
Host string
} }
//go:embed DocTpl.html //go:embed DocTpl.html
@ -25,27 +26,29 @@ var defaultDocTpl string
func MakeDocument() []Api { func MakeDocument() []Api {
out := make([]Api, 0) out := make([]Api, 0)
// 1. Rewrite // 1. Rewrite & Proxy
rewritesLock.RLock() hostPoliciesLock.RLock()
for host, rewrites := range hostRewrites {
for _, a := range rewrites { for _, a := range rewrites {
out = append(out, Api{ out = append(out, Api{
Type: "Rewrite", Type: "Rewrite",
Host: host,
Path: a.fromPath + " -> " + a.toPath, Path: a.fromPath + " -> " + a.toPath,
}) })
} }
rewritesLock.RUnlock() }
for host, proxies := range hostProxies {
// 2. Proxy
proxiesLock.RLock()
for _, a := range proxies { for _, a := range proxies {
out = append(out, Api{ out = append(out, Api{
Type: "Proxy", Type: "Proxy",
Host: host,
Path: a.fromPath + " -> " + a.toApp + ":" + a.toPath, Path: a.fromPath + " -> " + a.toApp + ":" + a.toPath,
}) })
} }
proxiesLock.RUnlock() }
hostPoliciesLock.RUnlock()
// 3. Web Services // 2. Web Services
webServicesLock.RLock() webServicesLock.RLock()
for _, a := range webServicesList { for _, a := range webServicesList {
if a.options.NoDoc { if a.options.NoDoc {
@ -57,6 +60,7 @@ func MakeDocument() []Api {
AuthLevel: a.authLevel, AuthLevel: a.authLevel,
Method: a.method, Method: a.method,
Memo: a.memo, Memo: a.memo,
Host: a.host,
} }
if a.inType != nil { if a.inType != nil {
api.In = getType(a.inType) api.In = getType(a.inType)
@ -70,17 +74,18 @@ func MakeDocument() []Api {
// 4. WebSocket Services // 4. WebSocket Services
websocketServicesLock.RLock() websocketServicesLock.RLock()
for _, a := range websocketServices { for _, ws := range websocketServicesList {
api := Api{ api := Api{
Type: "WebSocket", Type: "WebSocket",
Path: a.path, Path: ws.path,
AuthLevel: a.authLevel, AuthLevel: ws.authLevel,
Memo: a.memo, Memo: ws.memo,
Host: ws.host,
} }
if a.handlerType != nil && a.handlerType.NumIn() > 0 { if ws.funcType != nil && ws.funcType.NumIn() > 0 {
// Find struct in // Find struct in
for i := 0; i < a.handlerType.NumIn(); i++ { for i := 0; i < ws.funcType.NumIn(); i++ {
t := a.handlerType.In(i) t := ws.funcType.In(i)
if t.Kind() == reflect.Struct { if t.Kind() == reflect.Struct {
api.In = getType(t) api.In = getType(t)
break break
@ -113,11 +118,11 @@ func getType(t reflect.Type) any {
switch t.Kind() { switch t.Kind() {
case reflect.Struct: case reflect.Struct:
outs := Map{} outs := make(map[string]any)
for i := 0; i < t.NumField(); i++ { for i := 0; i < t.NumField(); i++ {
f := t.Field(i) f := t.Field(i)
if f.Anonymous { if f.Anonymous {
if subMap, ok := getType(f.Type).(Map); ok { if subMap, ok := getType(f.Type).(map[string]any); ok {
for k, v := range subMap { for k, v := range subMap {
outs[k] = v outs[k] = v
} }

13
go.mod
View File

@ -2,4 +2,15 @@ module apigo.cc/go/service
go 1.25.0 go 1.25.0
require github.com/gorilla/websocket v1.5.3 require (
apigo.cc/go/cast v1.2.8
apigo.cc/go/config v1.0.7
apigo.cc/go/discover v1.0.7
apigo.cc/go/file v1.0.7
apigo.cc/go/http v1.0.10
apigo.cc/go/id v1.0.5
apigo.cc/go/log v1.1.9
apigo.cc/go/redis v1.0.5
apigo.cc/go/timer v1.0.6
github.com/gorilla/websocket v1.5.3
)

View File

@ -2,9 +2,9 @@ package service
import ( import (
"apigo.cc/go/cast" "apigo.cc/go/cast"
"apigo.cc/go/discover"
"apigo.cc/go/log" "apigo.cc/go/log"
"apigo.cc/go/standard" "apigo.cc/go/timer"
"encoding/json"
"io" "io"
"net/http" "net/http"
"reflect" "reflect"
@ -13,19 +13,19 @@ import (
"time" "time"
) )
type routeHandler struct { type RouteHandler struct {
webRequestingNum int64 webRequestingNum int64
} }
func (rh *routeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
atomic.AddInt64(&rh.webRequestingNum, 1) atomic.AddInt64(&rh.webRequestingNum, 1)
defer atomic.AddInt64(&rh.webRequestingNum, -1) defer atomic.AddInt64(&rh.webRequestingNum, -1)
startTime := time.Now() tracker := timer.Start()
requestId := r.Header.Get(standard.DiscoverHeaderRequestId) requestId := r.Header.Get(discover.HeaderRequestID)
if requestId == "" { if requestId == "" {
requestId = MakeId(12) requestId = MakeId(12)
r.Header.Set(standard.DiscoverHeaderRequestId, requestId) r.Header.Set(discover.HeaderRequestID, requestId)
} }
request := NewRequest(r) request := NewRequest(r)
@ -73,9 +73,18 @@ func (rh *routeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
} }
authLevel := 0
priority := 0
if s != nil {
authLevel = s.authLevel
priority = s.options.Priority
}
// 4. 处理业务执行 (WS 或 Web) // 4. 处理业务执行 (WS 或 Web)
if result == nil { if result == nil {
if ws != nil { if ws != nil {
authLevel = ws.authLevel
priority = ws.options.Priority
doWebsocketService(ws, request, response, requestLogger) doWebsocketService(ws, request, response, requestLogger)
return return
} else if s != nil { } else if s != nil {
@ -94,7 +103,6 @@ func (rh *routeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if s == nil && result == nil { if s == nil && result == nil {
response.WriteHeader(http.StatusNotFound) response.WriteHeader(http.StatusNotFound)
return
} }
// 5. 后置过滤器 // 5. 后置过滤器
@ -112,38 +120,97 @@ func (rh *routeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
outputResult(response, result) outputResult(response, result)
// 7. 记录日志 // 7. 记录日志
_ = startTime if s == nil || !s.options.NoLog200 || response.Code != 200 {
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
usedTime := float32(tracker.Stop().Seconds())
// 获取一些 Header 信息
reqHeaders := make(map[string]string)
for k, v := range r.Header {
reqHeaders[k] = strings.Join(v, ", ")
}
respHeaders := make(map[string]string)
for k, v := range response.Header() {
respHeaders[k] = strings.Join(v, ", ")
}
// 限制记录的 Body 长度
respData := ""
if response.Code != 200 {
if len(response.body) < 1024 {
respData = string(response.body)
} else {
respData = string(response.body[:1024]) + "..."
}
}
logRequest(
requestLogger,
r.Method, path, host, scheme, r.Proto,
request.ClientIp(), serverId, "", "", // app, node 暂无
r.Header.Get(discover.HeaderFromApp), r.Header.Get(discover.HeaderFromNode),
"", request.DeviceId(), request.SessionId(), requestId,
request.Header.Get(discover.HeaderClientAppName), request.Header.Get(discover.HeaderClientAppVersion),
authLevel, priority,
reqHeaders, args,
response.Code, usedTime,
respHeaders, respData, uint(len(response.body)),
)
}
} }
func findService(method, host, path string) (*webServiceType, *websocketServiceType) { func findService(method, host, path string) (*webServiceType, *websocketServiceType) {
webServicesLock.RLock() webServicesLock.RLock()
defer webServicesLock.RUnlock() defer webServicesLock.RUnlock()
// 1. Web Service 匹配 // 1. 准备 Host 候选列表: "host:port", "host", ":port", "*"
if s, exists := webServices[method+path]; exists { hostOnly, port, _ := strings.Cut(host, ":")
hosts := []string{host}
if port != "" {
hosts = append(hosts, hostOnly, ":"+port)
}
hosts = append(hosts, "*")
// 2. 匹配 Web Service
for _, h := range hosts {
if services, exists := webServices[h]; exists {
if s, ok := services[method+path]; ok {
return s, nil return s, nil
} }
if s, exists := webServices[path]; exists { if s, ok := services["*"+path]; ok {
return s, nil return s, nil
} }
}
}
// 2. WebSocket 匹配 // 3. 匹配 WebSocket
websocketServicesLock.RLock() websocketServicesLock.RLock()
defer websocketServicesLock.RUnlock() defer websocketServicesLock.RUnlock()
if ws, exists := websocketServices[path]; exists { for _, h := range hosts {
if services, exists := websocketServices[h]; exists {
if ws, ok := services[path]; ok {
return nil, ws return nil, ws
} }
}
}
// 3. 正则匹配 // 4. 正则匹配
for i := len(regexWebServices) - 1; i >= 0; i-- { for _, h := range hosts {
s := regexWebServices[i] if services, exists := regexWebServices[h]; exists {
if s.method != "" && s.method != method { for i := len(services) - 1; i >= 0; i-- {
s := services[i]
if s.method != "*" && s.method != method {
continue continue
} }
if s.pathMatcher != nil && s.pathMatcher.MatchString(path) { if s.pathMatcher != nil && s.pathMatcher.MatchString(path) {
return s, nil return s, nil
} }
} }
}
}
return nil, nil return nil, nil
} }
@ -166,7 +233,7 @@ func parseRequestArgs(request *Request, args map[string]any) {
body, _ := io.ReadAll(request.Body) body, _ := io.ReadAll(request.Body)
_ = request.Body.Close() _ = request.Body.Close()
if len(body) > 0 { if len(body) > 0 {
_ = json.Unmarshal(body, &args) _ = cast.UnmarshalJSON(body, &args)
} }
} else { } else {
_ = request.ParseForm() _ = request.ParseForm()
@ -198,8 +265,8 @@ func doWebService(service *webServiceType, request *Request, response *Response,
return result return result
} }
params := make([]reflect.Value, service.parmsNum) params := make([]reflect.Value, service.paramsNum)
for i := 0; i < service.parmsNum; i++ { for i := 0; i < service.paramsNum; i++ {
t := service.funcType.In(i) t := service.funcType.In(i)
switch i { switch i {
case service.requestIndex: case service.requestIndex:
@ -288,7 +355,7 @@ func handleClientKeys(request *Request, response *Response) {
}) })
} }
} }
request.Header.Set(standard.DiscoverHeaderSessionId, sessionId) request.Header.Set(discover.HeaderSessionID, sessionId)
response.Header().Set(usedSessionIdKey, sessionId) response.Header().Set(usedSessionIdKey, sessionId)
} }
@ -312,7 +379,7 @@ func handleClientKeys(request *Request, response *Response) {
}) })
} }
} }
request.Header.Set(standard.DiscoverHeaderDeviceId, deviceId) request.Header.Set(discover.HeaderDeviceID, deviceId)
response.Header().Set(usedDeviceIdKey, deviceId) response.Header().Set(usedDeviceIdKey, deviceId)
} }
} }

View File

@ -12,9 +12,9 @@ func TestServeHTTP(t *testing.T) {
handler := func(in struct{ Name string }) string { handler := func(in struct{ Name string }) string {
return "Hello " + in.Name return "Hello " + in.Name
} }
Register(0, "/hello", handler, "say hello") Host("*").POST("/hello", handler).Auth(0).Memo("say hello")
rh := &routeHandler{} rh := &RouteHandler{}
// 模拟请求 // 模拟请求
req := httptest.NewRequest("POST", "/hello", strings.NewReader(`{"name":"Star"}`)) req := httptest.NewRequest("POST", "/hello", strings.NewReader(`{"name":"Star"}`))
@ -34,7 +34,7 @@ func TestServeHTTP(t *testing.T) {
} }
func TestServeHTTP_404(t *testing.T) { func TestServeHTTP_404(t *testing.T) {
rh := &routeHandler{} rh := &RouteHandler{}
req := httptest.NewRequest("GET", "/notfound", nil) req := httptest.NewRequest("GET", "/notfound", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -52,9 +52,9 @@ func TestServeHTTP_VerifyFailed(t *testing.T) {
handler := func(in ValidIn) string { handler := func(in ValidIn) string {
return "ok" return "ok"
} }
Register(0, "/verify", handler, "test verify") Host("*").POST("/verify", handler).Auth(0).Memo("test verify")
rh := &routeHandler{} rh := &RouteHandler{}
req := httptest.NewRequest("POST", "/verify", strings.NewReader(`{"age":10}`)) req := httptest.NewRequest("POST", "/verify", strings.NewReader(`{"age":10}`))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()

223
log.go Normal file
View File

@ -0,0 +1,223 @@
package service
import (
"apigo.cc/go/cast"
"apigo.cc/go/log"
)
type RequestLog struct {
log.BaseLog
ServerId string `log:"pos:6,hide:true"`
App string `log:"pos:7,color:cyan,keyname:App"`
Node string `log:"pos:8,color:gray,attachBefore:true"`
FromApp string `log:"pos:9,color:cyan,keyname:From"`
FromNode string `log:"pos:10,color:gray,attachBefore:true"`
ClientIp string `log:"pos:11,withoutkey:true"`
ClientAppName string `log:"pos:12,attachBefore:true,keyname:Client"`
ClientAppVersion string `log:"pos:13,attachBefore:true"`
UserId string `log:"pos:14,color:magenta,keyname:User"`
DeviceId string `log:"pos:15,color:gray,keyname:Device"`
SessionId string `log:"pos:16,keyname:Session"`
Host string `log:"pos:17,color:gray,withoutkey:true"`
Method string `log:"pos:18,color:gray,withoutkey:true"`
Path string `log:"pos:19,color:cyan,withoutkey:true"`
Scheme string `log:"pos:20,color:gray,withoutkey:true"`
Proto string `log:"pos:21,color:gray,withoutkey:true"`
AuthLevel int `log:"pos:22,color:green"`
Priority int `log:"pos:23,hide:true"`
RequestData any `log:"pos:24,color:cyan,keyname:Request"`
RequestHeaders map[string]string `log:"pos:25,color:cyan,keyname:Headers"`
UsedTime float32 `log:"pos:26,color:green,precision:6"`
ResponseCode int `log:"pos:27,color:magenta,keyname:Status"`
ResponseDataLength uint `log:"pos:28,color:magenta,keyname:ContentLength"`
ResponseData any `log:"pos:29,color:magenta,keyname:Response"`
ResponseHeaders map[string]string `log:"pos:30,color:magenta,keyname:Headers"`
}
func (l *RequestLog) Reset() {
l.BaseLog.Reset()
l.ServerId = ""
l.App = ""
l.Node = ""
l.ClientIp = ""
l.FromApp = ""
l.FromNode = ""
l.UserId = ""
l.DeviceId = ""
l.ClientAppName = ""
l.ClientAppVersion = ""
l.SessionId = ""
l.Host = ""
l.Scheme = ""
l.Proto = ""
l.AuthLevel = 0
l.Priority = 0
l.Method = ""
l.Path = ""
if l.RequestHeaders == nil {
l.RequestHeaders = make(map[string]string, 8)
} else {
clear(l.RequestHeaders)
}
l.RequestData = nil
l.UsedTime = 0
l.ResponseCode = 0
if l.ResponseHeaders == nil {
l.ResponseHeaders = make(map[string]string, 8)
} else {
clear(l.ResponseHeaders)
}
l.ResponseDataLength = 0
l.ResponseData = nil
}
// RequestLog 调用封装
func logRequest(
logger *log.Logger,
method, path, host, scheme, proto string,
clientIp, serverId, app, node string,
fromApp, fromNode string,
userId, deviceId, sessionId, requestId string,
clientAppName, clientAppVersion string,
authLevel, priority int,
reqHeaders map[string]string,
reqData map[string]any,
responseCode int,
usedTime float32,
respHeaders map[string]string,
responseData string,
responseDataLength uint,
extra ...any,
) {
if !logger.CheckLevel(log.INFO) {
return
}
entry := log.GetEntry[RequestLog]()
logger.FillBase(entry.GetBaseLog(), log.LogTypeRequest)
entry.Method = method
entry.Path = path
entry.Host = host
entry.Scheme = scheme
entry.Proto = proto
entry.ClientIp = clientIp
entry.ServerId = serverId
entry.App = app
entry.Node = node
entry.FromApp = fromApp
entry.FromNode = fromNode
entry.UserId = userId
entry.DeviceId = deviceId
entry.SessionId = sessionId
entry.ClientAppName = clientAppName
entry.ClientAppVersion = clientAppVersion
entry.AuthLevel = authLevel
entry.Priority = priority
entry.RequestHeaders = reqHeaders
entry.RequestData = reqData
entry.ResponseCode = responseCode
entry.UsedTime = usedTime
entry.ResponseHeaders = respHeaders
entry.ResponseData = responseData
entry.ResponseDataLength = responseDataLength
if len(extra) > 0 {
cast.FillMap(&entry.Extra, extra)
}
logger.Log(entry)
}
type TaskLog struct {
log.BaseLog
Task string `log:"pos:6"`
UsedTime float32 `log:"pos:7"`
Success bool `log:"pos:8"`
Message string `log:"pos:9"`
}
func (l *TaskLog) Reset() {
l.BaseLog.Reset()
l.Task = ""
l.UsedTime = 0
l.Success = false
l.Message = ""
}
type MonitorLog struct {
log.BaseLog
Target string `log:"pos:6"`
Status int `log:"pos:7"`
Message string `log:"pos:8"`
}
func (l *MonitorLog) Reset() {
l.BaseLog.Reset()
l.Target = ""
l.Status = 0
l.Message = ""
}
type StatisticLog struct {
log.BaseLog
Category string `log:"pos:6"`
Item string `log:"pos:7"`
Value float64 `log:"pos:8"`
}
func (l *StatisticLog) Reset() {
l.BaseLog.Reset()
l.Category = ""
l.Item = ""
l.Value = 0
}
func logTask(logger *log.Logger, taskName string, usedTime float32, success bool, message string, extra ...any) {
if logger.CheckLevel(log.INFO) {
entry := log.GetEntry[TaskLog]()
logger.FillBase(entry.GetBaseLog(), log.LogTypeTask)
entry.Task = taskName
entry.UsedTime = usedTime
entry.Success = success
entry.Message = message
if len(extra) > 0 {
cast.FillMap(&entry.Extra, extra)
}
logger.Log(entry)
}
}
func logMonitor(logger *log.Logger, target string, status int, message string, extra ...any) {
if logger.CheckLevel(log.INFO) {
entry := log.GetEntry[MonitorLog]()
logger.FillBase(entry.GetBaseLog(), log.LogTypeMonitor)
entry.Target = target
entry.Status = status
entry.Message = message
if len(extra) > 0 {
cast.FillMap(&entry.Extra, extra)
}
logger.Log(entry)
}
}
func logStatistic(logger *log.Logger, category, item string, value float64, extra ...any) {
if logger.CheckLevel(log.INFO) {
entry := log.GetEntry[StatisticLog]()
logger.FillBase(entry.GetBaseLog(), log.LogTypeStatistic)
entry.Category = category
entry.Item = item
entry.Value = value
if len(extra) > 0 {
cast.FillMap(&entry.Extra, extra)
}
logger.Log(entry)
}
}
func init() {
log.RegisterType(log.LogTypeRequest, &RequestLog{})
log.RegisterType(log.LogTypeTask, &TaskLog{})
log.RegisterType(log.LogTypeMonitor, &MonitorLog{})
log.RegisterType(log.LogTypeStatistic, &StatisticLog{})
}

View File

@ -9,11 +9,10 @@ import (
"net/http" "net/http"
"regexp" "regexp"
"strings" "strings"
"sync"
"time" "time"
) )
type proxyInfo struct { type proxyType struct {
matcher *regexp.Regexp matcher *regexp.Regexp
authLevel int authLevel int
fromPath string fromPath string
@ -21,39 +20,32 @@ type proxyInfo struct {
toPath string toPath string
} }
var ( func (hc *HostContext) Proxy(authLevel int, path string, toApp, toPath string) *HostContext {
proxies = make(map[string]*proxyInfo) p := &proxyType{authLevel: authLevel, fromPath: path, toApp: toApp, toPath: toPath}
regexProxies = make([]*proxyInfo, 0)
proxyBy func(*Request) (int, *string, *string, map[string]string)
proxiesLock = sync.RWMutex{}
httpClientPool *gohttp.Client
)
// Proxy 注册代理规则
func Proxy(authLevel int, path string, toApp, toPath string) {
p := &proxyInfo{authLevel: authLevel, fromPath: path, toApp: toApp, toPath: toPath}
if strings.Contains(path, "(") { if strings.Contains(path, "(") {
matcher, err := regexp.Compile("^" + path + "$") matcher, err := regexp.Compile("^" + path + "$")
if err == nil { if err == nil {
p.matcher = matcher p.matcher = matcher
proxiesLock.Lock()
regexProxies = append(regexProxies, p)
proxiesLock.Unlock()
}
} else {
proxiesLock.Lock()
proxies[path] = p
proxiesLock.Unlock()
} }
} }
// SetProxyBy 设置动态代理函数 hostPoliciesLock.Lock()
func SetProxyBy(by func(request *Request) (authLevel int, toApp, toPath *string, headers map[string]string)) { defer hostPoliciesLock.Unlock()
proxyBy = by hostProxies[hc.host] = append(hostProxies[hc.host], p)
return hc
} }
func findProxy(request *Request) (int, *string, *string) { var httpClientPool *gohttp.Client
func findProxy(request *Request) (int, *string, *string, string) {
host := request.Host
hostOnly, port, _ := strings.Cut(host, ":")
hosts := []string{host}
if port != "" {
hosts = append(hosts, hostOnly, ":"+port)
}
hosts = append(hosts, "*")
requestPath := request.RequestURI requestPath := request.RequestURI
queryString := "" queryString := ""
if pos := strings.Index(requestPath, "?"); pos != -1 { if pos := strings.Index(requestPath, "?"); pos != -1 {
@ -61,16 +53,22 @@ func findProxy(request *Request) (int, *string, *string) {
requestPath = requestPath[:pos] requestPath = requestPath[:pos]
} }
proxiesLock.RLock() hostPoliciesLock.RLock()
defer proxiesLock.RUnlock() defer hostPoliciesLock.RUnlock()
if pi, ok := proxies[requestPath]; ok { for _, h := range hosts {
toPath := pi.toPath + queryString proxies, exists := hostProxies[h]
return pi.authLevel, &pi.toApp, &toPath if !exists {
continue
} }
for _, pi := range regexProxies { for _, pi := range proxies {
if pi.matcher != nil { if pi.matcher == nil {
if pi.fromPath == requestPath {
toPath := pi.toPath + queryString
return pi.authLevel, &pi.toApp, &toPath, h
}
} else {
finds := pi.matcher.FindAllStringSubmatch(requestPath, 1) finds := pi.matcher.FindAllStringSubmatch(requestPath, 1)
if len(finds) > 0 { if len(finds) > 0 {
toApp := pi.toApp toApp := pi.toApp
@ -80,21 +78,17 @@ func findProxy(request *Request) (int, *string, *string) {
toPath = strings.ReplaceAll(toPath, fmt.Sprintf("$%d", i), part) toPath = strings.ReplaceAll(toPath, fmt.Sprintf("$%d", i), part)
} }
toPath += queryString toPath += queryString
return pi.authLevel, &toApp, &toPath return pi.authLevel, &toApp, &toPath, h
}
} }
} }
} }
return 0, nil, nil return 0, nil, nil, ""
} }
func processProxy(request *Request, response *Response, logger *log.Logger) bool { func processProxy(request *Request, response *Response, logger *log.Logger) bool {
authLevel, proxyToApp, proxyToPath := findProxy(request) authLevel, proxyToApp, proxyToPath, foundHost := findProxy(request)
var proxyHeaders map[string]string
if proxyBy != nil && (proxyToApp == nil || proxyToPath == nil || *proxyToApp == "" || *proxyToPath == "") {
authLevel, proxyToApp, proxyToPath, proxyHeaders = proxyBy(request)
}
if proxyToApp == nil || proxyToPath == nil || *proxyToApp == "" || *proxyToPath == "" { if proxyToApp == nil || proxyToPath == nil || *proxyToApp == "" || *proxyToPath == "" {
return false return false
@ -112,25 +106,20 @@ func processProxy(request *Request, response *Response, logger *log.Logger) bool
app := *proxyToApp app := *proxyToApp
path := *proxyToPath path := *proxyToPath
logger.Info("proxy", "app", app, "path", path, "host", foundHost)
// 构建自定义头部
headerArgs := make([]string, 0)
for k, v := range proxyHeaders {
headerArgs = append(headerArgs, k, v)
}
if strings.Contains(app, "://") { if strings.Contains(app, "://") {
// 直接 URL 代理 // 直接 URL 代理
if httpClientPool == nil { if httpClientPool == nil {
httpClientPool = gohttp.NewClient(time.Duration(Config.RedirectTimeout) * time.Millisecond) httpClientPool = gohttp.NewClient(time.Duration(Config.RedirectTimeout) * time.Millisecond)
} }
res := httpClientPool.ManualDoByRequest(request.Request, request.Method, app+path, request.Body, headerArgs...) res := httpClientPool.ManualDoByRequest(request.Request, request.Method, app+path, request.Body)
copyResponse(res, response, logger) copyResponse(res, response, logger)
} else { } else {
// Discover 代理 // Discover 代理
caller := discover.NewCaller(request.Request, logger) caller := discover.NewCaller(request.Request, logger)
caller.NoBody = true caller.NoBody = true
res, _ := caller.ManualDoWithNode(request.Method, app, "", path, request.Body, headerArgs...) res, _ := caller.ManualDoWithNode(request.Method, app, "", path, request.Body)
copyResponse(res, response, logger) copyResponse(res, response, logger)
} }

View File

@ -8,14 +8,14 @@ import (
func TestRewrite(t *testing.T) { func TestRewrite(t *testing.T) {
// 注册重写规则 // 注册重写规则
Rewrite("/old", "/new") Host("*").Rewrite("/old", "/new")
Rewrite("/regex/(.*)", "/target/$1") Host("*").Rewrite("/regex/(.*)", "/target/$1")
// 注册目标服务 // 注册目标服务
Register(0, "/new", func() string { return "new content" }, "new") Host("*").ANY("/new", func() string { return "new content" }).Memo("new")
Register(0, "/target/123", func() string { return "target content" }, "target") Host("*").ANY("/target/123", func() string { return "target content" }).Memo("target")
rh := &routeHandler{} rh := &RouteHandler{}
// 测试精确匹配重写 // 测试精确匹配重写
req1 := httptest.NewRequest("GET", "/old", nil) req1 := httptest.NewRequest("GET", "/old", nil)
@ -43,9 +43,9 @@ func TestProxyDirect(t *testing.T) {
defer backend.Close() defer backend.Close()
// 注册代理规则 // 注册代理规则
Proxy(0, "/proxy", backend.URL, "/hello") Host("*").Proxy(0, "/proxy", backend.URL, "/hello")
rh := &routeHandler{} rh := &RouteHandler{}
req := httptest.NewRequest("GET", "/proxy", nil) req := httptest.NewRequest("GET", "/proxy", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
rh.ServeHTTP(w, req) rh.ServeHTTP(w, req)

View File

@ -1,16 +1,14 @@
package service package service
import ( import (
"apigo.cc/go/cast" "apigo.cc/go/discover"
"apigo.cc/go/standard" "apigo.cc/go/file"
"io" "io"
"mime/multipart" "mime/multipart"
"net" "net"
"net/http" "net/http"
"net/textproto" "net/textproto"
"net/url" "net/url"
"os"
"path/filepath"
) )
// UploadFile 上传文件结构 // UploadFile 上传文件结构
@ -28,25 +26,11 @@ func (f *UploadFile) Open() (multipart.File, error) {
// Save 保存上传文件到本地 // Save 保存上传文件到本地
func (f *UploadFile) Save(filename string) error { func (f *UploadFile) Save(filename string) error {
dir := filepath.Dir(filename) data, err := f.Content()
if _, err := os.Stat(dir); os.IsNotExist(err) {
_ = os.MkdirAll(dir, 0755)
}
dst, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil { if err != nil {
return err return err
} }
defer dst.Close() return file.WriteBytes(filename, data)
src, err := f.fileHeader.Open()
if err != nil {
return err
}
defer src.Close()
_, err = io.Copy(dst, src)
return err
} }
// Content 获取上传文件内容 // Content 获取上传文件内容
@ -94,34 +78,37 @@ func (r *Request) Get(key string) any {
// MakeUrl 根据当前请求构建完整 URL // MakeUrl 根据当前请求构建完整 URL
func (r *Request) MakeUrl(path string) string { func (r *Request) MakeUrl(path string) string {
scheme := r.Header.Get(standard.DiscoverHeaderScheme) scheme := r.Header.Get(discover.HeaderScheme)
if scheme == "" { if scheme == "" {
scheme = "http" scheme = "http"
} }
host := r.Header.Get(standard.DiscoverHeaderHost) host := r.Header.Get(discover.HeaderHost)
if host == "" { if host == "" {
host = r.Host host = r.Host
} }
return scheme + "://" + host + path return scheme + "://" + host + path
} }
// GetSessionId 获取会话 ID // DeviceId 获取设备 ID
func (r *Request) GetSessionId() string { func (r *Request) DeviceId() string {
sessionId := r.Header.Get(Config.Listen) // Wait, this should be usedSessionIdKey return r.Header.Get(discover.HeaderDeviceID)
// TODO: Fix dependency on global usedSessionIdKey }
return sessionId
// SessionId 获取会话 ID
func (r *Request) SessionId() string {
return r.Header.Get(discover.HeaderSessionID)
} }
// SetUserId 设置用户 ID传递给下游 // SetUserId 设置用户 ID传递给下游
func (r *Request) SetUserId(userId string) { func (r *Request) SetUserId(userId string) {
r.Header.Set(standard.DiscoverHeaderUserId, userId) r.Header.Set(discover.HeaderUserID, userId)
} }
// GetRealIp 获取真实 IP // ClientIp 获取真实 IP
func (r *Request) GetRealIp() string { func (r *Request) ClientIp() string {
ip := r.Header.Get(standard.DiscoverHeaderClientIp) ip := r.Header.Get(discover.HeaderClientIP)
if ip == "" { if ip == "" {
ip = r.Header.Get(standard.DiscoverHeaderForwardedFor) ip = r.Header.Get(discover.HeaderForwardedFor)
} }
if ip == "" { if ip == "" {
host, _, err := net.SplitHostPort(r.RemoteAddr) host, _, err := net.SplitHostPort(r.RemoteAddr)
@ -132,8 +119,3 @@ func (r *Request) GetRealIp() string {
} }
return ip return ip
} }
// GetLowerName (Aliased from cast)
func GetLowerName(s string) string {
return cast.GetLowerName(s)
}

View File

@ -2,16 +2,17 @@ package service
import ( import (
"apigo.cc/go/cast" "apigo.cc/go/cast"
"apigo.cc/go/file"
"io" "io"
"net/http" "net/http"
"os"
) )
// Response 封装 http.ResponseWriter // Response 封装 http.ResponseWriter
type Response struct { type Response struct {
Id string Id string
Writer http.ResponseWriter Writer http.ResponseWriter
status int Code int
body []byte
outLen int outLen int
changed bool changed bool
headerWritten bool headerWritten bool
@ -24,7 +25,7 @@ type Response struct {
func NewResponse(writer http.ResponseWriter) *Response { func NewResponse(writer http.ResponseWriter) *Response {
return &Response{ return &Response{
Writer: writer, Writer: writer,
status: http.StatusOK, Code: http.StatusOK,
} }
} }
@ -42,6 +43,9 @@ func (r *Response) Write(bytes []byte) (int, error) {
r.checkWriteHeader() r.checkWriteHeader()
r.changed = true r.changed = true
r.outLen += len(bytes) r.outLen += len(bytes)
if r.Code != http.StatusOK && len(r.body) < 4096 {
r.body = append(r.body, bytes...)
}
if r.ProxyHeader != nil { if r.ProxyHeader != nil {
r.copyProxyHeader() r.copyProxyHeader()
} }
@ -56,8 +60,8 @@ func (r *Response) WriteString(s string) (int, error) {
// WriteHeader 设置响应状态码 // WriteHeader 设置响应状态码
func (r *Response) WriteHeader(code int) { func (r *Response) WriteHeader(code int) {
r.changed = true r.changed = true
r.status = code r.Code = code
if r.ProxyHeader != nil && (r.status == http.StatusBadGateway || r.status == http.StatusServiceUnavailable || r.status == http.StatusGatewayTimeout) { if r.ProxyHeader != nil && (r.Code == http.StatusBadGateway || r.Code == http.StatusServiceUnavailable || r.Code == http.StatusGatewayTimeout) {
return return
} }
if r.ProxyHeader != nil { if r.ProxyHeader != nil {
@ -68,8 +72,8 @@ func (r *Response) WriteHeader(code int) {
func (r *Response) checkWriteHeader() { func (r *Response) checkWriteHeader() {
if !r.headerWritten { if !r.headerWritten {
r.headerWritten = true r.headerWritten = true
if r.status != http.StatusOK { if r.Code != http.StatusOK {
r.Writer.WriteHeader(r.status) r.Writer.WriteHeader(r.Code)
} }
} }
} }
@ -94,7 +98,7 @@ func (r *Response) Flush() {
// GetStatusCode 获取当前状态码 // GetStatusCode 获取当前状态码
func (r *Response) GetStatusCode() int { func (r *Response) GetStatusCode() int {
return r.status return r.Code
} }
// DontLog200 标记不记录 200 状态码的日志 // DontLog200 标记不记录 200 状态码的日志
@ -111,10 +115,8 @@ func (r *Response) Location(location string) {
// SendFile 发送文件 // SendFile 发送文件
func (r *Response) SendFile(contentType, filename string) { func (r *Response) SendFile(contentType, filename string) {
r.Header().Set("Content-Type", contentType) r.Header().Set("Content-Type", contentType)
// TODO: Integrate memory file support if needed if data, err := file.ReadBytes(filename); err == nil {
if fd, err := os.Open(filename); err == nil { _, _ = r.Write(data)
defer fd.Close()
_, _ = io.Copy(r, fd)
} }
} }

View File

@ -6,47 +6,42 @@ import (
"net/url" "net/url"
"regexp" "regexp"
"strings" "strings"
"sync"
) )
type rewriteInfo struct { type rewriteType struct {
matcher *regexp.Regexp matcher *regexp.Regexp
fromPath string fromPath string
toPath string toPath string
} }
var ( func (hc *HostContext) Rewrite(path string, toPath string) *HostContext {
rewrites = make(map[string]*rewriteInfo) s := &rewriteType{fromPath: path, toPath: toPath}
regexRewrites = make([]*rewriteInfo, 0)
rewriteBy func(*Request) (string, bool)
rewritesLock = sync.RWMutex{}
)
// Rewrite 注册重写规则
func Rewrite(path string, toPath string) {
s := &rewriteInfo{fromPath: path, toPath: toPath}
if strings.ContainsRune(path, '(') { if strings.ContainsRune(path, '(') {
matcher, err := regexp.Compile("^" + path + "$") matcher, err := regexp.Compile("^" + path + "$")
if err == nil { if err == nil {
s.matcher = matcher s.matcher = matcher
rewritesLock.Lock()
regexRewrites = append(regexRewrites, s)
rewritesLock.Unlock()
}
} else {
rewritesLock.Lock()
rewrites[path] = s
rewritesLock.Unlock()
} }
} }
// SetRewriteBy 设置动态重写函数 hostPoliciesLock.Lock()
func SetRewriteBy(by func(request *Request) (toPath string, rewrite bool)) { defer hostPoliciesLock.Unlock()
rewriteBy = by hostRewrites[hc.host] = append(hostRewrites[hc.host], s)
return hc
} }
func processRewrite(request *Request, response *Response, logger *log.Logger) bool { func processRewrite(request *Request, response *Response, logger *log.Logger) bool {
host := request.Host
hostOnly, port, _ := strings.Cut(host, ":")
hosts := []string{host}
if port != "" {
hosts = append(hosts, hostOnly, ":"+port)
}
hosts = append(hosts, "*")
hostPoliciesLock.RLock()
defer hostPoliciesLock.RUnlock()
requestPath := request.RequestURI requestPath := request.RequestURI
queryString := "" queryString := ""
if pos := strings.Index(requestPath, "?"); pos != -1 { if pos := strings.Index(requestPath, "?"); pos != -1 {
@ -54,25 +49,22 @@ func processRewrite(request *Request, response *Response, logger *log.Logger) bo
requestPath = requestPath[:pos] requestPath = requestPath[:pos]
} }
var rewriteToPath string for _, h := range hosts {
var found bool rewrites, exists := hostRewrites[h]
if !exists {
continue
}
rewritesLock.RLock() for _, ri := range rewrites {
// 1. 精确匹配 found := false
if ri, ok := rewrites[requestPath]; ok { rewriteToPath := ""
if ri.matcher == nil {
if ri.fromPath == requestPath {
rewriteToPath = ri.toPath rewriteToPath = ri.toPath
found = true found = true
} }
} else {
// 2. 动态重写
if !found && rewriteBy != nil {
rewriteToPath, found = rewriteBy(request)
}
// 3. 正则匹配
if !found {
for _, ri := range regexRewrites {
if ri.matcher != nil {
finds := ri.matcher.FindAllStringSubmatch(request.RequestURI, 1) finds := ri.matcher.FindAllStringSubmatch(request.RequestURI, 1)
if len(finds) > 0 { if len(finds) > 0 {
toPath := ri.toPath toPath := ri.toPath
@ -81,12 +73,8 @@ func processRewrite(request *Request, response *Response, logger *log.Logger) bo
} }
rewriteToPath = toPath rewriteToPath = toPath
found = true found = true
break
} }
} }
}
}
rewritesLock.RUnlock()
if found { if found {
if strings.Contains(rewriteToPath, "://") { if strings.Contains(rewriteToPath, "://") {
@ -99,7 +87,7 @@ func processRewrite(request *Request, response *Response, logger *log.Logger) bo
return true return true
} else { } else {
// 内部重写 // 内部重写
logger.Info("rewrite", "from", request.RequestURI, "to", rewriteToPath) logger.Info("rewrite", "from", request.RequestURI, "to", rewriteToPath, "host", h)
if queryString != "" && !strings.Contains(rewriteToPath, "?") { if queryString != "" && !strings.Contains(rewriteToPath, "?") {
rewriteToPath += queryString rewriteToPath += queryString
} }
@ -108,6 +96,8 @@ func processRewrite(request *Request, response *Response, logger *log.Logger) bo
return false // 继续后续处理 return false // 继续后续处理
} }
} }
}
}
return false return false
} }

View File

@ -50,7 +50,7 @@ func (as *AsyncServer) start() {
serverAddr = as.Addr serverAddr = as.Addr
as.server = &http.Server{ as.server = &http.Server{
Handler: &routeHandler{}, Handler: &RouteHandler{},
} }
signal.Notify(as.stopChan, os.Interrupt, syscall.SIGTERM) signal.Notify(as.stopChan, os.Interrupt, syscall.SIGTERM)

View File

@ -9,31 +9,15 @@ import (
"sync" "sync"
) )
// Map 通用 Map 类型
type Map = map[string]any
// Arr 通用切片类型
type Arr = []any
// WebServiceOptions 服务注册选项
type WebServiceOptions struct {
Priority int
NoDoc bool
NoBody bool
NoLog200 bool
Host string
Ext Map
// Limiters []*Limiter // TODO: Integrate Limiter
}
// webServiceType 内部存储的服务元数据 // webServiceType 内部存储的服务元数据
type webServiceType struct { type webServiceType struct {
authLevel int authLevel int
method string method string
host string
path string path string
pathMatcher *regexp.Regexp pathMatcher *regexp.Regexp
pathArgs []string pathArgs []string
parmsNum int paramsNum int
inType reflect.Type inType reflect.Type
inIndex int inIndex int
headersType reflect.Type headersType reflect.Type
@ -47,10 +31,29 @@ type webServiceType struct {
funcType reflect.Type funcType reflect.Type
funcValue reflect.Value funcValue reflect.Value
options WebServiceOptions options WebServiceOptions
data Map data map[string]any
memo string memo string
} }
// WebServiceOptions 服务注册选项
type WebServiceOptions struct {
Priority int
NoDoc bool
NoBody bool
NoLog200 bool
Ext map[string]any
}
type websocketServiceType struct {
authLevel int
host string
path string
memo string
funcType reflect.Type
funcValue reflect.Value
options WebServiceOptions
}
var ( var (
serverId string serverId string
serverAddr string serverAddr string
@ -58,14 +61,21 @@ var (
serverProtoName = "http" serverProtoName = "http"
running = false running = false
webServices = make(map[string]*webServiceType) // webServices 按 Host 隔离: map[host]map[method+path]*webServiceType
regexWebServices = make([]*webServiceType, 0) webServices = make(map[string]map[string]*webServiceType)
// regexWebServices 按 Host 隔离: map[host][]*webServiceType
regexWebServices = make(map[string][]*webServiceType)
webServicesLock = sync.RWMutex{} webServicesLock = sync.RWMutex{}
webServicesList = make([]*webServiceType, 0) webServicesList = make([]*webServiceType, 0)
websocketServices = make(map[string]*websocketServiceType) websocketServices = make(map[string]map[string]*websocketServiceType)
websocketServicesLock = sync.RWMutex{} websocketServicesLock = sync.RWMutex{}
websocketServicesList = make([]*webServiceType, 0) websocketServicesList = make([]*websocketServiceType, 0)
// Rewrite 与 Proxy 按 Host 隔离
hostRewrites = make(map[string][]*rewriteType)
hostProxies = make(map[string][]*proxyType)
hostPoliciesLock = sync.RWMutex{}
// 过滤器与拦截器 // 过滤器与拦截器
inFilters = make([]func(*map[string]any, *Request, *Response, *log.Logger) any, 0) inFilters = make([]func(*map[string]any, *Request, *Response, *log.Logger) any, 0)
@ -118,29 +128,28 @@ func SetOutFilter(filter func(in map[string]any, request *Request, response *Res
outFilters = append(outFilters, filter) outFilters = append(outFilters, filter)
} }
// Register 注册服务(通用方法) // HostContext 提供流式服务注册能力
func Register(authLevel int, path string, serviceFunc any, memo string) { type HostContext struct {
Restful(authLevel, "", path, serviceFunc, memo) host string
} }
// Restful 注册指定方法的服务 // Host 指定服务运行的 Host (支持 "example.com", ":8080", "example.com:8080", "*")
func Restful(authLevel int, method, path string, serviceFunc any, memo string) { func Host(host string) *HostContext {
RestfulWithOptions(authLevel, method, path, serviceFunc, memo, WebServiceOptions{}) if host == "" {
host = "*"
}
return &HostContext{host: host}
} }
// RestfulWithOptions 注册带选项的服务 func (hc *HostContext) Register(method, path string, serviceFunc any) *webServiceType {
func RestfulWithOptions(authLevel int, method, path string, serviceFunc any, memo string, options WebServiceOptions) {
s, err := makeCachedService(serviceFunc) s, err := makeCachedService(serviceFunc)
if err != nil { if err != nil {
// TODO: Log error properly when logger is ready return &webServiceType{} // 返回空对象避免链式调用崩溃
return
} }
s.authLevel = authLevel s.host = hc.host
s.options = options s.method = strings.ToUpper(method)
s.method = method
s.path = path s.path = path
s.memo = memo
// 解析路径参数 {name} // 解析路径参数 {name}
finder, err := regexp.Compile("{(.*?)}") finder, err := regexp.Compile("{(.*?)}")
@ -159,13 +168,169 @@ func RestfulWithOptions(authLevel int, method, path string, serviceFunc any, mem
webServicesLock.Lock() webServicesLock.Lock()
defer webServicesLock.Unlock() defer webServicesLock.Unlock()
// 简单路径匹配
if s.pathMatcher == nil { if s.pathMatcher == nil {
webServices[method+path] = s // TODO: Include Host in key if webServices[s.host] == nil {
webServices[s.host] = make(map[string]*webServiceType)
}
webServices[s.host][s.method+s.path] = s
} else { } else {
regexWebServices = append(regexWebServices, s) regexWebServices[s.host] = append(regexWebServices[s.host], s)
} }
webServicesList = append(webServicesList, s) webServicesList = append(webServicesList, s)
return s
}
func (hc *HostContext) GET(path string, serviceFunc any) *webServiceType {
return hc.Register("GET", path, serviceFunc)
}
func (hc *HostContext) POST(path string, serviceFunc any) *webServiceType {
return hc.Register("POST", path, serviceFunc)
}
func (hc *HostContext) PUT(path string, serviceFunc any) *webServiceType {
return hc.Register("PUT", path, serviceFunc)
}
func (hc *HostContext) DELETE(path string, serviceFunc any) *webServiceType {
return hc.Register("DELETE", path, serviceFunc)
}
func (hc *HostContext) PATCH(path string, serviceFunc any) *webServiceType {
return hc.Register("PATCH", path, serviceFunc)
}
func (hc *HostContext) HEAD(path string, serviceFunc any) *webServiceType {
return hc.Register("HEAD", path, serviceFunc)
}
func (hc *HostContext) OPTIONS(path string, serviceFunc any) *webServiceType {
return hc.Register("OPTIONS", path, serviceFunc)
}
func (hc *HostContext) ANY(path string, serviceFunc any) *webServiceType {
return hc.Register("*", path, serviceFunc)
}
// GroupContext 提供路径分组注册能力
type GroupContext struct {
hc *HostContext
prefix string
}
// Group 创建路径分组
func (hc *HostContext) Group(prefix string) *GroupContext {
if prefix == "/" {
prefix = ""
}
return &GroupContext{hc: hc, prefix: prefix}
}
func (gc *GroupContext) GET(path string, serviceFunc any) *webServiceType {
return gc.hc.Register("GET", gc.prefix+path, serviceFunc)
}
func (gc *GroupContext) POST(path string, serviceFunc any) *webServiceType {
return gc.hc.Register("POST", gc.prefix+path, serviceFunc)
}
func (gc *GroupContext) PUT(path string, serviceFunc any) *webServiceType {
return gc.hc.Register("PUT", gc.prefix+path, serviceFunc)
}
func (gc *GroupContext) DELETE(path string, serviceFunc any) *webServiceType {
return gc.hc.Register("DELETE", gc.prefix+path, serviceFunc)
}
func (gc *GroupContext) ANY(path string, serviceFunc any) *webServiceType {
return gc.hc.Register("*", gc.prefix+path, serviceFunc)
}
func (gc *GroupContext) WebSocket(path string, serviceFunc any) *websocketServiceType {
return gc.hc.WebSocket(gc.prefix+path, serviceFunc)
}
func (gc *GroupContext) Rewrite(path string, toPath string) *GroupContext {
gc.hc.Rewrite(gc.prefix+path, toPath)
return gc
}
func (gc *GroupContext) Proxy(authLevel int, path string, toApp, toPath string) *GroupContext {
gc.hc.Proxy(authLevel, gc.prefix+path, toApp, toPath)
return gc
}
func (hc *HostContext) WebSocket(path string, serviceFunc any) *websocketServiceType {
funcType := reflect.TypeOf(serviceFunc)
if funcType.Kind() != reflect.Func {
return &websocketServiceType{}
}
ws := &websocketServiceType{
host: hc.host,
path: path,
funcType: funcType,
funcValue: reflect.ValueOf(serviceFunc),
}
websocketServicesLock.Lock()
defer websocketServicesLock.Unlock()
if websocketServices[hc.host] == nil {
websocketServices[hc.host] = make(map[string]*websocketServiceType)
}
websocketServices[hc.host][path] = ws
websocketServicesList = append(websocketServicesList, ws)
return ws
}
// webServiceType 链式配置方法
func (s *webServiceType) Auth(level int) *webServiceType {
s.authLevel = level
return s
}
func (s *webServiceType) Memo(memo string) *webServiceType {
s.memo = memo
return s
}
func (s *webServiceType) Priority(p int) *webServiceType {
s.options.Priority = p
return s
}
func (s *webServiceType) NoDoc() *webServiceType {
s.options.NoDoc = true
return s
}
func (s *webServiceType) NoBody() *webServiceType {
s.options.NoBody = true
return s
}
func (s *webServiceType) NoLog200() *webServiceType {
s.options.NoLog200 = true
return s
}
func (s *webServiceType) Ext(key string, val any) *webServiceType {
if s.options.Ext == nil {
s.options.Ext = make(map[string]any)
}
s.options.Ext[key] = val
return s
}
// websocketServiceType 链式配置方法
func (s *websocketServiceType) Auth(level int) *websocketServiceType {
s.authLevel = level
return s
}
func (s *websocketServiceType) Memo(memo string) *websocketServiceType {
s.memo = memo
return s
} }
func makeCachedService(matchedService any) (*webServiceType, error) { func makeCachedService(matchedService any) (*webServiceType, error) {
@ -175,7 +340,7 @@ func makeCachedService(matchedService any) (*webServiceType, error) {
} }
targetService := &webServiceType{ targetService := &webServiceType{
parmsNum: funcType.NumIn(), paramsNum: funcType.NumIn(),
inIndex: -1, inIndex: -1,
headersIndex: -1, headersIndex: -1,
requestIndex: -1, requestIndex: -1,
@ -188,7 +353,7 @@ func makeCachedService(matchedService any) (*webServiceType, error) {
funcValue: reflect.ValueOf(matchedService), funcValue: reflect.ValueOf(matchedService),
} }
for i := 0; i < targetService.parmsNum; i++ { for i := 0; i < targetService.paramsNum; i++ {
t := funcType.In(i) t := funcType.In(i)
tStr := t.String() tStr := t.String()
switch tStr { switch tStr {
@ -228,3 +393,14 @@ func GetInject(dataType reflect.Type) any {
} }
return nil return nil
} }
// GetInjectT 获取注入对象 (泛型版)
func GetInjectT[T any]() T {
var zero T
t := reflect.TypeOf((*T)(nil)).Elem()
obj := GetInject(t)
if obj == nil {
return zero
}
return obj.(T)
}

View File

@ -10,10 +10,10 @@ func TestServiceRegister(t *testing.T) {
return "ok" return "ok"
} }
Register(0, "/test", handler, "test service") Host("*").Register("*", "/test", handler).Auth(0).Memo("test service")
webServicesLock.RLock() webServicesLock.RLock()
s := webServices["/test"] s := webServices["*"]["*/test"]
webServicesLock.RUnlock() webServicesLock.RUnlock()
if s == nil { if s == nil {
@ -33,11 +33,12 @@ func TestRegexServiceRegister(t *testing.T) {
return "ok" return "ok"
} }
Register(0, "/user/{id}", handler, "get user") Host("*").Register("*", "/user/{id}", handler).Auth(0).Memo("get user")
webServicesLock.RLock() webServicesLock.RLock()
found := false found := false
for _, s := range regexWebServices { for _, services := range regexWebServices {
for _, s := range services {
if s.path == "/user/{id}" { if s.path == "/user/{id}" {
found = true found = true
if len(s.pathArgs) != 1 || s.pathArgs[0] != "id" { if len(s.pathArgs) != 1 || s.pathArgs[0] != "id" {
@ -46,6 +47,10 @@ func TestRegexServiceRegister(t *testing.T) {
break break
} }
} }
if found {
break
}
}
webServicesLock.RUnlock() webServicesLock.RUnlock()
if !found { if !found {

View File

@ -5,7 +5,6 @@ import (
"apigo.cc/go/log" "apigo.cc/go/log"
"mime" "mime"
"net/http" "net/http"
"os"
"path/filepath" "path/filepath"
"strings" "strings"
"sync" "sync"
@ -74,16 +73,16 @@ func processStatic(requestPath string, request *Request, response *Response, log
return false return false
} }
info, err := os.Stat(filePath) info := file.GetFileInfo(filePath)
if err != nil { if info == nil {
return false return false
} }
if info.IsDir() { if info.IsDir {
// 自动查找索引文件 // 自动查找索引文件
for _, indexFile := range Config.IndexFiles { for _, indexFile := range Config.IndexFiles {
f := filepath.Join(filePath, indexFile) f := filepath.Join(filePath, indexFile)
if i, err := os.Stat(f); err == nil && !i.IsDir() { if i := file.GetFileInfo(f); i != nil && !i.IsDir {
filePath = f filePath = f
info = i info = i
break break
@ -91,14 +90,15 @@ func processStatic(requestPath string, request *Request, response *Response, log
} }
} }
if info.IsDir() { if info.IsDir {
return false return false
} }
// 检查 304 // 检查 304
if ifModifiedSince := request.Header.Get("If-Modified-Since"); ifModifiedSince != "" { if ifModifiedSince := request.Header.Get("If-Modified-Since"); ifModifiedSince != "" {
if t, err := time.Parse(http.TimeFormat, ifModifiedSince); err == nil { if t, err := time.Parse(http.TimeFormat, ifModifiedSince); err == nil {
if !info.ModTime().Truncate(time.Second).After(t.Truncate(time.Second)) { if time.Unix(info.ModTime, 0).Truncate(time.Second).Before(t.Truncate(time.Second)) ||
time.Unix(info.ModTime, 0).Truncate(time.Second).Equal(t.Truncate(time.Second)) {
response.WriteHeader(http.StatusNotModified) response.WriteHeader(http.StatusNotModified)
return true return true
} }
@ -111,7 +111,7 @@ func processStatic(requestPath string, request *Request, response *Response, log
contentType = "application/octet-stream" contentType = "application/octet-stream"
} }
response.Header().Set("Content-Type", contentType) response.Header().Set("Content-Type", contentType)
response.Header().Set("Last-Modified", info.ModTime().UTC().Format(http.TimeFormat)) response.Header().Set("Last-Modified", time.Unix(info.ModTime, 0).UTC().Format(http.TimeFormat))
data, err := file.ReadBytes(filePath) data, err := file.ReadBytes(filePath)
if err != nil { if err != nil {

View File

@ -19,7 +19,7 @@ func TestStaticService(t *testing.T) {
// 注册静态目录 // 注册静态目录
Static("/ui", tempDir) Static("/ui", tempDir)
rh := &routeHandler{} rh := &RouteHandler{}
// 测试成功访问 // 测试成功访问
req := httptest.NewRequest("GET", "/ui/index.html", nil) req := httptest.NewRequest("GET", "/ui/index.html", nil)

View File

@ -92,7 +92,7 @@ func VerifyStruct(in any, logger *log.Logger) (ok bool, field string) {
keyTag := ft.Tag.Get("verifyKey") keyTag := ft.Tag.Get("verifyKey")
if tag != "" || keyTag != "" { if tag != "" || keyTag != "" {
var err error var err error
ok, f, err := _verifyValue(fv, tag, keyTag, logger) ok, f, err := verifyValue(fv, tag, keyTag, logger)
if !ok { if !ok {
if f == "" { if f == "" {
f = cast.GetLowerName(ft.Name) f = cast.GetLowerName(ft.Name)
@ -111,13 +111,13 @@ func VerifyStruct(in any, logger *log.Logger) (ok bool, field string) {
return true, "" return true, ""
} }
func _verifyValue(in reflect.Value, setting, keySetting string, logger *log.Logger) (bool, string, error) { func verifyValue(in reflect.Value, setting, keySetting string, logger *log.Logger) (bool, string, error) {
t := in.Type() t := in.Type()
// 处理切片 (非 byte 切片) // 处理切片 (非 byte 切片)
if t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8 { if t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8 {
if setting != "" { if setting != "" {
for i := 0; i < in.Len(); i++ { for i := 0; i < in.Len(); i++ {
if ok, f, err := _verifyValue(in.Index(i), setting, "", logger); !ok { if ok, f, err := verifyValue(in.Index(i), setting, "", logger); !ok {
return false, f, err return false, f, err
} }
} }
@ -129,12 +129,12 @@ func _verifyValue(in reflect.Value, setting, keySetting string, logger *log.Logg
if t.Kind() == reflect.Map { if t.Kind() == reflect.Map {
for _, k := range in.MapKeys() { for _, k := range in.MapKeys() {
if keySetting != "" { if keySetting != "" {
if ok, _, err := _verifyValue(k, keySetting, "", logger); !ok { if ok, _, err := verifyValue(k, keySetting, "", logger); !ok {
return false, "key", err return false, "key", err
} }
} }
if setting != "" { if setting != "" {
if ok, f, err := _verifyValue(in.MapIndex(k), setting, "", logger); !ok { if ok, f, err := verifyValue(in.MapIndex(k), setting, "", logger); !ok {
return false, f, err return false, f, err
} }
} }

View File

@ -7,40 +7,12 @@ import (
"reflect" "reflect"
) )
// websocketServiceType WebSocket 服务元数据 var defaultUpgrader = &websocket.Upgrader{
type websocketServiceType struct { CheckOrigin: func(r *http.Request) bool { return true },
authLevel int
path string
updater *websocket.Upgrader
handlerValue reflect.Value
handlerType reflect.Type
memo string
}
// RegisterWebsocket 注册 WebSocket 服务
func RegisterWebsocket(authLevel int, path string, handler any, memo string) {
v := reflect.ValueOf(handler)
t := v.Type()
if t.Kind() != reflect.Func {
return
}
s := &websocketServiceType{
authLevel: authLevel,
path: path,
memo: memo,
updater: &websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }},
handlerValue: v,
handlerType: t,
}
websocketServicesLock.Lock()
websocketServices[path] = s
websocketServicesLock.Unlock()
} }
func doWebsocketService(ws *websocketServiceType, request *Request, response *Response, logger *log.Logger) { func doWebsocketService(ws *websocketServiceType, request *Request, response *Response, logger *log.Logger) {
conn, err := ws.updater.Upgrade(response.Writer, request.Request, nil) conn, err := defaultUpgrader.Upgrade(response.Writer, request.Request, nil)
if err != nil { if err != nil {
logger.Error("websocket upgrade failed", "error", err.Error()) logger.Error("websocket upgrade failed", "error", err.Error())
return return
@ -48,9 +20,9 @@ func doWebsocketService(ws *websocketServiceType, request *Request, response *Re
defer conn.Close() defer conn.Close()
// 调用业务处理函数,注入依赖 // 调用业务处理函数,注入依赖
params := make([]reflect.Value, ws.handlerType.NumIn()) params := make([]reflect.Value, ws.funcType.NumIn())
for i := 0; i < len(params); i++ { for i := 0; i < len(params); i++ {
t := ws.handlerType.In(i) t := ws.funcType.In(i)
if t == reflect.TypeOf(request) { if t == reflect.TypeOf(request) {
params[i] = reflect.ValueOf(request) params[i] = reflect.ValueOf(request)
} else if t == reflect.TypeOf(logger) { } else if t == reflect.TypeOf(logger) {
@ -63,5 +35,5 @@ func doWebsocketService(ws *websocketServiceType, request *Request, response *Re
params[i] = reflect.New(t).Elem() params[i] = reflect.New(t).Elem()
} }
} }
ws.handlerValue.Call(params) ws.funcValue.Call(params)
} }

View File

@ -9,18 +9,18 @@ import (
func TestWebSocketService(t *testing.T) { func TestWebSocketService(t *testing.T) {
// 注册 WebSocket 服务 // 注册 WebSocket 服务
RegisterWebsocket(0, "/ws", func(conn *websocket.Conn) { Host("*").WebSocket("/ws", func(conn *websocket.Conn) {
for { for {
var msg Map var msg map[string]any
if err := conn.ReadJSON(&msg); err != nil { if err := conn.ReadJSON(&msg); err != nil {
break break
} }
_ = conn.WriteJSON(Map{"reply": msg["msg"]}) _ = conn.WriteJSON(map[string]any{"reply": msg["msg"]})
} }
}, "test websocket") }).Auth(0).Memo("test websocket")
// 启动测试服务器 // 启动测试服务器
server := httptest.NewServer(&routeHandler{}) server := httptest.NewServer(&RouteHandler{})
defer server.Close() defer server.Close()
// 建立连接 // 建立连接
@ -32,13 +32,13 @@ func TestWebSocketService(t *testing.T) {
defer conn.Close() defer conn.Close()
// 发送消息 // 发送消息
msg := Map{"action": "echo", "msg": "hello"} msg := map[string]any{"action": "echo", "msg": "hello"}
if err := conn.WriteJSON(msg); err != nil { if err := conn.WriteJSON(msg); err != nil {
t.Fatalf("WriteJSON failed: %v", err) t.Fatalf("WriteJSON failed: %v", err)
} }
// 接收响应 // 接收响应
var reply Map var reply map[string]any
if err := conn.ReadJSON(&reply); err != nil { if err := conn.ReadJSON(&reply); err != nil {
t.Fatalf("ReadJSON failed: %v", err) t.Fatalf("ReadJSON failed: %v", err)
} }