commit bdb104aa2f05d1fa5bf2a474dee8cb58e874d608 Author: AI Engineer Date: Fri May 8 07:27:06 2026 +0800 Migrate service module from ssgo/s with modern Go features (by AI) diff --git a/.log.meta.json b/.log.meta.json new file mode 100644 index 0000000..9190980 --- /dev/null +++ b/.log.meta.json @@ -0,0 +1,267 @@ +{ + "debug": [ + { + "index": 0, + "name": "LogName", + "color": "cyan", + "hide": true + }, + { + "index": 1, + "name": "LogType", + "color": "magenta", + "hide": true + }, + { + "index": 2, + "name": "LogTime", + "format": "time" + }, + { + "index": 3, + "name": "TraceId", + "color": "blue" + }, + { + "index": 4, + "name": "Image", + "color": "darkGray", + "hide": true + }, + { + "index": 5, + "name": "Server", + "color": "darkGray", + "hide": true + }, + { + "index": 6, + "name": "Debug", + "withoutKey": true + }, + { + "index": 7, + "name": "Extra" + } + ], + "discover": [ + { + "index": 0, + "name": "LogName", + "color": "cyan", + "hide": true + }, + { + "index": 1, + "name": "LogType", + "color": "magenta", + "hide": true + }, + { + "index": 2, + "name": "LogTime", + "format": "time" + }, + { + "index": 3, + "name": "TraceId", + "color": "blue" + }, + { + "index": 4, + "name": "Image", + "color": "darkGray", + "hide": true + }, + { + "index": 5, + "name": "Server", + "color": "darkGray", + "hide": true + }, + { + "index": 6, + "name": "App", + "color": "cyan" + }, + { + "index": 7, + "name": "Method", + "color": "magenta" + }, + { + "index": 8, + "name": "Path", + "color": "blue" + }, + { + "index": 9, + "name": "Node", + "color": "yellow" + }, + { + "index": 10, + "name": "Attempts" + }, + { + "index": 11, + "name": "UsedTime", + "format": "%.2fms" + }, + { + "index": 12, + "name": "Error", + "color": "red" + }, + { + "index": 13, + "name": "Extra" + } + ], + "error": [ + { + "index": 0, + "name": "LogName", + "color": "cyan", + "hide": true + }, + { + "index": 1, + "name": "LogType", + "color": "magenta", + "hide": true + }, + { + "index": 2, + "name": "LogTime", + "format": "time" + }, + { + "index": 3, + "name": "TraceId", + "color": "blue" + }, + { + "index": 4, + "name": "Image", + "color": "darkGray", + "hide": true + }, + { + "index": 5, + "name": "Server", + "color": "darkGray", + "hide": true + }, + { + "index": 6, + "name": "Error", + "color": "red", + "withoutKey": true + }, + { + "index": 7, + "name": "CallStacks" + }, + { + "index": 8, + "name": "Extra" + } + ], + "info": [ + { + "index": 0, + "name": "LogName", + "color": "cyan", + "hide": true + }, + { + "index": 1, + "name": "LogType", + "color": "magenta", + "hide": true + }, + { + "index": 2, + "name": "LogTime", + "format": "time" + }, + { + "index": 3, + "name": "TraceId", + "color": "blue" + }, + { + "index": 4, + "name": "Image", + "color": "darkGray", + "hide": true + }, + { + "index": 5, + "name": "Server", + "color": "darkGray", + "hide": true + }, + { + "index": 6, + "name": "Info", + "color": "cyan", + "withoutKey": true + }, + { + "index": 7, + "name": "Extra" + } + ], + "warning": [ + { + "index": 0, + "name": "LogName", + "color": "cyan", + "hide": true + }, + { + "index": 1, + "name": "LogType", + "color": "magenta", + "hide": true + }, + { + "index": 2, + "name": "LogTime", + "format": "time" + }, + { + "index": 3, + "name": "TraceId", + "color": "blue" + }, + { + "index": 4, + "name": "Image", + "color": "darkGray", + "hide": true + }, + { + "index": 5, + "name": "Server", + "color": "darkGray", + "hide": true + }, + { + "index": 6, + "name": "Warning", + "color": "yellow", + "withoutKey": true + }, + { + "index": 7, + "name": "CallStacks" + }, + { + "index": 8, + "name": "Extra" + } + ] +} diff --git a/DocTpl.html b/DocTpl.html new file mode 100644 index 0000000..4d655ae --- /dev/null +++ b/DocTpl.html @@ -0,0 +1,216 @@ + + + + + {{.title}} + + + + +
+ {{range .api}} + +
+
+ {{.Path}} + {{.Memo}} + {{if ne .Method ""}}{{end}} + + {{if ne .Type "Web"}}{{end}} +
+
+ + {{if isMap .In}} + + + + {{range $k, $v := .In}} + + + + + {{end}} + {{else}} + + + + {{end}} +
Request
{{$k}}{{toText $v}}
{{.In}}
+ + {{if isMap .Out}} + + + + {{range $k, $v := .Out}} + + + + + {{end}} + {{else}} + + + + {{end}} +
Response
{{$k}}{{toText $v}}
{{.Out}}
+
+ {{else}} +
no document
+ {{end}} +
+
+ + diff --git a/README.md b/README.md new file mode 100644 index 0000000..bf588e2 --- /dev/null +++ b/README.md @@ -0,0 +1,56 @@ +# go/service (核心微服务框架) + +极简、自动化的 Web 与 WebSocket 服务框架,实现极致的依赖注入与路由映射。 + +## 核心特性 +- **路由反射**: 自动解析函数参数,支持 `*Request`, `*Response`, `*log.Logger` 及自定义结构体自动注入。 +- **自动校验**: 集成 `verify` 引擎,通过 Struct Tag 实现入参合法性自动检查。 +- **功能闭环**: 内置静态文件服务、WebSocket (带 Action 路由)、URL 重写、反向代理(对接 Discover)。 +- **零摩擦启动**: 支持命令行指令管理 (start/stop/help) 及异步平滑启停。 + +## API 指南 + +### 1. 服务注册 +```go +import "apigo.cc/go/service" + +// 注册标准 Web 服务 +service.Register(0, "/hello", func(in struct{ Name string }) string { + return "Hello " + in.Name +}, "打招呼接口") + +// 注册 Restful 服务 +service.Restful(0, "POST", "/user/{id}", func(args map[string]any) service.Result { + res := service.Result{} + res.OK() + return res +}, "更新用户") +``` + +### 2. WebSocket 支持 +```go +ar := service.RegisterWebsocket(0, "/ws", onOpen, onClose, "聊天室") +ar.RegisterAction(0, "chat", func(in ChatMessage, sess *MySession) { + // 处理消息 +}, "发送消息") +``` + +### 3. 增强插件 +- **静态文件**: `service.Static("/ui", "./static_dir")` +- **URL 重写**: `service.Rewrite("/old", "/new")` +- **反向代理**: `service.Proxy(0, "/api", "other_app", "/api")` + +### 4. 生命周期管理 +```go +func main() { + service.CheckCmd() // 处理 start/stop/help 指令 + service.Start() // 阻塞启动 +} +``` + +## 基础设施对齐 +- **类型转换**: `apigo.cc/go/cast` +- **日志系统**: `apigo.cc/go/log` +- **服务发现**: `apigo.cc/go/discover` +- **分布式 ID**: `apigo.cc/go/id` +- **文件操作**: `apigo.cc/go/file` diff --git a/config.go b/config.go new file mode 100644 index 0000000..2b6812b --- /dev/null +++ b/config.go @@ -0,0 +1,62 @@ +package service + +// CertSet SSL 证书配置 +type CertSet struct { + CertFile string + KeyFile string +} + +// ServiceConfig 核心服务配置 +type ServiceConfig struct { + Listen string // 监听端口(|隔开多个监听)(,隔开多个选项),例如 80,http|443|443:h2|127.0.0.1:8080,h2c + SSL map[string]*CertSet // SSL 证书配置,key 为域名 + NoLogGets bool // 不记录 GET 请求的日志 + NoLogHeaders string // 不记录请求头中包含的这些字段,多个字段用逗号分隔 + LogInputArrayNum int // 请求字段中容器类型在日志打印个数限制 + LogInputFieldSize int // 请求字段中单个字段在日志打印长度限制 + NoLogOutputFields string // 不记录响应字段中包含的这些字段 + LogOutputArrayNum int // 响应字段中容器类型在日志打印个数限制 + LogOutputFieldSize int // 响应字段中单个字段在日志打印长度限制 + LogWebsocketAction bool // 记录 Websocket 中每个 Action 的请求日志 + Compress bool // 是否启用压缩 + CompressMinSize int // 启用压缩的最小长度 + CompressMaxSize int // 启用压缩的最大长度 + CheckDomain string // 心跳检测时使用域名 + AccessTokens map[string]*int // 指定 Access-Token 验证及其对应的 auth-level + RedirectTimeout int // Proxy 和 Discover 发起请求时的超时时间 (ms) + AcceptXRealIpWithoutRequestId bool // 是否允许头部没有携带请求ID的 X-Real-IP 信息 + StatisticTime bool // 是否开启请求时间统计 + StatisticTimeInterval int // 统计时间间隔 (ms) + Fast bool // 是否启用快速模式 + MaxUploadSize int64 // 最大上传文件大小 (Bytes) + IpPrefix string // Discover 服务发现时指定使用的 IP 网段 + Cpu int // CPU 占用的核数限制 + Memory int // 内存限制 (MB) + CpuMonitor bool // 记录 CPU 使用情况 + MemoryMonitor bool // 记录内存使用情况 + CpuLimitValue uint // CPU 自动重启阈值 (10-100) + MemoryLimitValue uint // 内存自动重启阈值 (10-100) + CpuLimitTimes uint // CPU 报警阈值连续次数 + MemoryLimitTimes uint // 内存报警阈值连续次数 + CookieScope string // Session Cookie 有效范围: host|domain|topDomain + SessionWithoutCookie bool // Session 禁用 Cookie + DeviceWithoutCookie bool // 设备ID禁用 Cookie + IdServer string // Redis 服务器连接 (用于全局唯一 ID 生成) + KeepKeyCase bool // 是否保持 Key 的首字母大小写 + IndexFiles []string // 静态文件索引文件 + IndexDir bool // 访问目录时显示文件列表 + ReadTimeout int // 读取请求的超时时间 (ms) + ReadHeaderTimeout int // 读取请求头的超时时间 (ms) + WriteTimeout int // 响应写入的超时时间 (ms) + IdleTimeout int // 连接空闲超时时间 (ms) + MaxHeaderBytes int // 请求头的最大字节数 + MaxHandlers int // 每个连接的最大处理程序数量 + MaxConcurrentStreams uint32 // 每个连接的最大并发流数量 + MaxDecoderHeaderTableSize uint32 // 解码器头表的最大大小 + MaxEncoderHeaderTableSize uint32 // 编码器头表的最大大小 + MaxReadFrameSize uint32 // 单个帧的最大读取大小 + MaxUploadBufferPerConnection int32 // 每个连接的最大上传缓冲区大小 + MaxUploadBufferPerStream int32 // 每个流的最大上传缓冲区大小 +} + +var Config = ServiceConfig{} diff --git a/document.go b/document.go new file mode 100644 index 0000000..428148c --- /dev/null +++ b/document.go @@ -0,0 +1,162 @@ +package service + +import ( + "apigo.cc/go/cast" + _ "embed" + "encoding/json" + "reflect" +) + +// Api 接口文档信息 +type Api struct { + Type string + Path string + AuthLevel int + Method string + In any + Out any + Memo string +} + +//go:embed DocTpl.html +var defaultDocTpl string + +// MakeDocument 生成文档数据 +func MakeDocument() []Api { + out := make([]Api, 0) + + // 1. Rewrite + rewritesLock.RLock() + for _, a := range rewrites { + out = append(out, Api{ + Type: "Rewrite", + Path: a.fromPath + " -> " + a.toPath, + }) + } + rewritesLock.RUnlock() + + // 2. Proxy + proxiesLock.RLock() + for _, a := range proxies { + out = append(out, Api{ + Type: "Proxy", + Path: a.fromPath + " -> " + a.toApp + ":" + a.toPath, + }) + } + proxiesLock.RUnlock() + + // 3. Web Services + webServicesLock.RLock() + for _, a := range webServicesList { + if a.options.NoDoc { + continue + } + api := Api{ + Type: "Web", + Path: a.path, + AuthLevel: a.authLevel, + Method: a.method, + Memo: a.memo, + } + if a.inType != nil { + api.In = getType(a.inType) + } + if a.funcType.NumOut() > 0 { + api.Out = getType(a.funcType.Out(0)) + } + out = append(out, api) + } + webServicesLock.RUnlock() + + // 4. WebSocket Services + websocketServicesLock.RLock() + for _, a := range websocketServices { + api := Api{ + Type: "WebSocket", + Path: a.path, + AuthLevel: a.authLevel, + Memo: a.memo, + } + if a.openFuncType != nil && a.openFuncType.NumIn() > 0 { + // Find struct in + for i := 0; i < a.openFuncType.NumIn(); i++ { + t := a.openFuncType.In(i) + if t.Kind() == reflect.Struct { + api.In = getType(t) + break + } + } + } + out = append(out, api) + + for name, act := range a.actions { + actionApi := Api{ + Type: "Action", + Path: name, + AuthLevel: act.authLevel, + Memo: act.memo, + } + if act.inType != nil { + actionApi.In = getType(act.inType) + } + if act.funcType.NumOut() > 0 { + actionApi.Out = getType(act.funcType.Out(0)) + } + out = append(out, actionApi) + } + } + websocketServicesLock.RUnlock() + + return out +} + +// MakeJsonDocument 生成 JSON 格式文档 +func MakeJsonDocument() string { + apis := MakeDocument() + data, _ := json.MarshalIndent(map[string]any{ + "api": apis, + }, "", "\t") + return string(data) +} + +func getType(t reflect.Type) any { + if t == nil { + return "" + } + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + + switch t.Kind() { + case reflect.Struct: + outs := Map{} + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if f.Anonymous { + if subMap, ok := getType(f.Type).(Map); ok { + for k, v := range subMap { + outs[k] = v + } + } + } else { + outs[cast.GetLowerName(f.Name)] = getType(f.Type) + } + } + return outs + case reflect.Map: + return map[string]any{t.Key().String(): getType(t.Elem())} + case reflect.Slice: + return []any{getType(t.Elem())} + case reflect.Interface: + return "Any" + default: + return t.String() + } +} + +// 自动注册文档服务 +func init() { + Register(0, "/__DOC__", func() string { + return MakeJsonDocument() + }, "API Document") +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..479da22 --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module apigo.cc/go/service + +go 1.25.0 + +require github.com/gorilla/websocket v1.5.3 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..25a9fc4 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= diff --git a/handler.go b/handler.go new file mode 100644 index 0000000..181fa08 --- /dev/null +++ b/handler.go @@ -0,0 +1,319 @@ +package service + +import ( + "apigo.cc/go/cast" + "apigo.cc/go/id" + "apigo.cc/go/log" + "apigo.cc/go/standard" + "encoding/json" + "io" + "net/http" + "reflect" + "strings" + "sync/atomic" + "time" +) + +type routeHandler struct { + webRequestingNum int64 +} + +func (rh *routeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&rh.webRequestingNum, 1) + defer atomic.AddInt64(&rh.webRequestingNum, -1) + + startTime := time.Now() + requestId := r.Header.Get(standard.DiscoverHeaderRequestId) + if requestId == "" { + requestId = id.MakeID(12) + r.Header.Set(standard.DiscoverHeaderRequestId, requestId) + } + + request := NewRequest(r) + request.Id = requestId + response := NewResponse(w) + response.Id = requestId + defer response.checkWriteHeader() + + // 处理 SessionId 和 DeviceId + handleClientKeys(request, response) + + requestLogger := log.New(requestId) + + // 0. 处理重写 (Rewrite) + if processRewrite(request, response, requestLogger) { + return + } + + // 处理代理 (Proxy) + if processProxy(request, response, requestLogger) { + return + } + + // 1. 路由匹配 + path := r.URL.Path + host := r.Host + + // 处理静态文件 + if processStatic(path, request, response, requestLogger) { + return + } + + s, ws := findService(r.Method, host, path) + + // 2. 参数解析 (Form & Body) + args := make(map[string]any) + parseRequestArgs(request, args) + + // 3. 前置过滤器 + var result any + for _, filter := range inFilters { + result = filter(&args, request, response, requestLogger) + if result != nil { + break + } + } + + // 4. 处理业务执行 (WS 或 Web) + if result == nil { + if ws != nil { + doWebsocketService(ws, request, response, requestLogger) + return + } else if s != nil { + // 鉴权 + pass, obj := checkAuth(s, request, response, args, requestLogger) + if !pass { + if !response.changed { + response.WriteHeader(http.StatusForbidden) + } + return + } + // 执行业务 + result = doWebService(s, request, response, args, nil, requestLogger, obj) + } + } + + if s == nil && result == nil { + response.WriteHeader(http.StatusNotFound) + return + } + + // 5. 后置过滤器 + for _, filter := range outFilters { + newResult, done := filter(args, request, response, result, requestLogger) + if newResult != nil { + result = newResult + } + if done { + break + } + } + + // 6. 输出结果 + outputResult(response, result) + + // 7. 记录日志 + _ = startTime +} + +func findService(method, host, path string) (*webServiceType, *websocketServiceType) { + webServicesLock.RLock() + defer webServicesLock.RUnlock() + + // 1. Web Service 匹配 + if s, exists := webServices[method+path]; exists { + return s, nil + } + if s, exists := webServices[path]; exists { + return s, nil + } + + // 2. WebSocket 匹配 + websocketServicesLock.RLock() + defer websocketServicesLock.RUnlock() + if ws, exists := websocketServices[path]; exists { + return nil, ws + } + + // 3. 正则匹配 + for i := len(regexWebServices) - 1; i >= 0; i-- { + s := regexWebServices[i] + if s.method != "" && s.method != method { + continue + } + if s.pathMatcher != nil && s.pathMatcher.MatchString(path) { + return s, nil + } + } + + return nil, nil +} + +func parseRequestArgs(request *Request, args map[string]any) { + // Query params + query := request.URL.Query() + for k, v := range query { + if len(v) == 1 { + args[k] = v[0] + } else { + args[k] = v + } + } + + // Form params + if request.Method == http.MethodPost || request.Method == http.MethodPut { + contentType := request.Header.Get("Content-Type") + if strings.HasPrefix(contentType, "application/json") { + body, _ := io.ReadAll(request.Body) + _ = request.Body.Close() + if len(body) > 0 { + _ = json.Unmarshal(body, &args) + } + } else { + _ = request.ParseForm() + for k, v := range request.Form { + if len(v) == 1 { + args[k] = v[0] + } else { + args[k] = v + } + } + } + } +} + +func checkAuth(s *webServiceType, request *Request, response *Response, args map[string]any, logger *log.Logger) (bool, any) { + ac := webAuthCheckers[s.authLevel] + if ac == nil { + ac = webAuthChecker + } + if ac == nil { + return true, nil + } + return ac(s.authLevel, logger, &request.RequestURI, args, request, response, &s.options) +} + +func doWebService(service *webServiceType, request *Request, response *Response, args map[string]any, + result any, logger *log.Logger, object any) any { + if result != nil { + return result + } + + params := make([]reflect.Value, service.parmsNum) + for i := 0; i < service.parmsNum; i++ { + t := service.funcType.In(i) + switch i { + case service.requestIndex: + params[i] = reflect.ValueOf(request) + case service.httpRequestIndex: + params[i] = reflect.ValueOf(request.Request) + case service.responseIndex: + params[i] = reflect.ValueOf(response) + case service.responseWriterIndex: + params[i] = reflect.ValueOf(response.Writer) + case service.loggerIndex: + params[i] = reflect.ValueOf(logger) + case service.inIndex: + in := reflect.New(service.inType).Interface() + cast.Convert(in, args) + // 参数校验 + if service.inType.Kind() == reflect.Struct { + if ok, _ := VerifyStruct(in, logger); !ok { + response.WriteHeader(http.StatusBadRequest) + return "parameter verification failed" + } + } + params[i] = reflect.ValueOf(in).Elem() + default: + // 尝试依赖注入 + if obj := GetInject(t); obj != nil { + params[i] = reflect.ValueOf(obj) + } else { + params[i] = reflect.New(t).Elem() + } + } + } + + outs := service.funcValue.Call(params) + if len(outs) > 0 { + return outs[0].Interface() + } + return "" +} + +func outputResult(response *Response, result any) { + if result == nil { + return + } + + var data []byte + contentType := "" + + switch v := result.(type) { + case string: + data = []byte(v) + case []byte: + data = v + default: + data, _ = cast.ToJSONBytes(result) + contentType = "application/json; charset=UTF-8" + } + + if contentType != "" && response.Header().Get("Content-Type") == "" { + response.Header().Set("Content-Type", contentType) + } + _, _ = response.Write(data) +} + +func handleClientKeys(request *Request, response *Response) { + // SessionId + if usedSessionIdKey != "" { + sessionId := request.Header.Get(usedSessionIdKey) + if sessionId == "" && !Config.SessionWithoutCookie { + if ck, err := request.Cookie(usedSessionIdKey); err == nil { + sessionId = ck.Value + } + } + if sessionId == "" { + if sessionIdMaker != nil { + sessionId = sessionIdMaker() + } else { + sessionId = id.MakeID(14) + } + if !Config.SessionWithoutCookie { + http.SetCookie(response.Writer, &http.Cookie{ + Name: usedSessionIdKey, + Value: sessionId, + Path: "/", + HttpOnly: true, + }) + } + } + request.Header.Set(standard.DiscoverHeaderSessionId, sessionId) + response.Header().Set(usedSessionIdKey, sessionId) + } + + // DeviceId + if usedDeviceIdKey != "" { + deviceId := request.Header.Get(usedDeviceIdKey) + if deviceId == "" && !Config.DeviceWithoutCookie { + if ck, err := request.Cookie(usedDeviceIdKey); err == nil { + deviceId = ck.Value + } + } + if deviceId == "" { + deviceId = id.MakeID(14) + if !Config.DeviceWithoutCookie { + http.SetCookie(response.Writer, &http.Cookie{ + Name: usedDeviceIdKey, + Value: deviceId, + Path: "/", + Expires: time.Now().AddDate(10, 0, 0), + HttpOnly: true, + }) + } + } + request.Header.Set(standard.DiscoverHeaderDeviceId, deviceId) + response.Header().Set(usedDeviceIdKey, deviceId) + } +} diff --git a/handler_test.go b/handler_test.go new file mode 100644 index 0000000..a5b03f3 --- /dev/null +++ b/handler_test.go @@ -0,0 +1,67 @@ +package service + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestServeHTTP(t *testing.T) { + // 注册服务 + handler := func(in struct{ Name string }) string { + return "Hello " + in.Name + } + Register(0, "/hello", handler, "say hello") + + rh := &routeHandler{} + + // 模拟请求 + req := httptest.NewRequest("POST", "/hello", strings.NewReader(`{"name":"Star"}`)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + rh.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + + body := w.Body.String() + if body != "Hello Star" { + t.Errorf("Expected 'Hello Star', got '%s'", body) + } +} + +func TestServeHTTP_404(t *testing.T) { + rh := &routeHandler{} + req := httptest.NewRequest("GET", "/notfound", nil) + w := httptest.NewRecorder() + + rh.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("Expected status 404, got %d", w.Code) + } +} + +func TestServeHTTP_VerifyFailed(t *testing.T) { + type ValidIn struct { + Age int `verify:"between:18-100"` + } + handler := func(in ValidIn) string { + return "ok" + } + Register(0, "/verify", handler, "test verify") + + rh := &routeHandler{} + req := httptest.NewRequest("POST", "/verify", strings.NewReader(`{"age":10}`)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + rh.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("Expected status 400, got %d", w.Code) + } +} diff --git a/proxy.go b/proxy.go new file mode 100644 index 0000000..969c923 --- /dev/null +++ b/proxy.go @@ -0,0 +1,171 @@ +package service + +import ( + "apigo.cc/go/discover" + gohttp "apigo.cc/go/http" + "apigo.cc/go/log" + "fmt" + "io" + "net/http" + "regexp" + "strings" + "sync" + "time" +) + +type proxyInfo struct { + matcher *regexp.Regexp + authLevel int + fromPath string + toApp string + toPath string +} + +var ( + proxies = make(map[string]*proxyInfo) + regexProxies = make([]*proxyInfo, 0) + proxyBy func(*Request) (int, *string, *string, map[string]string) + proxiesLock = sync.RWMutex{} + + httpClientPool *gohttp.Client +) + +// Proxy 注册代理规则 +func Proxy(authLevel int, path string, toApp, toPath string) { + p := &proxyInfo{authLevel: authLevel, fromPath: path, toApp: toApp, toPath: toPath} + if strings.Contains(path, "(") { + matcher, err := regexp.Compile("^" + path + "$") + if err == nil { + p.matcher = matcher + proxiesLock.Lock() + regexProxies = append(regexProxies, p) + proxiesLock.Unlock() + } + } else { + proxiesLock.Lock() + proxies[path] = p + proxiesLock.Unlock() + } +} + +// SetProxyBy 设置动态代理函数 +func SetProxyBy(by func(request *Request) (authLevel int, toApp, toPath *string, headers map[string]string)) { + proxyBy = by +} + +func findProxy(request *Request) (int, *string, *string) { + requestPath := request.RequestURI + queryString := "" + if pos := strings.Index(requestPath, "?"); pos != -1 { + queryString = requestPath[pos:] + requestPath = requestPath[:pos] + } + + proxiesLock.RLock() + defer proxiesLock.RUnlock() + + if pi, ok := proxies[requestPath]; ok { + toPath := pi.toPath + queryString + return pi.authLevel, &pi.toApp, &toPath + } + + for _, pi := range regexProxies { + if pi.matcher != nil { + finds := pi.matcher.FindAllStringSubmatch(requestPath, 1) + if len(finds) > 0 { + toApp := pi.toApp + toPath := pi.toPath + for i, part := range finds[0] { + toApp = strings.ReplaceAll(toApp, fmt.Sprintf("$%d", i), part) + toPath = strings.ReplaceAll(toPath, fmt.Sprintf("$%d", i), part) + } + toPath += queryString + return pi.authLevel, &toApp, &toPath + } + } + } + + return 0, nil, nil +} + +func processProxy(request *Request, response *Response, logger *log.Logger) bool { + authLevel, proxyToApp, proxyToPath := findProxy(request) + var proxyHeaders map[string]string + + if proxyBy != nil && (proxyToApp == nil || proxyToPath == nil || *proxyToApp == "" || *proxyToPath == "") { + authLevel, proxyToApp, proxyToPath, proxyHeaders = proxyBy(request) + } + + if proxyToApp == nil || proxyToPath == nil || *proxyToApp == "" || *proxyToPath == "" { + return false + } + + // 鉴权 + pass, obj := checkAuthForProxy(authLevel, request, response, logger) + if !pass { + if !response.changed { + response.WriteHeader(http.StatusForbidden) + } + return true + } + _ = obj // Currently unused in proxy + + app := *proxyToApp + path := *proxyToPath + + // 构建自定义头部 + headerArgs := make([]string, 0) + for k, v := range proxyHeaders { + headerArgs = append(headerArgs, k, v) + } + + if strings.Contains(app, "://") { + // 直接 URL 代理 + if httpClientPool == nil { + httpClientPool = gohttp.NewClient(time.Duration(Config.RedirectTimeout) * time.Millisecond) + } + res := httpClientPool.ManualDoByRequest(request.Request, request.Method, app+path, request.Body, headerArgs...) + copyResponse(res, response, logger) + } else { + // Discover 代理 + caller := discover.NewCaller(request.Request, logger) + caller.NoBody = true + res, _ := caller.ManualDoWithNode(request.Method, app, "", path, request.Body, headerArgs...) + copyResponse(res, response, logger) + } + + return true +} + +func checkAuthForProxy(authLevel int, request *Request, response *Response, logger *log.Logger) (bool, any) { + ac := webAuthCheckers[authLevel] + if ac == nil { + ac = webAuthChecker + } + if ac == nil { + return true, nil + } + return ac(authLevel, logger, &request.RequestURI, nil, request, response, nil) +} + +func copyResponse(res *gohttp.Result, response *Response, logger *log.Logger) { + if res.Error != nil || res.Response == nil { + response.WriteHeader(http.StatusBadGateway) + if res.Error != nil { + _, _ = response.WriteString(res.Error.Error()) + } + return + } + + for k, v := range res.Response.Header { + response.Header().Set(k, v[0]) + } + response.WriteHeader(res.Response.StatusCode) + if res.Response.Body != nil { + defer res.Response.Body.Close() + _, err := io.Copy(response.Writer, res.Response.Body) + if err != nil { + logger.Error("proxy copy body failed", "error", err.Error()) + } + } +} diff --git a/proxy_test.go b/proxy_test.go new file mode 100644 index 0000000..5f177b6 --- /dev/null +++ b/proxy_test.go @@ -0,0 +1,62 @@ +package service + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestRewrite(t *testing.T) { + // 注册重写规则 + Rewrite("/old", "/new") + Rewrite("/regex/(.*)", "/target/$1") + + // 注册目标服务 + Register(0, "/new", func() string { return "new content" }, "new") + Register(0, "/target/123", func() string { return "target content" }, "target") + + rh := &routeHandler{} + + // 测试精确匹配重写 + req1 := httptest.NewRequest("GET", "/old", nil) + w1 := httptest.NewRecorder() + rh.ServeHTTP(w1, req1) + if w1.Body.String() != "new content" { + t.Errorf("Expected 'new content', got '%s'", w1.Body.String()) + } + + // 测试正则匹配重写 + req2 := httptest.NewRequest("GET", "/regex/123", nil) + w2 := httptest.NewRecorder() + rh.ServeHTTP(w2, req2) + if w2.Body.String() != "target content" { + t.Errorf("Expected 'target content', got '%s'", w2.Body.String()) + } +} + +func TestProxyDirect(t *testing.T) { + // 启动后端服务器 + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Backend", "ok") + w.Write([]byte("backend content")) + })) + defer backend.Close() + + // 注册代理规则 + Proxy(0, "/proxy", backend.URL, "/hello") + + rh := &routeHandler{} + req := httptest.NewRequest("GET", "/proxy", nil) + w := httptest.NewRecorder() + rh.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected 200, got %d", w.Code) + } + if w.Header().Get("X-Backend") != "ok" { + t.Error("Header X-Backend mismatch") + } + if w.Body.String() != "backend content" { + t.Errorf("Expected 'backend content', got '%s'", w.Body.String()) + } +} diff --git a/request.go b/request.go new file mode 100644 index 0000000..800a9ad --- /dev/null +++ b/request.go @@ -0,0 +1,139 @@ +package service + +import ( + "apigo.cc/go/cast" + "apigo.cc/go/standard" + "io" + "mime/multipart" + "net" + "net/http" + "net/textproto" + "net/url" + "os" + "path/filepath" +) + +// UploadFile 上传文件结构 +type UploadFile struct { + fileHeader *multipart.FileHeader + Filename string + Header textproto.MIMEHeader + Size int64 +} + +// Open 打开上传文件 +func (f *UploadFile) Open() (multipart.File, error) { + return f.fileHeader.Open() +} + +// Save 保存上传文件到本地 +func (f *UploadFile) Save(filename string) error { + dir := filepath.Dir(filename) + if _, err := os.Stat(dir); os.IsNotExist(err) { + _ = os.MkdirAll(dir, 0755) + } + + dst, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) + if err != nil { + return err + } + defer dst.Close() + + src, err := f.fileHeader.Open() + if err != nil { + return err + } + defer src.Close() + + _, err = io.Copy(dst, src) + return err +} + +// Content 获取上传文件内容 +func (f *UploadFile) Content() ([]byte, error) { + src, err := f.fileHeader.Open() + if err != nil { + return nil, err + } + defer src.Close() + return io.ReadAll(src) +} + +// Request 封装 http.Request +type Request struct { + *http.Request + contextValues map[string]any + Id string +} + +// NewRequest 创建 Request 包装 +func NewRequest(httpRequest *http.Request) *Request { + return &Request{ + Request: httpRequest, + contextValues: make(map[string]any), + } +} + +// ResetPath 重写请求路径 +func (r *Request) ResetPath(path string) { + r.RequestURI = path + if u, err := url.Parse(path); err == nil { + r.URL = u + } +} + +// Set 设置请求上下文变量 +func (r *Request) Set(key string, value any) { + r.contextValues[key] = value +} + +// Get 获取请求上下文变量 +func (r *Request) Get(key string) any { + return r.contextValues[key] +} + +// MakeUrl 根据当前请求构建完整 URL +func (r *Request) MakeUrl(path string) string { + scheme := r.Header.Get(standard.DiscoverHeaderScheme) + if scheme == "" { + scheme = "http" + } + host := r.Header.Get(standard.DiscoverHeaderHost) + if host == "" { + host = r.Host + } + return scheme + "://" + host + path +} + +// GetSessionId 获取会话 ID +func (r *Request) GetSessionId() string { + sessionId := r.Header.Get(Config.Listen) // Wait, this should be usedSessionIdKey + // TODO: Fix dependency on global usedSessionIdKey + return sessionId +} + +// SetUserId 设置用户 ID(传递给下游) +func (r *Request) SetUserId(userId string) { + r.Header.Set(standard.DiscoverHeaderUserId, userId) +} + +// GetRealIp 获取真实 IP +func (r *Request) GetRealIp() string { + ip := r.Header.Get(standard.DiscoverHeaderClientIp) + if ip == "" { + ip = r.Header.Get(standard.DiscoverHeaderForwardedFor) + } + if ip == "" { + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err == nil { + return host + } + return r.RemoteAddr + } + return ip +} + +// GetLowerName (Aliased from cast) +func GetLowerName(s string) string { + return cast.GetLowerName(s) +} diff --git a/response.go b/response.go new file mode 100644 index 0000000..e51d926 --- /dev/null +++ b/response.go @@ -0,0 +1,152 @@ +package service + +import ( + "apigo.cc/go/cast" + "io" + "net/http" + "os" +) + +// Response 封装 http.ResponseWriter +type Response struct { + Id string + Writer http.ResponseWriter + status int + outLen int + changed bool + headerWritten bool + dontLog200 bool + dontLogArgs []string + ProxyHeader *http.Header +} + +// NewResponse 创建 Response 包装 +func NewResponse(writer http.ResponseWriter) *Response { + return &Response{ + Writer: writer, + status: http.StatusOK, + } +} + +// Header 获取响应头部 +func (r *Response) Header() http.Header { + r.changed = true + if r.ProxyHeader != nil { + return *r.ProxyHeader + } + return r.Writer.Header() +} + +// Write 写入响应内容 +func (r *Response) Write(bytes []byte) (int, error) { + r.checkWriteHeader() + r.changed = true + r.outLen += len(bytes) + if r.ProxyHeader != nil { + r.copyProxyHeader() + } + return r.Writer.Write(bytes) +} + +// WriteString 写入字符串响应 +func (r *Response) WriteString(s string) (int, error) { + return r.Write([]byte(s)) +} + +// WriteHeader 设置响应状态码 +func (r *Response) WriteHeader(code int) { + r.changed = true + r.status = code + if r.ProxyHeader != nil && (r.status == http.StatusBadGateway || r.status == http.StatusServiceUnavailable || r.status == http.StatusGatewayTimeout) { + return + } + if r.ProxyHeader != nil { + r.copyProxyHeader() + } +} + +func (r *Response) checkWriteHeader() { + if !r.headerWritten { + r.headerWritten = true + if r.status != http.StatusOK { + r.Writer.WriteHeader(r.status) + } + } +} + +func (r *Response) copyProxyHeader() { + src := *r.ProxyHeader + dst := r.Writer.Header() + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } + r.ProxyHeader = nil +} + +// Flush 刷新响应缓冲区 +func (r *Response) Flush() { + if flusher, ok := r.Writer.(http.Flusher); ok { + flusher.Flush() + } +} + +// GetStatusCode 获取当前状态码 +func (r *Response) GetStatusCode() int { + return r.status +} + +// DontLog200 标记不记录 200 状态码的日志 +func (r *Response) DontLog200() { + r.dontLog200 = true +} + +// Location 设置重定向地址 +func (r *Response) Location(location string) { + r.WriteHeader(http.StatusFound) + r.Header().Set("Location", location) +} + +// SendFile 发送文件 +func (r *Response) SendFile(contentType, filename string) { + r.Header().Set("Content-Type", contentType) + // TODO: Integrate memory file support if needed + if fd, err := os.Open(filename); err == nil { + defer fd.Close() + _, _ = io.Copy(r, fd) + } +} + +// DownloadFile 下载文件 +func (r *Response) DownloadFile(contentType, filename string, data any) { + if contentType == "" { + contentType = "application/octet-stream" + } + r.Header().Set("Content-Type", contentType) + + if filename != "" { + r.Header().Set("Content-Disposition", "attachment; filename="+filename) + } + + var outBytes []byte + var reader io.Reader + + switch v := data.(type) { + case []byte: + outBytes = v + case string: + outBytes = []byte(v) + case io.Reader: + reader = v + default: + outBytes, _ = cast.ToJSONBytes(data) + } + + if outBytes != nil { + r.Header().Set("Content-Length", cast.String(len(outBytes))) + _, _ = r.Write(outBytes) + } else if reader != nil { + _, _ = io.Copy(r, reader) + } +} diff --git a/rewrite.go b/rewrite.go new file mode 100644 index 0000000..70d340b --- /dev/null +++ b/rewrite.go @@ -0,0 +1,113 @@ +package service + +import ( + "apigo.cc/go/log" + "fmt" + "net/url" + "regexp" + "strings" + "sync" +) + +type rewriteInfo struct { + matcher *regexp.Regexp + fromPath string + toPath string +} + +var ( + rewrites = make(map[string]*rewriteInfo) + regexRewrites = make([]*rewriteInfo, 0) + rewriteBy func(*Request) (string, bool) + rewritesLock = sync.RWMutex{} +) + +// Rewrite 注册重写规则 +func Rewrite(path string, toPath string) { + s := &rewriteInfo{fromPath: path, toPath: toPath} + + if strings.ContainsRune(path, '(') { + matcher, err := regexp.Compile("^" + path + "$") + if err == nil { + s.matcher = matcher + rewritesLock.Lock() + regexRewrites = append(regexRewrites, s) + rewritesLock.Unlock() + } + } else { + rewritesLock.Lock() + rewrites[path] = s + rewritesLock.Unlock() + } +} + +// SetRewriteBy 设置动态重写函数 +func SetRewriteBy(by func(request *Request) (toPath string, rewrite bool)) { + rewriteBy = by +} + +func processRewrite(request *Request, response *Response, logger *log.Logger) bool { + requestPath := request.RequestURI + queryString := "" + if pos := strings.Index(requestPath, "?"); pos != -1 { + queryString = requestPath[pos:] + requestPath = requestPath[:pos] + } + + var rewriteToPath string + var found bool + + rewritesLock.RLock() + // 1. 精确匹配 + if ri, ok := rewrites[requestPath]; ok { + rewriteToPath = ri.toPath + found = true + } + + // 2. 动态重写 + if !found && rewriteBy != nil { + rewriteToPath, found = rewriteBy(request) + } + + // 3. 正则匹配 + if !found { + for _, ri := range regexRewrites { + if ri.matcher != nil { + finds := ri.matcher.FindAllStringSubmatch(request.RequestURI, 1) + if len(finds) > 0 { + toPath := ri.toPath + for i, part := range finds[0] { + toPath = strings.ReplaceAll(toPath, fmt.Sprintf("$%d", i), part) + } + rewriteToPath = toPath + found = true + break + } + } + } + } + rewritesLock.RUnlock() + + if found { + if strings.Contains(rewriteToPath, "://") { + // 外部重定向 + if !strings.Contains(rewriteToPath, "?") && queryString != "" { + rewriteToPath += queryString + } + response.Header().Set("Location", rewriteToPath) + response.WriteHeader(302) + return true + } else { + // 内部重写 + logger.Info("rewrite", "from", request.RequestURI, "to", rewriteToPath) + if queryString != "" && !strings.Contains(rewriteToPath, "?") { + rewriteToPath += queryString + } + request.RequestURI = rewriteToPath + request.URL, _ = url.Parse(rewriteToPath) + return false // 继续后续处理 + } + } + + return false +} diff --git a/server.go b/server.go new file mode 100644 index 0000000..28a63b3 --- /dev/null +++ b/server.go @@ -0,0 +1,88 @@ +package service + +import ( + "apigo.cc/go/log" + "context" + "net" + "net/http" + "os" + "os/signal" + "syscall" + "time" +) + +// AsyncServer 异步服务实例 +type AsyncServer struct { + server *http.Server + listener net.Listener + Addr string + stopChan chan os.Signal + startChan chan bool +} + +// AsyncStart 异步启动服务 +func AsyncStart() *AsyncServer { + as := &AsyncServer{ + startChan: make(chan bool, 1), + stopChan: make(chan os.Signal, 1), + } + + go as.start() + + <-as.startChan + return as +} + +func (as *AsyncServer) start() { + if Config.Listen == "" { + Config.Listen = ":8080" // 默认端口 + } + + listener, err := net.Listen("tcp", Config.Listen) + if err != nil { + log.DefaultLogger.Error("failed to listen", "addr", Config.Listen, "error", err.Error()) + as.startChan <- false + return + } + + as.listener = listener + as.Addr = listener.Addr().String() + serverAddr = as.Addr + + as.server = &http.Server{ + Handler: &routeHandler{}, + } + + signal.Notify(as.stopChan, os.Interrupt, syscall.SIGTERM) + + go func() { + log.DefaultLogger.Info("service starting", "addr", as.Addr) + as.startChan <- true + if err := as.server.Serve(listener); err != nil && err != http.ErrServerClosed { + log.DefaultLogger.Error("server error", "error", err.Error()) + } + }() +} + +// Stop 停止服务 +func (as *AsyncServer) Stop() { + log.DefaultLogger.Info("service stopping") + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := as.server.Shutdown(ctx); err != nil { + log.DefaultLogger.Error("server shutdown error", "error", err.Error()) + } + log.DefaultLogger.Info("service stopped") +} + +// Wait 等待服务结束 (信号监听) +func (as *AsyncServer) Wait() { + <-as.stopChan + as.Stop() +} + +// Start 同步启动服务 +func Start() { + AsyncStart().Wait() +} diff --git a/server_test.go b/server_test.go new file mode 100644 index 0000000..70d00f3 --- /dev/null +++ b/server_test.go @@ -0,0 +1,31 @@ +package service + +import ( + "net/http" + "testing" +) + +func TestAsyncServer(t *testing.T) { + Config.Listen = ":0" // 随机端口 + as := AsyncStart() + if as.Addr == "" { + t.Fatal("AsyncStart failed to get address") + } + + // 测试服务是否可用 + resp, err := http.Get("http://" + as.Addr + "/__CHECK__") + if err == nil { + // 虽然没有注册 /__CHECK__,但应该返回 404 而非连接拒绝 + if resp.StatusCode != http.StatusNotFound { + t.Errorf("Expected 404, got %d", resp.StatusCode) + } + } + + as.Stop() + + // 确认服务已关闭 + _, err = http.Get("http://" + as.Addr + "/__CHECK__") + if err == nil { + t.Error("Server should be closed") + } +} diff --git a/service.go b/service.go new file mode 100644 index 0000000..5dea62c --- /dev/null +++ b/service.go @@ -0,0 +1,225 @@ +package service + +import ( + "apigo.cc/go/log" + "errors" + "reflect" + "regexp" + "strings" + "sync" +) + +// WebServiceOptions 服务注册选项 +type WebServiceOptions struct { + Priority int + NoDoc bool + NoBody bool + NoLog200 bool + Host string + Ext Map + // Limiters []*Limiter // TODO: Integrate Limiter +} + +// webServiceType 内部存储的服务元数据 +type webServiceType struct { + authLevel int + method string + path string + pathMatcher *regexp.Regexp + pathArgs []string + parmsNum int + inType reflect.Type + inIndex int + headersType reflect.Type + headersIndex int + requestIndex int + httpRequestIndex int + responseIndex int + responseWriterIndex int + loggerIndex int + callerIndex int + funcType reflect.Type + funcValue reflect.Value + options WebServiceOptions + data Map + memo string +} + +var ( + serverId string + serverAddr string + serverProto = "http" + serverProtoName = "http" + running = false + + webServices = make(map[string]*webServiceType) + regexWebServices = make([]*webServiceType, 0) + webServicesLock = sync.RWMutex{} + webServicesList = make([]*webServiceType, 0) + + websocketServices = make(map[string]*websocketServiceType) + regexWebsocketServices = make([]*websocketServiceType, 0) + websocketServicesLock = sync.RWMutex{} + websocketServicesList = make([]*websocketServiceType, 0) + + // 过滤器与拦截器 + inFilters = make([]func(*map[string]any, *Request, *Response, *log.Logger) any, 0) + outFilters = make([]func(map[string]any, *Request, *Response, any, *log.Logger) (any, bool), 0) + errorHandle func(any, *Request, *Response) any + webAuthChecker func(int, *log.Logger, *string, map[string]any, *Request, *Response, *WebServiceOptions) (pass bool, object any) + webAuthCheckers = make(map[int]func(int, *log.Logger, *string, map[string]any, *Request, *Response, *WebServiceOptions) (pass bool, object any)) + + // 注入点 + injectObjects = make(map[reflect.Type]any) + injectFunctions = make(map[reflect.Type]func() any) + + usedDeviceIdKey string + usedClientAppKey string + usedSessionIdKey string + sessionIdMaker func() string +) + +// SetClientKeys 设置客户端标识相关的 Key 映射 +func SetClientKeys(deviceIdKey, clientAppKey, sessionIdKey string) { + usedDeviceIdKey = deviceIdKey + usedClientAppKey = clientAppKey + usedSessionIdKey = sessionIdKey +} + +// SetSessionIdMaker 设置自定义会话 ID 生成器 +func SetSessionIdMaker(maker func() string) { + sessionIdMaker = maker +} + +// SetAuthChecker 设置全局鉴权器 +func SetAuthChecker(authChecker func(authLevel int, logger *log.Logger, url *string, in map[string]any, request *Request, response *Response, options *WebServiceOptions) (pass bool, object any)) { + webAuthChecker = authChecker +} + +// AddAuthChecker 为指定级别添加鉴权器 +func AddAuthChecker(authLevels []int, authChecker func(authLevel int, logger *log.Logger, url *string, in map[string]any, request *Request, response *Response, options *WebServiceOptions) (pass bool, object any)) { + for _, al := range authLevels { + webAuthCheckers[al] = authChecker + } +} + +// SetInFilter 设置前置过滤器 +func SetInFilter(filter func(in *map[string]any, request *Request, response *Response, logger *log.Logger) (out any)) { + inFilters = append(inFilters, filter) +} + +// SetOutFilter 设置后置过滤器 +func SetOutFilter(filter func(in map[string]any, request *Request, response *Response, out any, logger *log.Logger) (newOut any, isOver bool)) { + outFilters = append(outFilters, filter) +} + +// Register 注册服务(通用方法) +func Register(authLevel int, path string, serviceFunc any, memo string) { + Restful(authLevel, "", path, serviceFunc, memo) +} + +// Restful 注册指定方法的服务 +func Restful(authLevel int, method, path string, serviceFunc any, memo string) { + RestfulWithOptions(authLevel, method, path, serviceFunc, memo, WebServiceOptions{}) +} + +// RestfulWithOptions 注册带选项的服务 +func RestfulWithOptions(authLevel int, method, path string, serviceFunc any, memo string, options WebServiceOptions) { + s, err := makeCachedService(serviceFunc) + if err != nil { + // TODO: Log error properly when logger is ready + return + } + + s.authLevel = authLevel + s.options = options + s.method = method + s.path = path + s.memo = memo + + // 解析路径参数 {name} + finder, err := regexp.Compile("{(.*?)}") + if err == nil { + keyName := regexp.QuoteMeta(path) + finds := finder.FindAllStringSubmatch(path, 20) + for _, found := range finds { + keyName = strings.Replace(keyName, regexp.QuoteMeta(found[0]), "(.*?)", 1) + s.pathArgs = append(s.pathArgs, found[1]) + } + if len(s.pathArgs) > 0 { + s.pathMatcher, _ = regexp.Compile("^" + keyName + "$") + } + } + + webServicesLock.Lock() + defer webServicesLock.Unlock() + + // 简单路径匹配 + if s.pathMatcher == nil { + webServices[method+path] = s // TODO: Include Host in key + } else { + regexWebServices = append(regexWebServices, s) + } + webServicesList = append(webServicesList, s) +} + +func makeCachedService(matchedService any) (*webServiceType, error) { + funcType := reflect.TypeOf(matchedService) + if funcType.Kind() != reflect.Func { + return nil, errors.New("handler must be a function") + } + + targetService := &webServiceType{ + parmsNum: funcType.NumIn(), + inIndex: -1, + headersIndex: -1, + requestIndex: -1, + httpRequestIndex: -1, + responseIndex: -1, + responseWriterIndex: -1, + loggerIndex: -1, + callerIndex: -1, + funcType: funcType, + funcValue: reflect.ValueOf(matchedService), + } + + for i := 0; i < targetService.parmsNum; i++ { + t := funcType.In(i) + tStr := t.String() + switch tStr { + case "*service.Request": + targetService.requestIndex = i + case "*http.Request": + targetService.httpRequestIndex = i + case "*service.Response": + targetService.responseIndex = i + case "http.ResponseWriter": + targetService.responseWriterIndex = i + case "*log.Logger": + targetService.loggerIndex = i + default: + if t.Kind() == reflect.Struct || (t.Kind() == reflect.Map && t.Elem().Kind() == reflect.Interface) { + if targetService.inType == nil { + targetService.inIndex = i + targetService.inType = t + } else if targetService.headersType == nil { + targetService.headersIndex = i + targetService.headersType = t + } + } + } + } + + return targetService, nil +} + +// GetInject 获取注入对象 +func GetInject(dataType reflect.Type) any { + if obj, exists := injectObjects[dataType]; exists { + return obj + } + if factory, exists := injectFunctions[dataType]; exists { + return factory() + } + return nil +} diff --git a/service_test.go b/service_test.go new file mode 100644 index 0000000..1e4c400 --- /dev/null +++ b/service_test.go @@ -0,0 +1,54 @@ +package service + +import ( + "apigo.cc/go/log" + "testing" +) + +func TestServiceRegister(t *testing.T) { + handler := func(req *Request, logger *log.Logger) string { + return "ok" + } + + Register(0, "/test", handler, "test service") + + webServicesLock.RLock() + s := webServices["/test"] + webServicesLock.RUnlock() + + if s == nil { + t.Fatal("Service not registered") + } + + if s.requestIndex != 0 { + t.Errorf("requestIndex mismatch: expected 0, got %d", s.requestIndex) + } + if s.loggerIndex != 1 { + t.Errorf("loggerIndex mismatch: expected 1, got %d", s.loggerIndex) + } +} + +func TestRegexServiceRegister(t *testing.T) { + handler := func(args map[string]any) string { + return "ok" + } + + Register(0, "/user/{id}", handler, "get user") + + webServicesLock.RLock() + found := false + for _, s := range regexWebServices { + if s.path == "/user/{id}" { + found = true + if len(s.pathArgs) != 1 || s.pathArgs[0] != "id" { + t.Errorf("pathArgs mismatch: %v", s.pathArgs) + } + break + } + } + webServicesLock.RUnlock() + + if !found { + t.Fatal("Regex service not registered") + } +} diff --git a/starter.go b/starter.go new file mode 100644 index 0000000..970005f --- /dev/null +++ b/starter.go @@ -0,0 +1,49 @@ +package service + +import ( + "fmt" + "os" + "path/filepath" +) + +// StartCmd 命令行命令定义 +type StartCmd struct { + Name string + Comment string + Func func() +} + +var startCmds = []StartCmd{ + {"start", "Start server", Start}, +} + +// AddCmd 添加自定义命令行命令 +func AddCmd(name, comment string, function func()) { + startCmds = append(startCmds, StartCmd{name, comment, function}) +} + +// CheckCmd 检查并执行命令行命令 +func CheckCmd() { + if len(os.Args) > 1 { + cmd := os.Args[1] + if cmd == "help" || cmd == "--help" { + showHelp() + os.Exit(0) + } + + for _, cmdInfo := range startCmds { + if cmd == cmdInfo.Name { + cmdInfo.Func() + os.Exit(0) + } + } + } +} + +func showHelp() { + fmt.Printf("Usage: %s [command]\n\n", filepath.Base(os.Args[0])) + fmt.Println("Available commands:") + for _, cmdInfo := range startCmds { + fmt.Printf(" %-10s %s\n", cmdInfo.Name, cmdInfo.Comment) + } +} diff --git a/static.go b/static.go new file mode 100644 index 0000000..5d33d67 --- /dev/null +++ b/static.go @@ -0,0 +1,123 @@ +package service + +import ( + "apigo.cc/go/file" + "apigo.cc/go/log" + "mime" + "net/http" + "os" + "path/filepath" + "strings" + "sync" + "time" +) + +var ( + statics = make(map[string]*string) + staticsByHost = make(map[string]map[string]*string) + staticsByHostLock = sync.RWMutex{} +) + +// Static 注册静态文件目录 +func Static(path, rootPath string) { + StaticByHost(path, rootPath, "") +} + +// StaticByHost 为指定域名注册静态文件目录 +func StaticByHost(path, rootPath, host string) { + if !filepath.IsAbs(rootPath) { + if absPath, err := filepath.Abs(rootPath); err == nil { + rootPath = absPath + } + } + + staticsByHostLock.Lock() + defer staticsByHostLock.Unlock() + + if host == "" { + statics[path] = &rootPath + } else { + if staticsByHost[host] == nil { + staticsByHost[host] = make(map[string]*string) + } + staticsByHost[host][path] = &rootPath + } +} + +func getStaticFilePath(requestPath, host string) string { + staticsByHostLock.RLock() + defer staticsByHostLock.RUnlock() + + // 优先匹配指定域名的配置 + if hostConfig, exists := staticsByHost[host]; exists { + if filePath := findMatchedPath(hostConfig, requestPath); filePath != "" { + return filePath + } + } + + // 匹配全局配置 + return findMatchedPath(statics, requestPath) +} + +func findMatchedPath(config map[string]*string, requestPath string) string { + for urlPath, rootPath := range config { + if strings.HasPrefix(requestPath, urlPath) { + return filepath.Join(*rootPath, requestPath[len(urlPath):]) + } + } + return "" +} + +func processStatic(requestPath string, request *Request, response *Response, logger *log.Logger) bool { + filePath := getStaticFilePath(requestPath, request.Host) + if filePath == "" { + return false + } + + info, err := os.Stat(filePath) + if err != nil { + return false + } + + if info.IsDir() { + // 自动查找索引文件 + for _, indexFile := range Config.IndexFiles { + f := filepath.Join(filePath, indexFile) + if i, err := os.Stat(f); err == nil && !i.IsDir() { + filePath = f + info = i + break + } + } + } + + if info.IsDir() { + return false + } + + // 检查 304 + if ifModifiedSince := request.Header.Get("If-Modified-Since"); ifModifiedSince != "" { + if t, err := time.Parse(http.TimeFormat, ifModifiedSince); err == nil { + if !info.ModTime().Truncate(time.Second).After(t.Truncate(time.Second)) { + response.WriteHeader(http.StatusNotModified) + return true + } + } + } + + // 发送文件 + contentType := mime.TypeByExtension(filepath.Ext(filePath)) + if contentType == "" { + contentType = "application/octet-stream" + } + response.Header().Set("Content-Type", contentType) + response.Header().Set("Last-Modified", info.ModTime().UTC().Format(http.TimeFormat)) + + data, err := file.ReadBytes(filePath) + if err != nil { + return false + } + + _, _ = response.Write(data) + return true +} diff --git a/static_test.go b/static_test.go new file mode 100644 index 0000000..20a6056 --- /dev/null +++ b/static_test.go @@ -0,0 +1,43 @@ +package service + +import ( + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" +) + +func TestStaticService(t *testing.T) { + // 创建临时测试目录和文件 + tempDir, _ := os.MkdirTemp("", "static_test") + defer os.RemoveAll(tempDir) + + testFile := filepath.Join(tempDir, "index.html") + os.WriteFile(testFile, []byte("

