Migrate service module from ssgo/s with modern Go features (by AI)

This commit is contained in:
AI Engineer 2026-05-08 07:27:06 +08:00
commit bdb104aa2f
28 changed files with 3135 additions and 0 deletions

267
.log.meta.json Normal file
View File

@ -0,0 +1,267 @@
{
"debug": [
{
"index": 0,
"name": "LogName",
"color": "cyan",
"hide": true
},
{
"index": 1,
"name": "LogType",
"color": "magenta",
"hide": true
},
{
"index": 2,
"name": "LogTime",
"format": "time"
},
{
"index": 3,
"name": "TraceId",
"color": "blue"
},
{
"index": 4,
"name": "Image",
"color": "darkGray",
"hide": true
},
{
"index": 5,
"name": "Server",
"color": "darkGray",
"hide": true
},
{
"index": 6,
"name": "Debug",
"withoutKey": true
},
{
"index": 7,
"name": "Extra"
}
],
"discover": [
{
"index": 0,
"name": "LogName",
"color": "cyan",
"hide": true
},
{
"index": 1,
"name": "LogType",
"color": "magenta",
"hide": true
},
{
"index": 2,
"name": "LogTime",
"format": "time"
},
{
"index": 3,
"name": "TraceId",
"color": "blue"
},
{
"index": 4,
"name": "Image",
"color": "darkGray",
"hide": true
},
{
"index": 5,
"name": "Server",
"color": "darkGray",
"hide": true
},
{
"index": 6,
"name": "App",
"color": "cyan"
},
{
"index": 7,
"name": "Method",
"color": "magenta"
},
{
"index": 8,
"name": "Path",
"color": "blue"
},
{
"index": 9,
"name": "Node",
"color": "yellow"
},
{
"index": 10,
"name": "Attempts"
},
{
"index": 11,
"name": "UsedTime",
"format": "%.2fms"
},
{
"index": 12,
"name": "Error",
"color": "red"
},
{
"index": 13,
"name": "Extra"
}
],
"error": [
{
"index": 0,
"name": "LogName",
"color": "cyan",
"hide": true
},
{
"index": 1,
"name": "LogType",
"color": "magenta",
"hide": true
},
{
"index": 2,
"name": "LogTime",
"format": "time"
},
{
"index": 3,
"name": "TraceId",
"color": "blue"
},
{
"index": 4,
"name": "Image",
"color": "darkGray",
"hide": true
},
{
"index": 5,
"name": "Server",
"color": "darkGray",
"hide": true
},
{
"index": 6,
"name": "Error",
"color": "red",
"withoutKey": true
},
{
"index": 7,
"name": "CallStacks"
},
{
"index": 8,
"name": "Extra"
}
],
"info": [
{
"index": 0,
"name": "LogName",
"color": "cyan",
"hide": true
},
{
"index": 1,
"name": "LogType",
"color": "magenta",
"hide": true
},
{
"index": 2,
"name": "LogTime",
"format": "time"
},
{
"index": 3,
"name": "TraceId",
"color": "blue"
},
{
"index": 4,
"name": "Image",
"color": "darkGray",
"hide": true
},
{
"index": 5,
"name": "Server",
"color": "darkGray",
"hide": true
},
{
"index": 6,
"name": "Info",
"color": "cyan",
"withoutKey": true
},
{
"index": 7,
"name": "Extra"
}
],
"warning": [
{
"index": 0,
"name": "LogName",
"color": "cyan",
"hide": true
},
{
"index": 1,
"name": "LogType",
"color": "magenta",
"hide": true
},
{
"index": 2,
"name": "LogTime",
"format": "time"
},
{
"index": 3,
"name": "TraceId",
"color": "blue"
},
{
"index": 4,
"name": "Image",
"color": "darkGray",
"hide": true
},
{
"index": 5,
"name": "Server",
"color": "darkGray",
"hide": true
},
{
"index": 6,
"name": "Warning",
"color": "yellow",
"withoutKey": true
},
{
"index": 7,
"name": "CallStacks"
},
{
"index": 8,
"name": "Extra"
}
]
}

216
DocTpl.html Normal file
View File

@ -0,0 +1,216 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>{{.title}}</title>
<style>
html {
overflow: auto;
height: 100%;
}
body {
margin: 0;
padding: 10px;
background: #fff;
color: #333;
font-size: 16px;
overflow: hidden;
display: flex;
margin: 0;
padding: 0;
height: 100%;
}
.nav {
overflow-x: hidden;
overflow-y: auto;
flex: 1;
background: #333;
color: #fff;
height: 100%;
padding: 10px 0;
box-sizing: border-box;
}
.navItem {
width: 100%;
display: block;
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
padding: 4px 10px;
cursor: pointer;
user-select: none;
color: #fff;
text-decoration: none;
}
.navItem:hover {
background: #99ccff;
}
.navItem .memo {
font-weight: normal;
font-size: 12px;
color: #ccc;
}
.apiBox {
overflow: auto;
flex: 3;
padding: 8px;
height: 100%;
}
header {
border-bottom: #ddd 1px solid;
margin-bottom: 5px;
background: #333;
color: #fff;
padding: 12px;
display: flex;
align-items: baseline;
}
header > span {
font-weight: bold;
margin-right: 10px;
}
header > span.memo {
font-weight: normal;
font-size: 14px;
color: #faebd7;
}
label {
display: inline-block;
margin-right: 10px;
padding: 4px 8px;
font-size: 12px;
background: #ccc;
color: #000;
font-weight: bold;
border-radius: 4px;
}
label.authLevel {
background: #f90;
color: #000;
}
label.authLevel0 {
background: #ccc;
color: #000;
}
label.type {
background: #9cf;
color: #000;
}
section {
margin-bottom: 40px;
white-space: nowrap;
font-size: 12px;
}
header.Action, section.Action {
margin-left: 20px;
}
section > table {
width: 50%;
display: inline-table;
border-collapse: collapse;
vertical-align: top;
font-size: 16px;
}
section > table:last-child {
border-left: 1px solid #ddd;
}
tr:nth-child(even) {
background: #f9f9f9;
}
th {
padding: 8px;
}
td {
padding: 6px 12px;
white-space: pre-wrap;
}
td:last-child {
color: #666;
}
</style>
</head>
<body>
<div class="nav">
{{range .api}}
<a href="#{{.Path}}" class="navItem">
<span>{{.Path}}</span>
<span class="memo">{{.Memo}}</span>
</a>
{{end}}
</div>
<div class="apiBox">
{{range .api}}
<a name="{{.Path}}"></a>
<div style="height: 16px"></div>
<header class="{{.Type}}">
<span>{{.Path}}</span>
<span class="memo">{{.Memo}}</span>
{{if ne .Method ""}}<label>{{.Method}}</label>{{end}}
<label title="Auth Level" class="authLevel authLevel{{.AuthLevel}}">{{.AuthLevel}}</label>
{{if ne .Type "Web"}}<label class="type">{{.Type}}</label>{{end}}
</header>
<section class="{{.Type}}">
<table>
{{if isMap .In}}
<tr>
<th colspan="2">Request</th>
</tr>
{{range $k, $v := .In}}
<tr>
<td width="30%">{{$k}}</td>
<td width="70%">{{toText $v}}</td>
</tr>
{{end}}
{{else}}
<tr>
<td colspan="2">{{.In}}</td>
</tr>
{{end}}
</table>
<table>
{{if isMap .Out}}
<tr>
<th colspan="2">Response</th>
</tr>
{{range $k, $v := .Out}}
<tr>
<td width="30%">{{$k}}</td>
<td width="70%">{{toText $v}}</td>
</tr>
{{end}}
{{else}}
<tr>
<td colspan="2">{{.Out}}</td>
</tr>
{{end}}
</table>
</section>
{{else}}
<div><strong>no document</strong></div>
{{end}}
<div style="height: 800px"></div>
</div>
</body>
</html>

56
README.md Normal file
View File

