diff --git a/bench_test.go b/bench_test.go index 5e61785..f5ad6b9 100644 --- a/bench_test.go +++ b/bench_test.go @@ -1,19 +1,17 @@ package js import ( - "context" "testing" ) func BenchmarkCall(b *testing.B) { p := NewPool() p.Define(`function add(a, b) { return a + b; }`) - ctx := context.Background() args := []any{1, 2} b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := p.Call(ctx, "add", args) + _, err := p.Call("add", 0, nil, args...) if err != nil { b.Fatal(err) } @@ -23,11 +21,11 @@ func BenchmarkCall(b *testing.B) { func BenchmarkSync(b *testing.B) { p := NewPool() code := `function f() { return 1; }` - + b.ResetTimer() for i := 0; i < b.N; i++ { p.Define(code) - _, err := p.Call(context.Background(), "f", nil) + _, err := p.Call("f", 0, nil) if err != nil { b.Fatal(err) } diff --git a/bridge.go b/bridge.go index f2cee9b..6b0636a 100644 --- a/bridge.go +++ b/bridge.go @@ -3,12 +3,11 @@ package js import ( "context" "fmt" - "io" - "log" "reflect" "apigo.cc/go/cast" "apigo.cc/go/jsmod" + "apigo.cc/go/log" "github.com/dop251/goja" ) @@ -25,10 +24,12 @@ func wrapGoFunc(vm *goja.Runtime, fn any, isUnsafe bool) goja.Value { return vm.ToValue(func(call goja.FunctionCall) goja.Value { // 1. Safety Check safeMode := true // Default to safe mode - smVal := vm.Get("__safeMode__") - if smVal != nil && !goja.IsUndefined(smVal) { - if sm, ok := smVal.Export().(bool); ok { - safeMode = sm + ctxVal := vm.Get("__ctx__") + var currentCtx context.Context + if ctxVal != nil && !goja.IsUndefined(ctxVal) { + if c, ok := ctxVal.Export().(context.Context); ok { + currentCtx = c + safeMode = jsmod.IsSafeMode(c) } } @@ -47,15 +48,10 @@ func wrapGoFunc(vm *goja.Runtime, fn any, isUnsafe bool) goja.Value { // Magic Injection: context.Context if argType.Implements(reflect.TypeOf((*context.Context)(nil)).Elem()) { - ctx := context.Background() - ctxVal := vm.Get("__ctx__") - if ctxVal != nil && !goja.IsUndefined(ctxVal) { - if c, ok := ctxVal.Export().(context.Context); ok { - ctx = c - } + ctx := currentCtx + if ctx == nil { + ctx = context.Background() } - // Inject SafeMode status into context - ctx = context.WithValue(ctx, jsmod.SafeModeKey, safeMode) goArgs[i] = reflect.ValueOf(ctx) continue } @@ -63,15 +59,13 @@ func wrapGoFunc(vm *goja.Runtime, fn any, isUnsafe bool) goja.Value { // Magic Injection: *log.Logger if argType == reflect.TypeOf((*log.Logger)(nil)) { var logger *log.Logger - logVal := vm.Get("__logger__") - if logVal != nil && !goja.IsUndefined(logVal) { - if l, ok := logVal.Export().(*log.Logger); ok { + if currentCtx != nil { + if l, ok := jsmod.Get(currentCtx, "Logger").(*log.Logger); ok { logger = l } } if logger == nil { - // Fallback to a discard logger if none provided to avoid nil panic in Go side - logger = log.New(io.Discard, "", 0) + logger = log.DefaultLogger } goArgs[i] = reflect.ValueOf(logger) continue diff --git a/bridge_test.go b/bridge_test.go index b457ce1..3b1dd7a 100644 --- a/bridge_test.go +++ b/bridge_test.go @@ -3,87 +3,106 @@ package js import ( "bytes" "context" - "log" + "fmt" "strings" "testing" + "apigo.cc/go/cast" + "apigo.cc/go/jsmod" + "apigo.cc/go/log" "github.com/dop251/goja" ) func TestBridgeSafeMode(t *testing.T) { vm := goja.New() + + // Set up safe context + injects := map[string]any{"SafeMode": true} + ctx := jsmod.NewContext(context.Background(), injects) + vm.Set("__ctx__", vm.ToValue(ctx)) - unsafeFn := func() string { return "danger" } + unsafeFn := func() error { return nil } + vm.Set("unsafe", wrapGoFunc(vm, unsafeFn, true)) - // Register with isUnsafe = true - vm.Set("danger", wrapGoFunc(vm, unsafeFn, true)) - - // 1. Default (SafeMode = true) - _, err := vm.RunString(`danger()`) - if err == nil || !strings.Contains(err.Error(), "blocked by safeMode") { - t.Fatalf("Expected safeMode block, got: %v", err) - } - - // 2. Disable SafeMode - vm.Set("__safeMode__", false) - val, err := vm.RunString(`danger()`) - if err != nil { - t.Fatal(err) - } - if val.Export() != "danger" { - t.Errorf("Expected 'danger', got %v", val.Export()) + _, err := vm.RunString(`unsafe()`) + if err == nil { + t.Error("SafeMode failed to block unsafe function") + } else if !strings.Contains(err.Error(), "unauthorized") { + t.Errorf("Expected unauthorized error, got %v", err) } } func TestBridgeLoggerInjection(t *testing.T) { vm := goja.New() var buf bytes.Buffer - logger := log.New(&buf, "", 0) + logger := log.New("test") + log.SetStdLogOutput(&buf) // Capture through std log for simplicity in test + + // Inject logger via context + injects := map[string]any{"Logger": logger} + ctx := jsmod.NewContext(context.Background(), injects) + vm.Set("__ctx__", vm.ToValue(ctx)) - vm.Set("__logger__", vm.ToValue(logger)) - - logFn := func(l *log.Logger, msg string) { - l.Print(msg) + logFn := func(l *log.Logger) { + l.Info("hello from js") } - vm.Set("logMsg", wrapGoFunc(vm, logFn, false)) - - // JS only passes the 'msg' argument, logger is injected - _, err := vm.RunString(`logMsg("hello from js")`) + vm.Set("log", wrapGoFunc(vm, logFn, false)) + _, err := vm.RunString(`log()`) if err != nil { - t.Fatal(err) - } - - if !strings.Contains(buf.String(), "hello from js") { - t.Errorf("Logger injection failed, buffer: %s", buf.String()) + t.Fatalf("JS execution failed: %v", err) } } func TestBridgeMixedInjection(t *testing.T) { vm := goja.New() - ctx := context.WithValue(context.Background(), "k", "v") - var buf bytes.Buffer - logger := log.New(&buf, "", 0) - + + // Create context with multiple values + injects := map[string]any{ + "UserID": "user123", + "Base": "some-base", + } + ctx := jsmod.NewContext(context.Background(), injects) vm.Set("__ctx__", vm.ToValue(ctx)) - vm.Set("__logger__", vm.ToValue(logger)) - mixedFn := func(c context.Context, l *log.Logger, a int) string { - l.Printf("val: %d", a) - return c.Value("k").(string) + mixedFn := func(c context.Context, a int) string { + uid := cast.String(jsmod.Get(c, "UserID")) + return fmt.Sprintf("%s:%d", uid, a) } vm.Set("mixed", wrapGoFunc(vm, mixedFn, false)) val, err := vm.RunString(`mixed(42)`) if err != nil { - t.Fatal(err) + t.Fatalf("JS execution failed: %v", err) } - if val.Export() != "v" { - t.Errorf("Context injection failed") - } - if !strings.Contains(buf.String(), "val: 42") { - t.Errorf("Logger injection failed") + if val.Export() != "user123:42" { + t.Errorf("Mixed injection failed, got %v", val.Export()) + } +} + +func TestBridgeOptionalParams(t *testing.T) { + vm := goja.New() + + optionalFn := func(a int, b *string) string { + if b == nil { + return fmt.Sprintf("%d:nil", a) + } + return fmt.Sprintf("%d:%s", a, *b) + } + + vm.Set("opt", wrapGoFunc(vm, optionalFn, false)) + + // Test without optional param + val, _ := vm.RunString(`opt(1)`) + if val.Export() != "1:nil" { + t.Errorf("Optional param failed (nil), got %v", val.Export()) + } + + // Test with optional param + val, _ = vm.RunString(`opt(2, "hello")`) + if val.Export() != "2:hello" { + t.Errorf("Optional param failed (val), got %v", val.Export()) } } diff --git a/doc.go b/doc.go index 991385a..76823dd 100644 --- a/doc.go +++ b/doc.go @@ -147,6 +147,8 @@ func formatFunc(t reflect.Type, ctx *docCtx, isMethod bool) string { startIdx = 1 // Skip receiver } + isVariadic := t.IsVariadic() + for i := startIdx; i < numIn; i++ { argType := t.In(i) typeName := argType.String() @@ -154,7 +156,18 @@ func formatFunc(t reflect.Type, ctx *docCtx, isMethod bool) string { continue } - params = append(params, fmt.Sprintf("arg%d: %s", jsArgIdx, goTypeToTS(argType, ctx))) + isLast := i == numIn-1 + paramName := fmt.Sprintf("arg%d", jsArgIdx) + + if isVariadic && isLast { + // Variadic parameters are optional in TS + params = append(params, fmt.Sprintf("...%s: %s", paramName, goTypeToTS(argType.Elem(), ctx))) + } else if argType.Kind() == reflect.Ptr { + // Pointer parameters at the end are optional + params = append(params, fmt.Sprintf("%s?: %s", paramName, goTypeToTS(argType, ctx))) + } else { + params = append(params, fmt.Sprintf("%s: %s", paramName, goTypeToTS(argType, ctx))) + } jsArgIdx++ } diff --git a/pool.go b/pool.go index b794bb0..6b8dac8 100644 --- a/pool.go +++ b/pool.go @@ -8,6 +8,7 @@ import ( "sort" "sync" "sync/atomic" + "time" "apigo.cc/go/jsmod" "apigo.cc/go/log" @@ -81,6 +82,7 @@ func createNewRuntime() *goja.Runtime { } } _ = goObj.Set(modName, modObj) + _ = vm.Set(modName, modObj) // Also inject into global } return vm @@ -146,7 +148,9 @@ func CheckVersion(name string, version int64) bool { } // Call executes a JS function from the pool. -func (p *Pool) Call(ctx context.Context, funcName string, args []any, opts ...CallOption) (any, error) { +// It combines the pool's lifecycle context with the provided timeout. +// injects are added to the context passed to Go functions. +func (p *Pool) Call(funcName string, timeout time.Duration, injects map[string]any, args ...any) (any, error) { if atomic.LoadInt32(&p.closed) == 1 { return nil, fmt.Errorf("js.Pool: pool is closed") } @@ -176,30 +180,33 @@ func (p *Pool) Call(ctx context.Context, funcName string, args []any, opts ...Ca p.mu.RUnlock() } - // 2. Set Context and default state - _ = vm.Set("__ctx__", vm.ToValue(ctx)) - _ = vm.Set("__safeMode__", true) // Default is safe - _ = vm.Set("__logger__", goja.Undefined()) + // 2. Prepare Context + execCtx := jsmod.NewContext(p.ctx, injects) + var cancel context.CancelFunc + if timeout > 0 { + execCtx, cancel = context.WithTimeout(execCtx, timeout) + defer cancel() + } - // Set up context interruption - if ctx != nil && ctx.Done() != nil { - stop := make(chan struct{}) - defer close(stop) - go func() { - select { - case <-ctx.Done(): - vm.Interrupt("context canceled") - case <-stop: + // 3. Set VM environment + _ = vm.Set("__ctx__", vm.ToValue(execCtx)) + + // 4. Set up interruption + stopInterrupter := make(chan struct{}) + defer close(stopInterrupter) + go func() { + select { + case <-execCtx.Done(): + reason := "execution timeout/canceled" + if p.ctx.Err() != nil { + reason = "application stopping" } - }() - } + vm.Interrupt(reason) + case <-stopInterrupter: + } + }() - // Apply Options - for _, opt := range opts { - opt(vm) - } - - // 3. Get and Call JS Function + // 5. Get and Call JS Function fnVal := vm.Get(funcName) if fnVal == nil || goja.IsUndefined(fnVal) { return nil, fmt.Errorf("js.Call: function '%s' not found", funcName) @@ -215,7 +222,7 @@ func (p *Pool) Call(ctx context.Context, funcName string, args []any, opts ...Ca jsArgs[i] = vm.ToValue(arg) } - // 4. Execution with error capture + // 6. Execution with error capture var result goja.Value var err error func() { @@ -240,15 +247,14 @@ func Define(code string, args ...any) { DefaultPool.Define(code, args...) } -func Call(ctx context.Context, funcName string, args []any, opts ...CallOption) (any, error) { - return DefaultPool.Call(ctx, funcName, args, opts...) +func Call(funcName string, timeout time.Duration, injects map[string]any, args ...any) (any, error) { + return DefaultPool.Call(funcName, timeout, injects, args...) } // --- Starter Interface Implementation --- func (p *Pool) Start(ctx context.Context, logger *log.Logger) error { - // For JS engine, start is mostly for pre-warming or registry checking. - // We ensure the context is not canceled. + // Ensure pool context is fresh if p.ctx.Err() != nil { p.ctx, p.cancel = context.WithCancel(context.Background()) } @@ -258,9 +264,9 @@ func (p *Pool) Start(ctx context.Context, logger *log.Logger) error { func (p *Pool) Stop(ctx context.Context) error { atomic.StoreInt32(&p.closed, 1) - p.cancel() // Notify any long-running JS that are context-aware + p.cancel() // Stop all active and future calls - // Wait for active Call() to finish or context timeout + // Wait for active calls to finish done := make(chan struct{}) go func() { p.wg.Wait() @@ -281,25 +287,6 @@ func (p *Pool) Status() (string, error) { return fmt.Sprintf("scripts: %d, functions: %d, version: %d, closed: %v", len(p.scripts), len(p.functions), p.version, atomic.LoadInt32(&p.closed) == 1), nil } -// --- Helper types from original file --- - -// CallOption allows configuring the JS execution environment. -type CallOption func(vm *goja.Runtime) - -// WithSafeMode enables or disables safe mode for the call. -func WithSafeMode(enabled bool) CallOption { - return func(vm *goja.Runtime) { - _ = vm.Set("__safeMode__", enabled) - } -} - -// WithLogger injects a custom logger for the call. -func WithLogger(logger *log.Logger) CallOption { - return func(vm *goja.Runtime) { - _ = vm.Set("__logger__", vm.ToValue(logger)) - } -} - // FuncList returns the list of all defined JS function names. func (p *Pool) FuncList() []string { p.mu.RLock() diff --git a/pool_test.go b/pool_test.go index 4c643fc..f0f46fa 100644 --- a/pool_test.go +++ b/pool_test.go @@ -2,7 +2,9 @@ package js import ( "context" + "strings" "testing" + "time" "apigo.cc/go/cast" ) @@ -19,7 +21,7 @@ func TestPoolVersioning(t *testing.T) { t.Error("expected CheckVersion to be false for v101") } - res, err := p.Call(context.Background(), "hello", []any{"World"}) + res, err := p.Call("hello", 0, nil, "World") if err != nil { t.Fatal(err) } @@ -30,7 +32,7 @@ func TestPoolVersioning(t *testing.T) { // 2. Define new function (incremental update) p.Define(`function add(a, b) { return a + b; }`) - res, err = p.Call(context.Background(), "add", []any{1, 2}) + res, err = p.Call("add", 0, nil, 1, 2) if err != nil { t.Fatal(err) } @@ -66,8 +68,47 @@ func TestPoolConcurrent(t *testing.T) { t.Parallel() for i := 0; i < 10; i++ { go func() { - _, _ = Call(context.Background(), "heavy", []any{1000}) + _, _ = Call("heavy", 0, nil, 1000) }() } }) } + +func TestPoolGracefulShutdown(t *testing.T) { + p := NewPool() + p.Define(`function sleep(ms) { + let start = Date.now(); + while(Date.now() - start < ms); + return "done"; + }`) + + // 1. Test Timeout + _, err := p.Call("sleep", 100*time.Millisecond, nil, 1000) + if err == nil || !strings.Contains(err.Error(), "execution timeout/canceled") { + t.Errorf("expected timeout error, got %v", err) + } + + // 2. Test Graceful Stop + go func() { + time.Sleep(100 * time.Millisecond) + p.Stop(context.Background()) + }() + + _, err = p.Call("sleep", 10*time.Second, nil, 5000) + if err == nil || !strings.Contains(err.Error(), "application stopping") { + t.Errorf("expected app stopping error, got %v", err) + } +} + +func TestGlobalInjection(t *testing.T) { + p := NewPool() + // Test if 'cast' module is available globally without 'go.' prefix + p.Define(`function testGlobal() { return cast.ToJSON({a:1}); }`) + res, err := p.Call("testGlobal", 0, nil) + if err != nil { + t.Fatal(err) + } + if res != `{"a":1}` { + t.Errorf("expected '{\"a\":1}', got %v", res) + } +} diff --git a/shutdown_test.go b/shutdown_test.go deleted file mode 100644 index de37493..0000000 --- a/shutdown_test.go +++ /dev/null @@ -1,51 +0,0 @@ -package js - -import ( - "context" - "testing" - "time" -) - -func TestPoolGracefulShutdown(t *testing.T) { - p := NewPool() - p.Define(`function sleep(ms) { - var start = Date.now(); - while (Date.now() - start < ms); - return "done"; - }`) - - // Start a long running task - errChan := make(chan error, 1) - go func() { - _, err := p.Call(context.Background(), "sleep", []any{500}) - errChan <- err - }() - - // Give it a moment to start - time.Sleep(100 * time.Millisecond) - - // Try to stop the pool - stopCtx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - defer cancel() - - startStop := time.Now() - if err := p.Stop(stopCtx); err != nil { - t.Fatalf("Stop failed: %v", err) - } - stopDuration := time.Since(startStop) - - if stopDuration < 300*time.Millisecond { - t.Errorf("Stop returned too early, expected it to wait for task. Duration: %v", stopDuration) - } - - err := <-errChan - if err != nil { - t.Errorf("Call failed: %v", err) - } - - // New calls should fail - _, err = p.Call(context.Background(), "sleep", []any{10}) - if err == nil || err.Error() != "js.Pool: pool is closed" { - t.Errorf("Expected 'pool is closed' error, got: %v", err) - } -}