refactor(service): defer EnableWebDev initialization to Start lifecycle

This commit is contained in:
AI Engineer 2026-06-05 11:38:44 +08:00
parent fe3b420d35
commit ff34d11c9b
3 changed files with 24 additions and 9 deletions

View File

@ -7,6 +7,7 @@ import (
"apigo.cc/go/redis" "apigo.cc/go/redis"
"apigo.cc/go/safe" "apigo.cc/go/safe"
"apigo.cc/go/starter" "apigo.cc/go/starter"
"apigo.cc/go/watch"
"context" "context"
"fmt" "fmt"
"golang.org/x/net/http2" "golang.org/x/net/http2"
@ -91,6 +92,10 @@ type webServer struct {
// 性能优化:标记是否有输出过滤器 // 性能优化:标记是否有输出过滤器
hasOutFilter bool hasOutFilter bool
// Web 开发模式配置
webDevEnabled bool
webDevConfig watch.Config
} }
// DefaultServer 全局单例服务实例 // DefaultServer 全局单例服务实例
@ -303,6 +308,10 @@ func (ws *webServer) Start(ctx context.Context, logger *log.Logger) error {
addr = ":" + addr addr = ":" + addr
} }
if ws.webDevEnabled {
ws.initWebDev(logger)
}
appName := ws.Config.App appName := ws.Config.App
if appName == "" { if appName == "" {
appName = GetDefaultName() appName = GetDefaultName()

View File

@ -493,19 +493,22 @@ var webDevOnce sync.Once
// EnableWebDev 开启 Web 开发模式,支持自动刷新 // EnableWebDev 开启 Web 开发模式,支持自动刷新
func EnableWebDev(config watch.Config) { func EnableWebDev(config watch.Config) {
DefaultServer.webDevEnabled = true
DefaultServer.webDevConfig = config
}
func (ws *webServer) initWebDev(logger *log.Logger) {
webDevOnce.Do(func() { webDevOnce.Do(func() {
log.DefaultLogger.Warning("Web Development Mode Enabled. This should NOT be used in production environment.") logger.Warning("Web Development Mode Enabled. This should NOT be used in production environment.")
onWatchConn := map[string]*WebSocketConn{} onWatchConn := map[string]*WebSocketConn{}
onWatchLock := sync.Mutex{} onWatchLock := sync.Mutex{}
// 1. 注册 WebSocket 服务 // 1. 注册 WebSocket 服务
RegisterWebsocket("/_watch", func(request *Request, conn *WebSocketConn, logger *log.Logger) { ws.RegisterWebsocket("/_watch", func(request *Request, conn *WebSocketConn, logger *log.Logger) {
onWatchLock.Lock() onWatchLock.Lock()
onWatchConn[request.Id] = conn onWatchConn[request.Id] = conn
onWatchLock.Unlock() onWatchLock.Unlock()
logger.Info("watch ws connected", "id", request.Id)
// 保持连接,处理消息 (如 ping) // 保持连接,处理消息 (如 ping)
for { for {
if _, err := conn.ReadString(); err != nil { if _, err := conn.ReadString(); err != nil {
@ -516,11 +519,10 @@ func EnableWebDev(config watch.Config) {
onWatchLock.Lock() onWatchLock.Lock()
delete(onWatchConn, request.Id) delete(onWatchConn, request.Id)
onWatchLock.Unlock() onWatchLock.Unlock()
logger.Info("watch ws disconnected", "id", request.Id)
}) })
// 2. 启动文件监听 // 2. 启动文件监听
watcher, err := watch.Start(config, func(e *watch.Event) { watcher, err := watch.Start(ws.webDevConfig, func(e *watch.Event) {
onWatchLock.Lock() onWatchLock.Lock()
defer onWatchLock.Unlock() defer onWatchLock.Unlock()
for _, conn := range onWatchConn { for _, conn := range onWatchConn {
@ -529,12 +531,12 @@ func EnableWebDev(config watch.Config) {
}) })
if err != nil { if err != nil {
log.DefaultLogger.Error("failed to start watch for EnableWebDev", "error", err.Error()) logger.Error("failed to start watch for EnableWebDev", "error", err.Error())
return return
} }
// 3. 注册停机钩子 // 3. 注册停机钩子
AddShutdownHook(func() { ws.AddShutdownHook(func() {
watcher.Stop() watcher.Stop()
onWatchLock.Lock() onWatchLock.Lock()
for _, conn := range onWatchConn { for _, conn := range onWatchConn {
@ -544,7 +546,7 @@ func EnableWebDev(config watch.Config) {
}) })
// 4. 注册输出过滤器进行注入 // 4. 注册输出过滤器进行注入
SetOutFilter(func(in map[string]any, request *Request, response *Response, out any, logger *log.Logger) (newOut any, isOver bool) { ws.SetOutFilter(func(in map[string]any, request *Request, response *Response, out any, logger *log.Logger) (newOut any, isOver bool) {
contentType := response.Header().Get("Content-Type") contentType := response.Header().Get("Content-Type")
var outStr string var outStr string

View File

@ -1,6 +1,7 @@
package service package service
import ( import (
"apigo.cc/go/log"
"apigo.cc/go/watch" "apigo.cc/go/watch"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"net/http/httptest" "net/http/httptest"
@ -57,6 +58,9 @@ func TestEnableWebDev(t *testing.T) {
Paths: []string{"."}, Paths: []string{"."},
}) })
// 必须手动调用 initWebDev 或触发 Start因为现在的逻辑是延迟初始化的
DefaultServer.initWebDev(log.DefaultLogger)
// 2. 准备一个真实的静态 HTML 文件 // 2. 准备一个真实的静态 HTML 文件
staticDir := "test_static" staticDir := "test_static"
_ = os.MkdirAll(staticDir, 0755) _ = os.MkdirAll(staticDir, 0755)