@ -0,0 +1,56 @@
# go/service (核心微服务框架)
极简、自动化的 Web 与 WebSocket 服务框架,实现极致的依赖注入与路由映射。
## 核心特性
- **路由反射**: 自动解析函数参数,支持 `*Request`, `*Response`, `*log.Logger` 及自定义结构体自动注入。
- **自动校验**: 集成 `verify` 引擎,通过 Struct Tag 实现入参合法性自动检查。
- **功能闭环**: 内置静态文件服务、WebSocket (带 Action 路由)、URL 重写、反向代理(对接 Discover
- **零摩擦启动**: 支持命令行指令管理 (start/stop/help) 及异步平滑启停。
## API 指南
### 1. 服务注册
```go
import "apigo.cc/go/service"
// 注册标准 Web 服务
service.Register(0, "/hello", func(in struct{ Name string }) string {
return "Hello " + in.Name
}, "打招呼接口")
// 注册 Restful 服务
service.Restful(0, "POST", "/user/{id}", func(args map[string]any) service.Result {
res := service.Result{}
res.OK()
return res
}, "更新用户")
```
### 2. WebSocket 支持
```go
ar := service.RegisterWebsocket(0, "/ws", onOpen, onClose, "聊天室")
ar.RegisterAction(0, "chat", func(in ChatMessage, sess *MySession) {
// 处理消息
}, "发送消息")
```
### 3. 增强插件
- **静态文件**: `service.Static("/ui", "./static_dir")`
- **URL 重写**: `service.Rewrite("/old", "/new")`
- **反向代理**: `service.Proxy(0, "/api", "other_app", "/api")`
### 4. 生命周期管理
```go
func main() {
service.CheckCmd() // 处理 start/stop/help 指令
service.Start() // 阻塞启动
}
```
## 基础设施对齐
- **类型转换**: `apigo.cc/go/cast`
- **日志系统**: `apigo.cc/go/log`
- **服务发现**: `apigo.cc/go/discover`
- **分布式 ID**: `apigo.cc/go/id`
- **文件操作**: `apigo.cc/go/file`

62
config.go Normal file
View File

@ -0,0 +1,62 @@
package service
// CertSet SSL 证书配置
type CertSet struct {
CertFile string
KeyFile string
}
// ServiceConfig 核心服务配置
type ServiceConfig struct {
Listen string // 监听端口(|隔开多个监听)(,隔开多个选项),例如 80,http|443|443:h2|127.0.0.1:8080,h2c
SSL map[string]*CertSet // SSL 证书配置key 为域名
NoLogGets bool // 不记录 GET 请求的日志
NoLogHeaders string // 不记录请求头中包含的这些字段,多个字段用逗号分隔
LogInputArrayNum int // 请求字段中容器类型在日志打印个数限制
LogInputFieldSize int // 请求字段中单个字段在日志打印长度限制
NoLogOutputFields string // 不记录响应字段中包含的这些字段
LogOutputArrayNum int // 响应字段中容器类型在日志打印个数限制
LogOutputFieldSize int // 响应字段中单个字段在日志打印长度限制
LogWebsocketAction bool // 记录 Websocket 中每个 Action 的请求日志
Compress bool // 是否启用压缩
CompressMinSize int // 启用压缩的最小长度
CompressMaxSize int // 启用压缩的最大长度
CheckDomain string // 心跳检测时使用域名
AccessTokens map[string]*int // 指定 Access-Token 验证及其对应的 auth-level
RedirectTimeout int // Proxy 和 Discover 发起请求时的超时时间 (ms)
AcceptXRealIpWithoutRequestId bool // 是否允许头部没有携带请求ID的 X-Real-IP 信息
StatisticTime bool // 是否开启请求时间统计
StatisticTimeInterval int // 统计时间间隔 (ms)
Fast bool // 是否启用快速模式
MaxUploadSize int64 // 最大上传文件大小 (Bytes)
IpPrefix string // Discover 服务发现时指定使用的 IP 网段
Cpu int // CPU 占用的核数限制
Memory int // 内存限制 (MB)
CpuMonitor bool // 记录 CPU 使用情况
MemoryMonitor bool // 记录内存使用情况
CpuLimitValue uint // CPU 自动重启阈值 (10-100)
MemoryLimitValue uint // 内存自动重启阈值 (10-100)
CpuLimitTimes uint // CPU 报警阈值连续次数
MemoryLimitTimes uint // 内存报警阈值连续次数
CookieScope string // Session Cookie 有效范围: host|domain|topDomain
SessionWithoutCookie bool // Session 禁用 Cookie
DeviceWithoutCookie bool // 设备ID禁用 Cookie
IdServer string // Redis 服务器连接 (用于全局唯一 ID 生成)
KeepKeyCase bool // 是否保持 Key 的首字母大小写
IndexFiles []string // 静态文件索引文件
IndexDir bool // 访问目录时显示文件列表
ReadTimeout int // 读取请求的超时时间 (ms)
ReadHeaderTimeout int // 读取请求头的超时时间 (ms)
WriteTimeout int // 响应写入的超时时间 (ms)
IdleTimeout int // 连接空闲超时时间 (ms)
MaxHeaderBytes int // 请求头的最大字节数
MaxHandlers int // 每个连接的最大处理程序数量
MaxConcurrentStreams uint32 // 每个连接的最大并发流数量
MaxDecoderHeaderTableSize uint32 // 解码器头表的最大大小
MaxEncoderHeaderTableSize uint32 // 编码器头表的最大大小
MaxReadFrameSize uint32 // 单个帧的最大读取大小
MaxUploadBufferPerConnection int32 // 每个连接的最大上传缓冲区大小
MaxUploadBufferPerStream int32 // 每个流的最大上传缓冲区大小
}
var Config = ServiceConfig{}

162
document.go Normal file
View File

@ -0,0 +1,162 @@
package service
import (
"apigo.cc/go/cast"
_ "embed"
"encoding/json"
"reflect"
)
// Api 接口文档信息
type Api struct {
Type string
Path string
AuthLevel int
Method string
In any
Out any
Memo string
}
//go:embed DocTpl.html
var defaultDocTpl string
// MakeDocument 生成文档数据
func MakeDocument() []Api {
out := make([]Api, 0)
// 1. Rewrite
rewritesLock.RLock()
for _, a := range rewrites {
out = append(out, Api{
Type: "Rewrite",
Path: a.fromPath + " -> " + a.toPath,
})
}
rewritesLock.RUnlock()
// 2. Proxy
proxiesLock.RLock()
for _, a := range proxies {
out = append(out, Api{
Type: "Proxy",
Path: a.fromPath + " -> " + a.toApp + ":" + a.toPath,
})
}
proxiesLock.RUnlock()
// 3. Web Services
webServicesLock.RLock()
for _, a := range webServicesList {
if a.options.NoDoc {
continue
}
api := Api{
Type: "Web",
Path: a.path,
AuthLevel: a.authLevel,
Method: a.method,
Memo: a.memo,
}
if a.inType != nil {
api.In = getType(a.inType)
}
if a.funcType.NumOut() > 0 {
api.Out = getType(a.funcType.Out(0))
}
out = append(out, api)
}
webServicesLock.RUnlock()
// 4. WebSocket Services
websocketServicesLock.RLock()
for _, a := range websocketServices {
api := Api{
Type: "WebSocket",
Path: a.path,
AuthLevel: a.authLevel,
Memo: a.memo,
}
if a.openFuncType != nil && a.openFuncType.NumIn() > 0 {
// Find struct in
for i := 0; i < a.openFuncType.NumIn(); i++ {
t := a.openFuncType.In(i)
if t.Kind() == reflect.Struct {
api.In = getType(t)
break
}
}
}
out = append(out, api)
for name, act := range a.actions {
actionApi := Api{
Type: "Action",
Path: name,
AuthLevel: act.authLevel,
Memo: act.memo,
}
if act.inType != nil {
actionApi.In = getType(act.inType)
}
if act.funcType.NumOut() > 0 {
actionApi.Out = getType(act.funcType.Out(0))
}
out = append(out, actionApi)
}
}
websocketServicesLock.RUnlock()
return out
}
// MakeJsonDocument 生成 JSON 格式文档
func MakeJsonDocument() string {
apis := MakeDocument()
data, _ := json.MarshalIndent(map[string]any{
"api": apis,
}, "", "\t")
return string(data)
}
func getType(t reflect.Type) any {
if t == nil {
return ""
}
for t.Kind() == reflect.Ptr {
t = t.Elem()
}
switch t.Kind() {
case reflect.Struct:
outs := Map{}
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
if f.Anonymous {
if subMap, ok := getType(f.Type).(Map); ok {
for k, v := range subMap {
outs[k] = v
}
}
} else {
outs[cast.GetLowerName(f.Name)] = getType(f.Type)
}
}
return outs
case reflect.Map:
return map[string]any{t.Key().String(): getType(t.Elem())}
case reflect.Slice:
return []any{getType(t.Elem())}
case reflect.Interface:
return "Any"
default:
return t.String()
}
}
// 自动注册文档服务
func init() {
Register(0, "/__DOC__", func() string {
return MakeJsonDocument()
}, "API Document")
}

5
go.mod Normal file
View File

@ -0,0 +1,5 @@
module apigo.cc/go/service
go 1.25.0
require github.com/gorilla/websocket v1.5.3

2
go.sum Normal file
View File

@ -0,0 +1,2 @@
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=

319
handler.go Normal file
View File

