Store pending endpoints in a set

There's no need for synthetic keys here.

PiperOrigin-RevId: 399263134
This commit is contained in:
Tamir Duberstein 2021-09-27 13:09:18 -07:00 committed by gVisor bot
parent 2e25547e04
commit 455924ee1b
1 changed files with 6 additions and 6 deletions

View File

@ -109,9 +109,9 @@ type listenContext struct {
// pending is used to wait for all pendingEndpoints to finish when // pending is used to wait for all pendingEndpoints to finish when
// a socket is closed. // a socket is closed.
pending sync.WaitGroup pending sync.WaitGroup
// pendingEndpoints is a map of all endpoints for which a handshake is // pendingEndpoints is a set of all endpoints for which a handshake is
// in progress. // in progress.
pendingEndpoints map[stack.TransportEndpointID]*endpoint pendingEndpoints map[*endpoint]struct{}
} }
// timeStamp returns an 8-bit timestamp with a granularity of 64 seconds. // timeStamp returns an 8-bit timestamp with a granularity of 64 seconds.
@ -129,7 +129,7 @@ func newListenContext(stk *stack.Stack, protocol *protocol, listenEP *endpoint,
v6Only: v6Only, v6Only: v6Only,
netProto: netProto, netProto: netProto,
listenEP: listenEP, listenEP: listenEP,
pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint), pendingEndpoints: make(map[*endpoint]struct{}),
} }
for i := range l.nonce { for i := range l.nonce {
@ -338,21 +338,21 @@ func (l *listenContext) performHandshake(s *segment, opts header.TCPSynOptions,
func (l *listenContext) addPendingEndpoint(n *endpoint) { func (l *listenContext) addPendingEndpoint(n *endpoint) {
l.pendingMu.Lock() l.pendingMu.Lock()
l.pendingEndpoints[n.TransportEndpointInfo.ID] = n l.pendingEndpoints[n] = struct{}{}
l.pending.Add(1) l.pending.Add(1)
l.pendingMu.Unlock() l.pendingMu.Unlock()
} }
func (l *listenContext) removePendingEndpoint(n *endpoint) { func (l *listenContext) removePendingEndpoint(n *endpoint) {
l.pendingMu.Lock() l.pendingMu.Lock()
delete(l.pendingEndpoints, n.TransportEndpointInfo.ID) delete(l.pendingEndpoints, n)
l.pending.Done() l.pending.Done()
l.pendingMu.Unlock() l.pendingMu.Unlock()
} }
func (l *listenContext) closeAllPendingEndpoints() { func (l *listenContext) closeAllPendingEndpoints() {
l.pendingMu.Lock() l.pendingMu.Lock()
for _, n := range l.pendingEndpoints { for n := range l.pendingEndpoints {
n.notifyProtocolGoroutine(notifyClose) n.notifyProtocolGoroutine(notifyClose)
} }
l.pendingMu.Unlock() l.pendingMu.Unlock()