From ad02b380c2b420c381b0b6726b1ca19f23f45dc5 Mon Sep 17 00:00:00 2001 From: AI Engineer Date: Fri, 8 May 2026 21:56:55 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=20Multipart=20?= =?UTF-8?q?=E7=B1=BB=E5=9E=8B=E4=B8=8A=E4=BC=A0=EF=BC=8C=E9=87=8D=E6=9E=84?= =?UTF-8?q?=20Form=20=E7=B1=BB=E5=9E=8B=EF=BC=8C=E5=AF=B9=E9=BD=90?= =?UTF-8?q?=E5=9F=BA=E7=A1=80=E8=AE=BE=E6=96=BD=20v1.0.7=20(by=20AI)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 11 ++- client.go | 185 ++++++++++++++++++++++--------------------------- client_test.go | 36 ++++++++-- go.mod | 6 +- 4 files changed, 124 insertions(+), 114 deletions(-) diff --git a/README.md b/README.md index e8da437..1b4d011 100644 --- a/README.md +++ b/README.md @@ -66,11 +66,18 @@ c.Download("local_file.zip", "https://example.com/large_file.zip", func(start, e ### 请求方法 - `func (c *Client) Get(url string, headers ...string) *Result` -- `func (c *Client) Post(url string, data any, headers ...string) *Result` +- `func (c *Client) Post(url string, data any, headers ...string) *Result`: 支持多种数据类型(JSON, Form, Multipart)。 - `func (c *Client) Put(url string, data any, headers ...string) *Result` - `func (c *Client) Delete(url string, data any, headers ...string) *Result` - `func (c *Client) Head(url string, headers ...string) *Result` -- `func (c *Client) PostMultipart(url string, formData map[string]string, files map[string]any, headers ...string) (*Result, []error)`: 多部分表单提交(支持文件与流)。 + +### 特殊类型 +- `type Form map[string]string`: 用于 `Post/Put` 等方法,显式指定为 `application/x-www-form-urlencoded` 格式。 + - 注意:直接传入 `map[string]string` 会被默认识别为 `application/json`。 +- `type Multipart map[string]any`: 用于 `Post/Put` 等方法,支持混合表单字段与文件上传。 + - 如果值为 `string` 且指向有效文件路径,则作为文件上传。 + - 如果值为 `[]byte` 或 `io.Reader`,则作为文件上传。 + - 其他类型将作为普通表单字段(复杂类型会自动转为 JSON)。 ### 响应处理 (Result) - `func (rs *Result) String() string`: 返回响应体字符串。 diff --git a/client.go b/client.go index 6bafdd2..cf6d554 100644 --- a/client.go +++ b/client.go @@ -42,7 +42,8 @@ type Result struct { data []byte } -type Form = map[string]string +type Form map[string]string +type Multipart map[string]any var bufferPool = sync.Pool{ New: func() any { @@ -162,6 +163,18 @@ func (client *Client) doByRequest(manualDo bool, request *http.Request, method, } } + // 确保 Request-ID 存在 + foundID := false + for i := 1; i < len(headers); i += 2 { + if headers[i-1] == HeaderRequestID { + foundID = true + break + } + } + if !foundID { + headers = append(headers, HeaderRequestID, string(encoding.Hex(rand.Bytes(16)))) + } + // 续传 X-Forwarded-For xForwardFor := request.Header.Get(HeaderForwardedFor) remoteIP, _, err := net.SplitHostPort(request.RemoteAddr) @@ -176,18 +189,6 @@ func (client *Client) doByRequest(manualDo bool, request *http.Request, method, } headers = append(headers, HeaderForwardedFor, xForwardFor) - // 处理请求唯一编号 - foundID := false - for i := 1; i < len(headers); i += 2 { - if headers[i-1] == HeaderRequestID { - foundID = true - break - } - } - if !foundID { - headers = append(headers, HeaderRequestID, string(encoding.Hex(rand.Bytes(16)))) - } - headers = append(headers, settedHeaders...) if manualDo { @@ -323,72 +324,61 @@ func (client *Client) Download(filename, url string, callback func(start, end in return result, err } -func (client *Client) PostMultipart(url string, formData map[string]string, files map[string]any, headers ...string) (*Result, []error) { +func (client *Client) buildMultipart(writer *multipart.Writer, data map[string]any) []error { errs := make([]error, 0) - buf := bufferPool.Get().(*bytes.Buffer) - buf.Reset() - defer bufferPool.Put(buf) - - writer := multipart.NewWriter(buf) - - if formData != nil { - for key, value := range formData { - if err := writer.WriteField(key, value); err != nil { - errs = append(errs, err) - } - } - } - - if files != nil { - for key, value := range files { - if filename, ok := value.(string); ok && file.Exists(filename) { - var reader io.Reader - var closer io.Closer - if mf := file.ReadFileFromMemory(filename); mf != nil { - reader = bytes.NewReader(mf.GetData()) - } else { - if fp, err := os.Open(filename); err == nil { - reader = fp - closer = fp - } else { - errs = append(errs, err) - continue - } - } - - if part, err := writer.CreateFormFile(key, filepath.Base(filename)); err == nil { - if _, err = io.Copy(part, reader); err != nil { - errs = append(errs, err) - } + for key, value := range data { + if filename, ok := value.(string); ok && file.Exists(filename) { + var r io.Reader + var closer io.Closer + if mf := file.ReadFileFromMemory(filename); mf != nil { + r = bytes.NewReader(mf.GetData()) + } else { + if fp, err := os.Open(filename); err == nil { + r = fp + closer = fp } else { errs = append(errs, err) + continue } + } - if closer != nil { - _ = closer.Close() + if part, err := writer.CreateFormFile(key, filepath.Base(filename)); err == nil { + if _, err = io.Copy(part, r); err != nil { + errs = append(errs, err) } } else { - h := make(textproto.MIMEHeader) - var dataBytes []byte - switch t := value.(type) { - case io.Reader: - dataBytes, _ = io.ReadAll(t) - h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, key, key)) - h.Set("Content-Type", "application/octet-stream") - case []byte: - dataBytes = t - h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, key, key)) - h.Set("Content-Type", "application/octet-stream") - case string: - dataBytes = []byte(t) - h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s.txt"`, key, key)) - h.Set("Content-Type", "text/plain") - default: - dataBytes = cast.As(cast.ToJSONBytes(value)) - h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s.json"`, key, key)) - h.Set("Content-Type", "application/json") - } + errs = append(errs, err) + } + if closer != nil { + _ = closer.Close() + } + } else { + var dataBytes []byte + h := make(textproto.MIMEHeader) + isField := false + switch t := value.(type) { + case io.Reader: + dataBytes, _ = io.ReadAll(t) + h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, key, key)) + h.Set("Content-Type", "application/octet-stream") + case []byte: + dataBytes = t + h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, key, key)) + h.Set("Content-Type", "application/octet-stream") + case string: + isField = true + dataBytes = []byte(t) + default: + isField = true + dataBytes, _ = cast.ToJSONBytes(value) + } + + if isField { + if err := writer.WriteField(key, string(dataBytes)); err != nil { + errs = append(errs, err) + } + } else { if part, err := writer.CreatePart(h); err == nil { if _, err = part.Write(dataBytes); err != nil { errs = append(errs, err) @@ -399,21 +389,7 @@ func (client *Client) PostMultipart(url string, formData map[string]string, file } } } - - if err := writer.Close(); err != nil { - errs = append(errs, err) - } - - if len(errs) > 0 { - return nil, errs - } - - headers = append(headers, "Content-Type", writer.FormDataContentType()) - result := client.Post(url, buf.Bytes(), headers...) - if result.Error != nil { - errs = append(errs, result.Error) - } - return result, errs + return errs } func (client *Client) do(fetchBody bool, method, url string, data any, headers ...string) *Result { @@ -438,13 +414,7 @@ func (client *Client) do(fetchBody bool, method, url string, data any, headers . reader = strings.NewReader(encoded) contentType = "application/x-www-form-urlencoded" contentLength = len(encoded) - case map[string][]string: - values := url2.Values(t) - encoded := values.Encode() - reader = strings.NewReader(encoded) - contentType = "application/x-www-form-urlencoded" - contentLength = len(encoded) - case map[string]string: + case Form: values := url2.Values{} for k, v := range t { values.Set(k, v) @@ -453,6 +423,22 @@ func (client *Client) do(fetchBody bool, method, url string, data any, headers . reader = strings.NewReader(encoded) contentType = "application/x-www-form-urlencoded" contentLength = len(encoded) + case Multipart: + buf := bufferPool.Get().(*bytes.Buffer) + buf.Reset() + defer bufferPool.Put(buf) + writer := multipart.NewWriter(buf) + errs := client.buildMultipart(writer, t) + if err := writer.Close(); err != nil { + errs = append(errs, err) + } + if len(errs) > 0 { + return &Result{Error: errors.Join(errs...)} + } + bytesData := buf.Bytes() + reader = bytes.NewReader(bytesData) + contentType = writer.FormDataContentType() + contentLength = len(bytesData) default: bytesData, _ := cast.ToJSONBytes(data) if len(bytesData) > 0 && string(bytesData) != "null" { @@ -558,19 +544,10 @@ func (result *Result) Slice() []any { } func (result *Result) To(v any) error { - if result.data == nil { + if len(result.data) == 0 { return errors.New("no data") } - err := cast.UnmarshalJSON(result.data, v) - if err != nil { - // 如果 cast 直接解不出来,尝试通过 convert 做深度映射(处理 struct 字段匹配等) - var tmp any - if err2 := cast.UnmarshalJSON(result.data, &tmp); err2 == nil { - cast.Convert(v, tmp) - return nil - } - } - return err + return cast.UnmarshalJSON(result.data, v) } // To 使用泛型获取结果 diff --git a/client_test.go b/client_test.go index 0dfbbaa..d5bb7ea 100644 --- a/client_test.go +++ b/client_test.go @@ -178,7 +178,7 @@ func TestDownload(t *testing.T) { } } -func TestMPost(t *testing.T) { +func TestMultipartDo(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _ = r.ParseMultipartForm(10 << 20) f := r.FormValue("foo") @@ -189,17 +189,43 @@ func TestMPost(t *testing.T) { } fmt.Fprintf(w, "foo=%s,file=%s", f, string(fileContent)) }) - server := &http.Server{Addr: ":18085", Handler: handler} + server := &http.Server{Addr: ":18086", Handler: handler} go func() { _ = server.ListenAndServe() }() defer server.Close() time.Sleep(100 * time.Millisecond) c := ah.NewClient(time.Second) - r, errs := c.PostMultipart("http://127.0.0.1:18085/", map[string]string{"foo": "bar"}, map[string]any{"file": []byte("baz")}) - if len(errs) > 0 { - t.Fatalf("PostMultipart failed: %v", errs) + r := c.Post("http://127.0.0.1:18086/", ah.Multipart{"foo": "bar", "file": []byte("baz")}) + if r.Error != nil { + t.Fatalf("Post with Multipart failed: %v", r.Error) } if r.String() != "foo=bar,file=baz" { t.Errorf("expected foo=bar,file=baz, got %s", r.String()) } } + +func TestFormAndMap(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ct := r.Header.Get("Content-Type") + body, _ := io.ReadAll(r.Body) + fmt.Fprintf(w, "ct=%s,body=%s", ct, string(body)) + }) + server := &http.Server{Addr: ":18087", Handler: handler} + go func() { _ = server.ListenAndServe() }() + defer server.Close() + time.Sleep(100 * time.Millisecond) + + c := ah.NewClient(time.Second) + + // Test Form (urlencoded) + r1 := c.Post("http://127.0.0.1:18087/", ah.Form{"foo": "bar"}) + if r1.String() != "ct=application/x-www-form-urlencoded,body=foo=bar" { + t.Errorf("Form failed, got: %s", r1.String()) + } + + // Test map[string]string (JSON) + r2 := c.Post("http://127.0.0.1:18087/", map[string]string{"foo": "bar"}) + if r2.String() != `ct=application/json,body={"foo":"bar"}` { + t.Errorf("map[string]string failed, got: %s", r2.String()) + } +} diff --git a/go.mod b/go.mod index 7776e42..f4bb3d2 100644 --- a/go.mod +++ b/go.mod @@ -3,10 +3,10 @@ module apigo.cc/go/http go 1.25.0 require ( - apigo.cc/go/cast v1.2.7 + apigo.cc/go/cast v1.2.8 apigo.cc/go/encoding v1.0.5 - apigo.cc/go/file v1.0.6 - apigo.cc/go/log v1.1.5 + apigo.cc/go/file v1.0.7 + apigo.cc/go/log v1.1.9 apigo.cc/go/rand v1.0.5 golang.org/x/net v0.53.0 )