@ -0,0 +1,319 @@
package service
import (
"apigo.cc/go/cast"
"apigo.cc/go/id"
"apigo.cc/go/log"
"apigo.cc/go/standard"
"encoding/json"
"io"
"net/http"
"reflect"
"strings"
"sync/atomic"
"time"
)
type routeHandler struct {
webRequestingNum int64
}
func (rh *routeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
atomic.AddInt64(&rh.webRequestingNum, 1)
defer atomic.AddInt64(&rh.webRequestingNum, -1)
startTime := time.Now()
requestId := r.Header.Get(standard.DiscoverHeaderRequestId)
if requestId == "" {
requestId = id.MakeID(12)
r.Header.Set(standard.DiscoverHeaderRequestId, requestId)
}
request := NewRequest(r)
request.Id = requestId
response := NewResponse(w)
response.Id = requestId
defer response.checkWriteHeader()
// 处理 SessionId 和 DeviceId
handleClientKeys(request, response)
requestLogger := log.New(requestId)
// 0. 处理重写 (Rewrite)
if processRewrite(request, response, requestLogger) {
return
}
// 处理代理 (Proxy)
if processProxy(request, response, requestLogger) {
return
}
// 1. 路由匹配
path := r.URL.Path
host := r.Host
// 处理静态文件
if processStatic(path, request, response, requestLogger) {
return
}
s, ws := findService(r.Method, host, path)
// 2. 参数解析 (Form & Body)
args := make(map[string]any)
parseRequestArgs(request, args)
// 3. 前置过滤器
var result any
for _, filter := range inFilters {
result = filter(&args, request, response, requestLogger)
if result != nil {
break
}
}
// 4. 处理业务执行 (WS 或 Web)
if result == nil {
if ws != nil {
doWebsocketService(ws, request, response, requestLogger)
return
} else if s != nil {
// 鉴权
pass, obj := checkAuth(s, request, response, args, requestLogger)
if !pass {
if !response.changed {
response.WriteHeader(http.StatusForbidden)
}
return
}
// 执行业务
result = doWebService(s, request, response, args, nil, requestLogger, obj)
}
}
if s == nil && result == nil {
response.WriteHeader(http.StatusNotFound)
return
}
// 5. 后置过滤器
for _, filter := range outFilters {
newResult, done := filter(args, request, response, result, requestLogger)
if newResult != nil {
result = newResult
}
if done {
break
}
}
// 6. 输出结果
outputResult(response, result)
// 7. 记录日志
_ = startTime
}
func findService(method, host, path string) (*webServiceType, *websocketServiceType) {
webServicesLock.RLock()
defer webServicesLock.RUnlock()
// 1. Web Service 匹配
if s, exists := webServices[method+path]; exists {
return s, nil
}
if s, exists := webServices[path]; exists {
return s, nil
}
// 2. WebSocket 匹配
websocketServicesLock.RLock()
defer websocketServicesLock.RUnlock()
if ws, exists := websocketServices[path]; exists {
return nil, ws
}
// 3. 正则匹配
for i := len(regexWebServices) - 1; i >= 0; i-- {
s := regexWebServices[i]
if s.method != "" && s.method != method {
continue
}
if s.pathMatcher != nil && s.pathMatcher.MatchString(path) {
return s, nil
}
}
return nil, nil
}
func parseRequestArgs(request *Request, args map[string]any) {
// Query params
query := request.URL.Query()
for k, v := range query {
if len(v) == 1 {
args[k] = v[0]
} else {
args[k] = v
}
}
// Form params
if request.Method == http.MethodPost || request.Method == http.MethodPut {
contentType := request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") {
body, _ := io.ReadAll(request.Body)
_ = request.Body.Close()
if len(body) > 0 {
_ = json.Unmarshal(body, &args)
}
} else {
_ = request.ParseForm()
for k, v := range request.Form {
if len(v) == 1 {
args[k] = v[0]
} else {
args[k] = v
}
}
}
}
}
func checkAuth(s *webServiceType, request *Request, response *Response, args map[string]any, logger *log.Logger) (bool, any) {
ac := webAuthCheckers[s.authLevel]
if ac == nil {
ac = webAuthChecker
}
if ac == nil {
return true, nil
}
return ac(s.authLevel, logger, &request.RequestURI, args, request, response, &s.options)
}
func doWebService(service *webServiceType, request *Request, response *Response, args map[string]any,
result any, logger *log.Logger, object any) any {
if result != nil {
return result
}
params := make([]reflect.Value, service.parmsNum)
for i := 0; i < service.parmsNum; i++ {
t := service.funcType.In(i)
switch i {
case service.requestIndex:
params[i] = reflect.ValueOf(request)
case service.httpRequestIndex:
params[i] = reflect.ValueOf(request.Request)
case service.responseIndex:
params[i] = reflect.ValueOf(response)
case service.responseWriterIndex:
params[i] = reflect.ValueOf(response.Writer)
case service.loggerIndex:
params[i] = reflect.ValueOf(logger)
case service.inIndex:
in := reflect.New(service.inType).Interface()
cast.Convert(in, args)
// 参数校验
if service.inType.Kind() == reflect.Struct {
if ok, _ := VerifyStruct(in, logger); !ok {
response.WriteHeader(http.StatusBadRequest)
return "parameter verification failed"
}
}
params[i] = reflect.ValueOf(in).Elem()
default:
// 尝试依赖注入
if obj := GetInject(t); obj != nil {
params[i] = reflect.ValueOf(obj)
} else {
params[i] = reflect.New(t).Elem()
}
}
}
outs := service.funcValue.Call(params)
if len(outs) > 0 {
return outs[0].Interface()
}
return ""
}
func outputResult(response *Response, result any) {
if result == nil {
return
}
var data []byte
contentType := ""
switch v := result.(type) {
case string:
data = []byte(v)
case []byte:
data = v
default:
data, _ = cast.ToJSONBytes(result)
contentType = "application/json; charset=UTF-8"
}
if contentType != "" && response.Header().Get("Content-Type") == "" {
response.Header().Set("Content-Type", contentType)
}
_, _ = response.Write(data)
}
func handleClientKeys(request *Request, response *Response) {
// SessionId
if usedSessionIdKey != "" {
sessionId := request.Header.Get(usedSessionIdKey)
if sessionId == "" && !Config.SessionWithoutCookie {
if ck, err := request.Cookie(usedSessionIdKey); err == nil {
sessionId = ck.Value
}
}
if sessionId == "" {
if sessionIdMaker != nil {
sessionId = sessionIdMaker()
} else {
sessionId = id.MakeID(14)
}
if !Config.SessionWithoutCookie {
http.SetCookie(response.Writer, &http.Cookie{
Name: usedSessionIdKey,
Value: sessionId,
Path: "/",
HttpOnly: true,
})
}
}
request.Header.Set(standard.DiscoverHeaderSessionId, sessionId)
response.Header().Set(usedSessionIdKey, sessionId)
}
// DeviceId
if usedDeviceIdKey != "" {
deviceId := request.Header.Get(usedDeviceIdKey)
if deviceId == "" && !Config.DeviceWithoutCookie {
if ck, err := request.Cookie(usedDeviceIdKey); err == nil {
deviceId = ck.Value
}
}
if deviceId == "" {
deviceId = id.MakeID(14)
if !Config.DeviceWithoutCookie {
http.SetCookie(response.Writer, &http.Cookie{
Name: usedDeviceIdKey,
Value: deviceId,
Path: "/",
Expires: time.Now().AddDate(10, 0, 0),
HttpOnly: true,
})
}
}
request.Header.Set(standard.DiscoverHeaderDeviceId, deviceId)
response.Header().Set(usedDeviceIdKey, deviceId)
}
}

67
handler_test.go Normal file
View File

@ -0,0 +1,67 @@
package service
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestServeHTTP(t *testing.T) {
// 注册服务
handler := func(in struct{ Name string }) string {
return "Hello " + in.Name
}
Register(0, "/hello", handler, "say hello")
rh := &routeHandler{}
// 模拟请求
req := httptest.NewRequest("POST", "/hello", strings.NewReader(`{"name":"Star"}`))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
rh.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
body := w.Body.String()
if body != "Hello Star" {
t.Errorf("Expected 'Hello Star', got '%s'", body)
}
}
func TestServeHTTP_404(t *testing.T) {
rh := &routeHandler{}
req := httptest.NewRequest("GET", "/notfound", nil)
w := httptest.NewRecorder()
rh.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("Expected status 404, got %d", w.Code)
}
}
func TestServeHTTP_VerifyFailed(t *testing.T) {
type ValidIn struct {
Age int `verify:"between:18-100"`
}
handler := func(in ValidIn) string {
return "ok"
}
Register(0, "/verify", handler, "test verify")
rh := &routeHandler{}
req := httptest.NewRequest("POST", "/verify", strings.NewReader(`{"age":10}`))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
rh.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected status 400, got %d", w.Code)
}
}

171
proxy.go Normal file
View File

@ -0,0 +1,171 @@
package service
import (
"apigo.cc/go/discover"
gohttp "apigo.cc/go/http"
"apigo.cc/go/log"
"fmt"
"io"
"net/http"
"regexp"
"strings"
"sync"
"time"
)
type proxyInfo struct {
matcher *regexp.Regexp
authLevel int
fromPath string
toApp string
toPath string
}
var (
proxies = make(map[string]*proxyInfo)
regexProxies = make([]*proxyInfo, 0)
proxyBy func(*Request) (int, *string, *string, map[string]string)
proxiesLock = sync.RWMutex{}
httpClientPool *gohttp.Client
)
// Proxy 注册代理规则
func Proxy(authLevel int, path string, toApp, toPath string) {
p := &proxyInfo{authLevel: authLevel, fromPath: path, toApp: toApp, toPath: toPath}
if strings.Contains(path, "(") {
matcher, err := regexp.Compile("^" + path + "$")
if err == nil {
p.matcher = matcher
proxiesLock.Lock()
regexProxies = append(regexProxies, p)
proxiesLock.Unlock()
}
} else {
proxiesLock.Lock()
proxies[path] = p
proxiesLock.Unlock()
}
}
// SetProxyBy 设置动态代理函数
func SetProxyBy(by func(request *Request) (authLevel int, toApp, toPath *string, headers map[string]string)) {
proxyBy = by
}
func findProxy(request *Request) (int, *string, *string) {
requestPath := request.RequestURI
queryString := ""
if pos := strings.Index(requestPath, "?"); pos != -1 {
queryString = requestPath[pos:]
requestPath = requestPath[:pos]
}
proxiesLock.RLock()
defer proxiesLock.RUnlock()
if pi, ok := proxies[requestPath]; ok {
toPath := pi.toPath + queryString
return pi.authLevel, &pi.toApp, &toPath
}
for _, pi := range regexProxies {
if pi.matcher != nil {
finds := pi.matcher.FindAllStringSubmatch(requestPath, 1)
if len(finds) > 0 {
toApp := pi.toApp
toPath := pi.toPath
for i, part := range finds[0] {
toApp = strings.ReplaceAll(toApp, fmt.Sprintf("$%d", i), part)
toPath = strings.ReplaceAll(toPath, fmt.Sprintf("$%d", i), part)
}
toPath += queryString
return pi.authLevel, &toApp, &toPath
}
}
}
return 0, nil, nil
}
func processProxy(request *Request, response *Response, logger *log.Logger) bool {
authLevel, proxyToApp, proxyToPath := findProxy(request)
var proxyHeaders map[string]string
if proxyBy != nil && (proxyToApp == nil || proxyToPath == nil || *proxyToApp == "" || *proxyToPath == "") {
authLevel, proxyToApp, proxyToPath, proxyHeaders = proxyBy(request)
}
if proxyToApp == nil || proxyToPath == nil || *proxyToApp == "" || *proxyToPath == "" {
return false
}
// 鉴权
pass, obj := checkAuthForProxy(authLevel, request, response, logger)
if !pass {
if !response.changed {
response.WriteHeader(http.StatusForbidden)
}
return true
}
_ = obj // Currently unused in proxy
app := *proxyToApp
path := *proxyToPath
// 构建自定义头部
headerArgs := make([]string, 0)
for k, v := range proxyHeaders {
headerArgs = append(headerArgs, k, v)
}
if strings.Contains(app, "://") {
// 直接 URL 代理
if httpClientPool == nil {
httpClientPool = gohttp.NewClient(time.Duration(Config.RedirectTimeout) * time.Millisecond)
}
res := httpClientPool.ManualDoByRequest(request.Request, request.Method, app+path, request.Body, headerArgs...)
copyResponse(res, response, logger)
} else {
// Discover 代理
caller := discover.NewCaller(request.Request, logger)
caller.NoBody = true
res, _ := caller.ManualDoWithNode(request.Method, app, "", path, request.Body, headerArgs...)
copyResponse(res, response, logger)
}
return true
}
func checkAuthForProxy(authLevel int, request *Request, response *Response, logger *log.Logger) (bool, any) {
ac := webAuthCheckers[authLevel]
if ac == nil {
ac = webAuthChecker
}
if ac == nil {
return true, nil
}
return ac(authLevel, logger, &request.RequestURI, nil, request, response, nil)
}
func copyResponse(res *gohttp.Result, response *Response, logger *log.Logger) {
if res.Error != nil || res.Response == nil {
response.WriteHeader(http.StatusBadGateway)
if res.Error != nil {
_, _ = response.WriteString(res.Error.Error())
}
return
}
for k, v := range res.Response.Header {
response.Header().Set(k, v[0])
}
response.WriteHeader(res.Response.StatusCode)
if res.Response.Body != nil {
defer res.Response.Body.Close()
_, err := io.Copy(response.Writer, res.Response.Body)
if err != nil {
logger.Error("proxy copy body failed", "error", err.Error())
}
}
}

