package service import ( _ "embed" "encoding/json" "errors" "fmt" "reflect" "regexp" "strings" "sync" "time" "apigo.cc/gojs" "apigo.cc/gojs/goja" "github.com/gorilla/websocket" "github.com/ssgo/config" "github.com/ssgo/discover" "github.com/ssgo/log" "github.com/ssgo/redis" "github.com/ssgo/s" "github.com/ssgo/standard" "github.com/ssgo/u" ) //go:embed service.ts var serviceTS string //go:embed README.md var serviceMD string var server *s.AsyncServer var pools = map[string]*gojs.Pool{} var poolExists = map[string]bool{} var poolActionRegistered = map[string]bool{} var poolsLock = sync.RWMutex{} var waitChan chan bool type LimiterConfig struct { From string Time int Times int } type Config struct { SessionKey string DeviceKey string ClientKey string UserIdKey string SessionProvider string SessionTimeout int64 AuthFieldMessage string VerifyFieldMessage string LimitedMessage string Limiters map[string]*LimiterConfig LimiterRedis string Proxy map[string]string Rewrite map[string]string Static map[string]string } var serviceConfig Config var onStop goja.Callable var limiters = map[string]*s.Limiter{} func init() { obj := map[string]any{ "config": func(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value { // 处理配置 args := gojs.MakeArgs(&argsIn, vm) serviceConfig = Config{"Session", "Device", "Client", "userId", "", 3600, "auth failed", "verify failed", "too many requests", nil, "", map[string]string{}, map[string]string{}, map[string]string{}} if errs := config.LoadConfig("service", &serviceConfig); errs != nil && len(errs) > 0 { panic(vm.NewGoError(errs[0])) } config.LoadConfig("service", &discover.Config) // var auth goja.Callable if conf := args.Obj(0); conf != nil { u.Convert(conf.O.Export(), &serviceConfig) u.Convert(conf.O.Export(), &s.Config) u.Convert(conf.O.Export(), &discover.Config) // auth = conf.Func("auth") onStop = conf.Func("onStop") if serviceConfig.SessionProvider != "" { sessionRedis = redis.GetRedis(serviceConfig.SessionProvider, args.Logger) } sessionTimeout = serviceConfig.SessionTimeout if sessionTimeout < 0 { sessionTimeout = 0 } } // 身份验证和Session authAccessToken := len(s.Config.AccessTokens) > 0 s.SetClientKeys(serviceConfig.DeviceKey, serviceConfig.ClientKey, serviceConfig.SessionKey) if serviceConfig.SessionKey != "" { s.SetAuthChecker(func(authLevel int, logger *log.Logger, url *string, args map[string]any, request *s.Request, response *s.Response, options *s.WebServiceOptions) (pass bool, object any) { var session *Session setAuthLevel := 0 if serviceConfig.SessionKey != "" { sessionID := request.GetSessionId() if sessionID != "" { session = NewSession(sessionID, logger) } if userId, ok := session.data[serviceConfig.UserIdKey]; ok { request.SetUserId(u.String(userId)) } // 优先使用session中的authLevel if authLevel > 0 { if authLevelBySession, ok := session.data["_authLevel"]; ok { setAuthLevel = u.Int(authLevelBySession) } } } // 如果没有session中的authLevel验证失败,则使用Access-Token中的authLevel(服务间调用) if authAccessToken && setAuthLevel < authLevel { setAuthLevel = s.GetAuthTokenLevel(request.Header.Get("Access-Token")) } if setAuthLevel >= authLevel { return true, session } else { msg := serviceConfig.AuthFieldMessage if strings.Contains(msg, "{{") { msg = strings.ReplaceAll(msg, "{{TARGET_AUTHLEVEL}}", u.String(authLevel)) msg = strings.ReplaceAll(msg, "{{USER_AUTHLEVEL}}", u.String(setAuthLevel)) } var obj any if json.Unmarshal([]byte(msg), &obj) == nil { return false, obj } else { return false, msg } } }) s.Init() } // 限流器 if serviceConfig.Limiters != nil { var limiterRedis *redis.Redis if serviceConfig.LimiterRedis != "" { limiterRedis = redis.GetRedis(serviceConfig.LimiterRedis, args.Logger) } for name, limiter := range serviceConfig.Limiters { switch limiter.From { case "ip": limiter.From = "header." + standard.DiscoverHeaderClientIp case "user": limiter.From = "header." + standard.DiscoverHeaderUserId case "device": limiter.From = "header." + standard.DiscoverHeaderDeviceId } if limiterRedis != nil { limiters[name] = s.NewLimiter(name, limiter.From, time.Duration(limiter.Time)*time.Millisecond, limiter.Times, limiterRedis) } else { limiters[name] = s.NewLocalLimiter(name, limiter.From, time.Duration(limiter.Time)*time.Millisecond, limiter.Times) } } } return nil }, "start": func(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value { if server != nil { panic(vm.NewGoError(errors.New("server already started"))) } // 处理静态文件 if len(serviceConfig.Static) > 0 { UpdateStatic(serviceConfig.Static) } if len(serviceConfig.Rewrite) > 0 { UpdateRewrite(serviceConfig.Rewrite) s.SetRewriteBy(rewrite) } if len(serviceConfig.Proxy) > 0 { UpdateProxy(serviceConfig.Proxy) s.SetProxyBy(proxy) } // 启动服务 server = s.AsyncStart() waitChan = make(chan bool, 1) server.OnStop(func() { if onStop != nil { onStop(nil) } }) server.OnStopped(func() { if waitChan != nil { waitChan <- true } }) return vm.ToValue(server.Addr) }, "stop": func(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value { if server == nil { panic(vm.NewGoError(errors.New("server not started"))) } server.Stop() ClearRewritesAndProxies() pools = map[string]*gojs.Pool{} server = nil return nil }, "register": func(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value { args := gojs.MakeArgs(&argsIn, vm).Check(2) o := args.Obj(0) action := args.Func(1) if action == nil { panic(vm.NewGoError(errors.New("action must be a callback function"))) } authLevel := o.Int("authLevel") host := o.Str("host") method := strings.ToUpper(o.Str("method")) path := o.Str("path") memo := o.Str("memo") var usedLimiters []*s.Limiter for _, limiterName := range o.Array("limiters") { if limiter1, ok := limiters[u.String(limiterName)]; ok { usedLimiters = append(usedLimiters, limiter1) } } if verifiesObj := o.Obj("verifies"); verifiesObj != nil { verifiesSet := map[string]func(any, *goja.Runtime) bool{} for _, field := range verifiesObj.Keys() { v := verifiesObj.Get(field) // 根据类型设置验证器,存储到vm中,在请求到达时进行有效性验证 switch v.ExportType().Kind() { case reflect.String: verifiesSet[field] = verifyRegexp(u.String(v.Export())) case reflect.Int, reflect.Int64: verifiesSet[field] = verifyLen(u.Int(v.Export())) case reflect.Bool: verifiesSet[field] = verifyRequire() case reflect.Slice: list := make([]string, 0) vm.ForOf(v, func(v1 goja.Value) bool { list = append(list, u.String(v1.Export())) return true }) verifiesSet[field] = verifyIn(list) default: if fn, ok := goja.AssertFunction(v); ok { verifiesSet[field] = verifyFunc(fn, verifiesObj) } else if obj := v.ToObject(vm); obj != nil { // 支持传入js的正则表达式对象 if testV := obj.Get("test"); testV != nil { if testFn, ok := goja.AssertFunction(testV); ok { verifiesSet[field] = verifyFunc(testFn, obj) } } } } } vm.GoData[fmt.Sprint("VERIFY_"+host, method, path)] = verifiesSet } opt := s.WebServiceOptions{ NoBody: o.Bool("noBody"), NoLog200: o.Bool("noLog200"), Host: host, //Ext: nil, Limiters: usedLimiters, } startFile := u.String(vm.GoData["startFile"]) poolsLock.RLock() poolExist := poolExists[startFile] poolsLock.RUnlock() if poolExist { // 从对象调用(支持并发) actionKey := "REGISTER_" + host + method + path vm.GoData[actionKey] = action vm.GoData[actionKey+"This"] = args.This if method == "WS" { vm.GoData[actionKey+"onMessage"] = o.Func("onMessage") vm.GoData[actionKey+"onClose"] = o.Func("onClose") } poolsLock.Lock() actionRegistered := poolActionRegistered[actionKey] if !actionRegistered { poolActionRegistered[actionKey] = true } poolsLock.Unlock() if !actionRegistered { if strings.ToUpper(method) == "WS" { s.RegisterWebsocketWithOptions(authLevel, path, &websocket.Upgrader{}, makeWSAction(startFile, actionKey), nil, nil, nil, true, "", opt) } else { s.RestfulWithOptions(authLevel, method, path, makeOuterAction(startFile, actionKey), memo, opt) } } } else { // 无对象池,直接调用(单线程) s.RestfulWithOptions(authLevel, method, path, makeInnerAction(action, vm, args.This), memo, opt) } return nil }, "load": func(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value { args := gojs.MakeArgs(&argsIn, vm).Check(1) actionFile := args.Path(0) opt := args.Obj(1) var mi, ma, idle uint //var num uint mainArgs := make([]any, 0) debug := false if opt != nil { mi = uint(opt.Int("min")) ma = uint(opt.Int("max")) idle = uint(opt.Int("idle")) //num = uint(opt.Int("num")) debug = opt.Bool("debug") if mainArgs1 := opt.Array("args"); mainArgs1 != nil { mainArgs = mainArgs1 } } if !u.FileExists(actionFile) { panic(vm.NewGoError(errors.New("actionFile must be a js file path"))) } actionCode := u.ReadFileN(actionFile) if !strings.Contains(actionCode, "function main(") || !strings.Contains(actionCode, ".register(") { panic(vm.NewGoError(errors.New("actionFile must be a js file with main function and call service.register"))) } poolsLock.Lock() poolExists[actionFile] = true poolsLock.Unlock() p := gojs.NewPoolByCode(actionCode, actionFile, gojs.PoolConfig{ Min: mi, Max: ma, Idle: idle, Debug: debug, Args: mainArgs, }, args.Logger) //p := gojs.NewLBByCode(actionCode, actionFile, gojs.LBConfig{ // Num: num, // Debug: debug, // Args: mainArgs, //}, args.Logger) poolsLock.Lock() pools[actionFile] = p poolsLock.Unlock() return nil }, "task": func(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value { args := gojs.MakeArgs(&argsIn, vm).Check(1) taskFile := args.Path(0) interval := args.Int(1) if interval == 0 { interval = 1000 } if interval < 100 { interval = 100 } if !u.FileExists(taskFile) { panic(vm.NewGoError(errors.New("taskFile must be a js file path"))) } rt := gojs.New() _, err := rt.RunFile(taskFile) if err != nil { panic(vm.NewGoError(err)) } println(u.BMagenta("taskFile: "), taskFile, interval) s.NewTimerServer(taskFile, time.Duration(interval)*time.Millisecond, func(isRunning *bool) { rt.RunCode("if(onRun)onRun()") }, func() { rt.RunCode("if(onStart)onStart()") }, func() { rt.RunCode("if(onStop)onStop()") }) return nil }, "dataSet": DataSet, "dataGet": DataGet, "dataKeys": DataKeys, "dataCount": DataCount, "dataFetch": DataFetch, "dataRemove": DataRemove, "listPop": ListPop, "listPush": ListPush, "listCount": ListCount, "listRemove": ListRemove, "newCaller": NewCaller, } gojs.Register("apigo.cc/gojs/service", gojs.Module{ Object: obj, TsCode: serviceTS, Example: serviceMD, WaitForStop: func() { if waitChan != nil { <-waitChan } }, }) } func verifyRegexp(regexpStr string) func(any, *goja.Runtime) bool { if rx, err := regexp.Compile(regexpStr); err != nil { return func(value any, vm *goja.Runtime) bool { return rx.MatchString(u.String(value)) } } return nil } func verifyLen(checkLength int) func(any, *goja.Runtime) bool { return func(value any, vm *goja.Runtime) bool { v := u.FinalType(reflect.ValueOf(value)) if v.Kind() == reflect.Slice || v.Kind() == reflect.Map { return v.Len() >= checkLength } return len(u.String(value)) >= checkLength } } func verifyRequire() func(any, *goja.Runtime) bool { return func(value any, vm *goja.Runtime) bool { value = u.FixPtr(value) switch realValue := value.(type) { case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64: return u.Int64(realValue) != 0 case bool: return realValue default: v := u.FinalType(reflect.ValueOf(value)) if v.Kind() == reflect.Slice || v.Kind() == reflect.Map { return v.Len() >= 0 } return len(u.String(value)) >= 0 } } } func verifyIn(list []string) func(any, *goja.Runtime) bool { return func(value any, vm *goja.Runtime) bool { return u.StringIn(list, u.String(value)) } } func verifyFunc(callback goja.Callable, thisObj goja.Value) func(any, *goja.Runtime) bool { return func(value any, vm *goja.Runtime) bool { if r, err := callback(thisObj, vm.ToValue(value)); err == nil { return u.Bool(r.Export()) } else { log.DefaultLogger.Error(err.Error()) } return false } } func makeRequestParams(args map[string]any, headers map[string]string, request *s.Request, response *s.Response, client *websocket.Conn, caller *discover.Caller, session *Session, logger *log.Logger) (gojs.Map, *Response) { var resp *Response params := gojs.Map{ "args": args, "logger": gojs.MakeLogger(logger), "request": MakeRequest(request, args, headers), "client": MakeWSClient(client, request.Id), } if response != nil { resp = &Response{ resp: response, endCh: make(chan bool, 1), Id: response.Id, } params["response"] = gojs.MakeMap(resp) } if headers != nil { params["headers"] = headers } if session != nil { params["session"] = gojs.MakeMap(session) } if caller != nil { params["caller"] = gojs.MakeMap(&Caller{client: caller}) } return params, resp } func runAction(action goja.Callable, vm *goja.Runtime, thisArg goja.Value, args map[string]any, headers map[string]string, request *s.Request, response *s.Response, client *websocket.Conn, caller *discover.Caller, session *Session, logger *log.Logger) (any, error) { vm.CallbackLocker.Lock() defer vm.CallbackLocker.Unlock() // 验证请求参数的有效性 if verifies, ok := vm.GoData["VERIFY_"+u.String(request.Get("registerTag"))].(map[string]func(any, *goja.Runtime) bool); ok { failedFields := make([]string, 0) for k, v := range args { if verifier, ok1 := verifies[k]; ok1 { if !verifier(v, vm) { failedFields = append(failedFields, k) } } } // 数据有效性验证失败 if len(failedFields) > 0 { response.WriteHeader(400) msg := serviceConfig.VerifyFieldMessage if strings.Contains(msg, "{{") { msg = strings.ReplaceAll(msg, "{{FAILED_FIELDS}}", strings.Join(failedFields, ", ")) } var obj any if json.Unmarshal([]byte(msg), &obj) == nil { return obj, nil } else { return msg, nil } } } requestParams, resp := makeRequestParams(args, headers, request, response, client, caller, session, logger) var r any r1, err := action(thisArg, vm.ToValue(requestParams)) if err == nil && r1 != nil { r = r1.Export() } if err != nil { logger.Error(err.Error()) } if response != nil && r == nil && err == nil { <-resp.endCh r = resp.result } return r, err } func makeInnerAction(action goja.Callable, vm *goja.Runtime, thisArg goja.Value) any { return func(args map[string]any, headers map[string]string, request *s.Request, response *s.Response, caller *discover.Caller, session *Session, logger *log.Logger) any { r, _ := runAction(action, vm, thisArg, args, headers, request, response, nil, caller, session, logger) return r } } func getPool(startFile string) *gojs.Pool { var pool *gojs.Pool for i := 0; i < 10; i++ { poolsLock.RLock() pool = pools[startFile] poolsLock.RUnlock() if pool != nil { return pool } time.Sleep(time.Millisecond * 100) } return nil } func makeOuterAction(startFile string, actionKey string) any { return func(args map[string]any, headers map[string]string, request *s.Request, response *s.Response, caller *discover.Caller, session *Session, logger *log.Logger) any { if pool := getPool(startFile); pool != nil { rt := pool.Get() defer pool.Put(rt) if action, ok := rt.GetGoData(actionKey).(goja.Callable); ok { var thisArg goja.Value if thisArgV, ok := rt.GetGoData(actionKey + "This").(goja.Value); ok { thisArg = thisArgV } r, _ := rt.RunVM(func(vm *goja.Runtime) (any, error) { r2, err2 := runAction(action, vm, thisArg, args, headers, request, response, nil, caller, session, logger) return r2, err2 }) return r } } return nil } } type PoolStatus struct { Total uint MaxTotal uint MaxWaiting uint CreateTimes uint } func GetPoolStatus() map[string]PoolStatus { out := map[string]PoolStatus{} for k, v := range pools { total, maxTotal, maxWaiting, createTimes := v.Count() out[k] = PoolStatus{ Total: total, MaxTotal: maxTotal, MaxWaiting: maxWaiting, CreateTimes: createTimes, } } return out }