ai/gojs.go

283 lines
6.2 KiB
Go

package ai
import (
"bytes"
_ "embed"
"encoding/binary"
"html/template"
"math"
"strings"
"apigo.cc/gojs"
"apigo.cc/gojs/goja"
"github.com/ssgo/config"
)
//go:embed export.ts
var exportTS string
//go:embed README.md
var readmeMD string
func init() {
gojs.Register("apigo.cc/ai", gojs.Module{
ObjectMaker: makeJsObj,
TsCodeMaker: makeTsCode,
Desc: "ai plugin for gojs(http://apigo.cc/gojs)",
Example: readmeMD,
SetSSKey: SetSSKey,
})
}
var aiList map[string]*AILoadConfig
func makeAIList() {
if aiList == nil {
aiList = map[string]*AILoadConfig{}
_ = config.LoadConfig("ai", &aiList)
for aiName, aiConf := range aiList {
if aiConf.Agent == "" {
aiConf.Agent = aiName
}
if aiConf.ApiKey != "" {
aiConf.apiKey = confAes.DecryptUrlBase64ToString(aiConf.ApiKey)
aiConf.ApiKey = ""
}
if agent := agents[aiConf.Agent]; agent != nil {
if aiConf.Chat == nil {
aiConf.Chat = map[string]*ChatConfig{}
}
for confName, conf := range agent.ChatConfigs {
if aiConf.Chat[confName] == nil {
aiConf.Chat[confName] = conf
}
}
if aiConf.Embedding == nil {
aiConf.Embedding = map[string]*EmbeddingConfig{}
}
for confName, conf := range agent.EmbeddingConfigs {
if aiConf.Embedding[confName] == nil {
aiConf.Embedding[confName] = conf
}
}
if aiConf.Image == nil {
aiConf.Image = map[string]*ImageConfig{}
}
for confName, conf := range agent.ImageConfigs {
if aiConf.Image[confName] == nil {
aiConf.Image[confName] = conf
}
}
if aiConf.Video == nil {
aiConf.Video = map[string]*VideoConfig{}
}
for confName, conf := range agent.VideoConfigs {
if aiConf.Video[confName] == nil {
aiConf.Video[confName] = conf
}
}
if aiConf.Edit == nil {
aiConf.Edit = map[string]*map[string]any{}
}
for confName, conf := range agent.EditConfigs {
if aiConf.Edit[confName] == nil {
aiConf.Edit[confName] = conf
}
}
if aiConf.Scan == nil {
aiConf.Scan = map[string]*map[string]any{}
}
for confName, conf := range agent.ScanConfigs {
if aiConf.Scan[confName] == nil {
aiConf.Scan[confName] = conf
}
}
if aiConf.Asr == nil {
aiConf.Asr = map[string]*AsrConfig{}
}
for confName, conf := range agent.AsrConfigs {
if aiConf.Asr[confName] == nil {
aiConf.Asr[confName] = conf
}
}
if aiConf.Tts == nil {
aiConf.Tts = map[string]*TtsConfig{}
}
for confName, conf := range agent.TtsConfigs {
if aiConf.Tts[confName] == nil {
aiConf.Tts[confName] = conf
}
}
}
}
}
}
func makeTsCode() string {
makeAIList()
var tpl *template.Template
var err error
aiTSCode := ""
if tpl, err = template.New("").Parse(strings.ReplaceAll(exportTS, "//----", "")); err == nil {
buf := bytes.NewBuffer(make([]byte, 0))
if err = tpl.Execute(buf, aiList); err == nil {
aiTSCode = buf.String()
} else {
println(err.Error())
}
} else {
println(err.Error())
}
return aiTSCode
}
var jsObj gojs.Map
func makeJsObj(vm *goja.Runtime) gojs.Map {
if jsObj == nil {
makeAIList()
jsObj = make(gojs.Map)
for aiName, aiLoadConf := range aiList {
agent := agents[aiLoadConf.Agent]
if agent == nil {
continue
}
aiConf := AIConfig{
ApiKey: aiLoadConf.apiKey,
Endpoint: aiLoadConf.Endpoint,
Extra: aiLoadConf.Extra,
}
aiObj := map[string]any{}
// 生成Chat方法
for confName, conf := range aiLoadConf.Chat {
obj := &agentObj{
config: &aiConf,
chatConfig: conf,
agent: agent,
}
aiObj[confName] = obj.Chat
}
// 生成Embedding方法
for confName, conf := range aiLoadConf.Embedding {
obj := &agentObj{
config: &aiConf,
embeddingConfig: conf,
agent: agent,
}
aiObj[confName] = obj.Embedding
}
// 生成MakeImage方法
for confName, conf := range aiLoadConf.Image {
obj := &agentObj{
config: &aiConf,
imageConfig: conf,
agent: agent,
}
aiObj[confName] = obj.MakeImage
}
// 生成MakeVideo方法
for confName, conf := range aiLoadConf.Video {
obj := &agentObj{
config: &aiConf,
videoConfig: conf,
agent: agent,
}
aiObj[confName] = obj.MakeVideo
if aiObj["getVideoResult"] == nil {
aiObj["getVideoResult"] = obj.GetVideoResult
}
}
// 生成Edit方法
for confName, conf := range aiLoadConf.Edit {
obj := &agentObj{
config: &aiConf,
editConfig: conf,
agent: agent,
}
aiObj[confName] = obj.Edit
}
// 生成Scan方法
for confName, conf := range aiLoadConf.Scan {
obj := &agentObj{
config: &aiConf,
scanConfig: conf,
agent: agent,
}
aiObj[confName] = obj.Scan
}
// 生成Asr方法
for confName, conf := range aiLoadConf.Asr {
obj := &agentObj{
config: &aiConf,
asrConfig: conf,
agent: agent,
}
aiObj[confName] = obj.Asr
}
// 生成Tts方法
for confName, conf := range aiLoadConf.Tts {
obj := &agentObj{
config: &aiConf,
ttsConfig: conf,
agent: agent,
}
aiObj[confName] = obj.Tts
}
jsObj[aiName] = aiObj
}
// 生成similarity方法
jsObj["similarity"] = similarity
}
return jsObj
}
func similarity(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value {
args := gojs.MakeArgs(&argsIn, vm).Check(2)
return vm.ToValue(makeSimilarity(args.Bytes(0), args.Bytes(1)))
}
func makeSimilarity(buf1, buf2 []byte) float64 {
a := bin2float64(buf1)
b := bin2float64(buf2)
if len(a) != len(b) {
return 0
}
var dotProduct, magnitudeA, magnitudeB float64
for i := 0; i < len(a); i++ {
dotProduct += a[i] * b[i]
magnitudeA += a[i] * a[i]
magnitudeB += b[i] * b[i]
}
magnitudeA = math.Sqrt(magnitudeA)
magnitudeB = math.Sqrt(magnitudeB)
if magnitudeA == 0 || magnitudeB == 0 {
return 0
}
return dotProduct / (magnitudeA * magnitudeB)
}
func bin2float64(in []byte) []float64 {
buf := bytes.NewBuffer(in)
out := make([]float64, len(in)/4)
for i := 0; i < len(out); i++ {
var f float32
_ = binary.Read(buf, binary.LittleEndian, &f)
out[i] = float64(f)
}
return out
}