62
proxy_test.go Normal file
View File

@ -0,0 +1,62 @@
package service
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestRewrite(t *testing.T) {
// 注册重写规则
Rewrite("/old", "/new")
Rewrite("/regex/(.*)", "/target/$1")
// 注册目标服务
Register(0, "/new", func() string { return "new content" }, "new")
Register(0, "/target/123", func() string { return "target content" }, "target")
rh := &routeHandler{}
// 测试精确匹配重写
req1 := httptest.NewRequest("GET", "/old", nil)
w1 := httptest.NewRecorder()
rh.ServeHTTP(w1, req1)
if w1.Body.String() != "new content" {
t.Errorf("Expected 'new content', got '%s'", w1.Body.String())
}
// 测试正则匹配重写
req2 := httptest.NewRequest("GET", "/regex/123", nil)
w2 := httptest.NewRecorder()
rh.ServeHTTP(w2, req2)
if w2.Body.String() != "target content" {
t.Errorf("Expected 'target content', got '%s'", w2.Body.String())
}
}
func TestProxyDirect(t *testing.T) {
// 启动后端服务器
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Backend", "ok")
w.Write([]byte("backend content"))
}))
defer backend.Close()
// 注册代理规则
Proxy(0, "/proxy", backend.URL, "/hello")
rh := &routeHandler{}
req := httptest.NewRequest("GET", "/proxy", nil)
w := httptest.NewRecorder()
rh.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected 200, got %d", w.Code)
}
if w.Header().Get("X-Backend") != "ok" {
t.Error("Header X-Backend mismatch")
}
if w.Body.String() != "backend content" {
t.Errorf("Expected 'backend content', got '%s'", w.Body.String())
}
}

139
request.go Normal file
View File

@ -0,0 +1,139 @@
package service
import (
"apigo.cc/go/cast"
"apigo.cc/go/standard"
"io"
"mime/multipart"
"net"
"net/http"
"net/textproto"
"net/url"
"os"
"path/filepath"
)
// UploadFile 上传文件结构
type UploadFile struct {
fileHeader *multipart.FileHeader
Filename string
Header textproto.MIMEHeader
Size int64
}
// Open 打开上传文件
func (f *UploadFile) Open() (multipart.File, error) {
return f.fileHeader.Open()
}
// Save 保存上传文件到本地
func (f *UploadFile) Save(filename string) error {
dir := filepath.Dir(filename)
if _, err := os.Stat(dir); os.IsNotExist(err) {
_ = os.MkdirAll(dir, 0755)
}
dst, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil {
return err
}
defer dst.Close()
src, err := f.fileHeader.Open()
if err != nil {
return err
}
defer src.Close()
_, err = io.Copy(dst, src)
return err
}
// Content 获取上传文件内容
func (f *UploadFile) Content() ([]byte, error) {
src, err := f.fileHeader.Open()
if err != nil {
return nil, err
}
defer src.Close()
return io.ReadAll(src)
}
// Request 封装 http.Request
type Request struct {
*http.Request
contextValues map[string]any
Id string
}
// NewRequest 创建 Request 包装
func NewRequest(httpRequest *http.Request) *Request {
return &Request{
Request: httpRequest,
contextValues: make(map[string]any),
}
}
// ResetPath 重写请求路径
func (r *Request) ResetPath(path string) {
r.RequestURI = path
if u, err := url.Parse(path); err == nil {
r.URL = u
}
}
// Set 设置请求上下文变量
func (r *Request) Set(key string, value any) {
r.contextValues[key] = value
}
// Get 获取请求上下文变量
func (r *Request) Get(key string) any {
return r.contextValues[key]
}
// MakeUrl 根据当前请求构建完整 URL
func (r *Request) MakeUrl(path string) string {
scheme := r.Header.Get(standard.DiscoverHeaderScheme)
if scheme == "" {
scheme = "http"
}
host := r.Header.Get(standard.DiscoverHeaderHost)
if host == "" {
host = r.Host
}
return scheme + "://" + host + path
}
// GetSessionId 获取会话 ID
func (r *Request) GetSessionId() string {
sessionId := r.Header.Get(Config.Listen) // Wait, this should be usedSessionIdKey
// TODO: Fix dependency on global usedSessionIdKey
return sessionId
}
// SetUserId 设置用户 ID传递给下游
func (r *Request) SetUserId(userId string) {
r.Header.Set(standard.DiscoverHeaderUserId, userId)
}
// GetRealIp 获取真实 IP
func (r *Request) GetRealIp() string {
ip := r.Header.Get(standard.DiscoverHeaderClientIp)
if ip == "" {
ip = r.Header.Get(standard.DiscoverHeaderForwardedFor)
}
if ip == "" {
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err == nil {
return host
}
return r.RemoteAddr
}
return ip
}
// GetLowerName (Aliased from cast)
func GetLowerName(s string) string {
return cast.GetLowerName(s)
}

152
response.go Normal file
View File

@ -0,0 +1,152 @@
package service
import (
"apigo.cc/go/cast"
"io"
"net/http"
"os"
)
// Response 封装 http.ResponseWriter
type Response struct {
Id string
Writer http.ResponseWriter
status int
outLen int
changed bool
headerWritten bool
dontLog200 bool
dontLogArgs []string
ProxyHeader *http.Header
}
// NewResponse 创建 Response 包装
func NewResponse(writer http.ResponseWriter) *Response {
return &Response{
Writer: writer,
status: http.StatusOK,
}
}
// Header 获取响应头部
func (r *Response) Header() http.Header {
r.changed = true
if r.ProxyHeader != nil {
return *r.ProxyHeader
}
return r.Writer.Header()
}
// Write 写入响应内容
func (r *Response) Write(bytes []byte) (int, error) {
r.checkWriteHeader()
r.changed = true
r.outLen += len(bytes)
if r.ProxyHeader != nil {
r.copyProxyHeader()
}
return r.Writer.Write(bytes)
}
// WriteString 写入字符串响应
func (r *Response) WriteString(s string) (int, error) {
return r.Write([]byte(s))
}
// WriteHeader 设置响应状态码
func (r *Response) WriteHeader(code int) {
r.changed = true
r.status = code
if r.ProxyHeader != nil && (r.status == http.StatusBadGateway || r.status == http.StatusServiceUnavailable || r.status == http.StatusGatewayTimeout) {
return
}
if r.ProxyHeader != nil {
r.copyProxyHeader()
}
}
func (r *Response) checkWriteHeader() {
if !r.headerWritten {
r.headerWritten = true
if r.status != http.StatusOK {
r.Writer.WriteHeader(r.status)
}
}
}
func (r *Response) copyProxyHeader() {
src := *r.ProxyHeader
dst := r.Writer.Header()
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
r.ProxyHeader = nil
}
// Flush 刷新响应缓冲区
func (r *Response) Flush() {
if flusher, ok := r.Writer.(http.Flusher); ok {
flusher.Flush()
}
}
// GetStatusCode 获取当前状态码
func (r *Response) GetStatusCode() int {
return r.status
}
// DontLog200 标记不记录 200 状态码的日志
func (r *Response) DontLog200() {
r.dontLog200 = true
}
// Location 设置重定向地址
func (r *Response) Location(location string) {
r.WriteHeader(http.StatusFound)
r.Header().Set("Location", location)
}
// SendFile 发送文件
func (r *Response) SendFile(contentType, filename string) {
r.Header().Set("Content-Type", contentType)
// TODO: Integrate memory file support if needed
if fd, err := os.Open(filename); err == nil {
defer fd.Close()
_, _ = io.Copy(r, fd)
}
}
// DownloadFile 下载文件
func (r *Response) DownloadFile(contentType, filename string, data any) {
if contentType == "" {
contentType = "application/octet-stream"
}
r.Header().Set("Content-Type", contentType)
if filename != "" {
r.Header().Set("Content-Disposition", "attachment; filename="+filename)
}
var outBytes []byte
var reader io.Reader
switch v := data.(type) {
case []byte:
outBytes = v
case string:
outBytes = []byte(v)
case io.Reader:
reader = v
default:
outBytes, _ = cast.ToJSONBytes(data)
}
if outBytes != nil {
r.Header().Set("Content-Length", cast.String(len(outBytes)))
_, _ = r.Write(outBytes)
} else if reader != nil {
_, _ = io.Copy(r, reader)
}
}

