ai_old/js/js.go
2024-09-17 18:44:21 +08:00

239 lines
5.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package js
import (
"apigo.cc/ai/ai"
"apigo.cc/ai/ai/llm"
"bytes"
_ "embed"
"encoding/json"
"errors"
"fmt"
"github.com/dop251/goja"
"github.com/dop251/goja_nodejs/require"
"github.com/ssgo/u"
"path/filepath"
"regexp"
"strings"
"text/template"
)
//go:embed lib/ai.ts
var aiTS string
//go:embed lib/console.ts
var consoleTS string
//go:embed lib/file.ts
var fileTS string
func RunFile(file string, args ...any) (any, error) {
return Run(u.ReadFileN(file), file, args...)
}
func Run(code string, refFile string, args ...any) (any, error) {
var r any
js, err := StartFromCode(code, refFile)
if err == nil {
r, err = js.Run(args...)
}
return r, err
}
var importModMatcher = regexp.MustCompile(`(?im)^\s*import\s+(.+?)\s+from\s+['"](.+?)['"]`)
var importLibMatcher = regexp.MustCompile(`(?im)^\s*(import)\s+(.+?)\s+from\s+['"][./\\\w:]+lib[/\\](.+?)(\.ts)?['"]`)
var requireLibMatcher = regexp.MustCompile(`(?im)^\s*(const|let|var)\s+(.+?)\s*=\s*require\s*\(\s*['"][./\\\w:]+lib[/\\](.+?)(\.ts)?['"]\s*\)`)
var checkMainMatcher = regexp.MustCompile(`(?im)^\s*function\s+main\s*\(`)
type JS struct {
vm *goja.Runtime
required map[string]bool
file string
srcCode string
code string
}
func (js *JS) requireMod(name string) error {
var err error
if name == "console" || name == "" {
if !js.required["console"] {
js.required["console"] = true
err = js.vm.Set("console", requireConsole())
}
}
if err == nil && (name == "file" || name == "") {
if !js.required["file"] {
js.required["file"] = true
err = js.vm.Set("file", requireFile())
}
}
if err == nil && (name == "ai" || name == "") {
if !js.required["ai"] {
js.required["ai"] = true
aiList := make(map[string]any)
for name, lm := range llm.List() {
aiList[name] = requireAI(lm)
}
err = js.vm.Set("ai", aiList)
}
}
return err
}
func (js *JS) makeImport(matcher *regexp.Regexp, code string) (string, int, error) {
var modErr error
importCount := 0
code = matcher.ReplaceAllStringFunc(code, func(str string) string {
if m := matcher.FindStringSubmatch(str); m != nil && len(m) > 3 {
optName := m[1]
if optName == "import" {
optName = "let"
}
varName := m[2]
modName := m[3]
importCount++
if modErr == nil {
if err := js.requireMod(modName); err != nil {
modErr = err
}
}
if varName != modName {
return fmt.Sprintf("%s %s = %s", optName, varName, modName)
}
}
return ""
})
return code, importCount, modErr
}
func StartFromFile(file string) (*JS, error) {
return StartFromCode(u.ReadFileN(file), file)
}
func StartFromCode(code, refFile string) (*JS, error) {
if refFile == "" {
refFile = "main.js"
}
if absFile, err := filepath.Abs(refFile); err == nil {
refFile = absFile
}
ai.InitFrom(filepath.Dir(refFile))
js := &JS{
vm: goja.New(),
required: map[string]bool{},
file: refFile,
srcCode: code,
code: code,
}
// 按需加载引用
var importCount int
var modErr error
js.code, importCount, modErr = js.makeImport(importLibMatcher, js.code)
if modErr == nil {
importCount1 := importCount
js.code, importCount, modErr = js.makeImport(requireLibMatcher, js.code)
importCount += importCount1
}
// 将 import 转换为 require
js.code = importModMatcher.ReplaceAllString(js.code, "let $1 = require('$2')")
// 如果没有import默认import所有
if modErr == nil && importCount == 0 {
modErr = js.requireMod("")
}
if modErr != nil {
return nil, modErr
}
//fmt.Println(u.BCyan(js.code))
// 处理模块引用
require.NewRegistryWithLoader(func(path string) ([]byte, error) {
refPath := filepath.Join(filepath.Dir(js.file), path)
if !strings.HasSuffix(refPath, ".js") && !u.FileExists(refPath) {
refPath += ".js"
}
modCode, err := u.ReadFile(refPath)
if err != nil {
return nil, err
}
modCode, _, _ = js.makeImport(importLibMatcher, modCode)
modCode, _, _ = js.makeImport(requireLibMatcher, modCode)
return []byte(modCode), modErr
}).Enable(js.vm)
// 初始化主函数
if !checkMainMatcher.MatchString(js.code) {
js.code = "function main(...args){" + js.code + "}"
}
if _, err := js.vm.RunScript("main", js.code); err != nil {
return nil, err
}
return js, nil
}
func (js *JS) Run(args ...any) (any, error) {
// 解析参数
for i, arg := range args {
if str, ok := arg.(string); ok {
var v interface{}
if err := json.Unmarshal([]byte(str), &v); err == nil {
args[i] = v
}
}
}
if err := js.vm.Set("__args", args); err != nil {
return nil, err
}
jsResult, err := js.vm.RunScript(js.file, "main(...__args)")
var result any
if err == nil {
if jsResult != nil && !jsResult.Equals(goja.Undefined()) {
result = jsResult.Export()
}
}
return result, err
}
type Exports struct {
LLMList []string
}
func ExportForDev() (string, error) {
ai.Init()
if len(llm.List()) == 0 && !u.FileExists("env.yml") && !u.FileExists("env.json") && !u.FileExists("llm.yml") && !u.FileExists("llm.json") {
return "", errors.New("no llm config found, please run `ai -e` on env.yml or llm.yml path")
}
exports := Exports{}
for name, _ := range llm.List() {
exports.LLMList = append(exports.LLMList, name)
}
exportFile := filepath.Join("lib", "ai.ts")
var tpl *template.Template
var err error
if tpl, err = template.New(exportFile).Parse(aiTS); err == nil {
buf := bytes.NewBuffer(make([]byte, 0))
if err = tpl.Execute(buf, exports); err == nil {
err = u.WriteFileBytes(exportFile, buf.Bytes())
}
}
if err != nil {
return "", err
}
_ = u.WriteFile(filepath.Join("lib", "console.ts"), consoleTS)
_ = u.WriteFile(filepath.Join("lib", "file.ts"), fileTS)
return `import {` + strings.Join(exports.LLMList, ", ") + `} from './lib/ai'
import console from './lib/console'
import file from './lib/file'`, nil
}