feat: upgrade bridge with safeMode and magic injection (by AI)
This commit is contained in:
parent
7632cea6f6
commit
65228a6707
79
bridge.go
79
bridge.go
@ -3,6 +3,8 @@ package js
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"reflect"
|
||||
|
||||
"apigo.cc/go/cast"
|
||||
@ -10,8 +12,8 @@ import (
|
||||
)
|
||||
|
||||
// wrapGoFunc converts a standard Go function into a goja.Callable.
|
||||
// It handles context injection and automatic type conversion via go/cast.
|
||||
func wrapGoFunc(vm *goja.Runtime, fn any) goja.Value {
|
||||
// It handles context/logger injection, safeMode enforcement, and automatic type conversion.
|
||||
func wrapGoFunc(vm *goja.Runtime, fn any, isUnsafe bool) goja.Value {
|
||||
v := reflect.ValueOf(fn)
|
||||
if v.Kind() != reflect.Func {
|
||||
panic(fmt.Sprintf("js.bridge: expected func, got %T", fn))
|
||||
@ -20,50 +22,76 @@ func wrapGoFunc(vm *goja.Runtime, fn any) goja.Value {
|
||||
t := v.Type()
|
||||
|
||||
return vm.ToValue(func(call goja.FunctionCall) goja.Value {
|
||||
// 1. Prepare Arguments
|
||||
// 1. Safety Check
|
||||
if isUnsafe {
|
||||
safeMode := true // Default to safe mode
|
||||
smVal := vm.Get("__safeMode__")
|
||||
if smVal != nil && !goja.IsUndefined(smVal) {
|
||||
if sm, ok := smVal.Export().(bool); ok {
|
||||
safeMode = sm
|
||||
}
|
||||
}
|
||||
if safeMode {
|
||||
panic(vm.NewGoError(fmt.Errorf("unauthorized: unsafe operation blocked by safeMode")))
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Prepare Arguments
|
||||
numIn := t.NumIn()
|
||||
goArgs := make([]reflect.Value, numIn)
|
||||
jsArgs := call.Arguments
|
||||
jsArgIdx := 0
|
||||
|
||||
// Handle context.Context injection
|
||||
startIdx := 0
|
||||
if numIn > 0 && t.In(0).Implements(reflect.TypeOf((*context.Context)(nil)).Elem()) {
|
||||
// Inject context from VM's current execution context if available,
|
||||
// otherwise use Background. (We can improve this by storing ctx in VM's data)
|
||||
for i := 0; i < numIn; i++ {
|
||||
argType := t.In(i)
|
||||
|
||||
// Magic Injection: context.Context
|
||||
if argType.Implements(reflect.TypeOf((*context.Context)(nil)).Elem()) {
|
||||
ctx := context.Background()
|
||||
if c, ok := vm.Get("__ctx__").Export().(context.Context); ok {
|
||||
ctxVal := vm.Get("__ctx__")
|
||||
if ctxVal != nil && !goja.IsUndefined(ctxVal) {
|
||||
if c, ok := ctxVal.Export().(context.Context); ok {
|
||||
ctx = c
|
||||
}
|
||||
goArgs[0] = reflect.ValueOf(ctx)
|
||||
startIdx = 1
|
||||
}
|
||||
goArgs[i] = reflect.ValueOf(ctx)
|
||||
continue
|
||||
}
|
||||
|
||||
for i := startIdx; i < numIn; i++ {
|
||||
argType := t.In(i)
|
||||
goArgs[i] = reflect.New(argType).Elem()
|
||||
// Magic Injection: *log.Logger
|
||||
if argType == reflect.TypeOf((*log.Logger)(nil)) {
|
||||
var logger *log.Logger
|
||||
logVal := vm.Get("__logger__")
|
||||
if logVal != nil && !goja.IsUndefined(logVal) {
|
||||
if l, ok := logVal.Export().(*log.Logger); ok {
|
||||
logger = l
|
||||
}
|
||||
}
|
||||
if logger == nil {
|
||||
// Fallback to a discard logger if none provided to avoid nil panic in Go side
|
||||
logger = log.New(io.Discard, "", 0)
|
||||
}
|
||||
goArgs[i] = reflect.ValueOf(logger)
|
||||
continue
|
||||
}
|
||||
|
||||
// Normal JS Argument with go/cast
|
||||
goArgs[i] = reflect.New(argType).Elem()
|
||||
if jsArgIdx < len(jsArgs) {
|
||||
jsVal := jsArgs[jsArgIdx]
|
||||
// Use goja's Export() to get a Go-compatible value
|
||||
exported := jsVal.Export()
|
||||
|
||||
// First, try direct assignment to preserve pointer identity (Host Object fidelity)
|
||||
expV := reflect.ValueOf(exported)
|
||||
if expV.IsValid() && expV.Type().AssignableTo(argType) {
|
||||
goArgs[i].Set(expV)
|
||||
} else {
|
||||
// Otherwise, use go/cast to convert to the target Go type (frictionless)
|
||||
cast.Convert(goArgs[i].Addr().Interface(), exported)
|
||||
}
|
||||
jsArgIdx++
|
||||
} else {
|
||||
// If JS args are missing, cast will keep it as zero value (frictionless)
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Call the Go function
|
||||
// We use recover to catch Go panics and turn them into JS errors
|
||||
// 3. Call the Go function
|
||||
var results []reflect.Value
|
||||
var recovered any
|
||||
func() {
|
||||
@ -75,33 +103,28 @@ func wrapGoFunc(vm *goja.Runtime, fn any) goja.Value {
|
||||
panic(vm.NewGoError(fmt.Errorf("go panic: %v", recovered)))
|
||||
}
|
||||
|
||||
// 3. Process Results
|
||||
// 4. Process Results
|
||||
if len(results) == 0 {
|
||||
return goja.Undefined()
|
||||
}
|
||||
|
||||
// If the last return value is an error, check it
|
||||
if len(results) > 0 {
|
||||
// Check for error return
|
||||
last := results[len(results)-1]
|
||||
if last.Type().Implements(reflect.TypeOf((*error)(nil)).Elem()) {
|
||||
if !last.IsNil() {
|
||||
err := last.Interface().(error)
|
||||
panic(vm.NewGoError(err))
|
||||
}
|
||||
// If it's an error but nil, exclude it from normal results if it's the only result
|
||||
if len(results) == 1 {
|
||||
return goja.Undefined()
|
||||
}
|
||||
// Otherwise, we take results up to len-1
|
||||
results = results[:len(results)-1]
|
||||
}
|
||||
}
|
||||
|
||||
if len(results) == 1 {
|
||||
return vm.ToValue(results[0].Interface())
|
||||
}
|
||||
|
||||
// Multiple return values (other than the handled error) are returned as a JS array
|
||||
resSlice := make([]any, len(results))
|
||||
for i, r := range results {
|
||||
resSlice[i] = r.Interface()
|
||||
|
||||
161
bridge_test.go
161
bridge_test.go
@ -1,162 +1,89 @@
|
||||
package js
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"apigo.cc/go/jsmod"
|
||||
"github.com/dop251/goja"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID int
|
||||
Name string
|
||||
}
|
||||
|
||||
func (u *User) GetInfo() string {
|
||||
return u.Name
|
||||
}
|
||||
|
||||
func TestBridgeDataFidelity(t *testing.T) {
|
||||
func TestBridgeSafeMode(t *testing.T) {
|
||||
vm := goja.New()
|
||||
|
||||
// 1. Setup Go functions
|
||||
originalUser := &User{ID: 1, Name: "Star"}
|
||||
unsafeFn := func() string { return "danger" }
|
||||
|
||||
getUser := func() *User {
|
||||
return originalUser
|
||||
// Register with isUnsafe = true
|
||||
vm.Set("danger", wrapGoFunc(vm, unsafeFn, true))
|
||||
|
||||
// 1. Default (SafeMode = true)
|
||||
_, err := vm.RunString(`danger()`)
|
||||
if err == nil || !strings.Contains(err.Error(), "blocked by safeMode") {
|
||||
t.Fatalf("Expected safeMode block, got: %v", err)
|
||||
}
|
||||
|
||||
verifyUser := func(u *User) bool {
|
||||
// Verify pointer address remains the same (Host Object fidelity)
|
||||
return u == originalUser
|
||||
}
|
||||
|
||||
// Register functions manually for testing bridge
|
||||
vm.Set("getUser", wrapGoFunc(vm, getUser))
|
||||
vm.Set("verifyUser", wrapGoFunc(vm, verifyUser))
|
||||
|
||||
// 2. JS Execution
|
||||
script := `
|
||||
let u = getUser();
|
||||
if (u.Name !== "Star") throw "Name mismatch: " + u.Name;
|
||||
if (u.ID !== 1) throw "ID mismatch: " + u.ID;
|
||||
|
||||
// Host Object method call (if exported)
|
||||
// Note: goja requires methods to be exported and usually works better with struct pointers
|
||||
|
||||
let isSame = verifyUser(u);
|
||||
if (!isSame) throw "Pointer mismatch in Go side";
|
||||
|
||||
"ok"
|
||||
`
|
||||
val, err := vm.RunString(script)
|
||||
if err != nil {
|
||||
t.Fatalf("JS execution failed: %v", err)
|
||||
}
|
||||
if val.Export() != "ok" {
|
||||
t.Errorf("expected 'ok', got %v", val.Export())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBridgeCasting(t *testing.T) {
|
||||
vm := goja.New()
|
||||
|
||||
sum := func(a, b int) int {
|
||||
return a + b
|
||||
}
|
||||
|
||||
vm.Set("sum", wrapGoFunc(vm, sum))
|
||||
|
||||
// Test passing string as number (frictionless casting via go/cast)
|
||||
script := `sum("10", 20)`
|
||||
val, err := vm.RunString(script)
|
||||
// 2. Disable SafeMode
|
||||
vm.Set("__safeMode__", false)
|
||||
val, err := vm.RunString(`danger()`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if val.Export().(int64) != 30 {
|
||||
t.Errorf("expected 30, got %v", val.Export())
|
||||
if val.Export() != "danger" {
|
||||
t.Errorf("Expected 'danger', got %v", val.Export())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBridgeErrorHandling(t *testing.T) {
|
||||
func TestBridgeLoggerInjection(t *testing.T) {
|
||||
vm := goja.New()
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&buf, "", 0)
|
||||
|
||||
failFunc := func() (string, error) {
|
||||
return "", errors.New("go_error")
|
||||
vm.Set("__logger__", vm.ToValue(logger))
|
||||
|
||||
logFn := func(l *log.Logger, msg string) {
|
||||
l.Print(msg)
|
||||
}
|
||||
|
||||
vm.Set("failFunc", wrapGoFunc(vm, failFunc))
|
||||
vm.Set("logMsg", wrapGoFunc(vm, logFn, false))
|
||||
|
||||
script := `
|
||||
try {
|
||||
failFunc();
|
||||
} catch (e) {
|
||||
e.message;
|
||||
}
|
||||
`
|
||||
val, err := vm.RunString(script)
|
||||
// JS only passes the 'msg' argument, logger is injected
|
||||
_, err := vm.RunString(`logMsg("hello from js")`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if val.Export() != "go_error" {
|
||||
t.Errorf("expected 'go_error', got %v", val.Export())
|
||||
|
||||
if !strings.Contains(buf.String(), "hello from js") {
|
||||
t.Errorf("Logger injection failed, buffer: %s", buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBridgeContextInjection(t *testing.T) {
|
||||
func TestBridgeMixedInjection(t *testing.T) {
|
||||
vm := goja.New()
|
||||
ctx := context.WithValue(context.Background(), "key", "value")
|
||||
ctx := context.WithValue(context.Background(), "k", "v")
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&buf, "", 0)
|
||||
|
||||
// Inject context into VM
|
||||
vm.Set("__ctx__", vm.ToValue(ctx))
|
||||
vm.Set("__logger__", vm.ToValue(logger))
|
||||
|
||||
checkCtx := func(c context.Context) string {
|
||||
return c.Value("key").(string)
|
||||
mixedFn := func(c context.Context, l *log.Logger, a int) string {
|
||||
l.Printf("val: %d", a)
|
||||
return c.Value("k").(string)
|
||||
}
|
||||
|
||||
vm.Set("checkCtx", wrapGoFunc(vm, checkCtx))
|
||||
vm.Set("mixed", wrapGoFunc(vm, mixedFn, false))
|
||||
|
||||
val, err := vm.RunString(`checkCtx()`)
|
||||
val, err := vm.RunString(`mixed(42)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if val.Export() != "value" {
|
||||
t.Errorf("expected 'value', got %v", val.Export())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBridgeComplexStruct(t *testing.T) {
|
||||
vm := goja.New()
|
||||
|
||||
type Complex struct {
|
||||
Data map[string]any
|
||||
Tags []string
|
||||
if val.Export() != "v" {
|
||||
t.Errorf("Context injection failed")
|
||||
}
|
||||
|
||||
process := func(c Complex) int {
|
||||
return len(c.Tags) + len(c.Data)
|
||||
}
|
||||
|
||||
vm.Set("process", wrapGoFunc(vm, process))
|
||||
|
||||
script := `
|
||||
process({
|
||||
Tags: ["a", "b"],
|
||||
Data: { "x": 1, "y": 2, "z": 3 }
|
||||
})
|
||||
`
|
||||
val, err := vm.RunString(script)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if val.Export().(int64) != 5 {
|
||||
t.Errorf("expected 5, got %v", val.Export())
|
||||
if !strings.Contains(buf.String(), "val: 42") {
|
||||
t.Errorf("Logger injection failed")
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure jsmod is used to avoid unused import if needed
|
||||
func init() {
|
||||
_ = jsmod.GetModules()
|
||||
}
|
||||
|
||||
33
doc.go
33
doc.go
@ -24,17 +24,21 @@ func Doc() string {
|
||||
|
||||
sb.WriteString("declare namespace go {\n")
|
||||
for _, modName := range keys {
|
||||
exports := modules[modName]
|
||||
mod := modules[modName]
|
||||
sb.WriteString(fmt.Sprintf(" namespace %s {\n", modName))
|
||||
|
||||
expKeys := make([]string, 0, len(exports))
|
||||
for k := range exports {
|
||||
expKeys := make([]string, 0, len(mod.Exports))
|
||||
for k := range mod.Exports {
|
||||
expKeys = append(expKeys, k)
|
||||
}
|
||||
sort.Strings(expKeys)
|
||||
|
||||
for _, name := range expKeys {
|
||||
val := exports[name]
|
||||
val := mod.Exports[name]
|
||||
isUnsafe := mod.UnsafeList[name]
|
||||
if isUnsafe {
|
||||
sb.WriteString(" /** @unsafe */\n")
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf(" %s\n", formatExport(name, val)))
|
||||
}
|
||||
sb.WriteString(" }\n")
|
||||
@ -60,14 +64,18 @@ func formatExport(name string, val any) string {
|
||||
func formatFunc(t reflect.Type) string {
|
||||
var params []string
|
||||
numIn := t.NumIn()
|
||||
startIdx := 0
|
||||
// Skip context.Context in TS doc as it's injected automatically
|
||||
if numIn > 0 && t.In(0).String() == "context.Context" {
|
||||
startIdx = 1
|
||||
jsArgIdx := 0
|
||||
|
||||
for i := 0; i < numIn; i++ {
|
||||
argType := t.In(i)
|
||||
// Skip Context and Logger in TS doc
|
||||
if argType.Implements(reflect.TypeOf((*interface{ Done() <-chan struct{} })(nil)).Elem()) ||
|
||||
argType.String() == "*log.Logger" {
|
||||
continue
|
||||
}
|
||||
|
||||
for i := startIdx; i < numIn; i++ {
|
||||
params = append(params, fmt.Sprintf("arg%d: %s", i-startIdx, goTypeToTS(t.In(i))))
|
||||
params = append(params, fmt.Sprintf("arg%d: %s", jsArgIdx, goTypeToTS(argType)))
|
||||
jsArgIdx++
|
||||
}
|
||||
|
||||
// Handle return values
|
||||
@ -78,7 +86,7 @@ func formatFunc(t reflect.Type) string {
|
||||
} else {
|
||||
// If last return is error, we only care about the first part for TS doc
|
||||
realOut := numOut
|
||||
if numOut > 0 && t.Out(numOut-1).String() == "error" {
|
||||
if numOut > 0 && t.Out(numOut-1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
|
||||
realOut--
|
||||
}
|
||||
|
||||
@ -118,9 +126,6 @@ func goTypeToTS(t reflect.Type) string {
|
||||
case reflect.Map:
|
||||
return "Record<string, any>"
|
||||
case reflect.Struct:
|
||||
// For structs, we could recursively list fields, but for a concise AI doc,
|
||||
// "any" or the struct name is often sufficient.
|
||||
// Let's at least show it's an object.
|
||||
return "{ [key: string]: any }"
|
||||
case reflect.Interface:
|
||||
return "any"
|
||||
|
||||
38
pool.go
38
pool.go
@ -3,6 +3,7 @@ package js
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
@ -38,11 +39,12 @@ func createNewRuntime() *goja.Runtime {
|
||||
_ = vm.Set("go", goObj)
|
||||
|
||||
modules := jsmod.GetModules()
|
||||
for modName, exports := range modules {
|
||||
for modName, mod := range modules {
|
||||
modObj := vm.NewObject()
|
||||
for name, val := range exports {
|
||||
for name, val := range mod.Exports {
|
||||
isUnsafe := mod.UnsafeList[name]
|
||||
if reflectType := fmt.Sprintf("%T", val); reflectType == "func" || (len(reflectType) > 4 && reflectType[:4] == "func") {
|
||||
_ = modObj.Set(name, wrapGoFunc(vm, val))
|
||||
_ = modObj.Set(name, wrapGoFunc(vm, val, isUnsafe))
|
||||
} else {
|
||||
_ = modObj.Set(name, vm.ToValue(val))
|
||||
}
|
||||
@ -63,9 +65,26 @@ func Define(code string) {
|
||||
atomic.AddInt32(&globalVersion, 1)
|
||||
}
|
||||
|
||||
// CallOption allows configuring the JS execution environment.
|
||||
type CallOption func(vm *goja.Runtime)
|
||||
|
||||
// WithSafeMode enables or disables safe mode for the call.
|
||||
func WithSafeMode(enabled bool) CallOption {
|
||||
return func(vm *goja.Runtime) {
|
||||
_ = vm.Set("__safeMode__", enabled)
|
||||
}
|
||||
}
|
||||
|
||||
// WithLogger injects a custom logger for the call.
|
||||
func WithLogger(logger *log.Logger) CallOption {
|
||||
return func(vm *goja.Runtime) {
|
||||
_ = vm.Set("__logger__", vm.ToValue(logger))
|
||||
}
|
||||
}
|
||||
|
||||
// Call executes a JS function from the pool.
|
||||
// It automatically synchronizes the VM to the latest version.
|
||||
func Call(ctx context.Context, funcName string, args ...any) (any, error) {
|
||||
func Call(ctx context.Context, funcName string, args []any, opts ...CallOption) (any, error) {
|
||||
instance := pool.Get().(*vmInstance)
|
||||
defer pool.Put(instance)
|
||||
|
||||
@ -86,8 +105,15 @@ func Call(ctx context.Context, funcName string, args ...any) (any, error) {
|
||||
scriptsMu.RUnlock()
|
||||
}
|
||||
|
||||
// 2. Set Context
|
||||
// 2. Set Context and default state
|
||||
_ = vm.Set("__ctx__", vm.ToValue(ctx))
|
||||
_ = vm.Set("__safeMode__", true) // Default is safe
|
||||
_ = vm.Set("__logger__", goja.Undefined())
|
||||
|
||||
// Apply Options
|
||||
for _, opt := range opts {
|
||||
opt(vm)
|
||||
}
|
||||
|
||||
// 3. Get and Call JS Function
|
||||
fnVal := vm.Get(funcName)
|
||||
@ -128,7 +154,5 @@ func Call(ctx context.Context, funcName string, args ...any) (any, error) {
|
||||
func FuncList() []string {
|
||||
scriptsMu.RLock()
|
||||
defer scriptsMu.RUnlock()
|
||||
// In a real implementation, we would extract function names from scripts.
|
||||
// For now, this is a placeholder.
|
||||
return []string{}
|
||||
}
|
||||
|
||||
@ -9,7 +9,7 @@ func TestPoolVersioning(t *testing.T) {
|
||||
// 1. Define initial function
|
||||
Define(`function hello(name) { return "Hello " + name; }`)
|
||||
|
||||
res, err := Call(context.Background(), "hello", "World")
|
||||
res, err := Call(context.Background(), "hello", []any{"World"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -20,7 +20,7 @@ func TestPoolVersioning(t *testing.T) {
|
||||
// 2. Define new function (incremental update)
|
||||
Define(`function add(a, b) { return a + b; }`)
|
||||
|
||||
res, err = Call(context.Background(), "add", 1, 2)
|
||||
res, err = Call(context.Background(), "add", []any{1, 2})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -29,7 +29,7 @@ func TestPoolVersioning(t *testing.T) {
|
||||
}
|
||||
|
||||
// 3. Ensure old function still works
|
||||
res, err = Call(context.Background(), "hello", "Again")
|
||||
res, err = Call(context.Background(), "hello", []any{"Again"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -49,7 +49,7 @@ func TestPoolConcurrent(t *testing.T) {
|
||||
t.Parallel()
|
||||
for i := 0; i < 10; i++ {
|
||||
go func() {
|
||||
_, _ = Call(context.Background(), "heavy", 1000)
|
||||
_, _ = Call(context.Background(), "heavy", []any{1000})
|
||||
}()
|
||||
}
|
||||
})
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user