250 lines
		
	
	
		
			6.1 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			250 lines
		
	
	
		
			6.1 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package ai
 | 
						||
 | 
						||
import (
 | 
						||
	"apigo.cc/ai/ai/js"
 | 
						||
	"apigo.cc/ai/ai/llm"
 | 
						||
	"bytes"
 | 
						||
	_ "embed"
 | 
						||
	"encoding/json"
 | 
						||
	"errors"
 | 
						||
	"fmt"
 | 
						||
	"github.com/dop251/goja"
 | 
						||
	"github.com/dop251/goja_nodejs/require"
 | 
						||
	"github.com/ssgo/u"
 | 
						||
	"path/filepath"
 | 
						||
	"regexp"
 | 
						||
	"strings"
 | 
						||
	"text/template"
 | 
						||
)
 | 
						||
 | 
						||
//go:embed js/lib/ai.ts
 | 
						||
var aiTS string
 | 
						||
 | 
						||
//go:embed js/lib/console.ts
 | 
						||
var consoleTS string
 | 
						||
 | 
						||
//go:embed js/lib/file.ts
 | 
						||
var fileTS string
 | 
						||
 | 
						||
//go:embed js/lib/util.ts
 | 
						||
var utilTS string
 | 
						||
 | 
						||
func RunFile(file string, args ...any) (any, error) {
 | 
						||
	return Run(u.ReadFileN(file), file, args...)
 | 
						||
}
 | 
						||
 | 
						||
func Run(code string, refFile string, args ...any) (any, error) {
 | 
						||
	var r any
 | 
						||
	rt, err := StartFromCode(code, refFile)
 | 
						||
	if err == nil {
 | 
						||
		r, err = rt.Run(args...)
 | 
						||
	}
 | 
						||
	return r, err
 | 
						||
}
 | 
						||
 | 
						||
var importModMatcher = regexp.MustCompile(`(?im)^\s*import\s+(.+?)\s+from\s+['"](.+?)['"]`)
 | 
						||
var importLibMatcher = regexp.MustCompile(`(?im)^\s*(import)\s+(.+?)\s+from\s+['"][./\\\w:]+lib[/\\](.+?)(\.ts)?['"]`)
 | 
						||
var requireLibMatcher = regexp.MustCompile(`(?im)^\s*(const|let|var)\s+(.+?)\s*=\s*require\s*\(\s*['"][./\\\w:]+lib[/\\](.+?)(\.ts)?['"]\s*\)`)
 | 
						||
var checkMainMatcher = regexp.MustCompile(`(?im)^\s*function\s+main\s*\(`)
 | 
						||
 | 
						||
type Runtime struct {
 | 
						||
	vm       *goja.Runtime
 | 
						||
	required map[string]bool
 | 
						||
	file     string
 | 
						||
	srcCode  string
 | 
						||
	code     string
 | 
						||
}
 | 
						||
 | 
						||
func (rt *Runtime) requireMod(name string) error {
 | 
						||
	var err error
 | 
						||
	if name == "console" || name == "" {
 | 
						||
		if !rt.required["console"] {
 | 
						||
			rt.required["console"] = true
 | 
						||
			err = rt.vm.Set("console", js.RequireConsole())
 | 
						||
		}
 | 
						||
	}
 | 
						||
	if err == nil && (name == "file" || name == "") {
 | 
						||
		if !rt.required["file"] {
 | 
						||
			rt.required["file"] = true
 | 
						||
			err = rt.vm.Set("file", js.RequireFile())
 | 
						||
		}
 | 
						||
	}
 | 
						||
	if err == nil && (name == "util" || name == "") {
 | 
						||
		if !rt.required["util"] {
 | 
						||
			rt.required["util"] = true
 | 
						||
			err = rt.vm.Set("util", js.RequireUtil())
 | 
						||
		}
 | 
						||
	}
 | 
						||
	if err == nil && (name == "ai" || name == "") {
 | 
						||
		if !rt.required["ai"] {
 | 
						||
			rt.required["ai"] = true
 | 
						||
			aiList := make(map[string]any)
 | 
						||
			for name, lm := range llm.List() {
 | 
						||
				aiList[name] = js.RequireAI(lm)
 | 
						||
			}
 | 
						||
			err = rt.vm.Set("ai", aiList)
 | 
						||
		}
 | 
						||
	}
 | 
						||
	return err
 | 
						||
}
 | 
						||
 | 
						||
