1
2
3
4
5
6
7 package tls
8
9 import (
10 "bytes"
11 "context"
12 "crypto/cipher"
13 "crypto/subtle"
14 "crypto/x509"
15 "errors"
16 "fmt"
17 "hash"
18 "internal/godebug"
19 "io"
20 "net"
21 "sync"
22 "sync/atomic"
23 "time"
24 )
25
26
27
28 type Conn struct {
29
30 conn net.Conn
31 isClient bool
32 handshakeFn func(context.Context) error
33 quic *quicState
34
35
36
37
38 isHandshakeComplete atomic.Bool
39
40 handshakeMutex sync.Mutex
41 handshakeErr error
42 vers uint16
43 haveVers bool
44 config *Config
45
46
47
48 handshakes int
49 extMasterSecret bool
50 didResume bool
51 didHRR bool
52 cipherSuite uint16
53 curveID CurveID
54 peerSigAlg SignatureScheme
55 ocspResponse []byte
56 scts [][]byte
57 peerCertificates []*x509.Certificate
58
59
60 verifiedChains [][]*x509.Certificate
61
62 serverName string
63
64
65
66 secureRenegotiation bool
67
68 ekm func(label string, context []byte, length int) ([]byte, error)
69
70
71 resumptionSecret []byte
72 echAccepted bool
73
74
75
76
77 ticketKeys []ticketKey
78
79
80
81
82
83 clientFinishedIsFirst bool
84
85
86 closeNotifyErr error
87
88
89 closeNotifySent bool
90
91
92
93
94
95 clientFinished [12]byte
96 serverFinished [12]byte
97
98
99 clientProtocol string
100
101
102 in, out halfConn
103 rawInput bytes.Buffer
104 input bytes.Reader
105 hand bytes.Buffer
106 buffering bool
107 sendBuf []byte
108
109
110
111 bytesSent int64
112 packetsSent int64
113
114
115
116
117 retryCount int
118
119
120
121 activeCall atomic.Int32
122
123 tmp [16]byte
124 }
125
126
127
128
129
130
131 func (c *Conn) LocalAddr() net.Addr {
132 return c.conn.LocalAddr()
133 }
134
135
136 func (c *Conn) RemoteAddr() net.Addr {
137 return c.conn.RemoteAddr()
138 }
139
140
141
142
143 func (c *Conn) SetDeadline(t time.Time) error {
144 return c.conn.SetDeadline(t)
145 }
146
147
148
149 func (c *Conn) SetReadDeadline(t time.Time) error {
150 return c.conn.SetReadDeadline(t)
151 }
152
153
154
155
156 func (c *Conn) SetWriteDeadline(t time.Time) error {
157 return c.conn.SetWriteDeadline(t)
158 }
159
160
161
162
163 func (c *Conn) NetConn() net.Conn {
164 return c.conn
165 }
166
167
168
169 type halfConn struct {
170 sync.Mutex
171
172 err error
173 version uint16
174 cipher any
175 mac hash.Hash
176 seq [8]byte
177
178 scratchBuf [13]byte
179
180 nextCipher any
181 nextMac hash.Hash
182
183 level QUICEncryptionLevel
184 trafficSecret []byte
185 }
186
187 type permanentError struct {
188 err net.Error
189 }
190
191 func (e *permanentError) Error() string { return e.err.Error() }
192 func (e *permanentError) Unwrap() error { return e.err }
193 func (e *permanentError) Timeout() bool { return e.err.Timeout() }
194 func (e *permanentError) Temporary() bool { return false }
195
196 func (hc *halfConn) setErrorLocked(err error) error {
197 if e, ok := err.(net.Error); ok {
198 hc.err = &permanentError{err: e}
199 } else {
200 hc.err = err
201 }
202 return hc.err
203 }
204
205
206
207 func (hc *halfConn) prepareCipherSpec(version uint16, cipher any, mac hash.Hash) {
208 hc.version = version
209 hc.nextCipher = cipher
210 hc.nextMac = mac
211 }
212
213
214
215 func (hc *halfConn) changeCipherSpec() error {
216 if hc.nextCipher == nil || hc.version == VersionTLS13 {
217 return alertInternalError
218 }
219 hc.cipher = hc.nextCipher
220 hc.mac = hc.nextMac
221 hc.nextCipher = nil
222 hc.nextMac = nil
223 for i := range hc.seq {
224 hc.seq[i] = 0
225 }
226 return nil
227 }
228
229 func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, level QUICEncryptionLevel, secret []byte) {
230 hc.trafficSecret = secret
231 hc.level = level
232 key, iv := suite.trafficKey(secret)
233 hc.cipher = suite.aead(key, iv)
234 for i := range hc.seq {
235 hc.seq[i] = 0
236 }
237 }
238
239
240 func (hc *halfConn) incSeq() {
241 for i := 7; i >= 0; i-- {
242 hc.seq[i]++
243 if hc.seq[i] != 0 {
244 return
245 }
246 }
247
248
249
250
251 panic("TLS: sequence number wraparound")
252 }
253
254
255
256
257 func (hc *halfConn) explicitNonceLen() int {
258 if hc.cipher == nil {
259 return 0
260 }
261
262 switch c := hc.cipher.(type) {
263 case cipher.Stream:
264 return 0
265 case aead:
266 return c.explicitNonceLen()
267 case cbcMode:
268
269 if hc.version >= VersionTLS11 {
270 return c.BlockSize()
271 }
272 return 0
273 default:
274 panic("unknown cipher type")
275 }
276 }
277
278
279
280
281 func extractPadding(payload []byte) (toRemove int, good byte) {
282 if len(payload) < 1 {
283 return 0, 0
284 }
285
286 paddingLen := payload[len(payload)-1]
287 t := uint(len(payload)-1) - uint(paddingLen)
288
289 good = byte(int32(^t) >> 31)
290
291
292 toCheck := 256
293
294 if toCheck > len(payload) {
295 toCheck = len(payload)
296 }
297
298 for i := 0; i < toCheck; i++ {
299 t := uint(paddingLen) - uint(i)
300
301 mask := byte(int32(^t) >> 31)
302 b := payload[len(payload)-1-i]
303 good &^= mask&paddingLen ^ mask&b
304 }
305
306
307
308 good &= good << 4
309 good &= good << 2
310 good &= good << 1
311 good = uint8(int8(good) >> 7)
312
313
314
315
316
317
318
319
320
321
322 paddingLen &= good
323
324 toRemove = int(paddingLen) + 1
325 return
326 }
327
328 func roundUp(a, b int) int {
329 return a + (b-a%b)%b
330 }
331
332
333 type cbcMode interface {
334 cipher.BlockMode
335 SetIV([]byte)
336 }
337
338
339
340 func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) {
341 var plaintext []byte
342 typ := recordType(record[0])
343 payload := record[recordHeaderLen:]
344
345
346
347 if hc.version == VersionTLS13 && typ == recordTypeChangeCipherSpec {
348 return payload, typ, nil
349 }
350
351 paddingGood := byte(255)
352 paddingLen := 0
353
354 explicitNonceLen := hc.explicitNonceLen()
355
356 if hc.cipher != nil {
357 switch c := hc.cipher.(type) {
358 case cipher.Stream:
359 c.XORKeyStream(payload, payload)
360 case aead:
361 if len(payload) < explicitNonceLen {
362 return nil, 0, alertBadRecordMAC
363 }
364 nonce := payload[:explicitNonceLen]
365 if len(nonce) == 0 {
366 nonce = hc.seq[:]
367 }
368 payload = payload[explicitNonceLen:]
369
370 var additionalData []byte
371 if hc.version == VersionTLS13 {
372 additionalData = record[:recordHeaderLen]
373 } else {
374 additionalData = append(hc.scratchBuf[:0], hc.seq[:]...)
375 additionalData = append(additionalData, record[:3]...)
376 n := len(payload) - c.Overhead()
377 additionalData = append(additionalData, byte(n>>8), byte(n))
378 }
379
380 var err error
381 plaintext, err = c.Open(payload[:0], nonce, payload, additionalData)
382 if err != nil {
383 return nil, 0, alertBadRecordMAC
384 }
385 case cbcMode:
386 blockSize := c.BlockSize()
387 minPayload := explicitNonceLen + roundUp(hc.mac.Size()+1, blockSize)
388 if len(payload)%blockSize != 0 || len(payload) < minPayload {
389 return nil, 0, alertBadRecordMAC
390 }
391
392 if explicitNonceLen > 0 {
393 c.SetIV(payload[:explicitNonceLen])
394 payload = payload[explicitNonceLen:]
395 }
396 c.CryptBlocks(payload, payload)
397
398
399
400
401
402
403
404 paddingLen, paddingGood = extractPadding(payload)
405 default:
406 panic("unknown cipher type")
407 }
408
409 if hc.version == VersionTLS13 {
410 if typ != recordTypeApplicationData {
411 return nil, 0, alertUnexpectedMessage
412 }
413 if len(plaintext) > maxPlaintext+1 {
414 return nil, 0, alertRecordOverflow
415 }
416
417 for i := len(plaintext) - 1; i >= 0; i-- {
418 if plaintext[i] != 0 {
419 typ = recordType(plaintext[i])
420 plaintext = plaintext[:i]
421 break
422 }
423 if i == 0 {
424 return nil, 0, alertUnexpectedMessage
425 }
426 }
427 }
428 } else {
429 plaintext = payload
430 }
431
432 if hc.mac != nil {
433 macSize := hc.mac.Size()
434 if len(payload) < macSize {
435 return nil, 0, alertBadRecordMAC
436 }
437
438 n := len(payload) - macSize - paddingLen
439 n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n)
440 record[3] = byte(n >> 8)
441 record[4] = byte(n)
442 remoteMAC := payload[n : n+macSize]
443 localMAC := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload[:n], payload[n+macSize:])
444
445
446
447
448
449
450
451
452 macAndPaddingGood := subtle.ConstantTimeCompare(localMAC, remoteMAC) & int(paddingGood)
453 if macAndPaddingGood != 1 {
454 return nil, 0, alertBadRecordMAC
455 }
456
457 plaintext = payload[:n]
458 }
459
460 hc.incSeq()
461 return plaintext, typ, nil
462 }
463
464
465
466
467 func sliceForAppend(in []byte, n int) (head, tail []byte) {
468 if total := len(in) + n; cap(in) >= total {
469 head = in[:total]
470 } else {
471 head = make([]byte, total)
472 copy(head, in)
473 }
474 tail = head[len(in):]
475 return
476 }
477
478
479
480 func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) {
481 if hc.cipher == nil {
482 return append(record, payload...), nil
483 }
484
485 var explicitNonce []byte
486 if explicitNonceLen := hc.explicitNonceLen(); explicitNonceLen > 0 {
487 record, explicitNonce = sliceForAppend(record, explicitNonceLen)
488 if _, isCBC := hc.cipher.(cbcMode); !isCBC && explicitNonceLen < 16 {
489
490
491
492
493
494
495
496
497
498 copy(explicitNonce, hc.seq[:])
499 } else {
500 if _, err := io.ReadFull(rand, explicitNonce); err != nil {
501 return nil, err
502 }
503 }
504 }
505
506 var dst []byte
507 switch c := hc.cipher.(type) {
508 case cipher.Stream:
509 mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
510 record, dst = sliceForAppend(record, len(payload)+len(mac))
511 c.XORKeyStream(dst[:len(payload)], payload)
512 c.XORKeyStream(dst[len(payload):], mac)
513 case aead:
514 nonce := explicitNonce
515 if len(nonce) == 0 {
516 nonce = hc.seq[:]
517 }
518
519 if hc.version == VersionTLS13 {
520 record = append(record, payload...)
521
522
523 record = append(record, record[0])
524 record[0] = byte(recordTypeApplicationData)
525
526 n := len(payload) + 1 + c.Overhead()
527 record[3] = byte(n >> 8)
528 record[4] = byte(n)
529
530 record = c.Seal(record[:recordHeaderLen],
531 nonce, record[recordHeaderLen:], record[:recordHeaderLen])
532 } else {
533 additionalData := append(hc.scratchBuf[:0], hc.seq[:]...)
534 additionalData = append(additionalData, record[:recordHeaderLen]...)
535 record = c.Seal(record, nonce, payload, additionalData)
536 }
537 case cbcMode:
538 mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
539 blockSize := c.BlockSize()
540 plaintextLen := len(payload) + len(mac)
541 paddingLen := blockSize - plaintextLen%blockSize
542 record, dst = sliceForAppend(record, plaintextLen+paddingLen)
543 copy(dst, payload)
544 copy(dst[len(payload):], mac)
545 for i := plaintextLen; i < len(dst); i++ {
546 dst[i] = byte(paddingLen - 1)
547 }
548 if len(explicitNonce) > 0 {
549 c.SetIV(explicitNonce)
550 }
551 c.CryptBlocks(dst, dst)
552 default:
553 panic("unknown cipher type")
554 }
555
556
557 n := len(record) - recordHeaderLen
558 record[3] = byte(n >> 8)
559 record[4] = byte(n)
560 hc.incSeq()
561
562 return record, nil
563 }
564
565
566 type RecordHeaderError struct {
567
568 Msg string
569
570
571 RecordHeader [5]byte
572
573
574
575
576 Conn net.Conn
577 }
578
579 func (e RecordHeaderError) Error() string { return "tls: " + e.Msg }
580
581 func (c *Conn) newRecordHeaderError(conn net.Conn, msg string) (err RecordHeaderError) {
582 err.Msg = msg
583 err.Conn = conn
584 copy(err.RecordHeader[:], c.rawInput.Bytes())
585 return err
586 }
587
588 func (c *Conn) readRecord() error {
589 return c.readRecordOrCCS(false)
590 }
591
592 func (c *Conn) readChangeCipherSpec() error {
593 return c.readRecordOrCCS(true)
594 }
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610 func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error {
611 if c.in.err != nil {
612 return c.in.err
613 }
614 handshakeComplete := c.isHandshakeComplete.Load()
615
616
617 if c.input.Len() != 0 {
618 return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with pending application data"))
619 }
620 c.input.Reset(nil)
621
622 if c.quic != nil {
623 return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with QUIC transport"))
624 }
625
626
627 if err := c.readFromUntil(c.conn, recordHeaderLen); err != nil {
628
629
630
631 if err == io.ErrUnexpectedEOF && c.rawInput.Len() == 0 {
632 err = io.EOF
633 }
634 if e, ok := err.(net.Error); !ok || !e.Temporary() {
635 c.in.setErrorLocked(err)
636 }
637 return err
638 }
639 hdr := c.rawInput.Bytes()[:recordHeaderLen]
640 typ := recordType(hdr[0])
641
642
643
644
645
646 if !handshakeComplete && typ == 0x80 {
647 c.sendAlert(alertProtocolVersion)
648 return c.in.setErrorLocked(c.newRecordHeaderError(nil, "unsupported SSLv2 handshake received"))
649 }
650
651 vers := uint16(hdr[1])<<8 | uint16(hdr[2])
652 expectedVers := c.vers
653 if expectedVers == VersionTLS13 {
654
655
656 expectedVers = VersionTLS12
657 }
658 n := int(hdr[3])<<8 | int(hdr[4])
659 if c.haveVers && vers != expectedVers {
660 c.sendAlert(alertProtocolVersion)
661 msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, expectedVers)
662 return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
663 }
664 if !c.haveVers {
665
666
667
668
669 if (typ != recordTypeAlert && typ != recordTypeHandshake) || vers >= 0x1000 {
670 return c.in.setErrorLocked(c.newRecordHeaderError(c.conn, "first record does not look like a TLS handshake"))
671 }
672 }
673 if c.vers == VersionTLS13 && n > maxCiphertextTLS13 || n > maxCiphertext {
674 c.sendAlert(alertRecordOverflow)
675 msg := fmt.Sprintf("oversized record received with length %d", n)
676 return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
677 }
678 if err := c.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
679 if e, ok := err.(net.Error); !ok || !e.Temporary() {
680 c.in.setErrorLocked(err)
681 }
682 return err
683 }
684
685
686 record := c.rawInput.Next(recordHeaderLen + n)
687 data, typ, err := c.in.decrypt(record)
688 if err != nil {
689 return c.in.setErrorLocked(c.sendAlert(err.(alert)))
690 }
691 if len(data) > maxPlaintext {
692 return c.in.setErrorLocked(c.sendAlert(alertRecordOverflow))
693 }
694
695
696 if c.in.cipher == nil && typ == recordTypeApplicationData {
697 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
698 }
699
700 if typ != recordTypeAlert && typ != recordTypeChangeCipherSpec && len(data) > 0 {
701
702 c.retryCount = 0
703 }
704
705
706 if c.vers == VersionTLS13 && typ != recordTypeHandshake && c.hand.Len() > 0 {
707 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
708 }
709
710 switch typ {
711 default:
712 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
713
714 case recordTypeAlert:
715 if c.quic != nil {
716 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
717 }
718 if len(data) != 2 {
719 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
720 }
721 if alert(data[1]) == alertCloseNotify {
722 return c.in.setErrorLocked(io.EOF)
723 }
724 if c.vers == VersionTLS13 {
725
726
727
728
729
730 if alert(data[1]) == alertUserCanceled {
731
732 return c.retryReadRecord(expectChangeCipherSpec)
733 }
734 return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
735 }
736 switch data[0] {
737 case alertLevelWarning:
738
739 return c.retryReadRecord(expectChangeCipherSpec)
740 case alertLevelError:
741 return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
742 default:
743 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
744 }
745
746 case recordTypeChangeCipherSpec:
747 if len(data) != 1 || data[0] != 1 {
748 return c.in.setErrorLocked(c.sendAlert(alertDecodeError))
749 }
750
751 if c.hand.Len() > 0 {
752 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
753 }
754
755
756
757
758
759 if c.vers == VersionTLS13 {
760 return c.retryReadRecord(expectChangeCipherSpec)
761 }
762 if !expectChangeCipherSpec {
763 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
764 }
765 if err := c.in.changeCipherSpec(); err != nil {
766 return c.in.setErrorLocked(c.sendAlert(err.(alert)))
767 }
768
769 case recordTypeApplicationData:
770 if !handshakeComplete || expectChangeCipherSpec {
771 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
772 }
773
774
775 if len(data) == 0 {
776 return c.retryReadRecord(expectChangeCipherSpec)
777 }
778
779
780
781 c.input.Reset(data)
782
783 case recordTypeHandshake:
784 if len(data) == 0 || expectChangeCipherSpec {
785 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
786 }
787 c.hand.Write(data)
788 }
789
790 return nil
791 }
792
793
794
795 func (c *Conn) retryReadRecord(expectChangeCipherSpec bool) error {
796 c.retryCount++
797 if c.retryCount > maxUselessRecords {
798 c.sendAlert(alertUnexpectedMessage)
799 return c.in.setErrorLocked(errors.New("tls: too many ignored records"))
800 }
801 return c.readRecordOrCCS(expectChangeCipherSpec)
802 }
803
804
805
806
807 type atLeastReader struct {
808 R io.Reader
809 N int64
810 }
811
812 func (r *atLeastReader) Read(p []byte) (int, error) {
813 if r.N <= 0 {
814 return 0, io.EOF
815 }
816 n, err := r.R.Read(p)
817 r.N -= int64(n)
818 if r.N > 0 && err == io.EOF {
819 return n, io.ErrUnexpectedEOF
820 }
821 if r.N <= 0 && err == nil {
822 return n, io.EOF
823 }
824 return n, err
825 }
826
827
828
829 func (c *Conn) readFromUntil(r io.Reader, n int) error {
830 if c.rawInput.Len() >= n {
831 return nil
832 }
833 needs := n - c.rawInput.Len()
834
835
836
837 c.rawInput.Grow(needs + bytes.MinRead)
838 _, err := c.rawInput.ReadFrom(&atLeastReader{r, int64(needs)})
839 return err
840 }
841
842
843 func (c *Conn) sendAlertLocked(err alert) error {
844 if c.quic != nil {
845 return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
846 }
847
848 switch err {
849 case alertNoRenegotiation, alertCloseNotify:
850 c.tmp[0] = alertLevelWarning
851 default:
852 c.tmp[0] = alertLevelError
853 }
854 c.tmp[1] = byte(err)
855
856 _, writeErr := c.writeRecordLocked(recordTypeAlert, c.tmp[0:2])
857 if err == alertCloseNotify {
858
859 return writeErr
860 }
861
862 return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
863 }
864
865
866 func (c *Conn) sendAlert(err alert) error {
867 c.out.Lock()
868 defer c.out.Unlock()
869 return c.sendAlertLocked(err)
870 }
871
872 const (
873
874
875
876
877
878 tcpMSSEstimate = 1208
879
880
881
882
883 recordSizeBoostThreshold = 128 * 1024
884 )
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902 func (c *Conn) maxPayloadSizeForWrite(typ recordType) int {
903 if c.config.DynamicRecordSizingDisabled || typ != recordTypeApplicationData {
904 return maxPlaintext
905 }
906
907 if c.bytesSent >= recordSizeBoostThreshold {
908 return maxPlaintext
909 }
910
911
912 payloadBytes := tcpMSSEstimate - recordHeaderLen - c.out.explicitNonceLen()
913 if c.out.cipher != nil {
914 switch ciph := c.out.cipher.(type) {
915 case cipher.Stream:
916 payloadBytes -= c.out.mac.Size()
917 case cipher.AEAD:
918 payloadBytes -= ciph.Overhead()
919 case cbcMode:
920 blockSize := ciph.BlockSize()
921
922
923 payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1
924
925
926 payloadBytes -= c.out.mac.Size()
927 default:
928 panic("unknown cipher type")
929 }
930 }
931 if c.vers == VersionTLS13 {
932 payloadBytes--
933 }
934
935
936 pkt := c.packetsSent
937 c.packetsSent++
938 if pkt > 1000 {
939 return maxPlaintext
940 }
941
942 n := payloadBytes * int(pkt+1)
943 if n > maxPlaintext {
944 n = maxPlaintext
945 }
946 return n
947 }
948
949 func (c *Conn) write(data []byte) (int, error) {
950 if c.buffering {
951 c.sendBuf = append(c.sendBuf, data...)
952 return len(data), nil
953 }
954
955 n, err := c.conn.Write(data)
956 c.bytesSent += int64(n)
957 return n, err
958 }
959
960 func (c *Conn) flush() (int, error) {
961 if len(c.sendBuf) == 0 {
962 return 0, nil
963 }
964
965 n, err := c.conn.Write(c.sendBuf)
966 c.bytesSent += int64(n)
967 c.sendBuf = nil
968 c.buffering = false
969 return n, err
970 }
971
972
973 var outBufPool = sync.Pool{
974 New: func() any {
975 return new([]byte)
976 },
977 }
978
979
980
981 func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
982 if c.quic != nil {
983 if typ != recordTypeHandshake {
984 return 0, errors.New("tls: internal error: sending non-handshake message to QUIC transport")
985 }
986 c.quicWriteCryptoData(c.out.level, data)
987 if !c.buffering {
988 if _, err := c.flush(); err != nil {
989 return 0, err
990 }
991 }
992 return len(data), nil
993 }
994
995 outBufPtr := outBufPool.Get().(*[]byte)
996 outBuf := *outBufPtr
997 defer func() {
998
999
1000
1001
1002
1003 *outBufPtr = outBuf
1004 outBufPool.Put(outBufPtr)
1005 }()
1006
1007 var n int
1008 for len(data) > 0 {
1009 m := len(data)
1010 if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload {
1011 m = maxPayload
1012 }
1013
1014 _, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen)
1015 outBuf[0] = byte(typ)
1016 vers := c.vers
1017 if vers == 0 {
1018
1019
1020 vers = VersionTLS10
1021 } else if vers == VersionTLS13 {
1022
1023
1024 vers = VersionTLS12
1025 }
1026 outBuf[1] = byte(vers >> 8)
1027 outBuf[2] = byte(vers)
1028 outBuf[3] = byte(m >> 8)
1029 outBuf[4] = byte(m)
1030
1031 var err error
1032 outBuf, err = c.out.encrypt(outBuf, data[:m], c.config.rand())
1033 if err != nil {
1034 return n, err
1035 }
1036 if _, err := c.write(outBuf); err != nil {
1037 return n, err
1038 }
1039 n += m
1040 data = data[m:]
1041 }
1042
1043 if typ == recordTypeChangeCipherSpec && c.vers != VersionTLS13 {
1044 if err := c.out.changeCipherSpec(); err != nil {
1045 return n, c.sendAlertLocked(err.(alert))
1046 }
1047 }
1048
1049 return n, nil
1050 }
1051
1052
1053
1054
1055 func (c *Conn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash) (int, error) {
1056 c.out.Lock()
1057 defer c.out.Unlock()
1058
1059 data, err := msg.marshal()
1060 if err != nil {
1061 return 0, err
1062 }
1063 if transcript != nil {
1064 transcript.Write(data)
1065 }
1066
1067 return c.writeRecordLocked(recordTypeHandshake, data)
1068 }
1069
1070
1071
1072 func (c *Conn) writeChangeCipherRecord() error {
1073 c.out.Lock()
1074 defer c.out.Unlock()
1075 _, err := c.writeRecordLocked(recordTypeChangeCipherSpec, []byte{1})
1076 return err
1077 }
1078
1079
1080 func (c *Conn) readHandshakeBytes(n int) error {
1081 if c.quic != nil {
1082 return c.quicReadHandshakeBytes(n)
1083 }
1084 for c.hand.Len() < n {
1085 if err := c.readRecord(); err != nil {
1086 return err
1087 }
1088 }
1089 return nil
1090 }
1091
1092
1093
1094
1095 func (c *Conn) readHandshake(transcript transcriptHash) (any, error) {
1096 if err := c.readHandshakeBytes(4); err != nil {
1097 return nil, err
1098 }
1099 data := c.hand.Bytes()
1100
1101 maxHandshakeSize := maxHandshake
1102
1103
1104
1105 if c.haveVers && data[0] == typeCertificate {
1106
1107
1108
1109 maxHandshakeSize = maxHandshakeCertificateMsg
1110 }
1111
1112 n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
1113 if n > maxHandshakeSize {
1114 c.sendAlertLocked(alertInternalError)
1115 return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshakeSize))
1116 }
1117 if err := c.readHandshakeBytes(4 + n); err != nil {
1118 return nil, err
1119 }
1120 data = c.hand.Next(4 + n)
1121 return c.unmarshalHandshakeMessage(data, transcript)
1122 }
1123
1124 func (c *Conn) unmarshalHandshakeMessage(data []byte, transcript transcriptHash) (handshakeMessage, error) {
1125 var m handshakeMessage
1126 switch data[0] {
1127 case typeHelloRequest:
1128 m = new(helloRequestMsg)
1129 case typeClientHello:
1130 m = new(clientHelloMsg)
1131 case typeServerHello:
1132 m = new(serverHelloMsg)
1133 case typeNewSessionTicket:
1134 if c.vers == VersionTLS13 {
1135 m = new(newSessionTicketMsgTLS13)
1136 } else {
1137 m = new(newSessionTicketMsg)
1138 }
1139 case typeCertificate:
1140 if c.vers == VersionTLS13 {
1141 m = new(certificateMsgTLS13)
1142 } else {
1143 m = new(certificateMsg)
1144 }
1145 case typeCertificateRequest:
1146 if c.vers == VersionTLS13 {
1147 m = new(certificateRequestMsgTLS13)
1148 } else {
1149 m = &certificateRequestMsg{
1150 hasSignatureAlgorithm: c.vers >= VersionTLS12,
1151 }
1152 }
1153 case typeCertificateStatus:
1154 m = new(certificateStatusMsg)
1155 case typeServerKeyExchange:
1156 m = new(serverKeyExchangeMsg)
1157 case typeServerHelloDone:
1158 m = new(serverHelloDoneMsg)
1159 case typeClientKeyExchange:
1160 m = new(clientKeyExchangeMsg)
1161 case typeCertificateVerify:
1162 m = &certificateVerifyMsg{
1163 hasSignatureAlgorithm: c.vers >= VersionTLS12,
1164 }
1165 case typeFinished:
1166 m = new(finishedMsg)
1167 case typeEncryptedExtensions:
1168 m = new(encryptedExtensionsMsg)
1169 case typeEndOfEarlyData:
1170 m = new(endOfEarlyDataMsg)
1171 case typeKeyUpdate:
1172 m = new(keyUpdateMsg)
1173 default:
1174 return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
1175 }
1176
1177
1178
1179
1180 data = append([]byte(nil), data...)
1181
1182 if !m.unmarshal(data) {
1183 return nil, c.in.setErrorLocked(c.sendAlert(alertDecodeError))
1184 }
1185
1186 if transcript != nil {
1187 transcript.Write(data)
1188 }
1189
1190 return m, nil
1191 }
1192
1193 var (
1194 errShutdown = errors.New("tls: protocol is shutdown")
1195 )
1196
1197
1198
1199
1200
1201
1202
1203 func (c *Conn) Write(b []byte) (int, error) {
1204
1205 for {
1206 x := c.activeCall.Load()
1207 if x&1 != 0 {
1208 return 0, net.ErrClosed
1209 }
1210 if c.activeCall.CompareAndSwap(x, x+2) {
1211 break
1212 }
1213 }
1214 defer c.activeCall.Add(-2)
1215
1216 if err := c.Handshake(); err != nil {
1217 return 0, err
1218 }
1219
1220 c.out.Lock()
1221 defer c.out.Unlock()
1222
1223 if err := c.out.err; err != nil {
1224 return 0, err
1225 }
1226
1227 if !c.isHandshakeComplete.Load() {
1228 return 0, alertInternalError
1229 }
1230
1231 if c.closeNotifySent {
1232 return 0, errShutdown
1233 }
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244 var m int
1245 if len(b) > 1 && c.vers == VersionTLS10 {
1246 if _, ok := c.out.cipher.(cipher.BlockMode); ok {
1247 n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1])
1248 if err != nil {
1249 return n, c.out.setErrorLocked(err)
1250 }
1251 m, b = 1, b[1:]
1252 }
1253 }
1254
1255 n, err := c.writeRecordLocked(recordTypeApplicationData, b)
1256 return n + m, c.out.setErrorLocked(err)
1257 }
1258
1259
1260 func (c *Conn) handleRenegotiation() error {
1261 if c.vers == VersionTLS13 {
1262 return errors.New("tls: internal error: unexpected renegotiation")
1263 }
1264
1265 msg, err := c.readHandshake(nil)
1266 if err != nil {
1267 return err
1268 }
1269
1270 helloReq, ok := msg.(*helloRequestMsg)
1271 if !ok {
1272 c.sendAlert(alertUnexpectedMessage)
1273 return unexpectedMessageError(helloReq, msg)
1274 }
1275
1276 if !c.isClient {
1277 return c.sendAlert(alertNoRenegotiation)
1278 }
1279
1280 switch c.config.Renegotiation {
1281 case RenegotiateNever:
1282 return c.sendAlert(alertNoRenegotiation)
1283 case RenegotiateOnceAsClient:
1284 if c.handshakes > 1 {
1285 return c.sendAlert(alertNoRenegotiation)
1286 }
1287 case RenegotiateFreelyAsClient:
1288
1289 default:
1290 c.sendAlert(alertInternalError)
1291 return errors.New("tls: unknown Renegotiation value")
1292 }
1293
1294 c.handshakeMutex.Lock()
1295 defer c.handshakeMutex.Unlock()
1296
1297 c.isHandshakeComplete.Store(false)
1298 if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil {
1299 c.handshakes++
1300 }
1301 return c.handshakeErr
1302 }
1303
1304
1305
1306 func (c *Conn) handlePostHandshakeMessage() error {
1307 if c.vers != VersionTLS13 {
1308 return c.handleRenegotiation()
1309 }
1310
1311 msg, err := c.readHandshake(nil)
1312 if err != nil {
1313 return err
1314 }
1315 c.retryCount++
1316 if c.retryCount > maxUselessRecords {
1317 c.sendAlert(alertUnexpectedMessage)
1318 return c.in.setErrorLocked(errors.New("tls: too many non-advancing records"))
1319 }
1320
1321 switch msg := msg.(type) {
1322 case *newSessionTicketMsgTLS13:
1323 return c.handleNewSessionTicket(msg)
1324 case *keyUpdateMsg:
1325 return c.handleKeyUpdate(msg)
1326 }
1327
1328
1329
1330
1331 c.sendAlert(alertUnexpectedMessage)
1332 return fmt.Errorf("tls: received unexpected handshake message of type %T", msg)
1333 }
1334
1335 func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
1336 if c.quic != nil {
1337 c.sendAlert(alertUnexpectedMessage)
1338 return c.in.setErrorLocked(errors.New("tls: received unexpected key update message"))
1339 }
1340
1341 cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite)
1342 if cipherSuite == nil {
1343 return c.in.setErrorLocked(c.sendAlert(alertInternalError))
1344 }
1345
1346 newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret)
1347 c.in.setTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret)
1348
1349 if keyUpdate.updateRequested {
1350 c.out.Lock()
1351 defer c.out.Unlock()
1352
1353 msg := &keyUpdateMsg{}
1354 msgBytes, err := msg.marshal()
1355 if err != nil {
1356 return err
1357 }
1358 _, err = c.writeRecordLocked(recordTypeHandshake, msgBytes)
1359 if err != nil {
1360
1361 c.out.setErrorLocked(err)
1362 return nil
1363 }
1364
1365 newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret)
1366 c.out.setTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret)
1367 }
1368
1369 return nil
1370 }
1371
1372
1373
1374
1375
1376
1377
1378 func (c *Conn) Read(b []byte) (int, error) {
1379 if err := c.Handshake(); err != nil {
1380 return 0, err
1381 }
1382 if len(b) == 0 {
1383
1384
1385 return 0, nil
1386 }
1387
1388 c.in.Lock()
1389 defer c.in.Unlock()
1390
1391 for c.input.Len() == 0 {
1392 if err := c.readRecord(); err != nil {
1393 return 0, err
1394 }
1395 for c.hand.Len() > 0 {
1396 if err := c.handlePostHandshakeMessage(); err != nil {
1397 return 0, err
1398 }
1399 }
1400 }
1401
1402 n, _ := c.input.Read(b)
1403
1404
1405
1406
1407
1408
1409
1410
1411 if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 &&
1412 recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
1413 if err := c.readRecord(); err != nil {
1414 return n, err
1415 }
1416 }
1417
1418 return n, nil
1419 }
1420
1421
1422 func (c *Conn) Close() error {
1423
1424 var x int32
1425 for {
1426 x = c.activeCall.Load()
1427 if x&1 != 0 {
1428 return net.ErrClosed
1429 }
1430 if c.activeCall.CompareAndSwap(x, x|1) {
1431 break
1432 }
1433 }
1434 if x != 0 {
1435
1436
1437
1438
1439
1440
1441 return c.conn.Close()
1442 }
1443
1444 var alertErr error
1445 if c.isHandshakeComplete.Load() {
1446 if err := c.closeNotify(); err != nil {
1447 alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err)
1448 }
1449 }
1450
1451 if err := c.conn.Close(); err != nil {
1452 return err
1453 }
1454 return alertErr
1455 }
1456
1457 var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete")
1458
1459
1460
1461
1462 func (c *Conn) CloseWrite() error {
1463 if !c.isHandshakeComplete.Load() {
1464 return errEarlyCloseWrite
1465 }
1466
1467 return c.closeNotify()
1468 }
1469
1470 func (c *Conn) closeNotify() error {
1471 c.out.Lock()
1472 defer c.out.Unlock()
1473
1474 if !c.closeNotifySent {
1475
1476 c.SetWriteDeadline(time.Now().Add(time.Second * 5))
1477 c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify)
1478 c.closeNotifySent = true
1479
1480 c.SetWriteDeadline(time.Now())
1481 }
1482 return c.closeNotifyErr
1483 }
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498 func (c *Conn) Handshake() error {
1499 return c.HandshakeContext(context.Background())
1500 }
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512 func (c *Conn) HandshakeContext(ctx context.Context) error {
1513
1514
1515 return c.handshakeContext(ctx)
1516 }
1517
1518 func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
1519
1520
1521
1522 if c.isHandshakeComplete.Load() {
1523 return nil
1524 }
1525
1526 handshakeCtx, cancel := context.WithCancel(ctx)
1527
1528
1529
1530 defer cancel()
1531
1532 if c.quic != nil {
1533 c.quic.cancelc = handshakeCtx.Done()
1534 c.quic.cancel = cancel
1535 } else if ctx.Done() != nil {
1536
1537
1538
1539
1540
1541 done := make(chan struct{})
1542 interruptRes := make(chan error, 1)
1543 defer func() {
1544 close(done)
1545 if ctxErr := <-interruptRes; ctxErr != nil {
1546
1547 ret = ctxErr
1548 }
1549 }()
1550 go func() {
1551 select {
1552 case <-handshakeCtx.Done():
1553
1554 _ = c.conn.Close()
1555 interruptRes <- handshakeCtx.Err()
1556 case <-done:
1557 interruptRes <- nil
1558 }
1559 }()
1560 }
1561
1562 c.handshakeMutex.Lock()
1563 defer c.handshakeMutex.Unlock()
1564
1565 if err := c.handshakeErr; err != nil {
1566 return err
1567 }
1568 if c.isHandshakeComplete.Load() {
1569 return nil
1570 }
1571
1572 c.in.Lock()
1573 defer c.in.Unlock()
1574
1575 c.handshakeErr = c.handshakeFn(handshakeCtx)
1576 if c.handshakeErr == nil {
1577 c.handshakes++
1578 } else {
1579
1580
1581 c.flush()
1582 }
1583
1584 if c.handshakeErr == nil && !c.isHandshakeComplete.Load() {
1585 c.handshakeErr = errors.New("tls: internal error: handshake should have had a result")
1586 }
1587 if c.handshakeErr != nil && c.isHandshakeComplete.Load() {
1588 panic("tls: internal error: handshake returned an error but is marked successful")
1589 }
1590
1591 if c.quic != nil {
1592 if c.handshakeErr == nil {
1593 c.quicHandshakeComplete()
1594
1595
1596
1597 c.quicSetReadSecret(QUICEncryptionLevelApplication, c.cipherSuite, c.in.trafficSecret)
1598 } else {
1599 var a alert
1600 c.out.Lock()
1601 if !errors.As(c.out.err, &a) {
1602 a = alertInternalError
1603 }
1604 c.out.Unlock()
1605
1606
1607
1608
1609 c.handshakeErr = fmt.Errorf("%w%.0w", c.handshakeErr, AlertError(a))
1610 }
1611 close(c.quic.blockedc)
1612 close(c.quic.signalc)
1613 }
1614
1615 return c.handshakeErr
1616 }
1617
1618
1619 func (c *Conn) ConnectionState() ConnectionState {
1620 c.handshakeMutex.Lock()
1621 defer c.handshakeMutex.Unlock()
1622 return c.connectionStateLocked()
1623 }
1624
1625 var tlsunsafeekm = godebug.New("tlsunsafeekm")
1626
1627 func (c *Conn) connectionStateLocked() ConnectionState {
1628 var state ConnectionState
1629 state.HandshakeComplete = c.isHandshakeComplete.Load()
1630 state.Version = c.vers
1631 state.NegotiatedProtocol = c.clientProtocol
1632 state.DidResume = c.didResume
1633 state.testingOnlyDidHRR = c.didHRR
1634 state.testingOnlyPeerSignatureAlgorithm = c.peerSigAlg
1635 state.CurveID = c.curveID
1636 state.NegotiatedProtocolIsMutual = true
1637 state.ServerName = c.serverName
1638 state.CipherSuite = c.cipherSuite
1639 state.PeerCertificates = c.peerCertificates
1640 state.VerifiedChains = c.verifiedChains
1641 state.SignedCertificateTimestamps = c.scts
1642 state.OCSPResponse = c.ocspResponse
1643 if (!c.didResume || c.extMasterSecret) && c.vers != VersionTLS13 {
1644 if c.clientFinishedIsFirst {
1645 state.TLSUnique = c.clientFinished[:]
1646 } else {
1647 state.TLSUnique = c.serverFinished[:]
1648 }
1649 }
1650 if c.config.Renegotiation != RenegotiateNever {
1651 state.ekm = noEKMBecauseRenegotiation
1652 } else if c.vers != VersionTLS13 && !c.extMasterSecret {
1653 state.ekm = func(label string, context []byte, length int) ([]byte, error) {
1654 if tlsunsafeekm.Value() == "1" {
1655 tlsunsafeekm.IncNonDefault()
1656 return c.ekm(label, context, length)
1657 }
1658 return noEKMBecauseNoEMS(label, context, length)
1659 }
1660 } else {
1661 state.ekm = c.ekm
1662 }
1663 state.ECHAccepted = c.echAccepted
1664 return state
1665 }
1666
1667
1668
1669 func (c *Conn) OCSPResponse() []byte {
1670 c.handshakeMutex.Lock()
1671 defer c.handshakeMutex.Unlock()
1672
1673 return c.ocspResponse
1674 }
1675
1676
1677
1678
1679 func (c *Conn) VerifyHostname(host string) error {
1680 c.handshakeMutex.Lock()
1681 defer c.handshakeMutex.Unlock()
1682 if !c.isClient {
1683 return errors.New("tls: VerifyHostname called on TLS server connection")
1684 }
1685 if !c.isHandshakeComplete.Load() {
1686 return errors.New("tls: handshake has not yet been performed")
1687 }
1688 if len(c.verifiedChains) == 0 {
1689 return errors.New("tls: handshake did not verify certificate chain")
1690 }
1691 return c.peerCertificates[0].VerifyHostname(host)
1692 }
1693
View as plain text