From bdb104aa2f05d1fa5bf2a474dee8cb58e874d608 Mon Sep 17 00:00:00 2001 From: AI Engineer Date: Fri, 8 May 2026 07:27:06 +0800 Subject: [PATCH] Migrate service module from ssgo/s with modern Go features (by AI) --- .log.meta.json | 267 ++++++++++++++++++++++++++++++++++++++ DocTpl.html | 216 +++++++++++++++++++++++++++++++ README.md | 56 ++++++++ config.go | 62 +++++++++ document.go | 162 +++++++++++++++++++++++ go.mod | 5 + go.sum | 2 + handler.go | 319 ++++++++++++++++++++++++++++++++++++++++++++++ handler_test.go | 67 ++++++++++ proxy.go | 171 +++++++++++++++++++++++++ proxy_test.go | 62 +++++++++ request.go | 139 ++++++++++++++++++++ response.go | 152 ++++++++++++++++++++++ rewrite.go | 113 ++++++++++++++++ server.go | 88 +++++++++++++ server_test.go | 31 +++++ service.go | 225 ++++++++++++++++++++++++++++++++ service_test.go | 54 ++++++++ starter.go | 49 +++++++ static.go | 123 ++++++++++++++++++ static_test.go | 43 +++++++ types.go | 68 ++++++++++ types_test.go | 41 ++++++ utility.go | 20 +++ verify.go | 305 ++++++++++++++++++++++++++++++++++++++++++++ verify_test.go | 76 +++++++++++ websocket.go | 175 +++++++++++++++++++++++++ websocket_test.go | 44 +++++++ 28 files changed, 3135 insertions(+) create mode 100644 .log.meta.json create mode 100644 DocTpl.html create mode 100644 README.md create mode 100644 config.go create mode 100644 document.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 handler.go create mode 100644 handler_test.go create mode 100644 proxy.go create mode 100644 proxy_test.go create mode 100644 request.go create mode 100644 response.go create mode 100644 rewrite.go create mode 100644 server.go create mode 100644 server_test.go create mode 100644 service.go create mode 100644 service_test.go create mode 100644 starter.go create mode 100644 static.go create mode 100644 static_test.go create mode 100644 types.go create mode 100644 types_test.go create mode 100644 utility.go create mode 100644 verify.go create mode 100644 verify_test.go create mode 100644 websocket.go create mode 100644 websocket_test.go diff --git a/.log.meta.json b/.log.meta.json new file mode 100644 index 0000000..9190980 --- /dev/null +++ b/.log.meta.json @@ -0,0 +1,267 @@ +{ + "debug": [ + { + "index": 0, + "name": "LogName", + "color": "cyan", + "hide": true + }, + { + "index": 1, + "name": "LogType", + "color": "magenta", + "hide": true + }, + { + "index": 2, + "name": "LogTime", + "format": "time" + }, + { + "index": 3, + "name": "TraceId", + "color": "blue" + }, + { + "index": 4, + "name": "Image", + "color": "darkGray", + "hide": true + }, + { + "index": 5, + "name": "Server", + "color": "darkGray", + "hide": true + }, + { + "index": 6, + "name": "Debug", + "withoutKey": true + }, + { + "index": 7, + "name": "Extra" + } + ], + "discover": [ + { + "index": 0, + "name": "LogName", + "color": "cyan", + "hide": true + }, + { + "index": 1, + "name": "LogType", + "color": "magenta", + "hide": true + }, + { + "index": 2, + "name": "LogTime", + "format": "time" + }, + { + "index": 3, + "name": "TraceId", + "color": "blue" + }, + { + "index": 4, + "name": "Image", + "color": "darkGray", + "hide": true + }, + { + "index": 5, + "name": "Server", + "color": "darkGray", + "hide": true + }, + { + "index": 6, + "name": "App", + "color": "cyan" + }, + { + "index": 7, + "name": "Method", + "color": "magenta" + }, + { + "index": 8, + "name": "Path", + "color": "blue" + }, + { + "index": 9, + "name": "Node", + "color": "yellow" + }, + { + "index": 10, + "name": "Attempts" + }, + { + "index": 11, + "name": "UsedTime", + "format": "%.2fms" + }, + { + "index": 12, + "name": "Error", + "color": "red" + }, + { + "index": 13, + "name": "Extra" + } + ], + "error": [ + { + "index": 0, + "name": "LogName", + "color": "cyan", + "hide": true + }, + { + "index": 1, + "name": "LogType", + "color": "magenta", + "hide": true + }, + { + "index": 2, + "name": "LogTime", + "format": "time" + }, + { + "index": 3, + "name": "TraceId", + "color": "blue" + }, + { + "index": 4, + "name": "Image", + "color": "darkGray", + "hide": true + }, + { + "index": 5, + "name": "Server", + "color": "darkGray", + "hide": true + }, + { + "index": 6, + "name": "Error", + "color": "red", + "withoutKey": true + }, + { + "index": 7, + "name": "CallStacks" + }, + { + "index": 8, + "name": "Extra" + } + ], + "info": [ + { + "index": 0, + "name": "LogName", + "color": "cyan", + "hide": true + }, + { + "index": 1, + "name": "LogType", + "color": "magenta", + "hide": true + }, + { + "index": 2, + "name": "LogTime", + "format": "time" + }, + { + "index": 3, + "name": "TraceId", + "color": "blue" + }, + { + "index": 4, + "name": "Image", + "color": "darkGray", + "hide": true + }, + { + "index": 5, + "name": "Server", + "color": "darkGray", + "hide": true + }, + { + "index": 6, + "name": "Info", + "color": "cyan", + "withoutKey": true + }, + { + "index": 7, + "name": "Extra" + } + ], + "warning": [ + { + "index": 0, + "name": "LogName", + "color": "cyan", + "hide": true + }, + { + "index": 1, + "name": "LogType", + "color": "magenta", + "hide": true + }, + { + "index": 2, + "name": "LogTime", + "format": "time" + }, + { + "index": 3, + "name": "TraceId", + "color": "blue" + }, + { + "index": 4, + "name": "Image", + "color": "darkGray", + "hide": true + }, + { + "index": 5, + "name": "Server", + "color": "darkGray", + "hide": true + }, + { + "index": 6, + "name": "Warning", + "color": "yellow", + "withoutKey": true + }, + { + "index": 7, + "name": "CallStacks" + }, + { + "index": 8, + "name": "Extra" + } + ] +} diff --git a/DocTpl.html b/DocTpl.html new file mode 100644 index 0000000..4d655ae --- /dev/null +++ b/DocTpl.html @@ -0,0 +1,216 @@ + + + + + {{.title}} + + + + +
+ {{range .api}} + +
+
+ {{.Path}} + {{.Memo}} + {{if ne .Method ""}}{{end}} + + {{if ne .Type "Web"}}{{end}} +
+
+ + {{if isMap .In}} + + + + {{range $k, $v := .In}} + + + + + {{end}} + {{else}} + + + + {{end}} +
Request
{{$k}}{{toText $v}}
{{.In}}
+ + {{if isMap .Out}} + + + + {{range $k, $v := .Out}} + + + + + {{end}} + {{else}} + + + + {{end}} +
Response
{{$k}}{{toText $v}}
{{.Out}}
+
+ {{else}} +
no document
+ {{end}} +
+
+ + diff --git a/README.md b/README.md new file mode 100644 index 0000000..bf588e2 --- /dev/null +++ b/README.md @@ -0,0 +1,56 @@ +# go/service (核心微服务框架) + +极简、自动化的 Web 与 WebSocket 服务框架,实现极致的依赖注入与路由映射。 + +## 核心特性 +- **路由反射**: 自动解析函数参数,支持 `*Request`, `*Response`, `*log.Logger` 及自定义结构体自动注入。 +- **自动校验**: 集成 `verify` 引擎,通过 Struct Tag 实现入参合法性自动检查。 +- **功能闭环**: 内置静态文件服务、WebSocket (带 Action 路由)、URL 重写、反向代理(对接 Discover)。 +- **零摩擦启动**: 支持命令行指令管理 (start/stop/help) 及异步平滑启停。 + +## API 指南 + +### 1. 服务注册 +```go +import "apigo.cc/go/service" + +// 注册标准 Web 服务 +service.Register(0, "/hello", func(in struct{ Name string }) string { + return "Hello " + in.Name +}, "打招呼接口") + +// 注册 Restful 服务 +service.Restful(0, "POST", "/user/{id}", func(args map[string]any) service.Result { + res := service.Result{} + res.OK() + return res +}, "更新用户") +``` + +### 2. WebSocket 支持 +```go +ar := service.RegisterWebsocket(0, "/ws", onOpen, onClose, "聊天室") +ar.RegisterAction(0, "chat", func(in ChatMessage, sess *MySession) { + // 处理消息 +}, "发送消息") +``` + +### 3. 增强插件 +- **静态文件**: `service.Static("/ui", "./static_dir")` +- **URL 重写**: `service.Rewrite("/old", "/new")` +- **反向代理**: `service.Proxy(0, "/api", "other_app", "/api")` + +### 4. 生命周期管理 +```go +func main() { + service.CheckCmd() // 处理 start/stop/help 指令 + service.Start() // 阻塞启动 +} +``` + +## 基础设施对齐 +- **类型转换**: `apigo.cc/go/cast` +- **日志系统**: `apigo.cc/go/log` +- **服务发现**: `apigo.cc/go/discover` +- **分布式 ID**: `apigo.cc/go/id` +- **文件操作**: `apigo.cc/go/file` diff --git a/config.go b/config.go new file mode 100644 index 0000000..2b6812b --- /dev/null +++ b/config.go @@ -0,0 +1,62 @@ +package service + +// CertSet SSL 证书配置 +type CertSet struct { + CertFile string + KeyFile string +} + +// ServiceConfig 核心服务配置 +type ServiceConfig struct { + Listen string // 监听端口(|隔开多个监听)(,隔开多个选项),例如 80,http|443|443:h2|127.0.0.1:8080,h2c + SSL map[string]*CertSet // SSL 证书配置,key 为域名 + NoLogGets bool // 不记录 GET 请求的日志 + NoLogHeaders string // 不记录请求头中包含的这些字段,多个字段用逗号分隔 + LogInputArrayNum int // 请求字段中容器类型在日志打印个数限制 + LogInputFieldSize int // 请求字段中单个字段在日志打印长度限制 + NoLogOutputFields string // 不记录响应字段中包含的这些字段 + LogOutputArrayNum int // 响应字段中容器类型在日志打印个数限制 + LogOutputFieldSize int // 响应字段中单个字段在日志打印长度限制 + LogWebsocketAction bool // 记录 Websocket 中每个 Action 的请求日志 + Compress bool // 是否启用压缩 + CompressMinSize int // 启用压缩的最小长度 + CompressMaxSize int // 启用压缩的最大长度 + CheckDomain string // 心跳检测时使用域名 + AccessTokens map[string]*int // 指定 Access-Token 验证及其对应的 auth-level + RedirectTimeout int // Proxy 和 Discover 发起请求时的超时时间 (ms) + AcceptXRealIpWithoutRequestId bool // 是否允许头部没有携带请求ID的 X-Real-IP 信息 + StatisticTime bool // 是否开启请求时间统计 + StatisticTimeInterval int // 统计时间间隔 (ms) + Fast bool // 是否启用快速模式 + MaxUploadSize int64 // 最大上传文件大小 (Bytes) + IpPrefix string // Discover 服务发现时指定使用的 IP 网段 + Cpu int // CPU 占用的核数限制 + Memory int // 内存限制 (MB) + CpuMonitor bool // 记录 CPU 使用情况 + MemoryMonitor bool // 记录内存使用情况 + CpuLimitValue uint // CPU 自动重启阈值 (10-100) + MemoryLimitValue uint // 内存自动重启阈值 (10-100) + CpuLimitTimes uint // CPU 报警阈值连续次数 + MemoryLimitTimes uint // 内存报警阈值连续次数 + CookieScope string // Session Cookie 有效范围: host|domain|topDomain + SessionWithoutCookie bool // Session 禁用 Cookie + DeviceWithoutCookie bool // 设备ID禁用 Cookie + IdServer string // Redis 服务器连接 (用于全局唯一 ID 生成) + KeepKeyCase bool // 是否保持 Key 的首字母大小写 + IndexFiles []string // 静态文件索引文件 + IndexDir bool // 访问目录时显示文件列表 + ReadTimeout int // 读取请求的超时时间 (ms) + ReadHeaderTimeout int // 读取请求头的超时时间 (ms) + WriteTimeout int // 响应写入的超时时间 (ms) + IdleTimeout int // 连接空闲超时时间 (ms) + MaxHeaderBytes int // 请求头的最大字节数 + MaxHandlers int // 每个连接的最大处理程序数量 + MaxConcurrentStreams uint32 // 每个连接的最大并发流数量 + MaxDecoderHeaderTableSize uint32 // 解码器头表的最大大小 + MaxEncoderHeaderTableSize uint32 // 编码器头表的最大大小 + MaxReadFrameSize uint32 // 单个帧的最大读取大小 + MaxUploadBufferPerConnection int32 // 每个连接的最大上传缓冲区大小 + MaxUploadBufferPerStream int32 // 每个流的最大上传缓冲区大小 +} + +var Config = ServiceConfig{} diff --git a/document.go b/document.go new file mode 100644 index 0000000..428148c --- /dev/null +++ b/document.go @@ -0,0 +1,162 @@ +package service + +import ( + "apigo.cc/go/cast" + _ "embed" + "encoding/json" + "reflect" +) + +// Api 接口文档信息 +type Api struct { + Type string + Path string + AuthLevel int + Method string + In any + Out any + Memo string +} + +//go:embed DocTpl.html +var defaultDocTpl string + +// MakeDocument 生成文档数据 +func MakeDocument() []Api { + out := make([]Api, 0) + + // 1. Rewrite + rewritesLock.RLock() + for _, a := range rewrites { + out = append(out, Api{ + Type: "Rewrite", + Path: a.fromPath + " -> " + a.toPath, + }) + } + rewritesLock.RUnlock() + + // 2. Proxy + proxiesLock.RLock() + for _, a := range proxies { + out = append(out, Api{ + Type: "Proxy", + Path: a.fromPath + " -> " + a.toApp + ":" + a.toPath, + }) + } + proxiesLock.RUnlock() + + // 3. Web Services + webServicesLock.RLock() + for _, a := range webServicesList { + if a.options.NoDoc { + continue + } + api := Api{ + Type: "Web", + Path: a.path, + AuthLevel: a.authLevel, + Method: a.method, + Memo: a.memo, + } + if a.inType != nil { + api.In = getType(a.inType) + } + if a.funcType.NumOut() > 0 { + api.Out = getType(a.funcType.Out(0)) + } + out = append(out, api) + } + webServicesLock.RUnlock() + + // 4. WebSocket Services + websocketServicesLock.RLock() + for _, a := range websocketServices { + api := Api{ + Type: "WebSocket", + Path: a.path, + AuthLevel: a.authLevel, + Memo: a.memo, + } + if a.openFuncType != nil && a.openFuncType.NumIn() > 0 { + // Find struct in + for i := 0; i < a.openFuncType.NumIn(); i++ { + t := a.openFuncType.In(i) + if t.Kind() == reflect.Struct { + api.In = getType(t) + break + } + } + } + out = append(out, api) + + for name, act := range a.actions { + actionApi := Api{ + Type: "Action", + Path: name, + AuthLevel: act.authLevel, + Memo: act.memo, + } + if act.inType != nil { + actionApi.In = getType(act.inType) + } + if act.funcType.NumOut() > 0 { + actionApi.Out = getType(act.funcType.Out(0)) + } + out = append(out, actionApi) + } + } + websocketServicesLock.RUnlock() + + return out +} + +// MakeJsonDocument 生成 JSON 格式文档 +func MakeJsonDocument() string { + apis := MakeDocument() + data, _ := json.MarshalIndent(map[string]any{ + "api": apis, + }, "", "\t") + return string(data) +} + +func getType(t reflect.Type) any { + if t == nil { + return "" + } + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + + switch t.Kind() { + case reflect.Struct: + outs := Map{} + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if f.Anonymous { + if subMap, ok := getType(f.Type).(Map); ok { + for k, v := range subMap { + outs[k] = v + } + } + } else { + outs[cast.GetLowerName(f.Name)] = getType(f.Type) + } + } + return outs + case reflect.Map: + return map[string]any{t.Key().String(): getType(t.Elem())} + case reflect.Slice: + return []any{getType(t.Elem())} + case reflect.Interface: + return "Any" + default: + return t.String() + } +} + +// 自动注册文档服务 +func init() { + Register(0, "/__DOC__", func() string { + return MakeJsonDocument() + }, "API Document") +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..479da22 --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module apigo.cc/go/service + +go 1.25.0 + +require github.com/gorilla/websocket v1.5.3 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..25a9fc4 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= diff --git a/handler.go b/handler.go new file mode 100644 index 0000000..181fa08 --- /dev/null +++ b/handler.go @@ -0,0 +1,319 @@ +package service + +import ( + "apigo.cc/go/cast" + "apigo.cc/go/id" + "apigo.cc/go/log" + "apigo.cc/go/standard" + "encoding/json" + "io" + "net/http" + "reflect" + "strings" + "sync/atomic" + "time" +) + +type routeHandler struct { + webRequestingNum int64 +} + +func (rh *routeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&rh.webRequestingNum, 1) + defer atomic.AddInt64(&rh.webRequestingNum, -1) + + startTime := time.Now() + requestId := r.Header.Get(standard.DiscoverHeaderRequestId) + if requestId == "" { + requestId = id.MakeID(12) + r.Header.Set(standard.DiscoverHeaderRequestId, requestId) + } + + request := NewRequest(r) + request.Id = requestId + response := NewResponse(w) + response.Id = requestId + defer response.checkWriteHeader() + + // 处理 SessionId 和 DeviceId + handleClientKeys(request, response) + + requestLogger := log.New(requestId) + + // 0. 处理重写 (Rewrite) + if processRewrite(request, response, requestLogger) { + return + } + + // 处理代理 (Proxy) + if processProxy(request, response, requestLogger) { + return + } + + // 1. 路由匹配 + path := r.URL.Path + host := r.Host + + // 处理静态文件 + if processStatic(path, request, response, requestLogger) { + return + } + + s, ws := findService(r.Method, host, path) + + // 2. 参数解析 (Form & Body) + args := make(map[string]any) + parseRequestArgs(request, args) + + // 3. 前置过滤器 + var result any + for _, filter := range inFilters { + result = filter(&args, request, response, requestLogger) + if result != nil { + break + } + } + + // 4. 处理业务执行 (WS 或 Web) + if result == nil { + if ws != nil { + doWebsocketService(ws, request, response, requestLogger) + return + } else if s != nil { + // 鉴权 + pass, obj := checkAuth(s, request, response, args, requestLogger) + if !pass { + if !response.changed { + response.WriteHeader(http.StatusForbidden) + } + return + } + // 执行业务 + result = doWebService(s, request, response, args, nil, requestLogger, obj) + } + } + + if s == nil && result == nil { + response.WriteHeader(http.StatusNotFound) + return + } + + // 5. 后置过滤器 + for _, filter := range outFilters { + newResult, done := filter(args, request, response, result, requestLogger) + if newResult != nil { + result = newResult + } + if done { + break + } + } + + // 6. 输出结果 + outputResult(response, result) + + // 7. 记录日志 + _ = startTime +} + +func findService(method, host, path string) (*webServiceType, *websocketServiceType) { + webServicesLock.RLock() + defer webServicesLock.RUnlock() + + // 1. Web Service 匹配 + if s, exists := webServices[method+path]; exists { + return s, nil + } + if s, exists := webServices[path]; exists { + return s, nil + } + + // 2. WebSocket 匹配 + websocketServicesLock.RLock() + defer websocketServicesLock.RUnlock() + if ws, exists := websocketServices[path]; exists { + return nil, ws + } + + // 3. 正则匹配 + for i := len(regexWebServices) - 1; i >= 0; i-- { + s := regexWebServices[i] + if s.method != "" && s.method != method { + continue + } + if s.pathMatcher != nil && s.pathMatcher.MatchString(path) { + return s, nil + } + } + + return nil, nil +} + +func parseRequestArgs(request *Request, args map[string]any) { + // Query params + query := request.URL.Query() + for k, v := range query { + if len(v) == 1 { + args[k] = v[0] + } else { + args[k] = v + } + } + + // Form params + if request.Method == http.MethodPost || request.Method == http.MethodPut { + contentType := request.Header.Get("Content-Type") + if strings.HasPrefix(contentType, "application/json") { + body, _ := io.ReadAll(request.Body) + _ = request.Body.Close() + if len(body) > 0 { + _ = json.Unmarshal(body, &args) + } + } else { + _ = request.ParseForm() + for k, v := range request.Form { + if len(v) == 1 { + args[k] = v[0] + } else { + args[k] = v + } + } + } + } +} + +func checkAuth(s *webServiceType, request *Request, response *Response, args map[string]any, logger *log.Logger) (bool, any) { + ac := webAuthCheckers[s.authLevel] + if ac == nil { + ac = webAuthChecker + } + if ac == nil { + return true, nil + } + return ac(s.authLevel, logger, &request.RequestURI, args, request, response, &s.options) +} + +func doWebService(service *webServiceType, request *Request, response *Response, args map[string]any, + result any, logger *log.Logger, object any) any { + if result != nil { + return result + } + + params := make([]reflect.Value, service.parmsNum) + for i := 0; i < service.parmsNum; i++ { + t := service.funcType.In(i) + switch i { + case service.requestIndex: + params[i] = reflect.ValueOf(request) + case service.httpRequestIndex: + params[i] = reflect.ValueOf(request.Request) + case service.responseIndex: + params[i] = reflect.ValueOf(response) + case service.responseWriterIndex: + params[i] = reflect.ValueOf(response.Writer) + case service.loggerIndex: + params[i] = reflect.ValueOf(logger) + case service.inIndex: + in := reflect.New(service.inType).Interface() + cast.Convert(in, args) + // 参数校验 + if service.inType.Kind() == reflect.Struct { + if ok, _ := VerifyStruct(in, logger); !ok { + response.WriteHeader(http.StatusBadRequest) + return "parameter verification failed" + } + } + params[i] = reflect.ValueOf(in).Elem() + default: + // 尝试依赖注入 + if obj := GetInject(t); obj != nil { + params[i] = reflect.ValueOf(obj) + } else { + params[i] = reflect.New(t).Elem() + } + } + } + + outs := service.funcValue.Call(params) + if len(outs) > 0 { + return outs[0].Interface() + } + return "" +} + +func outputResult(response *Response, result any) { + if result == nil { + return + } + + var data []byte + contentType := "" + + switch v := result.(type) { + case string: + data = []byte(v) + case []byte: + data = v + default: + data, _ = cast.ToJSONBytes(result) + contentType = "application/json; charset=UTF-8" + } + + if contentType != "" && response.Header().Get("Content-Type") == "" { + response.Header().Set("Content-Type", contentType) + } + _, _ = response.Write(data) +} + +func handleClientKeys(request *Request, response *Response) { + // SessionId + if usedSessionIdKey != "" { + sessionId := request.Header.Get(usedSessionIdKey) + if sessionId == "" && !Config.SessionWithoutCookie { + if ck, err := request.Cookie(usedSessionIdKey); err == nil { + sessionId = ck.Value + } + } + if sessionId == "" { + if sessionIdMaker != nil { + sessionId = sessionIdMaker() + } else { + sessionId = id.MakeID(14) + } + if !Config.SessionWithoutCookie { + http.SetCookie(response.Writer, &http.Cookie{ + Name: usedSessionIdKey, + Value: sessionId, + Path: "/", + HttpOnly: true, + }) + } + } + request.Header.Set(standard.DiscoverHeaderSessionId, sessionId) + response.Header().Set(usedSessionIdKey, sessionId) + } + + // DeviceId + if usedDeviceIdKey != "" { + deviceId := request.Header.Get(usedDeviceIdKey) + if deviceId == "" && !Config.DeviceWithoutCookie { + if ck, err := request.Cookie(usedDeviceIdKey); err == nil { + deviceId = ck.Value + } + } + if deviceId == "" { + deviceId = id.MakeID(14) + if !Config.DeviceWithoutCookie { + http.SetCookie(response.Writer, &http.Cookie{ + Name: usedDeviceIdKey, + Value: deviceId, + Path: "/", + Expires: time.Now().AddDate(10, 0, 0), + HttpOnly: true, + }) + } + } + request.Header.Set(standard.DiscoverHeaderDeviceId, deviceId) + response.Header().Set(usedDeviceIdKey, deviceId) + } +} diff --git a/handler_test.go b/handler_test.go new file mode 100644 index 0000000..a5b03f3 --- /dev/null +++ b/handler_test.go @@ -0,0 +1,67 @@ +package service + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestServeHTTP(t *testing.T) { + // 注册服务 + handler := func(in struct{ Name string }) string { + return "Hello " + in.Name + } + Register(0, "/hello", handler, "say hello") + + rh := &routeHandler{} + + // 模拟请求 + req := httptest.NewRequest("POST", "/hello", strings.NewReader(`{"name":"Star"}`)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + rh.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + + body := w.Body.String() + if body != "Hello Star" { + t.Errorf("Expected 'Hello Star', got '%s'", body) + } +} + +func TestServeHTTP_404(t *testing.T) { + rh := &routeHandler{} + req := httptest.NewRequest("GET", "/notfound", nil) + w := httptest.NewRecorder() + + rh.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("Expected status 404, got %d", w.Code) + } +} + +func TestServeHTTP_VerifyFailed(t *testing.T) { + type ValidIn struct { + Age int `verify:"between:18-100"` + } + handler := func(in ValidIn) string { + return "ok" + } + Register(0, "/verify", handler, "test verify") + + rh := &routeHandler{} + req := httptest.NewRequest("POST", "/verify", strings.NewReader(`{"age":10}`)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + rh.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("Expected status 400, got %d", w.Code) + } +} diff --git a/proxy.go b/proxy.go new file mode 100644 index 0000000..969c923 --- /dev/null +++ b/proxy.go @@ -0,0 +1,171 @@ +package service + +import ( + "apigo.cc/go/discover" + gohttp "apigo.cc/go/http" + "apigo.cc/go/log" + "fmt" + "io" + "net/http" + "regexp" + "strings" + "sync" + "time" +) + +type proxyInfo struct { + matcher *regexp.Regexp + authLevel int + fromPath string + toApp string + toPath string +} + +var ( + proxies = make(map[string]*proxyInfo) + regexProxies = make([]*proxyInfo, 0) + proxyBy func(*Request) (int, *string, *string, map[string]string) + proxiesLock = sync.RWMutex{} + + httpClientPool *gohttp.Client +) + +// Proxy 注册代理规则 +func Proxy(authLevel int, path string, toApp, toPath string) { + p := &proxyInfo{authLevel: authLevel, fromPath: path, toApp: toApp, toPath: toPath} + if strings.Contains(path, "(") { + matcher, err := regexp.Compile("^" + path + "$") + if err == nil { + p.matcher = matcher + proxiesLock.Lock() + regexProxies = append(regexProxies, p) + proxiesLock.Unlock() + } + } else { + proxiesLock.Lock() + proxies[path] = p + proxiesLock.Unlock() + } +} + +// SetProxyBy 设置动态代理函数 +func SetProxyBy(by func(request *Request) (authLevel int, toApp, toPath *string, headers map[string]string)) { + proxyBy = by +} + +func findProxy(request *Request) (int, *string, *string) { + requestPath := request.RequestURI + queryString := "" + if pos := strings.Index(requestPath, "?"); pos != -1 { + queryString = requestPath[pos:] + requestPath = requestPath[:pos] + } + + proxiesLock.RLock() + defer proxiesLock.RUnlock() + + if pi, ok := proxies[requestPath]; ok { + toPath := pi.toPath + queryString + return pi.authLevel, &pi.toApp, &toPath + } + + for _, pi := range regexProxies { + if pi.matcher != nil { + finds := pi.matcher.FindAllStringSubmatch(requestPath, 1) + if len(finds) > 0 { + toApp := pi.toApp + toPath := pi.toPath + for i, part := range finds[0] { + toApp = strings.ReplaceAll(toApp, fmt.Sprintf("$%d", i), part) + toPath = strings.ReplaceAll(toPath, fmt.Sprintf("$%d", i), part) + } + toPath += queryString + return pi.authLevel, &toApp, &toPath + } + } + } + + return 0, nil, nil +} + +func processProxy(request *Request, response *Response, logger *log.Logger) bool { + authLevel, proxyToApp, proxyToPath := findProxy(request) + var proxyHeaders map[string]string + + if proxyBy != nil && (proxyToApp == nil || proxyToPath == nil || *proxyToApp == "" || *proxyToPath == "") { + authLevel, proxyToApp, proxyToPath, proxyHeaders = proxyBy(request) + } + + if proxyToApp == nil || proxyToPath == nil || *proxyToApp == "" || *proxyToPath == "" { + return false + } + + // 鉴权 + pass, obj := checkAuthForProxy(authLevel, request, response, logger) + if !pass { + if !response.changed { + response.WriteHeader(http.StatusForbidden) + } + return true + } + _ = obj // Currently unused in proxy + + app := *proxyToApp + path := *proxyToPath + + // 构建自定义头部 + headerArgs := make([]string, 0) + for k, v := range proxyHeaders { + headerArgs = append(headerArgs, k, v) + } + + if strings.Contains(app, "://") { + // 直接 URL 代理 + if httpClientPool == nil { + httpClientPool = gohttp.NewClient(time.Duration(Config.RedirectTimeout) * time.Millisecond) + } + res := httpClientPool.ManualDoByRequest(request.Request, request.Method, app+path, request.Body, headerArgs...) + copyResponse(res, response, logger) + } else { + // Discover 代理 + caller := discover.NewCaller(request.Request, logger) + caller.NoBody = true + res, _ := caller.ManualDoWithNode(request.Method, app, "", path, request.Body, headerArgs...) + copyResponse(res, response, logger) + } + + return true +} + +func checkAuthForProxy(authLevel int, request *Request, response *Response, logger *log.Logger) (bool, any) { + ac := webAuthCheckers[authLevel] + if ac == nil { + ac = webAuthChecker + } + if ac == nil { + return true, nil + } + return ac(authLevel, logger, &request.RequestURI, nil, request, response, nil) +} + +func copyResponse(res *gohttp.Result, response *Response, logger *log.Logger) { + if res.Error != nil || res.Response == nil { + response.WriteHeader(http.StatusBadGateway) + if res.Error != nil { + _, _ = response.WriteString(res.Error.Error()) + } + return + } + + for k, v := range res.Response.Header { + response.Header().Set(k, v[0]) + } + response.WriteHeader(res.Response.StatusCode) + if res.Response.Body != nil { + defer res.Response.Body.Close() + _, err := io.Copy(response.Writer, res.Response.Body) + if err != nil { + logger.Error("proxy copy body failed", "error", err.Error()) + } + } +} diff --git a/proxy_test.go b/proxy_test.go new file mode 100644 index 0000000..5f177b6 --- /dev/null +++ b/proxy_test.go @@ -0,0 +1,62 @@ +package service + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestRewrite(t *testing.T) { + // 注册重写规则 + Rewrite("/old", "/new") + Rewrite("/regex/(.*)", "/target/$1") + + // 注册目标服务 + Register(0, "/new", func() string { return "new content" }, "new") + Register(0, "/target/123", func() string { return "target content" }, "target") + + rh := &routeHandler{} + + // 测试精确匹配重写 + req1 := httptest.NewRequest("GET", "/old", nil) + w1 := httptest.NewRecorder() + rh.ServeHTTP(w1, req1) + if w1.Body.String() != "new content" { + t.Errorf("Expected 'new content', got '%s'", w1.Body.String()) + } + + // 测试正则匹配重写 + req2 := httptest.NewRequest("GET", "/regex/123", nil) + w2 := httptest.NewRecorder() + rh.ServeHTTP(w2, req2) + if w2.Body.String() != "target content" { + t.Errorf("Expected 'target content', got '%s'", w2.Body.String()) + } +} + +func TestProxyDirect(t *testing.T) { + // 启动后端服务器 + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Backend", "ok") + w.Write([]byte("backend content")) + })) + defer backend.Close() + + // 注册代理规则 + Proxy(0, "/proxy", backend.URL, "/hello") + + rh := &routeHandler{} + req := httptest.NewRequest("GET", "/proxy", nil) + w := httptest.NewRecorder() + rh.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected 200, got %d", w.Code) + } + if w.Header().Get("X-Backend") != "ok" { + t.Error("Header X-Backend mismatch") + } + if w.Body.String() != "backend content" { + t.Errorf("Expected 'backend content', got '%s'", w.Body.String()) + } +} diff --git a/request.go b/request.go new file mode 100644 index 0000000..800a9ad --- /dev/null +++ b/request.go @@ -0,0 +1,139 @@ +package service + +import ( + "apigo.cc/go/cast" + "apigo.cc/go/standard" + "io" + "mime/multipart" + "net" + "net/http" + "net/textproto" + "net/url" + "os" + "path/filepath" +) + +// UploadFile 上传文件结构 +type UploadFile struct { + fileHeader *multipart.FileHeader + Filename string + Header textproto.MIMEHeader + Size int64 +} + +// Open 打开上传文件 +func (f *UploadFile) Open() (multipart.File, error) { + return f.fileHeader.Open() +} + +// Save 保存上传文件到本地 +func (f *UploadFile) Save(filename string) error { + dir := filepath.Dir(filename) + if _, err := os.Stat(dir); os.IsNotExist(err) { + _ = os.MkdirAll(dir, 0755) + } + + dst, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) + if err != nil { + return err + } + defer dst.Close() + + src, err := f.fileHeader.Open() + if err != nil { + return err + } + defer src.Close() + + _, err = io.Copy(dst, src) + return err +} + +// Content 获取上传文件内容 +func (f *UploadFile) Content() ([]byte, error) { + src, err := f.fileHeader.Open() + if err != nil { + return nil, err + } + defer src.Close() + return io.ReadAll(src) +} + +// Request 封装 http.Request +type Request struct { + *http.Request + contextValues map[string]any + Id string +} + +// NewRequest 创建 Request 包装 +func NewRequest(httpRequest *http.Request) *Request { + return &Request{ + Request: httpRequest, + contextValues: make(map[string]any), + } +} + +// ResetPath 重写请求路径 +func (r *Request) ResetPath(path string) { + r.RequestURI = path + if u, err := url.Parse(path); err == nil { + r.URL = u + } +} + +// Set 设置请求上下文变量 +func (r *Request) Set(key string, value any) { + r.contextValues[key] = value +} + +// Get 获取请求上下文变量 +func (r *Request) Get(key string) any { + return r.contextValues[key] +} + +// MakeUrl 根据当前请求构建完整 URL +func (r *Request) MakeUrl(path string) string { + scheme := r.Header.Get(standard.DiscoverHeaderScheme) + if scheme == "" { + scheme = "http" + } + host := r.Header.Get(standard.DiscoverHeaderHost) + if host == "" { + host = r.Host + } + return scheme + "://" + host + path +} + +// GetSessionId 获取会话 ID +func (r *Request) GetSessionId() string { + sessionId := r.Header.Get(Config.Listen) // Wait, this should be usedSessionIdKey + // TODO: Fix dependency on global usedSessionIdKey + return sessionId +} + +// SetUserId 设置用户 ID(传递给下游) +func (r *Request) SetUserId(userId string) { + r.Header.Set(standard.DiscoverHeaderUserId, userId) +} + +// GetRealIp 获取真实 IP +func (r *Request) GetRealIp() string { + ip := r.Header.Get(standard.DiscoverHeaderClientIp) + if ip == "" { + ip = r.Header.Get(standard.DiscoverHeaderForwardedFor) + } + if ip == "" { + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err == nil { + return host + } + return r.RemoteAddr + } + return ip +} + +// GetLowerName (Aliased from cast) +func GetLowerName(s string) string { + return cast.GetLowerName(s) +} diff --git a/response.go b/response.go new file mode 100644 index 0000000..e51d926 --- /dev/null +++ b/response.go @@ -0,0 +1,152 @@ +package service + +import ( + "apigo.cc/go/cast" + "io" + "net/http" + "os" +) + +// Response 封装 http.ResponseWriter +type Response struct { + Id string + Writer http.ResponseWriter + status int + outLen int + changed bool + headerWritten bool + dontLog200 bool + dontLogArgs []string + ProxyHeader *http.Header +} + +// NewResponse 创建 Response 包装 +func NewResponse(writer http.ResponseWriter) *Response { + return &Response{ + Writer: writer, + status: http.StatusOK, + } +} + +// Header 获取响应头部 +func (r *Response) Header() http.Header { + r.changed = true + if r.ProxyHeader != nil { + return *r.ProxyHeader + } + return r.Writer.Header() +} + +// Write 写入响应内容 +func (r *Response) Write(bytes []byte) (int, error) { + r.checkWriteHeader() + r.changed = true + r.outLen += len(bytes) + if r.ProxyHeader != nil { + r.copyProxyHeader() + } + return r.Writer.Write(bytes) +} + +// WriteString 写入字符串响应 +func (r *Response) WriteString(s string) (int, error) { + return r.Write([]byte(s)) +} + +// WriteHeader 设置响应状态码 +func (r *Response) WriteHeader(code int) { + r.changed = true + r.status = code + if r.ProxyHeader != nil && (r.status == http.StatusBadGateway || r.status == http.StatusServiceUnavailable || r.status == http.StatusGatewayTimeout) { + return + } + if r.ProxyHeader != nil { + r.copyProxyHeader() + } +} + +func (r *Response) checkWriteHeader() { + if !r.headerWritten { + r.headerWritten = true + if r.status != http.StatusOK { + r.Writer.WriteHeader(r.status) + } + } +} + +func (r *Response) copyProxyHeader() { + src := *r.ProxyHeader + dst := r.Writer.Header() + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } + r.ProxyHeader = nil +} + +// Flush 刷新响应缓冲区 +func (r *Response) Flush() { + if flusher, ok := r.Writer.(http.Flusher); ok { + flusher.Flush() + } +} + +// GetStatusCode 获取当前状态码 +func (r *Response) GetStatusCode() int { + return r.status +} + +// DontLog200 标记不记录 200 状态码的日志 +func (r *Response) DontLog200() { + r.dontLog200 = true +} + +// Location 设置重定向地址 +func (r *Response) Location(location string) { + r.WriteHeader(http.StatusFound) + r.Header().Set("Location", location) +} + +// SendFile 发送文件 +func (r *Response) SendFile(contentType, filename string) { + r.Header().Set("Content-Type", contentType) + // TODO: Integrate memory file support if needed + if fd, err := os.Open(filename); err == nil { + defer fd.Close() + _, _ = io.Copy(r, fd) + } +} + +// DownloadFile 下载文件 +func (r *Response) DownloadFile(contentType, filename string, data any) { + if contentType == "" { + contentType = "application/octet-stream" + } + r.Header().Set("Content-Type", contentType) + + if filename != "" { + r.Header().Set("Content-Disposition", "attachment; filename="+filename) + } + + var outBytes []byte + var reader io.Reader + + switch v := data.(type) { + case []byte: + outBytes = v + case string: + outBytes = []byte(v) + case io.Reader: + reader = v + default: + outBytes, _ = cast.ToJSONBytes(data) + } + + if outBytes != nil { + r.Header().Set("Content-Length", cast.String(len(outBytes))) + _, _ = r.Write(outBytes) + } else if reader != nil { + _, _ = io.Copy(r, reader) + } +} diff --git a/rewrite.go b/rewrite.go new file mode 100644 index 0000000..70d340b --- /dev/null +++ b/rewrite.go @@ -0,0 +1,113 @@ +package service + +import ( + "apigo.cc/go/log" + "fmt" + "net/url" + "regexp" + "strings" + "sync" +) + +type rewriteInfo struct { + matcher *regexp.Regexp + fromPath string + toPath string +} + +var ( + rewrites = make(map[string]*rewriteInfo) + regexRewrites = make([]*rewriteInfo, 0) + rewriteBy func(*Request) (string, bool) + rewritesLock = sync.RWMutex{} +) + +// Rewrite 注册重写规则 +func Rewrite(path string, toPath string) { + s := &rewriteInfo{fromPath: path, toPath: toPath} + + if strings.ContainsRune(path, '(') { + matcher, err := regexp.Compile("^" + path + "$") + if err == nil { + s.matcher = matcher + rewritesLock.Lock() + regexRewrites = append(regexRewrites, s) + rewritesLock.Unlock() + } + } else { + rewritesLock.Lock() + rewrites[path] = s + rewritesLock.Unlock() + } +} + +// SetRewriteBy 设置动态重写函数 +func SetRewriteBy(by func(request *Request) (toPath string, rewrite bool)) { + rewriteBy = by +} + +func processRewrite(request *Request, response *Response, logger *log.Logger) bool { + requestPath := request.RequestURI + queryString := "" + if pos := strings.Index(requestPath, "?"); pos != -1 { + queryString = requestPath[pos:] + requestPath = requestPath[:pos] + } + + var rewriteToPath string + var found bool + + rewritesLock.RLock() + // 1. 精确匹配 + if ri, ok := rewrites[requestPath]; ok { + rewriteToPath = ri.toPath + found = true + } + + // 2. 动态重写 + if !found && rewriteBy != nil { + rewriteToPath, found = rewriteBy(request) + } + + // 3. 正则匹配 + if !found { + for _, ri := range regexRewrites { + if ri.matcher != nil { + finds := ri.matcher.FindAllStringSubmatch(request.RequestURI, 1) + if len(finds) > 0 { + toPath := ri.toPath + for i, part := range finds[0] { + toPath = strings.ReplaceAll(toPath, fmt.Sprintf("$%d", i), part) + } + rewriteToPath = toPath + found = true + break + } + } + } + } + rewritesLock.RUnlock() + + if found { + if strings.Contains(rewriteToPath, "://") { + // 外部重定向 + if !strings.Contains(rewriteToPath, "?") && queryString != "" { + rewriteToPath += queryString + } + response.Header().Set("Location", rewriteToPath) + response.WriteHeader(302) + return true + } else { + // 内部重写 + logger.Info("rewrite", "from", request.RequestURI, "to", rewriteToPath) + if queryString != "" && !strings.Contains(rewriteToPath, "?") { + rewriteToPath += queryString + } + request.RequestURI = rewriteToPath + request.URL, _ = url.Parse(rewriteToPath) + return false // 继续后续处理 + } + } + + return false +} diff --git a/server.go b/server.go new file mode 100644 index 0000000..28a63b3 --- /dev/null +++ b/server.go @@ -0,0 +1,88 @@ +package service + +import ( + "apigo.cc/go/log" + "context" + "net" + "net/http" + "os" + "os/signal" + "syscall" + "time" +) + +// AsyncServer 异步服务实例 +type AsyncServer struct { + server *http.Server + listener net.Listener + Addr string + stopChan chan os.Signal + startChan chan bool +} + +// AsyncStart 异步启动服务 +func AsyncStart() *AsyncServer { + as := &AsyncServer{ + startChan: make(chan bool, 1), + stopChan: make(chan os.Signal, 1), + } + + go as.start() + + <-as.startChan + return as +} + +func (as *AsyncServer) start() { + if Config.Listen == "" { + Config.Listen = ":8080" // 默认端口 + } + + listener, err := net.Listen("tcp", Config.Listen) + if err != nil { + log.DefaultLogger.Error("failed to listen", "addr", Config.Listen, "error", err.Error()) + as.startChan <- false + return + } + + as.listener = listener + as.Addr = listener.Addr().String() + serverAddr = as.Addr + + as.server = &http.Server{ + Handler: &routeHandler{}, + } + + signal.Notify(as.stopChan, os.Interrupt, syscall.SIGTERM) + + go func() { + log.DefaultLogger.Info("service starting", "addr", as.Addr) + as.startChan <- true + if err := as.server.Serve(listener); err != nil && err != http.ErrServerClosed { + log.DefaultLogger.Error("server error", "error", err.Error()) + } + }() +} + +// Stop 停止服务 +func (as *AsyncServer) Stop() { + log.DefaultLogger.Info("service stopping") + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := as.server.Shutdown(ctx); err != nil { + log.DefaultLogger.Error("server shutdown error", "error", err.Error()) + } + log.DefaultLogger.Info("service stopped") +} + +// Wait 等待服务结束 (信号监听) +func (as *AsyncServer) Wait() { + <-as.stopChan + as.Stop() +} + +// Start 同步启动服务 +func Start() { + AsyncStart().Wait() +} diff --git a/server_test.go b/server_test.go new file mode 100644 index 0000000..70d00f3 --- /dev/null +++ b/server_test.go @@ -0,0 +1,31 @@ +package service + +import ( + "net/http" + "testing" +) + +func TestAsyncServer(t *testing.T) { + Config.Listen = ":0" // 随机端口 + as := AsyncStart() + if as.Addr == "" { + t.Fatal("AsyncStart failed to get address") + } + + // 测试服务是否可用 + resp, err := http.Get("http://" + as.Addr + "/__CHECK__") + if err == nil { + // 虽然没有注册 /__CHECK__,但应该返回 404 而非连接拒绝 + if resp.StatusCode != http.StatusNotFound { + t.Errorf("Expected 404, got %d", resp.StatusCode) + } + } + + as.Stop() + + // 确认服务已关闭 + _, err = http.Get("http://" + as.Addr + "/__CHECK__") + if err == nil { + t.Error("Server should be closed") + } +} diff --git a/service.go b/service.go new file mode 100644 index 0000000..5dea62c --- /dev/null +++ b/service.go @@ -0,0 +1,225 @@ +package service + +import ( + "apigo.cc/go/log" + "errors" + "reflect" + "regexp" + "strings" + "sync" +) + +// WebServiceOptions 服务注册选项 +type WebServiceOptions struct { + Priority int + NoDoc bool + NoBody bool + NoLog200 bool + Host string + Ext Map + // Limiters []*Limiter // TODO: Integrate Limiter +} + +// webServiceType 内部存储的服务元数据 +type webServiceType struct { + authLevel int + method string + path string + pathMatcher *regexp.Regexp + pathArgs []string + parmsNum int + inType reflect.Type + inIndex int + headersType reflect.Type + headersIndex int + requestIndex int + httpRequestIndex int + responseIndex int + responseWriterIndex int + loggerIndex int + callerIndex int + funcType reflect.Type + funcValue reflect.Value + options WebServiceOptions + data Map + memo string +} + +var ( + serverId string + serverAddr string + serverProto = "http" + serverProtoName = "http" + running = false + + webServices = make(map[string]*webServiceType) + regexWebServices = make([]*webServiceType, 0) + webServicesLock = sync.RWMutex{} + webServicesList = make([]*webServiceType, 0) + + websocketServices = make(map[string]*websocketServiceType) + regexWebsocketServices = make([]*websocketServiceType, 0) + websocketServicesLock = sync.RWMutex{} + websocketServicesList = make([]*websocketServiceType, 0) + + // 过滤器与拦截器 + inFilters = make([]func(*map[string]any, *Request, *Response, *log.Logger) any, 0) + outFilters = make([]func(map[string]any, *Request, *Response, any, *log.Logger) (any, bool), 0) + errorHandle func(any, *Request, *Response) any + webAuthChecker func(int, *log.Logger, *string, map[string]any, *Request, *Response, *WebServiceOptions) (pass bool, object any) + webAuthCheckers = make(map[int]func(int, *log.Logger, *string, map[string]any, *Request, *Response, *WebServiceOptions) (pass bool, object any)) + + // 注入点 + injectObjects = make(map[reflect.Type]any) + injectFunctions = make(map[reflect.Type]func() any) + + usedDeviceIdKey string + usedClientAppKey string + usedSessionIdKey string + sessionIdMaker func() string +) + +// SetClientKeys 设置客户端标识相关的 Key 映射 +func SetClientKeys(deviceIdKey, clientAppKey, sessionIdKey string) { + usedDeviceIdKey = deviceIdKey + usedClientAppKey = clientAppKey + usedSessionIdKey = sessionIdKey +} + +// SetSessionIdMaker 设置自定义会话 ID 生成器 +func SetSessionIdMaker(maker func() string) { + sessionIdMaker = maker +} + +// SetAuthChecker 设置全局鉴权器 +func SetAuthChecker(authChecker func(authLevel int, logger *log.Logger, url *string, in map[string]any, request *Request, response *Response, options *WebServiceOptions) (pass bool, object any)) { + webAuthChecker = authChecker +} + +// AddAuthChecker 为指定级别添加鉴权器 +func AddAuthChecker(authLevels []int, authChecker func(authLevel int, logger *log.Logger, url *string, in map[string]any, request *Request, response *Response, options *WebServiceOptions) (pass bool, object any)) { + for _, al := range authLevels { + webAuthCheckers[al] = authChecker + } +} + +// SetInFilter 设置前置过滤器 +func SetInFilter(filter func(in *map[string]any, request *Request, response *Response, logger *log.Logger) (out any)) { + inFilters = append(inFilters, filter) +} + +// SetOutFilter 设置后置过滤器 +func SetOutFilter(filter func(in map[string]any, request *Request, response *Response, out any, logger *log.Logger) (newOut any, isOver bool)) { + outFilters = append(outFilters, filter) +} + +// Register 注册服务(通用方法) +func Register(authLevel int, path string, serviceFunc any, memo string) { + Restful(authLevel, "", path, serviceFunc, memo) +} + +// Restful 注册指定方法的服务 +func Restful(authLevel int, method, path string, serviceFunc any, memo string) { + RestfulWithOptions(authLevel, method, path, serviceFunc, memo, WebServiceOptions{}) +} + +// RestfulWithOptions 注册带选项的服务 +func RestfulWithOptions(authLevel int, method, path string, serviceFunc any, memo string, options WebServiceOptions) { + s, err := makeCachedService(serviceFunc) + if err != nil { + // TODO: Log error properly when logger is ready + return + } + + s.authLevel = authLevel + s.options = options + s.method = method + s.path = path + s.memo = memo + + // 解析路径参数 {name} + finder, err := regexp.Compile("{(.*?)}") + if err == nil { + keyName := regexp.QuoteMeta(path) + finds := finder.FindAllStringSubmatch(path, 20) + for _, found := range finds { + keyName = strings.Replace(keyName, regexp.QuoteMeta(found[0]), "(.*?)", 1) + s.pathArgs = append(s.pathArgs, found[1]) + } + if len(s.pathArgs) > 0 { + s.pathMatcher, _ = regexp.Compile("^" + keyName + "$") + } + } + + webServicesLock.Lock() + defer webServicesLock.Unlock() + + // 简单路径匹配 + if s.pathMatcher == nil { + webServices[method+path] = s // TODO: Include Host in key + } else { + regexWebServices = append(regexWebServices, s) + } + webServicesList = append(webServicesList, s) +} + +func makeCachedService(matchedService any) (*webServiceType, error) { + funcType := reflect.TypeOf(matchedService) + if funcType.Kind() != reflect.Func { + return nil, errors.New("handler must be a function") + } + + targetService := &webServiceType{ + parmsNum: funcType.NumIn(), + inIndex: -1, + headersIndex: -1, + requestIndex: -1, + httpRequestIndex: -1, + responseIndex: -1, + responseWriterIndex: -1, + loggerIndex: -1, + callerIndex: -1, + funcType: funcType, + funcValue: reflect.ValueOf(matchedService), + } + + for i := 0; i < targetService.parmsNum; i++ { + t := funcType.In(i) + tStr := t.String() + switch tStr { + case "*service.Request": + targetService.requestIndex = i + case "*http.Request": + targetService.httpRequestIndex = i + case "*service.Response": + targetService.responseIndex = i + case "http.ResponseWriter": + targetService.responseWriterIndex = i + case "*log.Logger": + targetService.loggerIndex = i + default: + if t.Kind() == reflect.Struct || (t.Kind() == reflect.Map && t.Elem().Kind() == reflect.Interface) { + if targetService.inType == nil { + targetService.inIndex = i + targetService.inType = t + } else if targetService.headersType == nil { + targetService.headersIndex = i + targetService.headersType = t + } + } + } + } + + return targetService, nil +} + +// GetInject 获取注入对象 +func GetInject(dataType reflect.Type) any { + if obj, exists := injectObjects[dataType]; exists { + return obj + } + if factory, exists := injectFunctions[dataType]; exists { + return factory() + } + return nil +} diff --git a/service_test.go b/service_test.go new file mode 100644 index 0000000..1e4c400 --- /dev/null +++ b/service_test.go @@ -0,0 +1,54 @@ +package service + +import ( + "apigo.cc/go/log" + "testing" +) + +func TestServiceRegister(t *testing.T) { + handler := func(req *Request, logger *log.Logger) string { + return "ok" + } + + Register(0, "/test", handler, "test service") + + webServicesLock.RLock() + s := webServices["/test"] + webServicesLock.RUnlock() + + if s == nil { + t.Fatal("Service not registered") + } + + if s.requestIndex != 0 { + t.Errorf("requestIndex mismatch: expected 0, got %d", s.requestIndex) + } + if s.loggerIndex != 1 { + t.Errorf("loggerIndex mismatch: expected 1, got %d", s.loggerIndex) + } +} + +func TestRegexServiceRegister(t *testing.T) { + handler := func(args map[string]any) string { + return "ok" + } + + Register(0, "/user/{id}", handler, "get user") + + webServicesLock.RLock() + found := false + for _, s := range regexWebServices { + if s.path == "/user/{id}" { + found = true + if len(s.pathArgs) != 1 || s.pathArgs[0] != "id" { + t.Errorf("pathArgs mismatch: %v", s.pathArgs) + } + break + } + } + webServicesLock.RUnlock() + + if !found { + t.Fatal("Regex service not registered") + } +} diff --git a/starter.go b/starter.go new file mode 100644 index 0000000..970005f --- /dev/null +++ b/starter.go @@ -0,0 +1,49 @@ +package service + +import ( + "fmt" + "os" + "path/filepath" +) + +// StartCmd 命令行命令定义 +type StartCmd struct { + Name string + Comment string + Func func() +} + +var startCmds = []StartCmd{ + {"start", "Start server", Start}, +} + +// AddCmd 添加自定义命令行命令 +func AddCmd(name, comment string, function func()) { + startCmds = append(startCmds, StartCmd{name, comment, function}) +} + +// CheckCmd 检查并执行命令行命令 +func CheckCmd() { + if len(os.Args) > 1 { + cmd := os.Args[1] + if cmd == "help" || cmd == "--help" { + showHelp() + os.Exit(0) + } + + for _, cmdInfo := range startCmds { + if cmd == cmdInfo.Name { + cmdInfo.Func() + os.Exit(0) + } + } + } +} + +func showHelp() { + fmt.Printf("Usage: %s [command]\n\n", filepath.Base(os.Args[0])) + fmt.Println("Available commands:") + for _, cmdInfo := range startCmds { + fmt.Printf(" %-10s %s\n", cmdInfo.Name, cmdInfo.Comment) + } +} diff --git a/static.go b/static.go new file mode 100644 index 0000000..5d33d67 --- /dev/null +++ b/static.go @@ -0,0 +1,123 @@ +package service + +import ( + "apigo.cc/go/file" + "apigo.cc/go/log" + "mime" + "net/http" + "os" + "path/filepath" + "strings" + "sync" + "time" +) + +var ( + statics = make(map[string]*string) + staticsByHost = make(map[string]map[string]*string) + staticsByHostLock = sync.RWMutex{} +) + +// Static 注册静态文件目录 +func Static(path, rootPath string) { + StaticByHost(path, rootPath, "") +} + +// StaticByHost 为指定域名注册静态文件目录 +func StaticByHost(path, rootPath, host string) { + if !filepath.IsAbs(rootPath) { + if absPath, err := filepath.Abs(rootPath); err == nil { + rootPath = absPath + } + } + + staticsByHostLock.Lock() + defer staticsByHostLock.Unlock() + + if host == "" { + statics[path] = &rootPath + } else { + if staticsByHost[host] == nil { + staticsByHost[host] = make(map[string]*string) + } + staticsByHost[host][path] = &rootPath + } +} + +func getStaticFilePath(requestPath, host string) string { + staticsByHostLock.RLock() + defer staticsByHostLock.RUnlock() + + // 优先匹配指定域名的配置 + if hostConfig, exists := staticsByHost[host]; exists { + if filePath := findMatchedPath(hostConfig, requestPath); filePath != "" { + return filePath + } + } + + // 匹配全局配置 + return findMatchedPath(statics, requestPath) +} + +func findMatchedPath(config map[string]*string, requestPath string) string { + for urlPath, rootPath := range config { + if strings.HasPrefix(requestPath, urlPath) { + return filepath.Join(*rootPath, requestPath[len(urlPath):]) + } + } + return "" +} + +func processStatic(requestPath string, request *Request, response *Response, logger *log.Logger) bool { + filePath := getStaticFilePath(requestPath, request.Host) + if filePath == "" { + return false + } + + info, err := os.Stat(filePath) + if err != nil { + return false + } + + if info.IsDir() { + // 自动查找索引文件 + for _, indexFile := range Config.IndexFiles { + f := filepath.Join(filePath, indexFile) + if i, err := os.Stat(f); err == nil && !i.IsDir() { + filePath = f + info = i + break + } + } + } + + if info.IsDir() { + return false + } + + // 检查 304 + if ifModifiedSince := request.Header.Get("If-Modified-Since"); ifModifiedSince != "" { + if t, err := time.Parse(http.TimeFormat, ifModifiedSince); err == nil { + if !info.ModTime().Truncate(time.Second).After(t.Truncate(time.Second)) { + response.WriteHeader(http.StatusNotModified) + return true + } + } + } + + // 发送文件 + contentType := mime.TypeByExtension(filepath.Ext(filePath)) + if contentType == "" { + contentType = "application/octet-stream" + } + response.Header().Set("Content-Type", contentType) + response.Header().Set("Last-Modified", info.ModTime().UTC().Format(http.TimeFormat)) + + data, err := file.ReadBytes(filePath) + if err != nil { + return false + } + + _, _ = response.Write(data) + return true +} diff --git a/static_test.go b/static_test.go new file mode 100644 index 0000000..20a6056 --- /dev/null +++ b/static_test.go @@ -0,0 +1,43 @@ +package service + +import ( + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" +) + +func TestStaticService(t *testing.T) { + // 创建临时测试目录和文件 + tempDir, _ := os.MkdirTemp("", "static_test") + defer os.RemoveAll(tempDir) + + testFile := filepath.Join(tempDir, "index.html") + os.WriteFile(testFile, []byte("