func (rt *Runtime) makeImport(matcher *regexp.Regexp, code string) (string, int, error) {
 | 
						||
	var modErr error
 | 
						||
	importCount := 0
 | 
						||
	code = matcher.ReplaceAllStringFunc(code, func(str string) string {
 | 
						||
		if m := matcher.FindStringSubmatch(str); m != nil && len(m) > 3 {
 | 
						||
			optName := m[1]
 | 
						||
			if optName == "import" {
 | 
						||
				optName = "let"
 | 
						||
			}
 | 
						||
			varName := m[2]
 | 
						||
			modName := m[3]
 | 
						||
			importCount++
 | 
						||
			if modErr == nil {
 | 
						||
				if err := rt.requireMod(modName); err != nil {
 | 
						||
					modErr = err
 | 
						||
				}
 | 
						||
			}
 | 
						||
			if varName != modName {
 | 
						||
				return fmt.Sprintf("%s %s = %s", optName, varName, modName)
 | 
						||
			}
 | 
						||
		}
 | 
						||
		return ""
 | 
						||
	})
 | 
						||
	return code, importCount, modErr
 | 
						||
}
 | 
						||
 | 
						||
func StartFromFile(file string) (*Runtime, error) {
 | 
						||
	return StartFromCode(u.ReadFileN(file), file)
 | 
						||
}
 | 
						||
 | 
						||
func StartFromCode(code, refFile string) (*Runtime, error) {
 | 
						||
	if refFile == "" {
 | 
						||
		refFile = "main.js"
 | 
						||
	}
 | 
						||
 | 
						||
	if absFile, err := filepath.Abs(refFile); err == nil {
 | 
						||
		refFile = absFile
 | 
						||
	}
 | 
						||
 | 
						||
	InitFrom(filepath.Dir(refFile))
 | 
						||
 | 
						||
	rt := &Runtime{
 | 
						||
		vm:       goja.New(),
 | 
						||
		required: map[string]bool{},
 | 
						||
		file:     refFile,
 | 
						||
		srcCode:  code,
 | 
						||
		code:     code,
 | 
						||
	}
 | 
						||
 | 
						||
	// 按需加载引用
 | 
						||
	var importCount int
 | 
						||
	var modErr error
 | 
						||
	rt.code, importCount, modErr = rt.makeImport(importLibMatcher, rt.code)
 | 
						||
	if modErr == nil {
 | 
						||
		importCount1 := importCount
 | 
						||
		rt.code, importCount, modErr = rt.makeImport(requireLibMatcher, rt.code)
 | 
						||
		importCount += importCount1
 | 
						||
	}
 | 
						||
 | 
						||
	// 将 import 转换为 require
 | 
						||
	rt.code = importModMatcher.ReplaceAllString(rt.code, "let $1 = require('$2')")
 | 
						||
 | 
						||
	// 如果没有import,默认import所有
 | 
						||
	if modErr == nil && importCount == 0 {
 | 
						||
		modErr = rt.requireMod("")
 | 
						||
	}
 | 
						||
	if modErr != nil {
 | 
						||
		return nil, modErr
 | 
						||
	}
 | 
						||
 | 
						||
	//fmt.Println(u.BCyan(rt.code))
 | 
						||
 | 
						||
	// 处理模块引用
 | 
						||
	require.NewRegistryWithLoader(func(path string) ([]byte, error) {
 | 
						||
		refPath := filepath.Join(filepath.Dir(rt.file), path)
 | 
						||
		if !strings.HasSuffix(refPath, ".js") && !u.FileExists(refPath) {
 | 
						||
			refPath += ".js"
 | 
						||
		}
 | 
						||
		modCode, err := u.ReadFile(refPath)
 | 
						||
		if err != nil {
 | 
						||
			return nil, err
 | 
						||
		}
 | 
						||
		modCode, _, _ = rt.makeImport(importLibMatcher, modCode)
 | 
						||
		modCode, _, _ = rt.makeImport(requireLibMatcher, modCode)
 | 
						||
		return []byte(modCode), modErr
 | 
						||
	}).Enable(rt.vm)
 | 
						||
 | 
						||
	// 初始化主函数
 | 
						||
	if !checkMainMatcher.MatchString(rt.code) {
 | 
						||
		rt.code = "function main(...args){" + rt.code + "}"
 | 
						||
	}
 | 
						||
	if _, err := rt.vm.RunScript("main", rt.code); err != nil {
 | 
						||
		return nil, err
 | 
						||
	}
 | 
						||
 | 
						||
	return rt, nil
 | 
						||
}
 | 
						||
 | 
						||
