diff --git a/server.go b/server.go index 909ef12..2b1e73b 100644 --- a/server.go +++ b/server.go @@ -7,6 +7,7 @@ import ( "apigo.cc/go/redis" "apigo.cc/go/safe" "apigo.cc/go/starter" + "apigo.cc/go/watch" "context" "fmt" "golang.org/x/net/http2" @@ -91,6 +92,10 @@ type webServer struct { // 性能优化:标记是否有输出过滤器 hasOutFilter bool + + // Web 开发模式配置 + webDevEnabled bool + webDevConfig watch.Config } // DefaultServer 全局单例服务实例 @@ -303,6 +308,10 @@ func (ws *webServer) Start(ctx context.Context, logger *log.Logger) error { addr = ":" + addr } + if ws.webDevEnabled { + ws.initWebDev(logger) + } + appName := ws.Config.App if appName == "" { appName = GetDefaultName() diff --git a/service.go b/service.go index e5e879c..aa33655 100644 --- a/service.go +++ b/service.go @@ -493,19 +493,22 @@ var webDevOnce sync.Once // EnableWebDev 开启 Web 开发模式,支持自动刷新 func EnableWebDev(config watch.Config) { + DefaultServer.webDevEnabled = true + DefaultServer.webDevConfig = config +} + +func (ws *webServer) initWebDev(logger *log.Logger) { webDevOnce.Do(func() { - log.DefaultLogger.Warning("Web Development Mode Enabled. This should NOT be used in production environment.") + logger.Warning("Web Development Mode Enabled. This should NOT be used in production environment.") onWatchConn := map[string]*WebSocketConn{} onWatchLock := sync.Mutex{} // 1. 注册 WebSocket 服务 - RegisterWebsocket("/_watch", func(request *Request, conn *WebSocketConn, logger *log.Logger) { + ws.RegisterWebsocket("/_watch", func(request *Request, conn *WebSocketConn, logger *log.Logger) { onWatchLock.Lock() onWatchConn[request.Id] = conn onWatchLock.Unlock() - logger.Info("watch ws connected", "id", request.Id) - // 保持连接,处理消息 (如 ping) for { if _, err := conn.ReadString(); err != nil { @@ -516,11 +519,10 @@ func EnableWebDev(config watch.Config) { onWatchLock.Lock() delete(onWatchConn, request.Id) onWatchLock.Unlock() - logger.Info("watch ws disconnected", "id", request.Id) }) // 2. 启动文件监听 - watcher, err := watch.Start(config, func(e *watch.Event) { + watcher, err := watch.Start(ws.webDevConfig, func(e *watch.Event) { onWatchLock.Lock() defer onWatchLock.Unlock() for _, conn := range onWatchConn { @@ -529,12 +531,12 @@ func EnableWebDev(config watch.Config) { }) if err != nil { - log.DefaultLogger.Error("failed to start watch for EnableWebDev", "error", err.Error()) + logger.Error("failed to start watch for EnableWebDev", "error", err.Error()) return } // 3. 注册停机钩子 - AddShutdownHook(func() { + ws.AddShutdownHook(func() { watcher.Stop() onWatchLock.Lock() for _, conn := range onWatchConn { @@ -544,7 +546,7 @@ func EnableWebDev(config watch.Config) { }) // 4. 注册输出过滤器进行注入 - SetOutFilter(func(in map[string]any, request *Request, response *Response, out any, logger *log.Logger) (newOut any, isOver bool) { + ws.SetOutFilter(func(in map[string]any, request *Request, response *Response, out any, logger *log.Logger) (newOut any, isOver bool) { contentType := response.Header().Get("Content-Type") var outStr string diff --git a/websocket_test.go b/websocket_test.go index 02da9d4..7163647 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -1,6 +1,7 @@ package service import ( + "apigo.cc/go/log" "apigo.cc/go/watch" "github.com/gorilla/websocket" "net/http/httptest" @@ -57,6 +58,9 @@ func TestEnableWebDev(t *testing.T) { Paths: []string{"."}, }) + // 必须手动调用 initWebDev 或触发 Start,因为现在的逻辑是延迟初始化的 + DefaultServer.initWebDev(log.DefaultLogger) + // 2. 准备一个真实的静态 HTML 文件 staticDir := "test_static" _ = os.MkdirAll(staticDir, 0755)