1
2
3
4
5
6
7 package http2
8
9 import (
10 "context"
11 "errors"
12 "net"
13 "slices"
14 "sync"
15 )
16
17
18 type clientConnPool struct {
19 t *Transport
20
21 mu sync.Mutex
22
23
24 conns map[string][]*ClientConn
25 dialing map[string]*dialCall
26 keys map[*ClientConn][]string
27 addConnCalls map[string]*addConnCall
28 }
29
30 func (p *clientConnPool) GetClientConn(req *ClientRequest, addr string) (*ClientConn, error) {
31 return p.getClientConn(req, addr, dialOnMiss)
32 }
33
34 const (
35 dialOnMiss = true
36 noDialOnMiss = false
37 )
38
39 func (p *clientConnPool) getClientConn(req *ClientRequest, addr string, dialOnMiss bool) (*ClientConn, error) {
40
41 if isConnectionCloseRequest(req) && dialOnMiss {
42
43 traceGetConn(req, addr)
44 const singleUse = true
45 cc, err := p.t.dialClientConn(req.Context, addr, singleUse)
46 if err != nil {
47 return nil, err
48 }
49 return cc, nil
50 }
51 for {
52 p.mu.Lock()
53 for _, cc := range p.conns[addr] {
54 if cc.ReserveNewRequest() {
55
56
57
58 if !cc.getConnCalled {
59 traceGetConn(req, addr)
60 }
61 cc.getConnCalled = false
62 p.mu.Unlock()
63 return cc, nil
64 }
65 }
66 if !dialOnMiss {
67 p.mu.Unlock()
68 return nil, ErrNoCachedConn
69 }
70 traceGetConn(req, addr)
71 call := p.getStartDialLocked(req.Context, addr)
72 p.mu.Unlock()
73 <-call.done
74 if shouldRetryDial(call, req) {
75 continue
76 }
77 cc, err := call.res, call.err
78 if err != nil {
79 return nil, err
80 }
81 if cc.ReserveNewRequest() {
82 return cc, nil
83 }
84 }
85 }
86
87
88 type dialCall struct {
89 _ incomparable
90 p *clientConnPool
91
92
93 ctx context.Context
94 done chan struct{}
95 res *ClientConn
96 err error
97 }
98
99
100 func (p *clientConnPool) getStartDialLocked(ctx context.Context, addr string) *dialCall {
101 if call, ok := p.dialing[addr]; ok {
102
103 return call
104 }
105 call := &dialCall{p: p, done: make(chan struct{}), ctx: ctx}
106 if p.dialing == nil {
107 p.dialing = make(map[string]*dialCall)
108 }
109 p.dialing[addr] = call
110 go call.dial(call.ctx, addr)
111 return call
112 }
113
114
115 func (c *dialCall) dial(ctx context.Context, addr string) {
116 const singleUse = false
117 c.res, c.err = c.p.t.dialClientConn(ctx, addr, singleUse)
118
119 c.p.mu.Lock()
120 delete(c.p.dialing, addr)
121 if c.err == nil {
122 c.p.addConnLocked(addr, c.res)
123 }
124 c.p.mu.Unlock()
125
126 close(c.done)
127 }
128
129
130
131
132
133
134
135
136
137 func (p *clientConnPool) addConnIfNeeded(key string, t *Transport, c net.Conn) (used bool, err error) {
138 p.mu.Lock()
139 for _, cc := range p.conns[key] {
140 if cc.CanTakeNewRequest() {
141 p.mu.Unlock()
142 return false, nil
143 }
144 }
145 call, dup := p.addConnCalls[key]
146 if !dup {
147 if p.addConnCalls == nil {
148 p.addConnCalls = make(map[string]*addConnCall)
149 }
150 call = &addConnCall{
151 p: p,
152 done: make(chan struct{}),
153 }
154 p.addConnCalls[key] = call
155 go call.run(t, key, c)
156 }
157 p.mu.Unlock()
158
159 <-call.done
160 if call.err != nil {
161 return false, call.err
162 }
163 return !dup, nil
164 }
165
166 type addConnCall struct {
167 _ incomparable
168 p *clientConnPool
169 done chan struct{}
170 err error
171 }
172
173 func (c *addConnCall) run(t *Transport, key string, nc net.Conn) {
174 cc, err := t.newClientConn(nc, t.disableKeepAlives(), nil)
175
176 p := c.p
177 p.mu.Lock()
178 if err != nil {
179 c.err = err
180 } else {
181 cc.getConnCalled = true
182 p.addConnLocked(key, cc)
183 }
184 delete(p.addConnCalls, key)
185 p.mu.Unlock()
186 close(c.done)
187 }
188
189
190 func (p *clientConnPool) addConnLocked(key string, cc *ClientConn) {
191 if slices.Contains(p.conns[key], cc) {
192 return
193 }
194 if p.conns == nil {
195 p.conns = make(map[string][]*ClientConn)
196 }
197 if p.keys == nil {
198 p.keys = make(map[*ClientConn][]string)
199 }
200 p.conns[key] = append(p.conns[key], cc)
201 p.keys[cc] = append(p.keys[cc], key)
202 }
203
204 func (p *clientConnPool) MarkDead(cc *ClientConn) {
205 p.mu.Lock()
206 defer p.mu.Unlock()
207 for _, key := range p.keys[cc] {
208 vv, ok := p.conns[key]
209 if !ok {
210 continue
211 }
212 newList := filterOutClientConn(vv, cc)
213 if len(newList) > 0 {
214 p.conns[key] = newList
215 } else {
216 delete(p.conns, key)
217 }
218 }
219 delete(p.keys, cc)
220 }
221
222 func (p *clientConnPool) closeIdleConnections() {
223 p.mu.Lock()
224 defer p.mu.Unlock()
225
226
227
228
229
230
231 for _, vv := range p.conns {
232 for _, cc := range vv {
233 cc.closeIfIdle()
234 }
235 }
236 }
237
238 func filterOutClientConn(in []*ClientConn, exclude *ClientConn) []*ClientConn {
239 out := in[:0]
240 for _, v := range in {
241 if v != exclude {
242 out = append(out, v)
243 }
244 }
245
246
247 if len(in) != len(out) {
248 in[len(in)-1] = nil
249 }
250 return out
251 }
252
253
254
255
256 type noDialClientConnPool struct{ *clientConnPool }
257
258 func (p noDialClientConnPool) GetClientConn(req *ClientRequest, addr string) (*ClientConn, error) {
259 return p.getClientConn(req, addr, noDialOnMiss)
260 }
261
262
263
264
265
266 func shouldRetryDial(call *dialCall, req *ClientRequest) bool {
267 if call.err == nil {
268
269 return false
270 }
271 if call.ctx == req.Context {
272
273
274
275 return false
276 }
277 if !errors.Is(call.err, context.Canceled) && !errors.Is(call.err, context.DeadlineExceeded) {
278
279
280 return false
281 }
282
283
284 return call.ctx.Err() != nil
285 }
286
View as plain text