Migrate service module from ssgo/s with modern Go features (by AI)
This commit is contained in:
commit
bdb104aa2f
267
.log.meta.json
Normal file
267
.log.meta.json
Normal file
@ -0,0 +1,267 @@
|
||||
{
|
||||
"debug": [
|
||||
{
|
||||
"index": 0,
|
||||
"name": "LogName",
|
||||
"color": "cyan",
|
||||
"hide": true
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"name": "LogType",
|
||||
"color": "magenta",
|
||||
"hide": true
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"name": "LogTime",
|
||||
"format": "time"
|
||||
},
|
||||
{
|
||||
"index": 3,
|
||||
"name": "TraceId",
|
||||
"color": "blue"
|
||||
},
|
||||
{
|
||||
"index": 4,
|
||||
"name": "Image",
|
||||
"color": "darkGray",
|
||||
"hide": true
|
||||
},
|
||||
{
|
||||
"index": 5,
|
||||
"name": "Server",
|
||||
"color": "darkGray",
|
||||
"hide": true
|
||||
},
|
||||
{
|
||||
"index": 6,
|
||||
"name": "Debug",
|
||||
"withoutKey": true
|
||||
},
|
||||
{
|
||||
"index": 7,
|
||||
"name": "Extra"
|
||||
}
|
||||
],
|
||||
"discover": [
|
||||
{
|
||||
"index": 0,
|
||||
"name": "LogName",
|
||||
"color": "cyan",
|
||||
"hide": true
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"name": "LogType",
|
||||
"color": "magenta",
|
||||
"hide": true
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"name": "LogTime",
|
||||
"format": "time"
|
||||
},
|
||||
{
|
||||
"index": 3,
|
||||
"name": "TraceId",
|
||||
"color": "blue"
|
||||
},
|
||||
{
|
||||
"index": 4,
|
||||
"name": "Image",
|
||||
"color": "darkGray",
|
||||
"hide": true
|
||||
},
|
||||
{
|
||||
"index": 5,
|
||||
"name": "Server",
|
||||
"color": "darkGray",
|
||||
"hide": true
|
||||
},
|
||||
{
|
||||
"index": 6,
|
||||
"name": "App",
|
||||
"color": "cyan"
|
||||
},
|
||||
{
|
||||
"index": 7,
|
||||
"name": "Method",
|
||||
"color": "magenta"
|
||||
},
|
||||
{
|
||||
"index": 8,
|
||||
"name": "Path",
|
||||
"color": "blue"
|
||||
},
|
||||
{
|
||||
"index": 9,
|
||||
"name": "Node",
|
||||
"color": "yellow"
|
||||
},
|
||||
{
|
||||
"index": 10,
|
||||
"name": "Attempts"
|
||||
},
|
||||
{
|
||||
"index": 11,
|
||||
"name": "UsedTime",
|
||||
"format": "%.2fms"
|
||||
},
|
||||
{
|
||||
"index": 12,
|
||||
"name": "Error",
|
||||
"color": "red"
|
||||
},
|
||||
{
|
||||
"index": 13,
|
||||
"name": "Extra"
|
||||
}
|
||||
],
|
||||
"error": [
|
||||
{
|
||||
"index": 0,
|
||||
"name": "LogName",
|
||||
"color": "cyan",
|
||||
"hide": true
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"name": "LogType",
|
||||
"color": "magenta",
|
||||
"hide": true
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"name": "LogTime",
|
||||
"format": "time"
|
||||
},
|
||||
{
|
||||
"index": 3,
|
||||
"name": "TraceId",
|
||||
"color": "blue"
|
||||
},
|
||||
{
|
||||
"index": 4,
|
||||
"name": "Image",
|
||||
"color": "darkGray",
|
||||
"hide": true
|
||||
},
|
||||
{
|
||||
"index": 5,
|
||||
"name": "Server",
|
||||
"color": "darkGray",
|
||||
"hide": true
|
||||
},
|
||||
{
|
||||
"index": 6,
|
||||
"name": "Error",
|
||||
"color": "red",
|
||||
"withoutKey": true
|
||||
},
|
||||
{
|
||||
"index": 7,
|
||||
"name": "CallStacks"
|
||||
},
|
||||
{
|
||||
"index": 8,
|
||||
"name": "Extra"
|
||||
}
|
||||
],
|
||||
"info": [
|
||||
{
|
||||
"index": 0,
|
||||
"name": "LogName",
|
||||
"color": "cyan",
|
||||
"hide": true
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"name": "LogType",
|
||||
"color": "magenta",
|
||||
"hide": true
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"name": "LogTime",
|
||||
"format": "time"
|
||||
},
|
||||
{
|
||||
"index": 3,
|
||||
"name": "TraceId",
|
||||
"color": "blue"
|
||||
},
|
||||
{
|
||||
"index": 4,
|
||||
"name": "Image",
|
||||
"color": "darkGray",
|
||||
"hide": true
|
||||
},
|
||||
{
|
||||
"index": 5,
|
||||
"name": "Server",
|
||||
"color": "darkGray",
|
||||
"hide": true
|
||||
},
|
||||
{
|
||||
"index": 6,
|
||||
"name": "Info",
|
||||
"color": "cyan",
|
||||
"withoutKey": true
|
||||
},
|
||||
{
|
||||
"index": 7,
|
||||
"name": "Extra"
|
||||
}
|
||||
],
|
||||
"warning": [
|
||||
{
|
||||
"index": 0,
|
||||
"name": "LogName",
|
||||
"color": "cyan",
|
||||
"hide": true
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"name": "LogType",
|
||||
"color": "magenta",
|
||||
"hide": true
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"name": "LogTime",
|
||||
"format": "time"
|
||||
},
|
||||
{
|
||||
"index": 3,
|
||||
"name": "TraceId",
|
||||
"color": "blue"
|
||||
},
|
||||
{
|
||||
"index": 4,
|
||||
"name": "Image",
|
||||
"color": "darkGray",
|
||||
"hide": true
|
||||
},
|
||||
{
|
||||
"index": 5,
|
||||
"name": "Server",
|
||||
"color": "darkGray",
|
||||
"hide": true
|
||||
},
|
||||
{
|
||||
"index": 6,
|
||||
"name": "Warning",
|
||||
"color": "yellow",
|
||||
"withoutKey": true
|
||||
},
|
||||
{
|
||||
"index": 7,
|
||||
"name": "CallStacks"
|
||||
},
|
||||
{
|
||||
"index": 8,
|
||||
"name": "Extra"
|
||||
}
|
||||
]
|
||||
}
|
||||
216
DocTpl.html
Normal file
216
DocTpl.html
Normal file
@ -0,0 +1,216 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>{{.title}}</title>
|
||||
<style>
|
||||
html {
|
||||
overflow: auto;
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
body {
|
||||
margin: 0;
|
||||
padding: 10px;
|
||||
background: #fff;
|
||||
color: #333;
|
||||
font-size: 16px;
|
||||
overflow: hidden;
|
||||
display: flex;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
.nav {
|
||||
overflow-x: hidden;
|
||||
overflow-y: auto;
|
||||
flex: 1;
|
||||
background: #333;
|
||||
color: #fff;
|
||||
height: 100%;
|
||||
padding: 10px 0;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
.navItem {
|
||||
width: 100%;
|
||||
display: block;
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
padding: 4px 10px;
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
color: #fff;
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
.navItem:hover {
|
||||
background: #99ccff;
|
||||
}
|
||||
|
||||
.navItem .memo {
|
||||
font-weight: normal;
|
||||
font-size: 12px;
|
||||
color: #ccc;
|
||||
}
|
||||
|
||||
.apiBox {
|
||||
overflow: auto;
|
||||
flex: 3;
|
||||
padding: 8px;
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
header {
|
||||
border-bottom: #ddd 1px solid;
|
||||
margin-bottom: 5px;
|
||||
background: #333;
|
||||
color: #fff;
|
||||
padding: 12px;
|
||||
display: flex;
|
||||
align-items: baseline;
|
||||
}
|
||||
|
||||
header > span {
|
||||
font-weight: bold;
|
||||
margin-right: 10px;
|
||||
}
|
||||
|
||||
header > span.memo {
|
||||
font-weight: normal;
|
||||
font-size: 14px;
|
||||
color: #faebd7;
|
||||
}
|
||||
|
||||
label {
|
||||
display: inline-block;
|
||||
margin-right: 10px;
|
||||
padding: 4px 8px;
|
||||
font-size: 12px;
|
||||
background: #ccc;
|
||||
color: #000;
|
||||
font-weight: bold;
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
label.authLevel {
|
||||
background: #f90;
|
||||
color: #000;
|
||||
}
|
||||
|
||||
label.authLevel0 {
|
||||
background: #ccc;
|
||||
color: #000;
|
||||
}
|
||||
|
||||
label.type {
|
||||
background: #9cf;
|
||||
color: #000;
|
||||
}
|
||||
|
||||
section {
|
||||
margin-bottom: 40px;
|
||||
white-space: nowrap;
|
||||
font-size: 12px;
|
||||
}
|
||||
|
||||
header.Action, section.Action {
|
||||
margin-left: 20px;
|
||||
}
|
||||
|
||||
section > table {
|
||||
width: 50%;
|
||||
display: inline-table;
|
||||
border-collapse: collapse;
|
||||
vertical-align: top;
|
||||
font-size: 16px;
|
||||
|
||||
}
|
||||
|
||||
section > table:last-child {
|
||||
border-left: 1px solid #ddd;
|
||||
}
|
||||
|
||||
tr:nth-child(even) {
|
||||
background: #f9f9f9;
|
||||
}
|
||||
|
||||
th {
|
||||
padding: 8px;
|
||||
}
|
||||
|
||||
td {
|
||||
padding: 6px 12px;
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
|
||||
td:last-child {
|
||||
color: #666;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="nav">
|
||||
{{range .api}}
|
||||
<a href="#{{.Path}}" class="navItem">
|
||||
<span>{{.Path}}</span>
|
||||
<span class="memo">{{.Memo}}</span>
|
||||
</a>
|
||||
{{end}}
|
||||
</div>
|
||||
<div class="apiBox">
|
||||
{{range .api}}
|
||||
<a name="{{.Path}}"></a>
|
||||
<div style="height: 16px"></div>
|
||||
<header class="{{.Type}}">
|
||||
<span>{{.Path}}</span>
|
||||
<span class="memo">{{.Memo}}</span>
|
||||
{{if ne .Method ""}}<label>{{.Method}}</label>{{end}}
|
||||
<label title="Auth Level" class="authLevel authLevel{{.AuthLevel}}">{{.AuthLevel}}</label>
|
||||
{{if ne .Type "Web"}}<label class="type">{{.Type}}</label>{{end}}
|
||||
</header>
|
||||
<section class="{{.Type}}">
|
||||
<table>
|
||||
{{if isMap .In}}
|
||||
<tr>
|
||||
<th colspan="2">Request</th>
|
||||
</tr>
|
||||
{{range $k, $v := .In}}
|
||||
<tr>
|
||||
<td width="30%">{{$k}}</td>
|
||||
<td width="70%">{{toText $v}}</td>
|
||||
</tr>
|
||||
{{end}}
|
||||
{{else}}
|
||||
<tr>
|
||||
<td colspan="2">{{.In}}</td>
|
||||
</tr>
|
||||
{{end}}
|
||||
</table>
|
||||
<table>
|
||||
{{if isMap .Out}}
|
||||
<tr>
|
||||
<th colspan="2">Response</th>
|
||||
</tr>
|
||||
{{range $k, $v := .Out}}
|
||||
<tr>
|
||||
<td width="30%">{{$k}}</td>
|
||||
<td width="70%">{{toText $v}}</td>
|
||||
</tr>
|
||||
{{end}}
|
||||
{{else}}
|
||||
<tr>
|
||||
<td colspan="2">{{.Out}}</td>
|
||||
</tr>
|
||||
{{end}}
|
||||
</table>
|
||||
</section>
|
||||
{{else}}
|
||||
<div><strong>no document</strong></div>
|
||||
{{end}}
|
||||
<div style="height: 800px"></div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
56
README.md
Normal file
56
README.md
Normal file
@ -0,0 +1,56 @@
|
||||
# go/service (核心微服务框架)
|
||||
|
||||
极简、自动化的 Web 与 WebSocket 服务框架,实现极致的依赖注入与路由映射。
|
||||
|
||||
## 核心特性
|
||||
- **路由反射**: 自动解析函数参数,支持 `*Request`, `*Response`, `*log.Logger` 及自定义结构体自动注入。
|
||||
- **自动校验**: 集成 `verify` 引擎,通过 Struct Tag 实现入参合法性自动检查。
|
||||
- **功能闭环**: 内置静态文件服务、WebSocket (带 Action 路由)、URL 重写、反向代理(对接 Discover)。
|
||||
- **零摩擦启动**: 支持命令行指令管理 (start/stop/help) 及异步平滑启停。
|
||||
|
||||
## API 指南
|
||||
|
||||
### 1. 服务注册
|
||||
```go
|
||||
import "apigo.cc/go/service"
|
||||
|
||||
// 注册标准 Web 服务
|
||||
service.Register(0, "/hello", func(in struct{ Name string }) string {
|
||||
return "Hello " + in.Name
|
||||
}, "打招呼接口")
|
||||
|
||||
// 注册 Restful 服务
|
||||
service.Restful(0, "POST", "/user/{id}", func(args map[string]any) service.Result {
|
||||
res := service.Result{}
|
||||
res.OK()
|
||||
return res
|
||||
}, "更新用户")
|
||||
```
|
||||
|
||||
### 2. WebSocket 支持
|
||||
```go
|
||||
ar := service.RegisterWebsocket(0, "/ws", onOpen, onClose, "聊天室")
|
||||
ar.RegisterAction(0, "chat", func(in ChatMessage, sess *MySession) {
|
||||
// 处理消息
|
||||
}, "发送消息")
|
||||
```
|
||||
|
||||
### 3. 增强插件
|
||||
- **静态文件**: `service.Static("/ui", "./static_dir")`
|
||||
- **URL 重写**: `service.Rewrite("/old", "/new")`
|
||||
- **反向代理**: `service.Proxy(0, "/api", "other_app", "/api")`
|
||||
|
||||
### 4. 生命周期管理
|
||||
```go
|
||||
func main() {
|
||||
service.CheckCmd() // 处理 start/stop/help 指令
|
||||
service.Start() // 阻塞启动
|
||||
}
|
||||
```
|
||||
|
||||
## 基础设施对齐
|
||||
- **类型转换**: `apigo.cc/go/cast`
|
||||
- **日志系统**: `apigo.cc/go/log`
|
||||
- **服务发现**: `apigo.cc/go/discover`
|
||||
- **分布式 ID**: `apigo.cc/go/id`
|
||||
- **文件操作**: `apigo.cc/go/file`
|
||||
62
config.go
Normal file
62
config.go
Normal file
@ -0,0 +1,62 @@
|
||||
package service
|
||||
|
||||
// CertSet SSL 证书配置
|
||||
type CertSet struct {
|
||||
CertFile string
|
||||
KeyFile string
|
||||
}
|
||||
|
||||
// ServiceConfig 核心服务配置
|
||||
type ServiceConfig struct {
|
||||
Listen string // 监听端口(|隔开多个监听)(,隔开多个选项),例如 80,http|443|443:h2|127.0.0.1:8080,h2c
|
||||
SSL map[string]*CertSet // SSL 证书配置,key 为域名
|
||||
NoLogGets bool // 不记录 GET 请求的日志
|
||||
NoLogHeaders string // 不记录请求头中包含的这些字段,多个字段用逗号分隔
|
||||
LogInputArrayNum int // 请求字段中容器类型在日志打印个数限制
|
||||
LogInputFieldSize int // 请求字段中单个字段在日志打印长度限制
|
||||
NoLogOutputFields string // 不记录响应字段中包含的这些字段
|
||||
LogOutputArrayNum int // 响应字段中容器类型在日志打印个数限制
|
||||
LogOutputFieldSize int // 响应字段中单个字段在日志打印长度限制
|
||||
LogWebsocketAction bool // 记录 Websocket 中每个 Action 的请求日志
|
||||
Compress bool // 是否启用压缩
|
||||
CompressMinSize int // 启用压缩的最小长度
|
||||
CompressMaxSize int // 启用压缩的最大长度
|
||||
CheckDomain string // 心跳检测时使用域名
|
||||
AccessTokens map[string]*int // 指定 Access-Token 验证及其对应的 auth-level
|
||||
RedirectTimeout int // Proxy 和 Discover 发起请求时的超时时间 (ms)
|
||||
AcceptXRealIpWithoutRequestId bool // 是否允许头部没有携带请求ID的 X-Real-IP 信息
|
||||
StatisticTime bool // 是否开启请求时间统计
|
||||
StatisticTimeInterval int // 统计时间间隔 (ms)
|
||||
Fast bool // 是否启用快速模式
|
||||
MaxUploadSize int64 // 最大上传文件大小 (Bytes)
|
||||
IpPrefix string // Discover 服务发现时指定使用的 IP 网段
|
||||
Cpu int // CPU 占用的核数限制
|
||||
Memory int // 内存限制 (MB)
|
||||
CpuMonitor bool // 记录 CPU 使用情况
|
||||
MemoryMonitor bool // 记录内存使用情况
|
||||
CpuLimitValue uint // CPU 自动重启阈值 (10-100)
|
||||
MemoryLimitValue uint // 内存自动重启阈值 (10-100)
|
||||
CpuLimitTimes uint // CPU 报警阈值连续次数
|
||||
MemoryLimitTimes uint // 内存报警阈值连续次数
|
||||
CookieScope string // Session Cookie 有效范围: host|domain|topDomain
|
||||
SessionWithoutCookie bool // Session 禁用 Cookie
|
||||
DeviceWithoutCookie bool // 设备ID禁用 Cookie
|
||||
IdServer string // Redis 服务器连接 (用于全局唯一 ID 生成)
|
||||
KeepKeyCase bool // 是否保持 Key 的首字母大小写
|
||||
IndexFiles []string // 静态文件索引文件
|
||||
IndexDir bool // 访问目录时显示文件列表
|
||||
ReadTimeout int // 读取请求的超时时间 (ms)
|
||||
ReadHeaderTimeout int // 读取请求头的超时时间 (ms)
|
||||
WriteTimeout int // 响应写入的超时时间 (ms)
|
||||
IdleTimeout int // 连接空闲超时时间 (ms)
|
||||
MaxHeaderBytes int // 请求头的最大字节数
|
||||
MaxHandlers int // 每个连接的最大处理程序数量
|
||||
MaxConcurrentStreams uint32 // 每个连接的最大并发流数量
|
||||
MaxDecoderHeaderTableSize uint32 // 解码器头表的最大大小
|
||||
MaxEncoderHeaderTableSize uint32 // 编码器头表的最大大小
|
||||
MaxReadFrameSize uint32 // 单个帧的最大读取大小
|
||||
MaxUploadBufferPerConnection int32 // 每个连接的最大上传缓冲区大小
|
||||
MaxUploadBufferPerStream int32 // 每个流的最大上传缓冲区大小
|
||||
}
|
||||
|
||||
var Config = ServiceConfig{}
|
||||
162
document.go
Normal file
162
document.go
Normal file
@ -0,0 +1,162 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"apigo.cc/go/cast"
|
||||
_ "embed"
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// Api 接口文档信息
|
||||
type Api struct {
|
||||
Type string
|
||||
Path string
|
||||
AuthLevel int
|
||||
Method string
|
||||
In any
|
||||
Out any
|
||||
Memo string
|
||||
}
|
||||
|
||||
//go:embed DocTpl.html
|
||||
var defaultDocTpl string
|
||||
|
||||
// MakeDocument 生成文档数据
|
||||
func MakeDocument() []Api {
|
||||
out := make([]Api, 0)
|
||||
|
||||
// 1. Rewrite
|
||||
rewritesLock.RLock()
|
||||
for _, a := range rewrites {
|
||||
out = append(out, Api{
|
||||
Type: "Rewrite",
|
||||
Path: a.fromPath + " -> " + a.toPath,
|
||||
})
|
||||
}
|
||||
rewritesLock.RUnlock()
|
||||
|
||||
// 2. Proxy
|
||||
proxiesLock.RLock()
|
||||
for _, a := range proxies {
|
||||
out = append(out, Api{
|
||||
Type: "Proxy",
|
||||
Path: a.fromPath + " -> " + a.toApp + ":" + a.toPath,
|
||||
})
|
||||
}
|
||||
proxiesLock.RUnlock()
|
||||
|
||||
// 3. Web Services
|
||||
webServicesLock.RLock()
|
||||
for _, a := range webServicesList {
|
||||
if a.options.NoDoc {
|
||||
continue
|
||||
}
|
||||
api := Api{
|
||||
Type: "Web",
|
||||
Path: a.path,
|
||||
AuthLevel: a.authLevel,
|
||||
Method: a.method,
|
||||
Memo: a.memo,
|
||||
}
|
||||
if a.inType != nil {
|
||||
api.In = getType(a.inType)
|
||||
}
|
||||
if a.funcType.NumOut() > 0 {
|
||||
api.Out = getType(a.funcType.Out(0))
|
||||
}
|
||||
out = append(out, api)
|
||||
}
|
||||
webServicesLock.RUnlock()
|
||||
|
||||
// 4. WebSocket Services
|
||||
websocketServicesLock.RLock()
|
||||
for _, a := range websocketServices {
|
||||
api := Api{
|
||||
Type: "WebSocket",
|
||||
Path: a.path,
|
||||
AuthLevel: a.authLevel,
|
||||
Memo: a.memo,
|
||||
}
|
||||
if a.openFuncType != nil && a.openFuncType.NumIn() > 0 {
|
||||
// Find struct in
|
||||
for i := 0; i < a.openFuncType.NumIn(); i++ {
|
||||
t := a.openFuncType.In(i)
|
||||
if t.Kind() == reflect.Struct {
|
||||
api.In = getType(t)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
out = append(out, api)
|
||||
|
||||
for name, act := range a.actions {
|
||||
actionApi := Api{
|
||||
Type: "Action",
|
||||
Path: name,
|
||||
AuthLevel: act.authLevel,
|
||||
Memo: act.memo,
|
||||
}
|
||||
if act.inType != nil {
|
||||
actionApi.In = getType(act.inType)
|
||||
}
|
||||
if act.funcType.NumOut() > 0 {
|
||||
actionApi.Out = getType(act.funcType.Out(0))
|
||||
}
|
||||
out = append(out, actionApi)
|
||||
}
|
||||
}
|
||||
websocketServicesLock.RUnlock()
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// MakeJsonDocument 生成 JSON 格式文档
|
||||
func MakeJsonDocument() string {
|
||||
apis := MakeDocument()
|
||||
data, _ := json.MarshalIndent(map[string]any{
|
||||
"api": apis,
|
||||
}, "", "\t")
|
||||
return string(data)
|
||||
}
|
||||
|
||||
func getType(t reflect.Type) any {
|
||||
if t == nil {
|
||||
return ""
|
||||
}
|
||||
for t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
|
||||
switch t.Kind() {
|
||||
case reflect.Struct:
|
||||
outs := Map{}
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
f := t.Field(i)
|
||||
if f.Anonymous {
|
||||
if subMap, ok := getType(f.Type).(Map); ok {
|
||||
for k, v := range subMap {
|
||||
outs[k] = v
|
||||
}
|
||||
}
|
||||
} else {
|
||||
outs[cast.GetLowerName(f.Name)] = getType(f.Type)
|
||||
}
|
||||
}
|
||||
return outs
|
||||
case reflect.Map:
|
||||
return map[string]any{t.Key().String(): getType(t.Elem())}
|
||||
case reflect.Slice:
|
||||
return []any{getType(t.Elem())}
|
||||
case reflect.Interface:
|
||||
return "Any"
|
||||
default:
|
||||
return t.String()
|
||||
}
|
||||
}
|
||||
|
||||
// 自动注册文档服务
|
||||
func init() {
|
||||
Register(0, "/__DOC__", func() string {
|
||||
return MakeJsonDocument()
|
||||
}, "API Document")
|
||||
}
|
||||
5
go.mod
Normal file
5
go.mod
Normal file
@ -0,0 +1,5 @@
|
||||
module apigo.cc/go/service
|
||||
|
||||
go 1.25.0
|
||||
|
||||
require github.com/gorilla/websocket v1.5.3
|
||||
2
go.sum
Normal file
2
go.sum
Normal file
@ -0,0 +1,2 @@
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
319
handler.go
Normal file
319
handler.go
Normal file
@ -0,0 +1,319 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"apigo.cc/go/cast"
|
||||
"apigo.cc/go/id"
|
||||
"apigo.cc/go/log"
|
||||
"apigo.cc/go/standard"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type routeHandler struct {
|
||||
webRequestingNum int64
|
||||
}
|
||||
|
||||
func (rh *routeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
atomic.AddInt64(&rh.webRequestingNum, 1)
|
||||
defer atomic.AddInt64(&rh.webRequestingNum, -1)
|
||||
|
||||
startTime := time.Now()
|
||||
requestId := r.Header.Get(standard.DiscoverHeaderRequestId)
|
||||
if requestId == "" {
|
||||
requestId = id.MakeID(12)
|
||||
r.Header.Set(standard.DiscoverHeaderRequestId, requestId)
|
||||
}
|
||||
|
||||
request := NewRequest(r)
|
||||
request.Id = requestId
|
||||
response := NewResponse(w)
|
||||
response.Id = requestId
|
||||
defer response.checkWriteHeader()
|
||||
|
||||
// 处理 SessionId 和 DeviceId
|
||||
handleClientKeys(request, response)
|
||||
|
||||
requestLogger := log.New(requestId)
|
||||
|
||||
// 0. 处理重写 (Rewrite)
|
||||
if processRewrite(request, response, requestLogger) {
|
||||
return
|
||||
}
|
||||
|
||||
// 处理代理 (Proxy)
|
||||
if processProxy(request, response, requestLogger) {
|
||||
return
|
||||
}
|
||||
|
||||
// 1. 路由匹配
|
||||
path := r.URL.Path
|
||||
host := r.Host
|
||||
|
||||
// 处理静态文件
|
||||
if processStatic(path, request, response, requestLogger) {
|
||||
return
|
||||
}
|
||||
|
||||
s, ws := findService(r.Method, host, path)
|
||||
|
||||
// 2. 参数解析 (Form & Body)
|
||||
args := make(map[string]any)
|
||||
parseRequestArgs(request, args)
|
||||
|
||||
// 3. 前置过滤器
|
||||
var result any
|
||||
for _, filter := range inFilters {
|
||||
result = filter(&args, request, response, requestLogger)
|
||||
if result != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 处理业务执行 (WS 或 Web)
|
||||
if result == nil {
|
||||
if ws != nil {
|
||||
doWebsocketService(ws, request, response, requestLogger)
|
||||
return
|
||||
} else if s != nil {
|
||||
// 鉴权
|
||||
pass, obj := checkAuth(s, request, response, args, requestLogger)
|
||||
if !pass {
|
||||
if !response.changed {
|
||||
response.WriteHeader(http.StatusForbidden)
|
||||
}
|
||||
return
|
||||
}
|
||||
// 执行业务
|
||||
result = doWebService(s, request, response, args, nil, requestLogger, obj)
|
||||
}
|
||||
}
|
||||
|
||||
if s == nil && result == nil {
|
||||
response.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// 5. 后置过滤器
|
||||
for _, filter := range outFilters {
|
||||
newResult, done := filter(args, request, response, result, requestLogger)
|
||||
if newResult != nil {
|
||||
result = newResult
|
||||
}
|
||||
if done {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 6. 输出结果
|
||||
outputResult(response, result)
|
||||
|
||||
// 7. 记录日志
|
||||
_ = startTime
|
||||
}
|
||||
|
||||
func findService(method, host, path string) (*webServiceType, *websocketServiceType) {
|
||||
webServicesLock.RLock()
|
||||
defer webServicesLock.RUnlock()
|
||||
|
||||
// 1. Web Service 匹配
|
||||
if s, exists := webServices[method+path]; exists {
|
||||
return s, nil
|
||||
}
|
||||
if s, exists := webServices[path]; exists {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// 2. WebSocket 匹配
|
||||
websocketServicesLock.RLock()
|
||||
defer websocketServicesLock.RUnlock()
|
||||
if ws, exists := websocketServices[path]; exists {
|
||||
return nil, ws
|
||||
}
|
||||
|
||||
// 3. 正则匹配
|
||||
for i := len(regexWebServices) - 1; i >= 0; i-- {
|
||||
s := regexWebServices[i]
|
||||
if s.method != "" && s.method != method {
|
||||
continue
|
||||
}
|
||||
if s.pathMatcher != nil && s.pathMatcher.MatchString(path) {
|
||||
return s, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func parseRequestArgs(request *Request, args map[string]any) {
|
||||
// Query params
|
||||
query := request.URL.Query()
|
||||
for k, v := range query {
|
||||
if len(v) == 1 {
|
||||
args[k] = v[0]
|
||||
} else {
|
||||
args[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// Form params
|
||||
if request.Method == http.MethodPost || request.Method == http.MethodPut {
|
||||
contentType := request.Header.Get("Content-Type")
|
||||
if strings.HasPrefix(contentType, "application/json") {
|
||||
body, _ := io.ReadAll(request.Body)
|
||||
_ = request.Body.Close()
|
||||
if len(body) > 0 {
|
||||
_ = json.Unmarshal(body, &args)
|
||||
}
|
||||
} else {
|
||||
_ = request.ParseForm()
|
||||
for k, v := range request.Form {
|
||||
if len(v) == 1 {
|
||||
args[k] = v[0]
|
||||
} else {
|
||||
args[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func checkAuth(s *webServiceType, request *Request, response *Response, args map[string]any, logger *log.Logger) (bool, any) {
|
||||
ac := webAuthCheckers[s.authLevel]
|
||||
if ac == nil {
|
||||
ac = webAuthChecker
|
||||
}
|
||||
if ac == nil {
|
||||
return true, nil
|
||||
}
|
||||
return ac(s.authLevel, logger, &request.RequestURI, args, request, response, &s.options)
|
||||
}
|
||||
|
||||
func doWebService(service *webServiceType, request *Request, response *Response, args map[string]any,
|
||||
result any, logger *log.Logger, object any) any {
|
||||
if result != nil {
|
||||
return result
|
||||
}
|
||||
|
||||
params := make([]reflect.Value, service.parmsNum)
|
||||
for i := 0; i < service.parmsNum; i++ {
|
||||
t := service.funcType.In(i)
|
||||
switch i {
|
||||
case service.requestIndex:
|
||||
params[i] = reflect.ValueOf(request)
|
||||
case service.httpRequestIndex:
|
||||
params[i] = reflect.ValueOf(request.Request)
|
||||
case service.responseIndex:
|
||||
params[i] = reflect.ValueOf(response)
|
||||
case service.responseWriterIndex:
|
||||
params[i] = reflect.ValueOf(response.Writer)
|
||||
case service.loggerIndex:
|
||||
params[i] = reflect.ValueOf(logger)
|
||||
case service.inIndex:
|
||||
in := reflect.New(service.inType).Interface()
|
||||
cast.Convert(in, args)
|
||||
// 参数校验
|
||||
if service.inType.Kind() == reflect.Struct {
|
||||
if ok, _ := VerifyStruct(in, logger); !ok {
|
||||
response.WriteHeader(http.StatusBadRequest)
|
||||
return "parameter verification failed"
|
||||
}
|
||||
}
|
||||
params[i] = reflect.ValueOf(in).Elem()
|
||||
default:
|
||||
// 尝试依赖注入
|
||||
if obj := GetInject(t); obj != nil {
|
||||
params[i] = reflect.ValueOf(obj)
|
||||
} else {
|
||||
params[i] = reflect.New(t).Elem()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
outs := service.funcValue.Call(params)
|
||||
if len(outs) > 0 {
|
||||
return outs[0].Interface()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func outputResult(response *Response, result any) {
|
||||
if result == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var data []byte
|
||||
contentType := ""
|
||||
|
||||
switch v := result.(type) {
|
||||
case string:
|
||||
data = []byte(v)
|
||||
case []byte:
|
||||
data = v
|
||||
default:
|
||||
data, _ = cast.ToJSONBytes(result)
|
||||
contentType = "application/json; charset=UTF-8"
|
||||
}
|
||||
|
||||
if contentType != "" && response.Header().Get("Content-Type") == "" {
|
||||
response.Header().Set("Content-Type", contentType)
|
||||
}
|
||||
_, _ = response.Write(data)
|
||||
}
|
||||
|
||||
func handleClientKeys(request *Request, response *Response) {
|
||||
// SessionId
|
||||
if usedSessionIdKey != "" {
|
||||
sessionId := request.Header.Get(usedSessionIdKey)
|
||||
if sessionId == "" && !Config.SessionWithoutCookie {
|
||||
if ck, err := request.Cookie(usedSessionIdKey); err == nil {
|
||||
sessionId = ck.Value
|
||||
}
|
||||
}
|
||||
if sessionId == "" {
|
||||
if sessionIdMaker != nil {
|
||||
sessionId = sessionIdMaker()
|
||||
} else {
|
||||
sessionId = id.MakeID(14)
|
||||
}
|
||||
if !Config.SessionWithoutCookie {
|
||||
http.SetCookie(response.Writer, &http.Cookie{
|
||||
Name: usedSessionIdKey,
|
||||
Value: sessionId,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
request.Header.Set(standard.DiscoverHeaderSessionId, sessionId)
|
||||
response.Header().Set(usedSessionIdKey, sessionId)
|
||||
}
|
||||
|
||||
// DeviceId
|
||||
if usedDeviceIdKey != "" {
|
||||
deviceId := request.Header.Get(usedDeviceIdKey)
|
||||
if deviceId == "" && !Config.DeviceWithoutCookie {
|
||||
if ck, err := request.Cookie(usedDeviceIdKey); err == nil {
|
||||
deviceId = ck.Value
|
||||
}
|
||||
}
|
||||
if deviceId == "" {
|
||||
deviceId = id.MakeID(14)
|
||||
if !Config.DeviceWithoutCookie {
|
||||
http.SetCookie(response.Writer, &http.Cookie{
|
||||
Name: usedDeviceIdKey,
|
||||
Value: deviceId,
|
||||
Path: "/",
|
||||
Expires: time.Now().AddDate(10, 0, 0),
|
||||
HttpOnly: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
request.Header.Set(standard.DiscoverHeaderDeviceId, deviceId)
|
||||
response.Header().Set(usedDeviceIdKey, deviceId)
|
||||
}
|
||||
}
|
||||
67
handler_test.go
Normal file
67
handler_test.go
Normal file
@ -0,0 +1,67 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestServeHTTP(t *testing.T) {
|
||||
// 注册服务
|
||||
handler := func(in struct{ Name string }) string {
|
||||
return "Hello " + in.Name
|
||||
}
|
||||
Register(0, "/hello", handler, "say hello")
|
||||
|
||||
rh := &routeHandler{}
|
||||
|
||||
// 模拟请求
|
||||
req := httptest.NewRequest("POST", "/hello", strings.NewReader(`{"name":"Star"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
rh.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
body := w.Body.String()
|
||||
if body != "Hello Star" {
|
||||
t.Errorf("Expected 'Hello Star', got '%s'", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeHTTP_404(t *testing.T) {
|
||||
rh := &routeHandler{}
|
||||
req := httptest.NewRequest("GET", "/notfound", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
rh.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("Expected status 404, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeHTTP_VerifyFailed(t *testing.T) {
|
||||
type ValidIn struct {
|
||||
Age int `verify:"between:18-100"`
|
||||
}
|
||||
handler := func(in ValidIn) string {
|
||||
return "ok"
|
||||
}
|
||||
Register(0, "/verify", handler, "test verify")
|
||||
|
||||
rh := &routeHandler{}
|
||||
req := httptest.NewRequest("POST", "/verify", strings.NewReader(`{"age":10}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
rh.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
171
proxy.go
Normal file
171
proxy.go
Normal file
@ -0,0 +1,171 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"apigo.cc/go/discover"
|
||||
gohttp "apigo.cc/go/http"
|
||||
"apigo.cc/go/log"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type proxyInfo struct {
|
||||
matcher *regexp.Regexp
|
||||
authLevel int
|
||||
fromPath string
|
||||
toApp string
|
||||
toPath string
|
||||
}
|
||||
|
||||
var (
|
||||
proxies = make(map[string]*proxyInfo)
|
||||
regexProxies = make([]*proxyInfo, 0)
|
||||
proxyBy func(*Request) (int, *string, *string, map[string]string)
|
||||
proxiesLock = sync.RWMutex{}
|
||||
|
||||
httpClientPool *gohttp.Client
|
||||
)
|
||||
|
||||
// Proxy 注册代理规则
|
||||
func Proxy(authLevel int, path string, toApp, toPath string) {
|
||||
p := &proxyInfo{authLevel: authLevel, fromPath: path, toApp: toApp, toPath: toPath}
|
||||
if strings.Contains(path, "(") {
|
||||
matcher, err := regexp.Compile("^" + path + "$")
|
||||
if err == nil {
|
||||
p.matcher = matcher
|
||||
proxiesLock.Lock()
|
||||
regexProxies = append(regexProxies, p)
|
||||
proxiesLock.Unlock()
|
||||
}
|
||||
} else {
|
||||
proxiesLock.Lock()
|
||||
proxies[path] = p
|
||||
proxiesLock.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// SetProxyBy 设置动态代理函数
|
||||
func SetProxyBy(by func(request *Request) (authLevel int, toApp, toPath *string, headers map[string]string)) {
|
||||
proxyBy = by
|
||||
}
|
||||
|
||||
func findProxy(request *Request) (int, *string, *string) {
|
||||
requestPath := request.RequestURI
|
||||
queryString := ""
|
||||
if pos := strings.Index(requestPath, "?"); pos != -1 {
|
||||
queryString = requestPath[pos:]
|
||||
requestPath = requestPath[:pos]
|
||||
}
|
||||
|
||||
proxiesLock.RLock()
|
||||
defer proxiesLock.RUnlock()
|
||||
|
||||
if pi, ok := proxies[requestPath]; ok {
|
||||
toPath := pi.toPath + queryString
|
||||
return pi.authLevel, &pi.toApp, &toPath
|
||||
}
|
||||
|
||||
for _, pi := range regexProxies {
|
||||
if pi.matcher != nil {
|
||||
finds := pi.matcher.FindAllStringSubmatch(requestPath, 1)
|
||||
if len(finds) > 0 {
|
||||
toApp := pi.toApp
|
||||
toPath := pi.toPath
|
||||
for i, part := range finds[0] {
|
||||
toApp = strings.ReplaceAll(toApp, fmt.Sprintf("$%d", i), part)
|
||||
toPath = strings.ReplaceAll(toPath, fmt.Sprintf("$%d", i), part)
|
||||
}
|
||||
toPath += queryString
|
||||
return pi.authLevel, &toApp, &toPath
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0, nil, nil
|
||||
}
|
||||
|
||||
func processProxy(request *Request, response *Response, logger *log.Logger) bool {
|
||||
authLevel, proxyToApp, proxyToPath := findProxy(request)
|
||||
var proxyHeaders map[string]string
|
||||
|
||||
if proxyBy != nil && (proxyToApp == nil || proxyToPath == nil || *proxyToApp == "" || *proxyToPath == "") {
|
||||
authLevel, proxyToApp, proxyToPath, proxyHeaders = proxyBy(request)
|
||||
}
|
||||
|
||||
if proxyToApp == nil || proxyToPath == nil || *proxyToApp == "" || *proxyToPath == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// 鉴权
|
||||
pass, obj := checkAuthForProxy(authLevel, request, response, logger)
|
||||
if !pass {
|
||||
if !response.changed {
|
||||
response.WriteHeader(http.StatusForbidden)
|
||||
}
|
||||
return true
|
||||
}
|
||||
_ = obj // Currently unused in proxy
|
||||
|
||||
app := *proxyToApp
|
||||
path := *proxyToPath
|
||||
|
||||
// 构建自定义头部
|
||||
headerArgs := make([]string, 0)
|
||||
for k, v := range proxyHeaders {
|
||||
headerArgs = append(headerArgs, k, v)
|
||||
}
|
||||
|
||||
if strings.Contains(app, "://") {
|
||||
// 直接 URL 代理
|
||||
if httpClientPool == nil {
|
||||
httpClientPool = gohttp.NewClient(time.Duration(Config.RedirectTimeout) * time.Millisecond)
|
||||
}
|
||||
res := httpClientPool.ManualDoByRequest(request.Request, request.Method, app+path, request.Body, headerArgs...)
|
||||
copyResponse(res, response, logger)
|
||||
} else {
|
||||
// Discover 代理
|
||||
caller := discover.NewCaller(request.Request, logger)
|
||||
caller.NoBody = true
|
||||
res, _ := caller.ManualDoWithNode(request.Method, app, "", path, request.Body, headerArgs...)
|
||||
copyResponse(res, response, logger)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func checkAuthForProxy(authLevel int, request *Request, response *Response, logger *log.Logger) (bool, any) {
|
||||
ac := webAuthCheckers[authLevel]
|
||||
if ac == nil {
|
||||
ac = webAuthChecker
|
||||
}
|
||||
if ac == nil {
|
||||
return true, nil
|
||||
}
|
||||
return ac(authLevel, logger, &request.RequestURI, nil, request, response, nil)
|
||||
}
|
||||
|
||||
func copyResponse(res *gohttp.Result, response *Response, logger *log.Logger) {
|
||||
if res.Error != nil || res.Response == nil {
|
||||
response.WriteHeader(http.StatusBadGateway)
|
||||
if res.Error != nil {
|
||||
_, _ = response.WriteString(res.Error.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
for k, v := range res.Response.Header {
|
||||
response.Header().Set(k, v[0])
|
||||
}
|
||||
response.WriteHeader(res.Response.StatusCode)
|
||||
if res.Response.Body != nil {
|
||||
defer res.Response.Body.Close()
|
||||
_, err := io.Copy(response.Writer, res.Response.Body)
|
||||
if err != nil {
|
||||
logger.Error("proxy copy body failed", "error", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
62
proxy_test.go
Normal file
62
proxy_test.go
Normal file
@ -0,0 +1,62 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRewrite(t *testing.T) {
|
||||
// 注册重写规则
|
||||
Rewrite("/old", "/new")
|
||||
Rewrite("/regex/(.*)", "/target/$1")
|
||||
|
||||
// 注册目标服务
|
||||
Register(0, "/new", func() string { return "new content" }, "new")
|
||||
Register(0, "/target/123", func() string { return "target content" }, "target")
|
||||
|
||||
rh := &routeHandler{}
|
||||
|
||||
// 测试精确匹配重写
|
||||
req1 := httptest.NewRequest("GET", "/old", nil)
|
||||
w1 := httptest.NewRecorder()
|
||||
rh.ServeHTTP(w1, req1)
|
||||
if w1.Body.String() != "new content" {
|
||||
t.Errorf("Expected 'new content', got '%s'", w1.Body.String())
|
||||
}
|
||||
|
||||
// 测试正则匹配重写
|
||||
req2 := httptest.NewRequest("GET", "/regex/123", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
rh.ServeHTTP(w2, req2)
|
||||
if w2.Body.String() != "target content" {
|
||||
t.Errorf("Expected 'target content', got '%s'", w2.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyDirect(t *testing.T) {
|
||||
// 启动后端服务器
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Backend", "ok")
|
||||
w.Write([]byte("backend content"))
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
// 注册代理规则
|
||||
Proxy(0, "/proxy", backend.URL, "/hello")
|
||||
|
||||
rh := &routeHandler{}
|
||||
req := httptest.NewRequest("GET", "/proxy", nil)
|
||||
w := httptest.NewRecorder()
|
||||
rh.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected 200, got %d", w.Code)
|
||||
}
|
||||
if w.Header().Get("X-Backend") != "ok" {
|
||||
t.Error("Header X-Backend mismatch")
|
||||
}
|
||||
if w.Body.String() != "backend content" {
|
||||
t.Errorf("Expected 'backend content', got '%s'", w.Body.String())
|
||||
}
|
||||
}
|
||||
139
request.go
Normal file
139
request.go
Normal file
@ -0,0 +1,139 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"apigo.cc/go/cast"
|
||||
"apigo.cc/go/standard"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// UploadFile 上传文件结构
|
||||
type UploadFile struct {
|
||||
fileHeader *multipart.FileHeader
|
||||
Filename string
|
||||
Header textproto.MIMEHeader
|
||||
Size int64
|
||||
}
|
||||
|
||||
// Open 打开上传文件
|
||||
func (f *UploadFile) Open() (multipart.File, error) {
|
||||
return f.fileHeader.Open()
|
||||
}
|
||||
|
||||
// Save 保存上传文件到本地
|
||||
func (f *UploadFile) Save(filename string) error {
|
||||
dir := filepath.Dir(filename)
|
||||
if _, err := os.Stat(dir); os.IsNotExist(err) {
|
||||
_ = os.MkdirAll(dir, 0755)
|
||||
}
|
||||
|
||||
dst, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer dst.Close()
|
||||
|
||||
src, err := f.fileHeader.Open()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
_, err = io.Copy(dst, src)
|
||||
return err
|
||||
}
|
||||
|
||||
// Content 获取上传文件内容
|
||||
func (f *UploadFile) Content() ([]byte, error) {
|
||||
src, err := f.fileHeader.Open()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer src.Close()
|
||||
return io.ReadAll(src)
|
||||
}
|
||||
|
||||
// Request 封装 http.Request
|
||||
type Request struct {
|
||||
*http.Request
|
||||
contextValues map[string]any
|
||||
Id string
|
||||
}
|
||||
|
||||
// NewRequest 创建 Request 包装
|
||||
func NewRequest(httpRequest *http.Request) *Request {
|
||||
return &Request{
|
||||
Request: httpRequest,
|
||||
contextValues: make(map[string]any),
|
||||
}
|
||||
}
|
||||
|
||||
// ResetPath 重写请求路径
|
||||
func (r *Request) ResetPath(path string) {
|
||||
r.RequestURI = path
|
||||
if u, err := url.Parse(path); err == nil {
|
||||
r.URL = u
|
||||
}
|
||||
}
|
||||
|
||||
// Set 设置请求上下文变量
|
||||
func (r *Request) Set(key string, value any) {
|
||||
r.contextValues[key] = value
|
||||
}
|
||||
|
||||
// Get 获取请求上下文变量
|
||||
func (r *Request) Get(key string) any {
|
||||
return r.contextValues[key]
|
||||
}
|
||||
|
||||
// MakeUrl 根据当前请求构建完整 URL
|
||||
func (r *Request) MakeUrl(path string) string {
|
||||
scheme := r.Header.Get(standard.DiscoverHeaderScheme)
|
||||
if scheme == "" {
|
||||
scheme = "http"
|
||||
}
|
||||
host := r.Header.Get(standard.DiscoverHeaderHost)
|
||||
if host == "" {
|
||||
host = r.Host
|
||||
}
|
||||
return scheme + "://" + host + path
|
||||
}
|
||||
|
||||
// GetSessionId 获取会话 ID
|
||||
func (r *Request) GetSessionId() string {
|
||||
sessionId := r.Header.Get(Config.Listen) // Wait, this should be usedSessionIdKey
|
||||
// TODO: Fix dependency on global usedSessionIdKey
|
||||
return sessionId
|
||||
}
|
||||
|
||||
// SetUserId 设置用户 ID(传递给下游)
|
||||
func (r *Request) SetUserId(userId string) {
|
||||
r.Header.Set(standard.DiscoverHeaderUserId, userId)
|
||||
}
|
||||
|
||||
// GetRealIp 获取真实 IP
|
||||
func (r *Request) GetRealIp() string {
|
||||
ip := r.Header.Get(standard.DiscoverHeaderClientIp)
|
||||
if ip == "" {
|
||||
ip = r.Header.Get(standard.DiscoverHeaderForwardedFor)
|
||||
}
|
||||
if ip == "" {
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err == nil {
|
||||
return host
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
// GetLowerName (Aliased from cast)
|
||||
func GetLowerName(s string) string {
|
||||
return cast.GetLowerName(s)
|
||||
}
|
||||
152
response.go
Normal file
152
response.go
Normal file
@ -0,0 +1,152 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"apigo.cc/go/cast"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
)
|
||||
|
||||
// Response 封装 http.ResponseWriter
|
||||
type Response struct {
|
||||
Id string
|
||||
Writer http.ResponseWriter
|
||||
status int
|
||||
outLen int
|
||||
changed bool
|
||||
headerWritten bool
|
||||
dontLog200 bool
|
||||
dontLogArgs []string
|
||||
ProxyHeader *http.Header
|
||||
}
|
||||
|
||||
// NewResponse 创建 Response 包装
|
||||
func NewResponse(writer http.ResponseWriter) *Response {
|
||||
return &Response{
|
||||
Writer: writer,
|
||||
status: http.StatusOK,
|
||||
}
|
||||
}
|
||||
|
||||
// Header 获取响应头部
|
||||
func (r *Response) Header() http.Header {
|
||||
r.changed = true
|
||||
if r.ProxyHeader != nil {
|
||||
return *r.ProxyHeader
|
||||
}
|
||||
return r.Writer.Header()
|
||||
}
|
||||
|
||||
// Write 写入响应内容
|
||||
func (r *Response) Write(bytes []byte) (int, error) {
|
||||
r.checkWriteHeader()
|
||||
r.changed = true
|
||||
r.outLen += len(bytes)
|
||||
if r.ProxyHeader != nil {
|
||||
r.copyProxyHeader()
|
||||
}
|
||||
return r.Writer.Write(bytes)
|
||||
}
|
||||
|
||||
// WriteString 写入字符串响应
|
||||
func (r *Response) WriteString(s string) (int, error) {
|
||||
return r.Write([]byte(s))
|
||||
}
|
||||
|
||||
// WriteHeader 设置响应状态码
|
||||
func (r *Response) WriteHeader(code int) {
|
||||
r.changed = true
|
||||
r.status = code
|
||||
if r.ProxyHeader != nil && (r.status == http.StatusBadGateway || r.status == http.StatusServiceUnavailable || r.status == http.StatusGatewayTimeout) {
|
||||
return
|
||||
}
|
||||
if r.ProxyHeader != nil {
|
||||
r.copyProxyHeader()
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Response) checkWriteHeader() {
|
||||
if !r.headerWritten {
|
||||
r.headerWritten = true
|
||||
if r.status != http.StatusOK {
|
||||
r.Writer.WriteHeader(r.status)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Response) copyProxyHeader() {
|
||||
src := *r.ProxyHeader
|
||||
dst := r.Writer.Header()
|
||||
for k, vv := range src {
|
||||
for _, v := range vv {
|
||||
dst.Add(k, v)
|
||||
}
|
||||
}
|
||||
r.ProxyHeader = nil
|
||||
}
|
||||
|
||||
// Flush 刷新响应缓冲区
|
||||
func (r *Response) Flush() {
|
||||
if flusher, ok := r.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// GetStatusCode 获取当前状态码
|
||||
func (r *Response) GetStatusCode() int {
|
||||
return r.status
|
||||
}
|
||||
|
||||
// DontLog200 标记不记录 200 状态码的日志
|
||||
func (r *Response) DontLog200() {
|
||||
r.dontLog200 = true
|
||||
}
|
||||
|
||||
// Location 设置重定向地址
|
||||
func (r *Response) Location(location string) {
|
||||
r.WriteHeader(http.StatusFound)
|
||||
r.Header().Set("Location", location)
|
||||
}
|
||||
|
||||
// SendFile 发送文件
|
||||
func (r *Response) SendFile(contentType, filename string) {
|
||||
r.Header().Set("Content-Type", contentType)
|
||||
// TODO: Integrate memory file support if needed
|
||||
if fd, err := os.Open(filename); err == nil {
|
||||
defer fd.Close()
|
||||
_, _ = io.Copy(r, fd)
|
||||
}
|
||||
}
|
||||
|
||||
// DownloadFile 下载文件
|
||||
func (r *Response) DownloadFile(contentType, filename string, data any) {
|
||||
if contentType == "" {
|
||||
contentType = "application/octet-stream"
|
||||
}
|
||||
r.Header().Set("Content-Type", contentType)
|
||||
|
||||
if filename != "" {
|
||||
r.Header().Set("Content-Disposition", "attachment; filename="+filename)
|
||||
}
|
||||
|
||||
var outBytes []byte
|
||||
var reader io.Reader
|
||||
|
||||
switch v := data.(type) {
|
||||
case []byte:
|
||||
outBytes = v
|
||||
case string:
|
||||
outBytes = []byte(v)
|
||||
case io.Reader:
|
||||
reader = v
|
||||
default:
|
||||
outBytes, _ = cast.ToJSONBytes(data)
|
||||
}
|
||||
|
||||
if outBytes != nil {
|
||||
r.Header().Set("Content-Length", cast.String(len(outBytes)))
|
||||
_, _ = r.Write(outBytes)
|
||||
} else if reader != nil {
|
||||
_, _ = io.Copy(r, reader)
|
||||
}
|
||||
}
|
||||
113
rewrite.go
Normal file
113
rewrite.go
Normal file
@ -0,0 +1,113 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"apigo.cc/go/log"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type rewriteInfo struct {
|
||||
matcher *regexp.Regexp
|
||||
fromPath string
|
||||
toPath string
|
||||
}
|
||||
|
||||
var (
|
||||
rewrites = make(map[string]*rewriteInfo)
|
||||
regexRewrites = make([]*rewriteInfo, 0)
|
||||
rewriteBy func(*Request) (string, bool)
|
||||
rewritesLock = sync.RWMutex{}
|
||||
)
|
||||
|
||||
// Rewrite 注册重写规则
|
||||
func Rewrite(path string, toPath string) {
|
||||
s := &rewriteInfo{fromPath: path, toPath: toPath}
|
||||
|
||||
if strings.ContainsRune(path, '(') {
|
||||
matcher, err := regexp.Compile("^" + path + "$")
|
||||
if err == nil {
|
||||
s.matcher = matcher
|
||||
rewritesLock.Lock()
|
||||
regexRewrites = append(regexRewrites, s)
|
||||
rewritesLock.Unlock()
|
||||
}
|
||||
} else {
|
||||
rewritesLock.Lock()
|
||||
rewrites[path] = s
|
||||
rewritesLock.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// SetRewriteBy 设置动态重写函数
|
||||
func SetRewriteBy(by func(request *Request) (toPath string, rewrite bool)) {
|
||||
rewriteBy = by
|
||||
}
|
||||
|
||||
func processRewrite(request *Request, response *Response, logger *log.Logger) bool {
|
||||
requestPath := request.RequestURI
|
||||
queryString := ""
|
||||
if pos := strings.Index(requestPath, "?"); pos != -1 {
|
||||
queryString = requestPath[pos:]
|
||||
requestPath = requestPath[:pos]
|
||||
}
|
||||
|
||||
var rewriteToPath string
|
||||
var found bool
|
||||
|
||||
rewritesLock.RLock()
|
||||
// 1. 精确匹配
|
||||
if ri, ok := rewrites[requestPath]; ok {
|
||||
rewriteToPath = ri.toPath
|
||||
found = true
|
||||
}
|
||||
|
||||
// 2. 动态重写
|
||||
if !found && rewriteBy != nil {
|
||||
rewriteToPath, found = rewriteBy(request)
|
||||
}
|
||||
|
||||
// 3. 正则匹配
|
||||
if !found {
|
||||
for _, ri := range regexRewrites {
|
||||
if ri.matcher != nil {
|
||||
finds := ri.matcher.FindAllStringSubmatch(request.RequestURI, 1)
|
||||
if len(finds) > 0 {
|
||||
toPath := ri.toPath
|
||||
for i, part := range finds[0] {
|
||||
toPath = strings.ReplaceAll(toPath, fmt.Sprintf("$%d", i), part)
|
||||
}
|
||||
rewriteToPath = toPath
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
rewritesLock.RUnlock()
|
||||
|
||||
if found {
|
||||
if strings.Contains(rewriteToPath, "://") {
|
||||
// 外部重定向
|
||||
if !strings.Contains(rewriteToPath, "?") && queryString != "" {
|
||||
rewriteToPath += queryString
|
||||
}
|
||||
response.Header().Set("Location", rewriteToPath)
|
||||
response.WriteHeader(302)
|
||||
return true
|
||||
} else {
|
||||
// 内部重写
|
||||
logger.Info("rewrite", "from", request.RequestURI, "to", rewriteToPath)
|
||||
if queryString != "" && !strings.Contains(rewriteToPath, "?") {
|
||||
rewriteToPath += queryString
|
||||
}
|
||||
request.RequestURI = rewriteToPath
|
||||
request.URL, _ = url.Parse(rewriteToPath)
|
||||
return false // 继续后续处理
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
88
server.go
Normal file
88
server.go
Normal file
@ -0,0 +1,88 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"apigo.cc/go/log"
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AsyncServer 异步服务实例
|
||||
type AsyncServer struct {
|
||||
server *http.Server
|
||||
listener net.Listener
|
||||
Addr string
|
||||
stopChan chan os.Signal
|
||||
startChan chan bool
|
||||
}
|
||||
|
||||
// AsyncStart 异步启动服务
|
||||
func AsyncStart() *AsyncServer {
|
||||
as := &AsyncServer{
|
||||
startChan: make(chan bool, 1),
|
||||
stopChan: make(chan os.Signal, 1),
|
||||
}
|
||||
|
||||
go as.start()
|
||||
|
||||
<-as.startChan
|
||||
return as
|
||||
}
|
||||
|
||||
func (as *AsyncServer) start() {
|
||||
if Config.Listen == "" {
|
||||
Config.Listen = ":8080" // 默认端口
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", Config.Listen)
|
||||
if err != nil {
|
||||
log.DefaultLogger.Error("failed to listen", "addr", Config.Listen, "error", err.Error())
|
||||
as.startChan <- false
|
||||
return
|
||||
}
|
||||
|
||||
as.listener = listener
|
||||
as.Addr = listener.Addr().String()
|
||||
serverAddr = as.Addr
|
||||
|
||||
as.server = &http.Server{
|
||||
Handler: &routeHandler{},
|
||||
}
|
||||
|
||||
signal.Notify(as.stopChan, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
go func() {
|
||||
log.DefaultLogger.Info("service starting", "addr", as.Addr)
|
||||
as.startChan <- true
|
||||
if err := as.server.Serve(listener); err != nil && err != http.ErrServerClosed {
|
||||
log.DefaultLogger.Error("server error", "error", err.Error())
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Stop 停止服务
|
||||
func (as *AsyncServer) Stop() {
|
||||
log.DefaultLogger.Info("service stopping")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := as.server.Shutdown(ctx); err != nil {
|
||||
log.DefaultLogger.Error("server shutdown error", "error", err.Error())
|
||||
}
|
||||
log.DefaultLogger.Info("service stopped")
|
||||
}
|
||||
|
||||
// Wait 等待服务结束 (信号监听)
|
||||
func (as *AsyncServer) Wait() {
|
||||
<-as.stopChan
|
||||
as.Stop()
|
||||
}
|
||||
|
||||
// Start 同步启动服务
|
||||
func Start() {
|
||||
AsyncStart().Wait()
|
||||
}
|
||||
31
server_test.go
Normal file
31
server_test.go
Normal file
@ -0,0 +1,31 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAsyncServer(t *testing.T) {
|
||||
Config.Listen = ":0" // 随机端口
|
||||
as := AsyncStart()
|
||||
if as.Addr == "" {
|
||||
t.Fatal("AsyncStart failed to get address")
|
||||
}
|
||||
|
||||
// 测试服务是否可用
|
||||
resp, err := http.Get("http://" + as.Addr + "/__CHECK__")
|
||||
if err == nil {
|
||||
// 虽然没有注册 /__CHECK__,但应该返回 404 而非连接拒绝
|
||||
if resp.StatusCode != http.StatusNotFound {
|
||||
t.Errorf("Expected 404, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
as.Stop()
|
||||
|
||||
// 确认服务已关闭
|
||||
_, err = http.Get("http://" + as.Addr + "/__CHECK__")
|
||||
if err == nil {
|
||||
t.Error("Server should be closed")
|
||||
}
|
||||
}
|
||||
225
service.go
Normal file
225
service.go
Normal file
@ -0,0 +1,225 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"apigo.cc/go/log"
|
||||
"errors"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// WebServiceOptions 服务注册选项
|
||||
type WebServiceOptions struct {
|
||||
Priority int
|
||||
NoDoc bool
|
||||
NoBody bool
|
||||
NoLog200 bool
|
||||
Host string
|
||||
Ext Map
|
||||
// Limiters []*Limiter // TODO: Integrate Limiter
|
||||
}
|
||||
|
||||
// webServiceType 内部存储的服务元数据
|
||||
type webServiceType struct {
|
||||
authLevel int
|
||||
method string
|
||||
path string
|
||||
pathMatcher *regexp.Regexp
|
||||
pathArgs []string
|
||||
parmsNum int
|
||||
inType reflect.Type
|
||||
inIndex int
|
||||
headersType reflect.Type
|
||||
headersIndex int
|
||||
requestIndex int
|
||||
httpRequestIndex int
|
||||
responseIndex int
|
||||
responseWriterIndex int
|
||||
loggerIndex int
|
||||
callerIndex int
|
||||
funcType reflect.Type
|
||||
funcValue reflect.Value
|
||||
options WebServiceOptions
|
||||
data Map
|
||||
memo string
|
||||
}
|
||||
|
||||
var (
|
||||
serverId string
|
||||
serverAddr string
|
||||
serverProto = "http"
|
||||
serverProtoName = "http"
|
||||
running = false
|
||||
|
||||
webServices = make(map[string]*webServiceType)
|
||||
regexWebServices = make([]*webServiceType, 0)
|
||||
webServicesLock = sync.RWMutex{}
|
||||
webServicesList = make([]*webServiceType, 0)
|
||||
|
||||
websocketServices = make(map[string]*websocketServiceType)
|
||||
regexWebsocketServices = make([]*websocketServiceType, 0)
|
||||
websocketServicesLock = sync.RWMutex{}
|
||||
websocketServicesList = make([]*websocketServiceType, 0)
|
||||
|
||||
// 过滤器与拦截器
|
||||
inFilters = make([]func(*map[string]any, *Request, *Response, *log.Logger) any, 0)
|
||||
outFilters = make([]func(map[string]any, *Request, *Response, any, *log.Logger) (any, bool), 0)
|
||||
errorHandle func(any, *Request, *Response) any
|
||||
webAuthChecker func(int, *log.Logger, *string, map[string]any, *Request, *Response, *WebServiceOptions) (pass bool, object any)
|
||||
webAuthCheckers = make(map[int]func(int, *log.Logger, *string, map[string]any, *Request, *Response, *WebServiceOptions) (pass bool, object any))
|
||||
|
||||
// 注入点
|
||||
injectObjects = make(map[reflect.Type]any)
|
||||
injectFunctions = make(map[reflect.Type]func() any)
|
||||
|
||||
usedDeviceIdKey string
|
||||
usedClientAppKey string
|
||||
usedSessionIdKey string
|
||||
sessionIdMaker func() string
|
||||
)
|
||||
|
||||
// SetClientKeys 设置客户端标识相关的 Key 映射
|
||||
func SetClientKeys(deviceIdKey, clientAppKey, sessionIdKey string) {
|
||||
usedDeviceIdKey = deviceIdKey
|
||||
usedClientAppKey = clientAppKey
|
||||
usedSessionIdKey = sessionIdKey
|
||||
}
|
||||
|
||||
// SetSessionIdMaker 设置自定义会话 ID 生成器
|
||||
func SetSessionIdMaker(maker func() string) {
|
||||
sessionIdMaker = maker
|
||||
}
|
||||
|
||||
// SetAuthChecker 设置全局鉴权器
|
||||
func SetAuthChecker(authChecker func(authLevel int, logger *log.Logger, url *string, in map[string]any, request *Request, response *Response, options *WebServiceOptions) (pass bool, object any)) {
|
||||
webAuthChecker = authChecker
|
||||
}
|
||||
|
||||
// AddAuthChecker 为指定级别添加鉴权器
|
||||
func AddAuthChecker(authLevels []int, authChecker func(authLevel int, logger *log.Logger, url *string, in map[string]any, request *Request, response *Response, options *WebServiceOptions) (pass bool, object any)) {
|
||||
for _, al := range authLevels {
|
||||
webAuthCheckers[al] = authChecker
|
||||
}
|
||||
}
|
||||
|
||||
// SetInFilter 设置前置过滤器
|
||||
func SetInFilter(filter func(in *map[string]any, request *Request, response *Response, logger *log.Logger) (out any)) {
|
||||
inFilters = append(inFilters, filter)
|
||||
}
|
||||
|
||||
// SetOutFilter 设置后置过滤器
|
||||
func SetOutFilter(filter func(in map[string]any, request *Request, response *Response, out any, logger *log.Logger) (newOut any, isOver bool)) {
|
||||
outFilters = append(outFilters, filter)
|
||||
}
|
||||
|
||||
// Register 注册服务(通用方法)
|
||||
func Register(authLevel int, path string, serviceFunc any, memo string) {
|
||||
Restful(authLevel, "", path, serviceFunc, memo)
|
||||
}
|
||||
|
||||
// Restful 注册指定方法的服务
|
||||
func Restful(authLevel int, method, path string, serviceFunc any, memo string) {
|
||||
RestfulWithOptions(authLevel, method, path, serviceFunc, memo, WebServiceOptions{})
|
||||
}
|
||||
|
||||
// RestfulWithOptions 注册带选项的服务
|
||||
func RestfulWithOptions(authLevel int, method, path string, serviceFunc any, memo string, options WebServiceOptions) {
|
||||
s, err := makeCachedService(serviceFunc)
|
||||
if err != nil {
|
||||
// TODO: Log error properly when logger is ready
|
||||
return
|
||||
}
|
||||
|
||||
s.authLevel = authLevel
|
||||
s.options = options
|
||||
s.method = method
|
||||
s.path = path
|
||||
s.memo = memo
|
||||
|
||||
// 解析路径参数 {name}
|
||||
finder, err := regexp.Compile("{(.*?)}")
|
||||
if err == nil {
|
||||
keyName := regexp.QuoteMeta(path)
|
||||
finds := finder.FindAllStringSubmatch(path, 20)
|
||||
for _, found := range finds {
|
||||
keyName = strings.Replace(keyName, regexp.QuoteMeta(found[0]), "(.*?)", 1)
|
||||
s.pathArgs = append(s.pathArgs, found[1])
|
||||
}
|
||||
if len(s.pathArgs) > 0 {
|
||||
s.pathMatcher, _ = regexp.Compile("^" + keyName + "$")
|
||||
}
|
||||
}
|
||||
|
||||
webServicesLock.Lock()
|
||||
defer webServicesLock.Unlock()
|
||||
|
||||
// 简单路径匹配
|
||||
if s.pathMatcher == nil {
|
||||
webServices[method+path] = s // TODO: Include Host in key
|
||||
} else {
|
||||
regexWebServices = append(regexWebServices, s)
|
||||
}
|
||||
webServicesList = append(webServicesList, s)
|
||||
}
|
||||
|
||||
func makeCachedService(matchedService any) (*webServiceType, error) {
|
||||
funcType := reflect.TypeOf(matchedService)
|
||||
if funcType.Kind() != reflect.Func {
|
||||
return nil, errors.New("handler must be a function")
|
||||
}
|
||||
|
||||
targetService := &webServiceType{
|
||||
parmsNum: funcType.NumIn(),
|
||||
inIndex: -1,
|
||||
headersIndex: -1,
|
||||
requestIndex: -1,
|
||||
httpRequestIndex: -1,
|
||||
responseIndex: -1,
|
||||
responseWriterIndex: -1,
|
||||
loggerIndex: -1,
|
||||
callerIndex: -1,
|
||||
funcType: funcType,
|
||||
funcValue: reflect.ValueOf(matchedService),
|
||||
}
|
||||
|
||||
for i := 0; i < targetService.parmsNum; i++ {
|
||||
t := funcType.In(i)
|
||||
tStr := t.String()
|
||||
switch tStr {
|
||||
case "*service.Request":
|
||||
targetService.requestIndex = i
|
||||
case "*http.Request":
|
||||
targetService.httpRequestIndex = i
|
||||
case "*service.Response":
|
||||
targetService.responseIndex = i
|
||||
case "http.ResponseWriter":
|
||||
targetService.responseWriterIndex = i
|
||||
case "*log.Logger":
|
||||
targetService.loggerIndex = i
|
||||
default:
|
||||
if t.Kind() == reflect.Struct || (t.Kind() == reflect.Map && t.Elem().Kind() == reflect.Interface) {
|
||||
if targetService.inType == nil {
|
||||
targetService.inIndex = i
|
||||
targetService.inType = t
|
||||
} else if targetService.headersType == nil {
|
||||
targetService.headersIndex = i
|
||||
targetService.headersType = t
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return targetService, nil
|
||||
}
|
||||
|
||||
// GetInject 获取注入对象
|
||||
func GetInject(dataType reflect.Type) any {
|
||||
if obj, exists := injectObjects[dataType]; exists {
|
||||
return obj
|
||||
}
|
||||
if factory, exists := injectFunctions[dataType]; exists {
|
||||
return factory()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
54
service_test.go
Normal file
54
service_test.go
Normal file
@ -0,0 +1,54 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"apigo.cc/go/log"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestServiceRegister(t *testing.T) {
|
||||
handler := func(req *Request, logger *log.Logger) string {
|
||||
return "ok"
|
||||
}
|
||||
|
||||
Register(0, "/test", handler, "test service")
|
||||
|
||||
webServicesLock.RLock()
|
||||
s := webServices["/test"]
|
||||
webServicesLock.RUnlock()
|
||||
|
||||
if s == nil {
|
||||
t.Fatal("Service not registered")
|
||||
}
|
||||
|
||||
if s.requestIndex != 0 {
|
||||
t.Errorf("requestIndex mismatch: expected 0, got %d", s.requestIndex)
|
||||
}
|
||||
if s.loggerIndex != 1 {
|
||||
t.Errorf("loggerIndex mismatch: expected 1, got %d", s.loggerIndex)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegexServiceRegister(t *testing.T) {
|
||||
handler := func(args map[string]any) string {
|
||||
return "ok"
|
||||
}
|
||||
|
||||
Register(0, "/user/{id}", handler, "get user")
|
||||
|
||||
webServicesLock.RLock()
|
||||
found := false
|
||||
for _, s := range regexWebServices {
|
||||
if s.path == "/user/{id}" {
|
||||
found = true
|
||||
if len(s.pathArgs) != 1 || s.pathArgs[0] != "id" {
|
||||
t.Errorf("pathArgs mismatch: %v", s.pathArgs)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
webServicesLock.RUnlock()
|
||||
|
||||
if !found {
|
||||
t.Fatal("Regex service not registered")
|
||||
}
|
||||
}
|
||||
49
starter.go
Normal file
49
starter.go
Normal file
@ -0,0 +1,49 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// StartCmd 命令行命令定义
|
||||
type StartCmd struct {
|
||||
Name string
|
||||
Comment string
|
||||
Func func()
|
||||
}
|
||||
|
||||
var startCmds = []StartCmd{
|
||||
{"start", "Start server", Start},
|
||||
}
|
||||
|
||||
// AddCmd 添加自定义命令行命令
|
||||
func AddCmd(name, comment string, function func()) {
|
||||
startCmds = append(startCmds, StartCmd{name, comment, function})
|
||||
}
|
||||
|
||||
// CheckCmd 检查并执行命令行命令
|
||||
func CheckCmd() {
|
||||
if len(os.Args) > 1 {
|
||||
cmd := os.Args[1]
|
||||
if cmd == "help" || cmd == "--help" {
|
||||
showHelp()
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
for _, cmdInfo := range startCmds {
|
||||
if cmd == cmdInfo.Name {
|
||||
cmdInfo.Func()
|
||||
os.Exit(0)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func showHelp() {
|
||||
fmt.Printf("Usage: %s [command]\n\n", filepath.Base(os.Args[0]))
|
||||
fmt.Println("Available commands:")
|
||||
for _, cmdInfo := range startCmds {
|
||||
fmt.Printf(" %-10s %s\n", cmdInfo.Name, cmdInfo.Comment)
|
||||
}
|
||||
}
|
||||
123
static.go
Normal file
123
static.go
Normal file
@ -0,0 +1,123 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"apigo.cc/go/file"
|
||||
"apigo.cc/go/log"
|
||||
"mime"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
statics = make(map[string]*string)
|
||||
staticsByHost = make(map[string]map[string]*string)
|
||||
staticsByHostLock = sync.RWMutex{}
|
||||
)
|
||||
|
||||
// Static 注册静态文件目录
|
||||
func Static(path, rootPath string) {
|
||||
StaticByHost(path, rootPath, "")
|
||||
}
|
||||
|
||||
// StaticByHost 为指定域名注册静态文件目录
|
||||
func StaticByHost(path, rootPath, host string) {
|
||||
if !filepath.IsAbs(rootPath) {
|
||||
if absPath, err := filepath.Abs(rootPath); err == nil {
|
||||
rootPath = absPath
|
||||
}
|
||||
}
|
||||
|
||||
staticsByHostLock.Lock()
|
||||
defer staticsByHostLock.Unlock()
|
||||
|
||||
if host == "" {
|
||||
statics[path] = &rootPath
|
||||
} else {
|
||||
if staticsByHost[host] == nil {
|
||||
staticsByHost[host] = make(map[string]*string)
|
||||
}
|
||||
staticsByHost[host][path] = &rootPath
|
||||
}
|
||||
}
|
||||
|
||||
func getStaticFilePath(requestPath, host string) string {
|
||||
staticsByHostLock.RLock()
|
||||
defer staticsByHostLock.RUnlock()
|
||||
|
||||
// 优先匹配指定域名的配置
|
||||
if hostConfig, exists := staticsByHost[host]; exists {
|
||||
if filePath := findMatchedPath(hostConfig, requestPath); filePath != "" {
|
||||
return filePath
|
||||
}
|
||||
}
|
||||
|
||||
// 匹配全局配置
|
||||
return findMatchedPath(statics, requestPath)
|
||||
}
|
||||
|
||||
func findMatchedPath(config map[string]*string, requestPath string) string {
|
||||
for urlPath, rootPath := range config {
|
||||
if strings.HasPrefix(requestPath, urlPath) {
|
||||
return filepath.Join(*rootPath, requestPath[len(urlPath):])
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func processStatic(requestPath string, request *Request, response *Response, logger *log.Logger) bool {
|
||||
filePath := getStaticFilePath(requestPath, request.Host)
|
||||
if filePath == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
info, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
// 自动查找索引文件
|
||||
for _, indexFile := range Config.IndexFiles {
|
||||
f := filepath.Join(filePath, indexFile)
|
||||
if i, err := os.Stat(f); err == nil && !i.IsDir() {
|
||||
filePath = f
|
||||
info = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查 304
|
||||
if ifModifiedSince := request.Header.Get("If-Modified-Since"); ifModifiedSince != "" {
|
||||
if t, err := time.Parse(http.TimeFormat, ifModifiedSince); err == nil {
|
||||
if !info.ModTime().Truncate(time.Second).After(t.Truncate(time.Second)) {
|
||||
response.WriteHeader(http.StatusNotModified)
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 发送文件
|
||||
contentType := mime.TypeByExtension(filepath.Ext(filePath))
|
||||
if contentType == "" {
|
||||
contentType = "application/octet-stream"
|
||||
}
|
||||
response.Header().Set("Content-Type", contentType)
|
||||
response.Header().Set("Last-Modified", info.ModTime().UTC().Format(http.TimeFormat))
|
||||
|
||||
data, err := file.ReadBytes(filePath)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
_, _ = response.Write(data)
|
||||
return true
|
||||
}
|
||||
43
static_test.go
Normal file
43
static_test.go
Normal file
@ -0,0 +1,43 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStaticService(t *testing.T) {
|
||||
// 创建临时测试目录和文件
|
||||
tempDir, _ := os.MkdirTemp("", "static_test")
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
testFile := filepath.Join(tempDir, "index.html")
|
||||
os.WriteFile(testFile, []byte("<h1>Static Page</h1>"), 0644)
|
||||
|
||||
// 注册静态目录
|
||||
Static("/ui", tempDir)
|
||||
|
||||
rh := &routeHandler{}
|
||||
|
||||
// 测试成功访问
|
||||
req := httptest.NewRequest("GET", "/ui/index.html", nil)
|
||||
w := httptest.NewRecorder()
|
||||
rh.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected 200, got %d", w.Code)
|
||||
}
|
||||
if body := w.Body.String(); body != "<h1>Static Page</h1>" {
|
||||
t.Errorf("Content mismatch: %s", body)
|
||||
}
|
||||
|
||||
// 测试 404
|
||||
req404 := httptest.NewRequest("GET", "/ui/notfound.html", nil)
|
||||
w404 := httptest.NewRecorder()
|
||||
rh.ServeHTTP(w404, req404)
|
||||
if w404.Code != http.StatusNotFound {
|
||||
t.Errorf("Expected 404 for missing file, got %d", w404.Code)
|
||||
}
|
||||
}
|
||||
68
types.go
Normal file
68
types.go
Normal file
@ -0,0 +1,68 @@
|
||||
package service
|
||||
|
||||
// Map 通用 Map 类型
|
||||
type Map = map[string]any
|
||||
|
||||
// Arr 通用切片类型
|
||||
type Arr = []any
|
||||
|
||||
// Argot 错误码/标识符类型
|
||||
type Argot string
|
||||
|
||||
// Result 通用返回结构
|
||||
type Result struct {
|
||||
Ok bool `json:"ok"`
|
||||
Argot Argot `json:"argot,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// CodeResult 带状态码的返回结构
|
||||
type CodeResult struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// ArgotInfo 标识符信息(用于文档生成)
|
||||
type ArgotInfo struct {
|
||||
Name Argot
|
||||
Memo string
|
||||
}
|
||||
|
||||
// OK 设置成功状态
|
||||
func (r *Result) OK(argots ...Argot) {
|
||||
r.Ok = true
|
||||
if len(argots) > 0 {
|
||||
r.Argot = argots[0]
|
||||
}
|
||||
}
|
||||
|
||||
// Failed 设置失败状态
|
||||
func (r *Result) Failed(message string, argots ...Argot) {
|
||||
r.Ok = false
|
||||
r.Message = message
|
||||
if len(argots) > 0 {
|
||||
r.Argot = argots[0]
|
||||
}
|
||||
}
|
||||
|
||||
// Done 根据布尔值设置状态
|
||||
func (r *Result) Done(ok bool, failedMessage string, argots ...Argot) {
|
||||
r.Ok = ok
|
||||
if !ok {
|
||||
r.Message = failedMessage
|
||||
if len(argots) > 0 {
|
||||
r.Argot = argots[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OK 设置成功状态 (Code=1)
|
||||
func (r *CodeResult) OK() {
|
||||
r.Code = 1
|
||||
}
|
||||
|
||||
// Failed 设置失败状态与错误码
|
||||
func (r *CodeResult) Failed(code int, message string) {
|
||||
r.Code = code
|
||||
r.Message = message
|
||||
}
|
||||
41
types_test.go
Normal file
41
types_test.go
Normal file
@ -0,0 +1,41 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestResult(t *testing.T) {
|
||||
r := &Result{}
|
||||
r.OK()
|
||||
if !r.Ok {
|
||||
t.Error("Result.OK() failed")
|
||||
}
|
||||
|
||||
r.Failed("error", Argot("ERR_CODE"))
|
||||
if r.Ok || r.Message != "error" || r.Argot != "ERR_CODE" {
|
||||
t.Error("Result.Failed() failed")
|
||||
}
|
||||
|
||||
r.Done(true, "never")
|
||||
if !r.Ok {
|
||||
t.Error("Result.Done(true) failed")
|
||||
}
|
||||
|
||||
r.Done(false, "failed", Argot("FAIL"))
|
||||
if r.Ok || r.Message != "failed" || r.Argot != "FAIL" {
|
||||
t.Error("Result.Done(false) failed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodeResult(t *testing.T) {
|
||||
cr := &CodeResult{}
|
||||
cr.OK()
|
||||
if cr.Code != 1 {
|
||||
t.Error("CodeResult.OK() failed")
|
||||
}
|
||||
|
||||
cr.Failed(500, "internal error")
|
||||
if cr.Code != 500 || cr.Message != "internal error" {
|
||||
t.Error("CodeResult.Failed() failed")
|
||||
}
|
||||
}
|
||||
20
utility.go
Normal file
20
utility.go
Normal file
@ -0,0 +1,20 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"apigo.cc/go/id"
|
||||
)
|
||||
|
||||
// MakeId 生成指定长度的 ID
|
||||
func MakeId(size int) string {
|
||||
return id.MakeID(size)
|
||||
}
|
||||
|
||||
// MakeIdForMysql 生成适用于 MySQL 的有序 ID
|
||||
func MakeIdForMysql(size int) string {
|
||||
return id.DefaultIDMaker.GetForMysql(size)
|
||||
}
|
||||
|
||||
// MakeIdForPostgreSQL 生成适用于 PostgreSQL 的有序 ID
|
||||
func MakeIdForPostgreSQL(size int) string {
|
||||
return id.DefaultIDMaker.GetForPostgreSQL(size)
|
||||
}
|
||||
305
verify.go
Normal file
305
verify.go
Normal file
@ -0,0 +1,305 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"apigo.cc/go/cast"
|
||||
"apigo.cc/go/log"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// VerifyType 校验类型
|
||||
type VerifyType uint8
|
||||
|
||||
const (
|
||||
VerifyUnknown VerifyType = iota
|
||||
VerifyRegex
|
||||
VerifyStringLength
|
||||
VerifyGreaterThan
|
||||
VerifyLessThan
|
||||
VerifyBetween
|
||||
VerifyInList
|
||||
VerifyByFunc
|
||||
)
|
||||
|
||||
// VerifySet 校验规则集
|
||||
type VerifySet struct {
|
||||
Type VerifyType
|
||||
Regex *regexp.Regexp
|
||||
StringArgs []string
|
||||
IntArgs []int
|
||||
FloatArgs []float64
|
||||
Func func(any, []string) bool
|
||||
}
|
||||
|
||||
var (
|
||||
verifySets = make(map[string]*VerifySet)
|
||||
verifySetsLock = sync.RWMutex{}
|
||||
verifyFunctions = make(map[string]func(any, []string) bool)
|
||||
verifyFunctionsLock = sync.RWMutex{}
|
||||
)
|
||||
|
||||
// RegisterVerifyFunc 注册自定义校验函数
|
||||
func RegisterVerifyFunc(name string, f func(in any, args []string) bool) {
|
||||
verifyFunctionsLock.Lock()
|
||||
verifyFunctions[name] = f
|
||||
verifyFunctionsLock.Unlock()
|
||||
}
|
||||
|
||||
// RegisterVerify 注册预定义校验规则
|
||||
func RegisterVerify(name, setting string) {
|
||||
set, _ := compileVerifySet(setting)
|
||||
if set != nil {
|
||||
verifySetsLock.Lock()
|
||||
verifySets[name] = set
|
||||
verifySetsLock.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// VerifyStruct 校验结构体
|
||||
func VerifyStruct(in any, logger *log.Logger) (ok bool, field string) {
|
||||
v := cast.RealValue(reflect.ValueOf(in))
|
||||
if v.Kind() != reflect.Struct {
|
||||
if logger != nil {
|
||||
logger.Error("verify input is not struct", "type", v.Type().String())
|
||||
}
|
||||
return false, ""
|
||||
}
|
||||
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
ft := v.Type().Field(i)
|
||||
fv := v.Field(i)
|
||||
|
||||
// 忽略空指针、空切片、空 Map
|
||||
if (fv.Kind() == reflect.Ptr && fv.IsNil()) ||
|
||||
(fv.Kind() == reflect.Slice && fv.Len() == 0) ||
|
||||
(fv.Kind() == reflect.Map && fv.Len() == 0) {
|
||||
continue
|
||||
}
|
||||
|
||||
if ft.Anonymous {
|
||||
// 处理嵌套结构体(继承)
|
||||
if fv.CanInterface() {
|
||||
if ok, f := VerifyStruct(fv.Interface(), logger); !ok {
|
||||
return false, f
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
tag := ft.Tag.Get("verify")
|
||||
keyTag := ft.Tag.Get("verifyKey")
|
||||
if tag != "" || keyTag != "" {
|
||||
var err error
|
||||
ok, f, err := _verifyValue(fv, tag, keyTag, logger)
|
||||
if !ok {
|
||||
if f == "" {
|
||||
f = cast.GetLowerName(ft.Name)
|
||||
}
|
||||
if logger != nil {
|
||||
if err != nil {
|
||||
logger.Error(err.Error(), "field", f)
|
||||
} else {
|
||||
logger.Warning("verify failed", "field", f, "tag", tag)
|
||||
}
|
||||
}
|
||||
return false, f
|
||||
}
|
||||
}
|
||||
}
|
||||
return true, ""
|
||||
}
|
||||
|
||||
func _verifyValue(in reflect.Value, setting, keySetting string, logger *log.Logger) (bool, string, error) {
|
||||
t := in.Type()
|
||||
// 处理切片 (非 byte 切片)
|
||||
if t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8 {
|
||||
if setting != "" {
|
||||
for i := 0; i < in.Len(); i++ {
|
||||
if ok, f, err := _verifyValue(in.Index(i), setting, "", logger); !ok {
|
||||
return false, f, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return true, "", nil
|
||||
}
|
||||
|
||||
// 处理 Map
|
||||
if t.Kind() == reflect.Map {
|
||||
for _, k := range in.MapKeys() {
|
||||
if keySetting != "" {
|
||||
if ok, _, err := _verifyValue(k, keySetting, "", logger); !ok {
|
||||
return false, "key", err
|
||||
}
|
||||
}
|
||||
if setting != "" {
|
||||
if ok, f, err := _verifyValue(in.MapIndex(k), setting, "", logger); !ok {
|
||||
return false, f, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return true, "", nil
|
||||
}
|
||||
|
||||
// 处理嵌套 Struct
|
||||
if t.Kind() == reflect.Struct {
|
||||
ok, f := VerifyStruct(in.Interface(), logger)
|
||||
return ok, f, nil
|
||||
}
|
||||
|
||||
// 基础校验
|
||||
if setting == "" {
|
||||
return true, "", nil
|
||||
}
|
||||
|
||||
ok, err := verify(in.Interface(), setting)
|
||||
return ok, "", err
|
||||
}
|
||||
|
||||
func verify(in any, setting string) (bool, error) {
|
||||
if len(setting) < 2 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
verifySetsLock.RLock()
|
||||
set, exists := verifySets[setting]
|
||||
verifySetsLock.RUnlock()
|
||||
|
||||
if !exists {
|
||||
var err error
|
||||
set, err = compileVerifySet(setting)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
verifySetsLock.Lock()
|
||||
verifySets[setting] = set
|
||||
verifySetsLock.Unlock()
|
||||
}
|
||||
|
||||
switch set.Type {
|
||||
case VerifyByFunc:
|
||||
return set.Func(in, set.StringArgs), nil
|
||||
case VerifyRegex:
|
||||
return set.Regex.MatchString(cast.String(in)), nil
|
||||
case VerifyStringLength:
|
||||
l := len(cast.String(in))
|
||||
if len(set.StringArgs) > 0 {
|
||||
if set.StringArgs[0] == "+" {
|
||||
return l >= set.IntArgs[0], nil
|
||||
} else if set.StringArgs[0] == "-" {
|
||||
return l <= set.IntArgs[0], nil
|
||||
}
|
||||
}
|
||||
if len(set.IntArgs) > 1 {
|
||||
return l >= set.IntArgs[0] && l <= set.IntArgs[1], nil
|
||||
}
|
||||
return l == set.IntArgs[0], nil
|
||||
case VerifyGreaterThan:
|
||||
return cast.Float64(in) > set.FloatArgs[0], nil
|
||||
case VerifyLessThan:
|
||||
return cast.Float64(in) < set.FloatArgs[0], nil
|
||||
case VerifyBetween:
|
||||
val := cast.Float64(in)
|
||||
return val >= set.FloatArgs[0] && val <= set.FloatArgs[1], nil
|
||||
case VerifyInList:
|
||||
s := cast.String(in)
|
||||
for _, item := range set.StringArgs {
|
||||
if item == s {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func compileVerifySet(setting string) (*VerifySet, error) {
|
||||
set := &VerifySet{Type: VerifyUnknown}
|
||||
if setting == "" {
|
||||
return set, nil
|
||||
}
|
||||
|
||||
if setting[0] != '^' {
|
||||
key := setting
|
||||
args := ""
|
||||
if pos := strings.IndexByte(setting, ':'); pos != -1 {
|
||||
key = setting[:pos]
|
||||
args = setting[pos+1:]
|
||||
}
|
||||
|
||||
// 优先查找自定义函数
|
||||
verifyFunctionsLock.RLock()
|
||||
f, exists := verifyFunctions[key]
|
||||
verifyFunctionsLock.RUnlock()
|
||||
if exists {
|
||||
set.Type = VerifyByFunc
|
||||
set.Func = f
|
||||
if args != "" {
|
||||
set.StringArgs = strings.Split(args, ",")
|
||||
}
|
||||
return set, nil
|
||||
}
|
||||
|
||||
// 内置规则
|
||||
switch key {
|
||||
case "length":
|
||||
set.Type = VerifyStringLength
|
||||
if args == "" {
|
||||
args = "1+"
|
||||
}
|
||||
last := args[len(args)-1]
|
||||
if last == '+' || last == '-' {
|
||||
set.StringArgs = []string{string(last)}
|
||||
args = args[:len(args)-1]
|
||||
}
|
||||
// 同时支持逗号和中划线
|
||||
sep := ","
|
||||
if strings.Contains(args, "-") && !strings.Contains(args, ",") {
|
||||
sep = "-"
|
||||
}
|
||||
if strings.Contains(args, sep) {
|
||||
a := strings.Split(args, sep)
|
||||
set.IntArgs = []int{cast.Int(a[0]), cast.Int(a[1])}
|
||||
} else {
|
||||
set.IntArgs = []int{cast.Int(args)}
|
||||
}
|
||||
return set, nil
|
||||
case "between":
|
||||
set.Type = VerifyBetween
|
||||
if args == "" {
|
||||
args = "1-100000000"
|
||||
}
|
||||
a := strings.Split(args, "-")
|
||||
if len(a) == 1 {
|
||||
set.FloatArgs = []float64{0, cast.Float64(a[0])}
|
||||
} else {
|
||||
set.FloatArgs = []float64{cast.Float64(a[0]), cast.Float64(a[1])}
|
||||
}
|
||||
return set, nil
|
||||
case "gt":
|
||||
set.Type = VerifyGreaterThan
|
||||
set.FloatArgs = []float64{cast.Float64(args)}
|
||||
return set, nil
|
||||
case "lt":
|
||||
set.Type = VerifyLessThan
|
||||
set.FloatArgs = []float64{cast.Float64(args)}
|
||||
return set, nil
|
||||
case "in":
|
||||
set.Type = VerifyInList
|
||||
if args != "" {
|
||||
set.StringArgs = strings.Split(args, ",")
|
||||
}
|
||||
return set, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 默认视为正则表达式
|
||||
rx, err := regexp.Compile(setting)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
set.Type = VerifyRegex
|
||||
set.Regex = rx
|
||||
return set, nil
|
||||
}
|
||||
76
verify_test.go
Normal file
76
verify_test.go
Normal file
@ -0,0 +1,76 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
type TestUser struct {
|
||||
Name string `verify:"length:2-10"`
|
||||
Age int `verify:"between:18-100"`
|
||||
Type string `verify:"in:admin,user,guest"`
|
||||
}
|
||||
|
||||
type NestedStruct struct {
|
||||
TestUser
|
||||
Note string `verify:"^.{1,20}$"`
|
||||
}
|
||||
|
||||
func TestVerifyStruct(t *testing.T) {
|
||||
u := TestUser{Name: "Star", Age: 25, Type: "admin"}
|
||||
if ok, f := VerifyStruct(u, nil); !ok {
|
||||
t.Errorf("VerifyStruct failed on valid user, field: %s", f)
|
||||
}
|
||||
|
||||
u.Name = "S"
|
||||
if ok, f := VerifyStruct(u, nil); ok || f != "name" {
|
||||
t.Errorf("VerifyStruct should fail on short name, got ok=%v, field=%s", ok, f)
|
||||
}
|
||||
|
||||
u.Name = "Star"
|
||||
u.Age = 10
|
||||
if ok, f := VerifyStruct(u, nil); ok || f != "age" {
|
||||
t.Errorf("VerifyStruct should fail on young age, got ok=%v, field=%s", ok, f)
|
||||
}
|
||||
|
||||
u.Age = 25
|
||||
u.Type = "invalid"
|
||||
if ok, f := VerifyStruct(u, nil); ok || f != "type" {
|
||||
t.Errorf("VerifyStruct should fail on invalid type, got ok=%v, field=%s", ok, f)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNestedVerify(t *testing.T) {
|
||||
n := NestedStruct{
|
||||
TestUser: TestUser{Name: "Star", Age: 25, Type: "user"},
|
||||
Note: "Hello",
|
||||
}
|
||||
if ok, f := VerifyStruct(n, nil); !ok {
|
||||
t.Errorf("Nested VerifyStruct failed on valid data, field: %s", f)
|
||||
}
|
||||
|
||||
n.TestUser.Age = 5
|
||||
if ok, f := VerifyStruct(n, nil); ok || f != "age" {
|
||||
t.Errorf("Nested VerifyStruct should fail on nested age, got ok=%v, field=%s", ok, f)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCustomVerify(t *testing.T) {
|
||||
RegisterVerifyFunc("odd", func(in any, args []string) bool {
|
||||
val := in.(int)
|
||||
return val%2 != 0
|
||||
})
|
||||
|
||||
type OddStruct struct {
|
||||
Num int `verify:"odd"`
|
||||
}
|
||||
|
||||
o := OddStruct{Num: 3}
|
||||
if ok, f := VerifyStruct(o, nil); !ok {
|
||||
t.Errorf("Custom verify failed on odd number, field: %s", f)
|
||||
}
|
||||
|
||||
o.Num = 4
|
||||
if ok, f := VerifyStruct(o, nil); ok || f != "num" {
|
||||
t.Errorf("Custom verify should fail on even number, got ok=%v, field=%s", ok, f)
|
||||
}
|
||||
}
|
||||
175
websocket.go
Normal file
175
websocket.go
Normal file
@ -0,0 +1,175 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"apigo.cc/go/cast"
|
||||
"apigo.cc/go/log"
|
||||
"github.com/gorilla/websocket"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
// websocketServiceType WebSocket 服务元数据
|
||||
type websocketServiceType struct {
|
||||
authLevel int
|
||||
path string
|
||||
pathMatcher *regexp.Regexp
|
||||
pathArgs []string
|
||||
updater *websocket.Upgrader
|
||||
openFuncValue reflect.Value
|
||||
openFuncType reflect.Type
|
||||
closeFuncValue reflect.Value
|
||||
closeFuncType reflect.Type
|
||||
sessionType reflect.Type
|
||||
actions map[string]*websocketActionType
|
||||
isSimple bool
|
||||
options WebServiceOptions
|
||||
memo string
|
||||
}
|
||||
|
||||
// websocketActionType WebSocket Action 元数据
|
||||
type websocketActionType struct {
|
||||
authLevel int
|
||||
funcValue reflect.Value
|
||||
funcType reflect.Type
|
||||
inType reflect.Type
|
||||
memo string
|
||||
}
|
||||
|
||||
// ActionRegister WebSocket Action 注册器
|
||||
type ActionRegister struct {
|
||||
ws *websocketServiceType
|
||||
}
|
||||
|
||||
// RegisterWebsocket 注册 WebSocket 服务
|
||||
func RegisterWebsocket(authLevel int, path string, onOpen, onClose any, memo string) *ActionRegister {
|
||||
s := &websocketServiceType{
|
||||
authLevel: authLevel,
|
||||
path: path,
|
||||
memo: memo,
|
||||
updater: &websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }},
|
||||
actions: make(map[string]*websocketActionType),
|
||||
}
|
||||
|
||||
if onOpen != nil {
|
||||
s.openFuncValue = reflect.ValueOf(onOpen)
|
||||
s.openFuncType = s.openFuncValue.Type()
|
||||
if s.openFuncType.NumOut() > 0 {
|
||||
s.sessionType = s.openFuncType.Out(0)
|
||||
}
|
||||
}
|
||||
|
||||
if onClose != nil {
|
||||
s.closeFuncValue = reflect.ValueOf(onClose)
|
||||
s.closeFuncType = s.closeFuncValue.Type()
|
||||
}
|
||||
|
||||
websocketServicesLock.Lock()
|
||||
websocketServices[path] = s
|
||||
websocketServicesLock.Unlock()
|
||||
|
||||
return &ActionRegister{ws: s}
|
||||
}
|
||||
|
||||
// RegisterAction 注册 WebSocket Action
|
||||
func (ar *ActionRegister) RegisterAction(authLevel int, name string, action any, memo string) {
|
||||
v := reflect.ValueOf(action)
|
||||
t := v.Type()
|
||||
a := &websocketActionType{
|
||||
authLevel: authLevel,
|
||||
funcValue: v,
|
||||
funcType: t,
|
||||
memo: memo,
|
||||
}
|
||||
|
||||
// 查找输入参数类型
|
||||
for i := 0; i < t.NumIn(); i++ {
|
||||
inT := t.In(i)
|
||||
if inT.Kind() == reflect.Struct {
|
||||
a.inType = inT
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
ar.ws.actions[name] = a
|
||||
}
|
||||
|
||||
func doWebsocketService(ws *websocketServiceType, request *Request, response *Response, logger *log.Logger) {
|
||||
conn, err := ws.updater.Upgrade(response.Writer, request.Request, nil)
|
||||
if err != nil {
|
||||
logger.Error("websocket upgrade failed", "error", err.Error())
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
var session any
|
||||
if ws.openFuncValue.IsValid() {
|
||||
// 简化版:仅支持基础参数注入
|
||||
params := make([]reflect.Value, ws.openFuncType.NumIn())
|
||||
for i := 0; i < len(params); i++ {
|
||||
t := ws.openFuncType.In(i)
|
||||
if t == reflect.TypeOf(request) {
|
||||
params[i] = reflect.ValueOf(request)
|
||||
} else if t == reflect.TypeOf(logger) {
|
||||
params[i] = reflect.ValueOf(logger)
|
||||
} else {
|
||||
params[i] = reflect.New(t).Elem()
|
||||
}
|
||||
}
|
||||
outs := ws.openFuncValue.Call(params)
|
||||
if len(outs) > 0 {
|
||||
session = outs[0].Interface()
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
var msg Map
|
||||
if err := conn.ReadJSON(&msg); err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
actionName := cast.String(msg["action"])
|
||||
action := ws.actions[actionName]
|
||||
if action == nil {
|
||||
action = ws.actions[""] // 默认 action
|
||||
}
|
||||
|
||||
if action != nil {
|
||||
params := make([]reflect.Value, action.funcType.NumIn())
|
||||
for i := 0; i < len(params); i++ {
|
||||
t := action.funcType.In(i)
|
||||
if t == ws.sessionType {
|
||||
params[i] = reflect.ValueOf(session)
|
||||
} else if t == reflect.TypeOf(conn) {
|
||||
params[i] = reflect.ValueOf(conn)
|
||||
} else if t.Kind() == reflect.Struct {
|
||||
in := reflect.New(t).Interface()
|
||||
cast.Convert(in, msg)
|
||||
params[i] = reflect.ValueOf(in).Elem()
|
||||
} else {
|
||||
params[i] = reflect.New(t).Elem()
|
||||
}
|
||||
}
|
||||
outs := action.funcValue.Call(params)
|
||||
if len(outs) > 0 {
|
||||
result := outs[0].Interface()
|
||||
if result != nil {
|
||||
_ = conn.WriteJSON(result)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if ws.closeFuncValue.IsValid() {
|
||||
params := make([]reflect.Value, ws.closeFuncType.NumIn())
|
||||
for i := 0; i < len(params); i++ {
|
||||
t := ws.closeFuncType.In(i)
|
||||
if t == ws.sessionType {
|
||||
params[i] = reflect.ValueOf(session)
|
||||
} else {
|
||||
params[i] = reflect.New(t).Elem()
|
||||
}
|
||||
}
|
||||
ws.closeFuncValue.Call(params)
|
||||
}
|
||||
}
|
||||
44
websocket_test.go
Normal file
44
websocket_test.go
Normal file
@ -0,0 +1,44 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"github.com/gorilla/websocket"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWebSocketService(t *testing.T) {
|
||||
// 注册 WebSocket 服务
|
||||
ar := RegisterWebsocket(0, "/ws", nil, nil, "test websocket")
|
||||
ar.RegisterAction(0, "echo", func(in struct{ Msg string }) Map {
|
||||
return Map{"action": "echo", "reply": in.Msg}
|
||||
}, "echo action")
|
||||
|
||||
// 启动测试服务器
|
||||
server := httptest.NewServer(&routeHandler{})
|
||||
defer server.Close()
|
||||
|
||||
// 建立连接
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/ws"
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial failed: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// 发送消息
|
||||
msg := Map{"action": "echo", "msg": "hello"}
|
||||
if err := conn.WriteJSON(msg); err != nil {
|
||||
t.Fatalf("WriteJSON failed: %v", err)
|
||||
}
|
||||
|
||||
// 接收响应
|
||||
var reply Map
|
||||
if err := conn.ReadJSON(&reply); err != nil {
|
||||
t.Fatalf("ReadJSON failed: %v", err)
|
||||
}
|
||||
|
||||
if reply["reply"] != "hello" {
|
||||
t.Errorf("Reply mismatch: %v", reply)
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user