service/service.go
2024-10-18 17:54:37 +08:00

593 lines
17 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
_ "embed"
"encoding/json"
"errors"
"fmt"
"reflect"
"regexp"
"strings"
"sync"
"time"
"apigo.cc/gojs"
"apigo.cc/gojs/goja"
"github.com/gorilla/websocket"
"github.com/ssgo/config"
"github.com/ssgo/discover"
"github.com/ssgo/log"
"github.com/ssgo/redis"
"github.com/ssgo/s"
"github.com/ssgo/standard"
"github.com/ssgo/u"
)
//go:embed service.ts
var serviceTS string
//go:embed README.md
var serviceMD string
var server *s.AsyncServer
var pools = map[string]*gojs.Pool{}
var poolExists = map[string]bool{}
var poolActionRegistered = map[string]bool{}
var poolsLock = sync.RWMutex{}
var waitChan chan bool
type LimiterConfig struct {
From string
Time int
Times int
}
type Config struct {
SessionKey string
DeviceKey string
ClientKey string
UserIdKey string
SessionProvider string
SessionTimeout int64
AuthFieldMessage string
VerifyFieldMessage string
LimitedMessage string
Limiters map[string]*LimiterConfig
LimiterRedis string
Proxy map[string]string
Rewrite map[string]string
Static map[string]string
}
var serviceConfig Config
var onStop goja.Callable
var limiters = map[string]*s.Limiter{}
func init() {
obj := map[string]any{
"config": func(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value {
// 处理配置
args := gojs.MakeArgs(&argsIn, vm)
serviceConfig = Config{"Session", "Device", "Client", "userId", "", 3600, "auth failed", "verify failed", "too many requests", nil, "", map[string]string{}, map[string]string{}, map[string]string{}}
if errs := config.LoadConfig("service", &serviceConfig); errs != nil && len(errs) > 0 {
panic(vm.NewGoError(errs[0]))
}
config.LoadConfig("service", &discover.Config)
// var auth goja.Callable
if conf := args.Obj(0); conf != nil {
u.Convert(conf.O.Export(), &serviceConfig)
u.Convert(conf.O.Export(), &s.Config)
u.Convert(conf.O.Export(), &discover.Config)
// auth = conf.Func("auth")
onStop = conf.Func("onStop")
if serviceConfig.SessionProvider != "" {
sessionRedis = redis.GetRedis(serviceConfig.SessionProvider, args.Logger)
}
sessionTimeout = serviceConfig.SessionTimeout
if sessionTimeout < 0 {
sessionTimeout = 0
}
}
// 身份验证和Session
authAccessToken := len(s.Config.AccessTokens) > 0
s.SetClientKeys(serviceConfig.DeviceKey, serviceConfig.ClientKey, serviceConfig.SessionKey)
if serviceConfig.SessionKey != "" {
s.SetAuthChecker(func(authLevel int, logger *log.Logger, url *string, args map[string]any, request *s.Request, response *s.Response, options *s.WebServiceOptions) (pass bool, object any) {
var session *Session
setAuthLevel := 0
if serviceConfig.SessionKey != "" {
sessionID := request.GetSessionId()
if sessionID != "" {
session = NewSession(sessionID, logger)
}
if userId, ok := session.data[serviceConfig.UserIdKey]; ok {
request.SetUserId(u.String(userId))
}
// 优先使用session中的authLevel
if authLevel > 0 {
if authLevelBySession, ok := session.data["_authLevel"]; ok {
setAuthLevel = u.Int(authLevelBySession)
}
}
}
// 如果没有session中的authLevel验证失败则使用Access-Token中的authLevel服务间调用
if authAccessToken && setAuthLevel < authLevel {
setAuthLevel = s.GetAuthTokenLevel(request.Header.Get("Access-Token"))
}
if setAuthLevel >= authLevel {
return true, session
} else {
msg := serviceConfig.AuthFieldMessage
if strings.Contains(msg, "{{") {
msg = strings.ReplaceAll(msg, "{{TARGET_AUTHLEVEL}}", u.String(authLevel))
msg = strings.ReplaceAll(msg, "{{USER_AUTHLEVEL}}", u.String(setAuthLevel))
}
var obj any
if json.Unmarshal([]byte(msg), &obj) == nil {
return false, obj
} else {
return false, msg
}
}
})
s.Init()
}
// 限流器
if serviceConfig.Limiters != nil {
var limiterRedis *redis.Redis
if serviceConfig.LimiterRedis != "" {
limiterRedis = redis.GetRedis(serviceConfig.LimiterRedis, args.Logger)
}
for name, limiter := range serviceConfig.Limiters {
switch limiter.From {
case "ip":
limiter.From = "header." + standard.DiscoverHeaderClientIp
case "user":
limiter.From = "header." + standard.DiscoverHeaderUserId
case "device":
limiter.From = "header." + standard.DiscoverHeaderDeviceId
}
if limiterRedis != nil {
limiters[name] = s.NewLimiter(name, limiter.From, time.Duration(limiter.Time)*time.Millisecond, limiter.Times, limiterRedis)
} else {
limiters[name] = s.NewLocalLimiter(name, limiter.From, time.Duration(limiter.Time)*time.Millisecond, limiter.Times)
}
}
}
return nil
},
"start": func(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value {
if server != nil {
panic(vm.NewGoError(errors.New("server already started")))
}
// 处理静态文件
if len(serviceConfig.Static) > 0 {
UpdateStatic(serviceConfig.Static)
}
if len(serviceConfig.Rewrite) > 0 {
UpdateRewrite(serviceConfig.Rewrite)
s.SetRewriteBy(rewrite)
}
if len(serviceConfig.Proxy) > 0 {
UpdateProxy(serviceConfig.Proxy)
s.SetProxyBy(proxy)
}
// 启动服务
server = s.AsyncStart()
waitChan = make(chan bool, 1)
server.OnStop(func() {
if onStop != nil {
onStop(nil)
}
})
server.OnStopped(func() {
if waitChan != nil {
waitChan <- true
}
})
return vm.ToValue(server.Addr)
},
"stop": func(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value {
if server == nil {
panic(vm.NewGoError(errors.New("server not started")))
}
server.Stop()
ClearRewritesAndProxies()
pools = map[string]*gojs.Pool{}
server = nil
return nil
},
"register": func(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value {
args := gojs.MakeArgs(&argsIn, vm).Check(2)
o := args.Obj(0)
action := args.Func(1)
if action == nil {
panic(vm.NewGoError(errors.New("action must be a callback function")))
}
authLevel := o.Int("authLevel")
host := o.Str("host")
method := strings.ToUpper(o.Str("method"))
path := o.Str("path")
memo := o.Str("memo")
var usedLimiters []*s.Limiter
for _, limiterName := range o.Array("limiters") {
if limiter1, ok := limiters[u.String(limiterName)]; ok {
usedLimiters = append(usedLimiters, limiter1)
}
}
if verifiesObj := o.Obj("verifies"); verifiesObj != nil {
verifiesSet := map[string]func(any, *goja.Runtime) bool{}
for _, field := range verifiesObj.Keys() {
v := verifiesObj.Get(field)
// 根据类型设置验证器存储到vm中在请求到达时进行有效性验证
switch v.ExportType().Kind() {
case reflect.String:
verifiesSet[field] = verifyRegexp(u.String(v.Export()))
case reflect.Int, reflect.Int64:
verifiesSet[field] = verifyLen(u.Int(v.Export()))
case reflect.Bool:
verifiesSet[field] = verifyRequire()
case reflect.Slice:
list := make([]string, 0)
vm.ForOf(v, func(v1 goja.Value) bool {
list = append(list, u.String(v1.Export()))
return true
})
verifiesSet[field] = verifyIn(list)
default:
if fn, ok := goja.AssertFunction(v); ok {
verifiesSet[field] = verifyFunc(fn, verifiesObj)
} else if obj := v.ToObject(vm); obj != nil {
// 支持传入js的正则表达式对象
if testV := obj.Get("test"); testV != nil {
if testFn, ok := goja.AssertFunction(testV); ok {
verifiesSet[field] = verifyFunc(testFn, obj)
}
}
}
}
}
vm.GoData[fmt.Sprint("VERIFY_"+host, method, path)] = verifiesSet
}
opt := s.WebServiceOptions{
NoBody: o.Bool("noBody"),
NoLog200: o.Bool("noLog200"),
Host: host,
//Ext: nil,
Limiters: usedLimiters,
}
startFile := u.String(vm.GoData["startFile"])
poolsLock.RLock()
poolExist := poolExists[startFile]
poolsLock.RUnlock()
if poolExist {
// 从对象调用(支持并发)
actionKey := "REGISTER_" + host + method + path
vm.GoData[actionKey] = action
vm.GoData[actionKey+"This"] = args.This
if method == "WS" {
vm.GoData[actionKey+"onMessage"] = o.Func("onMessage")
vm.GoData[actionKey+"onClose"] = o.Func("onClose")
}
poolsLock.Lock()
actionRegistered := poolActionRegistered[actionKey]
if !actionRegistered {
poolActionRegistered[actionKey] = true
}
poolsLock.Unlock()
if !actionRegistered {
if strings.ToUpper(method) == "WS" {
s.RegisterWebsocketWithOptions(authLevel, path, &websocket.Upgrader{}, makeWSAction(startFile, actionKey), nil, nil, nil, true, "", opt)
} else {
s.RestfulWithOptions(authLevel, method, path, makeOuterAction(startFile, actionKey), memo, opt)
}
}
} else {
// 无对象池,直接调用(单线程)
s.RestfulWithOptions(authLevel, method, path, makeInnerAction(action, vm, args.This), memo, opt)
}
return nil
},
"load": func(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value {
args := gojs.MakeArgs(&argsIn, vm).Check(1)
actionFile := args.Path(0)
opt := args.Obj(1)
var mi, ma, idle uint
//var num uint
mainArgs := make([]any, 0)
debug := false
if opt != nil {
mi = uint(opt.Int("min"))
ma = uint(opt.Int("max"))
idle = uint(opt.Int("idle"))
//num = uint(opt.Int("num"))
debug = opt.Bool("debug")
if mainArgs1 := opt.Array("args"); mainArgs1 != nil {
mainArgs = mainArgs1
}
}
if !u.FileExists(actionFile) {
panic(vm.NewGoError(errors.New("actionFile must be a js file path")))
}
actionCode := u.ReadFileN(actionFile)
if !strings.Contains(actionCode, "function main(") || !strings.Contains(actionCode, ".register(") {
panic(vm.NewGoError(errors.New("actionFile must be a js file with main function and call service.register")))
}
poolsLock.Lock()
poolExists[actionFile] = true
poolsLock.Unlock()
p := gojs.NewPoolByCode(actionCode, actionFile, gojs.PoolConfig{
Min: mi,
Max: ma,
Idle: idle,
Debug: debug,
Args: mainArgs,
}, args.Logger)
//p := gojs.NewLBByCode(actionCode, actionFile, gojs.LBConfig{
// Num: num,
// Debug: debug,
// Args: mainArgs,
//}, args.Logger)
poolsLock.Lock()
pools[actionFile] = p
poolsLock.Unlock()
return nil
},
"task": func(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value {
args := gojs.MakeArgs(&argsIn, vm).Check(1)
taskFile := args.Path(0)
interval := args.Int(1)
if interval == 0 {
interval = 1000
}
if interval < 100 {
interval = 100
}
if !u.FileExists(taskFile) {
panic(vm.NewGoError(errors.New("taskFile must be a js file path")))
}
rt := gojs.New()
_, err := rt.RunFile(taskFile)
if err != nil {
panic(vm.NewGoError(err))
}
println(u.BMagenta("taskFile: "), taskFile, interval)
s.NewTimerServer(taskFile, time.Duration(interval)*time.Millisecond, func(isRunning *bool) {
rt.RunCode("if(onRun)onRun()")
}, func() {
rt.RunCode("if(onStart)onStart()")
}, func() {
rt.RunCode("if(onStop)onStop()")
})
return nil
},
"dataSet": DataSet,
"dataGet": DataGet,
"dataKeys": DataKeys,
"dataCount": DataCount,
"dataFetch": DataFetch,
"dataRemove": DataRemove,
"listPop": ListPop,
"listPush": ListPush,
"listCount": ListCount,
"listRemove": ListRemove,
"newCaller": NewCaller,
}
gojs.Register("apigo.cc/gojs/service", gojs.Module{
Object: obj,
TsCode: serviceTS,
Example: serviceMD,
WaitForStop: func() {
if waitChan != nil {
<-waitChan
}
},
})
}
func verifyRegexp(regexpStr string) func(any, *goja.Runtime) bool {
if rx, err := regexp.Compile(regexpStr); err != nil {
return func(value any, vm *goja.Runtime) bool {
return rx.MatchString(u.String(value))
}
}
return nil
}
func verifyLen(checkLength int) func(any, *goja.Runtime) bool {
return func(value any, vm *goja.Runtime) bool {
v := u.FinalType(reflect.ValueOf(value))
if v.Kind() == reflect.Slice || v.Kind() == reflect.Map {
return v.Len() >= checkLength
}
return len(u.String(value)) >= checkLength
}
}
func verifyRequire() func(any, *goja.Runtime) bool {
return func(value any, vm *goja.Runtime) bool {
value = u.FixPtr(value)
switch realValue := value.(type) {
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64:
return u.Int64(realValue) != 0
case bool:
return realValue
default:
v := u.FinalType(reflect.ValueOf(value))
if v.Kind() == reflect.Slice || v.Kind() == reflect.Map {
return v.Len() >= 0
}
return len(u.String(value)) >= 0
}
}
}
func verifyIn(list []string) func(any, *goja.Runtime) bool {
return func(value any, vm *goja.Runtime) bool {
return u.StringIn(list, u.String(value))
}
}
func verifyFunc(callback goja.Callable, thisObj goja.Value) func(any, *goja.Runtime) bool {
return func(value any, vm *goja.Runtime) bool {
if r, err := callback(thisObj, vm.ToValue(value)); err == nil {
return u.Bool(r.Export())
} else {
log.DefaultLogger.Error(err.Error())
}
return false
}
}
func makeRequestParams(args map[string]any, headers map[string]string, request *s.Request, response *s.Response, client *websocket.Conn, caller *discover.Caller, session *Session, logger *log.Logger) (gojs.Map, *Response) {
var resp *Response
params := gojs.Map{
"args": args,
"logger": gojs.MakeLogger(logger),
"request": MakeRequest(request, args, headers),
"client": MakeWSClient(client, request.Id),
}
if response != nil {
resp = &Response{
resp: response,
endCh: make(chan bool, 1),
Id: response.Id,
}
params["response"] = gojs.MakeMap(resp)
}
if headers != nil {
params["headers"] = headers
}
if session != nil {
params["session"] = gojs.MakeMap(session)
}
if caller != nil {
params["caller"] = gojs.MakeMap(&Caller{client: caller})
}
return params, resp
}
func runAction(action goja.Callable, vm *goja.Runtime, thisArg goja.Value, args map[string]any, headers map[string]string, request *s.Request, response *s.Response, client *websocket.Conn, caller *discover.Caller, session *Session, logger *log.Logger) (any, error) {
vm.CallbackLocker.Lock()
defer vm.CallbackLocker.Unlock()
// 验证请求参数的有效性
if verifies, ok := vm.GoData["VERIFY_"+u.String(request.Get("registerTag"))].(map[string]func(any, *goja.Runtime) bool); ok {
failedFields := make([]string, 0)
for k, v := range args {
if verifier, ok1 := verifies[k]; ok1 {
if !verifier(v, vm) {
failedFields = append(failedFields, k)
}
}
}
// 数据有效性验证失败
if len(failedFields) > 0 {
response.WriteHeader(400)
msg := serviceConfig.VerifyFieldMessage
if strings.Contains(msg, "{{") {
msg = strings.ReplaceAll(msg, "{{FAILED_FIELDS}}", strings.Join(failedFields, ", "))
}
var obj any
if json.Unmarshal([]byte(msg), &obj) == nil {
return obj, nil
} else {
return msg, nil
}
}
}
requestParams, resp := makeRequestParams(args, headers, request, response, client, caller, session, logger)
var r any
r1, err := action(thisArg, vm.ToValue(requestParams))
if err == nil && r1 != nil {
r = r1.Export()
}
if err != nil {
logger.Error(err.Error())
}
if response != nil && r == nil && err == nil {
<-resp.endCh
r = resp.result
}
return r, err
}
func makeInnerAction(action goja.Callable, vm *goja.Runtime, thisArg goja.Value) any {
return func(args map[string]any, headers map[string]string, request *s.Request, response *s.Response, caller *discover.Caller, session *Session, logger *log.Logger) any {
r, _ := runAction(action, vm, thisArg, args, headers, request, response, nil, caller, session, logger)
return r
}
}
func getPool(startFile string) *gojs.Pool {
var pool *gojs.Pool
for i := 0; i < 10; i++ {
poolsLock.RLock()
pool = pools[startFile]
poolsLock.RUnlock()
if pool != nil {
return pool
}
time.Sleep(time.Millisecond * 100)
}
return nil
}
func makeOuterAction(startFile string, actionKey string) any {
return func(args map[string]any, headers map[string]string, request *s.Request, response *s.Response, caller *discover.Caller, session *Session, logger *log.Logger) any {
if pool := getPool(startFile); pool != nil {
rt := pool.Get()
defer pool.Put(rt)
if action, ok := rt.GetGoData(actionKey).(goja.Callable); ok {
var thisArg goja.Value
if thisArgV, ok := rt.GetGoData(actionKey + "This").(goja.Value); ok {
thisArg = thisArgV
}
r, _ := rt.RunVM(func(vm *goja.Runtime) (any, error) {
r2, err2 := runAction(action, vm, thisArg, args, headers, request, response, nil, caller, session, logger)
return r2, err2
})
return r
}
}
return nil
}
}
type PoolStatus struct {
Total uint
MaxTotal uint
MaxWaiting uint
CreateTimes uint
}
func GetPoolStatus() map[string]PoolStatus {
out := map[string]PoolStatus{}
for k, v := range pools {
total, maxTotal, maxWaiting, createTimes := v.Count()
out[k] = PoolStatus{
Total: total,
MaxTotal: maxTotal,
MaxWaiting: maxWaiting,
CreateTimes: createTimes,
}
}
return out
}