Static Page

"), 0644) + + // 注册静态目录 + Static("/ui", tempDir) + + rh := &routeHandler{} + + // 测试成功访问 + req := httptest.NewRequest("GET", "/ui/index.html", nil) + w := httptest.NewRecorder() + rh.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected 200, got %d", w.Code) + } + if body := w.Body.String(); body != "

Static Page

" { + t.Errorf("Content mismatch: %s", body) + } + + // 测试 404 + req404 := httptest.NewRequest("GET", "/ui/notfound.html", nil) + w404 := httptest.NewRecorder() + rh.ServeHTTP(w404, req404) + if w404.Code != http.StatusNotFound { + t.Errorf("Expected 404 for missing file, got %d", w404.Code) + } +} diff --git a/types.go b/types.go new file mode 100644 index 0000000..5ce805f --- /dev/null +++ b/types.go @@ -0,0 +1,68 @@ +package service + +// Map 通用 Map 类型 +type Map = map[string]any + +// Arr 通用切片类型 +type Arr = []any + +// Argot 错误码/标识符类型 +type Argot string + +// Result 通用返回结构 +type Result struct { + Ok bool `json:"ok"` + Argot Argot `json:"argot,omitempty"` + Message string `json:"message,omitempty"` +} + +// CodeResult 带状态码的返回结构 +type CodeResult struct { + Code int `json:"code"` + Message string `json:"message,omitempty"` +} + +// ArgotInfo 标识符信息(用于文档生成) +type ArgotInfo struct { + Name Argot + Memo string +} + +// OK 设置成功状态 +func (r *Result) OK(argots ...Argot) { + r.Ok = true + if len(argots) > 0 { + r.Argot = argots[0] + } +} + +// Failed 设置失败状态 +func (r *Result) Failed(message string, argots ...Argot) { + r.Ok = false + r.Message = message + if len(argots) > 0 { + r.Argot = argots[0] + } +} + +// Done 根据布尔值设置状态 +func (r *Result) Done(ok bool, failedMessage string, argots ...Argot) { + r.Ok = ok + if !ok { + r.Message = failedMessage + if len(argots) > 0 { + r.Argot = argots[0] + } + } +} + +// OK 设置成功状态 (Code=1) +func (r *CodeResult) OK() { + r.Code = 1 +} + +// Failed 设置失败状态与错误码 +func (r *CodeResult) Failed(code int, message string) { + r.Code = code + r.Message = message +} diff --git a/types_test.go b/types_test.go new file mode 100644 index 0000000..65ba6ba --- /dev/null +++ b/types_test.go @@ -0,0 +1,41 @@ +package service + +import ( + "testing" +) + +func TestResult(t *testing.T) { + r := &Result{} + r.OK() + if !r.Ok { + t.Error("Result.OK() failed") + } + + r.Failed("error", Argot("ERR_CODE")) + if r.Ok || r.Message != "error" || r.Argot != "ERR_CODE" { + t.Error("Result.Failed() failed") + } + + r.Done(true, "never") + if !r.Ok { + t.Error("Result.Done(true) failed") + } + + r.Done(false, "failed", Argot("FAIL")) + if r.Ok || r.Message != "failed" || r.Argot != "FAIL" { + t.Error("Result.Done(false) failed") + } +} + +func TestCodeResult(t *testing.T) { + cr := &CodeResult{} + cr.OK() + if cr.Code != 1 { + t.Error("CodeResult.OK() failed") + } + + cr.Failed(500, "internal error") + if cr.Code != 500 || cr.Message != "internal error" { + t.Error("CodeResult.Failed() failed") + } +} diff --git a/utility.go b/utility.go new file mode 100644 index 0000000..293dcb0 --- /dev/null +++ b/utility.go @@ -0,0 +1,20 @@ +package service + +import ( + "apigo.cc/go/id" +) + +// MakeId 生成指定长度的 ID +func MakeId(size int) string { + return id.MakeID(size) +} + +// MakeIdForMysql 生成适用于 MySQL 的有序 ID +func MakeIdForMysql(size int) string { + return id.DefaultIDMaker.GetForMysql(size) +} + +// MakeIdForPostgreSQL 生成适用于 PostgreSQL 的有序 ID +func MakeIdForPostgreSQL(size int) string { + return id.DefaultIDMaker.GetForPostgreSQL(size) +} diff --git a/verify.go b/verify.go new file mode 100644 index 0000000..bce9bd0 --- /dev/null +++ b/verify.go @@ -0,0 +1,305 @@ +package service + +import ( + "apigo.cc/go/cast" + "apigo.cc/go/log" + "reflect" + "regexp" + "strings" + "sync" +) + +// VerifyType 校验类型 +type VerifyType uint8 + +const ( + VerifyUnknown VerifyType = iota + VerifyRegex + VerifyStringLength + VerifyGreaterThan + VerifyLessThan + VerifyBetween + VerifyInList + VerifyByFunc +) + +// VerifySet 校验规则集 +type VerifySet struct { + Type VerifyType + Regex *regexp.Regexp + StringArgs []string + IntArgs []int + FloatArgs []float64 + Func func(any, []string) bool +} + +var ( + verifySets = make(map[string]*VerifySet) + verifySetsLock = sync.RWMutex{} + verifyFunctions = make(map[string]func(any, []string) bool) + verifyFunctionsLock = sync.RWMutex{} +) + +// RegisterVerifyFunc 注册自定义校验函数 +func RegisterVerifyFunc(name string, f func(in any, args []string) bool) { + verifyFunctionsLock.Lock() + verifyFunctions[name] = f + verifyFunctionsLock.Unlock() +} + +// RegisterVerify 注册预定义校验规则 +func RegisterVerify(name, setting string) { + set, _ := compileVerifySet(setting) + if set != nil { + verifySetsLock.Lock() + verifySets[name] = set + verifySetsLock.Unlock() + } +} + +// VerifyStruct 校验结构体 +func VerifyStruct(in any, logger *log.Logger) (ok bool, field string) { + v := cast.RealValue(reflect.ValueOf(in)) + if v.Kind() != reflect.Struct { + if logger != nil { + logger.Error("verify input is not struct", "type", v.Type().String()) + } + return false, "" + } + + for i := 0; i < v.NumField(); i++ { + ft := v.Type().Field(i) + fv := v.Field(i) + + // 忽略空指针、空切片、空 Map + if (fv.Kind() == reflect.Ptr && fv.IsNil()) || + (fv.Kind() == reflect.Slice && fv.Len() == 0) || + (fv.Kind() == reflect.Map && fv.Len() == 0) { + continue + } + + if ft.Anonymous { + // 处理嵌套结构体(继承) + if fv.CanInterface() { + if ok, f := VerifyStruct(fv.Interface(), logger); !ok { + return false, f + } + } + continue + } + + tag := ft.Tag.Get("verify") + keyTag := ft.Tag.Get("verifyKey") + if tag != "" || keyTag != "" { + var err error + ok, f, err := _verifyValue(fv, tag, keyTag, logger) + if !ok { + if f == "" { + f = cast.GetLowerName(ft.Name) + } + if logger != nil { + if err != nil { + logger.Error(err.Error(), "field", f) + } else { + logger.Warning("verify failed", "field", f, "tag", tag) + } + } + return false, f + } + } + } + return true, "" +} + +func _verifyValue(in reflect.Value, setting, keySetting string, logger *log.Logger) (bool, string, error) { + t := in.Type() + // 处理切片 (非 byte 切片) + if t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8 { + if setting != "" { + for i := 0; i < in.Len(); i++ { + if ok, f, err := _verifyValue(in.Index(i), setting, "", logger); !ok { + return false, f, err + } + } + } + return true, "", nil + } + + // 处理 Map + if t.Kind() == reflect.Map { + for _, k := range in.MapKeys() { + if keySetting != "" { + if ok, _, err := _verifyValue(k, keySetting, "", logger); !ok { + return false, "key", err + } + } + if setting != "" { + if ok, f, err := _verifyValue(in.MapIndex(k), setting, "", logger); !ok { + return false, f, err + } + } + } + return true, "", nil + } + + // 处理嵌套 Struct + if t.Kind() == reflect.Struct { + ok, f := VerifyStruct(in.Interface(), logger) + return ok, f, nil + } + + // 基础校验 + if setting == "" { + return true, "", nil + } + + ok, err := verify(in.Interface(), setting) + return ok, "", err +} + +func verify(in any, setting string) (bool, error) { + if len(setting) < 2 { + return false, nil + } + + verifySetsLock.RLock() + set, exists := verifySets[setting] + verifySetsLock.RUnlock() + + if !exists { + var err error + set, err = compileVerifySet(setting) + if err != nil { + return false, err + } + verifySetsLock.Lock() + verifySets[setting] = set + verifySetsLock.Unlock() + } + + switch set.Type { + case VerifyByFunc: + return set.Func(in, set.StringArgs), nil + case VerifyRegex: + return set.Regex.MatchString(cast.String(in)), nil + case VerifyStringLength: + l := len(cast.String(in)) + if len(set.StringArgs) > 0 { + if set.StringArgs[0] == "+" { + return l >= set.IntArgs[0], nil + } else if set.StringArgs[0] == "-" { + return l <= set.IntArgs[0], nil + } + } + if len(set.IntArgs) > 1 { + return l >= set.IntArgs[0] && l <= set.IntArgs[1], nil + } + return l == set.IntArgs[0], nil + case VerifyGreaterThan: + return cast.Float64(in) > set.FloatArgs[0], nil + case VerifyLessThan: + return cast.Float64(in) < set.FloatArgs[0], nil + case VerifyBetween: + val := cast.Float64(in) + return val >= set.FloatArgs[0] && val <= set.FloatArgs[1], nil + case VerifyInList: + s := cast.String(in) + for _, item := range set.StringArgs { + if item == s { + return true, nil + } + } + return false, nil + } + return false, nil +} + +func compileVerifySet(setting string) (*VerifySet, error) { + set := &VerifySet{Type: VerifyUnknown} + if setting == "" { + return set, nil + } + + if setting[0] != '^' { + key := setting + args := "" + if pos := strings.IndexByte(setting, ':'); pos != -1 { + key = setting[:pos] + args = setting[pos+1:] + } + + // 优先查找自定义函数 + verifyFunctionsLock.RLock() + f, exists := verifyFunctions[key] + verifyFunctionsLock.RUnlock() + if exists { + set.Type = VerifyByFunc + set.Func = f + if args != "" { + set.StringArgs = strings.Split(args, ",") + } + return set, nil + } + + // 内置规则 + switch key { + case "length": + set.Type = VerifyStringLength + if args == "" { + args = "1+" + } + last := args[len(args)-1] + if last == '+' || last == '-' { + set.StringArgs = []string{string(last)} + args = args[:len(args)-1] + } + // 同时支持逗号和中划线 + sep := "," + if strings.Contains(args, "-") && !strings.Contains(args, ",") { + sep = "-" + } + if strings.Contains(args, sep) { + a := strings.Split(args, sep) + set.IntArgs = []int{cast.Int(a[0]), cast.Int(a[1])} + } else { + set.IntArgs = []int{cast.Int(args)} + } + return set, nil + case "between": + set.Type = VerifyBetween + if args == "" { + args = "1-100000000" + } + a := strings.Split(args, "-") + if len(a) == 1 { + set.FloatArgs = []float64{0, cast.Float64(a[0])} + } else { + set.FloatArgs = []float64{cast.Float64(a[0]), cast.Float64(a[1])} + } + return set, nil + case "gt": + set.Type = VerifyGreaterThan + set.FloatArgs = []float64{cast.Float64(args)} + return set, nil + case "lt": + set.Type = VerifyLessThan + set.FloatArgs = []float64{cast.Float64(args)} + return set, nil + case "in": + set.Type = VerifyInList + if args != "" { + set.StringArgs = strings.Split(args, ",") + } + return set, nil + } + } + + // 默认视为正则表达式 + rx, err := regexp.Compile(setting) + if err != nil { + return nil, err + } + set.Type = VerifyRegex + set.Regex = rx + return set, nil +} diff --git a/verify_test.go b/verify_test.go new file mode 100644 index 0000000..b4f5d17 --- /dev/null +++ b/verify_test.go @@ -0,0 +1,76 @@ +package service + +import ( + "testing" +) + +type TestUser struct { + Name string `verify:"length:2-10"` + Age int `verify:"between:18-100"` + Type string `verify:"in:admin,user,guest"` +} + +type NestedStruct struct { + TestUser + Note string `verify:"^.{1,20}$"` +} + +func TestVerifyStruct(t *testing.T) { + u := TestUser{Name: "Star", Age: 25, Type: "admin"} + if ok, f := VerifyStruct(u, nil); !ok { + t.Errorf("VerifyStruct failed on valid user, field: %s", f) + } + + u.Name = "S" + if ok, f := VerifyStruct(u, nil); ok || f != "name" { + t.Errorf("VerifyStruct should fail on short name, got ok=%v, field=%s", ok, f) + } + + u.Name = "Star" + u.Age = 10 + if ok, f := VerifyStruct(u, nil); ok || f != "age" { + t.Errorf("VerifyStruct should fail on young age, got ok=%v, field=%s", ok, f) + } + + u.Age = 25 + u.Type = "invalid" + if ok, f := VerifyStruct(u, nil); ok || f != "type" { + t.Errorf("VerifyStruct should fail on invalid type, got ok=%v, field=%s", ok, f) + } +} + +func TestNestedVerify(t *testing.T) { + n := NestedStruct{ + TestUser: TestUser{Name: "Star", Age: 25, Type: "user"}, + Note: "Hello", + } + if ok, f := VerifyStruct(n, nil); !ok { + t.Errorf("Nested VerifyStruct failed on valid data, field: %s", f) + } + + n.TestUser.Age = 5 + if ok, f := VerifyStruct(n, nil); ok || f != "age" { + t.Errorf("Nested VerifyStruct should fail on nested age, got ok=%v, field=%s", ok, f) + } +} + +func TestCustomVerify(t *testing.T) { + RegisterVerifyFunc("odd", func(in any, args []string) bool { + val := in.(int) + return val%2 != 0 + }) + + type OddStruct struct { + Num int `verify:"odd"` + } + + o := OddStruct{Num: 3} + if ok, f := VerifyStruct(o, nil); !ok { + t.Errorf("Custom verify failed on odd number, field: %s", f) + } + + o.Num = 4 + if ok, f := VerifyStruct(o, nil); ok || f != "num" { + t.Errorf("Custom verify should fail on even number, got ok=%v, field=%s", ok, f) + } +} diff --git a/websocket.go b/websocket.go new file mode 100644 index 0000000..323eb5d --- /dev/null +++ b/websocket.go @@ -0,0 +1,175 @@ +package service + +import ( + "apigo.cc/go/cast" + "apigo.cc/go/log" + "github.com/gorilla/websocket" + "net/http" + "reflect" + "regexp" +) + +// websocketServiceType WebSocket 服务元数据 +type websocketServiceType struct { + authLevel int + path string + pathMatcher *regexp.Regexp + pathArgs []string + updater *websocket.Upgrader + openFuncValue reflect.Value + openFuncType reflect.Type + closeFuncValue reflect.Value + closeFuncType reflect.Type + sessionType reflect.Type + actions map[string]*websocketActionType + isSimple bool + options WebServiceOptions + memo string +} + +// websocketActionType WebSocket Action 元数据 +type websocketActionType struct { + authLevel int + funcValue reflect.Value + funcType reflect.Type + inType reflect.Type + memo string +} + +// ActionRegister WebSocket Action 注册器 +type ActionRegister struct { + ws *websocketServiceType +} + +// RegisterWebsocket 注册 WebSocket 服务 +func RegisterWebsocket(authLevel int, path string, onOpen, onClose any, memo string) *ActionRegister { + s := &websocketServiceType{ + authLevel: authLevel, + path: path, + memo: memo, + updater: &websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}, + actions: make(map[string]*websocketActionType), + } + + if onOpen != nil { + s.openFuncValue = reflect.ValueOf(onOpen) + s.openFuncType = s.openFuncValue.Type() + if s.openFuncType.NumOut() > 0 { + s.sessionType = s.openFuncType.Out(0) + } + } + + if onClose != nil { + s.closeFuncValue = reflect.ValueOf(onClose) + s.closeFuncType = s.closeFuncValue.Type() + } + + websocketServicesLock.Lock() + websocketServices[path] = s + websocketServicesLock.Unlock() + + return &ActionRegister{ws: s} +} + +// RegisterAction 注册 WebSocket Action +func (ar *ActionRegister) RegisterAction(authLevel int, name string, action any, memo string) { + v := reflect.ValueOf(action) + t := v.Type() + a := &websocketActionType{ + authLevel: authLevel, + funcValue: v, + funcType: t, + memo: memo, + } + + // 查找输入参数类型 + for i := 0; i < t.NumIn(); i++ { + inT := t.In(i) + if inT.Kind() == reflect.Struct { + a.inType = inT + break + } + } + + ar.ws.actions[name] = a +} + +func doWebsocketService(ws *websocketServiceType, request *Request, response *Response, logger *log.Logger) { + conn, err := ws.updater.Upgrade(response.Writer, request.Request, nil) + if err != nil { + logger.Error("websocket upgrade failed", "error", err.Error()) + return + } + defer conn.Close() + + var session any + if ws.openFuncValue.IsValid() { + // 简化版:仅支持基础参数注入 + params := make([]reflect.Value, ws.openFuncType.NumIn()) + for i := 0; i < len(params); i++ { + t := ws.openFuncType.In(i) + if t == reflect.TypeOf(request) { + params[i] = reflect.ValueOf(request) + } else if t == reflect.TypeOf(logger) { + params[i] = reflect.ValueOf(logger) + } else { + params[i] = reflect.New(t).Elem() + } + } + outs := ws.openFuncValue.Call(params) + if len(outs) > 0 { + session = outs[0].Interface() + } + } + + for { + var msg Map + if err := conn.ReadJSON(&msg); err != nil { + break + } + + actionName := cast.String(msg["action"]) + action := ws.actions[actionName] + if action == nil { + action = ws.actions[""] // 默认 action + } + + if action != nil { + params := make([]reflect.Value, action.funcType.NumIn()) + for i := 0; i < len(params); i++ { + t := action.funcType.In(i) + if t == ws.sessionType { + params[i] = reflect.ValueOf(session) + } else if t == reflect.TypeOf(conn) { + params[i] = reflect.ValueOf(conn) + } else if t.Kind() == reflect.Struct { + in := reflect.New(t).Interface() + cast.Convert(in, msg) + params[i] = reflect.ValueOf(in).Elem() + } else { + params[i] = reflect.New(t).Elem() + } + } + outs := action.funcValue.Call(params) + if len(outs) > 0 { + result := outs[0].Interface() + if result != nil { + _ = conn.WriteJSON(result) + } + } + } + } + + if ws.closeFuncValue.IsValid() { + params := make([]reflect.Value, ws.closeFuncType.NumIn()) + for i := 0; i < len(params); i++ { + t := ws.closeFuncType.In(i) + if t == ws.sessionType { + params[i] = reflect.ValueOf(session) + } else { + params[i] = reflect.New(t).Elem() + } + } + ws.closeFuncValue.Call(params) + } +} diff --git a/websocket_test.go b/websocket_test.go new file mode 100644 index 0000000..8689ff0 --- /dev/null +++ b/websocket_test.go @@ -0,0 +1,44 @@ +package service + +import ( + "github.com/gorilla/websocket" + "net/http/httptest" + "strings" + "testing" +) + +func TestWebSocketService(t *testing.T) { + // 注册 WebSocket 服务 + ar := RegisterWebsocket(0, "/ws", nil, nil, "test websocket") + ar.RegisterAction(0, "echo", func(in struct{ Msg string }) Map { + return Map{"action": "echo", "reply": in.Msg} + }, "echo action") + + // 启动测试服务器 + server := httptest.NewServer(&routeHandler{}) + defer server.Close() + + // 建立连接 + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/ws" + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + defer conn.Close() + + // 发送消息 + msg := Map{"action": "echo", "msg": "hello"} + if err := conn.WriteJSON(msg); err != nil { + t.Fatalf("WriteJSON failed: %v", err) + } + + // 接收响应 + var reply Map + if err := conn.ReadJSON(&reply); err != nil { + t.Fatalf("ReadJSON failed: %v", err) + } + + if reply["reply"] != "hello" { + t.Errorf("Reply mismatch: %v", reply) + } +}