113
rewrite.go Normal file
View File

@ -0,0 +1,113 @@
package service
import (
"apigo.cc/go/log"
"fmt"
"net/url"
"regexp"
"strings"
"sync"
)
type rewriteInfo struct {
matcher *regexp.Regexp
fromPath string
toPath string
}
var (
rewrites = make(map[string]*rewriteInfo)
regexRewrites = make([]*rewriteInfo, 0)
rewriteBy func(*Request) (string, bool)
rewritesLock = sync.RWMutex{}
)
// Rewrite 注册重写规则
func Rewrite(path string, toPath string) {
s := &rewriteInfo{fromPath: path, toPath: toPath}
if strings.ContainsRune(path, '(') {
matcher, err := regexp.Compile("^" + path + "$")
if err == nil {
s.matcher = matcher
rewritesLock.Lock()
regexRewrites = append(regexRewrites, s)
rewritesLock.Unlock()
}
} else {
rewritesLock.Lock()
rewrites[path] = s
rewritesLock.Unlock()
}
}
// SetRewriteBy 设置动态重写函数
func SetRewriteBy(by func(request *Request) (toPath string, rewrite bool)) {
rewriteBy = by
}
func processRewrite(request *Request, response *Response, logger *log.Logger) bool {
requestPath := request.RequestURI
queryString := ""
if pos := strings.Index(requestPath, "?"); pos != -1 {
queryString = requestPath[pos:]
requestPath = requestPath[:pos]
}
var rewriteToPath string
var found bool
rewritesLock.RLock()
// 1. 精确匹配
if ri, ok := rewrites[requestPath]; ok {
rewriteToPath = ri.toPath
found = true
}
// 2. 动态重写
if !found && rewriteBy != nil {
rewriteToPath, found = rewriteBy(request)
}
// 3. 正则匹配
if !found {
for _, ri := range regexRewrites {
if ri.matcher != nil {
finds := ri.matcher.FindAllStringSubmatch(request.RequestURI, 1)
if len(finds) > 0 {
toPath := ri.toPath
for i, part := range finds[0] {
toPath = strings.ReplaceAll(toPath, fmt.Sprintf("$%d", i), part)
}
rewriteToPath = toPath
found = true
break
}
}
}
}
rewritesLock.RUnlock()
if found {
if strings.Contains(rewriteToPath, "://") {
// 外部重定向
if !strings.Contains(rewriteToPath, "?") && queryString != "" {
rewriteToPath += queryString
}
response.Header().Set("Location", rewriteToPath)
response.WriteHeader(302)
return true
} else {
// 内部重写
logger.Info("rewrite", "from", request.RequestURI, "to", rewriteToPath)
if queryString != "" && !strings.Contains(rewriteToPath, "?") {
rewriteToPath += queryString
}
request.RequestURI = rewriteToPath
request.URL, _ = url.Parse(rewriteToPath)
return false // 继续后续处理
}
}
return false
}

88
server.go Normal file
View File

@ -0,0 +1,88 @@
package service
import (
"apigo.cc/go/log"
"context"
"net"
"net/http"
"os"
"os/signal"
"syscall"
"time"
)
// AsyncServer 异步服务实例
type AsyncServer struct {
server *http.Server
listener net.Listener
Addr string
stopChan chan os.Signal
startChan chan bool
}
// AsyncStart 异步启动服务
func AsyncStart() *AsyncServer {
as := &AsyncServer{
startChan: make(chan bool, 1),
stopChan: make(chan os.Signal, 1),
}
go as.start()
<-as.startChan
return as
}
func (as *AsyncServer) start() {
if Config.Listen == "" {
Config.Listen = ":8080" // 默认端口
}
listener, err := net.Listen("tcp", Config.Listen)
if err != nil {
log.DefaultLogger.Error("failed to listen", "addr", Config.Listen, "error", err.Error())
as.startChan <- false
return
}
as.listener = listener
as.Addr = listener.Addr().String()
serverAddr = as.Addr
as.server = &http.Server{
Handler: &routeHandler{},
}
signal.Notify(as.stopChan, os.Interrupt, syscall.SIGTERM)
go func() {
log.DefaultLogger.Info("service starting", "addr", as.Addr)
as.startChan <- true
if err := as.server.Serve(listener); err != nil && err != http.ErrServerClosed {
log.DefaultLogger.Error("server error", "error", err.Error())
}
}()
}
// Stop 停止服务
func (as *AsyncServer) Stop() {
log.DefaultLogger.Info("service stopping")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := as.server.Shutdown(ctx); err != nil {
log.DefaultLogger.Error("server shutdown error", "error", err.Error())
}
log.DefaultLogger.Info("service stopped")
}
// Wait 等待服务结束 (信号监听)
func (as *AsyncServer) Wait() {
<-as.stopChan
as.Stop()
}
// Start 同步启动服务
func Start() {
AsyncStart().Wait()
}

31
server_test.go Normal file
View File

@ -0,0 +1,31 @@
package service
import (
"net/http"
"testing"
)
func TestAsyncServer(t *testing.T) {
Config.Listen = ":0" // 随机端口
as := AsyncStart()
if as.Addr == "" {
t.Fatal("AsyncStart failed to get address")
}
// 测试服务是否可用
resp, err := http.Get("http://" + as.Addr + "/__CHECK__")
if err == nil {
// 虽然没有注册 /__CHECK__但应该返回 404 而非连接拒绝
if resp.StatusCode != http.StatusNotFound {
t.Errorf("Expected 404, got %d", resp.StatusCode)
}
}
as.Stop()
// 确认服务已关闭
_, err = http.Get("http://" + as.Addr + "/__CHECK__")
if err == nil {
t.Error("Server should be closed")
}
}

225
service.go Normal file
View File

