ai_old/js.go
2024-09-20 16:50:35 +08:00

293 lines
7.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package ai
import (
"apigo.cc/ai/ai/js"
"apigo.cc/ai/ai/llm"
"bytes"
_ "embed"
"encoding/json"
"errors"
"fmt"
"github.com/dop251/goja"
"github.com/dop251/goja_nodejs/require"
"github.com/ssgo/u"
"path/filepath"
"regexp"
"strings"
"text/template"
)
//go:embed js/lib/ai.ts
var aiTS string
//go:embed js/lib/console.ts
var consoleTS string
//go:embed js/lib/file.ts
var fileTS string
//go:embed js/lib/util.ts
var utilTS string
//go:embed js/lib/http.ts
var httpTS string
func RunFile(file string, args ...any) (any, error) {
return Run(u.ReadFileN(file), file, args...)
}
func Run(code string, refFile string, args ...any) (any, error) {
rt := New()
var r any
_, err := rt.StartFromCode(code, refFile)
if err == nil {
r, err = rt.RunMain(args...)
}
return r, err
}
var importModMatcher = regexp.MustCompile(`(?im)^\s*import\s+(.+?)\s+from\s+['"](.+?)['"]`)
var importLibMatcher = regexp.MustCompile(`(?im)^\s*(import)\s+(.+?)\s+from\s+['"][./\\\w:]+lib[/\\](.+?)(\.ts)?['"]`)
var requireLibMatcher = regexp.MustCompile(`(?im)^\s*(const|let|var)\s+(.+?)\s*=\s*require\s*\(\s*['"][./\\\w:]+lib[/\\](.+?)(\.ts)?['"]\s*\)`)
var checkMainMatcher = regexp.MustCompile(`(?im)^\s*function\s+main\s*\(`)
type Runtime struct {
vm *goja.Runtime
required map[string]bool
file string
srcCode string
code string
moduleLoader func(string) string
}
func (rt *Runtime) SetModuleLoader(fn func(filename string) string) {
rt.moduleLoader = fn
}
func (rt *Runtime) requireMod(name string) error {
var err error
if name == "console" || name == "" {
if !rt.required["console"] {
rt.required["console"] = true
err = rt.vm.Set("console", js.RequireConsole())
}
}
if err == nil && (name == "file" || name == "") {
if !rt.required["file"] {
rt.required["file"] = true
err = rt.vm.Set("file", js.RequireFile())
}
}
if err == nil && (name == "util" || name == "") {
if !rt.required["util"] {
rt.required["util"] = true
err = rt.vm.Set("util", js.RequireUtil())
}
}
if err == nil && (name == "http" || name == "") {
if !rt.required["http"] {
rt.required["http"] = true
err = rt.vm.Set("http", js.RequireHTTP())
}
}
if err == nil && (name == "ai" || name == "") {
if !rt.required["ai"] {
rt.required["ai"] = true
aiList := make(map[string]any)
for name, lm := range llm.List() {
aiList[name] = js.RequireAI(lm)
}
err = rt.vm.Set("ai", aiList)
}
}
return err
}
func (rt *Runtime) makeImport(matcher *regexp.Regexp, code string) (string, int, error) {
var modErr error
importCount := 0
code = matcher.ReplaceAllStringFunc(code, func(str string) string {
if m := matcher.FindStringSubmatch(str); m != nil && len(m) > 3 {
optName := m[1]
if optName == "import" {
optName = "let"
}
varName := m[2]
modName := m[3]
importCount++
if modErr == nil {
if err := rt.requireMod(modName); err != nil {
modErr = err
}
}
if varName != modName {
return fmt.Sprintf("%s %s = %s", optName, varName, modName)
}
}
return ""
})
return code, importCount, modErr
}
func New() *Runtime {
return &Runtime{
vm: goja.New(),
required: map[string]bool{},
}
}
func (rt *Runtime) StartFromFile(file string) (any, error) {
return rt.StartFromCode(u.ReadFileN(file), file)
}
func (rt *Runtime) StartFromCode(code, refFile string) (any, error) {
if refFile != "" {
rt.file = refFile
}
if rt.file == "" {
rt.file = "main.js"
}
if absFile, err := filepath.Abs(rt.file); err == nil {
rt.file = absFile
}
InitFrom(filepath.Dir(refFile))
if rt.srcCode == "" {
rt.srcCode = code
}
rt.code = code
// 按需加载引用
var importCount int
var modErr error
rt.code, importCount, modErr = rt.makeImport(importLibMatcher, rt.code)
if modErr == nil {
importCount1 := importCount
rt.code, importCount, modErr = rt.makeImport(requireLibMatcher, rt.code)
importCount += importCount1
}
// 将 import 转换为 require
rt.code = importModMatcher.ReplaceAllString(rt.code, "let $1 = require('$2')")
// 如果没有import默认import所有
if modErr == nil && importCount == 0 {
modErr = rt.requireMod("")
}
if modErr != nil {
return nil, modErr
}
//fmt.Println(u.BCyan(rt.code))
// 处理模块引用
require.NewRegistryWithLoader(func(path string) ([]byte, error) {
refPath := filepath.Join(filepath.Dir(rt.file), path)
if !strings.HasSuffix(refPath, ".js") && !u.FileExists(refPath) {
refPath += ".js"
}
modCode := ""
if rt.moduleLoader != nil {
modCode = rt.moduleLoader(refPath)
}
if modCode == "" {
var err error
modCode, err = u.ReadFile(refPath)
if err != nil {
return nil, err
}
}
modCode, _, _ = rt.makeImport(importLibMatcher, modCode)
modCode, _, _ = rt.makeImport(requireLibMatcher, modCode)
return []byte(modCode), modErr
}).Enable(rt.vm)
// 初始化主函数
if !checkMainMatcher.MatchString(rt.code) {
rt.code = "function main(...args){" + rt.code + "}"
}
if r, err := rt.vm.RunScript("main", rt.code); err != nil {
return nil, err
} else {
return r, nil
}
}
func (rt *Runtime) RunMain(args ...any) (any, error) {
// 解析参数
for i, arg := range args {
if str, ok := arg.(string); ok {
var v interface{}
if err := json.Unmarshal([]byte(str), &v); err == nil {
args[i] = v
}
}
}
if err := rt.vm.Set("__args", args); err != nil {
return nil, err
}
jsResult, err := rt.vm.RunScript(rt.file, "main(...__args)")
var result any
if err == nil {
if jsResult != nil && !jsResult.Equals(goja.Undefined()) {
result = jsResult.Export()
}
}
return result, err
}
func (rt *Runtime) RunCode(code string) (any, error) {
jsResult, err := rt.vm.RunScript(rt.file, code)
var result any
if err == nil {
if jsResult != nil && !jsResult.Equals(goja.Undefined()) {
result = jsResult.Export()
}
}
return result, err
}
type Exports struct {
LLMList []string
}
func ExportForDev() (string, error) {
Init()
if len(llm.List()) == 0 && !u.FileExists("env.yml") && !u.FileExists("env.json") && !u.FileExists("llm.yml") && !u.FileExists("llm.json") {
return "", errors.New("no llm config found, please run `ai -e` on env.yml or llm.yml path")
}
exports := Exports{}
for name, _ := range llm.List() {
exports.LLMList = append(exports.LLMList, name)
}
exportFile := filepath.Join("lib", "ai.ts")
var tpl *template.Template
var err error
if tpl, err = template.New(exportFile).Parse(aiTS); err == nil {
buf := bytes.NewBuffer(make([]byte, 0))
if err = tpl.Execute(buf, exports); err == nil {
err = u.WriteFileBytes(exportFile, buf.Bytes())
}
}
if err != nil {
return "", err
}
_ = u.WriteFile(filepath.Join("lib", "console.ts"), consoleTS)
_ = u.WriteFile(filepath.Join("lib", "file.ts"), fileTS)
_ = u.WriteFile(filepath.Join("lib", "util.ts"), utilTS)
_ = u.WriteFile(filepath.Join("lib", "http.ts"), httpTS)
return `import {` + strings.Join(exports.LLMList, ", ") + `} from './lib/ai'
import console from './lib/console'
import util from './lib/util'
import http from './lib/http'
import file from './lib/file'`, nil
}