1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26 package http2
27
28 import (
29 "bufio"
30 "bytes"
31 "context"
32 "crypto/rand"
33 "crypto/tls"
34 "errors"
35 "fmt"
36 "io"
37 "log"
38 "math"
39 "net"
40 "net/http/internal"
41 "net/http/internal/httpcommon"
42 "net/textproto"
43 "net/url"
44 "os"
45 "reflect"
46 "runtime"
47 "slices"
48 "strconv"
49 "strings"
50 "sync"
51 "time"
52
53 "golang.org/x/net/http/httpguts"
54 "golang.org/x/net/http2/hpack"
55 )
56
57 const (
58 prefaceTimeout = 10 * time.Second
59 firstSettingsTimeout = 2 * time.Second
60 handlerChunkWriteSize = 4 << 10
61 defaultMaxStreams = 250
62
63
64
65
66 maxQueuedControlFrames = 10000
67 )
68
69 var (
70 errClientDisconnected = errors.New("client disconnected")
71 errClosedBody = errors.New("body closed by handler")
72 errHandlerComplete = errors.New("http2: request body closed due to handler exiting")
73 errStreamClosed = errors.New("http2: stream closed")
74 )
75
76 var responseWriterStatePool = sync.Pool{
77 New: func() any {
78 rws := &responseWriterState{}
79 rws.bw = bufio.NewWriterSize(chunkWriter{rws}, handlerChunkWriteSize)
80 return rws
81 },
82 }
83
84
85 var (
86 testHookOnConn func()
87 testHookOnPanicMu *sync.Mutex
88 testHookOnPanic func(sc *serverConn, panicVal any) (rePanic bool)
89 )
90
91
92 type Server struct {
93 mu sync.Mutex
94 activeConns map[*serverConn]struct{}
95
96
97
98 errChanPool sync.Pool
99 }
100
101 func (s *Server) registerConn(sc *serverConn) {
102 if s == nil {
103 return
104 }
105 s.mu.Lock()
106 s.activeConns[sc] = struct{}{}
107 s.mu.Unlock()
108 }
109
110 func (s *Server) unregisterConn(sc *serverConn) {
111 if s == nil {
112 return
113 }
114 s.mu.Lock()
115 delete(s.activeConns, sc)
116 s.mu.Unlock()
117 }
118
119 func (s *Server) startGracefulShutdown() {
120 if s == nil {
121 return
122 }
123 s.mu.Lock()
124 for sc := range s.activeConns {
125 sc.startGracefulShutdown()
126 }
127 s.mu.Unlock()
128 }
129
130
131
132 var errChanPool = sync.Pool{
133 New: func() any { return make(chan error, 1) },
134 }
135
136 func (s *Server) getErrChan() chan error {
137 if s == nil {
138 return errChanPool.Get().(chan error)
139 }
140 return s.errChanPool.Get().(chan error)
141 }
142
143 func (s *Server) putErrChan(ch chan error) {
144 if s == nil {
145 errChanPool.Put(ch)
146 return
147 }
148 s.errChanPool.Put(ch)
149 }
150
151 func (s *Server) Configure(conf ServerConfig, tcfg *tls.Config) error {
152 s.activeConns = make(map[*serverConn]struct{})
153 s.errChanPool = sync.Pool{New: func() any { return make(chan error, 1) }}
154
155 if tcfg.CipherSuites != nil && tcfg.MinVersion < tls.VersionTLS13 {
156
157
158
159 haveRequired := false
160 for _, cs := range tcfg.CipherSuites {
161 switch cs {
162 case tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
163
164
165 tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256:
166 haveRequired = true
167 }
168 }
169 if !haveRequired {
170 return fmt.Errorf("http2: TLSConfig.CipherSuites is missing an HTTP/2-required AES_128_GCM_SHA256 cipher (need at least one of TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 or TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256)")
171 }
172 }
173
174
175
176
177
178
179
180
181 return nil
182 }
183
184 func (s *Server) GracefulShutdown() {
185 s.startGracefulShutdown()
186 }
187
188
189 type ServeConnOpts struct {
190
191
192 Context context.Context
193
194
195
196 BaseConfig ServerConfig
197
198
199
200
201 Handler Handler
202
203
204
205 Settings []byte
206
207 UpgradeRequest *ServerRequest
208
209
210
211 SawClientPreface bool
212 }
213
214 func (o *ServeConnOpts) context() context.Context {
215 if o != nil && o.Context != nil {
216 return o.Context
217 }
218 return context.Background()
219 }
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235 func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
236 if opts == nil {
237 opts = &ServeConnOpts{}
238 }
239
240 var newf func(*serverConn)
241 if inTests {
242
243 newf, _ = opts.Context.Value(NewConnContextKey).(func(*serverConn))
244 }
245
246 s.serveConn(c, opts, newf)
247 }
248
249 type contextKey string
250
251 var (
252 NewConnContextKey = new("NewConnContextKey")
253 ConnectionStateContextKey = new("ConnectionStateContextKey")
254 )
255
256 func (s *Server) serveConn(c net.Conn, opts *ServeConnOpts, newf func(*serverConn)) {
257 baseCtx, cancel := serverConnBaseContext(c, opts)
258 defer cancel()
259
260 conf := configFromServer(opts.BaseConfig)
261 sc := &serverConn{
262 srv: s,
263 hs: opts.BaseConfig,
264 conn: c,
265 baseCtx: baseCtx,
266 remoteAddrStr: c.RemoteAddr().String(),
267 bw: newBufferedWriter(c, conf.WriteByteTimeout),
268 handler: opts.Handler,
269 streams: make(map[uint32]*stream),
270 readFrameCh: make(chan readFrameResult),
271 wantWriteFrameCh: make(chan FrameWriteRequest, 8),
272 serveMsgCh: make(chan any, 8),
273 wroteFrameCh: make(chan frameWriteResult, 1),
274 bodyReadCh: make(chan bodyReadMsg),
275 doneServing: make(chan struct{}),
276 clientMaxStreams: math.MaxUint32,
277 advMaxStreams: uint32(conf.MaxConcurrentStreams),
278 initialStreamSendWindowSize: initialWindowSize,
279 initialStreamRecvWindowSize: int32(conf.MaxReceiveBufferPerStream),
280 maxFrameSize: initialMaxFrameSize,
281 pingTimeout: conf.PingTimeout,
282 countErrorFunc: conf.CountError,
283 serveG: newGoroutineLock(),
284 pushEnabled: true,
285 sawClientPreface: opts.SawClientPreface,
286 }
287 if newf != nil {
288 newf(sc)
289 }
290
291 s.registerConn(sc)
292 defer s.unregisterConn(sc)
293
294
295
296
297
298
299 if sc.hs.WriteTimeout() > 0 {
300 sc.conn.SetWriteDeadline(time.Time{})
301 }
302
303 switch {
304 case sc.hs.DisableClientPriority():
305 sc.writeSched = newRoundRobinWriteScheduler()
306 default:
307 sc.writeSched = newPriorityWriteSchedulerRFC9218()
308 }
309
310
311
312
313 sc.flow.add(initialWindowSize)
314 sc.inflow.init(initialWindowSize)
315 sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
316 sc.hpackEncoder.SetMaxDynamicTableSizeLimit(uint32(conf.MaxEncoderHeaderTableSize))
317
318 fr := NewFramer(sc.bw, c)
319 if conf.CountError != nil {
320 fr.countError = conf.CountError
321 }
322 fr.ReadMetaHeaders = hpack.NewDecoder(uint32(conf.MaxDecoderHeaderTableSize), nil)
323 fr.MaxHeaderListSize = sc.maxHeaderListSize()
324 fr.SetMaxReadFrameSize(uint32(conf.MaxReadFrameSize))
325 sc.framer = fr
326
327 if tc, ok := c.(connectionStater); ok {
328 sc.tlsState = new(tls.ConnectionState)
329 *sc.tlsState = tc.ConnectionState()
330
331
332 if inTests {
333 f, ok := opts.Context.Value(ConnectionStateContextKey).(func() tls.ConnectionState)
334 if ok {
335 *sc.tlsState = f()
336 }
337 }
338
339
340
341
342
343
344
345
346
347
348
349 if sc.tlsState.Version < tls.VersionTLS12 {
350 sc.rejectConn(ErrCodeInadequateSecurity, "TLS version too low")
351 return
352 }
353
354 if sc.tlsState.ServerName == "" {
355
356
357
358
359
360
361
362
363
364 }
365
366 if !conf.PermitProhibitedCipherSuites && isBadCipher(sc.tlsState.CipherSuite) {
367
368
369
370
371
372
373
374
375
376
377 sc.rejectConn(ErrCodeInadequateSecurity, fmt.Sprintf("Prohibited TLS 1.2 Cipher Suite: %x", sc.tlsState.CipherSuite))
378 return
379 }
380 }
381
382 if opts.Settings != nil {
383 fr := &SettingsFrame{
384 FrameHeader: FrameHeader{valid: true},
385 p: opts.Settings,
386 }
387 if err := fr.ForeachSetting(sc.processSetting); err != nil {
388 sc.rejectConn(ErrCodeProtocol, "invalid settings")
389 return
390 }
391 opts.Settings = nil
392 }
393
394 if opts.UpgradeRequest != nil {
395 sc.upgradeRequest(opts.UpgradeRequest)
396 opts.UpgradeRequest = nil
397 }
398
399 sc.serve(conf)
400 }
401
402 func serverConnBaseContext(c net.Conn, opts *ServeConnOpts) (ctx context.Context, cancel func()) {
403 return context.WithCancel(opts.context())
404 }
405
406 func (sc *serverConn) rejectConn(err ErrCode, debug string) {
407 sc.vlogf("http2: server rejecting conn: %v, %s", err, debug)
408
409 sc.framer.WriteGoAway(0, err, []byte(debug))
410 sc.bw.Flush()
411 sc.conn.Close()
412 }
413
414 type serverConn struct {
415
416 srv *Server
417 hs ServerConfig
418 conn net.Conn
419 bw *bufferedWriter
420 handler Handler
421 baseCtx context.Context
422 framer *Framer
423 doneServing chan struct{}
424 readFrameCh chan readFrameResult
425 wantWriteFrameCh chan FrameWriteRequest
426 wroteFrameCh chan frameWriteResult
427 bodyReadCh chan bodyReadMsg
428 serveMsgCh chan any
429 flow outflow
430 inflow inflow
431 tlsState *tls.ConnectionState
432 remoteAddrStr string
433 writeSched WriteScheduler
434 countErrorFunc func(errType string)
435
436
437 serveG goroutineLock
438 pushEnabled bool
439 sawClientPreface bool
440 sawFirstSettings bool
441 needToSendSettingsAck bool
442 unackedSettings int
443 queuedControlFrames int
444 clientMaxStreams uint32
445 advMaxStreams uint32
446 curClientStreams uint32
447 curPushedStreams uint32
448 curHandlers uint32
449 maxClientStreamID uint32
450 maxPushPromiseID uint32
451 streams map[uint32]*stream
452 unstartedHandlers []unstartedHandler
453 initialStreamSendWindowSize int32
454 initialStreamRecvWindowSize int32
455 maxFrameSize int32
456 peerMaxHeaderListSize uint32
457 canonHeader map[string]string
458 canonHeaderKeysSize int
459 writingFrame bool
460 writingFrameAsync bool
461 needsFrameFlush bool
462 inGoAway bool
463 inFrameScheduleLoop bool
464 needToSendGoAway bool
465 pingSent bool
466 sentPingData [8]byte
467 goAwayCode ErrCode
468 shutdownTimer *time.Timer
469 idleTimer *time.Timer
470 readIdleTimeout time.Duration
471 pingTimeout time.Duration
472 readIdleTimer *time.Timer
473
474
475 headerWriteBuf bytes.Buffer
476 hpackEncoder *hpack.Encoder
477
478
479 shutdownOnce sync.Once
480
481
482 hasIntermediary bool
483 priorityAware bool
484 }
485
486 func (sc *serverConn) writeSchedIgnoresRFC7540() bool {
487 switch sc.writeSched.(type) {
488 case *priorityWriteSchedulerRFC9218:
489 return true
490 case *roundRobinWriteScheduler:
491 return true
492 default:
493 return false
494 }
495 }
496
497 const DefaultMaxHeaderBytes = 1 << 20
498
499 func (sc *serverConn) maxHeaderListSize() uint32 {
500 n := sc.hs.MaxHeaderBytes()
501 if n <= 0 {
502 n = DefaultMaxHeaderBytes
503 }
504 return uint32(adjustHTTP1MaxHeaderSize(int64(n)))
505 }
506
507 func (sc *serverConn) curOpenStreams() uint32 {
508 sc.serveG.check()
509 return sc.curClientStreams + sc.curPushedStreams
510 }
511
512
513
514
515
516
517
518
519 type stream struct {
520
521 sc *serverConn
522 id uint32
523 body *pipe
524 cw closeWaiter
525 ctx context.Context
526 cancelCtx func()
527
528
529 bodyBytes int64
530 declBodyBytes int64
531 flow outflow
532 inflow inflow
533 state streamState
534 resetQueued bool
535 gotTrailerHeader bool
536 wroteHeaders bool
537 readDeadline *time.Timer
538 writeDeadline *time.Timer
539 closeErr error
540
541 trailer Header
542 reqTrailer Header
543 }
544
545 func (sc *serverConn) Framer() *Framer { return sc.framer }
546 func (sc *serverConn) CloseConn() error { return sc.conn.Close() }
547 func (sc *serverConn) Flush() error { return sc.bw.Flush() }
548 func (sc *serverConn) HeaderEncoder() (*hpack.Encoder, *bytes.Buffer) {
549 return sc.hpackEncoder, &sc.headerWriteBuf
550 }
551
552 func (sc *serverConn) state(streamID uint32) (streamState, *stream) {
553 sc.serveG.check()
554
555 if st, ok := sc.streams[streamID]; ok {
556 return st.state, st
557 }
558
559
560
561
562
563
564 if streamID%2 == 1 {
565 if streamID <= sc.maxClientStreamID {
566 return stateClosed, nil
567 }
568 } else {
569 if streamID <= sc.maxPushPromiseID {
570 return stateClosed, nil
571 }
572 }
573 return stateIdle, nil
574 }
575
576
577
578
579 func (sc *serverConn) setConnState(state ConnState) {
580 sc.hs.ConnState(sc.conn, state)
581 }
582
583 func (sc *serverConn) vlogf(format string, args ...any) {
584 if VerboseLogs {
585 sc.logf(format, args...)
586 }
587 }
588
589 func (sc *serverConn) logf(format string, args ...any) {
590 if lg := sc.hs.ErrorLog(); lg != nil {
591 lg.Printf(format, args...)
592 } else {
593 log.Printf(format, args...)
594 }
595 }
596
597
598
599
600
601 func errno(v error) uintptr {
602 if rv := reflect.ValueOf(v); rv.Kind() == reflect.Uintptr {
603 return uintptr(rv.Uint())
604 }
605 return 0
606 }
607
608
609
610 func isClosedConnError(err error) bool {
611 if err == nil {
612 return false
613 }
614
615 if errors.Is(err, net.ErrClosed) {
616 return true
617 }
618
619
620
621
622
623 if runtime.GOOS == "windows" {
624 if oe, ok := err.(*net.OpError); ok && oe.Op == "read" {
625 if se, ok := oe.Err.(*os.SyscallError); ok && se.Syscall == "wsarecv" {
626 const WSAECONNABORTED = 10053
627 const WSAECONNRESET = 10054
628 if n := errno(se.Err); n == WSAECONNRESET || n == WSAECONNABORTED {
629 return true
630 }
631 }
632 }
633 }
634 return false
635 }
636
637 func (sc *serverConn) condlogf(err error, format string, args ...any) {
638 if err == nil {
639 return
640 }
641 if err == io.EOF || err == io.ErrUnexpectedEOF || isClosedConnError(err) || err == errPrefaceTimeout {
642
643 sc.vlogf(format, args...)
644 } else {
645 sc.logf(format, args...)
646 }
647 }
648
649
650
651
652
653
654 const maxCachedCanonicalHeadersKeysSize = 2048
655
656 func (sc *serverConn) canonicalHeader(v string) string {
657 sc.serveG.check()
658 cv, ok := httpcommon.CachedCanonicalHeader(v)
659 if ok {
660 return cv
661 }
662 cv, ok = sc.canonHeader[v]
663 if ok {
664 return cv
665 }
666 if sc.canonHeader == nil {
667 sc.canonHeader = make(map[string]string)
668 }
669 cv = textproto.CanonicalMIMEHeaderKey(v)
670 size := 100 + len(v)*2
671 if sc.canonHeaderKeysSize+size <= maxCachedCanonicalHeadersKeysSize {
672 sc.canonHeader[v] = cv
673 sc.canonHeaderKeysSize += size
674 }
675 return cv
676 }
677
678 type readFrameResult struct {
679 f Frame
680 err error
681
682
683
684
685 readMore func()
686 }
687
688
689
690
691
692 func (sc *serverConn) readFrames() {
693 gate := make(chan struct{})
694 gateDone := func() { gate <- struct{}{} }
695 for {
696 f, err := sc.framer.ReadFrame()
697 select {
698 case sc.readFrameCh <- readFrameResult{f, err, gateDone}:
699 case <-sc.doneServing:
700 return
701 }
702 select {
703 case <-gate:
704 case <-sc.doneServing:
705 return
706 }
707 if terminalReadFrameError(err) {
708 return
709 }
710 }
711 }
712
713
714 type frameWriteResult struct {
715 _ incomparable
716 wr FrameWriteRequest
717 err error
718 }
719
720
721
722
723
724 func (sc *serverConn) writeFrameAsync(wr FrameWriteRequest, wd *writeData) {
725 var err error
726 if wd == nil {
727 err = wr.write.writeFrame(sc)
728 } else {
729 err = sc.framer.endWrite()
730 }
731 sc.wroteFrameCh <- frameWriteResult{wr: wr, err: err}
732 }
733
734 func (sc *serverConn) closeAllStreamsOnConnClose() {
735 sc.serveG.check()
736 for _, st := range sc.streams {
737 sc.closeStream(st, errClientDisconnected)
738 }
739 }
740
741 func (sc *serverConn) stopShutdownTimer() {
742 sc.serveG.check()
743 if t := sc.shutdownTimer; t != nil {
744 t.Stop()
745 }
746 }
747
748 func (sc *serverConn) notePanic() {
749
750 if testHookOnPanicMu != nil {
751 testHookOnPanicMu.Lock()
752 defer testHookOnPanicMu.Unlock()
753 }
754 if testHookOnPanic != nil {
755 if e := recover(); e != nil {
756 if testHookOnPanic(sc, e) {
757 panic(e)
758 }
759 }
760 }
761 }
762
763 func (sc *serverConn) serve(conf Config) {
764 sc.serveG.check()
765 defer sc.notePanic()
766 defer sc.conn.Close()
767 defer sc.closeAllStreamsOnConnClose()
768 defer sc.stopShutdownTimer()
769 defer close(sc.doneServing)
770
771 if VerboseLogs {
772 sc.vlogf("http2: server connection from %v on %p", sc.conn.RemoteAddr(), sc.hs)
773 }
774
775 settings := writeSettings{
776 {SettingMaxFrameSize, uint32(conf.MaxReadFrameSize)},
777 {SettingMaxConcurrentStreams, sc.advMaxStreams},
778 {SettingMaxHeaderListSize, sc.maxHeaderListSize()},
779 {SettingHeaderTableSize, uint32(conf.MaxDecoderHeaderTableSize)},
780 {SettingInitialWindowSize, uint32(sc.initialStreamRecvWindowSize)},
781 }
782 if !disableExtendedConnectProtocol {
783 settings = append(settings, Setting{SettingEnableConnectProtocol, 1})
784 }
785 if sc.writeSchedIgnoresRFC7540() {
786 settings = append(settings, Setting{SettingNoRFC7540Priorities, 1})
787 }
788 sc.writeFrame(FrameWriteRequest{
789 write: settings,
790 })
791 sc.unackedSettings++
792
793
794
795 if diff := conf.MaxReceiveBufferPerConnection - initialWindowSize; diff > 0 {
796 sc.sendWindowUpdate(nil, int(diff))
797 }
798
799 if err := sc.readPreface(); err != nil {
800 sc.condlogf(err, "http2: server: error reading preface from client %v: %v", sc.conn.RemoteAddr(), err)
801 return
802 }
803
804
805
806
807 sc.setConnState(ConnStateActive)
808 sc.setConnState(ConnStateIdle)
809
810 if idle := sc.hs.IdleTimeout(); idle > 0 {
811 sc.idleTimer = time.AfterFunc(idle, sc.onIdleTimer)
812 defer sc.idleTimer.Stop()
813 }
814
815 if conf.SendPingTimeout > 0 {
816 sc.readIdleTimeout = conf.SendPingTimeout
817 sc.readIdleTimer = time.AfterFunc(conf.SendPingTimeout, sc.onReadIdleTimer)
818 defer sc.readIdleTimer.Stop()
819 }
820
821 go sc.readFrames()
822
823 settingsTimer := time.AfterFunc(firstSettingsTimeout, sc.onSettingsTimer)
824 defer settingsTimer.Stop()
825
826 lastFrameTime := time.Now()
827 loopNum := 0
828 for {
829 loopNum++
830 select {
831 case wr := <-sc.wantWriteFrameCh:
832 if se, ok := wr.write.(StreamError); ok {
833 sc.resetStream(se)
834 break
835 }
836 sc.writeFrame(wr)
837 case res := <-sc.wroteFrameCh:
838 sc.wroteFrame(res)
839 case res := <-sc.readFrameCh:
840 lastFrameTime = time.Now()
841
842
843 if sc.writingFrameAsync {
844 select {
845 case wroteRes := <-sc.wroteFrameCh:
846 sc.wroteFrame(wroteRes)
847 default:
848 }
849 }
850 if !sc.processFrameFromReader(res) {
851 return
852 }
853 res.readMore()
854 if settingsTimer != nil {
855 settingsTimer.Stop()
856 settingsTimer = nil
857 }
858 case m := <-sc.bodyReadCh:
859 sc.noteBodyRead(m.st, m.n)
860 case msg := <-sc.serveMsgCh:
861 switch v := msg.(type) {
862 case func(int):
863 v(loopNum)
864 case *serverMessage:
865 switch v {
866 case settingsTimerMsg:
867 sc.logf("timeout waiting for SETTINGS frames from %v", sc.conn.RemoteAddr())
868 return
869 case idleTimerMsg:
870 sc.vlogf("connection is idle")
871 sc.goAway(ErrCodeNo)
872 case readIdleTimerMsg:
873 sc.handlePingTimer(lastFrameTime)
874 case shutdownTimerMsg:
875 sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr())
876 return
877 case gracefulShutdownMsg:
878 sc.startGracefulShutdownInternal()
879 case handlerDoneMsg:
880 sc.handlerDone()
881 default:
882 panic("unknown timer")
883 }
884 case *startPushRequest:
885 sc.startPush(v)
886 case func(*serverConn):
887 v(sc)
888 default:
889 panic(fmt.Sprintf("unexpected type %T", v))
890 }
891 }
892
893
894
895
896 if sc.queuedControlFrames > maxQueuedControlFrames {
897 sc.vlogf("http2: too many control frames in send queue, closing connection")
898 return
899 }
900
901
902
903
904 sentGoAway := sc.inGoAway && !sc.needToSendGoAway && !sc.writingFrame
905 gracefulShutdownComplete := sc.goAwayCode == ErrCodeNo && sc.curOpenStreams() == 0
906 if sentGoAway && sc.shutdownTimer == nil && (sc.goAwayCode != ErrCodeNo || gracefulShutdownComplete) {
907 sc.shutDownIn(goAwayTimeout)
908 }
909 }
910 }
911
912 func (sc *serverConn) handlePingTimer(lastFrameReadTime time.Time) {
913 if sc.pingSent {
914 sc.logf("timeout waiting for PING response")
915 if f := sc.countErrorFunc; f != nil {
916 f("conn_close_lost_ping")
917 }
918 sc.conn.Close()
919 return
920 }
921
922 pingAt := lastFrameReadTime.Add(sc.readIdleTimeout)
923 now := time.Now()
924 if pingAt.After(now) {
925
926
927 sc.readIdleTimer.Reset(pingAt.Sub(now))
928 return
929 }
930
931 sc.pingSent = true
932
933
934 _, _ = rand.Read(sc.sentPingData[:])
935 sc.writeFrame(FrameWriteRequest{
936 write: &writePing{data: sc.sentPingData},
937 })
938 sc.readIdleTimer.Reset(sc.pingTimeout)
939 }
940
941 type serverMessage int
942
943
944 var (
945 settingsTimerMsg = new(serverMessage)
946 idleTimerMsg = new(serverMessage)
947 readIdleTimerMsg = new(serverMessage)
948 shutdownTimerMsg = new(serverMessage)
949 gracefulShutdownMsg = new(serverMessage)
950 handlerDoneMsg = new(serverMessage)
951 )
952
953 func (sc *serverConn) onSettingsTimer() { sc.sendServeMsg(settingsTimerMsg) }
954 func (sc *serverConn) onIdleTimer() { sc.sendServeMsg(idleTimerMsg) }
955 func (sc *serverConn) onReadIdleTimer() { sc.sendServeMsg(readIdleTimerMsg) }
956 func (sc *serverConn) onShutdownTimer() { sc.sendServeMsg(shutdownTimerMsg) }
957
958 func (sc *serverConn) sendServeMsg(msg any) {
959 sc.serveG.checkNotOn()
960 select {
961 case sc.serveMsgCh <- msg:
962 case <-sc.doneServing:
963 }
964 }
965
966 var errPrefaceTimeout = errors.New("timeout waiting for client preface")
967
968
969
970
971 func (sc *serverConn) readPreface() error {
972 if sc.sawClientPreface {
973 return nil
974 }
975 errc := make(chan error, 1)
976 go func() {
977
978 buf := make([]byte, len(ClientPreface))
979 if _, err := io.ReadFull(sc.conn, buf); err != nil {
980 errc <- err
981 } else if !bytes.Equal(buf, clientPreface) {
982 errc <- fmt.Errorf("bogus greeting %q", buf)
983 } else {
984 errc <- nil
985 }
986 }()
987 timer := time.NewTimer(prefaceTimeout)
988 defer timer.Stop()
989 select {
990 case <-timer.C:
991 return errPrefaceTimeout
992 case err := <-errc:
993 if err == nil {
994 if VerboseLogs {
995 sc.vlogf("http2: server: client %v said hello", sc.conn.RemoteAddr())
996 }
997 }
998 return err
999 }
1000 }
1001
1002 var writeDataPool = sync.Pool{
1003 New: func() any { return new(writeData) },
1004 }
1005
1006
1007
1008 func (sc *serverConn) writeDataFromHandler(stream *stream, data []byte, endStream bool) error {
1009 ch := sc.srv.getErrChan()
1010 writeArg := writeDataPool.Get().(*writeData)
1011 *writeArg = writeData{stream.id, data, endStream}
1012 err := sc.writeFrameFromHandler(FrameWriteRequest{
1013 write: writeArg,
1014 stream: stream,
1015 done: ch,
1016 })
1017 if err != nil {
1018 return err
1019 }
1020 var frameWriteDone bool
1021 select {
1022 case err = <-ch:
1023 frameWriteDone = true
1024 case <-sc.doneServing:
1025 return errClientDisconnected
1026 case <-stream.cw:
1027
1028
1029
1030
1031
1032
1033
1034 select {
1035 case err = <-ch:
1036 frameWriteDone = true
1037 default:
1038 return errStreamClosed
1039 }
1040 }
1041 sc.srv.putErrChan(ch)
1042 if frameWriteDone {
1043 writeDataPool.Put(writeArg)
1044 }
1045 return err
1046 }
1047
1048
1049
1050
1051
1052
1053
1054
1055 func (sc *serverConn) writeFrameFromHandler(wr FrameWriteRequest) error {
1056 sc.serveG.checkNotOn()
1057 select {
1058 case sc.wantWriteFrameCh <- wr:
1059 return nil
1060 case <-sc.doneServing:
1061
1062
1063 return errClientDisconnected
1064 }
1065 }
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075 func (sc *serverConn) writeFrame(wr FrameWriteRequest) {
1076 sc.serveG.check()
1077
1078
1079 var ignoreWrite bool
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099 if wr.StreamID() != 0 {
1100 _, isReset := wr.write.(StreamError)
1101 if state, _ := sc.state(wr.StreamID()); state == stateClosed && !isReset {
1102 ignoreWrite = true
1103 }
1104 }
1105
1106
1107
1108 switch wr.write.(type) {
1109 case *writeResHeaders:
1110 wr.stream.wroteHeaders = true
1111 case write100ContinueHeadersFrame:
1112 if wr.stream.wroteHeaders {
1113
1114
1115 if wr.done != nil {
1116 panic("wr.done != nil for write100ContinueHeadersFrame")
1117 }
1118 ignoreWrite = true
1119 }
1120 }
1121
1122 if !ignoreWrite {
1123 if wr.isControl() {
1124 sc.queuedControlFrames++
1125
1126
1127 if sc.queuedControlFrames < 0 {
1128 sc.conn.Close()
1129 }
1130 }
1131 sc.writeSched.Push(wr)
1132 }
1133 sc.scheduleFrameWrite()
1134 }
1135
1136
1137
1138
1139 func (sc *serverConn) startFrameWrite(wr FrameWriteRequest) {
1140 sc.serveG.check()
1141 if sc.writingFrame {
1142 panic("internal error: can only be writing one frame at a time")
1143 }
1144
1145 st := wr.stream
1146 if st != nil {
1147 switch st.state {
1148 case stateHalfClosedLocal:
1149 switch wr.write.(type) {
1150 case StreamError, handlerPanicRST, writeWindowUpdate:
1151
1152
1153 default:
1154 panic(fmt.Sprintf("internal error: attempt to send frame on a half-closed-local stream: %v", wr))
1155 }
1156 case stateClosed:
1157 panic(fmt.Sprintf("internal error: attempt to send frame on a closed stream: %v", wr))
1158 }
1159 }
1160 if wpp, ok := wr.write.(*writePushPromise); ok {
1161 var err error
1162 wpp.promisedID, err = wpp.allocatePromisedID()
1163 if err != nil {
1164 sc.writingFrameAsync = false
1165 wr.replyToWriter(err)
1166 return
1167 }
1168 }
1169
1170 sc.writingFrame = true
1171 sc.needsFrameFlush = true
1172 if wr.write.staysWithinBuffer(sc.bw.Available()) {
1173 sc.writingFrameAsync = false
1174 err := wr.write.writeFrame(sc)
1175 sc.wroteFrame(frameWriteResult{wr: wr, err: err})
1176 } else if wd, ok := wr.write.(*writeData); ok {
1177
1178
1179
1180 sc.framer.startWriteDataPadded(wd.streamID, wd.endStream, wd.p, nil)
1181 sc.writingFrameAsync = true
1182 go sc.writeFrameAsync(wr, wd)
1183 } else {
1184 sc.writingFrameAsync = true
1185 go sc.writeFrameAsync(wr, nil)
1186 }
1187 }
1188
1189
1190
1191
1192 var errHandlerPanicked = errors.New("http2: handler panicked")
1193
1194
1195
1196 func (sc *serverConn) wroteFrame(res frameWriteResult) {
1197 sc.serveG.check()
1198 if !sc.writingFrame {
1199 panic("internal error: expected to be already writing a frame")
1200 }
1201 sc.writingFrame = false
1202 sc.writingFrameAsync = false
1203
1204 if res.err != nil {
1205 sc.conn.Close()
1206 }
1207
1208 wr := res.wr
1209
1210 if writeEndsStream(wr.write) {
1211 st := wr.stream
1212 if st == nil {
1213 panic("internal error: expecting non-nil stream")
1214 }
1215 switch st.state {
1216 case stateOpen:
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227 st.state = stateHalfClosedLocal
1228
1229
1230
1231
1232 sc.resetStream(streamError(st.id, ErrCodeNo))
1233 case stateHalfClosedRemote:
1234 sc.closeStream(st, errHandlerComplete)
1235 }
1236 } else {
1237 switch v := wr.write.(type) {
1238 case StreamError:
1239
1240 if st, ok := sc.streams[v.StreamID]; ok {
1241 sc.closeStream(st, v)
1242 }
1243 case handlerPanicRST:
1244 sc.closeStream(wr.stream, errHandlerPanicked)
1245 }
1246 }
1247
1248
1249 wr.replyToWriter(res.err)
1250
1251 sc.scheduleFrameWrite()
1252 }
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264 func (sc *serverConn) scheduleFrameWrite() {
1265 sc.serveG.check()
1266 if sc.writingFrame || sc.inFrameScheduleLoop {
1267 return
1268 }
1269 sc.inFrameScheduleLoop = true
1270 for !sc.writingFrameAsync {
1271 if sc.needToSendGoAway {
1272 sc.needToSendGoAway = false
1273 sc.startFrameWrite(FrameWriteRequest{
1274 write: &writeGoAway{
1275 maxStreamID: sc.maxClientStreamID,
1276 code: sc.goAwayCode,
1277 },
1278 })
1279 continue
1280 }
1281 if sc.needToSendSettingsAck {
1282 sc.needToSendSettingsAck = false
1283 sc.startFrameWrite(FrameWriteRequest{write: writeSettingsAck{}})
1284 continue
1285 }
1286 if !sc.inGoAway || sc.goAwayCode == ErrCodeNo {
1287 if wr, ok := sc.writeSched.Pop(); ok {
1288 if wr.isControl() {
1289 sc.queuedControlFrames--
1290 }
1291 sc.startFrameWrite(wr)
1292 continue
1293 }
1294 }
1295 if sc.needsFrameFlush {
1296 sc.startFrameWrite(FrameWriteRequest{write: flushFrameWriter{}})
1297 sc.needsFrameFlush = false
1298 continue
1299 }
1300 break
1301 }
1302 sc.inFrameScheduleLoop = false
1303 }
1304
1305
1306
1307
1308
1309
1310
1311
1312 func (sc *serverConn) startGracefulShutdown() {
1313 sc.serveG.checkNotOn()
1314 sc.shutdownOnce.Do(func() { sc.sendServeMsg(gracefulShutdownMsg) })
1315 }
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333 var goAwayTimeout = 1 * time.Second
1334
1335 func (sc *serverConn) startGracefulShutdownInternal() {
1336 sc.goAway(ErrCodeNo)
1337 }
1338
1339 func (sc *serverConn) goAway(code ErrCode) {
1340 sc.serveG.check()
1341 if sc.inGoAway {
1342 if sc.goAwayCode == ErrCodeNo {
1343 sc.goAwayCode = code
1344 }
1345 return
1346 }
1347 sc.inGoAway = true
1348 sc.needToSendGoAway = true
1349 sc.goAwayCode = code
1350 sc.scheduleFrameWrite()
1351 }
1352
1353 func (sc *serverConn) shutDownIn(d time.Duration) {
1354 sc.serveG.check()
1355 sc.shutdownTimer = time.AfterFunc(d, sc.onShutdownTimer)
1356 }
1357
1358 func (sc *serverConn) resetStream(se StreamError) {
1359 sc.serveG.check()
1360 sc.writeFrame(FrameWriteRequest{write: se})
1361 if st, ok := sc.streams[se.StreamID]; ok {
1362 st.resetQueued = true
1363 }
1364 }
1365
1366
1367
1368
1369 func (sc *serverConn) processFrameFromReader(res readFrameResult) bool {
1370 sc.serveG.check()
1371 err := res.err
1372 if err != nil {
1373 if err == ErrFrameTooLarge {
1374 sc.goAway(ErrCodeFrameSize)
1375 return true
1376 }
1377 clientGone := err == io.EOF || err == io.ErrUnexpectedEOF || isClosedConnError(err)
1378 if clientGone {
1379
1380
1381
1382
1383
1384
1385
1386
1387 return false
1388 }
1389 } else {
1390 f := res.f
1391 if VerboseLogs {
1392 sc.vlogf("http2: server read frame %v", summarizeFrame(f))
1393 }
1394 err = sc.processFrame(f)
1395 if err == nil {
1396 return true
1397 }
1398 }
1399
1400 switch ev := err.(type) {
1401 case StreamError:
1402 sc.resetStream(ev)
1403 return true
1404 case goAwayFlowError:
1405 sc.goAway(ErrCodeFlowControl)
1406 return true
1407 case ConnectionError:
1408 if res.f != nil {
1409 if id := res.f.Header().StreamID; id > sc.maxClientStreamID {
1410 sc.maxClientStreamID = id
1411 }
1412 }
1413 sc.logf("http2: server connection error from %v: %v", sc.conn.RemoteAddr(), ev)
1414 sc.goAway(ErrCode(ev))
1415 return true
1416 default:
1417 if res.err != nil {
1418 sc.vlogf("http2: server closing client connection; error reading frame from client %s: %v", sc.conn.RemoteAddr(), err)
1419 } else {
1420 sc.logf("http2: server closing client connection: %v", err)
1421 }
1422 return false
1423 }
1424 }
1425
1426 func (sc *serverConn) processFrame(f Frame) error {
1427 sc.serveG.check()
1428
1429
1430 if !sc.sawFirstSettings {
1431 if _, ok := f.(*SettingsFrame); !ok {
1432 return sc.countError("first_settings", ConnectionError(ErrCodeProtocol))
1433 }
1434 sc.sawFirstSettings = true
1435 }
1436
1437
1438
1439
1440
1441 if sc.inGoAway && (sc.goAwayCode != ErrCodeNo || f.Header().StreamID > sc.maxClientStreamID) {
1442
1443 if f, ok := f.(*DataFrame); ok {
1444 if !sc.inflow.take(f.Length) {
1445 return sc.countError("data_flow", streamError(f.Header().StreamID, ErrCodeFlowControl))
1446 }
1447 sc.sendWindowUpdate(nil, int(f.Length))
1448 }
1449 return nil
1450 }
1451
1452 switch f := f.(type) {
1453 case *SettingsFrame:
1454 return sc.processSettings(f)
1455 case *MetaHeadersFrame:
1456 return sc.processHeaders(f)
1457 case *WindowUpdateFrame:
1458 return sc.processWindowUpdate(f)
1459 case *PingFrame:
1460 return sc.processPing(f)
1461 case *DataFrame:
1462 return sc.processData(f)
1463 case *RSTStreamFrame:
1464 return sc.processResetStream(f)
1465 case *PriorityFrame:
1466 return sc.processPriority(f)
1467 case *GoAwayFrame:
1468 return sc.processGoAway(f)
1469 case *PushPromiseFrame:
1470
1471
1472 return sc.countError("push_promise", ConnectionError(ErrCodeProtocol))
1473 case *PriorityUpdateFrame:
1474 return sc.processPriorityUpdate(f)
1475 default:
1476 sc.vlogf("http2: server ignoring frame: %v", f.Header())
1477 return nil
1478 }
1479 }
1480
1481 func (sc *serverConn) processPing(f *PingFrame) error {
1482 sc.serveG.check()
1483 if f.IsAck() {
1484 if sc.pingSent && sc.sentPingData == f.Data {
1485
1486 sc.pingSent = false
1487 sc.readIdleTimer.Reset(sc.readIdleTimeout)
1488 }
1489
1490
1491 return nil
1492 }
1493 if f.StreamID != 0 {
1494
1495
1496
1497
1498
1499 return sc.countError("ping_on_stream", ConnectionError(ErrCodeProtocol))
1500 }
1501 sc.writeFrame(FrameWriteRequest{write: writePingAck{f}})
1502 return nil
1503 }
1504
1505 func (sc *serverConn) processWindowUpdate(f *WindowUpdateFrame) error {
1506 sc.serveG.check()
1507 switch {
1508 case f.StreamID != 0:
1509 state, st := sc.state(f.StreamID)
1510 if state == stateIdle {
1511
1512
1513
1514
1515 return sc.countError("stream_idle", ConnectionError(ErrCodeProtocol))
1516 }
1517 if st == nil {
1518
1519
1520
1521
1522
1523 return nil
1524 }
1525 if !st.flow.add(int32(f.Increment)) {
1526 return sc.countError("bad_flow", streamError(f.StreamID, ErrCodeFlowControl))
1527 }
1528 default:
1529 if !sc.flow.add(int32(f.Increment)) {
1530 return goAwayFlowError{}
1531 }
1532 }
1533 sc.scheduleFrameWrite()
1534 return nil
1535 }
1536
1537 func (sc *serverConn) processResetStream(f *RSTStreamFrame) error {
1538 sc.serveG.check()
1539
1540 state, st := sc.state(f.StreamID)
1541 if state == stateIdle {
1542
1543
1544
1545
1546
1547 return sc.countError("reset_idle_stream", ConnectionError(ErrCodeProtocol))
1548 }
1549 if st != nil {
1550 st.cancelCtx()
1551 sc.closeStream(st, streamError(f.StreamID, f.ErrCode))
1552 }
1553 return nil
1554 }
1555
1556 func (sc *serverConn) closeStream(st *stream, err error) {
1557 sc.serveG.check()
1558 if st.state == stateIdle || st.state == stateClosed {
1559 panic(fmt.Sprintf("invariant; can't close stream in state %v", st.state))
1560 }
1561 st.state = stateClosed
1562 if st.readDeadline != nil {
1563 st.readDeadline.Stop()
1564 }
1565 if st.writeDeadline != nil {
1566 st.writeDeadline.Stop()
1567 }
1568 if st.isPushed() {
1569 sc.curPushedStreams--
1570 } else {
1571 sc.curClientStreams--
1572 }
1573 delete(sc.streams, st.id)
1574 if len(sc.streams) == 0 {
1575 sc.setConnState(ConnStateIdle)
1576 idleTimeout := sc.hs.IdleTimeout()
1577 if idleTimeout > 0 && sc.idleTimer != nil {
1578 sc.idleTimer.Reset(idleTimeout)
1579 }
1580 if h1ServerKeepAlivesDisabled(sc.hs) {
1581 sc.startGracefulShutdownInternal()
1582 }
1583 }
1584 if p := st.body; p != nil {
1585
1586
1587 sc.sendWindowUpdate(nil, p.Len())
1588
1589 p.CloseWithError(err)
1590 }
1591 if e, ok := err.(StreamError); ok {
1592 if e.Cause != nil {
1593 err = e.Cause
1594 } else {
1595 err = errStreamClosed
1596 }
1597 }
1598 st.closeErr = err
1599 st.cancelCtx()
1600 st.cw.Close()
1601 sc.writeSched.CloseStream(st.id)
1602 }
1603
1604 func (sc *serverConn) processSettings(f *SettingsFrame) error {
1605 sc.serveG.check()
1606 if f.IsAck() {
1607 sc.unackedSettings--
1608 if sc.unackedSettings < 0 {
1609
1610
1611
1612 return sc.countError("ack_mystery", ConnectionError(ErrCodeProtocol))
1613 }
1614 return nil
1615 }
1616 if f.NumSettings() > 100 || f.HasDuplicates() {
1617
1618
1619
1620 return sc.countError("settings_big_or_dups", ConnectionError(ErrCodeProtocol))
1621 }
1622 if err := f.ForeachSetting(sc.processSetting); err != nil {
1623 return err
1624 }
1625
1626
1627 sc.needToSendSettingsAck = true
1628 sc.scheduleFrameWrite()
1629 return nil
1630 }
1631
1632 func (sc *serverConn) processSetting(s Setting) error {
1633 sc.serveG.check()
1634 if err := s.Valid(); err != nil {
1635 return err
1636 }
1637 if VerboseLogs {
1638 sc.vlogf("http2: server processing setting %v", s)
1639 }
1640 switch s.ID {
1641 case SettingHeaderTableSize:
1642 sc.hpackEncoder.SetMaxDynamicTableSize(s.Val)
1643 case SettingEnablePush:
1644 sc.pushEnabled = s.Val != 0
1645 case SettingMaxConcurrentStreams:
1646 sc.clientMaxStreams = s.Val
1647 case SettingInitialWindowSize:
1648 return sc.processSettingInitialWindowSize(s.Val)
1649 case SettingMaxFrameSize:
1650 sc.maxFrameSize = int32(s.Val)
1651 case SettingMaxHeaderListSize:
1652 sc.peerMaxHeaderListSize = s.Val
1653 case SettingEnableConnectProtocol:
1654
1655
1656 case SettingNoRFC7540Priorities:
1657 if s.Val > 1 {
1658 return ConnectionError(ErrCodeProtocol)
1659 }
1660 default:
1661
1662
1663
1664 if VerboseLogs {
1665 sc.vlogf("http2: server ignoring unknown setting %v", s)
1666 }
1667 }
1668 return nil
1669 }
1670
1671 func (sc *serverConn) processSettingInitialWindowSize(val uint32) error {
1672 sc.serveG.check()
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682 old := sc.initialStreamSendWindowSize
1683 sc.initialStreamSendWindowSize = int32(val)
1684 growth := int32(val) - old
1685 for _, st := range sc.streams {
1686 if !st.flow.add(growth) {
1687
1688
1689
1690
1691
1692
1693 return sc.countError("setting_win_size", ConnectionError(ErrCodeFlowControl))
1694 }
1695 }
1696 return nil
1697 }
1698
1699 func (sc *serverConn) processData(f *DataFrame) error {
1700 sc.serveG.check()
1701 id := f.Header().StreamID
1702
1703 data := f.Data()
1704 state, st := sc.state(id)
1705 if id == 0 || state == stateIdle {
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716 return sc.countError("data_on_idle", ConnectionError(ErrCodeProtocol))
1717 }
1718
1719
1720
1721
1722 if st == nil || state != stateOpen || st.gotTrailerHeader || st.resetQueued {
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732 if !sc.inflow.take(f.Length) {
1733 return sc.countError("data_flow", streamError(id, ErrCodeFlowControl))
1734 }
1735 sc.sendWindowUpdate(nil, int(f.Length))
1736
1737 if st != nil && st.resetQueued {
1738
1739 return nil
1740 }
1741 return sc.countError("closed", streamError(id, ErrCodeStreamClosed))
1742 }
1743 if st.body == nil {
1744 panic("internal error: should have a body in this state")
1745 }
1746
1747
1748 if st.declBodyBytes != -1 && st.bodyBytes+int64(len(data)) > st.declBodyBytes {
1749 if !sc.inflow.take(f.Length) {
1750 return sc.countError("data_flow", streamError(id, ErrCodeFlowControl))
1751 }
1752 sc.sendWindowUpdate(nil, int(f.Length))
1753
1754 st.body.CloseWithError(fmt.Errorf("sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes))
1755
1756
1757
1758 return sc.countError("send_too_much", streamError(id, ErrCodeProtocol))
1759 }
1760 if f.Length > 0 {
1761
1762 if !takeInflows(&sc.inflow, &st.inflow, f.Length) {
1763 return sc.countError("flow_on_data_length", streamError(id, ErrCodeFlowControl))
1764 }
1765
1766 if len(data) > 0 {
1767 st.bodyBytes += int64(len(data))
1768 wrote, err := st.body.Write(data)
1769 if err != nil {
1770
1771
1772
1773 sc.sendWindowUpdate(nil, int(f.Length)-wrote)
1774 return nil
1775 }
1776 if wrote != len(data) {
1777 panic("internal error: bad Writer")
1778 }
1779 }
1780
1781
1782
1783
1784
1785
1786 pad := int32(f.Length) - int32(len(data))
1787 sc.sendWindowUpdate32(nil, pad)
1788 sc.sendWindowUpdate32(st, pad)
1789 }
1790 if f.StreamEnded() {
1791 st.endStream()
1792 }
1793 return nil
1794 }
1795
1796 func (sc *serverConn) processGoAway(f *GoAwayFrame) error {
1797 sc.serveG.check()
1798 if f.ErrCode != ErrCodeNo {
1799 sc.logf("http2: received GOAWAY %+v, starting graceful shutdown", f)
1800 } else {
1801 sc.vlogf("http2: received GOAWAY %+v, starting graceful shutdown", f)
1802 }
1803 sc.startGracefulShutdownInternal()
1804
1805
1806 sc.pushEnabled = false
1807 return nil
1808 }
1809
1810
1811 func (st *stream) isPushed() bool {
1812 return st.id%2 == 0
1813 }
1814
1815
1816
1817 func (st *stream) endStream() {
1818 sc := st.sc
1819 sc.serveG.check()
1820
1821 if st.declBodyBytes != -1 && st.declBodyBytes != st.bodyBytes {
1822 st.body.CloseWithError(fmt.Errorf("request declared a Content-Length of %d but only wrote %d bytes",
1823 st.declBodyBytes, st.bodyBytes))
1824 } else {
1825 st.body.closeWithErrorAndCode(io.EOF, st.copyTrailersToHandlerRequest)
1826 st.body.CloseWithError(io.EOF)
1827 }
1828 st.state = stateHalfClosedRemote
1829 }
1830
1831
1832
1833 func (st *stream) copyTrailersToHandlerRequest() {
1834 for k, vv := range st.trailer {
1835 if _, ok := st.reqTrailer[k]; ok {
1836
1837 st.reqTrailer[k] = vv
1838 }
1839 }
1840 }
1841
1842
1843
1844 func (st *stream) onReadTimeout() {
1845 if st.body != nil {
1846
1847
1848 st.body.CloseWithError(fmt.Errorf("%w", os.ErrDeadlineExceeded))
1849 }
1850 }
1851
1852
1853
1854 func (st *stream) onWriteTimeout() {
1855 st.sc.writeFrameFromHandler(FrameWriteRequest{write: StreamError{
1856 StreamID: st.id,
1857 Code: ErrCodeInternal,
1858 Cause: os.ErrDeadlineExceeded,
1859 }})
1860 }
1861
1862 func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
1863 sc.serveG.check()
1864 id := f.StreamID
1865
1866
1867
1868
1869
1870 if id%2 != 1 {
1871 return sc.countError("headers_even", ConnectionError(ErrCodeProtocol))
1872 }
1873
1874
1875
1876
1877 if st := sc.streams[f.StreamID]; st != nil {
1878 if st.resetQueued {
1879
1880
1881 return nil
1882 }
1883
1884
1885
1886
1887 if st.state == stateHalfClosedRemote {
1888 return sc.countError("headers_half_closed", streamError(id, ErrCodeStreamClosed))
1889 }
1890 return st.processTrailerHeaders(f)
1891 }
1892
1893
1894
1895
1896
1897
1898 if id <= sc.maxClientStreamID {
1899 return sc.countError("stream_went_down", ConnectionError(ErrCodeProtocol))
1900 }
1901 sc.maxClientStreamID = id
1902
1903 if sc.idleTimer != nil {
1904 sc.idleTimer.Stop()
1905 }
1906
1907
1908
1909
1910
1911
1912
1913 if sc.curClientStreams+1 > sc.advMaxStreams {
1914 if sc.unackedSettings == 0 {
1915
1916 return sc.countError("over_max_streams", streamError(id, ErrCodeProtocol))
1917 }
1918
1919
1920
1921
1922
1923 return sc.countError("over_max_streams_race", streamError(id, ErrCodeRefusedStream))
1924 }
1925
1926 initialState := stateOpen
1927 if f.StreamEnded() {
1928 initialState = stateHalfClosedRemote
1929 }
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939 initialPriority := defaultRFC9218Priority(sc.priorityAware && !sc.hasIntermediary)
1940 if _, ok := sc.writeSched.(*priorityWriteSchedulerRFC9218); ok && !sc.hasIntermediary {
1941 headerPriority, priorityAware, hasIntermediary := f.rfc9218Priority(sc.priorityAware)
1942 initialPriority = headerPriority
1943 sc.hasIntermediary = hasIntermediary
1944 if priorityAware {
1945 sc.priorityAware = true
1946 }
1947 }
1948 st := sc.newStream(id, 0, initialState, initialPriority)
1949
1950 if f.HasPriority() {
1951 if err := sc.checkPriority(f.StreamID, f.Priority); err != nil {
1952 return err
1953 }
1954 if !sc.writeSchedIgnoresRFC7540() {
1955 sc.writeSched.AdjustStream(st.id, f.Priority)
1956 }
1957 }
1958
1959 rw, req, err := sc.newWriterAndRequest(st, f)
1960 if err != nil {
1961 return err
1962 }
1963 st.reqTrailer = req.Trailer
1964 if st.reqTrailer != nil {
1965 st.trailer = make(Header)
1966 }
1967 st.body = req.Body.(*requestBody).pipe
1968 st.declBodyBytes = req.ContentLength
1969
1970 handler := sc.handler.ServeHTTP
1971 if f.Truncated {
1972
1973 handler = handleHeaderListTooLong
1974 } else if err := checkValidHTTP2RequestHeaders(req.Header); err != nil {
1975 handler = serve400Handler{err}.ServeHTTP
1976 }
1977
1978
1979
1980
1981
1982
1983
1984
1985 if sc.hs.ReadTimeout() > 0 {
1986 sc.conn.SetReadDeadline(time.Time{})
1987 st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout(), st.onReadTimeout)
1988 }
1989
1990 return sc.scheduleHandler(id, rw, req, handler)
1991 }
1992
1993 func (sc *serverConn) upgradeRequest(req *ServerRequest) {
1994 sc.serveG.check()
1995 id := uint32(1)
1996 sc.maxClientStreamID = id
1997 st := sc.newStream(id, 0, stateHalfClosedRemote, defaultRFC9218Priority(sc.priorityAware && !sc.hasIntermediary))
1998 st.reqTrailer = req.Trailer
1999 if st.reqTrailer != nil {
2000 st.trailer = make(Header)
2001 }
2002 rw := sc.newResponseWriter(st)
2003 rw.rws.req = *req
2004 req = &rw.rws.req
2005
2006
2007
2008 if sc.hs.ReadTimeout() > 0 {
2009 sc.conn.SetReadDeadline(time.Time{})
2010 }
2011
2012
2013
2014
2015 sc.curHandlers++
2016 go sc.runHandler(rw, req, sc.handler.ServeHTTP)
2017 }
2018
2019 func (st *stream) processTrailerHeaders(f *MetaHeadersFrame) error {
2020 sc := st.sc
2021 sc.serveG.check()
2022 if st.gotTrailerHeader {
2023 return sc.countError("dup_trailers", ConnectionError(ErrCodeProtocol))
2024 }
2025 st.gotTrailerHeader = true
2026 if !f.StreamEnded() {
2027 return sc.countError("trailers_not_ended", streamError(st.id, ErrCodeProtocol))
2028 }
2029
2030 if len(f.PseudoFields()) > 0 {
2031 return sc.countError("trailers_pseudo", streamError(st.id, ErrCodeProtocol))
2032 }
2033 if st.trailer != nil {
2034 for _, hf := range f.RegularFields() {
2035 key := sc.canonicalHeader(hf.Name)
2036 if !httpguts.ValidTrailerHeader(key) {
2037
2038
2039
2040 return sc.countError("trailers_bogus", streamError(st.id, ErrCodeProtocol))
2041 }
2042 st.trailer[key] = append(st.trailer[key], hf.Value)
2043 }
2044 }
2045 st.endStream()
2046 return nil
2047 }
2048
2049 func (sc *serverConn) checkPriority(streamID uint32, p PriorityParam) error {
2050 if streamID == p.StreamDep {
2051
2052
2053
2054
2055 return sc.countError("priority", streamError(streamID, ErrCodeProtocol))
2056 }
2057 return nil
2058 }
2059
2060 func (sc *serverConn) processPriority(f *PriorityFrame) error {
2061 if err := sc.checkPriority(f.StreamID, f.PriorityParam); err != nil {
2062 return err
2063 }
2064
2065
2066
2067
2068
2069 if sc.writeSchedIgnoresRFC7540() {
2070 return nil
2071 }
2072 sc.writeSched.AdjustStream(f.StreamID, f.PriorityParam)
2073 return nil
2074 }
2075
2076 func (sc *serverConn) processPriorityUpdate(f *PriorityUpdateFrame) error {
2077 sc.priorityAware = true
2078 if _, ok := sc.writeSched.(*priorityWriteSchedulerRFC9218); !ok {
2079 return nil
2080 }
2081 p, ok := parseRFC9218Priority(f.Priority, sc.priorityAware)
2082 if !ok {
2083 return sc.countError("unparsable_priority_update", streamError(f.PrioritizedStreamID, ErrCodeProtocol))
2084 }
2085 sc.writeSched.AdjustStream(f.PrioritizedStreamID, p)
2086 return nil
2087 }
2088
2089 func (sc *serverConn) newStream(id, pusherID uint32, state streamState, priority PriorityParam) *stream {
2090 sc.serveG.check()
2091 if id == 0 {
2092 panic("internal error: cannot create stream with id 0")
2093 }
2094
2095 ctx, cancelCtx := context.WithCancel(sc.baseCtx)
2096 st := &stream{
2097 sc: sc,
2098 id: id,
2099 state: state,
2100 ctx: ctx,
2101 cancelCtx: cancelCtx,
2102 }
2103 st.cw.Init()
2104 st.flow.conn = &sc.flow
2105 st.flow.add(sc.initialStreamSendWindowSize)
2106 st.inflow.init(sc.initialStreamRecvWindowSize)
2107 if writeTimeout := sc.hs.WriteTimeout(); writeTimeout > 0 {
2108 st.writeDeadline = time.AfterFunc(writeTimeout, st.onWriteTimeout)
2109 }
2110
2111 sc.streams[id] = st
2112 sc.writeSched.OpenStream(st.id, OpenStreamOptions{PusherID: pusherID, priority: priority})
2113 if st.isPushed() {
2114 sc.curPushedStreams++
2115 } else {
2116 sc.curClientStreams++
2117 }
2118 if sc.curOpenStreams() == 1 {
2119 sc.setConnState(ConnStateActive)
2120 }
2121
2122 return st
2123 }
2124
2125 func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*responseWriter, *ServerRequest, error) {
2126 sc.serveG.check()
2127
2128 rp := httpcommon.ServerRequestParam{
2129 Method: f.PseudoValue("method"),
2130 Scheme: f.PseudoValue("scheme"),
2131 Authority: f.PseudoValue("authority"),
2132 Path: f.PseudoValue("path"),
2133 Protocol: f.PseudoValue("protocol"),
2134 }
2135
2136
2137 if disableExtendedConnectProtocol && rp.Protocol != "" {
2138 return nil, nil, sc.countError("bad_connect", streamError(f.StreamID, ErrCodeProtocol))
2139 }
2140
2141 isConnect := rp.Method == "CONNECT"
2142 if isConnect {
2143 if rp.Protocol == "" && (rp.Path != "" || rp.Scheme != "" || rp.Authority == "") {
2144 return nil, nil, sc.countError("bad_connect", streamError(f.StreamID, ErrCodeProtocol))
2145 }
2146 } else if rp.Method == "" || rp.Path == "" || (rp.Scheme != "https" && rp.Scheme != "http") {
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157 return nil, nil, sc.countError("bad_path_method", streamError(f.StreamID, ErrCodeProtocol))
2158 }
2159
2160 header := make(Header)
2161 rp.Header = header
2162 for _, hf := range f.RegularFields() {
2163 header.Add(sc.canonicalHeader(hf.Name), hf.Value)
2164 }
2165 if rp.Authority == "" {
2166 rp.Authority = header.Get("Host")
2167 }
2168 if rp.Protocol != "" {
2169 header.Set(":protocol", rp.Protocol)
2170 }
2171
2172 rw, req, err := sc.newWriterAndRequestNoBody(st, rp)
2173 if err != nil {
2174 return nil, nil, err
2175 }
2176 bodyOpen := !f.StreamEnded()
2177 if bodyOpen {
2178 if vv, ok := rp.Header["Content-Length"]; ok {
2179 if cl, err := strconv.ParseUint(vv[0], 10, 63); err == nil {
2180 req.ContentLength = int64(cl)
2181 } else {
2182 req.ContentLength = 0
2183 }
2184 } else {
2185 req.ContentLength = -1
2186 }
2187 req.Body.(*requestBody).pipe = &pipe{
2188 b: &dataBuffer{expected: req.ContentLength},
2189 }
2190 }
2191 return rw, req, nil
2192 }
2193
2194 func (sc *serverConn) newWriterAndRequestNoBody(st *stream, rp httpcommon.ServerRequestParam) (*responseWriter, *ServerRequest, error) {
2195 sc.serveG.check()
2196
2197 var tlsState *tls.ConnectionState
2198 if rp.Scheme == "https" {
2199 tlsState = sc.tlsState
2200 }
2201
2202 res := httpcommon.NewServerRequest(rp)
2203 if res.InvalidReason != "" {
2204 return nil, nil, sc.countError(res.InvalidReason, streamError(st.id, ErrCodeProtocol))
2205 }
2206
2207 body := &requestBody{
2208 conn: sc,
2209 stream: st,
2210 needsContinue: res.NeedsContinue,
2211 }
2212 rw := sc.newResponseWriter(st)
2213 rw.rws.req = ServerRequest{
2214 Context: st.ctx,
2215 Method: rp.Method,
2216 URL: res.URL,
2217 RemoteAddr: sc.remoteAddrStr,
2218 Header: rp.Header,
2219 RequestURI: res.RequestURI,
2220 Proto: "HTTP/2.0",
2221 ProtoMajor: 2,
2222 ProtoMinor: 0,
2223 TLS: tlsState,
2224 Host: rp.Authority,
2225 Body: body,
2226 Trailer: res.Trailer,
2227 }
2228 return rw, &rw.rws.req, nil
2229 }
2230
2231 func (sc *serverConn) newResponseWriter(st *stream) *responseWriter {
2232 rws := responseWriterStatePool.Get().(*responseWriterState)
2233 bwSave := rws.bw
2234 *rws = responseWriterState{}
2235 rws.conn = sc
2236 rws.bw = bwSave
2237 rws.bw.Reset(chunkWriter{rws})
2238 rws.stream = st
2239 return &responseWriter{rws: rws}
2240 }
2241
2242 type unstartedHandler struct {
2243 streamID uint32
2244 rw *responseWriter
2245 req *ServerRequest
2246 handler func(*ResponseWriter, *ServerRequest)
2247 }
2248
2249
2250
2251 func (sc *serverConn) scheduleHandler(streamID uint32, rw *responseWriter, req *ServerRequest, handler func(*ResponseWriter, *ServerRequest)) error {
2252 sc.serveG.check()
2253 maxHandlers := sc.advMaxStreams
2254 if sc.curHandlers < maxHandlers {
2255 sc.curHandlers++
2256 go sc.runHandler(rw, req, handler)
2257 return nil
2258 }
2259 if len(sc.unstartedHandlers) > int(4*sc.advMaxStreams) {
2260 return sc.countError("too_many_early_resets", ConnectionError(ErrCodeEnhanceYourCalm))
2261 }
2262 sc.unstartedHandlers = append(sc.unstartedHandlers, unstartedHandler{
2263 streamID: streamID,
2264 rw: rw,
2265 req: req,
2266 handler: handler,
2267 })
2268 return nil
2269 }
2270
2271 func (sc *serverConn) handlerDone() {
2272 sc.serveG.check()
2273 sc.curHandlers--
2274 i := 0
2275 maxHandlers := sc.advMaxStreams
2276 for ; i < len(sc.unstartedHandlers); i++ {
2277 u := sc.unstartedHandlers[i]
2278 if sc.streams[u.streamID] == nil {
2279
2280 continue
2281 }
2282 if sc.curHandlers >= maxHandlers {
2283 break
2284 }
2285 sc.curHandlers++
2286 go sc.runHandler(u.rw, u.req, u.handler)
2287 sc.unstartedHandlers[i] = unstartedHandler{}
2288 }
2289 sc.unstartedHandlers = sc.unstartedHandlers[i:]
2290 if len(sc.unstartedHandlers) == 0 {
2291 sc.unstartedHandlers = nil
2292 }
2293 }
2294
2295
2296 func (sc *serverConn) runHandler(rw *responseWriter, req *ServerRequest, handler func(*ResponseWriter, *ServerRequest)) {
2297 defer sc.sendServeMsg(handlerDoneMsg)
2298 didPanic := true
2299 defer func() {
2300 rw.rws.stream.cancelCtx()
2301 if req.MultipartForm != nil {
2302 req.MultipartForm.RemoveAll()
2303 }
2304 if didPanic {
2305 e := recover()
2306 sc.writeFrameFromHandler(FrameWriteRequest{
2307 write: handlerPanicRST{rw.rws.stream.id},
2308 stream: rw.rws.stream,
2309 })
2310
2311 if e != nil && e != ErrAbortHandler {
2312 const size = 64 << 10
2313 buf := make([]byte, size)
2314 buf = buf[:runtime.Stack(buf, false)]
2315 sc.logf("http2: panic serving %v: %v\n%s", sc.conn.RemoteAddr(), e, buf)
2316 }
2317 return
2318 }
2319 rw.handlerDone()
2320 }()
2321 handler(rw, req)
2322 didPanic = false
2323 }
2324
2325 func handleHeaderListTooLong(w *ResponseWriter, r *ServerRequest) {
2326
2327
2328
2329
2330 const statusRequestHeaderFieldsTooLarge = 431
2331 w.WriteHeader(statusRequestHeaderFieldsTooLarge)
2332 io.WriteString(w, "<h1>HTTP Error 431</h1><p>Request Header Field(s) Too Large</p>")
2333 }
2334
2335
2336
2337 func (sc *serverConn) writeHeaders(st *stream, headerData *writeResHeaders) error {
2338 sc.serveG.checkNotOn()
2339 var errc chan error
2340 if headerData.h != nil {
2341
2342
2343
2344
2345 errc = sc.srv.getErrChan()
2346 }
2347 if err := sc.writeFrameFromHandler(FrameWriteRequest{
2348 write: headerData,
2349 stream: st,
2350 done: errc,
2351 }); err != nil {
2352 return err
2353 }
2354 if errc != nil {
2355 select {
2356 case err := <-errc:
2357 sc.srv.putErrChan(errc)
2358 return err
2359 case <-sc.doneServing:
2360 return errClientDisconnected
2361 case <-st.cw:
2362 return errStreamClosed
2363 }
2364 }
2365 return nil
2366 }
2367
2368
2369 func (sc *serverConn) write100ContinueHeaders(st *stream) {
2370 sc.writeFrameFromHandler(FrameWriteRequest{
2371 write: write100ContinueHeadersFrame{st.id},
2372 stream: st,
2373 })
2374 }
2375
2376
2377
2378 type bodyReadMsg struct {
2379 st *stream
2380 n int
2381 }
2382
2383
2384
2385
2386 func (sc *serverConn) noteBodyReadFromHandler(st *stream, n int, err error) {
2387 sc.serveG.checkNotOn()
2388 if n > 0 {
2389 select {
2390 case sc.bodyReadCh <- bodyReadMsg{st, n}:
2391 case <-sc.doneServing:
2392 }
2393 }
2394 }
2395
2396 func (sc *serverConn) noteBodyRead(st *stream, n int) {
2397 sc.serveG.check()
2398 sc.sendWindowUpdate(nil, n)
2399 if st.state != stateHalfClosedRemote && st.state != stateClosed {
2400
2401
2402 sc.sendWindowUpdate(st, n)
2403 }
2404 }
2405
2406
2407 func (sc *serverConn) sendWindowUpdate32(st *stream, n int32) {
2408 sc.sendWindowUpdate(st, int(n))
2409 }
2410
2411
2412 func (sc *serverConn) sendWindowUpdate(st *stream, n int) {
2413 sc.serveG.check()
2414 var streamID uint32
2415 var send int32
2416 if st == nil {
2417 send = sc.inflow.add(n)
2418 } else {
2419 streamID = st.id
2420 send = st.inflow.add(n)
2421 }
2422 if send == 0 {
2423 return
2424 }
2425 sc.writeFrame(FrameWriteRequest{
2426 write: writeWindowUpdate{streamID: streamID, n: uint32(send)},
2427 stream: st,
2428 })
2429 }
2430
2431
2432
2433 type requestBody struct {
2434 _ incomparable
2435 stream *stream
2436 conn *serverConn
2437 closeOnce sync.Once
2438 sawEOF bool
2439 pipe *pipe
2440 needsContinue bool
2441 }
2442
2443 func (b *requestBody) Close() error {
2444 b.closeOnce.Do(func() {
2445 if b.pipe != nil {
2446 b.pipe.BreakWithError(errClosedBody)
2447 }
2448 })
2449 return nil
2450 }
2451
2452 func (b *requestBody) Read(p []byte) (n int, err error) {
2453 if b.needsContinue {
2454 b.needsContinue = false
2455 b.conn.write100ContinueHeaders(b.stream)
2456 }
2457 if b.pipe == nil || b.sawEOF {
2458 return 0, io.EOF
2459 }
2460 n, err = b.pipe.Read(p)
2461 if err == io.EOF {
2462 b.sawEOF = true
2463 }
2464 if b.conn == nil {
2465 return
2466 }
2467 b.conn.noteBodyReadFromHandler(b.stream, n, err)
2468 return
2469 }
2470
2471
2472
2473
2474
2475
2476
2477 type responseWriter struct {
2478 rws *responseWriterState
2479 }
2480
2481 type responseWriterState struct {
2482
2483 stream *stream
2484 req ServerRequest
2485 conn *serverConn
2486
2487
2488 bw *bufio.Writer
2489
2490
2491 handlerHeader Header
2492 snapHeader Header
2493 trailers []string
2494 status int
2495 wroteHeader bool
2496 sentHeader bool
2497 handlerDone bool
2498
2499 sentContentLen int64
2500 wroteBytes int64
2501
2502 closeNotifierMu sync.Mutex
2503 closeNotifierCh chan bool
2504 }
2505
2506 type chunkWriter struct{ rws *responseWriterState }
2507
2508 func (cw chunkWriter) Write(p []byte) (n int, err error) {
2509 n, err = cw.rws.writeChunk(p)
2510 if err == errStreamClosed {
2511
2512
2513 err = cw.rws.stream.closeErr
2514 }
2515 return n, err
2516 }
2517
2518 func (rws *responseWriterState) hasTrailers() bool { return len(rws.trailers) > 0 }
2519
2520 func (rws *responseWriterState) hasNonemptyTrailers() bool {
2521 for _, trailer := range rws.trailers {
2522 if _, ok := rws.handlerHeader[trailer]; ok {
2523 return true
2524 }
2525 }
2526 return false
2527 }
2528
2529
2530
2531
2532 func (rws *responseWriterState) declareTrailer(k string) {
2533 k = textproto.CanonicalMIMEHeaderKey(k)
2534 if !httpguts.ValidTrailerHeader(k) {
2535
2536 rws.conn.logf("ignoring invalid trailer %q", k)
2537 return
2538 }
2539 if !slices.Contains(rws.trailers, k) {
2540 rws.trailers = append(rws.trailers, k)
2541 }
2542 }
2543
2544 const TimeFormat = "Mon, 02 Jan 2006 15:04:05 GMT"
2545
2546
2547
2548
2549
2550
2551
2552 func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
2553 if !rws.wroteHeader {
2554 rws.writeHeader(200)
2555 }
2556
2557 if rws.handlerDone {
2558 rws.promoteUndeclaredTrailers()
2559 }
2560
2561 isHeadResp := rws.req.Method == "HEAD"
2562 if !rws.sentHeader {
2563 rws.sentHeader = true
2564 var ctype, clen string
2565 if clen = rws.snapHeader.Get("Content-Length"); clen != "" {
2566 rws.snapHeader.Del("Content-Length")
2567 if cl, err := strconv.ParseUint(clen, 10, 63); err == nil {
2568 rws.sentContentLen = int64(cl)
2569 } else {
2570 clen = ""
2571 }
2572 }
2573 _, hasContentLength := rws.snapHeader["Content-Length"]
2574 if !hasContentLength && clen == "" && rws.handlerDone && bodyAllowedForStatus(rws.status) && (len(p) > 0 || !isHeadResp) {
2575 clen = strconv.Itoa(len(p))
2576 }
2577 _, hasContentType := rws.snapHeader["Content-Type"]
2578
2579
2580 ce := rws.snapHeader.Get("Content-Encoding")
2581 hasCE := len(ce) > 0
2582 if !hasCE && !hasContentType && bodyAllowedForStatus(rws.status) && len(p) > 0 {
2583 ctype = internal.DetectContentType(p)
2584 }
2585 var date string
2586 if _, ok := rws.snapHeader["Date"]; !ok {
2587
2588 date = time.Now().UTC().Format(TimeFormat)
2589 }
2590
2591 for _, v := range rws.snapHeader["Trailer"] {
2592 foreachHeaderElement(v, rws.declareTrailer)
2593 }
2594
2595
2596
2597
2598
2599
2600 if _, ok := rws.snapHeader["Connection"]; ok {
2601 v := rws.snapHeader.Get("Connection")
2602 delete(rws.snapHeader, "Connection")
2603 if v == "close" {
2604 rws.conn.startGracefulShutdown()
2605 }
2606 }
2607
2608 endStream := (rws.handlerDone && !rws.hasTrailers() && len(p) == 0) || isHeadResp
2609 err = rws.conn.writeHeaders(rws.stream, &writeResHeaders{
2610 streamID: rws.stream.id,
2611 httpResCode: rws.status,
2612 h: rws.snapHeader,
2613 endStream: endStream,
2614 contentType: ctype,
2615 contentLength: clen,
2616 date: date,
2617 })
2618 if err != nil {
2619 return 0, err
2620 }
2621 if endStream {
2622 return 0, nil
2623 }
2624 }
2625 if isHeadResp {
2626 return len(p), nil
2627 }
2628 if len(p) == 0 && !rws.handlerDone {
2629 return 0, nil
2630 }
2631
2632
2633
2634 hasNonemptyTrailers := rws.hasNonemptyTrailers()
2635 endStream := rws.handlerDone && !hasNonemptyTrailers
2636 if len(p) > 0 || endStream {
2637
2638 if err := rws.conn.writeDataFromHandler(rws.stream, p, endStream); err != nil {
2639 return 0, err
2640 }
2641 }
2642
2643 if rws.handlerDone && hasNonemptyTrailers {
2644 err = rws.conn.writeHeaders(rws.stream, &writeResHeaders{
2645 streamID: rws.stream.id,
2646 h: rws.handlerHeader,
2647 trailers: rws.trailers,
2648 endStream: true,
2649 })
2650 return len(p), err
2651 }
2652 return len(p), nil
2653 }
2654
2655
2656
2657
2658
2659
2660
2661
2662
2663
2664
2665
2666
2667
2668 const TrailerPrefix = "Trailer:"
2669
2670
2671
2672
2673
2674
2675
2676
2677
2678
2679
2680
2681
2682
2683
2684
2685
2686
2687
2688
2689
2690
2691 func (rws *responseWriterState) promoteUndeclaredTrailers() {
2692 for k, vv := range rws.handlerHeader {
2693 if !strings.HasPrefix(k, TrailerPrefix) {
2694 continue
2695 }
2696 trailerKey := strings.TrimPrefix(k, TrailerPrefix)
2697 rws.declareTrailer(trailerKey)
2698 rws.handlerHeader[textproto.CanonicalMIMEHeaderKey(trailerKey)] = vv
2699 }
2700
2701 if len(rws.trailers) > 1 {
2702 slices.Sort(rws.trailers)
2703 }
2704 }
2705
2706 func (w *responseWriter) SetReadDeadline(deadline time.Time) error {
2707 st := w.rws.stream
2708 if !deadline.IsZero() && deadline.Before(time.Now()) {
2709
2710
2711 st.onReadTimeout()
2712 return nil
2713 }
2714 w.rws.conn.sendServeMsg(func(sc *serverConn) {
2715 if st.readDeadline != nil {
2716 if !st.readDeadline.Stop() {
2717
2718 return
2719 }
2720 }
2721 if deadline.IsZero() {
2722 st.readDeadline = nil
2723 } else if st.readDeadline == nil {
2724 st.readDeadline = time.AfterFunc(deadline.Sub(time.Now()), st.onReadTimeout)
2725 } else {
2726 st.readDeadline.Reset(deadline.Sub(time.Now()))
2727 }
2728 })
2729 return nil
2730 }
2731
2732 func (w *responseWriter) SetWriteDeadline(deadline time.Time) error {
2733 st := w.rws.stream
2734 if !deadline.IsZero() && deadline.Before(time.Now()) {
2735
2736
2737 st.onWriteTimeout()
2738 return nil
2739 }
2740 w.rws.conn.sendServeMsg(func(sc *serverConn) {
2741 if st.writeDeadline != nil {
2742 if !st.writeDeadline.Stop() {
2743
2744 return
2745 }
2746 }
2747 if deadline.IsZero() {
2748 st.writeDeadline = nil
2749 } else if st.writeDeadline == nil {
2750 st.writeDeadline = time.AfterFunc(deadline.Sub(time.Now()), st.onWriteTimeout)
2751 } else {
2752 st.writeDeadline.Reset(deadline.Sub(time.Now()))
2753 }
2754 })
2755 return nil
2756 }
2757
2758 func (w *responseWriter) EnableFullDuplex() error {
2759
2760 return nil
2761 }
2762
2763 func (w *responseWriter) Flush() {
2764 w.FlushError()
2765 }
2766
2767 func (w *responseWriter) FlushError() error {
2768 rws := w.rws
2769 if rws == nil {
2770 panic("Header called after Handler finished")
2771 }
2772 var err error
2773 if rws.bw.Buffered() > 0 {
2774 err = rws.bw.Flush()
2775 } else {
2776
2777
2778
2779
2780 _, err = chunkWriter{rws}.Write(nil)
2781 if err == nil {
2782 select {
2783 case <-rws.stream.cw:
2784 err = rws.stream.closeErr
2785 default:
2786 }
2787 }
2788 }
2789 return err
2790 }
2791
2792 func (w *responseWriter) CloseNotify() <-chan bool {
2793 rws := w.rws
2794 if rws == nil {
2795 panic("CloseNotify called after Handler finished")
2796 }
2797 rws.closeNotifierMu.Lock()
2798 ch := rws.closeNotifierCh
2799 if ch == nil {
2800 ch = make(chan bool, 1)
2801 rws.closeNotifierCh = ch
2802 cw := rws.stream.cw
2803 go func() {
2804 cw.Wait()
2805 ch <- true
2806 }()
2807 }
2808 rws.closeNotifierMu.Unlock()
2809 return ch
2810 }
2811
2812 func (w *responseWriter) Header() Header {
2813 rws := w.rws
2814 if rws == nil {
2815 panic("Header called after Handler finished")
2816 }
2817 if rws.handlerHeader == nil {
2818 rws.handlerHeader = make(Header)
2819 }
2820 return rws.handlerHeader
2821 }
2822
2823
2824 func checkWriteHeaderCode(code int) {
2825
2826
2827
2828
2829
2830
2831
2832
2833
2834
2835 if code < 100 || code > 999 {
2836 panic(fmt.Sprintf("invalid WriteHeader code %v", code))
2837 }
2838 }
2839
2840 func (w *responseWriter) WriteHeader(code int) {
2841 rws := w.rws
2842 if rws == nil {
2843 panic("WriteHeader called after Handler finished")
2844 }
2845 rws.writeHeader(code)
2846 }
2847
2848 func (rws *responseWriterState) writeHeader(code int) {
2849 if rws.wroteHeader {
2850 return
2851 }
2852
2853 checkWriteHeaderCode(code)
2854
2855
2856 if code >= 100 && code <= 199 {
2857
2858 h := rws.handlerHeader
2859
2860 _, cl := h["Content-Length"]
2861 _, te := h["Transfer-Encoding"]
2862 if cl || te {
2863 h = cloneHeader(h)
2864 h.Del("Content-Length")
2865 h.Del("Transfer-Encoding")
2866 }
2867
2868 rws.conn.writeHeaders(rws.stream, &writeResHeaders{
2869 streamID: rws.stream.id,
2870 httpResCode: code,
2871 h: h,
2872 endStream: rws.handlerDone && !rws.hasTrailers(),
2873 })
2874
2875 return
2876 }
2877
2878 rws.wroteHeader = true
2879 rws.status = code
2880 if len(rws.handlerHeader) > 0 {
2881 rws.snapHeader = cloneHeader(rws.handlerHeader)
2882 }
2883 }
2884
2885 func cloneHeader(h Header) Header {
2886 h2 := make(Header, len(h))
2887 for k, vv := range h {
2888 vv2 := make([]string, len(vv))
2889 copy(vv2, vv)
2890 h2[k] = vv2
2891 }
2892 return h2
2893 }
2894
2895
2896
2897
2898
2899
2900
2901
2902
2903 func (w *responseWriter) Write(p []byte) (n int, err error) {
2904 return w.write(len(p), p, "")
2905 }
2906
2907 func (w *responseWriter) WriteString(s string) (n int, err error) {
2908 return w.write(len(s), nil, s)
2909 }
2910
2911
2912 func (w *responseWriter) write(lenData int, dataB []byte, dataS string) (n int, err error) {
2913 rws := w.rws
2914 if rws == nil {
2915 panic("Write called after Handler finished")
2916 }
2917 if !rws.wroteHeader {
2918 w.WriteHeader(200)
2919 }
2920 if !bodyAllowedForStatus(rws.status) {
2921 return 0, ErrBodyNotAllowed
2922 }
2923 rws.wroteBytes += int64(len(dataB)) + int64(len(dataS))
2924 if rws.sentContentLen != 0 && rws.wroteBytes > rws.sentContentLen {
2925
2926 return 0, errors.New("http2: handler wrote more than declared Content-Length")
2927 }
2928
2929 if dataB != nil {
2930 return rws.bw.Write(dataB)
2931 } else {
2932 return rws.bw.WriteString(dataS)
2933 }
2934 }
2935
2936 func (w *responseWriter) handlerDone() {
2937 rws := w.rws
2938 rws.handlerDone = true
2939 w.Flush()
2940 w.rws = nil
2941 responseWriterStatePool.Put(rws)
2942 }
2943
2944
2945 var (
2946 ErrRecursivePush = errors.New("http2: recursive push not allowed")
2947 ErrPushLimitReached = errors.New("http2: push would exceed peer's SETTINGS_MAX_CONCURRENT_STREAMS")
2948 )
2949
2950 func (w *responseWriter) Push(target, method string, header Header) error {
2951 st := w.rws.stream
2952 sc := st.sc
2953 sc.serveG.checkNotOn()
2954
2955
2956
2957 if st.isPushed() {
2958 return ErrRecursivePush
2959 }
2960
2961
2962 if method == "" {
2963 method = "GET"
2964 }
2965 if header == nil {
2966 header = Header{}
2967 }
2968 wantScheme := "http"
2969 if w.rws.req.TLS != nil {
2970 wantScheme = "https"
2971 }
2972
2973
2974 u, err := url.Parse(target)
2975 if err != nil {
2976 return err
2977 }
2978 if u.Scheme == "" {
2979 if !strings.HasPrefix(target, "/") {
2980 return fmt.Errorf("target must be an absolute URL or an absolute path: %q", target)
2981 }
2982 u.Scheme = wantScheme
2983 u.Host = w.rws.req.Host
2984 } else {
2985 if u.Scheme != wantScheme {
2986 return fmt.Errorf("cannot push URL with scheme %q from request with scheme %q", u.Scheme, wantScheme)
2987 }
2988 if u.Host == "" {
2989 return errors.New("URL must have a host")
2990 }
2991 }
2992 for k := range header {
2993 if strings.HasPrefix(k, ":") {
2994 return fmt.Errorf("promised request headers cannot include pseudo header %q", k)
2995 }
2996
2997
2998
2999
3000 if asciiEqualFold(k, "content-length") ||
3001 asciiEqualFold(k, "content-encoding") ||
3002 asciiEqualFold(k, "trailer") ||
3003 asciiEqualFold(k, "te") ||
3004 asciiEqualFold(k, "expect") ||
3005 asciiEqualFold(k, "host") {
3006 return fmt.Errorf("promised request headers cannot include %q", k)
3007 }
3008 }
3009 if err := checkValidHTTP2RequestHeaders(header); err != nil {
3010 return err
3011 }
3012
3013
3014
3015
3016 if method != "GET" && method != "HEAD" {
3017 return fmt.Errorf("method %q must be GET or HEAD", method)
3018 }
3019
3020 msg := &startPushRequest{
3021 parent: st,
3022 method: method,
3023 url: u,
3024 header: cloneHeader(header),
3025 done: sc.srv.getErrChan(),
3026 }
3027
3028 select {
3029 case <-sc.doneServing:
3030 return errClientDisconnected
3031 case <-st.cw:
3032 return errStreamClosed
3033 case sc.serveMsgCh <- msg:
3034 }
3035
3036 select {
3037 case <-sc.doneServing:
3038 return errClientDisconnected
3039 case <-st.cw:
3040 return errStreamClosed
3041 case err := <-msg.done:
3042 sc.srv.putErrChan(msg.done)
3043 return err
3044 }
3045 }
3046
3047 type startPushRequest struct {
3048 parent *stream
3049 method string
3050 url *url.URL
3051 header Header
3052 done chan error
3053 }
3054
3055 func (sc *serverConn) startPush(msg *startPushRequest) {
3056 sc.serveG.check()
3057
3058
3059
3060
3061 if msg.parent.state != stateOpen && msg.parent.state != stateHalfClosedRemote {
3062
3063 msg.done <- errStreamClosed
3064 return
3065 }
3066
3067
3068 if !sc.pushEnabled {
3069 msg.done <- ErrNotSupported
3070 return
3071 }
3072
3073
3074
3075
3076 allocatePromisedID := func() (uint32, error) {
3077 sc.serveG.check()
3078
3079
3080
3081 if !sc.pushEnabled {
3082 return 0, ErrNotSupported
3083 }
3084
3085 if sc.curPushedStreams+1 > sc.clientMaxStreams {
3086 return 0, ErrPushLimitReached
3087 }
3088
3089
3090
3091
3092
3093 if sc.maxPushPromiseID+2 >= 1<<31 {
3094 sc.startGracefulShutdownInternal()
3095 return 0, ErrPushLimitReached
3096 }
3097 sc.maxPushPromiseID += 2
3098 promisedID := sc.maxPushPromiseID
3099
3100
3101
3102
3103
3104
3105 promised := sc.newStream(promisedID, msg.parent.id, stateHalfClosedRemote, defaultRFC9218Priority(sc.priorityAware && !sc.hasIntermediary))
3106 rw, req, err := sc.newWriterAndRequestNoBody(promised, httpcommon.ServerRequestParam{
3107 Method: msg.method,
3108 Scheme: msg.url.Scheme,
3109 Authority: msg.url.Host,
3110 Path: msg.url.RequestURI(),
3111 Header: cloneHeader(msg.header),
3112 })
3113 if err != nil {
3114
3115 panic(fmt.Sprintf("newWriterAndRequestNoBody(%+v): %v", msg.url, err))
3116 }
3117
3118 sc.curHandlers++
3119 go sc.runHandler(rw, req, sc.handler.ServeHTTP)
3120 return promisedID, nil
3121 }
3122
3123 sc.writeFrame(FrameWriteRequest{
3124 write: &writePushPromise{
3125 streamID: msg.parent.id,
3126 method: msg.method,
3127 url: msg.url,
3128 h: msg.header,
3129 allocatePromisedID: allocatePromisedID,
3130 },
3131 stream: msg.parent,
3132 done: msg.done,
3133 })
3134 }
3135
3136
3137
3138 func foreachHeaderElement(v string, fn func(string)) {
3139 v = textproto.TrimString(v)
3140 if v == "" {
3141 return
3142 }
3143 if !strings.Contains(v, ",") {
3144 fn(v)
3145 return
3146 }
3147 for f := range strings.SplitSeq(v, ",") {
3148 if f = textproto.TrimString(f); f != "" {
3149 fn(f)
3150 }
3151 }
3152 }
3153
3154
3155 var connHeaders = []string{
3156 "Connection",
3157 "Keep-Alive",
3158 "Proxy-Connection",
3159 "Transfer-Encoding",
3160 "Upgrade",
3161 }
3162
3163
3164
3165
3166 func checkValidHTTP2RequestHeaders(h Header) error {
3167 for _, k := range connHeaders {
3168 if _, ok := h[k]; ok {
3169 return fmt.Errorf("request header %q is not valid in HTTP/2", k)
3170 }
3171 }
3172 te := h["Te"]
3173 if len(te) > 0 && (len(te) > 1 || (te[0] != "trailers" && te[0] != "")) {
3174 return errors.New(`request header "TE" may only be "trailers" in HTTP/2`)
3175 }
3176 return nil
3177 }
3178
3179 type serve400Handler struct {
3180 err error
3181 }
3182
3183 func (handler serve400Handler) ServeHTTP(w *ResponseWriter, r *ServerRequest) {
3184 const statusBadRequest = 400
3185
3186
3187 h := w.Header()
3188 h.Del("Content-Length")
3189 h.Set("Content-Type", "text/plain; charset=utf-8")
3190 h.Set("X-Content-Type-Options", "nosniff")
3191 w.WriteHeader(statusBadRequest)
3192 fmt.Fprintln(w, handler.err.Error())
3193 }
3194
3195
3196
3197
3198 func h1ServerKeepAlivesDisabled(hs ServerConfig) bool {
3199 return !hs.DoKeepAlives()
3200 }
3201
3202 func (sc *serverConn) countError(name string, err error) error {
3203 if sc == nil || sc.srv == nil {
3204 return err
3205 }
3206 f := sc.countErrorFunc
3207 if f == nil {
3208 return err
3209 }
3210 var typ string
3211 var code ErrCode
3212 switch e := err.(type) {
3213 case ConnectionError:
3214 typ = "conn"
3215 code = ErrCode(e)
3216 case StreamError:
3217 typ = "stream"
3218 code = ErrCode(e.Code)
3219 default:
3220 return err
3221 }
3222 codeStr := errCodeName[code]
3223 if codeStr == "" {
3224 codeStr = strconv.Itoa(int(code))
3225 }
3226 f(fmt.Sprintf("%s_%s_%s", typ, codeStr, name))
3227 return err
3228 }
3229
View as plain text