@ -0,0 +1,225 @@
package service
import (
"apigo.cc/go/log"
"errors"
"reflect"
"regexp"
"strings"
"sync"
)
// WebServiceOptions 服务注册选项
type WebServiceOptions struct {
Priority int
NoDoc bool
NoBody bool
NoLog200 bool
Host string
Ext Map
// Limiters []*Limiter // TODO: Integrate Limiter
}
// webServiceType 内部存储的服务元数据
type webServiceType struct {
authLevel int
method string
path string
pathMatcher *regexp.Regexp
pathArgs []string
parmsNum int
inType reflect.Type
inIndex int
headersType reflect.Type
headersIndex int
requestIndex int
httpRequestIndex int
responseIndex int
responseWriterIndex int
loggerIndex int
callerIndex int
funcType reflect.Type
funcValue reflect.Value
options WebServiceOptions
data Map
memo string
}
var (
serverId string
serverAddr string
serverProto = "http"
serverProtoName = "http"
running = false
webServices = make(map[string]*webServiceType)
regexWebServices = make([]*webServiceType, 0)
webServicesLock = sync.RWMutex{}
webServicesList = make([]*webServiceType, 0)
websocketServices = make(map[string]*websocketServiceType)
regexWebsocketServices = make([]*websocketServiceType, 0)
websocketServicesLock = sync.RWMutex{}
websocketServicesList = make([]*websocketServiceType, 0)
// 过滤器与拦截器
inFilters = make([]func(*map[string]any, *Request, *Response, *log.Logger) any, 0)
outFilters = make([]func(map[string]any, *Request, *Response, any, *log.Logger) (any, bool), 0)
errorHandle func(any, *Request, *Response) any
webAuthChecker func(int, *log.Logger, *string, map[string]any, *Request, *Response, *WebServiceOptions) (pass bool, object any)
webAuthCheckers = make(map[int]func(int, *log.Logger, *string, map[string]any, *Request, *Response, *WebServiceOptions) (pass bool, object any))
// 注入点
injectObjects = make(map[reflect.Type]any)
injectFunctions = make(map[reflect.Type]func() any)
usedDeviceIdKey string
usedClientAppKey string
usedSessionIdKey string
sessionIdMaker func() string
)
// SetClientKeys 设置客户端标识相关的 Key 映射
func SetClientKeys(deviceIdKey, clientAppKey, sessionIdKey string) {
usedDeviceIdKey = deviceIdKey
usedClientAppKey = clientAppKey
usedSessionIdKey = sessionIdKey
}
// SetSessionIdMaker 设置自定义会话 ID 生成器
func SetSessionIdMaker(maker func() string) {
sessionIdMaker = maker
}
// SetAuthChecker 设置全局鉴权器
func SetAuthChecker(authChecker func(authLevel int, logger *log.Logger, url *string, in map[string]any, request *Request, response *Response, options *WebServiceOptions) (pass bool, object any)) {
webAuthChecker = authChecker
}
// AddAuthChecker 为指定级别添加鉴权器
func AddAuthChecker(authLevels []int, authChecker func(authLevel int, logger *log.Logger, url *string, in map[string]any, request *Request, response *Response, options *WebServiceOptions) (pass bool, object any)) {
for _, al := range authLevels {
webAuthCheckers[al] = authChecker
}
}
// SetInFilter 设置前置过滤器
func SetInFilter(filter func(in *map[string]any, request *Request, response *Response, logger *log.Logger) (out any)) {
inFilters = append(inFilters, filter)
}
// SetOutFilter 设置后置过滤器
func SetOutFilter(filter func(in map[string]any, request *Request, response *Response, out any, logger *log.Logger) (newOut any, isOver bool)) {
outFilters = append(outFilters, filter)
}
// Register 注册服务(通用方法)
func Register(authLevel int, path string, serviceFunc any, memo string) {
Restful(authLevel, "", path, serviceFunc, memo)
}
// Restful 注册指定方法的服务
func Restful(authLevel int, method, path string, serviceFunc any, memo string) {
RestfulWithOptions(authLevel, method, path, serviceFunc, memo, WebServiceOptions{})
}
// RestfulWithOptions 注册带选项的服务
func RestfulWithOptions(authLevel int, method, path string, serviceFunc any, memo string, options WebServiceOptions) {
s, err := makeCachedService(serviceFunc)
if err != nil {
// TODO: Log error properly when logger is ready
return
}
s.authLevel = authLevel
s.options = options
s.method = method
s.path = path
s.memo = memo
// 解析路径参数 {name}
finder, err := regexp.Compile("{(.*?)}")
if err == nil {
keyName := regexp.QuoteMeta(path)
finds := finder.FindAllStringSubmatch(path, 20)
for _, found := range finds {
keyName = strings.Replace(keyName, regexp.QuoteMeta(found[0]), "(.*?)", 1)
s.pathArgs = append(s.pathArgs, found[1])
}
if len(s.pathArgs) > 0 {
s.pathMatcher, _ = regexp.Compile("^" + keyName + "$")
}
}
webServicesLock.Lock()
defer webServicesLock.Unlock()
// 简单路径匹配
if s.pathMatcher == nil {
webServices[method+path] = s // TODO: Include Host in key
} else {
regexWebServices = append(regexWebServices, s)
}
webServicesList = append(webServicesList, s)
}
func makeCachedService(matchedService any) (*webServiceType, error) {
funcType := reflect.TypeOf(matchedService)
if funcType.Kind() != reflect.Func {
return nil, errors.New("handler must be a function")
}
targetService := &webServiceType{
parmsNum: funcType.NumIn(),
inIndex: -1,
headersIndex: -1,
requestIndex: -1,
httpRequestIndex: -1,
responseIndex: -1,
responseWriterIndex: -1,
loggerIndex: -1,
callerIndex: -1,
funcType: funcType,
funcValue: reflect.ValueOf(matchedService),
}
for i := 0; i < targetService.parmsNum; i++ {
t := funcType.In(i)
tStr := t.String()
switch tStr {
case "*service.Request":
targetService.requestIndex = i
case "*http.Request":
targetService.httpRequestIndex = i
case "*service.Response":
targetService.responseIndex = i
case "http.ResponseWriter":
targetService.responseWriterIndex = i
case "*log.Logger":
targetService.loggerIndex = i
default:
if t.Kind() == reflect.Struct || (t.Kind() == reflect.Map && t.Elem().Kind() == reflect.Interface) {
if targetService.inType == nil {
targetService.inIndex = i
targetService.inType = t
} else if targetService.headersType == nil {
targetService.headersIndex = i
targetService.headersType = t
}
}
}
}
return targetService, nil
}
// GetInject 获取注入对象
func GetInject(dataType reflect.Type) any {
if obj, exists := injectObjects[dataType]; exists {
return obj
}
if factory, exists := injectFunctions[dataType]; exists {
return factory()
}
return nil
}

54
service_test.go Normal file
View File

@ -0,0 +1,54 @@
package service
import (
"apigo.cc/go/log"
"testing"
)
func TestServiceRegister(t *testing.T) {
handler := func(req *Request, logger *log.Logger) string {
return "ok"
}
Register(0, "/test", handler, "test service")
webServicesLock.RLock()
s := webServices["/test"]
webServicesLock.RUnlock()
if s == nil {
t.Fatal("Service not registered")
}
if s.requestIndex != 0 {
t.Errorf("requestIndex mismatch: expected 0, got %d", s.requestIndex)
}
if s.loggerIndex != 1 {
t.Errorf("loggerIndex mismatch: expected 1, got %d", s.loggerIndex)
}
}
func TestRegexServiceRegister(t *testing.T) {
handler := func(args map[string]any) string {
return "ok"
}
Register(0, "/user/{id}", handler, "get user")
webServicesLock.RLock()
found := false
for _, s := range regexWebServices {
if s.path == "/user/{id}" {
found = true
if len(s.pathArgs) != 1 || s.pathArgs[0] != "id" {
t.Errorf("pathArgs mismatch: %v", s.pathArgs)
}
break
}
}
webServicesLock.RUnlock()
if !found {
t.Fatal("Regex service not registered")
}
}

49
starter.go Normal file
View File

@ -0,0 +1,49 @@
package service
import (
"fmt"
"os"
"path/filepath"
)
// StartCmd 命令行命令定义
type StartCmd struct {
Name string
Comment string
Func func()
}
var startCmds = []StartCmd{
{"start", "Start server", Start},
}
// AddCmd 添加自定义命令行命令
func AddCmd(name, comment string, function func()) {
startCmds = append(startCmds, StartCmd{name, comment, function})
}
// CheckCmd 检查并执行命令行命令
func CheckCmd() {
if len(os.Args) > 1 {
cmd := os.Args[1]
if cmd == "help" || cmd == "--help" {
showHelp()
os.Exit(0)
}
for _, cmdInfo := range startCmds {
if cmd == cmdInfo.Name {
cmdInfo.Func()
os.Exit(0)
}
}
}
}
func showHelp() {
fmt.Printf("Usage: %s [command]\n\n", filepath.Base(os.Args[0]))
fmt.Println("Available commands:")
for _, cmdInfo := range startCmds {
fmt.Printf(" %-10s %s\n", cmdInfo.Name, cmdInfo.Comment)
}
}

123
static.go Normal file
View File

@ -0,0 +1,123 @@
package service
import (
"apigo.cc/go/file"
"apigo.cc/go/log"
"mime"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"time"
)
var (
statics = make(map[string]*string)
staticsByHost = make(map[string]map[string]*string)
staticsByHostLock = sync.RWMutex{}
)
// Static 注册静态文件目录
func Static(path, rootPath string) {
StaticByHost(path, rootPath, "")
}
// StaticByHost 为指定域名注册静态文件目录
func StaticByHost(path, rootPath, host string) {
if !filepath.IsAbs(rootPath) {
if absPath, err := filepath.Abs(rootPath); err == nil {
rootPath = absPath
}
}
staticsByHostLock.Lock()
defer staticsByHostLock.Unlock()
if host == "" {
statics[path] = &rootPath
} else {
if staticsByHost[host] == nil {
staticsByHost[host] = make(map[string]*string)
}
staticsByHost[host][path] = &rootPath
}
}
func getStaticFilePath(requestPath, host string) string {
staticsByHostLock.RLock()
defer staticsByHostLock.RUnlock()
// 优先匹配指定域名的配置
if hostConfig, exists := staticsByHost[host]; exists {
if filePath := findMatchedPath(hostConfig, requestPath); filePath != "" {
return filePath
}
}
// 匹配全局配置
return findMatchedPath(statics, requestPath)
}
func findMatchedPath(config map[string]*string, requestPath string) string {
for urlPath, rootPath := range config {
if strings.HasPrefix(requestPath, urlPath) {
return filepath.Join(*rootPath, requestPath[len(urlPath):])
}
}
return ""
}
func processStatic(requestPath string, request *Request, response *Response, logger *log.Logger) bool {
filePath := getStaticFilePath(requestPath, request.Host)
if filePath == "" {
return false
}
info, err := os.Stat(filePath)
if err != nil {
return false
}
if info.IsDir() {
// 自动查找索引文件
for _, indexFile := range Config.IndexFiles {
f := filepath.Join(filePath, indexFile)
if i, err := os.Stat(f); err == nil && !i.IsDir() {
filePath = f
info = i
break
}
}
}
if info.IsDir() {
return false
}
// 检查 304
if ifModifiedSince := request.Header.Get("If-Modified-Since"); ifModifiedSince != "" {
if t, err := time.Parse(http.TimeFormat, ifModifiedSince); err == nil {
if !info.ModTime().Truncate(time.Second).After(t.Truncate(time.Second)) {
response.WriteHeader(http.StatusNotModified)
return true
}
}
}
// 发送文件
contentType := mime.TypeByExtension(filepath.Ext(filePath))
if contentType == "" {
contentType = "application/octet-stream"
}
response.Header().Set("Content-Type", contentType)
response.Header().Set("Last-Modified", info.ModTime().UTC().Format(http.TimeFormat))
data, err := file.ReadBytes(filePath)
if err != nil {
return false
}
_, _ = response.Write(data)
return true
}

43
static_test.go Normal file
View File

@ -0,0 +1,43 @@
package service
import (
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
)
func TestStaticService(t *testing.T) {
// 创建临时测试目录和文件
tempDir, _ := os.MkdirTemp("", "static_test")
defer os.RemoveAll(tempDir)
testFile := filepath.Join(tempDir, "index.html")
os.WriteFile(testFile, []byte("<h1>Static Page</h1>"), 0644)
// 注册静态目录
Static("/ui", tempDir)
rh := &routeHandler{}
// 测试成功访问
req := httptest.NewRequest("GET", "/ui/index.html", nil)
w := httptest.NewRecorder()
rh.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected 200, got %d", w.Code)
}
if body := w.Body.String(); body != "<h1>Static Page</h1>" {
t.Errorf("Content mismatch: %s", body)
}
// 测试 404
req404 := httptest.NewRequest("GET", "/ui/notfound.html", nil)
w404 := httptest.NewRecorder()
rh.ServeHTTP(w404, req404)
if w404.Code != http.StatusNotFound {
t.Errorf("Expected 404 for missing file, got %d", w404.Code)
}
}

