Compare commits

...

1 Commits
v1.5.2 ... main

7 changed files with 175 additions and 174 deletions

View File

@ -1,19 +1,17 @@
package js
import (
"context"
"testing"
)
func BenchmarkCall(b *testing.B) {
p := NewPool()
p.Define(`function add(a, b) { return a + b; }`)
ctx := context.Background()
args := []any{1, 2}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := p.Call(ctx, "add", args)
_, err := p.Call("add", 0, nil, args...)
if err != nil {
b.Fatal(err)
}
@ -23,11 +21,11 @@ func BenchmarkCall(b *testing.B) {
func BenchmarkSync(b *testing.B) {
p := NewPool()
code := `function f() { return 1; }`
b.ResetTimer()
for i := 0; i < b.N; i++ {
p.Define(code)
_, err := p.Call(context.Background(), "f", nil)
_, err := p.Call("f", 0, nil)
if err != nil {
b.Fatal(err)
}

View File

@ -3,12 +3,11 @@ package js
import (
"context"
"fmt"
"io"
"log"
"reflect"
"apigo.cc/go/cast"
"apigo.cc/go/jsmod"
"apigo.cc/go/log"
"github.com/dop251/goja"
)
@ -25,10 +24,12 @@ func wrapGoFunc(vm *goja.Runtime, fn any, isUnsafe bool) goja.Value {
return vm.ToValue(func(call goja.FunctionCall) goja.Value {
// 1. Safety Check
safeMode := true // Default to safe mode
smVal := vm.Get("__safeMode__")
if smVal != nil && !goja.IsUndefined(smVal) {
if sm, ok := smVal.Export().(bool); ok {
safeMode = sm
ctxVal := vm.Get("__ctx__")
var currentCtx context.Context
if ctxVal != nil && !goja.IsUndefined(ctxVal) {
if c, ok := ctxVal.Export().(context.Context); ok {
currentCtx = c
safeMode = jsmod.IsSafeMode(c)
}
}
@ -47,15 +48,10 @@ func wrapGoFunc(vm *goja.Runtime, fn any, isUnsafe bool) goja.Value {
// Magic Injection: context.Context
if argType.Implements(reflect.TypeOf((*context.Context)(nil)).Elem()) {
ctx := context.Background()
ctxVal := vm.Get("__ctx__")
if ctxVal != nil && !goja.IsUndefined(ctxVal) {
if c, ok := ctxVal.Export().(context.Context); ok {
ctx = c
}
ctx := currentCtx
if ctx == nil {
ctx = context.Background()
}
// Inject SafeMode status into context
ctx = context.WithValue(ctx, jsmod.SafeModeKey, safeMode)
goArgs[i] = reflect.ValueOf(ctx)
continue
}
@ -63,15 +59,13 @@ func wrapGoFunc(vm *goja.Runtime, fn any, isUnsafe bool) goja.Value {
// Magic Injection: *log.Logger
if argType == reflect.TypeOf((*log.Logger)(nil)) {
var logger *log.Logger
logVal := vm.Get("__logger__")
if logVal != nil && !goja.IsUndefined(logVal) {
if l, ok := logVal.Export().(*log.Logger); ok {
if currentCtx != nil {
if l, ok := jsmod.Get(currentCtx, "Logger").(*log.Logger); ok {
logger = l
}
}
if logger == nil {
// Fallback to a discard logger if none provided to avoid nil panic in Go side
logger = log.New(io.Discard, "", 0)
logger = log.DefaultLogger
}
goArgs[i] = reflect.ValueOf(logger)
continue

View File

@ -3,87 +3,106 @@ package js
import (
"bytes"
"context"
"log"
"fmt"
"strings"
"testing"
"apigo.cc/go/cast"
"apigo.cc/go/jsmod"
"apigo.cc/go/log"
"github.com/dop251/goja"
)
func TestBridgeSafeMode(t *testing.T) {
vm := goja.New()
// Set up safe context
injects := map[string]any{"SafeMode": true}
ctx := jsmod.NewContext(context.Background(), injects)
vm.Set("__ctx__", vm.ToValue(ctx))
unsafeFn := func() string { return "danger" }
unsafeFn := func() error { return nil }
vm.Set("unsafe", wrapGoFunc(vm, unsafeFn, true))
// Register with isUnsafe = true
vm.Set("danger", wrapGoFunc(vm, unsafeFn, true))
// 1. Default (SafeMode = true)
_, err := vm.RunString(`danger()`)
if err == nil || !strings.Contains(err.Error(), "blocked by safeMode") {
t.Fatalf("Expected safeMode block, got: %v", err)
}
// 2. Disable SafeMode
vm.Set("__safeMode__", false)
val, err := vm.RunString(`danger()`)
if err != nil {
t.Fatal(err)
}
if val.Export() != "danger" {
t.Errorf("Expected 'danger', got %v", val.Export())
_, err := vm.RunString(`unsafe()`)
if err == nil {
t.Error("SafeMode failed to block unsafe function")
} else if !strings.Contains(err.Error(), "unauthorized") {
t.Errorf("Expected unauthorized error, got %v", err)
}
}
func TestBridgeLoggerInjection(t *testing.T) {
vm := goja.New()
var buf bytes.Buffer
logger := log.New(&buf, "", 0)
logger := log.New("test")
log.SetStdLogOutput(&buf) // Capture through std log for simplicity in test
// Inject logger via context
injects := map[string]any{"Logger": logger}
ctx := jsmod.NewContext(context.Background(), injects)
vm.Set("__ctx__", vm.ToValue(ctx))
vm.Set("__logger__", vm.ToValue(logger))
logFn := func(l *log.Logger, msg string) {
l.Print(msg)
logFn := func(l *log.Logger) {
l.Info("hello from js")
}
vm.Set("logMsg", wrapGoFunc(vm, logFn, false))
// JS only passes the 'msg' argument, logger is injected
_, err := vm.RunString(`logMsg("hello from js")`)
vm.Set("log", wrapGoFunc(vm, logFn, false))
_, err := vm.RunString(`log()`)
if err != nil {
t.Fatal(err)
}
if !strings.Contains(buf.String(), "hello from js") {
t.Errorf("Logger injection failed, buffer: %s", buf.String())
t.Fatalf("JS execution failed: %v", err)
}
}
func TestBridgeMixedInjection(t *testing.T) {
vm := goja.New()
ctx := context.WithValue(context.Background(), "k", "v")
var buf bytes.Buffer
logger := log.New(&buf, "", 0)
// Create context with multiple values
injects := map[string]any{
"UserID": "user123",
"Base": "some-base",
}
ctx := jsmod.NewContext(context.Background(), injects)
vm.Set("__ctx__", vm.ToValue(ctx))
vm.Set("__logger__", vm.ToValue(logger))
mixedFn := func(c context.Context, l *log.Logger, a int) string {
l.Printf("val: %d", a)
return c.Value("k").(string)
mixedFn := func(c context.Context, a int) string {
uid := cast.String(jsmod.Get(c, "UserID"))
return fmt.Sprintf("%s:%d", uid, a)
}
vm.Set("mixed", wrapGoFunc(vm, mixedFn, false))
val, err := vm.RunString(`mixed(42)`)
if err != nil {
t.Fatal(err)
t.Fatalf("JS execution failed: %v", err)
}
if val.Export() != "v" {
t.Errorf("Context injection failed")
}
if !strings.Contains(buf.String(), "val: 42") {
t.Errorf("Logger injection failed")
if val.Export() != "user123:42" {
t.Errorf("Mixed injection failed, got %v", val.Export())
}
}
func TestBridgeOptionalParams(t *testing.T) {
vm := goja.New()
optionalFn := func(a int, b *string) string {
if b == nil {
return fmt.Sprintf("%d:nil", a)
}
return fmt.Sprintf("%d:%s", a, *b)
}
vm.Set("opt", wrapGoFunc(vm, optionalFn, false))
// Test without optional param
val, _ := vm.RunString(`opt(1)`)
if val.Export() != "1:nil" {
t.Errorf("Optional param failed (nil), got %v", val.Export())
}
// Test with optional param
val, _ = vm.RunString(`opt(2, "hello")`)
if val.Export() != "2:hello" {
t.Errorf("Optional param failed (val), got %v", val.Export())
}
}

15
doc.go
View File

@ -147,6 +147,8 @@ func formatFunc(t reflect.Type, ctx *docCtx, isMethod bool) string {
startIdx = 1 // Skip receiver
}
isVariadic := t.IsVariadic()
for i := startIdx; i < numIn; i++ {
argType := t.In(i)
typeName := argType.String()
@ -154,7 +156,18 @@ func formatFunc(t reflect.Type, ctx *docCtx, isMethod bool) string {
continue
}
params = append(params, fmt.Sprintf("arg%d: %s", jsArgIdx, goTypeToTS(argType, ctx)))
isLast := i == numIn-1
paramName := fmt.Sprintf("arg%d", jsArgIdx)
if isVariadic && isLast {
// Variadic parameters are optional in TS
params = append(params, fmt.Sprintf("...%s: %s", paramName, goTypeToTS(argType.Elem(), ctx)))
} else if argType.Kind() == reflect.Ptr {
// Pointer parameters at the end are optional
params = append(params, fmt.Sprintf("%s?: %s", paramName, goTypeToTS(argType, ctx)))
} else {
params = append(params, fmt.Sprintf("%s: %s", paramName, goTypeToTS(argType, ctx)))
}
jsArgIdx++
}

83
pool.go
View File

@ -8,6 +8,7 @@ import (
"sort"
"sync"
"sync/atomic"
"time"
"apigo.cc/go/jsmod"
"apigo.cc/go/log"
@ -81,6 +82,7 @@ func createNewRuntime() *goja.Runtime {
}
}
_ = goObj.Set(modName, modObj)
_ = vm.Set(modName, modObj) // Also inject into global
}
return vm
@ -146,7 +148,9 @@ func CheckVersion(name string, version int64) bool {
}
// Call executes a JS function from the pool.
func (p *Pool) Call(ctx context.Context, funcName string, args []any, opts ...CallOption) (any, error) {
// It combines the pool's lifecycle context with the provided timeout.
// injects are added to the context passed to Go functions.
func (p *Pool) Call(funcName string, timeout time.Duration, injects map[string]any, args ...any) (any, error) {
if atomic.LoadInt32(&p.closed) == 1 {
return nil, fmt.Errorf("js.Pool: pool is closed")
}
@ -176,30 +180,33 @@ func (p *Pool) Call(ctx context.Context, funcName string, args []any, opts ...Ca
p.mu.RUnlock()
}
// 2. Set Context and default state
_ = vm.Set("__ctx__", vm.ToValue(ctx))
_ = vm.Set("__safeMode__", true) // Default is safe
_ = vm.Set("__logger__", goja.Undefined())
// 2. Prepare Context
execCtx := jsmod.NewContext(p.ctx, injects)
var cancel context.CancelFunc
if timeout > 0 {
execCtx, cancel = context.WithTimeout(execCtx, timeout)
defer cancel()
}
// Set up context interruption
if ctx != nil && ctx.Done() != nil {
stop := make(chan struct{})
defer close(stop)
go func() {
select {
case <-ctx.Done():
vm.Interrupt("context canceled")
case <-stop:
// 3. Set VM environment
_ = vm.Set("__ctx__", vm.ToValue(execCtx))
// 4. Set up interruption
stopInterrupter := make(chan struct{})
defer close(stopInterrupter)
go func() {
select {
case <-execCtx.Done():
reason := "execution timeout/canceled"
if p.ctx.Err() != nil {
reason = "application stopping"
}
}()
}
vm.Interrupt(reason)
case <-stopInterrupter:
}
}()
// Apply Options
for _, opt := range opts {
opt(vm)
}
// 3. Get and Call JS Function
// 5. Get and Call JS Function
fnVal := vm.Get(funcName)
if fnVal == nil || goja.IsUndefined(fnVal) {
return nil, fmt.Errorf("js.Call: function '%s' not found", funcName)
@ -215,7 +222,7 @@ func (p *Pool) Call(ctx context.Context, funcName string, args []any, opts ...Ca
jsArgs[i] = vm.ToValue(arg)
}
// 4. Execution with error capture
// 6. Execution with error capture
var result goja.Value
var err error
func() {
@ -240,15 +247,14 @@ func Define(code string, args ...any) {
DefaultPool.Define(code, args...)
}
func Call(ctx context.Context, funcName string, args []any, opts ...CallOption) (any, error) {
return DefaultPool.Call(ctx, funcName, args, opts...)
func Call(funcName string, timeout time.Duration, injects map[string]any, args ...any) (any, error) {
return DefaultPool.Call(funcName, timeout, injects, args...)
}
// --- Starter Interface Implementation ---
func (p *Pool) Start(ctx context.Context, logger *log.Logger) error {
// For JS engine, start is mostly for pre-warming or registry checking.
// We ensure the context is not canceled.
// Ensure pool context is fresh
if p.ctx.Err() != nil {
p.ctx, p.cancel = context.WithCancel(context.Background())
}
@ -258,9 +264,9 @@ func (p *Pool) Start(ctx context.Context, logger *log.Logger) error {
func (p *Pool) Stop(ctx context.Context) error {
atomic.StoreInt32(&p.closed, 1)
p.cancel() // Notify any long-running JS that are context-aware
p.cancel() // Stop all active and future calls
// Wait for active Call() to finish or context timeout
// Wait for active calls to finish
done := make(chan struct{})
go func() {
p.wg.Wait()
@ -281,25 +287,6 @@ func (p *Pool) Status() (string, error) {
return fmt.Sprintf("scripts: %d, functions: %d, version: %d, closed: %v", len(p.scripts), len(p.functions), p.version, atomic.LoadInt32(&p.closed) == 1), nil
}
// --- Helper types from original file ---
// CallOption allows configuring the JS execution environment.
type CallOption func(vm *goja.Runtime)
// WithSafeMode enables or disables safe mode for the call.
func WithSafeMode(enabled bool) CallOption {
return func(vm *goja.Runtime) {
_ = vm.Set("__safeMode__", enabled)
}
}
// WithLogger injects a custom logger for the call.
func WithLogger(logger *log.Logger) CallOption {
return func(vm *goja.Runtime) {
_ = vm.Set("__logger__", vm.ToValue(logger))
}
}
// FuncList returns the list of all defined JS function names.
func (p *Pool) FuncList() []string {
p.mu.RLock()

View File

@ -2,7 +2,9 @@ package js
import (
"context"
"strings"
"testing"
"time"
"apigo.cc/go/cast"
)
@ -19,7 +21,7 @@ func TestPoolVersioning(t *testing.T) {
t.Error("expected CheckVersion to be false for v101")
}
res, err := p.Call(context.Background(), "hello", []any{"World"})
res, err := p.Call("hello", 0, nil, "World")
if err != nil {
t.Fatal(err)
}
@ -30,7 +32,7 @@ func TestPoolVersioning(t *testing.T) {
// 2. Define new function (incremental update)
p.Define(`function add(a, b) { return a + b; }`)
res, err = p.Call(context.Background(), "add", []any{1, 2})
res, err = p.Call("add", 0, nil, 1, 2)
if err != nil {
t.Fatal(err)
}
@ -66,8 +68,47 @@ func TestPoolConcurrent(t *testing.T) {
t.Parallel()
for i := 0; i < 10; i++ {
go func() {
_, _ = Call(context.Background(), "heavy", []any{1000})
_, _ = Call("heavy", 0, nil, 1000)
}()
}
})
}
func TestPoolGracefulShutdown(t *testing.T) {
p := NewPool()
p.Define(`function sleep(ms) {
let start = Date.now();
while(Date.now() - start < ms);
return "done";
}`)
// 1. Test Timeout
_, err := p.Call("sleep", 100*time.Millisecond, nil, 1000)
if err == nil || !strings.Contains(err.Error(), "execution timeout/canceled") {
t.Errorf("expected timeout error, got %v", err)
}
// 2. Test Graceful Stop
go func() {
time.Sleep(100 * time.Millisecond)
p.Stop(context.Background())
}()
_, err = p.Call("sleep", 10*time.Second, nil, 5000)
if err == nil || !strings.Contains(err.Error(), "application stopping") {
t.Errorf("expected app stopping error, got %v", err)
}
}
func TestGlobalInjection(t *testing.T) {
p := NewPool()
// Test if 'cast' module is available globally without 'go.' prefix
p.Define(`function testGlobal() { return cast.ToJSON({a:1}); }`)
res, err := p.Call("testGlobal", 0, nil)
if err != nil {
t.Fatal(err)
}
if res != `{"a":1}` {
t.Errorf("expected '{\"a\":1}', got %v", res)
}
}

View File

@ -1,51 +0,0 @@
package js
import (
"context"
"testing"
"time"
)
func TestPoolGracefulShutdown(t *testing.T) {
p := NewPool()
p.Define(`function sleep(ms) {
var start = Date.now();
while (Date.now() - start < ms);
return "done";
}`)
// Start a long running task
errChan := make(chan error, 1)
go func() {
_, err := p.Call(context.Background(), "sleep", []any{500})
errChan <- err
}()
// Give it a moment to start
time.Sleep(100 * time.Millisecond)
// Try to stop the pool
stopCtx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
startStop := time.Now()
if err := p.Stop(stopCtx); err != nil {
t.Fatalf("Stop failed: %v", err)
}
stopDuration := time.Since(startStop)
if stopDuration < 300*time.Millisecond {
t.Errorf("Stop returned too early, expected it to wait for task. Duration: %v", stopDuration)
}
err := <-errChan
if err != nil {
t.Errorf("Call failed: %v", err)
}
// New calls should fail
_, err = p.Call(context.Background(), "sleep", []any{10})
if err == nil || err.Error() != "js.Pool: pool is closed" {
t.Errorf("Expected 'pool is closed' error, got: %v", err)
}
}