starter/starter.go

485 lines
11 KiB
Go
Raw Normal View History

2026-05-10 15:53:17 +08:00
package starter
import (
"context"
"flag"
"fmt"
"io"
"net"
2026-05-10 15:53:17 +08:00
"os"
"os/exec"
"os/signal"
"path/filepath"
"sort"
"sync"
2026-05-10 15:53:17 +08:00
"syscall"
"time"
"apigo.cc/go/crypto"
"apigo.cc/go/id"
2026-05-10 15:53:17 +08:00
"apigo.cc/go/log"
)
var (
// Default configuration
appName = filepath.Base(os.Args[0])
appVersion = "1.0.1"
appUsage = ""
2026-05-10 15:53:17 +08:00
// Internal state
commands = make(map[string]*command)
// New Service registry
services = make(map[int][]*managedService)
startedPriorities []int
2026-05-10 15:53:17 +08:00
// Flags
flagSet = flag.NewFlagSet(appName, flag.ContinueOnError)
// IPC Security
ipcSecret = "apigo-starter-secret-2026"
2026-06-05 08:37:22 +08:00
// Starter Logger
starterLogger *log.Logger
// Signal handling
sigs chan os.Signal
cancelCtx context.CancelFunc
2026-05-10 15:53:17 +08:00
)
type managedService struct {
Name string
svc Service
priority int
startTimeout time.Duration
stopTimeout time.Duration
onStarting []func()
onStarted []func()
onStopping []func()
onStopped []func()
}
func (ms *managedService) OnStarting(fn func()) *managedService {
ms.onStarting = append(ms.onStarting, fn)
return ms
}
func (ms *managedService) OnStarted(fn func()) *managedService {
ms.onStarted = append(ms.onStarted, fn)
return ms
}
func (ms *managedService) OnStopping(fn func()) *managedService {
ms.onStopping = append(ms.onStopping, fn)
return ms
}
func (ms *managedService) OnStopped(fn func()) *managedService {
ms.onStopped = append(ms.onStopped, fn)
return ms
2026-05-10 15:53:17 +08:00
}
type command struct {
name string
desc string
fn func()
2026-05-10 15:53:17 +08:00
}
func init() {
AddCommand("start", "Start the service in background", startCmd)
AddCommand("stop", "Stop the service", stopCmd)
AddCommand("restart", "Restart the service", restartCmd)
AddCommand("status", "Show service status", statusCmd)
AddCommand("kill", "Send signal to a specific service: kill <svc_name> <signal_num>", killCmd)
2026-06-05 08:37:22 +08:00
}
2026-06-05 08:37:22 +08:00
// getStarterLogger returns the singleton logger for the starter context.
func getStarterLogger() *log.Logger {
if starterLogger == nil {
starterLogger = log.DefaultLogger.New(id.Get8Bytes4KPerSecond())
}
return starterLogger
2026-05-10 15:53:17 +08:00
}
// Register adds a service to be managed by the starter.
func Register(name string, svc Service, priority int, startTimeout, stopTimeout time.Duration) *managedService {
ms := &managedService{
Name: name,
svc: svc,
priority: priority,
startTimeout: startTimeout,
stopTimeout: stopTimeout,
}
services[priority] = append(services[priority], ms)
return ms
2026-05-10 15:53:17 +08:00
}
// SetAppInfo sets the application name and version.
func SetAppInfo(name, version string) {
appName = name
appVersion = version
2026-05-10 15:53:17 +08:00
}
// SetUsage sets custom usage text to be displayed in help.
func SetUsage(text string) {
appUsage = text
}
// AddCommand adds a custom command.
func AddCommand(name, desc string, fn func()) {
commands[name] = &command{name: name, desc: desc, fn: fn}
2026-05-10 15:53:17 +08:00
}
// TODO 使用 Start / Wait 代替 Run方便在启动后做初始化操作或者支持注册某服务的 OnStarting / OnStarted / OnStopping / OnStopped 事件
// Start parses arguments and starts the services.
func Start() error {
2026-05-10 15:53:17 +08:00
flagSet.Usage = showHelp
if len(os.Args) > 1 {
arg := os.Args[1]
if cmd, ok := commands[arg]; ok {
// Subcommand detected, parse flags after the command
_ = flagSet.Parse(os.Args[2:])
cmd.fn()
os.Exit(0)
2026-05-10 15:53:17 +08:00
}
// Check for help/version
switch arg {
case "-h", "--help", "help":
showHelp()
os.Exit(0)
2026-05-10 15:53:17 +08:00
case "-v", "--version", "version":
fmt.Printf("%s version %s\n", appName, appVersion)
os.Exit(0)
2026-05-10 15:53:17 +08:00
}
}
// No starter command, treat all as app flags
_ = flagSet.Parse(os.Args[1:])
2026-06-05 08:37:22 +08:00
// Setup signal handling for graceful shutdown
sigs = make(chan os.Signal, 1)
2026-06-05 08:37:22 +08:00
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGUSR1, syscall.SIGUSR2)
var ctx context.Context
ctx, cancelCtx = context.WithCancel(context.Background())
2026-05-10 15:53:17 +08:00
if err := startServices(ctx); err != nil {
2026-06-05 08:37:22 +08:00
getStarterLogger().Error(fmt.Sprintf("[starter] start services failed: %v", err))
return err
}
return nil
}
// Wait blocks until a termination signal is received.
func Wait() {
if sigs == nil {
return
}
2026-05-10 15:53:17 +08:00
2026-06-05 08:37:22 +08:00
for {
sig := <-sigs
2026-05-10 15:53:17 +08:00
if sig == syscall.SIGHUP {
reloadServices()
continue
}
// Handle user custom signals
if sig == syscall.SIGUSR1 || sig == syscall.SIGUSR2 {
if !handleUserSignal(nil, sig) {
2026-06-05 08:37:22 +08:00
getStarterLogger().Info(fmt.Sprintf("[starter] received signal %v, but no service handled it.", sig))
2026-05-10 15:53:17 +08:00
}
continue
}
2026-06-05 08:37:22 +08:00
getStarterLogger().Info(fmt.Sprintf("[starter] received signal: %v, shutting down...", sig))
2026-05-10 15:53:17 +08:00
break
}
if cancelCtx != nil {
cancelCtx() // Trigger context cancellation
}
stopServices()
2026-06-05 08:37:22 +08:00
getStarterLogger().Info("[starter] shutdown complete")
}
2026-05-10 15:53:17 +08:00
func startServices(ctx context.Context) error {
var priorities []int
for p := range services {
priorities = append(priorities, p)
}
sort.Ints(priorities)
for _, p := range priorities {
svcs := services[p]
var wg sync.WaitGroup
errChan := make(chan error, len(svcs))
for _, ms := range svcs {
wg.Add(1)
go func(ms *managedService) {
defer wg.Done()
sctx := ctx
if ms.startTimeout > 0 {
var cancel context.CancelFunc
sctx, cancel = context.WithTimeout(ctx, ms.startTimeout)
defer cancel()
}
2026-06-05 08:37:22 +08:00
// Each service gets its own unique 10-byte trace ID for its internal logs
serviceTraceId := id.Get10Bytes14MPerSecond()
serviceLogger := log.DefaultLogger.New(serviceTraceId)
// Log using starter's logger (8-byte ID) but include service's trace ID in extra fields
getStarterLogger().Info(fmt.Sprintf("service [%s] starting", ms.Name), "trace", serviceTraceId)
for _, fn := range ms.onStarting {
fn()
}
2026-06-05 08:37:22 +08:00
if err := ms.svc.Start(sctx, serviceLogger); err != nil {
errChan <- fmt.Errorf("service [%s] start error: %w", ms.Name, err)
2026-06-05 01:12:22 +08:00
} else {
2026-06-05 08:37:22 +08:00
getStarterLogger().Info(fmt.Sprintf("service [%s] started", ms.Name), "trace", serviceTraceId)
for _, fn := range ms.onStarted {
fn()
}
}
}(ms)
}
wg.Wait()
close(errChan)
for err := range errChan {
if err != nil {
return err
}
}
startedPriorities = append(startedPriorities, p)
}
return nil
}
func stopServices() {
sort.Slice(startedPriorities, func(i, j int) bool {
return startedPriorities[i] > startedPriorities[j]
})
2026-06-05 01:12:22 +08:00
for _, p := range startedPriorities {
svcs := services[p]
var wg sync.WaitGroup
for _, ms := range svcs {
wg.Add(1)
go func(ms *managedService) {
defer wg.Done()
sctx := context.Background()
if ms.stopTimeout > 0 {
var cancel context.CancelFunc
sctx, cancel = context.WithTimeout(sctx, ms.stopTimeout)
defer cancel()
}
2026-06-05 08:37:22 +08:00
getStarterLogger().Info(fmt.Sprintf("service [%s] stopping", ms.Name))
for _, fn := range ms.onStopping {
fn()
}
if err := ms.svc.Stop(sctx); err != nil {
2026-06-05 08:37:22 +08:00
getStarterLogger().Error(fmt.Sprintf("service [%s] stop error: %v", ms.Name, err))
2026-06-05 01:12:22 +08:00
} else {
2026-06-05 08:37:22 +08:00
getStarterLogger().Info(fmt.Sprintf("service [%s] stopped", ms.Name))
for _, fn := range ms.onStopped {
fn()
}
}
}(ms)
}
wg.Wait()
}
startedPriorities = nil
}
func reloadServices() {
2026-06-05 08:37:22 +08:00
getStarterLogger().Info("[starter] reloading all services...")
for _, p := range startedPriorities {
for _, ms := range services[p] {
if r, ok := ms.svc.(Reloader); ok {
if err := r.Reload(); err != nil {
2026-06-05 08:37:22 +08:00
getStarterLogger().Error(fmt.Sprintf("service [%s] reload error: %v", ms.Name, err))
}
}
}
2026-05-10 15:53:17 +08:00
}
}
func handleUserSignal(svcName *string, sig os.Signal) bool {
handled := false
for _, p := range startedPriorities {
for _, ms := range services[p] {
2026-06-05 08:37:22 +08:00
if svcName != nil && *svcName != ms.Name {
continue
}
if h, ok := ms.svc.(UserSignalHandler); ok {
if h.HandleUserSignal(sig) {
handled = true
}
}
}
}
return handled
}
2026-06-05 08:37:22 +08:00
func showHelp() {
fmt.Printf("%s (%s)\n\n", appName, appVersion)
if appUsage != "" {
fmt.Printf("%s\n\n", appUsage)
}
2026-06-05 08:37:22 +08:00
fmt.Printf("Usage:\n %s [command] [options]\n\nCommands:\n", filepath.Base(os.Args[0]))
2026-06-05 08:37:22 +08:00
var names []string
for cmdName := range commands {
names = append(names, cmdName)
}
2026-06-05 08:37:22 +08:00
sort.Strings(names)
2026-06-05 08:37:22 +08:00
for _, name := range names {
fmt.Printf(" %-10s %s\n", name, commands[name].desc)
}
2026-06-05 08:37:22 +08:00
fmt.Println()
2026-05-10 15:53:17 +08:00
}
func startCmd() {
2026-06-05 08:37:22 +08:00
pid := getPid()
if pid != 0 && isProcessRunning(pid) {
getStarterLogger().Info(fmt.Sprintf("[starter] %s is already running (PID %d)", appName, pid))
2026-05-10 15:53:17 +08:00
return
}
2026-06-05 08:37:22 +08:00
cmd := exec.Command(os.Args[0])
cmd.Args = append([]string{os.Args[0]}, flagSet.Args()...)
cmd.Stdout = nil
cmd.Stderr = nil
cmd.SysProcAttr = &syscall.SysProcAttr{Setsid: true}
2026-05-10 15:53:17 +08:00
2026-06-05 08:37:22 +08:00
if err := cmd.Start(); err != nil {
getStarterLogger().Error(fmt.Sprintf("[starter] failed to start %s: %v", appName, err))
return
2026-05-10 15:53:17 +08:00
}
2026-06-05 08:37:22 +08:00
_ = os.WriteFile(getPidPath(), []byte(fmt.Sprintf("%d", cmd.Process.Pid)), 0644)
getStarterLogger().Info(fmt.Sprintf("[starter] %s started (PID %d)", appName, cmd.Process.Pid))
2026-05-10 15:53:17 +08:00
}
func stopCmd() {
2026-06-05 08:37:22 +08:00
pid := getPid()
if pid == 0 || !isProcessRunning(pid) {
getStarterLogger().Info(fmt.Sprintf("[starter] %s is not running", appName))
2026-05-10 15:53:17 +08:00
return
}
2026-06-05 08:37:22 +08:00
getStarterLogger().Info(fmt.Sprintf("[starter] stopping %s (PID %d)...", appName, pid))
_ = syscall.Kill(pid, syscall.SIGTERM)
2026-05-10 15:53:17 +08:00
2026-06-05 08:37:22 +08:00
for i := 0; i < 30; i++ {
if !isProcessRunning(pid) {
removePid()
getStarterLogger().Info(fmt.Sprintf("[starter] %s stopped", appName))
return
2026-05-10 15:53:17 +08:00
}
2026-06-05 08:37:22 +08:00
time.Sleep(500 * time.Millisecond)
2026-05-10 15:53:17 +08:00
}
2026-06-05 08:37:22 +08:00
getStarterLogger().Warning(fmt.Sprintf("[starter] %s failed to stop gracefully, killing...", appName))
_ = syscall.Kill(pid, syscall.SIGKILL)
2026-05-10 15:53:17 +08:00
removePid()
}
func restartCmd() {
stopCmd()
startCmd()
}
func statusCmd() {
2026-06-05 08:37:22 +08:00
pid := getPid()
if pid == 0 || !isProcessRunning(pid) {
fmt.Printf("%s is NOT running\n", appName)
return
}
fmt.Printf("%s is running (PID %d)\n\n", appName, pid)
res, err := callIPC(pid, "status")
if err == nil {
fmt.Println("Services Status:")
fmt.Println(res)
}
}
func killCmd() {
2026-06-05 08:37:22 +08:00
if flagSet.NArg() < 2 {
fmt.Println("Usage: kill <service_name> <signal_num>")
return
}
svcName := flagSet.Arg(0)
sigNum := flagSet.Arg(1)
2026-06-05 08:37:22 +08:00
pid := getPid()
if pid == 0 || !isProcessRunning(pid) {
fmt.Println("Application is not running")
return
}
res, err := callIPC(pid, fmt.Sprintf("kill %s %s", svcName, sigNum))
if err != nil {
fmt.Printf("Error: %v\n", err)
} else {
fmt.Println(res)
}
}
func callIPC(pid int, cmd string) (string, error) {
2026-06-05 08:37:22 +08:00
sockPath := getSockPath()
conn, err := net.Dial("unix", sockPath)
if err != nil {
return "", err
}
defer conn.Close()
2026-06-05 08:37:22 +08:00
token := generateToken(pid)
_, _ = conn.Write([]byte(fmt.Sprintf("%s %s", token, cmd)))
2026-06-05 08:37:22 +08:00
buf, err := io.ReadAll(conn)
if err != nil {
return "", err
2026-05-10 15:53:17 +08:00
}
2026-06-05 08:37:22 +08:00
return string(buf), nil
}
2026-06-05 08:37:22 +08:00
func generateToken(pid int) string {
return crypto.Sha256ToHex([]byte(fmt.Sprintf("%s:%d", ipcSecret, pid)))
}
func getPidPath() string {
return filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.pid", appName, appVersion))
}
func getSockPath() string {
return filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.sock", appName, appVersion))
2026-05-10 15:53:17 +08:00
}
2026-06-05 08:37:22 +08:00
func getPid() int {
data, err := os.ReadFile(getPidPath())
2026-05-10 15:53:17 +08:00
if err != nil {
return 0
}
2026-06-05 08:37:22 +08:00
var pid int
fmt.Sscanf(string(data), "%d", &pid)
return pid
2026-05-10 15:53:17 +08:00
}
func removePid() {
_ = os.Remove(getPidPath())
2026-05-10 15:53:17 +08:00
}
func isProcessRunning(p int) bool {
process, err := os.FindProcess(p)
if err != nil {
return false
}
err = process.Signal(syscall.Signal(0))
return err == nil
}