zhipu/zhipu/fine_tune.go
2024-09-07 23:14:12 +08:00

457 lines
11 KiB
Go

package zhipu
import (
"context"
"strconv"
"github.com/go-resty/resty/v2"
)
const (
HyperParameterAuto = "auto"
FineTuneStatusCreate = "create"
FineTuneStatusValidatingFiles = "validating_files"
FineTuneStatusQueued = "queued"
FineTuneStatusRunning = "running"
FineTuneStatusSucceeded = "succeeded"
FineTuneStatusFailed = "failed"
FineTuneStatusCancelled = "cancelled"
)
// FineTuneItem is the item of the FineTune
type FineTuneItem struct {
ID string `json:"id"`
RequestID string `json:"request_id"`
FineTunedModel string `json:"fine_tuned_model"`
Status string `json:"status"`
Object string `json:"object"`
TrainingFile string `json:"training_file"`
ValidationFile string `json:"validation_file"`
Error APIError `json:"error"`
}
// FineTuneCreateService creates a new fine tune
type FineTuneCreateService struct {
client *Client
model string
trainingFile string
validationFile *string
learningRateMultiplier *StringOr[float64]
batchSize *StringOr[int]
nEpochs *StringOr[int]
suffix *string
requestID *string
}
// FineTuneCreateResponse is the response of the FineTuneCreateService
type FineTuneCreateResponse = FineTuneItem
// NewFineTuneCreateService creates a new FineTuneCreateService
func NewFineTuneCreateService(client *Client) *FineTuneCreateService {
return &FineTuneCreateService{
client: client,
}
}
// SetModel sets the model parameter
func (s *FineTuneCreateService) SetModel(model string) *FineTuneCreateService {
s.model = model
return s
}
// SetTrainingFile sets the trainingFile parameter
func (s *FineTuneCreateService) SetTrainingFile(trainingFile string) *FineTuneCreateService {
s.trainingFile = trainingFile
return s
}
// SetValidationFile sets the validationFile parameter
func (s *FineTuneCreateService) SetValidationFile(validationFile string) *FineTuneCreateService {
s.validationFile = &validationFile
return s
}
// SetLearningRateMultiplier sets the learningRateMultiplier parameter
func (s *FineTuneCreateService) SetLearningRateMultiplier(learningRateMultiplier float64) *FineTuneCreateService {
s.learningRateMultiplier = &StringOr[float64]{}
s.learningRateMultiplier.SetValue(learningRateMultiplier)
return s
}
// SetLearningRateMultiplierAuto sets the learningRateMultiplier parameter to auto
func (s *FineTuneCreateService) SetLearningRateMultiplierAuto() *FineTuneCreateService {
s.learningRateMultiplier = &StringOr[float64]{}
s.learningRateMultiplier.SetString(HyperParameterAuto)
return s
}
// SetBatchSize sets the batchSize parameter
func (s *FineTuneCreateService) SetBatchSize(batchSize int) *FineTuneCreateService {
s.batchSize = &StringOr[int]{}
s.batchSize.SetValue(batchSize)
return s
}
// SetBatchSizeAuto sets the batchSize parameter to auto
func (s *FineTuneCreateService) SetBatchSizeAuto() *FineTuneCreateService {
s.batchSize = &StringOr[int]{}
s.batchSize.SetString(HyperParameterAuto)
return s
}
// SetNEpochs sets the nEpochs parameter
func (s *FineTuneCreateService) SetNEpochs(nEpochs int) *FineTuneCreateService {
s.nEpochs = &StringOr[int]{}
s.nEpochs.SetValue(nEpochs)
return s
}
// SetNEpochsAuto sets the nEpochs parameter to auto
func (s *FineTuneCreateService) SetNEpochsAuto() *FineTuneCreateService {
s.nEpochs = &StringOr[int]{}
s.nEpochs.SetString(HyperParameterAuto)
return s
}
// SetSuffix sets the suffix parameter
func (s *FineTuneCreateService) SetSuffix(suffix string) *FineTuneCreateService {
s.suffix = &suffix
return s
}
// SetRequestID sets the requestID parameter
func (s *FineTuneCreateService) SetRequestID(requestID string) *FineTuneCreateService {
s.requestID = &requestID
return s
}
// Do makes the request
func (s *FineTuneCreateService) Do(ctx context.Context) (res FineTuneCreateResponse, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
body := M{
"model": s.model,
"training_file": s.trainingFile,
}
if s.validationFile != nil {
body["validation_file"] = *s.validationFile
}
if s.suffix != nil {
body["suffix"] = *s.suffix
}
if s.requestID != nil {
body["request_id"] = *s.requestID
}
if s.learningRateMultiplier != nil || s.batchSize != nil || s.nEpochs != nil {
hp := M{}
if s.learningRateMultiplier != nil {
hp["learning_rate_multiplier"] = s.learningRateMultiplier
}
if s.batchSize != nil {
hp["batch_size"] = s.batchSize
}
if s.nEpochs != nil {
hp["n_epochs"] = s.nEpochs
}
body["hyperparameters"] = hp
}
if resp, err = s.client.request(ctx).
SetBody(body).
SetResult(&res).
SetError(&apiError).
Post("fine_tuning/jobs"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}
// FineTuneEventListService creates a new fine tune event list
type FineTuneEventListService struct {
client *Client
jobID string
limit *int
after *string
}
// FineTuneEventData is the data of the FineTuneEventItem
type FineTuneEventData struct {
Acc float64 `json:"acc"`
Loss float64 `json:"loss"`
CurrentSteps int64 `json:"current_steps"`
RemainingTime string `json:"remaining_time"`
ElapsedTime string `json:"elapsed_time"`
TotalSteps int64 `json:"total_steps"`
Epoch int64 `json:"epoch"`
TrainedTokens int64 `json:"trained_tokens"`
LearningRate float64 `json:"learning_rate"`
}
// FineTuneEventItem is the item of the FineTuneEventListResponse
type FineTuneEventItem struct {
ID string `json:"id"`
Type string `json:"type"`
Level string `json:"level"`
Message string `json:"message"`
Object string `json:"object"`
CreatedAt int64 `json:"created_at"`
Data FineTuneEventData `json:"data"`
}
// FineTuneEventListResponse is the response of the FineTuneEventListService
type FineTuneEventListResponse struct {
Data []FineTuneEventItem `json:"data"`
HasMore bool `json:"has_more"`
Object string `json:"object"`
}
// NewFineTuneEventListService creates a new FineTuneEventListService
func NewFineTuneEventListService(client *Client) *FineTuneEventListService {
return &FineTuneEventListService{
client: client,
}
}
// SetJobID sets the jobID parameter
func (s *FineTuneEventListService) SetJobID(jobID string) *FineTuneEventListService {
s.jobID = jobID
return s
}
// SetLimit sets the limit parameter
func (s *FineTuneEventListService) SetLimit(limit int) *FineTuneEventListService {
s.limit = &limit
return s
}
// SetAfter sets the after parameter
func (s *FineTuneEventListService) SetAfter(after string) *FineTuneEventListService {
s.after = &after
return s
}
// Do makes the request
func (s *FineTuneEventListService) Do(ctx context.Context) (res FineTuneEventListResponse, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
req := s.client.request(ctx)
if s.limit != nil {
req.SetQueryParam("limit", strconv.Itoa(*s.limit))
}
if s.after != nil {
req.SetQueryParam("after", *s.after)
}
if resp, err = req.
SetPathParam("job_id", s.jobID).
SetResult(&res).
SetError(&apiError).
Get("fine_tuning/jobs/{job_id}/events"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}
// FineTuneGetService creates a new fine tune get
type FineTuneGetService struct {
client *Client
jobID string
}
// NewFineTuneGetService creates a new FineTuneGetService
func NewFineTuneGetService(client *Client) *FineTuneGetService {
return &FineTuneGetService{
client: client,
}
}
// SetJobID sets the jobID parameter
func (s *FineTuneGetService) SetJobID(jobID string) *FineTuneGetService {
s.jobID = jobID
return s
}
// Do makes the request
func (s *FineTuneGetService) Do(ctx context.Context) (res FineTuneItem, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
if resp, err = s.client.request(ctx).
SetPathParam("job_id", s.jobID).
SetResult(&res).
SetError(&apiError).
Get("fine_tuning/jobs/{job_id}"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}
// FineTuneListService creates a new fine tune list
type FineTuneListService struct {
client *Client
limit *int
after *string
}
// FineTuneListResponse is the response of the FineTuneListService
type FineTuneListResponse struct {
Data []FineTuneItem `json:"data"`
Object string `json:"object"`
}
// NewFineTuneListService creates a new FineTuneListService
func NewFineTuneListService(client *Client) *FineTuneListService {
return &FineTuneListService{
client: client,
}
}
// SetLimit sets the limit parameter
func (s *FineTuneListService) SetLimit(limit int) *FineTuneListService {
s.limit = &limit
return s
}
// SetAfter sets the after parameter
func (s *FineTuneListService) SetAfter(after string) *FineTuneListService {
s.after = &after
return s
}
// Do makes the request
func (s *FineTuneListService) Do(ctx context.Context) (res FineTuneListResponse, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
req := s.client.request(ctx)
if s.limit != nil {
req.SetQueryParam("limit", strconv.Itoa(*s.limit))
}
if s.after != nil {
req.SetQueryParam("after", *s.after)
}
if resp, err = req.
SetResult(&res).
SetError(&apiError).
Get("fine_tuning/jobs"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}
// FineTuneDeleteService creates a new fine tune delete
type FineTuneDeleteService struct {
client *Client
jobID string
}
// NewFineTuneDeleteService creates a new FineTuneDeleteService
func NewFineTuneDeleteService(client *Client) *FineTuneDeleteService {
return &FineTuneDeleteService{
client: client,
}
}
// SetJobID sets the jobID parameter
func (s *FineTuneDeleteService) SetJobID(jobID string) *FineTuneDeleteService {
s.jobID = jobID
return s
}
// Do makes the request
func (s *FineTuneDeleteService) Do(ctx context.Context) (res FineTuneItem, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
if resp, err = s.client.request(ctx).
SetPathParam("job_id", s.jobID).
SetResult(&res).
SetError(&apiError).
Delete("fine_tuning/jobs/{job_id}"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}
// FineTuneCancelService creates a new fine tune cancel
type FineTuneCancelService struct {
client *Client
jobID string
}
// NewFineTuneCancelService creates a new FineTuneCancelService
func NewFineTuneCancelService(client *Client) *FineTuneCancelService {
return &FineTuneCancelService{
client: client,
}
}
// SetJobID sets the jobID parameter
func (s *FineTuneCancelService) SetJobID(jobID string) *FineTuneCancelService {
s.jobID = jobID
return s
}
// Do makes the request
func (s *FineTuneCancelService) Do(ctx context.Context) (res FineTuneItem, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
if resp, err = s.client.request(ctx).
SetPathParam("job_id", s.jobID).
SetResult(&res).
SetError(&apiError).
Post("fine_tuning/jobs/{job_id}/cancel"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}