package service import ( "bytes" _ "embed" "encoding/json" "errors" "fmt" "os" "path/filepath" "reflect" "regexp" "strings" "sync" "syscall" "text/template" "time" "apigo.cc/gojs" "apigo.cc/gojs/goja" "github.com/gorilla/websocket" "github.com/ssgo/config" "github.com/ssgo/discover" "github.com/ssgo/log" "github.com/ssgo/redis" "github.com/ssgo/s" "github.com/ssgo/standard" "github.com/ssgo/u" ) //go:embed service.ts var serviceTS string //go:embed README.md var serviceMD string var server *s.AsyncServer var pools = map[string]*gojs.Pool{} var poolExists = map[string]bool{} var poolActionRegistered = map[string]bool{} var poolsLock = sync.RWMutex{} var waitChan chan bool type TplCache struct { FileModTime map[string]int64 Tpl *template.Template } var tplFunc = map[string]any{} var tplFuncLock = sync.RWMutex{} var tplCache = map[string]*TplCache{} var tplCacheLock = sync.RWMutex{} type LimiterConfig struct { From string Time int Times int } type Config struct { SessionKey string DeviceKey string ClientKey string UserIdKey string SessionProvider string SessionTimeout int64 AuthFieldMessage string VerifyFieldMessage string LimitedMessage string Limiters map[string]*LimiterConfig LimiterRedis string Proxy map[string]string Rewrite map[string]string Static map[string]string } var serviceConfig Config var onStop goja.Callable var limiters = map[string]*s.Limiter{} var configed = false func initConfig(opt *gojs.Obj, logger *log.Logger, vm *goja.Runtime) { configed = true s.InitConfig() if startPath, ok := vm.GoData["startPath"]; ok { s.SetWorkPath(u.String(startPath)) } // 处理配置 serviceConfig = Config{"Session", "Device", "Client", "userId", "", 3600, "auth failed", "verify failed", "too many requests", nil, "", map[string]string{}, map[string]string{}, map[string]string{}} if errs := config.LoadConfig("service", &serviceConfig); errs != nil && len(errs) > 0 { panic(vm.NewGoError(errs[0])) } config.LoadConfig("service", &discover.Config) // var auth goja.Callable if opt != nil { u.Convert(opt.O.Export(), &serviceConfig) u.Convert(opt.O.Export(), &s.Config) u.Convert(opt.O.Export(), &discover.Config) // auth = conf.Func("auth") onStop = opt.Func("onStop") if serviceConfig.SessionProvider != "" { sessionRedis = redis.GetRedis(serviceConfig.SessionProvider, logger) } sessionTimeout = serviceConfig.SessionTimeout if sessionTimeout < 0 { sessionTimeout = 0 } } // 身份验证和Session authAccessToken := len(s.Config.AccessTokens) > 0 s.SetClientKeys(serviceConfig.DeviceKey, serviceConfig.ClientKey, serviceConfig.SessionKey) if serviceConfig.SessionKey != "" { s.SetAuthChecker(func(authLevel int, logger *log.Logger, url *string, args map[string]any, request *s.Request, response *s.Response, options *s.WebServiceOptions) (pass bool, object any) { var session *Session setAuthLevel := 0 if serviceConfig.SessionKey != "" { sessionID := request.GetSessionId() if sessionID != "" { session = NewSession(sessionID, logger) } if userId, ok := session.data[serviceConfig.UserIdKey]; ok { request.SetUserId(u.String(userId)) } // 优先使用session中的authLevel if authLevel > 0 { if authLevelBySession, ok := session.data["_authLevel"]; ok { setAuthLevel = u.Int(authLevelBySession) } } } // 如果没有session中的authLevel验证失败,则使用Access-Token中的authLevel(服务间调用) if authAccessToken && setAuthLevel < authLevel { setAuthLevel = s.GetAuthTokenLevel(request.Header.Get("Access-Token")) } if setAuthLevel >= authLevel { return true, session } else { msg := serviceConfig.AuthFieldMessage if strings.Contains(msg, "{{") { msg = strings.ReplaceAll(msg, "{{TARGET_AUTHLEVEL}}", u.String(authLevel)) msg = strings.ReplaceAll(msg, "{{USER_AUTHLEVEL}}", u.String(setAuthLevel)) } var obj any if json.Unmarshal([]byte(msg), &obj) == nil { return false, obj } else { return false, msg } } }) s.Init() } // 限流器 if serviceConfig.Limiters != nil { var limiterRedis *redis.Redis if serviceConfig.LimiterRedis != "" { limiterRedis = redis.GetRedis(serviceConfig.LimiterRedis, logger) } for name, limiter := range serviceConfig.Limiters { switch limiter.From { case "ip": limiter.From = "header." + standard.DiscoverHeaderClientIp case "user": limiter.From = "header." + standard.DiscoverHeaderUserId case "device": limiter.From = "header." + standard.DiscoverHeaderDeviceId } if limiterRedis != nil { limiters[name] = s.NewLimiter(name, limiter.From, time.Duration(limiter.Time)*time.Millisecond, limiter.Times, limiterRedis) } else { limiters[name] = s.NewLocalLimiter(name, limiter.From, time.Duration(limiter.Time)*time.Millisecond, limiter.Times) } } } } func init() { s.Config.KeepKeyCase = true obj := map[string]any{ "config": func(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value { args := gojs.MakeArgs(&argsIn, vm) initConfig(args.Obj(0), args.Logger, vm) return nil }, "start": func(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value { if !configed { initConfig(nil, gojs.GetLogger(vm), vm) // panic(vm.NewGoError(errors.New("must run service.config frist"))) } if server != nil { panic(vm.NewGoError(errors.New("server already started"))) } // 处理静态文件 if len(serviceConfig.Static) > 0 { UpdateStatic(serviceConfig.Static) } if len(serviceConfig.Rewrite) > 0 { UpdateRewrite(serviceConfig.Rewrite) s.SetRewriteBy(rewrite) } if len(serviceConfig.Proxy) > 0 { UpdateProxy(serviceConfig.Proxy) s.SetProxyBy(proxy) } // 处理Watch if vm.GoData["inWatch"] == true { onWatchConn := map[string]*websocket.Conn{} onWatchLock := sync.Mutex{} vm.GoData["onWatch"] = func(filename string) { onWatchLock.Lock() defer onWatchLock.Unlock() for id, conn := range onWatchConn { if err := conn.WriteMessage(websocket.TextMessage, []byte(filename)); err != nil { delete(onWatchConn, id) } } } s.AddShutdownHook(func() { for _, conn := range onWatchConn { conn.Close() } }) s.RegisterSimpleWebsocket(0, "/_watch", func(request *s.Request, conn *websocket.Conn) { onWatchLock.Lock() onWatchConn[request.Id] = conn onWatchLock.Unlock() }, "") s.SetOutFilter(func(in map[string]any, request *s.Request, response *s.Response, out any, logger *log.Logger) (newOut any, isOver bool) { if strings.HasPrefix(response.Header().Get("Content-Type"), "text/html") { outStr := u.String(out) // 注入自动刷新的代码 outStr = strings.ReplaceAll(outStr, "", ` `) return []byte(outStr), false } return nil, false }) } // 启动服务 server = s.AsyncStart() waitChan = make(chan bool, 1) server.OnStop(func() { if onStop != nil { onStop(nil) } }) server.OnStopped(func() { ClearRewritesAndProxies() pools = map[string]*gojs.Pool{} poolExists = map[string]bool{} poolActionRegistered = map[string]bool{} server = nil if waitChan != nil { waitChan <- true } }) return vm.ToValue(server.Addr) }, "stop": func(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value { if server == nil { panic(vm.NewGoError(errors.New("server not started"))) } server.Stop() return nil }, "uniqueId": func(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value { args := gojs.MakeArgs(&argsIn, vm) size := args.Int(0) var id string if size >= 20 { id = s.UniqueId20() } else if size >= 16 { id = s.UniqueId16() } else if size >= 14 { id = s.UniqueId14() } else if size >= 12 { id = s.UniqueId14()[0:12] } else { id = s.UniqueId() } return vm.ToValue(id) }, "uniqueIdL": func(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value { args := gojs.MakeArgs(&argsIn, vm) size := args.Int(0) var id string if size >= 20 { id = s.UniqueId20() } else if size >= 16 { id = s.UniqueId16() } else if size >= 14 { id = s.UniqueId14() } else if size >= 12 { id = s.UniqueId14()[0:12] } else { id = s.UniqueId() } return vm.ToValue(strings.ToLower(id)) }, "id": func(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value { args := gojs.MakeArgs(&argsIn, vm).Check(1) space := args.Str(0) size := args.Int(1) var id string if size >= 12 { id = s.Id12(space) } else if size >= 10 { id = s.Id10(space) } else if size >= 8 { id = s.Id8(space) } else { id = s.Id6(space) } return vm.ToValue(id) }, "idL": func(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value { args := gojs.MakeArgs(&argsIn, vm).Check(1) space := args.Str(0) size := args.Int(1) var id string if size >= 12 { id = s.Id12(space) } else if size >= 10 { id = s.Id10(space) } else if size >= 8 { id = s.Id8(space) } else { id = s.Id6(space) } return vm.ToValue(strings.ToLower(id)) }, "register": func(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value { args := gojs.MakeArgs(&argsIn, vm).Check(2) o := args.Obj(0) action := args.Func(1) if action == nil { panic(vm.NewGoError(errors.New("action must be a callback function"))) } authLevel := o.Int("authLevel") host := o.Str("host") method := strings.ToUpper(o.Str("method")) path := o.Str("path") memo := o.Str("memo") var usedLimiters []*s.Limiter for _, limiterName := range o.Array("limiters") { if limiter1, ok := limiters[u.String(limiterName)]; ok { usedLimiters = append(usedLimiters, limiter1) } } if requiresObj := o.Array("requires"); requiresObj != nil { requires := make([]string, len(requiresObj)) for i, require := range requiresObj { requires[i] = u.String(require) } vm.GoData[fmt.Sprint("REQUIRE_"+host, method, path)] = requires } if verifiesObj := o.Obj("verifies"); verifiesObj != nil { verifiesSet := map[string]func(any, *goja.Runtime) bool{} for _, field := range verifiesObj.Keys() { v := verifiesObj.Get(field) // 根据类型设置验证器,存储到vm中,在请求到达时进行有效性验证 switch v.ExportType().Kind() { case reflect.String: verifiesSet[field] = verifyRegexp(u.String(v.Export())) case reflect.Int, reflect.Int64: verifiesSet[field] = verifyLen(u.Int(v.Export())) case reflect.Bool: verifiesSet[field] = verifyRequire() case reflect.Slice: list := make([]string, 0) vm.ForOf(v, func(v1 goja.Value) bool { list = append(list, u.String(v1.Export())) return true }) verifiesSet[field] = verifyIn(list) default: if fn, ok := goja.AssertFunction(v); ok { verifiesSet[field] = verifyFunc(fn, verifiesObj) } else if obj := v.ToObject(vm); obj != nil { // 支持传入js的正则表达式对象 if testV := obj.Get("test"); testV != nil { if testFn, ok := goja.AssertFunction(testV); ok { verifiesSet[field] = verifyFunc(testFn, obj) } } } } } vm.GoData[fmt.Sprint("VERIFY_"+host, method, path)] = verifiesSet } opt := s.WebServiceOptions{ NoBody: o.Bool("noBody"), NoLog200: o.Bool("noLog200"), Host: host, //Ext: nil, Limiters: usedLimiters, } startFile := u.String(vm.GoData["startFile"]) poolsLock.RLock() poolExist := poolExists[startFile] poolsLock.RUnlock() if poolExist { // 从对象调用(支持并发) actionKey := "REGISTER_" + host + method + path vm.GoData[actionKey] = action vm.GoData[actionKey+"This"] = args.This if method == "WS" { vm.GoData[actionKey+"onMessage"] = o.Func("onMessage") vm.GoData[actionKey+"onClose"] = o.Func("onClose") } poolsLock.Lock() actionRegistered := poolActionRegistered[actionKey] if !actionRegistered { poolActionRegistered[actionKey] = true } poolsLock.Unlock() if !actionRegistered { if strings.ToUpper(method) == "WS" { s.RegisterWebsocketWithOptions(authLevel, path, &websocket.Upgrader{}, makeWSAction(startFile, actionKey), nil, nil, nil, true, "", opt) } else { s.RestfulWithOptions(authLevel, method, path, makeOuterAction(startFile, actionKey), memo, opt) } } } else { // 无对象池,直接调用(单线程) s.RestfulWithOptions(authLevel, method, path, makeInnerAction(action, vm, args.This), memo, opt) } return nil }, "load": func(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value { args := gojs.MakeArgs(&argsIn, vm).Check(1) actionFile := args.Path(0) opt := args.Obj(1) var mi, ma, idle uint //var num uint mainArgs := make([]any, 0) debug := false if opt != nil { mi = uint(opt.Int("min")) ma = uint(opt.Int("max")) idle = uint(opt.Int("idle")) //num = uint(opt.Int("num")) debug = opt.Bool("debug") if mainArgs1 := opt.Array("args"); mainArgs1 != nil { mainArgs = mainArgs1 } } if !u.FileExists(actionFile) { fullActionFile, _ := filepath.Abs(actionFile) panic(vm.NewGoError(errors.New("actionFile must be a js file path: " + fullActionFile))) } actionCode := u.ReadFileN(actionFile) if !strings.Contains(actionCode, "function main(") || !strings.Contains(actionCode, ".register(") { panic(vm.NewGoError(errors.New("actionFile must be a js file with main function and call service.register"))) } poolsLock.Lock() poolExists[actionFile] = true poolsLock.Unlock() p := gojs.NewPoolByCode(actionCode, actionFile, gojs.PoolConfig{ Min: mi, Max: ma, Idle: idle, Debug: debug, Args: mainArgs, }, args.Logger) //p := gojs.NewLBByCode(actionCode, actionFile, gojs.LBConfig{ // Num: num, // Debug: debug, // Args: mainArgs, //}, args.Logger) poolsLock.Lock() pools[actionFile] = p poolsLock.Unlock() return nil }, "task": func(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value { args := gojs.MakeArgs(&argsIn, vm).Check(1) taskFile := args.Path(0) interval := args.Int(1) if interval == 0 { interval = 1000 } if interval < 100 { interval = 100 } if !u.FileExists(taskFile) { panic(vm.NewGoError(errors.New("taskFile must be a js file path"))) } rt := gojs.New() _, err := rt.RunFile(taskFile) if err != nil { panic(vm.NewGoError(err)) } s.NewTimerServer(taskFile, time.Duration(interval)*time.Millisecond, func(isRunning *bool) { rt.RunCode("if(onRun)onRun()") }, func() { rt.RunCode("if(onStart)onStart()") }, func() { rt.RunCode("if(onStop)onStop()") }) return nil }, "setTplFunc": func(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value { args := gojs.MakeArgs(&argsIn, vm).Check(1) fnObj := args.Obj(0) if fnObj != nil { fnList := map[string]any{} for _, k := range fnObj.O.Keys() { if jsFunc := fnObj.Func(k); jsFunc != nil { fn := func(args ...any) any { jsArgs := make([]goja.Value, len(args)) for i := 0; i < len(args); i++ { jsArgs[i] = vm.ToValue(args[i]) } if r, err := jsFunc(argsIn.This, jsArgs...); err == nil { return r.Export() } else { panic(vm.NewGoError(err)) } } fnList[k] = fn } } if len(fnList) > 0 { tplFuncLock.Lock() for k, v := range fnList { tplFunc[k] = v } tplFuncLock.Unlock() } } return nil }, "tpl": func(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value { args := gojs.MakeArgs(&argsIn, vm).Check(2) filename := args.Path(0) info := u.GetFileInfo(filename) if info == nil { panic(vm.NewGoError(errors.New("tpl file " + filename + " not exists"))) } data := args.Any(1) tplCacheLock.RLock() t := tplCache[filename] tplCacheLock.RUnlock() if t != nil { for f, tm := range t.FileModTime { info := u.GetFileInfo(f) if info == nil || info.ModTime.UnixMilli() != tm { t = nil break } } } if t == nil { tpl := template.New("main") if len(tplFunc) > 0 { tpl = tpl.Funcs(tplFunc) } fileModTime := map[string]int64{ filename: info.ModTime.UnixMilli(), } var err error for _, m := range tplIncludeMatcher.FindAllStringSubmatch(u.ReadFileN(filename), -1) { includeFilename := m[1] info2 := u.GetFileInfo(includeFilename) if info2 == nil { includeFilename = filepath.Join(filepath.Dir(filename), m[1]) info2 = u.GetFileInfo(includeFilename) } if info2 != nil { tpl, err = tpl.Parse(`{{ define "` + m[1] + `" }}` + u.ReadFileN(includeFilename) + `{{ end }}`) if err != nil { panic(vm.NewGoError(err)) } fileModTime[includeFilename] = info2.ModTime.UnixMilli() } } tpl, err = tpl.ParseFiles(filename) if err != nil { panic(vm.NewGoError(err)) } t = &TplCache{ Tpl: tpl, FileModTime: fileModTime, } } buf := bytes.NewBuffer(make([]byte, 0)) err := t.Tpl.ExecuteTemplate(buf, filepath.Base(filename), data) if err != nil { panic(vm.NewGoError(err)) } return vm.ToValue(buf.String()) }, "dataSet": DataSet, "dataGet": DataGet, "dataKeys": DataKeys, "dataCount": DataCount, "dataFetch": DataFetch, "dataRemove": DataRemove, "listPop": ListPop, "listPush": ListPush, "listCount": ListCount, "listRemove": ListRemove, "newCaller": NewCaller, } gojs.Register("apigo.cc/gojs/service", gojs.Module{ Object: obj, TsCode: serviceTS, Example: serviceMD, OnKill: func() { if server != nil { server.Stop() } }, OnSignal: func(sig os.Signal) { switch sig { case syscall.SIGUSR1: } }, WaitForStop: func() { if waitChan != nil { <-waitChan } }, }) } var tplIncludeMatcher = regexp.MustCompile(`{{\s*template\s+"([^"]+)"`) func verifyRegexp(regexpStr string) func(any, *goja.Runtime) bool { if rx, err := regexp.Compile(regexpStr); err != nil { return func(value any, vm *goja.Runtime) bool { return rx.MatchString(u.String(value)) } } return nil } func verifyLen(checkLength int) func(any, *goja.Runtime) bool { return func(value any, vm *goja.Runtime) bool { v := u.FinalType(reflect.ValueOf(value)) if v.Kind() == reflect.Slice || v.Kind() == reflect.Map { return v.Len() >= checkLength } return len(u.String(value)) >= checkLength } } func verifyRequire() func(any, *goja.Runtime) bool { return func(value any, vm *goja.Runtime) bool { value = u.FixPtr(value) switch realValue := value.(type) { case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64: return u.Int64(realValue) != 0 case bool: return realValue default: v := u.FinalType(reflect.ValueOf(value)) if v.Kind() == reflect.Slice || v.Kind() == reflect.Map { return v.Len() >= 0 } return len(u.String(value)) >= 0 } } } func verifyIn(list []string) func(any, *goja.Runtime) bool { return func(value any, vm *goja.Runtime) bool { return u.StringIn(list, u.String(value)) } } func verifyFunc(callback goja.Callable, thisObj goja.Value) func(any, *goja.Runtime) bool { return func(value any, vm *goja.Runtime) bool { if r, err := callback(thisObj, vm.ToValue(value)); err == nil { return u.Bool(r.Export()) } else { log.DefaultLogger.Error(err.Error()) } return false } } func makeRequestParams(args map[string]any, headers map[string]string, request *s.Request, response *s.Response, client *websocket.Conn, caller *discover.Caller, session *Session, logger *log.Logger) (gojs.Map, *Response) { var resp *Response params := gojs.Map{ "args": args, "logger": gojs.MakeLogger(logger), "request": MakeRequest(request, args, headers), "client": MakeWSClient(client, request.Id), } if response != nil { resp = &Response{ resp: response, endCh: make(chan bool, 1), Id: response.Id, } params["response"] = gojs.MakeMap(resp) } if headers != nil { params["headers"] = headers } if session != nil { params["session"] = gojs.MakeMap(session) } if caller != nil { params["caller"] = gojs.MakeMap(&Caller{client: caller}) } return params, resp } func runAction(action goja.Callable, vm *goja.Runtime, thisArg goja.Value, args map[string]any, headers map[string]string, request *s.Request, response *s.Response, client *websocket.Conn, caller *discover.Caller, session *Session, logger *log.Logger) (any, error) { vm.CallbackLocker.Lock() defer vm.CallbackLocker.Unlock() // 验证请求参数的有效性 if verifies, ok := vm.GoData["VERIFY_"+u.String(request.Get("registerTag"))].(map[string]func(any, *goja.Runtime) bool); ok { failedFields := make([]string, 0) // 检查必填字段 if requires, ok := vm.GoData["REQUIRE_"+u.String(request.Get("registerTag"))].([]string); ok { for _, requireField := range requires { if _, ok := args[requireField]; !ok { failedFields = append(failedFields, requireField) } } } // 验证数据有效性 for k, v := range args { if verifier, ok1 := verifies[k]; ok1 { if !verifier(v, vm) { failedFields = append(failedFields, k) } } } // 数据有效性验证失败 if len(failedFields) > 0 { response.WriteHeader(400) msg := serviceConfig.VerifyFieldMessage if strings.Contains(msg, "{{") { msg = strings.ReplaceAll(msg, "{{FAILED_FIELDS}}", strings.Join(failedFields, ", ")) } var obj any if json.Unmarshal([]byte(msg), &obj) == nil { return obj, nil } else { return msg, nil } } } requestParams, resp := makeRequestParams(args, headers, request, response, client, caller, session, logger) var r any r1, err := action(thisArg, vm.ToValue(requestParams)) if err == nil && r1 != nil { r = r1.Export() } if err != nil { logger.Error(err.Error()) } if response != nil && r == nil && err == nil { <-resp.endCh r = resp.result } return r, err } func makeInnerAction(action goja.Callable, vm *goja.Runtime, thisArg goja.Value) any { return func(args map[string]any, headers map[string]string, request *s.Request, response *s.Response, caller *discover.Caller, session *Session, logger *log.Logger) any { r, _ := runAction(action, vm, thisArg, args, headers, request, response, nil, caller, session, logger) return r } } func getPool(startFile string) *gojs.Pool { var pool *gojs.Pool for i := 0; i < 10; i++ { poolsLock.RLock() pool = pools[startFile] poolsLock.RUnlock() if pool != nil { return pool } time.Sleep(time.Millisecond * 100) } return nil } func makeOuterAction(startFile string, actionKey string) any { return func(args map[string]any, headers map[string]string, request *s.Request, response *s.Response, caller *discover.Caller, session *Session, logger *log.Logger) any { if pool := getPool(startFile); pool != nil { rt := pool.Get() defer pool.Put(rt) if action, ok := rt.GetGoData(actionKey).(goja.Callable); ok { var thisArg goja.Value if thisArgV, ok := rt.GetGoData(actionKey + "This").(goja.Value); ok { thisArg = thisArgV } r, _ := rt.RunVM(func(vm *goja.Runtime) (any, error) { r2, err2 := runAction(action, vm, thisArg, args, headers, request, response, nil, caller, session, logger) return r2, err2 }) return r } } return nil } } type PoolStatus struct { Total uint MaxTotal uint MaxWaiting uint CreateTimes uint } func GetPoolStatus() map[string]PoolStatus { out := map[string]PoolStatus{} for k, v := range pools { total, maxTotal, maxWaiting, createTimes := v.Count() out[k] = PoolStatus{ Total: total, MaxTotal: maxTotal, MaxWaiting: maxWaiting, CreateTimes: createTimes, } } return out }