68
types.go Normal file
View File

@ -0,0 +1,68 @@
package service
// Map 通用 Map 类型
type Map = map[string]any
// Arr 通用切片类型
type Arr = []any
// Argot 错误码/标识符类型
type Argot string
// Result 通用返回结构
type Result struct {
Ok bool `json:"ok"`
Argot Argot `json:"argot,omitempty"`
Message string `json:"message,omitempty"`
}
// CodeResult 带状态码的返回结构
type CodeResult struct {
Code int `json:"code"`
Message string `json:"message,omitempty"`
}
// ArgotInfo 标识符信息(用于文档生成)
type ArgotInfo struct {
Name Argot
Memo string
}
// OK 设置成功状态
func (r *Result) OK(argots ...Argot) {
r.Ok = true
if len(argots) > 0 {
r.Argot = argots[0]
}
}
// Failed 设置失败状态
func (r *Result) Failed(message string, argots ...Argot) {
r.Ok = false
r.Message = message
if len(argots) > 0 {
r.Argot = argots[0]
}
}
// Done 根据布尔值设置状态
func (r *Result) Done(ok bool, failedMessage string, argots ...Argot) {
r.Ok = ok
if !ok {
r.Message = failedMessage
if len(argots) > 0 {
r.Argot = argots[0]
}
}
}
// OK 设置成功状态 (Code=1)
func (r *CodeResult) OK() {
r.Code = 1
}
// Failed 设置失败状态与错误码
func (r *CodeResult) Failed(code int, message string) {
r.Code = code
r.Message = message
}

41
types_test.go Normal file
View File

@ -0,0 +1,41 @@
package service
import (
"testing"
)
func TestResult(t *testing.T) {
r := &Result{}
r.OK()
if !r.Ok {
t.Error("Result.OK() failed")
}
r.Failed("error", Argot("ERR_CODE"))
if r.Ok || r.Message != "error" || r.Argot != "ERR_CODE" {
t.Error("Result.Failed() failed")
}
r.Done(true, "never")
if !r.Ok {
t.Error("Result.Done(true) failed")
}
r.Done(false, "failed", Argot("FAIL"))
if r.Ok || r.Message != "failed" || r.Argot != "FAIL" {
t.Error("Result.Done(false) failed")
}
}
func TestCodeResult(t *testing.T) {
cr := &CodeResult{}
cr.OK()
if cr.Code != 1 {
t.Error("CodeResult.OK() failed")
}
cr.Failed(500, "internal error")
if cr.Code != 500 || cr.Message != "internal error" {
t.Error("CodeResult.Failed() failed")
}
}

20
utility.go Normal file
View File

@ -0,0 +1,20 @@
package service
import (
"apigo.cc/go/id"
)
// MakeId 生成指定长度的 ID
func MakeId(size int) string {
return id.MakeID(size)
}
// MakeIdForMysql 生成适用于 MySQL 的有序 ID
func MakeIdForMysql(size int) string {
return id.DefaultIDMaker.GetForMysql(size)
}
// MakeIdForPostgreSQL 生成适用于 PostgreSQL 的有序 ID
func MakeIdForPostgreSQL(size int) string {
return id.DefaultIDMaker.GetForPostgreSQL(size)
}

305
verify.go Normal file
View File

@ -0,0 +1,305 @@
package service
import (
"apigo.cc/go/cast"
"apigo.cc/go/log"
"reflect"
"regexp"
"strings"
"sync"
)
// VerifyType 校验类型
type VerifyType uint8
const (
VerifyUnknown VerifyType = iota
VerifyRegex
VerifyStringLength
VerifyGreaterThan
VerifyLessThan
VerifyBetween
VerifyInList
VerifyByFunc
)
// VerifySet 校验规则集
type VerifySet struct {
Type VerifyType
Regex *regexp.Regexp
StringArgs []string
IntArgs []int
FloatArgs []float64
Func func(any, []string) bool
}
var (
verifySets = make(map[string]*VerifySet)
verifySetsLock = sync.RWMutex{}
verifyFunctions = make(map[string]func(any, []string) bool)
verifyFunctionsLock = sync.RWMutex{}
)
// RegisterVerifyFunc 注册自定义校验函数
func RegisterVerifyFunc(name string, f func(in any, args []string) bool) {
verifyFunctionsLock.Lock()
verifyFunctions[name] = f
verifyFunctionsLock.Unlock()
}
// RegisterVerify 注册预定义校验规则
func RegisterVerify(name, setting string) {
set, _ := compileVerifySet(setting)
if set != nil {
verifySetsLock.Lock()
verifySets[name] = set
verifySetsLock.Unlock()
}
}
// VerifyStruct 校验结构体
func VerifyStruct(in any, logger *log.Logger) (ok bool, field string) {
v := cast.RealValue(reflect.ValueOf(in))
if v.Kind() != reflect.Struct {
if logger != nil {
logger.Error("verify input is not struct", "type", v.Type().String())
}
return false, ""
}
for i := 0; i < v.NumField(); i++ {
ft := v.Type().Field(i)
fv := v.Field(i)
// 忽略空指针、空切片、空 Map
if (fv.Kind() == reflect.Ptr && fv.IsNil()) ||
(fv.Kind() == reflect.Slice && fv.Len() == 0) ||
(fv.Kind() == reflect.Map && fv.Len() == 0) {
continue
}
if ft.Anonymous {
// 处理嵌套结构体(继承)
if fv.CanInterface() {
if ok, f := VerifyStruct(fv.Interface(), logger); !ok {
return false, f
}
}
continue
}
tag := ft.Tag.Get("verify")
keyTag := ft.Tag.Get("verifyKey")
if tag != "" || keyTag != "" {
var err error
ok, f, err := _verifyValue(fv, tag, keyTag, logger)
if !ok {
if f == "" {
f = cast.GetLowerName(ft.Name)
}
if logger != nil {
if err != nil {
logger.Error(err.Error(), "field", f)
} else {
logger.Warning("verify failed", "field", f, "tag", tag)
}
}
return false, f
}
}
}
return true, ""
}
func _verifyValue(in reflect.Value, setting, keySetting string, logger *log.Logger) (bool, string, error) {
t := in.Type()
// 处理切片 (非 byte 切片)
if t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8 {
if setting != "" {
for i := 0; i < in.Len(); i++ {
if ok, f, err := _verifyValue(in.Index(i), setting, "", logger); !ok {
return false, f, err
}
}
}
return true, "", nil
}
// 处理 Map
if t.Kind() == reflect.Map {
for _, k := range in.MapKeys() {
if keySetting != "" {
if ok, _, err := _verifyValue(k, keySetting, "", logger); !ok {
return false, "key", err
}
}
if setting != "" {
if ok, f, err := _verifyValue(in.MapIndex(k), setting, "", logger); !ok {
return false, f, err
}
}
}
return true, "", nil
}
// 处理嵌套 Struct
if t.Kind() == reflect.Struct {
ok, f := VerifyStruct(in.Interface(), logger)
return ok, f, nil
}
// 基础校验
if setting == "" {
return true, "", nil
}
ok, err := verify(in.Interface(), setting)
return ok, "", err
}
func verify(in any, setting string) (bool, error) {
if len(setting) < 2 {
return false, nil
}
verifySetsLock.RLock()
set, exists := verifySets[setting]
verifySetsLock.RUnlock()
if !exists {
var err error
set, err = compileVerifySet(setting)
if err != nil {
return false, err
}
verifySetsLock.Lock()
verifySets[setting] = set
verifySetsLock.Unlock()
}
switch set.Type {
case VerifyByFunc:
return set.Func(in, set.StringArgs), nil
case VerifyRegex:
return set.Regex.MatchString(cast.String(in)), nil
case VerifyStringLength:
l := len(cast.String(in))
if len(set.StringArgs) > 0 {
if set.StringArgs[0] == "+" {
return l >= set.IntArgs[0], nil
} else if set.StringArgs[0] == "-" {
return l <= set.IntArgs[0], nil
}
}
if len(set.IntArgs) > 1 {
return l >= set.IntArgs[0] && l <= set.IntArgs[1], nil
}
return l == set.IntArgs[0], nil
case VerifyGreaterThan:
return cast.Float64(in) > set.FloatArgs[0], nil
case VerifyLessThan:
return cast.Float64(in) < set.FloatArgs[0], nil
case VerifyBetween:
val := cast.Float64(in)
return val >= set.FloatArgs[0] && val <= set.FloatArgs[1], nil
case VerifyInList:
s := cast.String(in)
for _, item := range set.StringArgs {
if item == s {
return true, nil
}
}
return false, nil
}
return false, nil
}
func compileVerifySet(setting string) (*VerifySet, error) {
set := &VerifySet{Type: VerifyUnknown}
if setting == "" {
return set, nil
}
if setting[0] != '^' {
key := setting
args := ""
if pos := strings.IndexByte(setting, ':'); pos != -1 {
key = setting[:pos]
args = setting[pos+1:]
}
// 优先查找自定义函数
verifyFunctionsLock.RLock()
f, exists := verifyFunctions[key]
verifyFunctionsLock.RUnlock()
if exists {
set.Type = VerifyByFunc
set.Func = f
if args != "" {
set.StringArgs = strings.Split(args, ",")
}
return set, nil
}
// 内置规则
switch key {
case "length":
set.Type = VerifyStringLength
if args == "" {
args = "1+"
}
last := args[len(args)-1]
if last == '+' || last == '-' {
set.StringArgs = []string{string(last)}
args = args[:len(args)-1]
}
// 同时支持逗号和中划线
sep := ","
if strings.Contains(args, "-") && !strings.Contains(args, ",") {
sep = "-"
}
if strings.Contains(args, sep) {
a := strings.Split(args, sep)
set.IntArgs = []int{cast.Int(a[0]), cast.Int(a[1])}
} else {
set.IntArgs = []int{cast.Int(args)}
}
return set, nil
case "between":
set.Type = VerifyBetween
if args == "" {
args = "1-100000000"
}
a := strings.Split(args, "-")
if len(a) == 1 {
set.FloatArgs = []float64{0, cast.Float64(a[0])}
} else {
set.FloatArgs = []float64{cast.Float64(a[0]), cast.Float64(a[1])}
}
return set, nil
case "gt":
set.Type = VerifyGreaterThan
set.FloatArgs = []float64{cast.Float64(args)}
return set, nil
case "lt":
set.Type = VerifyLessThan
set.FloatArgs = []float64{cast.Float64(args)}
return set, nil
case "in":
set.Type = VerifyInList
if args != "" {
set.StringArgs = strings.Split(args, ",")
}
return set, nil
}
}
// 默认视为正则表达式
rx, err := regexp.Compile(setting)
if err != nil {
return nil, err
}
set.Type = VerifyRegex
set.Regex = rx
return set, nil
}

