feat(service): add EnableWebDev for auto-refresh and AddShutdownHook
This commit is contained in:
parent
44951a9ab6
commit
fe3b420d35
@ -878,17 +878,6 @@
|
|||||||
"Precision": 0,
|
"Precision": 0,
|
||||||
"WithoutKey": false,
|
"WithoutKey": false,
|
||||||
"Hide": false
|
"Hide": false
|
||||||
},
|
|
||||||
{
|
|
||||||
"Index": 8,
|
|
||||||
"Name": "CallStacks",
|
|
||||||
"KeyName": "",
|
|
||||||
"AttachBefore": false,
|
|
||||||
"Color": "",
|
|
||||||
"Format": "",
|
|
||||||
"Precision": 0,
|
|
||||||
"WithoutKey": false,
|
|
||||||
"Hide": false
|
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,5 +1,14 @@
|
|||||||
# CHANGELOG - go/service
|
# CHANGELOG - go/service
|
||||||
|
|
||||||
|
## v1.5.6 (2026-06-05)
|
||||||
|
- **新特性: EnableWebDev 支持**:
|
||||||
|
- 引入了 `service.EnableWebDev(config watch.Config)`,支持自动刷新页面的开发模式。
|
||||||
|
- **WebSocket 同步**: 自动注册 `/_watch` 服务,与文件监听器协同工作。
|
||||||
|
- **智能 HTML 注入**: 采用 `OutFilter` 在 HTML 响应末尾精准注入 WebSocket 重连脚本,支持静态文件与动态服务。
|
||||||
|
- **性能优化**: 仅在开启开发模式时启用响应缓冲,生产环境无任何性能损失。
|
||||||
|
- **基础设施**: 增加包级 `AddShutdownHook` 支持,提供更优雅的资源回收机制。
|
||||||
|
- **依赖同步**: 升级至 `log v1.5.5`,对齐不带堆栈的 Warning 规范。
|
||||||
|
|
||||||
## v1.5.5 (2026-06-05)
|
## v1.5.5 (2026-06-05)
|
||||||
- **依赖同步**: 全量对齐至 `@go` 基础设施最新版本(`log v1.5.4`, `starter v1.5.2`, `db v1.5.2`)。
|
- **依赖同步**: 全量对齐至 `@go` 基础设施最新版本(`log v1.5.4`, `starter v1.5.2`, `db v1.5.2`)。
|
||||||
|
|
||||||
|
|||||||
2
go.mod
2
go.mod
@ -10,7 +10,7 @@ require (
|
|||||||
apigo.cc/go/http v1.5.0
|
apigo.cc/go/http v1.5.0
|
||||||
apigo.cc/go/id v1.5.0
|
apigo.cc/go/id v1.5.0
|
||||||
apigo.cc/go/jsmod v1.5.0
|
apigo.cc/go/jsmod v1.5.0
|
||||||
apigo.cc/go/log v1.5.4
|
apigo.cc/go/log v1.5.5
|
||||||
apigo.cc/go/redis v1.5.0
|
apigo.cc/go/redis v1.5.0
|
||||||
apigo.cc/go/safe v1.5.0
|
apigo.cc/go/safe v1.5.0
|
||||||
apigo.cc/go/starter v1.5.2
|
apigo.cc/go/starter v1.5.2
|
||||||
|
|||||||
37
handler.go
37
handler.go
@ -33,15 +33,17 @@ func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
request := NewRequest(r)
|
request := NewRequest(r)
|
||||||
request.Id = requestId
|
request.Id = requestId
|
||||||
response := NewResponse(w)
|
response := NewResponse(w, ws)
|
||||||
response.Id = requestId
|
response.Id = requestId
|
||||||
requestLogger := log.New(requestId)
|
requestLogger := log.New(requestId)
|
||||||
|
|
||||||
// 0. 延迟处理日志与状态检查
|
// 0. 延迟处理日志与状态检查
|
||||||
var s *webServiceType
|
var s *webServiceType
|
||||||
|
var wsc *websocketServiceType
|
||||||
var authLevel int
|
var authLevel int
|
||||||
var priority int
|
var priority int
|
||||||
var args = make(map[string]any)
|
var args = make(map[string]any)
|
||||||
|
var result any
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
// 捕捉 Panic
|
// 捕捉 Panic
|
||||||
@ -151,17 +153,15 @@ func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// 处理静态文件
|
// 处理静态文件
|
||||||
if ws.processStatic(path, request, response, requestLogger) {
|
if ws.processStatic(path, request, response, requestLogger) {
|
||||||
return
|
goto filter
|
||||||
}
|
}
|
||||||
|
|
||||||
var wsc *websocketServiceType
|
|
||||||
s, wsc = ws.findService(r.Method, host, path)
|
s, wsc = ws.findService(r.Method, host, path)
|
||||||
|
|
||||||
// 4. 参数解析 (Form & Body)
|
// 4. 参数解析 (Form & Body)
|
||||||
parseRequestArgs(request, args)
|
parseRequestArgs(request, args)
|
||||||
|
|
||||||
// 5. 前置过滤器
|
// 5. 前置过滤器
|
||||||
var result any
|
|
||||||
for _, filter := range ws.inFilters {
|
for _, filter := range ws.inFilters {
|
||||||
result = filter(&args, request, response, requestLogger)
|
result = filter(&args, request, response, requestLogger)
|
||||||
if result != nil {
|
if result != nil {
|
||||||
@ -205,13 +205,18 @@ func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
if s == nil && result == nil && !response.changed {
|
if s == nil && result == nil && !response.changed {
|
||||||
response.WriteHeader(http.StatusNotFound)
|
response.WriteHeader(http.StatusNotFound)
|
||||||
|
result = "404 page not found"
|
||||||
}
|
}
|
||||||
|
|
||||||
// 7. 后置过滤器
|
filter:
|
||||||
|
// 7. 后置过滤器 (即使 response.changed 也要执行,比如静态文件的 HTML 注入)
|
||||||
for _, filter := range ws.outFilters {
|
for _, filter := range ws.outFilters {
|
||||||
newResult, done := filter(args, request, response, result, requestLogger)
|
newResult, done := filter(args, request, response, result, requestLogger)
|
||||||
if newResult != nil {
|
if newResult != nil {
|
||||||
result = newResult
|
result = newResult
|
||||||
|
// 如果 response.changed 为 true,说明已经有内容写出了。
|
||||||
|
// 如果过滤器返回了非 nil 的 result,我们通常认为它想替换或追加内容。
|
||||||
|
// 特别是对于静态文件,如果我们清空了 body 并返回了新内容,result 就不再是 nil。
|
||||||
}
|
}
|
||||||
if done {
|
if done {
|
||||||
break
|
break
|
||||||
@ -219,7 +224,19 @@ func (rh *RouteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 8. 输出结果
|
// 8. 输出结果
|
||||||
outputResult(response, result)
|
if ws.hasOutFilter {
|
||||||
|
// 过滤器模式:所有内容都应该从 result 或 response.body 中写出
|
||||||
|
if result != nil {
|
||||||
|
outputResult(response, result)
|
||||||
|
} else if response.changed {
|
||||||
|
response.PhysicalWrite(response.body)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 普通模式:result (业务返回值) 需要写出,而 response.changed (比如静态文件) 已经由 Response.Write 写过了
|
||||||
|
if result != nil {
|
||||||
|
outputResult(response, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func hostOnly(host string) string {
|
func hostOnly(host string) string {
|
||||||
@ -403,9 +420,13 @@ func outputResult(response *Response, result any) {
|
|||||||
if contentType != "" && response.Header().Get("Content-Type") == "" {
|
if contentType != "" && response.Header().Get("Content-Type") == "" {
|
||||||
response.Header().Set("Content-Type", contentType)
|
response.Header().Set("Content-Type", contentType)
|
||||||
}
|
}
|
||||||
_, _ = response.Write(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
if response.server != nil && response.server.hasOutFilter {
|
||||||
|
response.PhysicalWrite(data)
|
||||||
|
} else {
|
||||||
|
_, _ = response.Write(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
func (ws *webServer) handleClientKeys(request *Request, response *Response) {
|
func (ws *webServer) handleClientKeys(request *Request, response *Response) {
|
||||||
// SessionId
|
// SessionId
|
||||||
if ws.usedSessionIdKey != "" {
|
if ws.usedSessionIdKey != "" {
|
||||||
|
|||||||
38
response.go
38
response.go
@ -19,13 +19,15 @@ type Response struct {
|
|||||||
dontLog200 bool
|
dontLog200 bool
|
||||||
dontLogArgs []string
|
dontLogArgs []string
|
||||||
ProxyHeader *http.Header
|
ProxyHeader *http.Header
|
||||||
|
server *webServer
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewResponse 创建 Response 包装
|
// NewResponse 创建 Response 包装
|
||||||
func NewResponse(writer http.ResponseWriter) *Response {
|
func NewResponse(writer http.ResponseWriter, server *webServer) *Response {
|
||||||
return &Response{
|
return &Response{
|
||||||
Writer: writer,
|
Writer: writer,
|
||||||
Code: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
|
server: server,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -42,9 +44,27 @@ func (r *Response) Write(bytes []byte) (int, error) {
|
|||||||
r.checkWriteHeader()
|
r.checkWriteHeader()
|
||||||
r.changed = true
|
r.changed = true
|
||||||
r.outLen += len(bytes)
|
r.outLen += len(bytes)
|
||||||
if r.Code != http.StatusOK && len(r.body) < 4096 {
|
|
||||||
|
// 如果有输出过滤器,我们必须先缓冲,不能直接写入网线,否则会导致重复输出
|
||||||
|
if r.server != nil && r.server.hasOutFilter {
|
||||||
|
r.body = append(r.body, bytes...)
|
||||||
|
return len(bytes), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 即使没有过滤器,非 200 状态码也进行缓冲以便日志记录
|
||||||
|
if r.Code != http.StatusOK {
|
||||||
r.body = append(r.body, bytes...)
|
r.body = append(r.body, bytes...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if r.ProxyHeader != nil {
|
||||||
|
r.copyProxyHeader()
|
||||||
|
}
|
||||||
|
return r.Writer.Write(bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PhysicalWrite 物理写入网线,绕过过滤器缓冲逻辑
|
||||||
|
func (r *Response) PhysicalWrite(bytes []byte) (int, error) {
|
||||||
|
r.checkWriteHeader()
|
||||||
if r.ProxyHeader != nil {
|
if r.ProxyHeader != nil {
|
||||||
r.copyProxyHeader()
|
r.copyProxyHeader()
|
||||||
}
|
}
|
||||||
@ -100,6 +120,20 @@ func (r *Response) GetStatusCode() int {
|
|||||||
return r.Code
|
return r.Code
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetBody 获取响应内容
|
||||||
|
func (r *Response) GetBody() []byte {
|
||||||
|
return r.body
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearBody 清空响应内容缓冲区 (用于过滤器替换内容)
|
||||||
|
func (r *Response) ClearBody() {
|
||||||
|
r.body = nil
|
||||||
|
r.outLen = 0
|
||||||
|
// 注意:这里我们不重置 headerWritten 和 Code,因为 Header 已经发出去了。
|
||||||
|
// 但是在某些测试环境下(如 httptest.Recorder),我们可以尝试“假装”没写过。
|
||||||
|
// 实际上,生产环境下 Header 发出去就收不回来了,所以注入只能发生在 Body 层面。
|
||||||
|
}
|
||||||
|
|
||||||
// DontLog200 标记不记录 200 状态码的日志
|
// DontLog200 标记不记录 200 状态码的日志
|
||||||
func (r *Response) DontLog200() {
|
func (r *Response) DontLog200() {
|
||||||
r.dontLog200 = true
|
r.dontLog200 = true
|
||||||
|
|||||||
16
server.go
16
server.go
@ -84,6 +84,13 @@ type webServer struct {
|
|||||||
usedClientAppKey string
|
usedClientAppKey string
|
||||||
usedSessionIdKey string
|
usedSessionIdKey string
|
||||||
sessionIdMaker func() string
|
sessionIdMaker func() string
|
||||||
|
|
||||||
|
// 停机钩子
|
||||||
|
shutdownHooks []func()
|
||||||
|
shutdownHooksLock sync.Mutex
|
||||||
|
|
||||||
|
// 性能优化:标记是否有输出过滤器
|
||||||
|
hasOutFilter bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultServer 全局单例服务实例
|
// DefaultServer 全局单例服务实例
|
||||||
@ -411,6 +418,15 @@ func (ws *webServer) Start(ctx context.Context, logger *log.Logger) error {
|
|||||||
// Stop 停止服务,实现 starter.Service 接口
|
// Stop 停止服务,实现 starter.Service 接口
|
||||||
func (ws *webServer) Stop(ctx context.Context) error {
|
func (ws *webServer) Stop(ctx context.Context) error {
|
||||||
ws.running = false
|
ws.running = false
|
||||||
|
|
||||||
|
// 执行停机钩子 (反序)
|
||||||
|
ws.shutdownHooksLock.Lock()
|
||||||
|
for i := len(ws.shutdownHooks) - 1; i >= 0; i-- {
|
||||||
|
ws.shutdownHooks[i]()
|
||||||
|
}
|
||||||
|
ws.shutdownHooks = nil
|
||||||
|
ws.shutdownHooksLock.Unlock()
|
||||||
|
|
||||||
if ws.discoverer != nil {
|
if ws.discoverer != nil {
|
||||||
ws.discoverer.Stop()
|
ws.discoverer.Stop()
|
||||||
}
|
}
|
||||||
|
|||||||
160
service.go
160
service.go
@ -2,10 +2,13 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"apigo.cc/go/log"
|
"apigo.cc/go/log"
|
||||||
|
"apigo.cc/go/watch"
|
||||||
"errors"
|
"errors"
|
||||||
|
"math"
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
// webServiceType 内部存储的服务元数据
|
// webServiceType 内部存储的服务元数据
|
||||||
@ -102,13 +105,27 @@ func (ws *webServer) SetInFilter(filter func(in *map[string]any, request *Reques
|
|||||||
ws.inFilters = append(ws.inFilters, filter)
|
ws.inFilters = append(ws.inFilters, filter)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddShutdownHook 增加停机钩子
|
||||||
|
func AddShutdownHook(hook func()) {
|
||||||
|
DefaultServer.AddShutdownHook(hook)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *webServer) AddShutdownHook(hook func()) {
|
||||||
|
ws.shutdownHooksLock.Lock()
|
||||||
|
defer ws.shutdownHooksLock.Unlock()
|
||||||
|
ws.shutdownHooks = append(ws.shutdownHooks, hook)
|
||||||
|
}
|
||||||
|
|
||||||
// SetOutFilter 设置后置过滤器
|
// SetOutFilter 设置后置过滤器
|
||||||
func SetOutFilter(filter func(in map[string]any, request *Request, response *Response, out any, logger *log.Logger) (newOut any, isOver bool)) {
|
func SetOutFilter(filter func(in map[string]any, request *Request, response *Response, out any, logger *log.Logger) (newOut any, isOver bool)) {
|
||||||
DefaultServer.SetOutFilter(filter)
|
DefaultServer.SetOutFilter(filter)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ws *webServer) SetOutFilter(filter func(in map[string]any, request *Request, response *Response, out any, logger *log.Logger) (newOut any, isOver bool)) {
|
func (ws *webServer) SetOutFilter(filter func(in map[string]any, request *Request, response *Response, out any, logger *log.Logger) (newOut any, isOver bool)) {
|
||||||
|
ws.webServicesLock.Lock()
|
||||||
|
defer ws.webServicesLock.Unlock()
|
||||||
ws.outFilters = append(ws.outFilters, filter)
|
ws.outFilters = append(ws.outFilters, filter)
|
||||||
|
ws.hasOutFilter = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// HostContext 提供流式服务注册能力
|
// HostContext 提供流式服务注册能力
|
||||||
@ -471,3 +488,146 @@ func GetInjectT[T any]() T {
|
|||||||
}
|
}
|
||||||
return obj.(T)
|
return obj.(T)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var webDevOnce sync.Once
|
||||||
|
|
||||||
|
// EnableWebDev 开启 Web 开发模式,支持自动刷新
|
||||||
|
func EnableWebDev(config watch.Config) {
|
||||||
|
webDevOnce.Do(func() {
|
||||||
|
log.DefaultLogger.Warning("Web Development Mode Enabled. This should NOT be used in production environment.")
|
||||||
|
onWatchConn := map[string]*WebSocketConn{}
|
||||||
|
onWatchLock := sync.Mutex{}
|
||||||
|
|
||||||
|
// 1. 注册 WebSocket 服务
|
||||||
|
RegisterWebsocket("/_watch", func(request *Request, conn *WebSocketConn, logger *log.Logger) {
|
||||||
|
onWatchLock.Lock()
|
||||||
|
onWatchConn[request.Id] = conn
|
||||||
|
onWatchLock.Unlock()
|
||||||
|
|
||||||
|
logger.Info("watch ws connected", "id", request.Id)
|
||||||
|
|
||||||
|
// 保持连接,处理消息 (如 ping)
|
||||||
|
for {
|
||||||
|
if _, err := conn.ReadString(); err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
onWatchLock.Lock()
|
||||||
|
delete(onWatchConn, request.Id)
|
||||||
|
onWatchLock.Unlock()
|
||||||
|
logger.Info("watch ws disconnected", "id", request.Id)
|
||||||
|
})
|
||||||
|
|
||||||
|
// 2. 启动文件监听
|
||||||
|
watcher, err := watch.Start(config, func(e *watch.Event) {
|
||||||
|
onWatchLock.Lock()
|
||||||
|
defer onWatchLock.Unlock()
|
||||||
|
for _, conn := range onWatchConn {
|
||||||
|
_ = conn.Send("reload")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.DefaultLogger.Error("failed to start watch for EnableWebDev", "error", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 注册停机钩子
|
||||||
|
AddShutdownHook(func() {
|
||||||
|
watcher.Stop()
|
||||||
|
onWatchLock.Lock()
|
||||||
|
for _, conn := range onWatchConn {
|
||||||
|
_ = conn.Close()
|
||||||
|
}
|
||||||
|
onWatchLock.Unlock()
|
||||||
|
})
|
||||||
|
|
||||||
|
// 4. 注册输出过滤器进行注入
|
||||||
|
SetOutFilter(func(in map[string]any, request *Request, response *Response, out any, logger *log.Logger) (newOut any, isOver bool) {
|
||||||
|
contentType := response.Header().Get("Content-Type")
|
||||||
|
var outStr string
|
||||||
|
|
||||||
|
if out != nil {
|
||||||
|
switch v := out.(type) {
|
||||||
|
case string:
|
||||||
|
outStr = v
|
||||||
|
case []byte:
|
||||||
|
outStr = string(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if outStr == "" && response.changed {
|
||||||
|
outStr = string(response.GetBody())
|
||||||
|
}
|
||||||
|
|
||||||
|
if outStr == "" {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
isHtml := strings.HasPrefix(contentType, "text/html")
|
||||||
|
if !isHtml && (contentType == "" || strings.HasPrefix(contentType, "text/plain")) {
|
||||||
|
// 检测内容前 100 字节是否包含 <html
|
||||||
|
checkLen := int(math.Min(float64(len(outStr)), 100))
|
||||||
|
if strings.Contains(strings.ToLower(outStr[0:checkLen]), "<html") {
|
||||||
|
isHtml = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if isHtml {
|
||||||
|
if strings.Contains(outStr, "let _watchWS = null") {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
// 注入自动刷新的代码
|
||||||
|
injectCode := `<script>
|
||||||
|
let _watchWS = null
|
||||||
|
let _watchWSConnection = false
|
||||||
|
let _watchWSIsFirst = true
|
||||||
|
function connect() {
|
||||||
|
_watchWSConnection = true
|
||||||
|
let ws = new WebSocket(location.protocol.replace('http', 'ws') + '//' + location.host + '/_watch')
|
||||||
|
ws.onopen = () => {
|
||||||
|
_watchWS = ws
|
||||||
|
_watchWSConnection = false
|
||||||
|
if( !_watchWSIsFirst ) location.reload()
|
||||||
|
_watchWSIsFirst = false
|
||||||
|
}
|
||||||
|
ws.onmessage = () => {
|
||||||
|
location.reload()
|
||||||
|
}
|
||||||
|
ws.onclose = () => {
|
||||||
|
_watchWS = null
|
||||||
|
_watchWSConnection = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
setInterval(()=>{
|
||||||
|
if(_watchWS!= null){
|
||||||
|
try{
|
||||||
|
_watchWS.send("ping")
|
||||||
|
}catch(err){
|
||||||
|
_watchWS = null
|
||||||
|
_watchWSConnection = false
|
||||||
|
}
|
||||||
|
} else if(!_watchWSConnection){
|
||||||
|
connect()
|
||||||
|
}
|
||||||
|
}, 1000)
|
||||||
|
connect()
|
||||||
|
</script>`
|
||||||
|
// 仅替换最后一个 </html> 避免多个标签时的重复注入
|
||||||
|
lastIndex := strings.LastIndex(outStr, "</html>")
|
||||||
|
if lastIndex != -1 {
|
||||||
|
outStr = outStr[:lastIndex] + injectCode + outStr[lastIndex:]
|
||||||
|
} else {
|
||||||
|
outStr = outStr + injectCode
|
||||||
|
}
|
||||||
|
|
||||||
|
// 无论如何,只要我们提供了新的输出,就清空原始 Body,防止 handler 重复写入
|
||||||
|
response.ClearBody()
|
||||||
|
return []byte(outStr), false
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,8 +1,11 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"apigo.cc/go/watch"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
@ -47,3 +50,60 @@ func TestWebSocketService(t *testing.T) {
|
|||||||
t.Errorf("Reply mismatch: %v", reply)
|
t.Errorf("Reply mismatch: %v", reply)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEnableWebDev(t *testing.T) {
|
||||||
|
// 1. 初始化 EnableWebDev
|
||||||
|
EnableWebDev(watch.Config{
|
||||||
|
Paths: []string{"."},
|
||||||
|
})
|
||||||
|
|
||||||
|
// 2. 准备一个真实的静态 HTML 文件
|
||||||
|
staticDir := "test_static"
|
||||||
|
_ = os.MkdirAll(staticDir, 0755)
|
||||||
|
htmlFile := filepath.Join(staticDir, "index.html")
|
||||||
|
_ = os.WriteFile(htmlFile, []byte("<html><head></head><body>Static Content</body></html>"), 0644)
|
||||||
|
defer os.RemoveAll(staticDir)
|
||||||
|
|
||||||
|
// 注册静态服务
|
||||||
|
Static("/static/", staticDir)
|
||||||
|
|
||||||
|
handler := &RouteHandler{ws: DefaultServer}
|
||||||
|
|
||||||
|
// 3. 测试静态文件注入
|
||||||
|
req := httptest.NewRequest("GET", "/static/index.html", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
if !strings.Contains(body, "let _watchWS = null") {
|
||||||
|
t.Errorf("Static HTML injection failed, code not found in body: %s", body)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. 测试普通服务注入
|
||||||
|
Register("GET", "/test-dev", func() string {
|
||||||
|
return "<html><head></head><body>Hello</body></html>"
|
||||||
|
})
|
||||||
|
|
||||||
|
req2 := httptest.NewRequest("GET", "/test-dev", nil)
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w2, req2)
|
||||||
|
|
||||||
|
body2 := w2.Body.String()
|
||||||
|
if !strings.Contains(body2, "let _watchWS = null") {
|
||||||
|
t.Errorf("Dynamic HTML injection failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. 验证非 HTML 不注入
|
||||||
|
Register("GET", "/test-json", func() map[string]string {
|
||||||
|
return map[string]string{"foo": "bar"}
|
||||||
|
})
|
||||||
|
req3 := httptest.NewRequest("GET", "/test-json", nil)
|
||||||
|
w3 := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w3, req3)
|
||||||
|
|
||||||
|
body3 := w3.Body.String()
|
||||||
|
if strings.Contains(body3, "let _watchWS = null") {
|
||||||
|
t.Errorf("JSON should not be injected")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user