ai_old/js.go
Star e69b9a3a12 add db、log support
support embedding for llm
support watch run for js
many other updates
2024-09-29 21:20:28 +08:00

429 lines
10 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package ai
import (
"apigo.cc/ai/ai/goja"
"apigo.cc/ai/ai/goja_nodejs/require"
"apigo.cc/ai/ai/interface/llm"
"apigo.cc/ai/ai/js"
"apigo.cc/ai/ai/watcher"
"bytes"
_ "embed"
"encoding/json"
"errors"
"fmt"
"github.com/ssgo/log"
"github.com/ssgo/u"
"os"
"os/signal"
"path/filepath"
"regexp"
"strings"
"syscall"
"text/template"
"time"
)
//go:embed js/lib/ai.ts
var aiTS string
//go:embed js/lib/console.ts
var consoleTS string
//go:embed js/lib/file.ts
var fileTS string
//go:embed js/lib/util.ts
var utilTS string
//go:embed js/lib/http.ts
var httpTS string
//go:embed js/lib/log.ts
var logTS string
//go:embed js/lib/db.ts
var dbTS string
func RunFile(file string, args ...any) (any, error) {
return Run(u.ReadFileN(file), file, args...)
}
func Run(code string, refFile string, args ...any) (any, error) {
rt := New()
var r any
_, err := rt.StartFromCode(code, refFile)
if err == nil {
r, err = rt.RunMain(args...)
}
return r, err
}
var importModMatcher = regexp.MustCompile(`(?im)^\s*import\s+(.+?)\s+from\s+['"](.+?)['"]`)
var importLibMatcher = regexp.MustCompile(`(?im)^\s*(import)\s+(.+?)\s+from\s+['"][./\\\w:]+lib[/\\](.+?)(\.ts)?['"]`)
var requireLibMatcher = regexp.MustCompile(`(?im)^\s*(const|let|var)\s+(.+?)\s*=\s*require\s*\(\s*['"][./\\\w:]+lib[/\\](.+?)(\.ts)?['"]\s*\)`)
var checkMainMatcher = regexp.MustCompile(`(?im)^\s*function\s+main\s*\(`)
type Runtime struct {
vm *goja.Runtime
required map[string]bool
file string
srcCode string
code string
moduleLoader func(string) string
}
func (rt *Runtime) SetModuleLoader(fn func(filename string) string) {
rt.moduleLoader = fn
}
func (rt *Runtime) GetCallStack() []string {
callStacks := make([]string, 0)
for _, stack := range rt.vm.CaptureCallStack(0, nil) {
callStacks = append(callStacks, stack.Position().String())
}
return callStacks
}
func (rt *Runtime) requireMod(name string) error {
var err error
if name == "console" || name == "" {
if !rt.required["console"] {
rt.required["console"] = true
err = rt.vm.Set("console", js.RequireConsole())
}
}
if err == nil && (name == "file" || name == "") {
if !rt.required["file"] {
rt.required["file"] = true
err = rt.vm.Set("file", js.RequireFile())
}
}
if err == nil && (name == "util" || name == "") {
if !rt.required["util"] {
rt.required["util"] = true
err = rt.vm.Set("util", js.RequireUtil())
}
}
if err == nil && (name == "http" || name == "") {
if !rt.required["http"] {
rt.required["http"] = true
err = rt.vm.Set("http", js.RequireHTTP())
}
}
if err == nil && (name == "log" || name == "") {
if !rt.required["log"] {
rt.required["log"] = true
err = rt.vm.Set("log", js.RequireLog())
}
}
if err == nil && (name == "db" || name == "") {
if !rt.required["db"] {
rt.required["db"] = true
err = rt.vm.Set("db", js.RequireDB())
}
}
if err == nil && (name == "ai" || name == "") {
if !rt.required["ai"] {
rt.required["ai"] = true
err = rt.vm.Set("ai", js.RequireAI())
}
}
return err
}
func (rt *Runtime) makeImport(matcher *regexp.Regexp, code string) (string, int, error) {
var modErr error
importCount := 0
code = matcher.ReplaceAllStringFunc(code, func(str string) string {
if m := matcher.FindStringSubmatch(str); m != nil && len(m) > 3 {
optName := m[1]
if optName == "import" {
optName = "let"
}
varName := m[2]
modName := m[3]
importCount++
if modErr == nil {
if err := rt.requireMod(modName); err != nil {
modErr = err
}
}
if varName != modName {
return fmt.Sprintf("%s %s = %s", optName, varName, modName)
}
}
return ""
})
return code, importCount, modErr
}
func New() *Runtime {
vm := goja.New()
vm.GoData = map[string]any{
"logger": log.New(u.ShortUniqueId()),
}
return &Runtime{
vm: vm,
required: map[string]bool{},
}
}
func (rt *Runtime) StartFromFile(file string) (any, error) {
return rt.StartFromCode(u.ReadFileN(file), file)
}
func (rt *Runtime) StartFromCode(code, refFile string) (any, error) {
if refFile != "" {
rt.file = refFile
}
if rt.file == "" {
rt.file = "main.js"
}
if absFile, err := filepath.Abs(rt.file); err == nil {
rt.file = absFile
}
refPath := filepath.Dir(refFile)
rt.vm.GoData["startPath"] = refPath
InitFrom(refPath)
if rt.srcCode == "" {
rt.srcCode = code
}
rt.code = code
// 按需加载引用
var importCount int
var modErr error
rt.code, importCount, modErr = rt.makeImport(importLibMatcher, rt.code)
if modErr == nil {
importCount1 := importCount
rt.code, importCount, modErr = rt.makeImport(requireLibMatcher, rt.code)
importCount += importCount1
}
// 将 import 转换为 require
rt.code = importModMatcher.ReplaceAllString(rt.code, "let $1 = require('$2')")
// 如果没有import默认import所有
if modErr == nil && importCount == 0 {
modErr = rt.requireMod("")
}
if modErr != nil {
return nil, modErr
}
//fmt.Println(u.BCyan(rt.code))
// 处理模块引用
require.NewRegistryWithLoader(func(path string) ([]byte, error) {
modFile := path
if !filepath.IsAbs(modFile) {
modFile = filepath.Join(filepath.Dir(rt.file), modFile)
}
if !strings.HasSuffix(modFile, ".js") && !u.FileExists(modFile) {
modFile += ".js"
}
modCode := ""
if rt.moduleLoader != nil {
modCode = rt.moduleLoader(modFile)
}
if modCode == "" {
var err error
modCode, err = u.ReadFile(modFile)
if err != nil {
return nil, err
}
}
modCode, _, _ = rt.makeImport(importLibMatcher, modCode)
modCode, _, _ = rt.makeImport(requireLibMatcher, modCode)
modCode = importModMatcher.ReplaceAllString(modCode, "let $1 = require('$2')")
return []byte(modCode), modErr
}).Enable(rt.vm)
// 初始化主函数
if !checkMainMatcher.MatchString(rt.code) {
rt.code = "function main(...args){" + rt.code + "}"
}
if r, err := rt.vm.RunScript(rt.file, rt.code); err != nil {
return nil, err
} else {
return r, nil
}
}
func (rt *Runtime) RunMain(args ...any) (any, error) {
// 解析参数
for i, arg := range args {
if str, ok := arg.(string); ok {
var v interface{}
if err := json.Unmarshal([]byte(str), &v); err == nil {
args[i] = v
}
}
}
if err := rt.vm.Set("__args", args); err != nil {
return nil, err
}
jsResult, err := rt.vm.RunScript("main", "main(...__args)")
var result any
if err == nil {
if jsResult != nil && !jsResult.Equals(goja.Undefined()) {
result = jsResult.Export()
}
}
return result, err
}
func (rt *Runtime) RunCode(code string) (any, error) {
jsResult, err := rt.vm.RunScript(rt.file, code)
var result any
if err == nil {
if jsResult != nil && !jsResult.Equals(goja.Undefined()) {
result = jsResult.Export()
}
}
return result, err
}
type Exports struct {
LLMList []string
}
func ExportForDev() (string, error) {
Init()
if len(llm.List()) == 0 && !u.FileExists("env.yml") && !u.FileExists("env.json") && !u.FileExists("llm.yml") && !u.FileExists("llm.json") {
return "", errors.New("no llm config found, please run `ai -e` on env.yml or llm.yml path")
}
exports := Exports{}
for name, _ := range llm.List() {
exports.LLMList = append(exports.LLMList, name)
}
exportFile := filepath.Join("lib", "ai.ts")
var tpl *template.Template
var err error
if tpl, err = template.New(exportFile).Parse(aiTS); err == nil {
buf := bytes.NewBuffer(make([]byte, 0))
if err = tpl.Execute(buf, exports); err == nil {
err = u.WriteFileBytes(exportFile, buf.Bytes())
}
}
if err != nil {
return "", err
}
_ = u.WriteFile(filepath.Join("lib", "console.ts"), consoleTS)
_ = u.WriteFile(filepath.Join("lib", "file.ts"), fileTS)
_ = u.WriteFile(filepath.Join("lib", "util.ts"), utilTS)
_ = u.WriteFile(filepath.Join("lib", "http.ts"), httpTS)
_ = u.WriteFile(filepath.Join("lib", "log.ts"), logTS)
_ = u.WriteFile(filepath.Join("lib", "db.ts"), dbTS)
return `import {` + strings.Join(exports.LLMList, ", ") + `} from './lib/ai'
import console from './lib/console'
import log from './lib/log'
import util from './lib/util'
import http from './lib/http'
import db from './lib/db'
import file from './lib/file'`, nil
}
//func RunFile(file string, args ...any) (any, error) {
// return Run(u.ReadFileN(file), file, args...)
//}
type WatchRunner struct {
w *watcher.Watcher
}
func (wr *WatchRunner) WaitForKill() {
exitCh := make(chan os.Signal, 1)
closeCh := make(chan bool, 1)
signal.Notify(exitCh, os.Interrupt, syscall.SIGTERM, syscall.SIGINT, syscall.SIGHUP)
go func() {
<-exitCh
closeCh <- true
}()
<-closeCh
wr.w.Stop()
}
func (wr *WatchRunner) Stop() {
wr.w.Stop()
}
func WatchRun(file string, extDirs, extTypes []string, args ...any) (*WatchRunner, error) {
wr := &WatchRunner{}
run := func() {
rt := New()
if wr.w != nil {
rt.SetModuleLoader(func(filename string) string {
filePath := filepath.Dir(filename)
needWatch := true
for _, v := range wr.w.WatchList() {
if v == filePath {
needWatch = false
break
}
}
if needWatch {
fmt.Println(u.BMagenta("[watching module path]"), filePath)
_ = wr.w.Add(filePath)
}
return u.ReadFileN(filename)
})
}
_, err := rt.StartFromFile(file)
result, err := rt.RunMain(args...)
if err != nil {
fmt.Println(u.BRed(err.Error()))
fmt.Println(u.Red(" " + strings.Join(rt.GetCallStack(), "\n ")))
} else if result != nil {
fmt.Println(u.Cyan(u.JsonP(result)))
}
}
var isWaitingRun = false
onChange := func(filename string, event string) {
if !isWaitingRun {
_, _ = os.Stdout.WriteString("\x1b[3;J\x1b[H\x1b[2J")
isWaitingRun = true
go func() {
time.Sleep(time.Millisecond * 10)
isWaitingRun = false
run()
}()
}
fmt.Println(u.BYellow("[changed]"), filename)
}
_, _ = os.Stdout.WriteString("\x1b[3;J\x1b[H\x1b[2J")
watchStartPath := filepath.Dir(file)
fmt.Println(u.BMagenta("[watching root path]"), watchStartPath)
watchDirs := []string{watchStartPath}
watchTypes := []string{"js", "json", "yml"}
if extDirs != nil {
for _, v := range extDirs {
watchDirs = append(watchDirs, v)
}
}
if extTypes != nil {
for _, v := range extTypes {
watchTypes = append(watchTypes, v)
}
}
if w, err := watcher.Start(watchDirs, watchTypes, onChange); err == nil {
wr.w = w
go func() {
run()
}()
return wr, nil
} else {
return nil, err
}
}