76
verify_test.go Normal file
View File

@ -0,0 +1,76 @@
package service
import (
"testing"
)
type TestUser struct {
Name string `verify:"length:2-10"`
Age int `verify:"between:18-100"`
Type string `verify:"in:admin,user,guest"`
}
type NestedStruct struct {
TestUser
Note string `verify:"^.{1,20}$"`
}
func TestVerifyStruct(t *testing.T) {
u := TestUser{Name: "Star", Age: 25, Type: "admin"}
if ok, f := VerifyStruct(u, nil); !ok {
t.Errorf("VerifyStruct failed on valid user, field: %s", f)
}
u.Name = "S"
if ok, f := VerifyStruct(u, nil); ok || f != "name" {
t.Errorf("VerifyStruct should fail on short name, got ok=%v, field=%s", ok, f)
}
u.Name = "Star"
u.Age = 10
if ok, f := VerifyStruct(u, nil); ok || f != "age" {
t.Errorf("VerifyStruct should fail on young age, got ok=%v, field=%s", ok, f)
}
u.Age = 25
u.Type = "invalid"
if ok, f := VerifyStruct(u, nil); ok || f != "type" {
t.Errorf("VerifyStruct should fail on invalid type, got ok=%v, field=%s", ok, f)
}
}
func TestNestedVerify(t *testing.T) {
n := NestedStruct{
TestUser: TestUser{Name: "Star", Age: 25, Type: "user"},
Note: "Hello",
}
if ok, f := VerifyStruct(n, nil); !ok {
t.Errorf("Nested VerifyStruct failed on valid data, field: %s", f)
}
n.TestUser.Age = 5
if ok, f := VerifyStruct(n, nil); ok || f != "age" {
t.Errorf("Nested VerifyStruct should fail on nested age, got ok=%v, field=%s", ok, f)
}
}
func TestCustomVerify(t *testing.T) {
RegisterVerifyFunc("odd", func(in any, args []string) bool {
val := in.(int)
return val%2 != 0
})
type OddStruct struct {
Num int `verify:"odd"`
}
o := OddStruct{Num: 3}
if ok, f := VerifyStruct(o, nil); !ok {
t.Errorf("Custom verify failed on odd number, field: %s", f)
}
o.Num = 4
if ok, f := VerifyStruct(o, nil); ok || f != "num" {
t.Errorf("Custom verify should fail on even number, got ok=%v, field=%s", ok, f)
}
}

175
websocket.go Normal file
View File

@ -0,0 +1,175 @@
package service
import (
"apigo.cc/go/cast"
"apigo.cc/go/log"
"github.com/gorilla/websocket"
"net/http"
"reflect"
"regexp"
)
// websocketServiceType WebSocket 服务元数据
type websocketServiceType struct {
authLevel int
path string
pathMatcher *regexp.Regexp
pathArgs []string
updater *websocket.Upgrader
openFuncValue reflect.Value
openFuncType reflect.Type
closeFuncValue reflect.Value
closeFuncType reflect.Type
sessionType reflect.Type
actions map[string]*websocketActionType
isSimple bool
options WebServiceOptions
memo string
}
// websocketActionType WebSocket Action 元数据
type websocketActionType struct {
authLevel int
funcValue reflect.Value
funcType reflect.Type
inType reflect.Type
memo string
}
// ActionRegister WebSocket Action 注册器
type ActionRegister struct {
ws *websocketServiceType
}
// RegisterWebsocket 注册 WebSocket 服务
func RegisterWebsocket(authLevel int, path string, onOpen, onClose any, memo string) *ActionRegister {
s := &websocketServiceType{
authLevel: authLevel,
path: path,
memo: memo,
updater: &websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }},
actions: make(map[string]*websocketActionType),
}
if onOpen != nil {
s.openFuncValue = reflect.ValueOf(onOpen)
s.openFuncType = s.openFuncValue.Type()
if s.openFuncType.NumOut() > 0 {
s.sessionType = s.openFuncType.Out(0)
}
}
if onClose != nil {
s.closeFuncValue = reflect.ValueOf(onClose)
s.closeFuncType = s.closeFuncValue.Type()
}
websocketServicesLock.Lock()
websocketServices[path] = s
websocketServicesLock.Unlock()
return &ActionRegister{ws: s}
}
// RegisterAction 注册 WebSocket Action
func (ar *ActionRegister) RegisterAction(authLevel int, name string, action any, memo string) {
v := reflect.ValueOf(action)
t := v.Type()
a := &websocketActionType{
authLevel: authLevel,
funcValue: v,
funcType: t,
memo: memo,
}
// 查找输入参数类型
for i := 0; i < t.NumIn(); i++ {
inT := t.In(i)
if inT.Kind() == reflect.Struct {
a.inType = inT
break
}
}
ar.ws.actions[name] = a
}
func doWebsocketService(ws *websocketServiceType, request *Request, response *Response, logger *log.Logger) {
conn, err := ws.updater.Upgrade(response.Writer, request.Request, nil)
if err != nil {
logger.Error("websocket upgrade failed", "error", err.Error())
return
}
defer conn.Close()
var session any
if ws.openFuncValue.IsValid() {
// 简化版:仅支持基础参数注入
params := make([]reflect.Value, ws.openFuncType.NumIn())
for i := 0; i < len(params); i++ {
t := ws.openFuncType.In(i)
if t == reflect.TypeOf(request) {
params[i] = reflect.ValueOf(request)
} else if t == reflect.TypeOf(logger) {
params[i] = reflect.ValueOf(logger)
} else {
params[i] = reflect.New(t).Elem()
}
}
outs := ws.openFuncValue.Call(params)
if len(outs) > 0 {
session = outs[0].Interface()
}
}
for {
var msg Map
if err := conn.ReadJSON(&msg); err != nil {
break
}
actionName := cast.String(msg["action"])
action := ws.actions[actionName]
if action == nil {
action = ws.actions[""] // 默认 action
}
if action != nil {
params := make([]reflect.Value, action.funcType.NumIn())
for i := 0; i < len(params); i++ {
t := action.funcType.In(i)
if t == ws.sessionType {
params[i] = reflect.ValueOf(session)
} else if t == reflect.TypeOf(conn) {
params[i] = reflect.ValueOf(conn)
} else if t.Kind() == reflect.Struct {
in := reflect.New(t).Interface()
cast.Convert(in, msg)
params[i] = reflect.ValueOf(in).Elem()
} else {
params[i] = reflect.New(t).Elem()
}
}
outs := action.funcValue.Call(params)
if len(outs) > 0 {
result := outs[0].Interface()
if result != nil {
_ = conn.WriteJSON(result)
}
}
}
}
if ws.closeFuncValue.IsValid() {
params := make([]reflect.Value, ws.closeFuncType.NumIn())
for i := 0; i < len(params); i++ {
t := ws.closeFuncType.In(i)
if t == ws.sessionType {
params[i] = reflect.ValueOf(session)
} else {
params[i] = reflect.New(t).Elem()
}
}
ws.closeFuncValue.Call(params)
}
}

44
websocket_test.go Normal file
View File

@ -0,0 +1,44 @@
package service
import (
"github.com/gorilla/websocket"
"net/http/httptest"
"strings"
"testing"
)
func TestWebSocketService(t *testing.T) {
// 注册 WebSocket 服务
ar := RegisterWebsocket(0, "/ws", nil, nil, "test websocket")
ar.RegisterAction(0, "echo", func(in struct{ Msg string }) Map {
return Map{"action": "echo", "reply": in.Msg}
}, "echo action")
// 启动测试服务器
server := httptest.NewServer(&routeHandler{})
defer server.Close()
// 建立连接
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/ws"
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("Dial failed: %v", err)
}
defer conn.Close()
// 发送消息
msg := Map{"action": "echo", "msg": "hello"}
if err := conn.WriteJSON(msg); err != nil {
t.Fatalf("WriteJSON failed: %v", err)
}
// 接收响应
var reply Map
if err := conn.ReadJSON(&reply); err != nil {
t.Fatalf("ReadJSON failed: %v", err)
}
if reply["reply"] != "hello" {
t.Errorf("Reply mismatch: %v", reply)
}
}