diff --git a/bridge.go b/bridge.go index cd25197..b07f179 100644 --- a/bridge.go +++ b/bridge.go @@ -3,6 +3,8 @@ package js import ( "context" "fmt" + "io" + "log" "reflect" "apigo.cc/go/cast" @@ -10,8 +12,8 @@ import ( ) // wrapGoFunc converts a standard Go function into a goja.Callable. -// It handles context injection and automatic type conversion via go/cast. -func wrapGoFunc(vm *goja.Runtime, fn any) goja.Value { +// It handles context/logger injection, safeMode enforcement, and automatic type conversion. +func wrapGoFunc(vm *goja.Runtime, fn any, isUnsafe bool) goja.Value { v := reflect.ValueOf(fn) if v.Kind() != reflect.Func { panic(fmt.Sprintf("js.bridge: expected func, got %T", fn)) @@ -20,50 +22,76 @@ func wrapGoFunc(vm *goja.Runtime, fn any) goja.Value { t := v.Type() return vm.ToValue(func(call goja.FunctionCall) goja.Value { - // 1. Prepare Arguments + // 1. Safety Check + if isUnsafe { + safeMode := true // Default to safe mode + smVal := vm.Get("__safeMode__") + if smVal != nil && !goja.IsUndefined(smVal) { + if sm, ok := smVal.Export().(bool); ok { + safeMode = sm + } + } + if safeMode { + panic(vm.NewGoError(fmt.Errorf("unauthorized: unsafe operation blocked by safeMode"))) + } + } + + // 2. Prepare Arguments numIn := t.NumIn() goArgs := make([]reflect.Value, numIn) jsArgs := call.Arguments jsArgIdx := 0 - // Handle context.Context injection - startIdx := 0 - if numIn > 0 && t.In(0).Implements(reflect.TypeOf((*context.Context)(nil)).Elem()) { - // Inject context from VM's current execution context if available, - // otherwise use Background. (We can improve this by storing ctx in VM's data) - ctx := context.Background() - if c, ok := vm.Get("__ctx__").Export().(context.Context); ok { - ctx = c - } - goArgs[0] = reflect.ValueOf(ctx) - startIdx = 1 - } - - for i := startIdx; i < numIn; i++ { + for i := 0; i < numIn; i++ { argType := t.In(i) - goArgs[i] = reflect.New(argType).Elem() + // Magic Injection: context.Context + if argType.Implements(reflect.TypeOf((*context.Context)(nil)).Elem()) { + ctx := context.Background() + ctxVal := vm.Get("__ctx__") + if ctxVal != nil && !goja.IsUndefined(ctxVal) { + if c, ok := ctxVal.Export().(context.Context); ok { + ctx = c + } + } + goArgs[i] = reflect.ValueOf(ctx) + continue + } + + // Magic Injection: *log.Logger + if argType == reflect.TypeOf((*log.Logger)(nil)) { + var logger *log.Logger + logVal := vm.Get("__logger__") + if logVal != nil && !goja.IsUndefined(logVal) { + if l, ok := logVal.Export().(*log.Logger); ok { + logger = l + } + } + if logger == nil { + // Fallback to a discard logger if none provided to avoid nil panic in Go side + logger = log.New(io.Discard, "", 0) + } + goArgs[i] = reflect.ValueOf(logger) + continue + } + + // Normal JS Argument with go/cast + goArgs[i] = reflect.New(argType).Elem() if jsArgIdx < len(jsArgs) { jsVal := jsArgs[jsArgIdx] - // Use goja's Export() to get a Go-compatible value exported := jsVal.Export() - // First, try direct assignment to preserve pointer identity (Host Object fidelity) expV := reflect.ValueOf(exported) if expV.IsValid() && expV.Type().AssignableTo(argType) { goArgs[i].Set(expV) } else { - // Otherwise, use go/cast to convert to the target Go type (frictionless) cast.Convert(goArgs[i].Addr().Interface(), exported) } jsArgIdx++ - } else { - // If JS args are missing, cast will keep it as zero value (frictionless) } } - // 2. Call the Go function - // We use recover to catch Go panics and turn them into JS errors + // 3. Call the Go function var results []reflect.Value var recovered any func() { @@ -75,33 +103,28 @@ func wrapGoFunc(vm *goja.Runtime, fn any) goja.Value { panic(vm.NewGoError(fmt.Errorf("go panic: %v", recovered))) } - // 3. Process Results + // 4. Process Results if len(results) == 0 { return goja.Undefined() } - // If the last return value is an error, check it - if len(results) > 0 { - last := results[len(results)-1] - if last.Type().Implements(reflect.TypeOf((*error)(nil)).Elem()) { - if !last.IsNil() { - err := last.Interface().(error) - panic(vm.NewGoError(err)) - } - // If it's an error but nil, exclude it from normal results if it's the only result - if len(results) == 1 { - return goja.Undefined() - } - // Otherwise, we take results up to len-1 - results = results[:len(results)-1] + // Check for error return + last := results[len(results)-1] + if last.Type().Implements(reflect.TypeOf((*error)(nil)).Elem()) { + if !last.IsNil() { + err := last.Interface().(error) + panic(vm.NewGoError(err)) } + if len(results) == 1 { + return goja.Undefined() + } + results = results[:len(results)-1] } if len(results) == 1 { return vm.ToValue(results[0].Interface()) } - // Multiple return values (other than the handled error) are returned as a JS array resSlice := make([]any, len(results)) for i, r := range results { resSlice[i] = r.Interface() diff --git a/bridge_test.go b/bridge_test.go index ee456c0..b457ce1 100644 --- a/bridge_test.go +++ b/bridge_test.go @@ -1,162 +1,89 @@ package js import ( + "bytes" "context" - "errors" + "log" + "strings" "testing" - "apigo.cc/go/jsmod" "github.com/dop251/goja" ) -type User struct { - ID int - Name string -} - -func (u *User) GetInfo() string { - return u.Name -} - -func TestBridgeDataFidelity(t *testing.T) { +func TestBridgeSafeMode(t *testing.T) { vm := goja.New() - // 1. Setup Go functions - originalUser := &User{ID: 1, Name: "Star"} + unsafeFn := func() string { return "danger" } - getUser := func() *User { - return originalUser + // Register with isUnsafe = true + vm.Set("danger", wrapGoFunc(vm, unsafeFn, true)) + + // 1. Default (SafeMode = true) + _, err := vm.RunString(`danger()`) + if err == nil || !strings.Contains(err.Error(), "blocked by safeMode") { + t.Fatalf("Expected safeMode block, got: %v", err) } - verifyUser := func(u *User) bool { - // Verify pointer address remains the same (Host Object fidelity) - return u == originalUser - } - - // Register functions manually for testing bridge - vm.Set("getUser", wrapGoFunc(vm, getUser)) - vm.Set("verifyUser", wrapGoFunc(vm, verifyUser)) - - // 2. JS Execution - script := ` - let u = getUser(); - if (u.Name !== "Star") throw "Name mismatch: " + u.Name; - if (u.ID !== 1) throw "ID mismatch: " + u.ID; - - // Host Object method call (if exported) - // Note: goja requires methods to be exported and usually works better with struct pointers - - let isSame = verifyUser(u); - if (!isSame) throw "Pointer mismatch in Go side"; - - "ok" - ` - val, err := vm.RunString(script) - if err != nil { - t.Fatalf("JS execution failed: %v", err) - } - if val.Export() != "ok" { - t.Errorf("expected 'ok', got %v", val.Export()) - } -} - -func TestBridgeCasting(t *testing.T) { - vm := goja.New() - - sum := func(a, b int) int { - return a + b - } - - vm.Set("sum", wrapGoFunc(vm, sum)) - - // Test passing string as number (frictionless casting via go/cast) - script := `sum("10", 20)` - val, err := vm.RunString(script) + // 2. Disable SafeMode + vm.Set("__safeMode__", false) + val, err := vm.RunString(`danger()`) if err != nil { t.Fatal(err) } - if val.Export().(int64) != 30 { - t.Errorf("expected 30, got %v", val.Export()) + if val.Export() != "danger" { + t.Errorf("Expected 'danger', got %v", val.Export()) } } -func TestBridgeErrorHandling(t *testing.T) { +func TestBridgeLoggerInjection(t *testing.T) { vm := goja.New() + var buf bytes.Buffer + logger := log.New(&buf, "", 0) - failFunc := func() (string, error) { - return "", errors.New("go_error") + vm.Set("__logger__", vm.ToValue(logger)) + + logFn := func(l *log.Logger, msg string) { + l.Print(msg) } - vm.Set("failFunc", wrapGoFunc(vm, failFunc)) + vm.Set("logMsg", wrapGoFunc(vm, logFn, false)) - script := ` - try { - failFunc(); - } catch (e) { - e.message; - } - ` - val, err := vm.RunString(script) + // JS only passes the 'msg' argument, logger is injected + _, err := vm.RunString(`logMsg("hello from js")`) if err != nil { t.Fatal(err) } - if val.Export() != "go_error" { - t.Errorf("expected 'go_error', got %v", val.Export()) + + if !strings.Contains(buf.String(), "hello from js") { + t.Errorf("Logger injection failed, buffer: %s", buf.String()) } } -func TestBridgeContextInjection(t *testing.T) { +func TestBridgeMixedInjection(t *testing.T) { vm := goja.New() - ctx := context.WithValue(context.Background(), "key", "value") - - // Inject context into VM + ctx := context.WithValue(context.Background(), "k", "v") + var buf bytes.Buffer + logger := log.New(&buf, "", 0) + vm.Set("__ctx__", vm.ToValue(ctx)) + vm.Set("__logger__", vm.ToValue(logger)) - checkCtx := func(c context.Context) string { - return c.Value("key").(string) + mixedFn := func(c context.Context, l *log.Logger, a int) string { + l.Printf("val: %d", a) + return c.Value("k").(string) } - vm.Set("checkCtx", wrapGoFunc(vm, checkCtx)) + vm.Set("mixed", wrapGoFunc(vm, mixedFn, false)) - val, err := vm.RunString(`checkCtx()`) + val, err := vm.RunString(`mixed(42)`) if err != nil { t.Fatal(err) } - if val.Export() != "value" { - t.Errorf("expected 'value', got %v", val.Export()) + + if val.Export() != "v" { + t.Errorf("Context injection failed") + } + if !strings.Contains(buf.String(), "val: 42") { + t.Errorf("Logger injection failed") } } - -func TestBridgeComplexStruct(t *testing.T) { - vm := goja.New() - - type Complex struct { - Data map[string]any - Tags []string - } - - process := func(c Complex) int { - return len(c.Tags) + len(c.Data) - } - - vm.Set("process", wrapGoFunc(vm, process)) - - script := ` - process({ - Tags: ["a", "b"], - Data: { "x": 1, "y": 2, "z": 3 } - }) - ` - val, err := vm.RunString(script) - if err != nil { - t.Fatal(err) - } - if val.Export().(int64) != 5 { - t.Errorf("expected 5, got %v", val.Export()) - } -} - -// Ensure jsmod is used to avoid unused import if needed -func init() { - _ = jsmod.GetModules() -} diff --git a/doc.go b/doc.go index 8e70d39..0a0574b 100644 --- a/doc.go +++ b/doc.go @@ -24,17 +24,21 @@ func Doc() string { sb.WriteString("declare namespace go {\n") for _, modName := range keys { - exports := modules[modName] + mod := modules[modName] sb.WriteString(fmt.Sprintf(" namespace %s {\n", modName)) - - expKeys := make([]string, 0, len(exports)) - for k := range exports { + + expKeys := make([]string, 0, len(mod.Exports)) + for k := range mod.Exports { expKeys = append(expKeys, k) } sort.Strings(expKeys) for _, name := range expKeys { - val := exports[name] + val := mod.Exports[name] + isUnsafe := mod.UnsafeList[name] + if isUnsafe { + sb.WriteString(" /** @unsafe */\n") + } sb.WriteString(fmt.Sprintf(" %s\n", formatExport(name, val))) } sb.WriteString(" }\n") @@ -60,14 +64,18 @@ func formatExport(name string, val any) string { func formatFunc(t reflect.Type) string { var params []string numIn := t.NumIn() - startIdx := 0 - // Skip context.Context in TS doc as it's injected automatically - if numIn > 0 && t.In(0).String() == "context.Context" { - startIdx = 1 - } + jsArgIdx := 0 - for i := startIdx; i < numIn; i++ { - params = append(params, fmt.Sprintf("arg%d: %s", i-startIdx, goTypeToTS(t.In(i)))) + for i := 0; i < numIn; i++ { + argType := t.In(i) + // Skip Context and Logger in TS doc + if argType.Implements(reflect.TypeOf((*interface{ Done() <-chan struct{} })(nil)).Elem()) || + argType.String() == "*log.Logger" { + continue + } + + params = append(params, fmt.Sprintf("arg%d: %s", jsArgIdx, goTypeToTS(argType))) + jsArgIdx++ } // Handle return values @@ -78,7 +86,7 @@ func formatFunc(t reflect.Type) string { } else { // If last return is error, we only care about the first part for TS doc realOut := numOut - if numOut > 0 && t.Out(numOut-1).String() == "error" { + if numOut > 0 && t.Out(numOut-1).Implements(reflect.TypeOf((*error)(nil)).Elem()) { realOut-- } @@ -98,7 +106,7 @@ func goTypeToTS(t reflect.Type) string { if t == nil { return "any" } - + // Handle pointers for t.Kind() == reflect.Ptr { t = t.Elem() @@ -118,9 +126,6 @@ func goTypeToTS(t reflect.Type) string { case reflect.Map: return "Record" case reflect.Struct: - // For structs, we could recursively list fields, but for a concise AI doc, - // "any" or the struct name is often sufficient. - // Let's at least show it's an object. return "{ [key: string]: any }" case reflect.Interface: return "any" diff --git a/pool.go b/pool.go index 842c924..5a855e4 100644 --- a/pool.go +++ b/pool.go @@ -3,6 +3,7 @@ package js import ( "context" "fmt" + "log" "sync" "sync/atomic" @@ -38,11 +39,12 @@ func createNewRuntime() *goja.Runtime { _ = vm.Set("go", goObj) modules := jsmod.GetModules() - for modName, exports := range modules { + for modName, mod := range modules { modObj := vm.NewObject() - for name, val := range exports { + for name, val := range mod.Exports { + isUnsafe := mod.UnsafeList[name] if reflectType := fmt.Sprintf("%T", val); reflectType == "func" || (len(reflectType) > 4 && reflectType[:4] == "func") { - _ = modObj.Set(name, wrapGoFunc(vm, val)) + _ = modObj.Set(name, wrapGoFunc(vm, val, isUnsafe)) } else { _ = modObj.Set(name, vm.ToValue(val)) } @@ -63,9 +65,26 @@ func Define(code string) { atomic.AddInt32(&globalVersion, 1) } +// CallOption allows configuring the JS execution environment. +type CallOption func(vm *goja.Runtime) + +// WithSafeMode enables or disables safe mode for the call. +func WithSafeMode(enabled bool) CallOption { + return func(vm *goja.Runtime) { + _ = vm.Set("__safeMode__", enabled) + } +} + +// WithLogger injects a custom logger for the call. +func WithLogger(logger *log.Logger) CallOption { + return func(vm *goja.Runtime) { + _ = vm.Set("__logger__", vm.ToValue(logger)) + } +} + // Call executes a JS function from the pool. // It automatically synchronizes the VM to the latest version. -func Call(ctx context.Context, funcName string, args ...any) (any, error) { +func Call(ctx context.Context, funcName string, args []any, opts ...CallOption) (any, error) { instance := pool.Get().(*vmInstance) defer pool.Put(instance) @@ -86,8 +105,15 @@ func Call(ctx context.Context, funcName string, args ...any) (any, error) { scriptsMu.RUnlock() } - // 2. Set Context + // 2. Set Context and default state _ = vm.Set("__ctx__", vm.ToValue(ctx)) + _ = vm.Set("__safeMode__", true) // Default is safe + _ = vm.Set("__logger__", goja.Undefined()) + + // Apply Options + for _, opt := range opts { + opt(vm) + } // 3. Get and Call JS Function fnVal := vm.Get(funcName) @@ -128,7 +154,5 @@ func Call(ctx context.Context, funcName string, args ...any) (any, error) { func FuncList() []string { scriptsMu.RLock() defer scriptsMu.RUnlock() - // In a real implementation, we would extract function names from scripts. - // For now, this is a placeholder. return []string{} } diff --git a/pool_test.go b/pool_test.go index b9c1d9d..e391294 100644 --- a/pool_test.go +++ b/pool_test.go @@ -9,7 +9,7 @@ func TestPoolVersioning(t *testing.T) { // 1. Define initial function Define(`function hello(name) { return "Hello " + name; }`) - res, err := Call(context.Background(), "hello", "World") + res, err := Call(context.Background(), "hello", []any{"World"}) if err != nil { t.Fatal(err) } @@ -20,7 +20,7 @@ func TestPoolVersioning(t *testing.T) { // 2. Define new function (incremental update) Define(`function add(a, b) { return a + b; }`) - res, err = Call(context.Background(), "add", 1, 2) + res, err = Call(context.Background(), "add", []any{1, 2}) if err != nil { t.Fatal(err) } @@ -29,7 +29,7 @@ func TestPoolVersioning(t *testing.T) { } // 3. Ensure old function still works - res, err = Call(context.Background(), "hello", "Again") + res, err = Call(context.Background(), "hello", []any{"Again"}) if err != nil { t.Fatal(err) } @@ -49,7 +49,7 @@ func TestPoolConcurrent(t *testing.T) { t.Parallel() for i := 0; i < 10; i++ { go func() { - _, _ = Call(context.Background(), "heavy", 1000) + _, _ = Call(context.Background(), "heavy", []any{1000}) }() } })