Static Page

"), 0644) + + // 注册静态目录 + Static("/ui", tempDir) + + rh := &routeHandler{} + + // 测试成功访问 + req := httptest.NewRequest("GET", "/ui/index.html", nil) + w := httptest.NewRecorder() + rh.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected 200, got %d", w.Code) + } + if body := w.Body.String(); body != "

Static Page

" { + t.Errorf("Content mismatch: %s", body) + } + + // 测试 404 + req404 := httptest.NewRequest("GET", "/ui/notfound.html", nil) + w404 := httptest.NewRecorder() + rh.ServeHTTP(w404, req404) + if w404.Code != http.StatusNotFound { + t.Errorf("Expected 404 for missing file, got %d", w404.Code) + } +} diff --git a/types.go b/types.go new file mode 100644 index 0000000..5ce805f --- /dev/null +++ b/types.go @@ -0,0 +1,68 @@ +package service + +// Map 通用 Map 类型 +type Map = map[string]any + +// Arr 通用切片类型 +type Arr = []any + +// Argot 错误码/标识符类型 +type Argot string + +// Result 通用返回结构 +type Result struct { + Ok bool `json:"ok"` + Argot Argot `json:"argot,omitempty"` + Message string `json:"message,omitempty"` +} + +// CodeResult 带状态码的返回结构 +type CodeResult struct { + Code int `json:"code"` + Message string `json:"message,omitempty"` +} + +// ArgotInfo 标识符信息(用于文档生成) +type ArgotInfo struct { + Name Argot + Memo string +} + +// OK 设置成功状态 +func (r *Result) OK(argots ...Argot) { + r.Ok = true + if len(argots) > 0 { + r.Argot = argots[0] + } +} + +// Failed 设置失败状态 +func (r *Result) Failed(message string, argots ...Argot) { + r.Ok = false + r.Message = message + if len(argots) > 0 { + r.Argot = argots[0] + } +} + +// Done 根据布尔值设置状态 +func (r *Result) Done(ok bool, failedMessage string, argots ...Argot) { + r.Ok = ok + if !ok { + r.Message = failedMessage + if len(argots) > 0 { + r.Argot = argots[0] + } + } +} + +// OK 设置成功状态 (Code=1) +func (r *CodeResult) OK() { + r.Code = 1 +} + +// Failed 设置失败状态与错误码 +func (r *CodeResult) Failed(code int, message string) { + r.Code = code + r.Message = message +} diff --git a/types_test.go b/types_test.go new file mode 100644 index 0000000..65ba6ba --- /dev/null +++ b/types_test.go @@ -0,0 +1,41 @@ +package service + +import ( + "testing" +) + +func TestResult(t *testing.T) { + r := &Result{} + r.OK() + if !r.Ok { + t.Error("Result.OK() failed") + } + + r.Failed("error", Argot("ERR_CODE")) + if r.Ok || r.Message != "error" || r.Argot != "ERR_CODE" { + t.Error("Result.Failed() failed") + } + + r.Done(true, "never") + if !r.Ok { + t.Error("Result.Done(true) failed") + } + + r.Done(false, "failed", Argot("FAIL")) + if r.Ok || r.Message != "failed" || r.Argot != "FAIL" { + t.Error("Result.Done(false) failed") + } +} + +func TestCodeResult(t *testing.T) { + cr := &CodeResult{} + cr.OK() + if cr.Code != 1 { + t.Error("CodeResult.OK() failed") + } + + cr.Failed(500, "internal error") + if cr.Code != 500 || cr.Message != "internal error" { + t.Error("CodeResult.Failed() failed") + } +} diff --git a/utility.go b/utility.go new file mode 100644 index 0000000..293dcb0 --- /dev/null +++ b/utility.go @@ -0,0 +1,20 @@ +package service + +import ( + "apigo.cc/go/id" +) + +// MakeId 生成指定长度的 ID +func MakeId(size int) string { + return id.MakeID(size) +} + +// MakeIdForMysql 生成适用于 MySQL 的有序 ID +func MakeIdForMysql(size int) string { + return id.DefaultIDMaker.GetForMysql(size) +} + +// MakeIdForPostgreSQL 生成适用于 PostgreSQL 的有序 ID +func MakeIdForPostgreSQL(size int) string { + return id.DefaultIDMaker.GetForPostgreSQL(size) +} diff --git a/verify.go b/verify.go new file mode 100644 index 0000000..bce9bd0 --- /dev/null +++ b/verify.go @@ -0,0 +1,305 @@ +package service + +import ( + "apigo.cc/go/cast" + "apigo.cc/go/log" + "reflect" + "regexp" + "strings" + "sync" +) + +// VerifyType 校验类型 +type VerifyType uint8 + +const ( + VerifyUnknown VerifyType = iota + VerifyRegex + VerifyStringLength + VerifyGreaterThan + VerifyLessThan + VerifyBetween + VerifyInList + VerifyByFunc +) + +// VerifySet 校验规则集 +type VerifySet struct { + Type VerifyType + Regex *regexp.Regexp + StringArgs []string + IntArgs []int + FloatArgs []float64 + Func func(any, []string) bool +} + +var ( + verifySets = make(map[string]*VerifySet) + verifySetsLock = sync.RWMutex{} + verifyFunctions = make(map[string]func(any, []string) bool) + verifyFunctionsLock = sync.RWMutex{} +) + +// RegisterVerifyFunc 注册自定义校验函数 +func RegisterVerifyFunc(name string, f func(in any, args []string) bool) { + verifyFunctionsLock.Lock() + verifyFunctions[name] = f + verifyFunctionsLock.Unlock() +} + +// RegisterVerify 注册预定义校验规则 +func RegisterVerify(name, setting string) { + set, _ := compileVerifySet(setting) + if set != nil { + verifySetsLock.Lock() + verifySets[name] = set + verifySetsLock.Unlock() + } +} + +// VerifyStruct 校验结构体 +func VerifyStruct(in any, logger *log.Logger) (ok bool, field string) { + v := cast.RealValue(reflect.ValueOf(in)) + if v.Kind() != reflect.Struct { + if logger != nil { + logger.Error("verify input is not struct", "type", v.Type().String()) + } + return false, "" + } + + for i := 0; i < v.NumField(); i++ { + ft := v.Type().Field(i) + fv := v.Field(i) + + // 忽略空指针、空切片、空 Map + if (fv.Kind() == reflect.Ptr && fv.IsNil()) || + (fv.Kind() == reflect.Slice && fv.Len() == 0) || + (fv.Kind() == reflect.Map && fv.Len() == 0) { + continue + } + + if ft.Anonymous { + // 处理嵌套结构体(继承) + if fv.CanInterface() { + if ok, f := VerifyStruct(fv.Interface(), logger); !ok { + return false, f + } + } + continue + } + + tag := ft.Tag.Get("verify") + keyTag := ft.Tag.Get("verifyKey") + if tag != "" || keyTag != "" { + var err error + ok, f, err := _verifyValue(fv, tag, keyTag, logger) + if !ok { + if f == "" { + f = cast.GetLowerName(ft.Name) + } + if logger != nil { + if err != nil { + logger.Error(err.Error(), "field", f) + } else { + logger.Warning("verify failed", "field", f, "tag", tag) + } + } + return false, f + } + } + } + return true, "" +} + +func _verifyValue(in reflect.Value, setting, keySetting string, logger *log.Logger) (bool, string, error) { + t := in.Type() + // 处理切片 (非 byte 切片) + if t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8 { + if setting != "" { + for i := 0; i < in.Len(); i++ { + if ok, f, err := _verifyValue(in.Index(i), setting, "", logger); !ok { + return false, f, err + } + } + } + return true, "", nil + } + + // 处理 Map + if t.Kind() == reflect.Map { + for _, k := range in.MapKeys() { + if keySetting != "" { + if ok, _, err := _verifyValue(k, keySetting, "", logger); !ok { + return false, "key", err + } + } + if setting != "" { + if ok, f, err := _verifyValue(in.MapIndex(k), setting, "", logger); !ok { + return false, f, err + } + } + } + return true, "", nil + } + + // 处理嵌套 Struct + if t.Kind() == reflect.Struct { + ok, f := VerifyStruct(in.Interface(), logger) + return ok, f, nil + } + + // 基础校验 + if setting == "" { + return true, "", nil + } + + ok, err := verify(in.Interface(), setting) + return ok, "", err +} + +func verify(in any, setting string) (bool, error) { + if len(setting) < 2 { + return false, nil + } + + verifySetsLock.RLock() + set, exists := verifySets[setting] + verifySetsLock.RUnlock() + + if !exists { + var err error + set, err = compileVerifySet(setting) + if err != nil { + return false, err + } + verifySetsLock.Lock() + verifySets[setting] = set + verifySetsLock.Unlock() + } + + switch set.Type { + case VerifyByFunc: + return set.Func(in, set.StringArgs), nil + case VerifyRegex: + return set.Regex.MatchString(cast.String(in)), nil + case VerifyStringLength: + l := len(cast.String(in)) + if len(set.StringArgs) > 0 { + if set.StringArgs[0] == "+" { + return l >= set.IntArgs[0], nil + } else if set.StringArgs[0] == "-" { + return l <= set.IntArgs[0], nil + } + } + if len(set.IntArgs) > 1 { + return l >= set.IntArgs[0] && l <= set.IntArgs[1], nil + } + return l == set.IntArgs[0], nil + case VerifyGreaterThan: + return cast.Float64(in) > set.FloatArgs[0], nil + case VerifyLessThan: + return cast.Float64(in) < set.FloatArgs[0], nil + case VerifyBetween: + val := cast.Float64(in) + return val >= set.FloatArgs[0] && val <= set.FloatArgs[1], nil + case VerifyInList: + s := cast.String(in) + for _, item := range set.StringArgs { + if item == s { + return true, nil + } + } + return false, nil + } + return false, nil +} + +func compileVerifySet(setting string) (*VerifySet, error) { + set := &VerifySet{Type: VerifyUnknown} + if setting == "" { + return set, nil + } + + if setting[0] != '^' { + key := setting + args := "" + if pos := strings.IndexByte(setting, ':'); pos != -1 { + key = setting[:pos] + args = setting[pos+1:] + } + + // 优先查找自定义函数 + verifyFunctionsLock.RLock() + f, exists := verifyFunctions[key] + verifyFunctionsLock.RUnlock() + if exists { + set.Type = VerifyByFunc + set.Func = f + if args != "" { + set.StringArgs = strings.Split(args, ",") + } + return set, nil + } + + // 内置规则 + switch key { + case "length": + set.Type = VerifyStringLength + if args == "" { + args = "1+" + } + last := args[len(args)-1] + if last == '+' || last == '-' { + set.StringArgs = []string{string(last)} + args = args[:len(args)-1] + } + // 同时支持逗号和中划线 + sep := "," + if strings.Contains(args, "-") && !strings.Contains(args, ",") { + sep = "-" + } + if strings.Contains(args, sep) { + a := strings.Split(args, sep) + set.IntArgs = []int{cast.Int(a[0]), cast.Int(a[1])} + } else { + set.IntArgs = []int{cast.Int(args)} + } + return set, nil + case "between": + set.Type = VerifyBetween + if args == "" { + args = "1-100000000" + } + a := strings.Split(args, "-") + if len(a) == 1 { + set.FloatArgs = []float64{0, cast.Float64(a[0])} + } else { + set.FloatArgs = []float64{cast.Float64(a[0]), cast.Float64(a[1])} + } + return set, nil + case "gt": + set.Type = VerifyGreaterThan + set.FloatArgs = []float64{cast.Float64(args)} + return set, nil + case "lt": + set.Type = VerifyLessThan + set.FloatArgs = []float64{cast.Float64(args)} + return set, nil + case "in": + set.Type = VerifyInList + if args != "" { + set.StringArgs = strings.Split(args, ",") + } + return set, nil + } + } + + // 默认视为正则表达式 + rx, err := regexp.Compile(setting) + if err != nil { + return nil, err + } + set.Type = VerifyRegex + set.Regex = rx + return set, nil +} diff --git a/verify_test.go b/verify_test.go new file mode 100644 index 0000000..b4f5d17 --- /dev/null +++ b/verify_test.go @@ -0,0 +1,76 @@ +package service + +import ( + "testing" +) + +type TestUser struct { + Name string `verify:"length:2-10"` + Age int `verify:"between:18-100"` + Type string `verify:"in:admin,user,guest"` +} + +type NestedStruct struct { + TestUser + Note string `verify:"^.{1,20}$"` +} + +func TestVerifyStruct(t *testing.T) { + u := TestUser{Name: "Star", Age: 25, Type: "admin"} + if ok, f := VerifyStruct(u, nil); !ok { + t.Errorf("VerifyStruct failed on valid user, field: %s", f) + } + + u.Name = "S" + if ok, f := VerifyStruct(u, nil); ok || f != "name" { + t.Errorf("VerifyStruct should fail on short name, got ok=%v, field=%s", ok, f) + } + + u.Name = "Star" + u.Age = 10 + if ok, f := VerifyStruct(u, nil); ok || f != "age" { + t.Errorf("VerifyStruct should fail on young age, got ok=%v, field=%s", ok, f) + } + + u.Age = 25 + u.Type = "invalid" + if ok, f := VerifyStruct(u, nil); ok || f != "type" { + t.Errorf("VerifyStruct should fail on invalid type, got ok=%v, field=%s", ok, f) + } +} + +func TestNestedVerify(t *testing.T) { + n := NestedStruct{ + TestUser: TestUser{Name: "Star", Age: 25, Type: "user"}, + Note: "Hello", + } + if ok, f := VerifyStruct(n, nil); !ok { + t.Errorf("Nested VerifyStruct failed on valid data, field: %s", f) + } + + n.TestUser.Age = 5 + if ok, f := VerifyStruct(n, nil); ok || f != "age" { + t.Errorf("Nested VerifyStruct should fail on nested age, got ok=%v, field=%s", ok, f) + } +} + +func TestCustomVerify(t *testing.T) { + RegisterVerifyFunc("odd", func(in any, args []string) bool { + val := in.(int) + return val%2 != 0 + }) + + type OddStruct struct { + Num int `verify:"odd"` + } + + o := OddStruct{Num: 3} + if ok, f := VerifyStruct(o, nil); !ok { + t.Errorf("Custom verify failed on odd number, field: %s", f) + } + + o.Num = 4 + if ok, f := VerifyStruct(o, nil); ok || f != "num" { + t.Errorf("Custom verify should fail on even number, got ok=%v, field=%s", ok, f) + } +} diff --git a/websocket.go b/websocket.go new file mode 100644 index 0000000..323eb5d --- /dev/null +++ b/websocket.go @@ -0,0 +1,175 @@ +package service + +import ( + "apigo.cc/go/cast" + "apigo.cc/go/log" + "github.com/gorilla/websocket" + "net/http" + "reflect" + "regexp" +) + +// websocketServiceType WebSocket 服务元数据 +type websocketServiceType struct { + authLevel int + path string + pathMatcher *regexp.Regexp + pathArgs []string + updater *websocket.Upgrader + openFuncValue reflect.Value + openFuncType reflect.Type + closeFuncValue reflect.Value + closeFuncType reflect.Type + sessionType reflect.Type + actions map[string]*websocketActionType + isSimple bool + options WebServiceOptions + memo string +} + +// websocketActionType WebSocket Action 元数据 +type websocketActionType struct { + authLevel int + funcValue reflect.Value + funcType reflect.Type + inType reflect.Type + memo string +} + +// ActionRegister WebSocket Action 注册器 +type ActionRegister struct { + ws *websocketServiceType +} + +// RegisterWebsocket 注册 WebSocket 服务 +func RegisterWebsocket(authLevel int, path string, onOpen, onClose any, memo string) *ActionRegister { + s := &websocketServiceType{ + authLevel: authLevel, + path: path, + memo: memo, + updater: &websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}, + actions: make(map[string]*websocketActionType), + } + + if onOpen != nil { + s.openFuncValue = reflect.ValueOf(onOpen) + s.openFuncType = s.openFuncValue.Type() + if s.openFuncType.NumOut() > 0 { + s.sessionType = s.openFuncType.Out(0) + } + } + + if onClose != nil { + s.closeFuncValue = reflect.ValueOf(onClose) + s.closeFuncType = s.closeFuncValue.Type() + } + + websocketServicesLock.Lock() + websocketServices[path] = s + websocketServicesLock.Unlock() + + return &ActionRegister{ws: s} +} + +// RegisterAction 注册 WebSocket Action +func (ar *ActionRegister) RegisterAction(authLevel int, name string, action any, memo string) { + v := reflect.ValueOf(action) + t := v.Type() + a := &websocketActionType{ + authLevel: authLevel, + funcValue: v, + funcType: t, + memo: memo, + } + + // 查找输入参数类型 + for i := 0; i < t.NumIn(); i++ { + inT := t.In(i) + if inT.Kind() == reflect.Struct { + a.inType = inT + break + } + } + + ar.ws.actions[name] = a +} + +func doWebsocketService(ws *websocketServiceType, request *Request, response *Response, logger *log.Logger) { + conn, err := ws.updater.Upgrade(response.Writer, request.Request, nil) + if err != nil { + logger.Error("websocket upgrade failed", "error", err.Error()) + return + } + defer conn.Close() + + var session any + if ws.openFuncValue.IsValid() { + // 简化版:仅支持基础参数注入 + params := make([]reflect.Value, ws.openFuncType.NumIn()) + for i := 0; i < len(params); i++ { + t := ws.openFuncType.In(i) + if t == reflect.TypeOf(request) { + params[i] = reflect.ValueOf(request) + } else if t == reflect.TypeOf(logger) { + params[i] = reflect.ValueOf(logger) + } else { + params[i] = reflect.New(t).Elem() + } + } + outs := ws.openFuncValue.Call(params) + if len(outs) > 0 { + session = outs[0].Interface() + } + } + + for { + var msg Map + if err := conn.ReadJSON(&msg); err != nil { + break + } + + actionName := cast.String(msg["action"]) + action := ws.actions[actionName] + if action == nil { + action = ws.actions[""] // 默认 action + } + + if action != nil { + params := make([]reflect.Value, action.funcType.NumIn()) + for i := 0; i < len(params); i++ { + t := action.funcType.In(i) + if t == ws.sessionType { + params[i] = reflect.ValueOf(session) + } else if t == reflect.TypeOf(conn) { + params[i] = reflect.ValueOf(conn) + } else if t.Kind() == reflect.Struct { + in := reflect.New(t).Interface() + cast.Convert(in, msg) + params[i] = reflect.ValueOf(in).Elem() + } else { + params[i] = reflect.New(t).Elem() + } + } + outs := action.funcValue.Call(params) + if len(outs) > 0 { + result := outs[0].Interface() + if result != nil { + _ = conn.WriteJSON(result) + } + } + } + } + + if ws.closeFuncValue.IsValid() { + params := make([]reflect.Value, ws.closeFuncType.NumIn()) + for i := 0; i < len(params); i++ { + t := ws.closeFuncType.In(i) + if t == ws.sessionType { + params[i] = reflect.ValueOf(session) + } else { + params[i] = reflect.New(t).Elem() + } + } + ws.closeFuncValue.Call(params) + } +} diff --git a/websocket_test.go b/websocket_test.go new file mode 100644 index 0000000..8689ff0 --- /dev/null +++ b/websocket_test.go @@ -0,0 +1,44 @@ +package service + +import ( + "github.com/gorilla/websocket" + "net/http/httptest" + "strings" + "testing" +) + +func TestWebSocketService(t *testing.T) { + // 注册 WebSocket 服务 + ar := RegisterWebsocket(0, "/ws", nil, nil, "test websocket") + ar.RegisterAction(0, "echo", func(in struct{ Msg string }) Map { + return Map{"action": "echo", "reply": in.Msg} + }, "echo action") + + // 启动测试服务器 + server := httptest.NewServer(&routeHandler{}) + defer server.Close() + + // 建立连接 + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/ws" + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + defer conn.Close() + + // 发送消息 + msg := Map{"action": "echo", "msg": "hello"} + if err := conn.WriteJSON(msg); err != nil { + t.Fatalf("WriteJSON failed: %v", err) + } + + // 接收响应 + var reply Map + if err := conn.ReadJSON(&reply); err != nil { + t.Fatalf("ReadJSON failed: %v", err) + } + + if reply["reply"] != "hello" { + t.Errorf("Reply mismatch: %v", reply) + } +}