Source file
src/net/http/clientserver_test.go
1
2
3
4
5
6
7 package http_test
8
9 import (
10 "bytes"
11 "compress/gzip"
12 "context"
13 "crypto/rand"
14 "crypto/sha1"
15 "crypto/tls"
16 "fmt"
17 "hash"
18 "io"
19 "log"
20 "maps"
21 "net"
22 . "net/http"
23 "net/http/httptest"
24 "net/http/httptrace"
25 "net/http/httputil"
26 "net/textproto"
27 "net/url"
28 "os"
29 "reflect"
30 "runtime"
31 "slices"
32 "strings"
33 "sync"
34 "sync/atomic"
35 "testing"
36 "testing/synctest"
37 "time"
38 )
39
40 type testMode string
41
42 const (
43 http1Mode = testMode("h1")
44 https1Mode = testMode("https1")
45 http2Mode = testMode("h2")
46 http2UnencryptedMode = testMode("h2unencrypted")
47 )
48
49 type testNotParallelOpt struct{}
50
51 var (
52 testNotParallel = testNotParallelOpt{}
53 )
54
55 type TBRun[T any] interface {
56 testing.TB
57 Run(string, func(T)) bool
58 }
59
60
61
62
63
64
65
66
67 func run[T TBRun[T]](t T, f func(t T, mode testMode), opts ...any) {
68 t.Helper()
69 modes := []testMode{http1Mode, http2Mode}
70 parallel := true
71 for _, opt := range opts {
72 switch opt := opt.(type) {
73 case []testMode:
74 modes = opt
75 case testNotParallelOpt:
76 parallel = false
77 default:
78 t.Fatalf("unknown option type %T", opt)
79 }
80 }
81 if t, ok := any(t).(*testing.T); ok && parallel {
82 setParallel(t)
83 }
84 for _, mode := range modes {
85 t.Run(string(mode), func(t T) {
86 t.Helper()
87 if t, ok := any(t).(*testing.T); ok && parallel {
88 setParallel(t)
89 }
90 t.Cleanup(func() {
91 afterTest(t)
92 })
93 f(t, mode)
94 })
95 }
96 }
97
98
99
100
101 func runSynctest(t *testing.T, f func(t *testing.T, mode testMode), opts ...any) {
102 run(t, func(t *testing.T, mode testMode) {
103 synctest.Test(t, func(t *testing.T) {
104 f(t, mode)
105 })
106 }, opts...)
107 }
108
109 type clientServerTest struct {
110 t testing.TB
111 h2 bool
112 h Handler
113 ts *httptest.Server
114 tr *Transport
115 c *Client
116 li *fakeNetListener
117 }
118
119 func (t *clientServerTest) close() {
120 t.tr.CloseIdleConnections()
121 t.ts.Close()
122 }
123
124 func (t *clientServerTest) getURL(u string) string {
125 res, err := t.c.Get(u)
126 if err != nil {
127 t.t.Fatal(err)
128 }
129 defer res.Body.Close()
130 slurp, err := io.ReadAll(res.Body)
131 if err != nil {
132 t.t.Fatal(err)
133 }
134 return string(slurp)
135 }
136
137 func (t *clientServerTest) scheme() string {
138 if t.h2 {
139 return "https"
140 }
141 return "http"
142 }
143
144 var optQuietLog = func(ts *httptest.Server) {
145 ts.Config.ErrorLog = quietLog
146 }
147
148 func optWithServerLog(lg *log.Logger) func(*httptest.Server) {
149 return func(ts *httptest.Server) {
150 ts.Config.ErrorLog = lg
151 }
152 }
153
154 var optFakeNet = new(struct{})
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170 func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *clientServerTest {
171 if mode == http2Mode {
172 CondSkipHTTP2(t)
173 }
174 cst := &clientServerTest{
175 t: t,
176 h2: mode == http2Mode,
177 h: h,
178 }
179
180 var transportFuncs []func(*Transport)
181
182 if idx := slices.Index(opts, any(optFakeNet)); idx >= 0 {
183 opts = slices.Delete(opts, idx, idx+1)
184 cst.li = fakeNetListen()
185 cst.ts = &httptest.Server{
186 Config: &Server{Handler: h},
187 Listener: cst.li,
188 }
189 transportFuncs = append(transportFuncs, func(tr *Transport) {
190 tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
191 return cst.li.connect(), nil
192 }
193 })
194 } else {
195 cst.ts = httptest.NewUnstartedServer(h)
196 }
197
198 if mode == http2UnencryptedMode {
199 p := &Protocols{}
200 p.SetUnencryptedHTTP2(true)
201 cst.ts.Config.Protocols = p
202 }
203
204 for _, opt := range opts {
205 switch opt := opt.(type) {
206 case func(*Transport):
207 transportFuncs = append(transportFuncs, opt)
208 case func(*httptest.Server):
209 opt(cst.ts)
210 default:
211 t.Fatalf("unhandled option type %T", opt)
212 }
213 }
214
215 if cst.ts.Config.ErrorLog == nil {
216 cst.ts.Config.ErrorLog = log.New(testLogWriter{t}, "", 0)
217 }
218
219 switch mode {
220 case http1Mode:
221 cst.ts.Start()
222 case https1Mode:
223 cst.ts.StartTLS()
224 case http2UnencryptedMode:
225 ExportHttp2ConfigureServer(cst.ts.Config, nil)
226 cst.ts.Start()
227 case http2Mode:
228 ExportHttp2ConfigureServer(cst.ts.Config, nil)
229 cst.ts.TLS = cst.ts.Config.TLSConfig
230 cst.ts.StartTLS()
231 default:
232 t.Fatalf("unknown test mode %v", mode)
233 }
234 cst.c = cst.ts.Client()
235 cst.tr = cst.c.Transport.(*Transport)
236 if mode == http2Mode || mode == http2UnencryptedMode {
237 if err := ExportHttp2ConfigureTransport(cst.tr); err != nil {
238 t.Fatal(err)
239 }
240 }
241 for _, f := range transportFuncs {
242 f(cst.tr)
243 }
244
245 if mode == http2UnencryptedMode {
246 p := &Protocols{}
247 p.SetUnencryptedHTTP2(true)
248 cst.tr.Protocols = p
249 }
250
251 t.Cleanup(func() {
252 cst.close()
253 })
254 return cst
255 }
256
257 type testLogWriter struct {
258 t testing.TB
259 }
260
261 func (w testLogWriter) Write(b []byte) (int, error) {
262 w.t.Logf("server log: %v", strings.TrimSpace(string(b)))
263 return len(b), nil
264 }
265
266
267 func TestNewClientServerTest(t *testing.T) {
268 modes := []testMode{http1Mode, https1Mode, http2Mode}
269 t.Run("realnet", func(t *testing.T) {
270 run(t, func(t *testing.T, mode testMode) {
271 testNewClientServerTest(t, mode)
272 }, modes)
273 })
274 t.Run("synctest", func(t *testing.T) {
275 runSynctest(t, func(t *testing.T, mode testMode) {
276 testNewClientServerTest(t, mode, optFakeNet)
277 }, modes)
278 })
279 }
280 func testNewClientServerTest(t *testing.T, mode testMode, opts ...any) {
281 var got struct {
282 sync.Mutex
283 proto string
284 hasTLS bool
285 }
286 h := HandlerFunc(func(w ResponseWriter, r *Request) {
287 got.Lock()
288 defer got.Unlock()
289 got.proto = r.Proto
290 got.hasTLS = r.TLS != nil
291 })
292 cst := newClientServerTest(t, mode, h, opts...)
293 if _, err := cst.c.Head(cst.ts.URL); err != nil {
294 t.Fatal(err)
295 }
296 var wantProto string
297 var wantTLS bool
298 switch mode {
299 case http1Mode:
300 wantProto = "HTTP/1.1"
301 wantTLS = false
302 case https1Mode:
303 wantProto = "HTTP/1.1"
304 wantTLS = true
305 case http2Mode:
306 wantProto = "HTTP/2.0"
307 wantTLS = true
308 }
309 if got.proto != wantProto {
310 t.Errorf("req.Proto = %q, want %q", got.proto, wantProto)
311 }
312 if got.hasTLS != wantTLS {
313 t.Errorf("req.TLS set: %v, want %v", got.hasTLS, wantTLS)
314 }
315 }
316
317 func TestChunkedResponseHeaders(t *testing.T) { run(t, testChunkedResponseHeaders) }
318 func testChunkedResponseHeaders(t *testing.T, mode testMode) {
319 log.SetOutput(io.Discard)
320 defer log.SetOutput(os.Stderr)
321 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
322 w.Header().Set("Content-Length", "intentional gibberish")
323 w.(Flusher).Flush()
324 fmt.Fprintf(w, "I am a chunked response.")
325 }))
326
327 res, err := cst.c.Get(cst.ts.URL)
328 if err != nil {
329 t.Fatalf("Get error: %v", err)
330 }
331 defer res.Body.Close()
332 if g, e := res.ContentLength, int64(-1); g != e {
333 t.Errorf("expected ContentLength of %d; got %d", e, g)
334 }
335 wantTE := []string{"chunked"}
336 if mode == http2Mode {
337 wantTE = nil
338 }
339 if !slices.Equal(res.TransferEncoding, wantTE) {
340 t.Errorf("TransferEncoding = %v; want %v", res.TransferEncoding, wantTE)
341 }
342 if got, haveCL := res.Header["Content-Length"]; haveCL {
343 t.Errorf("Unexpected Content-Length: %q", got)
344 }
345 }
346
347 type reqFunc func(c *Client, url string) (*Response, error)
348
349
350
351 type h12Compare struct {
352 Handler func(ResponseWriter, *Request)
353 ReqFunc reqFunc
354 CheckResponse func(proto string, res *Response)
355 EarlyCheckResponse func(proto string, res *Response)
356 Opts []any
357 }
358
359 func (tt h12Compare) reqFunc() reqFunc {
360 if tt.ReqFunc == nil {
361 return (*Client).Get
362 }
363 return tt.ReqFunc
364 }
365
366 func (tt h12Compare) run(t *testing.T) {
367 setParallel(t)
368 cst1 := newClientServerTest(t, http1Mode, HandlerFunc(tt.Handler), tt.Opts...)
369 defer cst1.close()
370 cst2 := newClientServerTest(t, http2Mode, HandlerFunc(tt.Handler), tt.Opts...)
371 defer cst2.close()
372
373 res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL)
374 if err != nil {
375 t.Errorf("HTTP/1 request: %v", err)
376 return
377 }
378 res2, err := tt.reqFunc()(cst2.c, cst2.ts.URL)
379 if err != nil {
380 t.Errorf("HTTP/2 request: %v", err)
381 return
382 }
383
384 if fn := tt.EarlyCheckResponse; fn != nil {
385 fn("HTTP/1.1", res1)
386 fn("HTTP/2.0", res2)
387 }
388
389 tt.normalizeRes(t, res1, "HTTP/1.1")
390 tt.normalizeRes(t, res2, "HTTP/2.0")
391 res1body, res2body := res1.Body, res2.Body
392
393 eres1 := mostlyCopy(res1)
394 eres2 := mostlyCopy(res2)
395 if !reflect.DeepEqual(eres1, eres2) {
396 t.Errorf("Response headers to handler differed:\nhttp/1 (%v):\n\t%#v\nhttp/2 (%v):\n\t%#v",
397 cst1.ts.URL, eres1, cst2.ts.URL, eres2)
398 }
399 if !reflect.DeepEqual(res1body, res2body) {
400 t.Errorf("Response bodies to handler differed.\nhttp1: %v\nhttp2: %v\n", res1body, res2body)
401 }
402 if fn := tt.CheckResponse; fn != nil {
403 res1.Body, res2.Body = res1body, res2body
404 fn("HTTP/1.1", res1)
405 fn("HTTP/2.0", res2)
406 }
407 }
408
409 func mostlyCopy(r *Response) *Response {
410 c := *r
411 c.Body = nil
412 c.TransferEncoding = nil
413 c.TLS = nil
414 c.Request = nil
415 return &c
416 }
417
418 type slurpResult struct {
419 io.ReadCloser
420 body []byte
421 err error
422 }
423
424 func (sr slurpResult) String() string { return fmt.Sprintf("body %q; err %v", sr.body, sr.err) }
425
426 func (tt h12Compare) normalizeRes(t *testing.T, res *Response, wantProto string) {
427 if res.Proto == wantProto || res.Proto == "HTTP/IGNORE" {
428 res.Proto, res.ProtoMajor, res.ProtoMinor = "", 0, 0
429 } else {
430 t.Errorf("got %q response; want %q", res.Proto, wantProto)
431 }
432 slurp, err := io.ReadAll(res.Body)
433
434 res.Body.Close()
435 res.Body = slurpResult{
436 ReadCloser: io.NopCloser(bytes.NewReader(slurp)),
437 body: slurp,
438 err: err,
439 }
440 for i, v := range res.Header["Date"] {
441 res.Header["Date"][i] = strings.Repeat("x", len(v))
442 }
443 if res.Request == nil {
444 t.Errorf("for %s, no request", wantProto)
445 }
446 if (res.TLS != nil) != (wantProto == "HTTP/2.0") {
447 t.Errorf("TLS set = %v; want %v", res.TLS != nil, res.TLS == nil)
448 }
449 }
450
451
452 func TestH12_HeadContentLengthNoBody(t *testing.T) {
453 h12Compare{
454 ReqFunc: (*Client).Head,
455 Handler: func(w ResponseWriter, r *Request) {
456 },
457 }.run(t)
458 }
459
460 func TestH12_HeadContentLengthSmallBody(t *testing.T) {
461 h12Compare{
462 ReqFunc: (*Client).Head,
463 Handler: func(w ResponseWriter, r *Request) {
464 io.WriteString(w, "small")
465 },
466 }.run(t)
467 }
468
469 func TestH12_HeadContentLengthLargeBody(t *testing.T) {
470 h12Compare{
471 ReqFunc: (*Client).Head,
472 Handler: func(w ResponseWriter, r *Request) {
473 chunk := strings.Repeat("x", 512<<10)
474 for i := 0; i < 10; i++ {
475 io.WriteString(w, chunk)
476 }
477 },
478 }.run(t)
479 }
480
481 func TestH12_200NoBody(t *testing.T) {
482 h12Compare{Handler: func(w ResponseWriter, r *Request) {}}.run(t)
483 }
484
485 func TestH2_204NoBody(t *testing.T) { testH12_noBody(t, 204) }
486 func TestH2_304NoBody(t *testing.T) { testH12_noBody(t, 304) }
487 func TestH2_404NoBody(t *testing.T) { testH12_noBody(t, 404) }
488
489 func testH12_noBody(t *testing.T, status int) {
490 h12Compare{Handler: func(w ResponseWriter, r *Request) {
491 w.WriteHeader(status)
492 }}.run(t)
493 }
494
495 func TestH12_SmallBody(t *testing.T) {
496 h12Compare{Handler: func(w ResponseWriter, r *Request) {
497 io.WriteString(w, "small body")
498 }}.run(t)
499 }
500
501 func TestH12_ExplicitContentLength(t *testing.T) {
502 h12Compare{Handler: func(w ResponseWriter, r *Request) {
503 w.Header().Set("Content-Length", "3")
504 io.WriteString(w, "foo")
505 }}.run(t)
506 }
507
508 func TestH12_FlushBeforeBody(t *testing.T) {
509 h12Compare{Handler: func(w ResponseWriter, r *Request) {
510 w.(Flusher).Flush()
511 io.WriteString(w, "foo")
512 }}.run(t)
513 }
514
515 func TestH12_FlushMidBody(t *testing.T) {
516 h12Compare{Handler: func(w ResponseWriter, r *Request) {
517 io.WriteString(w, "foo")
518 w.(Flusher).Flush()
519 io.WriteString(w, "bar")
520 }}.run(t)
521 }
522
523 func TestH12_Head_ExplicitLen(t *testing.T) {
524 h12Compare{
525 ReqFunc: (*Client).Head,
526 Handler: func(w ResponseWriter, r *Request) {
527 if r.Method != "HEAD" {
528 t.Errorf("unexpected method %q", r.Method)
529 }
530 w.Header().Set("Content-Length", "1235")
531 },
532 }.run(t)
533 }
534
535 func TestH12_Head_ImplicitLen(t *testing.T) {
536 h12Compare{
537 ReqFunc: (*Client).Head,
538 Handler: func(w ResponseWriter, r *Request) {
539 if r.Method != "HEAD" {
540 t.Errorf("unexpected method %q", r.Method)
541 }
542 io.WriteString(w, "foo")
543 },
544 }.run(t)
545 }
546
547 func TestH12_HandlerWritesTooLittle(t *testing.T) {
548 h12Compare{
549 Handler: func(w ResponseWriter, r *Request) {
550 w.Header().Set("Content-Length", "3")
551 io.WriteString(w, "12")
552 },
553 CheckResponse: func(proto string, res *Response) {
554 sr, ok := res.Body.(slurpResult)
555 if !ok {
556 t.Errorf("%s body is %T; want slurpResult", proto, res.Body)
557 return
558 }
559 if sr.err != io.ErrUnexpectedEOF {
560 t.Errorf("%s read error = %v; want io.ErrUnexpectedEOF", proto, sr.err)
561 }
562 if string(sr.body) != "12" {
563 t.Errorf("%s body = %q; want %q", proto, sr.body, "12")
564 }
565 },
566 }.run(t)
567 }
568
569
570
571
572
573
574
575 func TestHandlerWritesTooMuch(t *testing.T) { run(t, testHandlerWritesTooMuch) }
576 func testHandlerWritesTooMuch(t *testing.T, mode testMode) {
577 wantBody := []byte("123")
578 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
579 rc := NewResponseController(w)
580 w.Header().Set("Content-Length", fmt.Sprintf("%v", len(wantBody)))
581 rc.Flush()
582 w.Write(wantBody)
583 rc.Flush()
584 n, err := io.WriteString(w, "x")
585 if err == nil {
586 err = rc.Flush()
587 }
588
589 if err == nil {
590 t.Errorf("for proto %q, final write = %v, %v; want _, some error", r.Proto, n, err)
591 }
592 }))
593
594 res, err := cst.c.Get(cst.ts.URL)
595 if err != nil {
596 t.Fatal(err)
597 }
598 defer res.Body.Close()
599
600 gotBody, _ := io.ReadAll(res.Body)
601 if !bytes.Equal(gotBody, wantBody) {
602 t.Fatalf("got response body: %q; want %q", gotBody, wantBody)
603 }
604 }
605
606
607
608 func TestH12_AutoGzip(t *testing.T) {
609 h12Compare{
610 Handler: func(w ResponseWriter, r *Request) {
611 if ae := r.Header.Get("Accept-Encoding"); ae != "gzip" {
612 t.Errorf("%s Accept-Encoding = %q; want gzip", r.Proto, ae)
613 }
614 w.Header().Set("Content-Encoding", "gzip")
615 gz := gzip.NewWriter(w)
616 io.WriteString(gz, "I am some gzipped content. Go go go go go go go go go go go go should compress well.")
617 gz.Close()
618 },
619 }.run(t)
620 }
621
622 func TestH12_AutoGzip_Disabled(t *testing.T) {
623 h12Compare{
624 Opts: []any{
625 func(tr *Transport) { tr.DisableCompression = true },
626 },
627 Handler: func(w ResponseWriter, r *Request) {
628 fmt.Fprintf(w, "%q", r.Header["Accept-Encoding"])
629 if ae := r.Header.Get("Accept-Encoding"); ae != "" {
630 t.Errorf("%s Accept-Encoding = %q; want empty", r.Proto, ae)
631 }
632 },
633 }.run(t)
634 }
635
636
637
638
639 func Test304Responses(t *testing.T) { run(t, test304Responses) }
640 func test304Responses(t *testing.T, mode testMode) {
641 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
642 w.WriteHeader(StatusNotModified)
643 _, err := w.Write([]byte("illegal body"))
644 if err != ErrBodyNotAllowed {
645 t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err)
646 }
647 }))
648 defer cst.close()
649 res, err := cst.c.Get(cst.ts.URL)
650 if err != nil {
651 t.Fatal(err)
652 }
653 if len(res.TransferEncoding) > 0 {
654 t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
655 }
656 body, err := io.ReadAll(res.Body)
657 if err != nil {
658 t.Error(err)
659 }
660 if len(body) > 0 {
661 t.Errorf("got unexpected body %q", string(body))
662 }
663 }
664
665 func TestH12_ServerEmptyContentLength(t *testing.T) {
666 h12Compare{
667 Handler: func(w ResponseWriter, r *Request) {
668 w.Header()["Content-Type"] = []string{""}
669 io.WriteString(w, "<html><body>hi</body></html>")
670 },
671 }.run(t)
672 }
673
674 func TestH12_RequestContentLength_Known_NonZero(t *testing.T) {
675 h12requestContentLength(t, func() io.Reader { return strings.NewReader("FOUR") }, 4)
676 }
677
678 func TestH12_RequestContentLength_Known_Zero(t *testing.T) {
679 h12requestContentLength(t, func() io.Reader { return nil }, 0)
680 }
681
682 func TestH12_RequestContentLength_Unknown(t *testing.T) {
683 h12requestContentLength(t, func() io.Reader { return struct{ io.Reader }{strings.NewReader("Stuff")} }, -1)
684 }
685
686 func h12requestContentLength(t *testing.T, bodyfn func() io.Reader, wantLen int64) {
687 h12Compare{
688 Handler: func(w ResponseWriter, r *Request) {
689 w.Header().Set("Got-Length", fmt.Sprint(r.ContentLength))
690 fmt.Fprintf(w, "Req.ContentLength=%v", r.ContentLength)
691 },
692 ReqFunc: func(c *Client, url string) (*Response, error) {
693 return c.Post(url, "text/plain", bodyfn())
694 },
695 CheckResponse: func(proto string, res *Response) {
696 if got, want := res.Header.Get("Got-Length"), fmt.Sprint(wantLen); got != want {
697 t.Errorf("Proto %q got length %q; want %q", proto, got, want)
698 }
699 },
700 }.run(t)
701 }
702
703
704
705 func TestCancelRequestMidBody(t *testing.T) { run(t, testCancelRequestMidBody) }
706 func testCancelRequestMidBody(t *testing.T, mode testMode) {
707 unblock := make(chan bool)
708 didFlush := make(chan bool, 1)
709 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
710 io.WriteString(w, "Hello")
711 w.(Flusher).Flush()
712 didFlush <- true
713 <-unblock
714 io.WriteString(w, ", world.")
715 }))
716 defer close(unblock)
717
718 req, _ := NewRequest("GET", cst.ts.URL, nil)
719 cancel := make(chan struct{})
720 req.Cancel = cancel
721
722 res, err := cst.c.Do(req)
723 if err != nil {
724 t.Fatal(err)
725 }
726 defer res.Body.Close()
727 <-didFlush
728
729
730
731 firstRead := make([]byte, 10)
732 n, err := res.Body.Read(firstRead)
733 if err != nil {
734 t.Fatal(err)
735 }
736 firstRead = firstRead[:n]
737
738 close(cancel)
739
740 rest, err := io.ReadAll(res.Body)
741 all := string(firstRead) + string(rest)
742 if all != "Hello" {
743 t.Errorf("Read %q (%q + %q); want Hello", all, firstRead, rest)
744 }
745 if err != ExportErrRequestCanceled {
746 t.Errorf("ReadAll error = %v; want %v", err, ExportErrRequestCanceled)
747 }
748 }
749
750
751 func TestTrailersClientToServer(t *testing.T) { run(t, testTrailersClientToServer) }
752 func testTrailersClientToServer(t *testing.T, mode testMode) {
753 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
754 slurp, err := io.ReadAll(r.Body)
755 if err != nil {
756 t.Errorf("Server reading request body: %v", err)
757 }
758 if string(slurp) != "foo" {
759 t.Errorf("Server read request body %q; want foo", slurp)
760 }
761 if r.Trailer == nil {
762 io.WriteString(w, "nil Trailer")
763 } else {
764 decl := slices.Sorted(maps.Keys(r.Trailer))
765 fmt.Fprintf(w, "decl: %v, vals: %s, %s",
766 decl,
767 r.Trailer.Get("Client-Trailer-A"),
768 r.Trailer.Get("Client-Trailer-B"))
769 }
770 }))
771
772 var req *Request
773 req, _ = NewRequest("POST", cst.ts.URL, io.MultiReader(
774 eofReaderFunc(func() {
775 req.Trailer["Client-Trailer-A"] = []string{"valuea"}
776 }),
777 strings.NewReader("foo"),
778 eofReaderFunc(func() {
779 req.Trailer["Client-Trailer-B"] = []string{"valueb"}
780 }),
781 ))
782 req.Trailer = Header{
783 "Client-Trailer-A": nil,
784 "Client-Trailer-B": nil,
785 }
786 req.ContentLength = -1
787 res, err := cst.c.Do(req)
788 if err != nil {
789 t.Fatal(err)
790 }
791 if err := wantBody(res, err, "decl: [Client-Trailer-A Client-Trailer-B], vals: valuea, valueb"); err != nil {
792 t.Error(err)
793 }
794 }
795
796
797 func TestTrailersServerToClient(t *testing.T) {
798 run(t, func(t *testing.T, mode testMode) {
799 testTrailersServerToClient(t, mode, false)
800 })
801 }
802 func TestTrailersServerToClientFlush(t *testing.T) {
803 run(t, func(t *testing.T, mode testMode) {
804 testTrailersServerToClient(t, mode, true)
805 })
806 }
807
808 func testTrailersServerToClient(t *testing.T, mode testMode, flush bool) {
809 const body = "Some body"
810 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
811 w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B")
812 w.Header().Add("Trailer", "Server-Trailer-C")
813
814 io.WriteString(w, body)
815 if flush {
816 w.(Flusher).Flush()
817 }
818
819
820
821
822
823 w.Header().Set("Server-Trailer-A", "valuea")
824 w.Header().Set("Server-Trailer-C", "valuec")
825 w.Header().Set("Server-Trailer-NotDeclared", "should be omitted")
826 }))
827
828 res, err := cst.c.Get(cst.ts.URL)
829 if err != nil {
830 t.Fatal(err)
831 }
832
833 wantHeader := Header{
834 "Content-Type": {"text/plain; charset=utf-8"},
835 }
836 wantLen := -1
837 if mode == http2Mode && !flush {
838
839
840
841
842
843 wantLen = len(body)
844 wantHeader["Content-Length"] = []string{fmt.Sprint(wantLen)}
845 }
846 if res.ContentLength != int64(wantLen) {
847 t.Errorf("ContentLength = %v; want %v", res.ContentLength, wantLen)
848 }
849
850 delete(res.Header, "Date")
851 if !reflect.DeepEqual(res.Header, wantHeader) {
852 t.Errorf("Header = %v; want %v", res.Header, wantHeader)
853 }
854
855 if got, want := res.Trailer, (Header{
856 "Server-Trailer-A": nil,
857 "Server-Trailer-B": nil,
858 "Server-Trailer-C": nil,
859 }); !reflect.DeepEqual(got, want) {
860 t.Errorf("Trailer before body read = %v; want %v", got, want)
861 }
862
863 if err := wantBody(res, nil, body); err != nil {
864 t.Fatal(err)
865 }
866
867 if got, want := res.Trailer, (Header{
868 "Server-Trailer-A": {"valuea"},
869 "Server-Trailer-B": nil,
870 "Server-Trailer-C": {"valuec"},
871 }); !reflect.DeepEqual(got, want) {
872 t.Errorf("Trailer after body read = %v; want %v", got, want)
873 }
874 }
875
876
877 func TestResponseBodyReadAfterClose(t *testing.T) { run(t, testResponseBodyReadAfterClose) }
878 func testResponseBodyReadAfterClose(t *testing.T, mode testMode) {
879 const body = "Some body"
880 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
881 io.WriteString(w, body)
882 }))
883 res, err := cst.c.Get(cst.ts.URL)
884 if err != nil {
885 t.Fatal(err)
886 }
887 res.Body.Close()
888 data, err := io.ReadAll(res.Body)
889 if len(data) != 0 || err == nil {
890 t.Fatalf("ReadAll returned %q, %v; want error", data, err)
891 }
892 }
893
894 func TestConcurrentReadWriteReqBody(t *testing.T) { run(t, testConcurrentReadWriteReqBody) }
895 func testConcurrentReadWriteReqBody(t *testing.T, mode testMode) {
896 const reqBody = "some request body"
897 const resBody = "some response body"
898 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
899 var wg sync.WaitGroup
900 wg.Add(2)
901 didRead := make(chan bool, 1)
902
903 go func() {
904 defer wg.Done()
905 data, err := io.ReadAll(r.Body)
906 if string(data) != reqBody {
907 t.Errorf("Handler read %q; want %q", data, reqBody)
908 }
909 if err != nil {
910 t.Errorf("Handler Read: %v", err)
911 }
912 didRead <- true
913 }()
914
915 go func() {
916 defer wg.Done()
917 if mode != http2Mode {
918
919
920
921
922 <-didRead
923 }
924 io.WriteString(w, resBody)
925 }()
926 wg.Wait()
927 }))
928 req, _ := NewRequest("POST", cst.ts.URL, strings.NewReader(reqBody))
929 req.Header.Add("Expect", "100-continue")
930 res, err := cst.c.Do(req)
931 if err != nil {
932 t.Fatal(err)
933 }
934 data, err := io.ReadAll(res.Body)
935 defer res.Body.Close()
936 if err != nil {
937 t.Fatal(err)
938 }
939 if string(data) != resBody {
940 t.Errorf("read %q; want %q", data, resBody)
941 }
942 }
943
944 func TestConnectRequest(t *testing.T) { run(t, testConnectRequest) }
945 func testConnectRequest(t *testing.T, mode testMode) {
946 gotc := make(chan *Request, 1)
947 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
948 gotc <- r
949 }))
950
951 u, err := url.Parse(cst.ts.URL)
952 if err != nil {
953 t.Fatal(err)
954 }
955
956 tests := []struct {
957 req *Request
958 want string
959 }{
960 {
961 req: &Request{
962 Method: "CONNECT",
963 Header: Header{},
964 URL: u,
965 },
966 want: u.Host,
967 },
968 {
969 req: &Request{
970 Method: "CONNECT",
971 Header: Header{},
972 URL: u,
973 Host: "example.com:123",
974 },
975 want: "example.com:123",
976 },
977 }
978
979 for i, tt := range tests {
980 res, err := cst.c.Do(tt.req)
981 if err != nil {
982 t.Errorf("%d. RoundTrip = %v", i, err)
983 continue
984 }
985 res.Body.Close()
986 req := <-gotc
987 if req.Method != "CONNECT" {
988 t.Errorf("method = %q; want CONNECT", req.Method)
989 }
990 if req.Host != tt.want {
991 t.Errorf("Host = %q; want %q", req.Host, tt.want)
992 }
993 if req.URL.Host != tt.want {
994 t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want)
995 }
996 }
997 }
998
999 func TestTransportUserAgent(t *testing.T) { run(t, testTransportUserAgent) }
1000 func testTransportUserAgent(t *testing.T, mode testMode) {
1001 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1002 fmt.Fprintf(w, "%q", r.Header["User-Agent"])
1003 }))
1004
1005 either := func(a, b string) string {
1006 if mode == http2Mode {
1007 return b
1008 }
1009 return a
1010 }
1011
1012 tests := []struct {
1013 setup func(*Request)
1014 want string
1015 }{
1016 {
1017 func(r *Request) {},
1018 either(`["Go-http-client/1.1"]`, `["Go-http-client/2.0"]`),
1019 },
1020 {
1021 func(r *Request) { r.Header.Set("User-Agent", "foo/1.2.3") },
1022 `["foo/1.2.3"]`,
1023 },
1024 {
1025 func(r *Request) { r.Header["User-Agent"] = []string{"single", "or", "multiple"} },
1026 `["single"]`,
1027 },
1028 {
1029 func(r *Request) { r.Header.Set("User-Agent", "") },
1030 `[]`,
1031 },
1032 {
1033 func(r *Request) { r.Header["User-Agent"] = nil },
1034 `[]`,
1035 },
1036 }
1037 for i, tt := range tests {
1038 req, _ := NewRequest("GET", cst.ts.URL, nil)
1039 tt.setup(req)
1040 res, err := cst.c.Do(req)
1041 if err != nil {
1042 t.Errorf("%d. RoundTrip = %v", i, err)
1043 continue
1044 }
1045 slurp, err := io.ReadAll(res.Body)
1046 res.Body.Close()
1047 if err != nil {
1048 t.Errorf("%d. read body = %v", i, err)
1049 continue
1050 }
1051 if string(slurp) != tt.want {
1052 t.Errorf("%d. body mismatch.\n got: %s\nwant: %s\n", i, slurp, tt.want)
1053 }
1054 }
1055 }
1056
1057 func TestStarRequestMethod(t *testing.T) {
1058 for _, method := range []string{"FOO", "OPTIONS"} {
1059 t.Run(method, func(t *testing.T) {
1060 run(t, func(t *testing.T, mode testMode) {
1061 testStarRequest(t, method, mode)
1062 })
1063 })
1064 }
1065 }
1066 func testStarRequest(t *testing.T, method string, mode testMode) {
1067 gotc := make(chan *Request, 1)
1068 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1069 w.Header().Set("foo", "bar")
1070 gotc <- r
1071 w.(Flusher).Flush()
1072 }))
1073
1074 u, err := url.Parse(cst.ts.URL)
1075 if err != nil {
1076 t.Fatal(err)
1077 }
1078 u.Path = "*"
1079
1080 req := &Request{
1081 Method: method,
1082 Header: Header{},
1083 URL: u,
1084 }
1085
1086 res, err := cst.c.Do(req)
1087 if err != nil {
1088 t.Fatalf("RoundTrip = %v", err)
1089 }
1090 res.Body.Close()
1091
1092 wantFoo := "bar"
1093 wantLen := int64(-1)
1094 if method == "OPTIONS" {
1095 wantFoo = ""
1096 wantLen = 0
1097 }
1098 if res.StatusCode != 200 {
1099 t.Errorf("status code = %v; want %d", res.Status, 200)
1100 }
1101 if res.ContentLength != wantLen {
1102 t.Errorf("content length = %v; want %d", res.ContentLength, wantLen)
1103 }
1104 if got := res.Header.Get("foo"); got != wantFoo {
1105 t.Errorf("response \"foo\" header = %q; want %q", got, wantFoo)
1106 }
1107 select {
1108 case req = <-gotc:
1109 default:
1110 req = nil
1111 }
1112 if req == nil {
1113 if method != "OPTIONS" {
1114 t.Fatalf("handler never got request")
1115 }
1116 return
1117 }
1118 if req.Method != method {
1119 t.Errorf("method = %q; want %q", req.Method, method)
1120 }
1121 if req.URL.Path != "*" {
1122 t.Errorf("URL.Path = %q; want *", req.URL.Path)
1123 }
1124 if req.RequestURI != "*" {
1125 t.Errorf("RequestURI = %q; want *", req.RequestURI)
1126 }
1127 }
1128
1129
1130 func TestTransportDiscardsUnneededConns(t *testing.T) {
1131 run(t, testTransportDiscardsUnneededConns, []testMode{http2Mode})
1132 }
1133 func testTransportDiscardsUnneededConns(t *testing.T, mode testMode) {
1134 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1135 fmt.Fprintf(w, "Hello, %v", r.RemoteAddr)
1136 }))
1137 defer cst.close()
1138
1139 var numOpen, numClose int32
1140
1141 tlsConfig := &tls.Config{InsecureSkipVerify: true}
1142 tr := &Transport{
1143 TLSClientConfig: tlsConfig,
1144 DialTLS: func(_, addr string) (net.Conn, error) {
1145 time.Sleep(10 * time.Millisecond)
1146 rc, err := net.Dial("tcp", addr)
1147 if err != nil {
1148 return nil, err
1149 }
1150 atomic.AddInt32(&numOpen, 1)
1151 c := noteCloseConn{rc, func() { atomic.AddInt32(&numClose, 1) }}
1152 return tls.Client(c, tlsConfig), nil
1153 },
1154 }
1155 if err := ExportHttp2ConfigureTransport(tr); err != nil {
1156 t.Fatal(err)
1157 }
1158 defer tr.CloseIdleConnections()
1159
1160 c := &Client{Transport: tr}
1161
1162 const N = 10
1163 gotBody := make(chan string, N)
1164 var wg sync.WaitGroup
1165 for i := 0; i < N; i++ {
1166 wg.Add(1)
1167 go func() {
1168 defer wg.Done()
1169 resp, err := c.Get(cst.ts.URL)
1170 if err != nil {
1171
1172
1173 time.Sleep(10 * time.Millisecond)
1174 resp, err = c.Get(cst.ts.URL)
1175 if err != nil {
1176 t.Errorf("Get: %v", err)
1177 return
1178 }
1179 }
1180 defer resp.Body.Close()
1181 slurp, err := io.ReadAll(resp.Body)
1182 if err != nil {
1183 t.Error(err)
1184 }
1185 gotBody <- string(slurp)
1186 }()
1187 }
1188 wg.Wait()
1189 close(gotBody)
1190
1191 var last string
1192 for got := range gotBody {
1193 if last == "" {
1194 last = got
1195 continue
1196 }
1197 if got != last {
1198 t.Errorf("Response body changed: %q -> %q", last, got)
1199 }
1200 }
1201
1202 var open, close int32
1203 for i := 0; i < 150; i++ {
1204 open, close = atomic.LoadInt32(&numOpen), atomic.LoadInt32(&numClose)
1205 if open < 1 {
1206 t.Fatalf("open = %d; want at least", open)
1207 }
1208 if close == open-1 {
1209
1210 return
1211 }
1212 time.Sleep(10 * time.Millisecond)
1213 }
1214 t.Errorf("%d connections opened, %d closed; want %d to close", open, close, open-1)
1215 }
1216
1217
1218 func TestTransportGCRequest(t *testing.T) {
1219 run(t, func(t *testing.T, mode testMode) {
1220 t.Run("Body", func(t *testing.T) { testTransportGCRequest(t, mode, true) })
1221 t.Run("NoBody", func(t *testing.T) { testTransportGCRequest(t, mode, false) })
1222 })
1223 }
1224 func testTransportGCRequest(t *testing.T, mode testMode, body bool) {
1225 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1226 io.ReadAll(r.Body)
1227 if body {
1228 io.WriteString(w, "Hello.")
1229 }
1230 }))
1231
1232 didGC := make(chan struct{})
1233 (func() {
1234 body := strings.NewReader("some body")
1235 req, _ := NewRequest("POST", cst.ts.URL, body)
1236 runtime.AddCleanup(req, func(ch chan struct{}) { close(ch) }, didGC)
1237 res, err := cst.c.Do(req)
1238 if err != nil {
1239 t.Fatal(err)
1240 }
1241 if _, err := io.ReadAll(res.Body); err != nil {
1242 t.Fatal(err)
1243 }
1244 if err := res.Body.Close(); err != nil {
1245 t.Fatal(err)
1246 }
1247 })()
1248 for {
1249 select {
1250 case <-didGC:
1251 return
1252 case <-time.After(1 * time.Millisecond):
1253 runtime.GC()
1254 }
1255 }
1256 }
1257
1258 func TestTransportRejectsInvalidHeaders(t *testing.T) { run(t, testTransportRejectsInvalidHeaders) }
1259 func testTransportRejectsInvalidHeaders(t *testing.T, mode testMode) {
1260 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1261 fmt.Fprintf(w, "Handler saw headers: %q", r.Header)
1262 }), optQuietLog)
1263 cst.tr.DisableKeepAlives = true
1264
1265 tests := []struct {
1266 key, val string
1267 ok bool
1268 }{
1269 {"Foo", "capital-key", true},
1270 {"Foo", "foo\x00bar", false},
1271 {"Foo", "two\nlines", false},
1272 {"bogus\nkey", "v", false},
1273 {"A space", "v", false},
1274 {"имя", "v", false},
1275 {"name", "валю", true},
1276 {"", "v", false},
1277 {"k", "", true},
1278 }
1279 for _, tt := range tests {
1280 dialedc := make(chan bool, 1)
1281 cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
1282 dialedc <- true
1283 return net.Dial(netw, addr)
1284 }
1285 req, _ := NewRequest("GET", cst.ts.URL, nil)
1286 req.Header[tt.key] = []string{tt.val}
1287 res, err := cst.c.Do(req)
1288 var body []byte
1289 if err == nil {
1290 body, _ = io.ReadAll(res.Body)
1291 res.Body.Close()
1292 }
1293 var dialed bool
1294 select {
1295 case <-dialedc:
1296 dialed = true
1297 default:
1298 }
1299
1300 if !tt.ok && dialed {
1301 t.Errorf("For key %q, value %q, transport dialed. Expected local failure. Response was: (%v, %v)\nServer replied with: %s", tt.key, tt.val, res, err, body)
1302 } else if (err == nil) != tt.ok {
1303 t.Errorf("For key %q, value %q; got err = %v; want ok=%v", tt.key, tt.val, err, tt.ok)
1304 }
1305 }
1306 }
1307
1308 func TestInterruptWithPanic(t *testing.T) {
1309 run(t, func(t *testing.T, mode testMode) {
1310 t.Run("boom", func(t *testing.T) { testInterruptWithPanic(t, mode, "boom") })
1311 t.Run("nil", func(t *testing.T) { t.Setenv("GODEBUG", "panicnil=1"); testInterruptWithPanic(t, mode, nil) })
1312 t.Run("ErrAbortHandler", func(t *testing.T) { testInterruptWithPanic(t, mode, ErrAbortHandler) })
1313 }, testNotParallel)
1314 }
1315 func testInterruptWithPanic(t *testing.T, mode testMode, panicValue any) {
1316 const msg = "hello"
1317
1318 testDone := make(chan struct{})
1319 defer close(testDone)
1320
1321 var errorLog lockedBytesBuffer
1322 gotHeaders := make(chan bool, 1)
1323 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1324 io.WriteString(w, msg)
1325 w.(Flusher).Flush()
1326
1327 select {
1328 case <-gotHeaders:
1329 case <-testDone:
1330 }
1331 panic(panicValue)
1332 }), func(ts *httptest.Server) {
1333 ts.Config.ErrorLog = log.New(&errorLog, "", 0)
1334 })
1335 res, err := cst.c.Get(cst.ts.URL)
1336 if err != nil {
1337 t.Fatal(err)
1338 }
1339 gotHeaders <- true
1340 defer res.Body.Close()
1341 slurp, err := io.ReadAll(res.Body)
1342 if string(slurp) != msg {
1343 t.Errorf("client read %q; want %q", slurp, msg)
1344 }
1345 if err == nil {
1346 t.Errorf("client read all successfully; want some error")
1347 }
1348 logOutput := func() string {
1349 errorLog.Lock()
1350 defer errorLog.Unlock()
1351 return errorLog.String()
1352 }
1353 wantStackLogged := panicValue != nil && panicValue != ErrAbortHandler
1354
1355 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
1356 gotLog := logOutput()
1357 if !wantStackLogged {
1358 if gotLog == "" {
1359 return true
1360 }
1361 t.Fatalf("want no log output; got: %s", gotLog)
1362 }
1363 if gotLog == "" {
1364 if d > 0 {
1365 t.Logf("wanted a stack trace logged; got nothing after %v", d)
1366 }
1367 return false
1368 }
1369 if !strings.Contains(gotLog, "created by ") && strings.Count(gotLog, "\n") < 6 {
1370 if d > 0 {
1371 t.Logf("output doesn't look like a panic stack trace after %v. Got: %s", d, gotLog)
1372 }
1373 return false
1374 }
1375 return true
1376 })
1377 }
1378
1379 type lockedBytesBuffer struct {
1380 sync.Mutex
1381 bytes.Buffer
1382 }
1383
1384 func (b *lockedBytesBuffer) Write(p []byte) (int, error) {
1385 b.Lock()
1386 defer b.Unlock()
1387 return b.Buffer.Write(p)
1388 }
1389
1390
1391 func TestH12_AutoGzipWithDumpResponse(t *testing.T) {
1392 h12Compare{
1393 Handler: func(w ResponseWriter, r *Request) {
1394 h := w.Header()
1395 h.Set("Content-Encoding", "gzip")
1396 h.Set("Content-Length", "23")
1397 io.WriteString(w, "\x1f\x8b\b\x00\x00\x00\x00\x00\x00\x00s\xf3\xf7\a\x00\xab'\xd4\x1a\x03\x00\x00\x00")
1398 },
1399 EarlyCheckResponse: func(proto string, res *Response) {
1400 if !res.Uncompressed {
1401 t.Errorf("%s: expected Uncompressed to be set", proto)
1402 }
1403 dump, err := httputil.DumpResponse(res, true)
1404 if err != nil {
1405 t.Errorf("%s: DumpResponse: %v", proto, err)
1406 return
1407 }
1408 if strings.Contains(string(dump), "Connection: close") {
1409 t.Errorf("%s: should not see \"Connection: close\" in dump; got:\n%s", proto, dump)
1410 }
1411 if !strings.Contains(string(dump), "FOO") {
1412 t.Errorf("%s: should see \"FOO\" in response; got:\n%s", proto, dump)
1413 }
1414 },
1415 }.run(t)
1416 }
1417
1418
1419 func TestCloseIdleConnections(t *testing.T) { run(t, testCloseIdleConnections) }
1420 func testCloseIdleConnections(t *testing.T, mode testMode) {
1421 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1422 w.Header().Set("X-Addr", r.RemoteAddr)
1423 }))
1424 get := func() string {
1425 res, err := cst.c.Get(cst.ts.URL)
1426 if err != nil {
1427 t.Fatal(err)
1428 }
1429 res.Body.Close()
1430 v := res.Header.Get("X-Addr")
1431 if v == "" {
1432 t.Fatal("didn't get X-Addr")
1433 }
1434 return v
1435 }
1436 a1 := get()
1437 cst.tr.CloseIdleConnections()
1438 a2 := get()
1439 if a1 == a2 {
1440 t.Errorf("didn't close connection")
1441 }
1442 }
1443
1444 type noteCloseConn struct {
1445 net.Conn
1446 closeFunc func()
1447 }
1448
1449 func (x noteCloseConn) Close() error {
1450 x.closeFunc()
1451 return x.Conn.Close()
1452 }
1453
1454 type testErrorReader struct{ t *testing.T }
1455
1456 func (r testErrorReader) Read(p []byte) (n int, err error) {
1457 r.t.Error("unexpected Read call")
1458 return 0, io.EOF
1459 }
1460
1461 func TestNoSniffExpectRequestBody(t *testing.T) { run(t, testNoSniffExpectRequestBody) }
1462 func testNoSniffExpectRequestBody(t *testing.T, mode testMode) {
1463 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1464 w.WriteHeader(StatusUnauthorized)
1465 }))
1466
1467
1468 cst.tr.ExpectContinueTimeout = 10 * time.Second
1469
1470 req, err := NewRequest("POST", cst.ts.URL, testErrorReader{t})
1471 if err != nil {
1472 t.Fatal(err)
1473 }
1474 req.ContentLength = 0
1475 req.Header.Set("Expect", "100-continue")
1476 res, err := cst.tr.RoundTrip(req)
1477 if err != nil {
1478 t.Fatal(err)
1479 }
1480 defer res.Body.Close()
1481 if res.StatusCode != StatusUnauthorized {
1482 t.Errorf("status code = %v; want %v", res.StatusCode, StatusUnauthorized)
1483 }
1484 }
1485
1486 func TestServerUndeclaredTrailers(t *testing.T) { run(t, testServerUndeclaredTrailers) }
1487 func testServerUndeclaredTrailers(t *testing.T, mode testMode) {
1488 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1489 w.Header().Set("Foo", "Bar")
1490 w.Header().Set("Trailer:Foo", "Baz")
1491 w.(Flusher).Flush()
1492 w.Header().Add("Trailer:Foo", "Baz2")
1493 w.Header().Set("Trailer:Bar", "Quux")
1494 }))
1495 res, err := cst.c.Get(cst.ts.URL)
1496 if err != nil {
1497 t.Fatal(err)
1498 }
1499 if _, err := io.Copy(io.Discard, res.Body); err != nil {
1500 t.Fatal(err)
1501 }
1502 res.Body.Close()
1503 delete(res.Header, "Date")
1504 delete(res.Header, "Content-Type")
1505
1506 if want := (Header{"Foo": {"Bar"}}); !reflect.DeepEqual(res.Header, want) {
1507 t.Errorf("Header = %#v; want %#v", res.Header, want)
1508 }
1509 if want := (Header{"Foo": {"Baz", "Baz2"}, "Bar": {"Quux"}}); !reflect.DeepEqual(res.Trailer, want) {
1510 t.Errorf("Trailer = %#v; want %#v", res.Trailer, want)
1511 }
1512 }
1513
1514 func TestBadResponseAfterReadingBody(t *testing.T) {
1515 run(t, testBadResponseAfterReadingBody, []testMode{http1Mode})
1516 }
1517 func testBadResponseAfterReadingBody(t *testing.T, mode testMode) {
1518 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1519 _, err := io.Copy(io.Discard, r.Body)
1520 if err != nil {
1521 t.Fatal(err)
1522 }
1523 c, _, err := w.(Hijacker).Hijack()
1524 if err != nil {
1525 t.Fatal(err)
1526 }
1527 defer c.Close()
1528 fmt.Fprintln(c, "some bogus crap")
1529 }))
1530
1531 closes := 0
1532 res, err := cst.c.Post(cst.ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
1533 if err == nil {
1534 res.Body.Close()
1535 t.Fatal("expected an error to be returned from Post")
1536 }
1537 if closes != 1 {
1538 t.Errorf("closes = %d; want 1", closes)
1539 }
1540 }
1541
1542 func TestWriteHeader0(t *testing.T) { run(t, testWriteHeader0) }
1543 func testWriteHeader0(t *testing.T, mode testMode) {
1544 gotpanic := make(chan bool, 1)
1545 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1546 defer close(gotpanic)
1547 defer func() {
1548 if e := recover(); e != nil {
1549 got := fmt.Sprintf("%T, %v", e, e)
1550 want := "string, invalid WriteHeader code 0"
1551 if got != want {
1552 t.Errorf("unexpected panic value:\n got: %v\nwant: %v\n", got, want)
1553 }
1554 gotpanic <- true
1555
1556
1557
1558
1559 w.WriteHeader(503)
1560 }
1561 }()
1562 w.WriteHeader(0)
1563 }))
1564 res, err := cst.c.Get(cst.ts.URL)
1565 if err != nil {
1566 t.Fatal(err)
1567 }
1568 if res.StatusCode != 503 {
1569 t.Errorf("Response: %v %q; want 503", res.StatusCode, res.Status)
1570 }
1571 if !<-gotpanic {
1572 t.Error("expected panic in handler")
1573 }
1574 }
1575
1576
1577
1578 func TestWriteHeaderNoCodeCheck(t *testing.T) {
1579 run(t, func(t *testing.T, mode testMode) {
1580 testWriteHeaderAfterWrite(t, mode, false)
1581 })
1582 }
1583 func TestWriteHeaderNoCodeCheck_h1hijack(t *testing.T) {
1584 testWriteHeaderAfterWrite(t, http1Mode, true)
1585 }
1586 func testWriteHeaderAfterWrite(t *testing.T, mode testMode, hijack bool) {
1587 var errorLog lockedBytesBuffer
1588 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1589 if hijack {
1590 conn, _, _ := w.(Hijacker).Hijack()
1591 defer conn.Close()
1592 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nfoo"))
1593 w.WriteHeader(0)
1594 conn.Write([]byte("bar"))
1595 return
1596 }
1597 io.WriteString(w, "foo")
1598 w.(Flusher).Flush()
1599 w.WriteHeader(0)
1600 io.WriteString(w, "bar")
1601 }), func(ts *httptest.Server) {
1602 ts.Config.ErrorLog = log.New(&errorLog, "", 0)
1603 })
1604 res, err := cst.c.Get(cst.ts.URL)
1605 if err != nil {
1606 t.Fatal(err)
1607 }
1608 defer res.Body.Close()
1609 body, err := io.ReadAll(res.Body)
1610 if err != nil {
1611 t.Fatal(err)
1612 }
1613 if got, want := string(body), "foobar"; got != want {
1614 t.Errorf("got = %q; want %q", got, want)
1615 }
1616
1617
1618 if mode == http2Mode {
1619
1620
1621 return
1622 }
1623 gotLog := strings.TrimSpace(errorLog.String())
1624 wantLog := "http: superfluous response.WriteHeader call from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:"
1625 if hijack {
1626 wantLog = "http: response.WriteHeader on hijacked connection from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:"
1627 }
1628 if !strings.HasPrefix(gotLog, wantLog) {
1629 t.Errorf("stderr output = %q; want %q", gotLog, wantLog)
1630 }
1631 }
1632
1633 func TestBidiStreamReverseProxy(t *testing.T) {
1634 run(t, testBidiStreamReverseProxy, []testMode{http2Mode})
1635 }
1636 func testBidiStreamReverseProxy(t *testing.T, mode testMode) {
1637 backend := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1638 if _, err := io.Copy(w, r.Body); err != nil {
1639 log.Printf("bidi backend copy: %v", err)
1640 }
1641 }))
1642
1643 backURL, err := url.Parse(backend.ts.URL)
1644 if err != nil {
1645 t.Fatal(err)
1646 }
1647 rp := httputil.NewSingleHostReverseProxy(backURL)
1648 rp.Transport = backend.tr
1649 proxy := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1650 rp.ServeHTTP(w, r)
1651 }))
1652
1653 bodyRes := make(chan any, 1)
1654 pr, pw := io.Pipe()
1655 req, _ := NewRequest("PUT", proxy.ts.URL, pr)
1656 const size = 4 << 20
1657 go func() {
1658 h := sha1.New()
1659 _, err := io.CopyN(io.MultiWriter(h, pw), rand.Reader, size)
1660 go pw.Close()
1661 if err != nil {
1662 t.Errorf("body copy: %v", err)
1663 bodyRes <- err
1664 } else {
1665 bodyRes <- h
1666 }
1667 }()
1668 res, err := backend.c.Do(req)
1669 if err != nil {
1670 t.Fatal(err)
1671 }
1672 defer res.Body.Close()
1673 hgot := sha1.New()
1674 n, err := io.Copy(hgot, res.Body)
1675 if err != nil {
1676 t.Fatal(err)
1677 }
1678 if n != size {
1679 t.Fatalf("got %d bytes; want %d", n, size)
1680 }
1681 select {
1682 case v := <-bodyRes:
1683 switch v := v.(type) {
1684 default:
1685 t.Fatalf("body copy: %v", err)
1686 case hash.Hash:
1687 if !bytes.Equal(v.Sum(nil), hgot.Sum(nil)) {
1688 t.Errorf("written bytes didn't match received bytes")
1689 }
1690 }
1691 case <-time.After(10 * time.Second):
1692 t.Fatal("timeout")
1693 }
1694
1695 }
1696
1697
1698 func TestH12_WebSocketUpgrade(t *testing.T) {
1699 h12Compare{
1700 Handler: func(w ResponseWriter, r *Request) {
1701 h := w.Header()
1702 h.Set("Foo", "bar")
1703 },
1704 ReqFunc: func(c *Client, url string) (*Response, error) {
1705 req, _ := NewRequest("GET", url, nil)
1706 req.Header.Set("Connection", "Upgrade")
1707 req.Header.Set("Upgrade", "WebSocket")
1708 return c.Do(req)
1709 },
1710 EarlyCheckResponse: func(proto string, res *Response) {
1711 if res.Proto != "HTTP/1.1" {
1712 t.Errorf("%s: expected HTTP/1.1, got %q", proto, res.Proto)
1713 }
1714 res.Proto = "HTTP/IGNORE"
1715 },
1716 }.run(t)
1717 }
1718
1719 func TestIdentityTransferEncoding(t *testing.T) { run(t, testIdentityTransferEncoding) }
1720 func testIdentityTransferEncoding(t *testing.T, mode testMode) {
1721 const body = "body"
1722 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1723 gotBody, _ := io.ReadAll(r.Body)
1724 if got, want := string(gotBody), body; got != want {
1725 t.Errorf("got request body = %q; want %q", got, want)
1726 }
1727 w.Header().Set("Transfer-Encoding", "identity")
1728 w.WriteHeader(StatusOK)
1729 w.(Flusher).Flush()
1730 io.WriteString(w, body)
1731 }))
1732 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader(body))
1733 res, err := cst.c.Do(req)
1734 if err != nil {
1735 t.Fatal(err)
1736 }
1737 defer res.Body.Close()
1738 gotBody, err := io.ReadAll(res.Body)
1739 if err != nil {
1740 t.Fatal(err)
1741 }
1742 if got, want := string(gotBody), body; got != want {
1743 t.Errorf("got response body = %q; want %q", got, want)
1744 }
1745 }
1746
1747 func TestEarlyHintsRequest(t *testing.T) { run(t, testEarlyHintsRequest) }
1748 func testEarlyHintsRequest(t *testing.T, mode testMode) {
1749 var wg sync.WaitGroup
1750 wg.Add(1)
1751 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1752 h := w.Header()
1753
1754 h.Add("Content-Length", "123")
1755 h.Add("Link", "</style.css>; rel=preload; as=style")
1756 h.Add("Link", "</script.js>; rel=preload; as=script")
1757 w.WriteHeader(StatusEarlyHints)
1758
1759 wg.Wait()
1760
1761 h.Add("Link", "</foo.js>; rel=preload; as=script")
1762 w.WriteHeader(StatusEarlyHints)
1763
1764 w.Write([]byte("Hello"))
1765 }))
1766
1767 checkLinkHeaders := func(t *testing.T, expected, got []string) {
1768 t.Helper()
1769
1770 if len(expected) != len(got) {
1771 t.Errorf("got %d expected %d", len(got), len(expected))
1772 }
1773
1774 for i := range expected {
1775 if expected[i] != got[i] {
1776 t.Errorf("got %q expected %q", got[i], expected[i])
1777 }
1778 }
1779 }
1780
1781 checkExcludedHeaders := func(t *testing.T, header textproto.MIMEHeader) {
1782 t.Helper()
1783
1784 for _, h := range []string{"Content-Length", "Transfer-Encoding"} {
1785 if v, ok := header[h]; ok {
1786 t.Errorf("%s is %q; must not be sent", h, v)
1787 }
1788 }
1789 }
1790
1791 var respCounter uint8
1792 trace := &httptrace.ClientTrace{
1793 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
1794 switch respCounter {
1795 case 0:
1796 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"])
1797 checkExcludedHeaders(t, header)
1798
1799 wg.Done()
1800 case 1:
1801 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"])
1802 checkExcludedHeaders(t, header)
1803
1804 default:
1805 t.Error("Unexpected 1xx response")
1806 }
1807
1808 respCounter++
1809
1810 return nil
1811 },
1812 }
1813 req, _ := NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", cst.ts.URL, nil)
1814
1815 res, err := cst.c.Do(req)
1816 if err != nil {
1817 t.Fatal(err)
1818 }
1819 defer res.Body.Close()
1820
1821 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"])
1822 if cl := res.Header.Get("Content-Length"); cl != "123" {
1823 t.Errorf("Content-Length is %q; want 123", cl)
1824 }
1825
1826 body, _ := io.ReadAll(res.Body)
1827 if string(body) != "Hello" {
1828 t.Errorf("Read body %q; want Hello", body)
1829 }
1830 }
1831
View as plain text