diff --git a/AppClient.go b/AppClient.go index 0ef0342..19f5bec 100644 --- a/AppClient.go +++ b/AppClient.go @@ -8,14 +8,15 @@ import ( // AppClient 用于管理单个请求的重试和负载均衡状态 type AppClient struct { - excludes map[string]bool // 本次请求已排除的节点 - attempts int // 本次请求的重试次数 - Logger *log.Logger // 用于日志记录的 Logger - App string // 目标应用名称 - Method string // 请求方法 - Path string // 请求路径 - Data map[string]any // 请求数据 - Headers map[string]string // 请求头 + discoverer *Discoverer + excludes map[string]bool // 本次请求已排除的节点 + attempts int // 本次请求的重试次数 + Logger *log.Logger // 用于日志记录的 Logger + App string // 目标应用名称 + Method string // 请求方法 + Path string // 请求路径 + Data map[string]any // 请求数据 + Headers map[string]string // 请求头 } // logError 记录 Discover 客户端错误 @@ -33,10 +34,10 @@ func (ac *AppClient) Next(app string, request *http.Request) *NodeInfo { // CheckApp 检查并尝试添加应用 func (ac *AppClient) CheckApp(app string) bool { - nodes := getAppNodes(app) + nodes := ac.discoverer.GetAppNodes(app) if nodes == nil { - if !addApp(app, "", true) { - ac.logError("app not found", "app", app, "calls", Config.Calls) + if !ac.discoverer.AddExternalApp(app, "") { + ac.logError("app not found", "app", app, "calls", ac.discoverer.Config.Calls) return false } } @@ -49,7 +50,7 @@ func (ac *AppClient) NextWithNode(app, withNode string, request *http.Request) * ac.excludes = make(map[string]bool) } - allNodes := getAppNodes(app) + allNodes := ac.discoverer.GetAppNodes(app) if len(allNodes) == 0 { ac.logError("node not found", "app", app) return nil @@ -63,7 +64,7 @@ func (ac *AppClient) NextWithNode(app, withNode string, request *http.Request) * readyNodes := make([]*NodeInfo, 0, len(allNodes)) for _, node := range allNodes { - if ac.excludes[node.Addr] || node.FailedTimes.Load() >= int32(Config.CallRetryTimes) { + if ac.excludes[node.Addr] || node.FailedTimes.Load() >= int32(ac.discoverer.Config.CallRetryTimes) { continue } readyNodes = append(readyNodes, node) @@ -80,7 +81,7 @@ func (ac *AppClient) NextWithNode(app, withNode string, request *http.Request) * var node *NodeInfo if len(readyNodes) > 0 { - node = settedLoadBalancer.Next(ac, readyNodes, request) + node = ac.discoverer.settedLoadBalancer.Next(ac, readyNodes, request) if node != nil { ac.excludes[node.Addr] = true } diff --git a/Caller.go b/Caller.go index 680e133..0d75591 100644 --- a/Caller.go +++ b/Caller.go @@ -43,14 +43,20 @@ func getHttpClient(app string, timeout time.Duration, h2c bool) *gohttp.Client { // Caller 用于发起服务间调用 type Caller struct { - Request *http.Request // 原始请求,用于透传 Header - NoBody bool // 是否不发送请求体 - logger *log.Logger // 用于日志记录的 Logger + discoverer *Discoverer + Request *http.Request // 原始请求,用于透传 Header + NoBody bool // 是否不发送请求体 + logger *log.Logger // 用于日志记录的 Logger } // NewCaller 创建一个新的调用器 func NewCaller(request *http.Request, logger *log.Logger) *Caller { - return &Caller{Request: request, logger: logger} + return DefaultDiscoverer.NewCaller(request, logger) +} + +// NewCaller 创建一个新的调用器实例 +func (d *Discoverer) NewCaller(request *http.Request, logger *log.Logger) *Caller { + return &Caller{discoverer: d, Request: request, logger: logger} } // logError 记录 Discover 调用器错误 @@ -125,9 +131,9 @@ func (c *Caller) doWithNode(manual bool, method, app, withNode, path string, dat callerHeaders[headers[i-1]] = headers[i] } - if isServer { - callerHeaders[HeaderFromApp] = Config.App - callerHeaders[HeaderFromNode] = myAddr + if c.discoverer.isServer { + callerHeaders[HeaderFromApp] = c.discoverer.Config.App + callerHeaders[HeaderFromNode] = c.discoverer.myAddr } callData := make(map[string]any) @@ -139,16 +145,17 @@ func (c *Caller) doWithNode(manual bool, method, app, withNode, path string, dat } appClient := AppClient{ - Logger: c.logger, - App: app, - Method: method, - Path: path, - Data: callData, - Headers: callerHeaders, + discoverer: c.discoverer, + Logger: c.logger, + App: app, + Method: method, + Path: path, + Data: callData, + Headers: callerHeaders, } - if settedRoute != nil { - settedRoute(&appClient, c.Request) + if c.discoverer.settedRoute != nil { + c.discoverer.settedRoute(&appClient, c.Request) app = appClient.App method = appClient.Method path = appClient.Path @@ -158,7 +165,7 @@ func (c *Caller) doWithNode(manual bool, method, app, withNode, path string, dat return &gohttp.Result{Error: fmt.Errorf("app %s not found", app)}, "" } - callInfo := getCallInfo(app) + callInfo := c.discoverer.getCallInfo(app) if callInfo != nil && callInfo.Token != "" { callerHeaders["Access-Token"] = callInfo.Token } @@ -222,7 +229,7 @@ func (c *Caller) doWithNode(manual bool, method, app, withNode, path string, dat responseTime := time.Since(startTime) usedTimeMs := float32(responseTime.Nanoseconds()) / 1e6 - settedLoadBalancer.Response(&appClient, node, res.Error, res.Response, responseTime) + c.discoverer.settedLoadBalancer.Response(&appClient, node, res.Error, res.Response, responseTime) if res.Error != nil || (res.Response != nil && res.Response.StatusCode >= 502 && res.Response.StatusCode <= 504) { node.FailedTimes.Add(1) @@ -236,14 +243,12 @@ func (c *Caller) doWithNode(manual bool, method, app, withNode, path string, dat c.logError(errStr, "app", app, "node", node.Addr, "path", path, "attempts", appClient.attempts) appClient.Log(node.Addr, usedTimeMs, fmt.Errorf("%s", errStr)) - // 仅做本地隔离,不再篡改全局注册中心状态 - if node.FailedTimes.Load() >= int32(Config.CallRetryTimes) { - logError("node isolated locally due to high failures", "app", app, "node", node.Addr) + if node.FailedTimes.Load() >= int32(c.discoverer.Config.CallRetryTimes) { + c.discoverer.logError("node isolated locally due to high failures", "app", app, "node", node.Addr) } continue } - // 请求成功,重置失败计数 node.FailedTimes.Store(0) appClient.Log(node.Addr, usedTimeMs, nil) if strings.ToUpper(method) == "WS" { diff --git a/Config.go b/Config.go index 3c4bfbf..aea514e 100644 --- a/Config.go +++ b/Config.go @@ -1,14 +1,17 @@ package discover -// Config 存储发现服务的全局配置 -var Config = struct { +// ConfigStruct 存储发现服务的配置 +type ConfigStruct struct { Registry string // 注册中心地址,如 redis://:@127.0.0.1:6379/15 App string // 当前应用名称 Weight int // 权重,默认为 100 Calls map[string]string // 调用的应用列表及其配置 CallRetryTimes int // 调用重试次数 IpPrefix string // 指定使用的 IP 网段 -}{ +} + +// Config 存储发现服务的全局配置(兼容旧代码) +var Config = ConfigStruct{ Weight: 100, CallRetryTimes: 10, } diff --git a/Discover.go b/Discover.go index 7dc9479..8de97de 100644 --- a/Discover.go +++ b/Discover.go @@ -19,26 +19,29 @@ import ( "apigo.cc/go/redis" ) -var ( +// Discoverer 发现服务实例 +type Discoverer struct { + Config ConfigStruct + serverRedisPool *redis.Redis clientRedisPool *redis.Redis pubsubRedisPool *redis.Redis - isServer = false - isClient = false - daemonRunning = false - myAddr = "" - _logger = log.DefaultLogger - _inited = false + isServer bool + isClient bool + daemonRunning bool + myAddr string + logger *log.Logger + inited bool daemonStopChan chan bool appLock sync.RWMutex - _calls = map[string]*callInfoType{} - _appNodes = map[string]map[string]*NodeInfo{} - appSubscribed = map[string]bool{} + calls map[string]*callInfoType + appNodes map[string]map[string]*NodeInfo + appSubscribed map[string]bool - settedRoute func(*AppClient, *http.Request) = nil - settedLoadBalancer LoadBalancer = &DefaultLoadBalancer{} -) + settedRoute func(*AppClient, *http.Request) + settedLoadBalancer LoadBalancer +} type callInfoType struct { Timeout time.Duration @@ -47,147 +50,168 @@ type callInfoType struct { SSL bool } -// IsServer 返回当前节点是否作为服务端运行 -func IsServer() bool { return isServer } +// DefaultDiscoverer 默认的全局发现服务实例 +var DefaultDiscoverer = NewDiscoverer() -// IsClient 返回当前节点是否作为客户端运行 -func IsClient() bool { return isClient } - -// logError 记录 Discover 内部错误 -func logError(msg string, extra ...any) { - _logger.Error("Discover: "+msg, append(extra, "app", Config.App, "addr", myAddr)...) +// NewDiscoverer 创建一个新的发现服务实例 +func NewDiscoverer() *Discoverer { + return &Discoverer{ + Config: ConfigStruct{ + Weight: 100, + CallRetryTimes: 10, + }, + logger: log.DefaultLogger, + calls: make(map[string]*callInfoType), + appNodes: make(map[string]map[string]*NodeInfo), + appSubscribed: make(map[string]bool), + settedLoadBalancer: &DefaultLoadBalancer{}, + } } -// logInfo 记录 Discover 内部信息 -func logInfo(msg string, extra ...any) { - _logger.Info("Discover: "+msg, append(extra, "app", Config.App, "addr", myAddr)...) +// IsServer 返回当前节点是否作为服务端运行 +func (d *Discoverer) IsServer() bool { return d.isServer } + +// IsClient 返回当前节点是否作为客户端运行 +func (d *Discoverer) IsClient() bool { return d.isClient } + +func (d *Discoverer) logError(msg string, extra ...any) { + d.logger.Error("Discover: "+msg, append(extra, "app", d.Config.App, "addr", d.myAddr)...) +} + +func (d *Discoverer) logInfo(msg string, extra ...any) { + d.logger.Info("Discover: "+msg, append(extra, "app", d.Config.App, "addr", d.myAddr)...) } // SetLogger 设置 Discover 使用的全局 Logger -func SetLogger(logger *log.Logger) { - _logger = logger +func (d *Discoverer) SetLogger(logger *log.Logger) { + d.logger = logger } -// Init 初始化 Discover 配置,通常由 Start 自动调用 -func Init() { - appLock.Lock() - defer appLock.Unlock() - if _inited { +// Init 初始化 Discover 配置 +func (d *Discoverer) Init() { + d.appLock.Lock() + defer d.appLock.Unlock() + if d.inited { return } - _inited = true - _ = config.Load(&Config, "discover") - - if Config.CallRetryTimes <= 0 { - Config.CallRetryTimes = 10 - } - if Config.Weight <= 0 { - Config.Weight = 100 - } - if Config.Registry == "" { - Config.Registry = DefaultRegistry + d.inited = true + // 如果是默认实例,尝试加载配置 + if d == DefaultDiscoverer { + _ = config.Load(&d.Config, "discover") + Config = d.Config // 保持 Config 变量同步 } - _logger = log.New(id.MakeID(12)) + if d.Config.CallRetryTimes <= 0 { + d.Config.CallRetryTimes = 10 + } + if d.Config.Weight <= 0 { + d.Config.Weight = 100 + } + if d.Config.Registry == "" { + d.Config.Registry = DefaultRegistry + } + + if d.logger == log.DefaultLogger || d.logger == nil { + d.logger = log.New(id.MakeID(12)) + } } // Start 启动服务发现,指定当前节点的外部访问地址 -func Start(addr string) bool { - Init() - myAddr = addr +func (d *Discoverer) Start(addr string) bool { + d.Init() + d.myAddr = addr - isServer = Config.App != "" && Config.Weight > 0 - if isServer && Config.Registry != "" { - serverRedisPool = redis.GetRedis(Config.Registry, _logger) - if serverRedisPool.Error != nil { - logError(serverRedisPool.Error.Error()) + d.isServer = d.Config.App != "" && d.Config.Weight > 0 + if d.isServer && d.Config.Registry != "" { + d.serverRedisPool = redis.GetRedis(d.Config.Registry, d.logger) + if d.serverRedisPool.Error != nil { + d.logError(d.serverRedisPool.Error.Error()) } // 注册节点 - if serverRedisPool.Do("HSET", Config.App, addr, Config.Weight).Error == nil { - serverRedisPool.Do("SETEX", Config.App+"_"+addr, 10, "1") - logInfo("registered") - serverRedisPool.PUBLISH("CH_"+Config.App, fmt.Sprintf("%s %d", addr, Config.Weight)) - daemonRunning = true - daemonStopChan = make(chan bool) - go daemon() + if d.serverRedisPool.Do("HSET", d.Config.App, addr, d.Config.Weight).Error == nil { + d.serverRedisPool.Do("SETEX", d.Config.App+"_"+addr, 10, "1") + d.logInfo("registered") + d.serverRedisPool.PUBLISH("CH_"+d.Config.App, fmt.Sprintf("%s %d", addr, d.Config.Weight)) + d.daemonRunning = true + d.daemonStopChan = make(chan bool) + go d.daemon() } else { - logError("register failed") + d.logError("register failed") } } - calls := getCalls() + calls := d.getCalls() if len(calls) > 0 { for app, conf := range calls { - addApp(app, conf, false) + d.addApp(app, conf, false) } - if !startSub() { + if !d.startSub() { return false } } return true } -func daemon() { - logInfo("daemon thread started") - // 每 5 秒心跳一次,降低 Redis 压力,TTL 保持 10 秒 +func (d *Discoverer) daemon() { + d.logInfo("daemon thread started") ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() - for daemonRunning { + for d.daemonRunning { <-ticker.C - if !daemonRunning { + if !d.daemonRunning { break } - if isServer && serverRedisPool != nil { - if !serverRedisPool.Do("HEXISTS", Config.App, myAddr).Bool() { - logInfo("lost app registered info, re-registering") - if serverRedisPool.Do("HSET", Config.App, myAddr, Config.Weight).Error == nil { - serverRedisPool.Do("SETEX", Config.App+"_"+myAddr, 10, "1") - serverRedisPool.PUBLISH("CH_"+Config.App, fmt.Sprintf("%s %d", myAddr, Config.Weight)) + if d.isServer && d.serverRedisPool != nil { + if !d.serverRedisPool.Do("HEXISTS", d.Config.App, d.myAddr).Bool() { + d.logInfo("lost app registered info, re-registering") + if d.serverRedisPool.Do("HSET", d.Config.App, d.myAddr, d.Config.Weight).Error == nil { + d.serverRedisPool.Do("SETEX", d.Config.App+"_"+d.myAddr, 10, "1") + d.serverRedisPool.PUBLISH("CH_"+d.Config.App, fmt.Sprintf("%s %d", d.myAddr, d.Config.Weight)) } } else { - serverRedisPool.Do("SETEX", Config.App+"_"+myAddr, 10, "1") + d.serverRedisPool.Do("SETEX", d.Config.App+"_"+d.myAddr, 10, "1") } } } - logInfo("daemon thread stopped") - if daemonStopChan != nil { - daemonStopChan <- true + d.logInfo("daemon thread stopped") + if d.daemonStopChan != nil { + d.daemonStopChan <- true } } -func startSub() bool { - if Config.Registry == "" { +func (d *Discoverer) startSub() bool { + if d.Config.Registry == "" { return true } - appLock.Lock() - if clientRedisPool == nil { - clientRedisPool = redis.GetRedis(Config.Registry, _logger) + d.appLock.Lock() + if d.clientRedisPool == nil { + d.clientRedisPool = redis.GetRedis(d.Config.Registry, d.logger) } - if pubsubRedisPool == nil { - pubsubRedisPool = redis.GetRedis(Config.Registry, _logger.New(id.MakeID(12))) + if d.pubsubRedisPool == nil { + d.pubsubRedisPool = redis.GetRedis(d.Config.Registry, d.logger.New(id.MakeID(12))) // 订阅所有已注册的应用 - for app := range appSubscribed { - subscribeAppUnderLock(app) + for app := range d.appSubscribed { + d.subscribeAppUnderLock(app) } // 必须在释放锁之前完成配置,但在释放锁之后启动,避免死锁 - appLock.Unlock() - pubsubRedisPool.Start() - appLock.Lock() + d.appLock.Unlock() + d.pubsubRedisPool.Start() + d.appLock.Lock() } - isClient = true - appLock.Unlock() + d.isClient = true + d.appLock.Unlock() return true } -func subscribeAppUnderLock(app string) { - pubsubRedisPool.Subscribe("CH_"+app, func() { - fetchApp(app) +func (d *Discoverer) subscribeAppUnderLock(app string) { + d.pubsubRedisPool.Subscribe("CH_"+app, func() { + d.fetchApp(app) }, func(data []byte) { a := strings.Split(string(data), " ") addr := a[0] @@ -195,43 +219,42 @@ func subscribeAppUnderLock(app string) { if len(a) == 2 { weight = cast.Int(a[1]) } - logInfo("received node update", "app", app, "addr", addr, "weight", weight) - pushNode(app, addr, weight) + d.logInfo("received node update", "app", app, "addr", addr, "weight", weight) + d.pushNode(app, addr, weight) }) } // Stop 停止 Discover 并从注册中心注销当前节点 -func Stop() { - appLock.Lock() - if isClient && pubsubRedisPool != nil { - pubsubRedisPool.Stop() - isClient = false +func (d *Discoverer) Stop() { + d.appLock.Lock() + if d.isClient && d.pubsubRedisPool != nil { + d.pubsubRedisPool.Stop() + d.isClient = false } - if isServer { - daemonRunning = false - if serverRedisPool != nil { - serverRedisPool.Do("HDEL", Config.App, myAddr) - serverRedisPool.Do("DEL", Config.App+"_"+myAddr) - serverRedisPool.PUBLISH("CH_"+Config.App, fmt.Sprintf("%s %d", myAddr, 0)) + if d.isServer { + d.daemonRunning = false + if d.serverRedisPool != nil { + d.serverRedisPool.Do("HDEL", d.Config.App, d.myAddr) + d.serverRedisPool.Do("DEL", d.Config.App+"_"+d.myAddr) + d.serverRedisPool.PUBLISH("CH_"+d.Config.App, fmt.Sprintf("%s %d", d.myAddr, 0)) } - isServer = false + d.isServer = false } - appLock.Unlock() + d.appLock.Unlock() } // Wait 等待守护进程退出 -func Wait() { - if daemonStopChan != nil { - <-daemonStopChan - daemonStopChan = nil +func (d *Discoverer) Wait() { + if d.daemonStopChan != nil { + <-d.daemonStopChan + d.daemonStopChan = nil } } // EasyStart 自动根据环境变量和本地网卡信息启动 Discover -// 返回监听的 IP 和 端口 -func EasyStart() (string, int) { - Init() +func (d *Discoverer) EasyStart() (string, int) { + d.Init() port := 0 if listen := os.Getenv("DISCOVER_LISTEN"); listen != "" { if _, p, err := net.SplitHostPort(listen); err == nil { @@ -243,7 +266,7 @@ func EasyStart() (string, int) { ln, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) if err != nil { - logError("failed to listen", "err", err) + d.logError("failed to listen", "err", err) return "", 0 } addrInfo := ln.Addr().(*net.TCPAddr) @@ -259,7 +282,7 @@ func EasyStart() (string, int) { if ip4 == nil || !ip4.IsGlobalUnicast() { continue } - if Config.IpPrefix != "" && strings.HasPrefix(ip4.String(), Config.IpPrefix) { + if d.Config.IpPrefix != "" && strings.HasPrefix(ip4.String(), d.Config.IpPrefix) { ip = ip4 break } @@ -271,7 +294,7 @@ func EasyStart() (string, int) { } addr := fmt.Sprintf("%s:%d", ip.String(), port) - if !Start(addr) { + if !d.Start(addr) { return "", 0 } @@ -279,50 +302,51 @@ func EasyStart() (string, int) { signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) go func() { <-sigChan - Stop() + d.Stop() }() return ip.String(), port } // AddExternalApp 动态添加需要发现的外部应用 -func AddExternalApp(app, callConf string) bool { - if addApp(app, callConf, true) { - if !isClient { - startSub() +func (d *Discoverer) AddExternalApp(app, callConf string) bool { + if d.addApp(app, callConf, true) { + if !d.isClient { + d.startSub() } else { - appLock.Lock() - subscribeAppUnderLock(app) - appLock.Unlock() + d.appLock.Lock() + d.subscribeAppUnderLock(app) + d.appLock.Unlock() + d.fetchApp(app) // 同步拉取一次 } return true } return false } -// SetNode 手动设置某个服务的节点信息(不通过注册中心) -func SetNode(app, addr string, weight int) { - pushNode(app, addr, weight) +// SetNode 手动设置某个服务的节点信息 +func (d *Discoverer) SetNode(app, addr string, weight int) { + d.pushNode(app, addr, weight) } -func getCallInfo(app string) *callInfoType { - appLock.RLock() - defer appLock.RUnlock() - return _calls[app] +func (d *Discoverer) getCallInfo(app string) *callInfoType { + d.appLock.RLock() + defer d.appLock.RUnlock() + return d.calls[app] } var numberMatcher = regexp.MustCompile(`^\d+(s|ms|us|µs|ns?)?$`) -func addApp(app, callConf string, fetch bool) bool { - appLock.Lock() - if Config.Calls == nil { - Config.Calls = make(map[string]string) +func (d *Discoverer) addApp(app, callConf string, fetch bool) bool { + d.appLock.Lock() + if d.Config.Calls == nil { + d.Config.Calls = make(map[string]string) } - if Config.Calls[app] == callConf && _appNodes[app] != nil { - appLock.Unlock() + if d.Config.Calls[app] == callConf && d.appNodes[app] != nil { + d.appLock.Unlock() return false } - Config.Calls[app] = callConf + d.Config.Calls[app] = callConf callInfo := &callInfoType{ Timeout: 10 * time.Second, @@ -354,29 +378,29 @@ func addApp(app, callConf string, fetch bool) bool { } } - _calls[app] = callInfo - if _appNodes[app] == nil { - _appNodes[app] = make(map[string]*NodeInfo) + d.calls[app] = callInfo + if d.appNodes[app] == nil { + d.appNodes[app] = make(map[string]*NodeInfo) } - appSubscribed[app] = true - appLock.Unlock() + d.appSubscribed[app] = true + d.appLock.Unlock() - if fetch { - fetchApp(app) + if fetch && d.isClient { + d.fetchApp(app) } return true } -func fetchApp(app string) { - appLock.RLock() - pool := clientRedisPool - appLock.RUnlock() +func (d *Discoverer) fetchApp(app string) { + d.appLock.RLock() + pool := d.clientRedisPool + d.appLock.RUnlock() if pool == nil { return } results := pool.Do("HGETALL", app).ResultMap() - + // 检查存活 for addr := range results { if !pool.Do("EXISTS", app+"_"+addr).Bool() { @@ -385,84 +409,132 @@ func fetchApp(app string) { } } - currentNodes := getAppNodes(app) + currentNodes := d.getAppNodes(app) if currentNodes != nil { for addr := range currentNodes { if _, ok := results[addr]; !ok { - pushNode(app, addr, 0) + d.pushNode(app, addr, 0) } } } for addr, res := range results { - pushNode(app, addr, res.Int()) + d.pushNode(app, addr, res.Int()) } } -func getAppNodes(app string) map[string]*NodeInfo { - appLock.RLock() - defer appLock.RUnlock() - if _appNodes[app] == nil { +func (d *Discoverer) getAppNodes(app string) map[string]*NodeInfo { + d.appLock.RLock() + defer d.appLock.RUnlock() + if d.appNodes[app] == nil { return nil } nodes := make(map[string]*NodeInfo) - for k, v := range _appNodes[app] { + for k, v := range d.appNodes[app] { nodes[k] = v } return nodes } -func getCalls() map[string]string { - appLock.RLock() - defer appLock.RUnlock() +func (d *Discoverer) getCalls() map[string]string { + d.appLock.RLock() + defer d.appLock.RUnlock() calls := make(map[string]string) - for k, v := range Config.Calls { + for k, v := range d.Config.Calls { calls[k] = v } return calls } // GetAppNodes 获取某个应用的所有节点列表 -func GetAppNodes(app string) map[string]*NodeInfo { - return getAppNodes(app) +func (d *Discoverer) GetAppNodes(app string) map[string]*NodeInfo { + return d.getAppNodes(app) } -func pushNode(app, addr string, weight int) { - appLock.Lock() - defer appLock.Unlock() +func (d *Discoverer) pushNode(app, addr string, weight int) { + d.appLock.Lock() + defer d.appLock.Unlock() if weight <= 0 { - if _appNodes[app] != nil { - delete(_appNodes[app], addr) + if d.appNodes[app] != nil { + delete(d.appNodes[app], addr) } return } - if _appNodes[app] == nil { - _appNodes[app] = make(map[string]*NodeInfo) + if d.appNodes[app] == nil { + d.appNodes[app] = make(map[string]*NodeInfo) } - if node, ok := _appNodes[app][addr]; ok { + if node, ok := d.appNodes[app][addr]; ok { if node.Weight != weight { - // 调整 UsedTimes 保持相对均衡,使用 Load() 和 Store() used := node.UsedTimes.Load() node.UsedTimes.Store(uint64(float64(used) / float64(node.Weight) * float64(weight))) node.Weight = weight } } else { var avgUsed uint64 = 0 - if len(_appNodes[app]) > 0 { + if len(d.appNodes[app]) > 0 { var totalScore float64 - for _, n := range _appNodes[app] { + for _, n := range d.appNodes[app] { totalScore += float64(n.UsedTimes.Load()) / float64(n.Weight) } - avgUsed = uint64(totalScore / float64(len(_appNodes[app])) * float64(weight)) + avgUsed = uint64(totalScore / float64(len(d.appNodes[app])) * float64(weight)) } node := &NodeInfo{ Addr: addr, Weight: weight, } node.UsedTimes.Store(avgUsed) - _appNodes[app][addr] = node + d.appNodes[app][addr] = node } } + +// 以下是包级别 API,通过转发给 DefaultDiscoverer 实现兼容性 + +func IsServer() bool { return DefaultDiscoverer.IsServer() } +func IsClient() bool { return DefaultDiscoverer.IsClient() } + +func logError(msg string, extra ...any) { + DefaultDiscoverer.logError(msg, extra...) +} + +func logInfo(msg string, extra ...any) { + DefaultDiscoverer.logInfo(msg, extra...) +} + +func SetLogger(logger *log.Logger) { + DefaultDiscoverer.SetLogger(logger) +} + +func Init() { + DefaultDiscoverer.Init() +} + +func Start(addr string) bool { + return DefaultDiscoverer.Start(addr) +} + +func Stop() { + DefaultDiscoverer.Stop() +} + +func Wait() { + DefaultDiscoverer.Wait() +} + +func EasyStart() (string, int) { + return DefaultDiscoverer.EasyStart() +} + +func AddExternalApp(app, callConf string) bool { + return DefaultDiscoverer.AddExternalApp(app, callConf) +} + +func SetNode(app, addr string, weight int) { + DefaultDiscoverer.SetNode(app, addr, weight) +} + +func GetAppNodes(app string) map[string]*NodeInfo { + return DefaultDiscoverer.GetAppNodes(app) +} diff --git a/Discover_test.go b/Discover_test.go index f65a620..ea7f256 100644 --- a/Discover_test.go +++ b/Discover_test.go @@ -46,8 +46,8 @@ func TestDiscover(t *testing.T) { defer server.Close() // 配置 Discover - discover.Config.App = "test-app" - discover.Config.Registry = "redis://127.0.0.1:6379/15" + discover.DefaultDiscoverer.Config.App = "test-app" + discover.DefaultDiscoverer.Config.Registry = "redis://127.0.0.1:6379/15" // 启动 Discover if !discover.Start("127.0.0.1:18001") { diff --git a/LoadBalancer.go b/LoadBalancer.go index fefefbd..b137b73 100644 --- a/LoadBalancer.go +++ b/LoadBalancer.go @@ -7,7 +7,12 @@ import ( // SetLoadBalancer 设置全局负载均衡策略 func SetLoadBalancer(lb LoadBalancer) { - settedLoadBalancer = lb + DefaultDiscoverer.SetLoadBalancer(lb) +} + +// SetLoadBalancer 设置负载均衡策略 +func (d *Discoverer) SetLoadBalancer(lb LoadBalancer) { + d.settedLoadBalancer = lb } // LoadBalancer 负载均衡接口 diff --git a/MultiInstance_test.go b/MultiInstance_test.go new file mode 100644 index 0000000..1d643c8 --- /dev/null +++ b/MultiInstance_test.go @@ -0,0 +1,74 @@ +package discover_test + +import ( + "fmt" + "net" + "net/http" + "testing" + "time" + + "apigo.cc/go/discover" +) + +func TestMultipleDiscoverer(t *testing.T) { + // 启动两个模拟服务 + l1, _ := net.Listen("tcp", "127.0.0.1:18011") + mux1 := http.NewServeMux() + mux1.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("OK1")) }) + server1 := &http.Server{Handler: mux1} + go func() { _ = server1.Serve(l1) }() + defer server1.Close() + + l2, _ := net.Listen("tcp", "127.0.0.1:18012") + mux2 := http.NewServeMux() + mux2.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("OK2")) }) + server2 := &http.Server{Handler: mux2} + go func() { _ = server2.Serve(l2) }() + defer server2.Close() + + registry := "redis://127.0.0.1:6379/15" + + // 实例 1 + d1 := discover.NewDiscoverer() + d1.Config.App = "app1" + d1.Config.Registry = registry + if !d1.Start("127.0.0.1:18011") { + t.Skip("redis not available") + } + defer d1.Stop() + + // 实例 2 + d2 := discover.NewDiscoverer() + d2.Config.App = "app2" + d2.Config.Registry = registry + if !d2.Start("127.0.0.1:18012") { + t.Skip("redis not available") + } + defer d2.Stop() + + // 实例 1 发现并调用自己 + d1.AddExternalApp("app1", "1") + time.Sleep(200 * time.Millisecond) // 等待同步 + c1 := d1.NewCaller(nil, nil) + res1 := c1.Get("app1", "/") + if res1.Error != nil || res1.String() != "OK1" { + t.Errorf("d1 call app1 failed: %v, %s", res1.Error, res1.String()) + } + + // 实例 2 发现并调用 实例 1 + d2.AddExternalApp("app1", "1") + time.Sleep(200 * time.Millisecond) // 等待同步 + c2 := d2.NewCaller(nil, nil) + res2 := c2.Get("app1", "/") + if res2.Error != nil || res2.String() != "OK1" { + t.Errorf("d2 call app1 failed: %v, %s", res2.Error, res2.String()) + } + + // 验证独立性:d1 不应该能直接调用 app2 (除非手动 AddExternalApp) + res3 := c1.Get("app2", "/") + if res3.Error == nil { + t.Error("d1 should not find app2 without AddExternalApp") + } + + fmt.Println("Multiple Discoverer instances verified") +} diff --git a/Route.go b/Route.go index c339f78..a261399 100644 --- a/Route.go +++ b/Route.go @@ -2,7 +2,12 @@ package discover import "net/http" -// SetRoute 设置全局路由规则,可以在请求前修改 App、Method、Path 等信息 +// SetRoute 设置全局路由规则 func SetRoute(route func(appClient *AppClient, request *http.Request)) { - settedRoute = route + DefaultDiscoverer.SetRoute(route) +} + +// SetRoute 设置路由规则 +func (d *Discoverer) SetRoute(route func(appClient *AppClient, request *http.Request)) { + d.settedRoute = route }