Source file src/net/http/internal/http2/client_conn_pool.go

     1  // Copyright 2015 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Transport code's client connection pooling.
     6  
     7  package http2
     8  
     9  import (
    10  	"context"
    11  	"errors"
    12  	"net"
    13  	"slices"
    14  	"sync"
    15  )
    16  
    17  // TODO: use singleflight for dialing and addConnCalls?
    18  type clientConnPool struct {
    19  	t *Transport
    20  
    21  	mu sync.Mutex // TODO: maybe switch to RWMutex
    22  	// TODO: add support for sharing conns based on cert names
    23  	// (e.g. share conn for googleapis.com and appspot.com)
    24  	conns        map[string][]*ClientConn // key is host:port
    25  	dialing      map[string]*dialCall     // currently in-flight dials
    26  	keys         map[*ClientConn][]string
    27  	addConnCalls map[string]*addConnCall // in-flight addConnIfNeeded calls
    28  }
    29  
    30  func (p *clientConnPool) GetClientConn(req *ClientRequest, addr string) (*ClientConn, error) {
    31  	return p.getClientConn(req, addr, dialOnMiss)
    32  }
    33  
    34  const (
    35  	dialOnMiss   = true
    36  	noDialOnMiss = false
    37  )
    38  
    39  func (p *clientConnPool) getClientConn(req *ClientRequest, addr string, dialOnMiss bool) (*ClientConn, error) {
    40  	// TODO(dneil): Dial a new connection when t.DisableKeepAlives is set?
    41  	if isConnectionCloseRequest(req) && dialOnMiss {
    42  		// It gets its own connection.
    43  		traceGetConn(req, addr)
    44  		const singleUse = true
    45  		cc, err := p.t.dialClientConn(req.Context, addr, singleUse)
    46  		if err != nil {
    47  			return nil, err
    48  		}
    49  		return cc, nil
    50  	}
    51  	for {
    52  		p.mu.Lock()
    53  		for _, cc := range p.conns[addr] {
    54  			if cc.ReserveNewRequest() {
    55  				// When a connection is presented to us by the net/http package,
    56  				// the GetConn hook has already been called.
    57  				// Don't call it a second time here.
    58  				if !cc.getConnCalled {
    59  					traceGetConn(req, addr)
    60  				}
    61  				cc.getConnCalled = false
    62  				p.mu.Unlock()
    63  				return cc, nil
    64  			}
    65  		}
    66  		if !dialOnMiss {
    67  			p.mu.Unlock()
    68  			return nil, ErrNoCachedConn
    69  		}
    70  		traceGetConn(req, addr)
    71  		call := p.getStartDialLocked(req.Context, addr)
    72  		p.mu.Unlock()
    73  		<-call.done
    74  		if shouldRetryDial(call, req) {
    75  			continue
    76  		}
    77  		cc, err := call.res, call.err
    78  		if err != nil {
    79  			return nil, err
    80  		}
    81  		if cc.ReserveNewRequest() {
    82  			return cc, nil
    83  		}
    84  	}
    85  }
    86  
    87  // dialCall is an in-flight Transport dial call to a host.
    88  type dialCall struct {
    89  	_ incomparable
    90  	p *clientConnPool
    91  	// the context associated with the request
    92  	// that created this dialCall
    93  	ctx  context.Context
    94  	done chan struct{} // closed when done
    95  	res  *ClientConn   // valid after done is closed
    96  	err  error         // valid after done is closed
    97  }
    98  
    99  // requires p.mu is held.
   100  func (p *clientConnPool) getStartDialLocked(ctx context.Context, addr string) *dialCall {
   101  	if call, ok := p.dialing[addr]; ok {
   102  		// A dial is already in-flight. Don't start another.
   103  		return call
   104  	}
   105  	call := &dialCall{p: p, done: make(chan struct{}), ctx: ctx}
   106  	if p.dialing == nil {
   107  		p.dialing = make(map[string]*dialCall)
   108  	}
   109  	p.dialing[addr] = call
   110  	go call.dial(call.ctx, addr)
   111  	return call
   112  }
   113  
   114  // run in its own goroutine.
   115  func (c *dialCall) dial(ctx context.Context, addr string) {
   116  	const singleUse = false // shared conn
   117  	c.res, c.err = c.p.t.dialClientConn(ctx, addr, singleUse)
   118  
   119  	c.p.mu.Lock()
   120  	delete(c.p.dialing, addr)
   121  	if c.err == nil {
   122  		c.p.addConnLocked(addr, c.res)
   123  	}
   124  	c.p.mu.Unlock()
   125  
   126  	close(c.done)
   127  }
   128  
   129  // addConnIfNeeded makes a NewClientConn out of c if a connection for key doesn't
   130  // already exist. It coalesces concurrent calls with the same key.
   131  // This is used by the http1 Transport code when it creates a new connection. Because
   132  // the http1 Transport doesn't de-dup TCP dials to outbound hosts (because it doesn't know
   133  // the protocol), it can get into a situation where it has multiple TLS connections.
   134  // This code decides which ones live or die.
   135  // The return value used is whether c was used.
   136  // c is never closed.
   137  func (p *clientConnPool) addConnIfNeeded(key string, t *Transport, c net.Conn) (used bool, err error) {
   138  	p.mu.Lock()
   139  	for _, cc := range p.conns[key] {
   140  		if cc.CanTakeNewRequest() {
   141  			p.mu.Unlock()
   142  			return false, nil
   143  		}
   144  	}
   145  	call, dup := p.addConnCalls[key]
   146  	if !dup {
   147  		if p.addConnCalls == nil {
   148  			p.addConnCalls = make(map[string]*addConnCall)
   149  		}
   150  		call = &addConnCall{
   151  			p:    p,
   152  			done: make(chan struct{}),
   153  		}
   154  		p.addConnCalls[key] = call
   155  		go call.run(t, key, c)
   156  	}
   157  	p.mu.Unlock()
   158  
   159  	<-call.done
   160  	if call.err != nil {
   161  		return false, call.err
   162  	}
   163  	return !dup, nil
   164  }
   165  
   166  type addConnCall struct {
   167  	_    incomparable
   168  	p    *clientConnPool
   169  	done chan struct{} // closed when done
   170  	err  error
   171  }
   172  
   173  func (c *addConnCall) run(t *Transport, key string, nc net.Conn) {
   174  	cc, err := t.newClientConn(nc, t.disableKeepAlives(), nil)
   175  
   176  	p := c.p
   177  	p.mu.Lock()
   178  	if err != nil {
   179  		c.err = err
   180  	} else {
   181  		cc.getConnCalled = true // already called by the net/http package
   182  		p.addConnLocked(key, cc)
   183  	}
   184  	delete(p.addConnCalls, key)
   185  	p.mu.Unlock()
   186  	close(c.done)
   187  }
   188  
   189  // p.mu must be held
   190  func (p *clientConnPool) addConnLocked(key string, cc *ClientConn) {
   191  	if slices.Contains(p.conns[key], cc) {
   192  		return
   193  	}
   194  	if p.conns == nil {
   195  		p.conns = make(map[string][]*ClientConn)
   196  	}
   197  	if p.keys == nil {
   198  		p.keys = make(map[*ClientConn][]string)
   199  	}
   200  	p.conns[key] = append(p.conns[key], cc)
   201  	p.keys[cc] = append(p.keys[cc], key)
   202  }
   203  
   204  func (p *clientConnPool) MarkDead(cc *ClientConn) {
   205  	p.mu.Lock()
   206  	defer p.mu.Unlock()
   207  	for _, key := range p.keys[cc] {
   208  		vv, ok := p.conns[key]
   209  		if !ok {
   210  			continue
   211  		}
   212  		newList := filterOutClientConn(vv, cc)
   213  		if len(newList) > 0 {
   214  			p.conns[key] = newList
   215  		} else {
   216  			delete(p.conns, key)
   217  		}
   218  	}
   219  	delete(p.keys, cc)
   220  }
   221  
   222  func (p *clientConnPool) closeIdleConnections() {
   223  	p.mu.Lock()
   224  	defer p.mu.Unlock()
   225  	// TODO: don't close a cc if it was just added to the pool
   226  	// milliseconds ago and has never been used. There's currently
   227  	// a small race window with the HTTP/1 Transport's integration
   228  	// where it can add an idle conn just before using it, and
   229  	// somebody else can concurrently call CloseIdleConns and
   230  	// break some caller's RoundTrip.
   231  	for _, vv := range p.conns {
   232  		for _, cc := range vv {
   233  			cc.closeIfIdle()
   234  		}
   235  	}
   236  }
   237  
   238  func filterOutClientConn(in []*ClientConn, exclude *ClientConn) []*ClientConn {
   239  	out := in[:0]
   240  	for _, v := range in {
   241  		if v != exclude {
   242  			out = append(out, v)
   243  		}
   244  	}
   245  	// If we filtered it out, zero out the last item to prevent
   246  	// the GC from seeing it.
   247  	if len(in) != len(out) {
   248  		in[len(in)-1] = nil
   249  	}
   250  	return out
   251  }
   252  
   253  // noDialClientConnPool is an implementation of http2.ClientConnPool
   254  // which never dials. We let the HTTP/1.1 client dial and use its TLS
   255  // connection instead.
   256  type noDialClientConnPool struct{ *clientConnPool }
   257  
   258  func (p noDialClientConnPool) GetClientConn(req *ClientRequest, addr string) (*ClientConn, error) {
   259  	return p.getClientConn(req, addr, noDialOnMiss)
   260  }
   261  
   262  // shouldRetryDial reports whether the current request should
   263  // retry dialing after the call finished unsuccessfully, for example
   264  // if the dial was canceled because of a context cancellation or
   265  // deadline expiry.
   266  func shouldRetryDial(call *dialCall, req *ClientRequest) bool {
   267  	if call.err == nil {
   268  		// No error, no need to retry
   269  		return false
   270  	}
   271  	if call.ctx == req.Context {
   272  		// If the call has the same context as the request, the dial
   273  		// should not be retried, since any cancellation will have come
   274  		// from this request.
   275  		return false
   276  	}
   277  	if !errors.Is(call.err, context.Canceled) && !errors.Is(call.err, context.DeadlineExceeded) {
   278  		// If the call error is not because of a context cancellation or a deadline expiry,
   279  		// the dial should not be retried.
   280  		return false
   281  	}
   282  	// Only retry if the error is a context cancellation error or deadline expiry
   283  	// and the context associated with the call was canceled or expired.
   284  	return call.ctx.Err() != nil
   285  }
   286  

View as plain text