feat: upgrade bridge with safeMode and magic injection (by AI)
This commit is contained in:
parent
7632cea6f6
commit
65228a6707
79
bridge.go
79
bridge.go
@ -3,6 +3,8 @@ package js
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
"apigo.cc/go/cast"
|
"apigo.cc/go/cast"
|
||||||
@ -10,8 +12,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// wrapGoFunc converts a standard Go function into a goja.Callable.
|
// wrapGoFunc converts a standard Go function into a goja.Callable.
|
||||||
// It handles context injection and automatic type conversion via go/cast.
|
// It handles context/logger injection, safeMode enforcement, and automatic type conversion.
|
||||||
func wrapGoFunc(vm *goja.Runtime, fn any) goja.Value {
|
func wrapGoFunc(vm *goja.Runtime, fn any, isUnsafe bool) goja.Value {
|
||||||
v := reflect.ValueOf(fn)
|
v := reflect.ValueOf(fn)
|
||||||
if v.Kind() != reflect.Func {
|
if v.Kind() != reflect.Func {
|
||||||
panic(fmt.Sprintf("js.bridge: expected func, got %T", fn))
|
panic(fmt.Sprintf("js.bridge: expected func, got %T", fn))
|
||||||
@ -20,50 +22,76 @@ func wrapGoFunc(vm *goja.Runtime, fn any) goja.Value {
|
|||||||
t := v.Type()
|
t := v.Type()
|
||||||
|
|
||||||
return vm.ToValue(func(call goja.FunctionCall) goja.Value {
|
return vm.ToValue(func(call goja.FunctionCall) goja.Value {
|
||||||
// 1. Prepare Arguments
|
// 1. Safety Check
|
||||||
|
if isUnsafe {
|
||||||
|
safeMode := true // Default to safe mode
|
||||||
|
smVal := vm.Get("__safeMode__")
|
||||||
|
if smVal != nil && !goja.IsUndefined(smVal) {
|
||||||
|
if sm, ok := smVal.Export().(bool); ok {
|
||||||
|
safeMode = sm
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if safeMode {
|
||||||
|
panic(vm.NewGoError(fmt.Errorf("unauthorized: unsafe operation blocked by safeMode")))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Prepare Arguments
|
||||||
numIn := t.NumIn()
|
numIn := t.NumIn()
|
||||||
goArgs := make([]reflect.Value, numIn)
|
goArgs := make([]reflect.Value, numIn)
|
||||||
jsArgs := call.Arguments
|
jsArgs := call.Arguments
|
||||||
jsArgIdx := 0
|
jsArgIdx := 0
|
||||||
|
|
||||||
// Handle context.Context injection
|
for i := 0; i < numIn; i++ {
|
||||||
startIdx := 0
|
argType := t.In(i)
|
||||||
if numIn > 0 && t.In(0).Implements(reflect.TypeOf((*context.Context)(nil)).Elem()) {
|
|
||||||
// Inject context from VM's current execution context if available,
|
// Magic Injection: context.Context
|
||||||
// otherwise use Background. (We can improve this by storing ctx in VM's data)
|
if argType.Implements(reflect.TypeOf((*context.Context)(nil)).Elem()) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
if c, ok := vm.Get("__ctx__").Export().(context.Context); ok {
|
ctxVal := vm.Get("__ctx__")
|
||||||
|
if ctxVal != nil && !goja.IsUndefined(ctxVal) {
|
||||||
|
if c, ok := ctxVal.Export().(context.Context); ok {
|
||||||
ctx = c
|
ctx = c
|
||||||
}
|
}
|
||||||
goArgs[0] = reflect.ValueOf(ctx)
|
}
|
||||||
startIdx = 1
|
goArgs[i] = reflect.ValueOf(ctx)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := startIdx; i < numIn; i++ {
|
// Magic Injection: *log.Logger
|
||||||
argType := t.In(i)
|
if argType == reflect.TypeOf((*log.Logger)(nil)) {
|
||||||
goArgs[i] = reflect.New(argType).Elem()
|
var logger *log.Logger
|
||||||
|
logVal := vm.Get("__logger__")
|
||||||
|
if logVal != nil && !goja.IsUndefined(logVal) {
|
||||||
|
if l, ok := logVal.Export().(*log.Logger); ok {
|
||||||
|
logger = l
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if logger == nil {
|
||||||
|
// Fallback to a discard logger if none provided to avoid nil panic in Go side
|
||||||
|
logger = log.New(io.Discard, "", 0)
|
||||||
|
}
|
||||||
|
goArgs[i] = reflect.ValueOf(logger)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normal JS Argument with go/cast
|
||||||
|
goArgs[i] = reflect.New(argType).Elem()
|
||||||
if jsArgIdx < len(jsArgs) {
|
if jsArgIdx < len(jsArgs) {
|
||||||
jsVal := jsArgs[jsArgIdx]
|
jsVal := jsArgs[jsArgIdx]
|
||||||
// Use goja's Export() to get a Go-compatible value
|
|
||||||
exported := jsVal.Export()
|
exported := jsVal.Export()
|
||||||
|
|
||||||
// First, try direct assignment to preserve pointer identity (Host Object fidelity)
|
|
||||||
expV := reflect.ValueOf(exported)
|
expV := reflect.ValueOf(exported)
|
||||||
if expV.IsValid() && expV.Type().AssignableTo(argType) {
|
if expV.IsValid() && expV.Type().AssignableTo(argType) {
|
||||||
goArgs[i].Set(expV)
|
goArgs[i].Set(expV)
|
||||||
} else {
|
} else {
|
||||||
// Otherwise, use go/cast to convert to the target Go type (frictionless)
|
|
||||||
cast.Convert(goArgs[i].Addr().Interface(), exported)
|
cast.Convert(goArgs[i].Addr().Interface(), exported)
|
||||||
}
|
}
|
||||||
jsArgIdx++
|
jsArgIdx++
|
||||||
} else {
|
|
||||||
// If JS args are missing, cast will keep it as zero value (frictionless)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. Call the Go function
|
// 3. Call the Go function
|
||||||
// We use recover to catch Go panics and turn them into JS errors
|
|
||||||
var results []reflect.Value
|
var results []reflect.Value
|
||||||
var recovered any
|
var recovered any
|
||||||
func() {
|
func() {
|
||||||
@ -75,33 +103,28 @@ func wrapGoFunc(vm *goja.Runtime, fn any) goja.Value {
|
|||||||
panic(vm.NewGoError(fmt.Errorf("go panic: %v", recovered)))
|
panic(vm.NewGoError(fmt.Errorf("go panic: %v", recovered)))
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. Process Results
|
// 4. Process Results
|
||||||
if len(results) == 0 {
|
if len(results) == 0 {
|
||||||
return goja.Undefined()
|
return goja.Undefined()
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the last return value is an error, check it
|
// Check for error return
|
||||||
if len(results) > 0 {
|
|
||||||
last := results[len(results)-1]
|
last := results[len(results)-1]
|
||||||
if last.Type().Implements(reflect.TypeOf((*error)(nil)).Elem()) {
|
if last.Type().Implements(reflect.TypeOf((*error)(nil)).Elem()) {
|
||||||
if !last.IsNil() {
|
if !last.IsNil() {
|
||||||
err := last.Interface().(error)
|
err := last.Interface().(error)
|
||||||
panic(vm.NewGoError(err))
|
panic(vm.NewGoError(err))
|
||||||
}
|
}
|
||||||
// If it's an error but nil, exclude it from normal results if it's the only result
|
|
||||||
if len(results) == 1 {
|
if len(results) == 1 {
|
||||||
return goja.Undefined()
|
return goja.Undefined()
|
||||||
}
|
}
|
||||||
// Otherwise, we take results up to len-1
|
|
||||||
results = results[:len(results)-1]
|
results = results[:len(results)-1]
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if len(results) == 1 {
|
if len(results) == 1 {
|
||||||
return vm.ToValue(results[0].Interface())
|
return vm.ToValue(results[0].Interface())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Multiple return values (other than the handled error) are returned as a JS array
|
|
||||||
resSlice := make([]any, len(results))
|
resSlice := make([]any, len(results))
|
||||||
for i, r := range results {
|
for i, r := range results {
|
||||||
resSlice[i] = r.Interface()
|
resSlice[i] = r.Interface()
|
||||||
|
|||||||
161
bridge_test.go
161
bridge_test.go
@ -1,162 +1,89 @@
|
|||||||
package js
|
package js
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"log"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"apigo.cc/go/jsmod"
|
|
||||||
"github.com/dop251/goja"
|
"github.com/dop251/goja"
|
||||||
)
|
)
|
||||||
|
|
||||||
type User struct {
|
func TestBridgeSafeMode(t *testing.T) {
|
||||||
ID int
|
|
||||||
Name string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *User) GetInfo() string {
|
|
||||||
return u.Name
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBridgeDataFidelity(t *testing.T) {
|
|
||||||
vm := goja.New()
|
vm := goja.New()
|
||||||
|
|
||||||
// 1. Setup Go functions
|
unsafeFn := func() string { return "danger" }
|
||||||
originalUser := &User{ID: 1, Name: "Star"}
|
|
||||||
|
|
||||||
getUser := func() *User {
|
// Register with isUnsafe = true
|
||||||
return originalUser
|
vm.Set("danger", wrapGoFunc(vm, unsafeFn, true))
|
||||||
|
|
||||||
|
// 1. Default (SafeMode = true)
|
||||||
|
_, err := vm.RunString(`danger()`)
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "blocked by safeMode") {
|
||||||
|
t.Fatalf("Expected safeMode block, got: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
verifyUser := func(u *User) bool {
|
// 2. Disable SafeMode
|
||||||
// Verify pointer address remains the same (Host Object fidelity)
|
vm.Set("__safeMode__", false)
|
||||||
return u == originalUser
|
val, err := vm.RunString(`danger()`)
|
||||||
}
|
|
||||||
|
|
||||||
// Register functions manually for testing bridge
|
|
||||||
vm.Set("getUser", wrapGoFunc(vm, getUser))
|
|
||||||
vm.Set("verifyUser", wrapGoFunc(vm, verifyUser))
|
|
||||||
|
|
||||||
// 2. JS Execution
|
|
||||||
script := `
|
|
||||||
let u = getUser();
|
|
||||||
if (u.Name !== "Star") throw "Name mismatch: " + u.Name;
|
|
||||||
if (u.ID !== 1) throw "ID mismatch: " + u.ID;
|
|
||||||
|
|
||||||
// Host Object method call (if exported)
|
|
||||||
// Note: goja requires methods to be exported and usually works better with struct pointers
|
|
||||||
|
|
||||||
let isSame = verifyUser(u);
|
|
||||||
if (!isSame) throw "Pointer mismatch in Go side";
|
|
||||||
|
|
||||||
"ok"
|
|
||||||
`
|
|
||||||
val, err := vm.RunString(script)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("JS execution failed: %v", err)
|
|
||||||
}
|
|
||||||
if val.Export() != "ok" {
|
|
||||||
t.Errorf("expected 'ok', got %v", val.Export())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBridgeCasting(t *testing.T) {
|
|
||||||
vm := goja.New()
|
|
||||||
|
|
||||||
sum := func(a, b int) int {
|
|
||||||
return a + b
|
|
||||||
}
|
|
||||||
|
|
||||||
vm.Set("sum", wrapGoFunc(vm, sum))
|
|
||||||
|
|
||||||
// Test passing string as number (frictionless casting via go/cast)
|
|
||||||
script := `sum("10", 20)`
|
|
||||||
val, err := vm.RunString(script)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if val.Export().(int64) != 30 {
|
if val.Export() != "danger" {
|
||||||
t.Errorf("expected 30, got %v", val.Export())
|
t.Errorf("Expected 'danger', got %v", val.Export())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBridgeErrorHandling(t *testing.T) {
|
func TestBridgeLoggerInjection(t *testing.T) {
|
||||||
vm := goja.New()
|
vm := goja.New()
|
||||||
|
var buf bytes.Buffer
|
||||||
|
logger := log.New(&buf, "", 0)
|
||||||
|
|
||||||
failFunc := func() (string, error) {
|
vm.Set("__logger__", vm.ToValue(logger))
|
||||||
return "", errors.New("go_error")
|
|
||||||
|
logFn := func(l *log.Logger, msg string) {
|
||||||
|
l.Print(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
vm.Set("failFunc", wrapGoFunc(vm, failFunc))
|
vm.Set("logMsg", wrapGoFunc(vm, logFn, false))
|
||||||
|
|
||||||
script := `
|
// JS only passes the 'msg' argument, logger is injected
|
||||||
try {
|
_, err := vm.RunString(`logMsg("hello from js")`)
|
||||||
failFunc();
|
|
||||||
} catch (e) {
|
|
||||||
e.message;
|
|
||||||
}
|
|
||||||
`
|
|
||||||
val, err := vm.RunString(script)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if val.Export() != "go_error" {
|
|
||||||
t.Errorf("expected 'go_error', got %v", val.Export())
|
if !strings.Contains(buf.String(), "hello from js") {
|
||||||
|
t.Errorf("Logger injection failed, buffer: %s", buf.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBridgeContextInjection(t *testing.T) {
|
func TestBridgeMixedInjection(t *testing.T) {
|
||||||
vm := goja.New()
|
vm := goja.New()
|
||||||
ctx := context.WithValue(context.Background(), "key", "value")
|
ctx := context.WithValue(context.Background(), "k", "v")
|
||||||
|
var buf bytes.Buffer
|
||||||
|
logger := log.New(&buf, "", 0)
|
||||||
|
|
||||||
// Inject context into VM
|
|
||||||
vm.Set("__ctx__", vm.ToValue(ctx))
|
vm.Set("__ctx__", vm.ToValue(ctx))
|
||||||
|
vm.Set("__logger__", vm.ToValue(logger))
|
||||||
|
|
||||||
checkCtx := func(c context.Context) string {
|
mixedFn := func(c context.Context, l *log.Logger, a int) string {
|
||||||
return c.Value("key").(string)
|
l.Printf("val: %d", a)
|
||||||
|
return c.Value("k").(string)
|
||||||
}
|
}
|
||||||
|
|
||||||
vm.Set("checkCtx", wrapGoFunc(vm, checkCtx))
|
vm.Set("mixed", wrapGoFunc(vm, mixedFn, false))
|
||||||
|
|
||||||
val, err := vm.RunString(`checkCtx()`)
|
val, err := vm.RunString(`mixed(42)`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if val.Export() != "value" {
|
|
||||||
t.Errorf("expected 'value', got %v", val.Export())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBridgeComplexStruct(t *testing.T) {
|
if val.Export() != "v" {
|
||||||
vm := goja.New()
|
t.Errorf("Context injection failed")
|
||||||
|
|
||||||
type Complex struct {
|
|
||||||
Data map[string]any
|
|
||||||
Tags []string
|
|
||||||
}
|
}
|
||||||
|
if !strings.Contains(buf.String(), "val: 42") {
|
||||||
process := func(c Complex) int {
|
t.Errorf("Logger injection failed")
|
||||||
return len(c.Tags) + len(c.Data)
|
|
||||||
}
|
|
||||||
|
|
||||||
vm.Set("process", wrapGoFunc(vm, process))
|
|
||||||
|
|
||||||
script := `
|
|
||||||
process({
|
|
||||||
Tags: ["a", "b"],
|
|
||||||
Data: { "x": 1, "y": 2, "z": 3 }
|
|
||||||
})
|
|
||||||
`
|
|
||||||
val, err := vm.RunString(script)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if val.Export().(int64) != 5 {
|
|
||||||
t.Errorf("expected 5, got %v", val.Export())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure jsmod is used to avoid unused import if needed
|
|
||||||
func init() {
|
|
||||||
_ = jsmod.GetModules()
|
|
||||||
}
|
|
||||||
|
|||||||
33
doc.go
33
doc.go
@ -24,17 +24,21 @@ func Doc() string {
|
|||||||
|
|
||||||
sb.WriteString("declare namespace go {\n")
|
sb.WriteString("declare namespace go {\n")
|
||||||
for _, modName := range keys {
|
for _, modName := range keys {
|
||||||
exports := modules[modName]
|
mod := modules[modName]
|
||||||
sb.WriteString(fmt.Sprintf(" namespace %s {\n", modName))
|
sb.WriteString(fmt.Sprintf(" namespace %s {\n", modName))
|
||||||
|
|
||||||
expKeys := make([]string, 0, len(exports))
|
expKeys := make([]string, 0, len(mod.Exports))
|
||||||
for k := range exports {
|
for k := range mod.Exports {
|
||||||
expKeys = append(expKeys, k)
|
expKeys = append(expKeys, k)
|
||||||
}
|
}
|
||||||
sort.Strings(expKeys)
|
sort.Strings(expKeys)
|
||||||
|
|
||||||
for _, name := range expKeys {
|
for _, name := range expKeys {
|
||||||
val := exports[name]
|
val := mod.Exports[name]
|
||||||
|
isUnsafe := mod.UnsafeList[name]
|
||||||
|
if isUnsafe {
|
||||||
|
sb.WriteString(" /** @unsafe */\n")
|
||||||
|
}
|
||||||
sb.WriteString(fmt.Sprintf(" %s\n", formatExport(name, val)))
|
sb.WriteString(fmt.Sprintf(" %s\n", formatExport(name, val)))
|
||||||
}
|
}
|
||||||
sb.WriteString(" }\n")
|
sb.WriteString(" }\n")
|
||||||
@ -60,14 +64,18 @@ func formatExport(name string, val any) string {
|
|||||||
func formatFunc(t reflect.Type) string {
|
func formatFunc(t reflect.Type) string {
|
||||||
var params []string
|
var params []string
|
||||||
numIn := t.NumIn()
|
numIn := t.NumIn()
|
||||||
startIdx := 0
|
jsArgIdx := 0
|
||||||
// Skip context.Context in TS doc as it's injected automatically
|
|
||||||
if numIn > 0 && t.In(0).String() == "context.Context" {
|
for i := 0; i < numIn; i++ {
|
||||||
startIdx = 1
|
argType := t.In(i)
|
||||||
|
// Skip Context and Logger in TS doc
|
||||||
|
if argType.Implements(reflect.TypeOf((*interface{ Done() <-chan struct{} })(nil)).Elem()) ||
|
||||||
|
argType.String() == "*log.Logger" {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := startIdx; i < numIn; i++ {
|
params = append(params, fmt.Sprintf("arg%d: %s", jsArgIdx, goTypeToTS(argType)))
|
||||||
params = append(params, fmt.Sprintf("arg%d: %s", i-startIdx, goTypeToTS(t.In(i))))
|
jsArgIdx++
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle return values
|
// Handle return values
|
||||||
@ -78,7 +86,7 @@ func formatFunc(t reflect.Type) string {
|
|||||||
} else {
|
} else {
|
||||||
// If last return is error, we only care about the first part for TS doc
|
// If last return is error, we only care about the first part for TS doc
|
||||||
realOut := numOut
|
realOut := numOut
|
||||||
if numOut > 0 && t.Out(numOut-1).String() == "error" {
|
if numOut > 0 && t.Out(numOut-1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
|
||||||
realOut--
|
realOut--
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -118,9 +126,6 @@ func goTypeToTS(t reflect.Type) string {
|
|||||||
case reflect.Map:
|
case reflect.Map:
|
||||||
return "Record<string, any>"
|
return "Record<string, any>"
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
// For structs, we could recursively list fields, but for a concise AI doc,
|
|
||||||
// "any" or the struct name is often sufficient.
|
|
||||||
// Let's at least show it's an object.
|
|
||||||
return "{ [key: string]: any }"
|
return "{ [key: string]: any }"
|
||||||
case reflect.Interface:
|
case reflect.Interface:
|
||||||
return "any"
|
return "any"
|
||||||
|
|||||||
38
pool.go
38
pool.go
@ -3,6 +3,7 @@ package js
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
@ -38,11 +39,12 @@ func createNewRuntime() *goja.Runtime {
|
|||||||
_ = vm.Set("go", goObj)
|
_ = vm.Set("go", goObj)
|
||||||
|
|
||||||
modules := jsmod.GetModules()
|
modules := jsmod.GetModules()
|
||||||
for modName, exports := range modules {
|
for modName, mod := range modules {
|
||||||
modObj := vm.NewObject()
|
modObj := vm.NewObject()
|
||||||
for name, val := range exports {
|
for name, val := range mod.Exports {
|
||||||
|
isUnsafe := mod.UnsafeList[name]
|
||||||
if reflectType := fmt.Sprintf("%T", val); reflectType == "func" || (len(reflectType) > 4 && reflectType[:4] == "func") {
|
if reflectType := fmt.Sprintf("%T", val); reflectType == "func" || (len(reflectType) > 4 && reflectType[:4] == "func") {
|
||||||
_ = modObj.Set(name, wrapGoFunc(vm, val))
|
_ = modObj.Set(name, wrapGoFunc(vm, val, isUnsafe))
|
||||||
} else {
|
} else {
|
||||||
_ = modObj.Set(name, vm.ToValue(val))
|
_ = modObj.Set(name, vm.ToValue(val))
|
||||||
}
|
}
|
||||||
@ -63,9 +65,26 @@ func Define(code string) {
|
|||||||
atomic.AddInt32(&globalVersion, 1)
|
atomic.AddInt32(&globalVersion, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CallOption allows configuring the JS execution environment.
|
||||||
|
type CallOption func(vm *goja.Runtime)
|
||||||
|
|
||||||
|
// WithSafeMode enables or disables safe mode for the call.
|
||||||
|
func WithSafeMode(enabled bool) CallOption {
|
||||||
|
return func(vm *goja.Runtime) {
|
||||||
|
_ = vm.Set("__safeMode__", enabled)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithLogger injects a custom logger for the call.
|
||||||
|
func WithLogger(logger *log.Logger) CallOption {
|
||||||
|
return func(vm *goja.Runtime) {
|
||||||
|
_ = vm.Set("__logger__", vm.ToValue(logger))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Call executes a JS function from the pool.
|
// Call executes a JS function from the pool.
|
||||||
// It automatically synchronizes the VM to the latest version.
|
// It automatically synchronizes the VM to the latest version.
|
||||||
func Call(ctx context.Context, funcName string, args ...any) (any, error) {
|
func Call(ctx context.Context, funcName string, args []any, opts ...CallOption) (any, error) {
|
||||||
instance := pool.Get().(*vmInstance)
|
instance := pool.Get().(*vmInstance)
|
||||||
defer pool.Put(instance)
|
defer pool.Put(instance)
|
||||||
|
|
||||||
@ -86,8 +105,15 @@ func Call(ctx context.Context, funcName string, args ...any) (any, error) {
|
|||||||
scriptsMu.RUnlock()
|
scriptsMu.RUnlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. Set Context
|
// 2. Set Context and default state
|
||||||
_ = vm.Set("__ctx__", vm.ToValue(ctx))
|
_ = vm.Set("__ctx__", vm.ToValue(ctx))
|
||||||
|
_ = vm.Set("__safeMode__", true) // Default is safe
|
||||||
|
_ = vm.Set("__logger__", goja.Undefined())
|
||||||
|
|
||||||
|
// Apply Options
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(vm)
|
||||||
|
}
|
||||||
|
|
||||||
// 3. Get and Call JS Function
|
// 3. Get and Call JS Function
|
||||||
fnVal := vm.Get(funcName)
|
fnVal := vm.Get(funcName)
|
||||||
@ -128,7 +154,5 @@ func Call(ctx context.Context, funcName string, args ...any) (any, error) {
|
|||||||
func FuncList() []string {
|
func FuncList() []string {
|
||||||
scriptsMu.RLock()
|
scriptsMu.RLock()
|
||||||
defer scriptsMu.RUnlock()
|
defer scriptsMu.RUnlock()
|
||||||
// In a real implementation, we would extract function names from scripts.
|
|
||||||
// For now, this is a placeholder.
|
|
||||||
return []string{}
|
return []string{}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -9,7 +9,7 @@ func TestPoolVersioning(t *testing.T) {
|
|||||||
// 1. Define initial function
|
// 1. Define initial function
|
||||||
Define(`function hello(name) { return "Hello " + name; }`)
|
Define(`function hello(name) { return "Hello " + name; }`)
|
||||||
|
|
||||||
res, err := Call(context.Background(), "hello", "World")
|
res, err := Call(context.Background(), "hello", []any{"World"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -20,7 +20,7 @@ func TestPoolVersioning(t *testing.T) {
|
|||||||
// 2. Define new function (incremental update)
|
// 2. Define new function (incremental update)
|
||||||
Define(`function add(a, b) { return a + b; }`)
|
Define(`function add(a, b) { return a + b; }`)
|
||||||
|
|
||||||
res, err = Call(context.Background(), "add", 1, 2)
|
res, err = Call(context.Background(), "add", []any{1, 2})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -29,7 +29,7 @@ func TestPoolVersioning(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 3. Ensure old function still works
|
// 3. Ensure old function still works
|
||||||
res, err = Call(context.Background(), "hello", "Again")
|
res, err = Call(context.Background(), "hello", []any{"Again"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -49,7 +49,7 @@ func TestPoolConcurrent(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
go func() {
|
go func() {
|
||||||
_, _ = Call(context.Background(), "heavy", 1000)
|
_, _ = Call(context.Background(), "heavy", []any{1000})
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user