func (rt *Runtime) Run(args ...any) (any, error) {
 | 
						||
	// 解析参数
 | 
						||
	for i, arg := range args {
 | 
						||
		if str, ok := arg.(string); ok {
 | 
						||
			var v interface{}
 | 
						||
			if err := json.Unmarshal([]byte(str), &v); err == nil {
 | 
						||
				args[i] = v
 | 
						||
			}
 | 
						||
		}
 | 
						||
	}
 | 
						||
 | 
						||
	if err := rt.vm.Set("__args", args); err != nil {
 | 
						||
		return nil, err
 | 
						||
	}
 | 
						||
	jsResult, err := rt.vm.RunScript(rt.file, "main(...__args)")
 | 
						||
 | 
						||
	var result any
 | 
						||
	if err == nil {
 | 
						||
		if jsResult != nil && !jsResult.Equals(goja.Undefined()) {
 | 
						||
			result = jsResult.Export()
 | 
						||
		}
 | 
						||
	}
 | 
						||
	return result, err
 | 
						||
}
 | 
						||
 | 
						||
type Exports struct {
 | 
						||
	LLMList []string
 | 
						||
}
 | 
						||
 | 
						||
func ExportForDev() (string, error) {
 | 
						||
	Init()
 | 
						||
	if len(llm.List()) == 0 && !u.FileExists("env.yml") && !u.FileExists("env.json") && !u.FileExists("llm.yml") && !u.FileExists("llm.json") {
 | 
						||
		return "", errors.New("no llm config found, please run `ai -e` on env.yml or llm.yml path")
 | 
						||
	}
 | 
						||
	exports := Exports{}
 | 
						||
	for name, _ := range llm.List() {
 | 
						||
		exports.LLMList = append(exports.LLMList, name)
 | 
						||
	}
 | 
						||
 | 
						||
	exportFile := filepath.Join("lib", "ai.ts")
 | 
						||
	var tpl *template.Template
 | 
						||
	var err error
 | 
						||
	if tpl, err = template.New(exportFile).Parse(aiTS); err == nil {
 | 
						||
		buf := bytes.NewBuffer(make([]byte, 0))
 | 
						||
		if err = tpl.Execute(buf, exports); err == nil {
 | 
						||
			err = u.WriteFileBytes(exportFile, buf.Bytes())
 | 
						||
		}
 | 
						||
	}
 | 
						||
	if err != nil {
 | 
						||
		return "", err
 | 
						||
	}
 | 
						||
 | 
						||
	_ = u.WriteFile(filepath.Join("lib", "console.ts"), consoleTS)
 | 
						||
	_ = u.WriteFile(filepath.Join("lib", "file.ts"), fileTS)
 | 
						||
	_ = u.WriteFile(filepath.Join("lib", "util.ts"), utilTS)
 | 
						||
 | 
						||
	return `import {` + strings.Join(exports.LLMList, ", ") + `} from './lib/ai'
 | 
						||
import console from './lib/console'
 | 
						||
import util from './lib/util'
 | 
						||
import file from './lib/file'`, nil
 | 
						||
}
 |