diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 7115d0a12..7348bb7a9 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -109,9 +109,9 @@ type listenContext struct { // pending is used to wait for all pendingEndpoints to finish when // a socket is closed. pending sync.WaitGroup - // pendingEndpoints is a map of all endpoints for which a handshake is + // pendingEndpoints is a set of all endpoints for which a handshake is // in progress. - pendingEndpoints map[stack.TransportEndpointID]*endpoint + pendingEndpoints map[*endpoint]struct{} } // timeStamp returns an 8-bit timestamp with a granularity of 64 seconds. @@ -129,7 +129,7 @@ func newListenContext(stk *stack.Stack, protocol *protocol, listenEP *endpoint, v6Only: v6Only, netProto: netProto, listenEP: listenEP, - pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint), + pendingEndpoints: make(map[*endpoint]struct{}), } for i := range l.nonce { @@ -338,21 +338,21 @@ func (l *listenContext) performHandshake(s *segment, opts header.TCPSynOptions, func (l *listenContext) addPendingEndpoint(n *endpoint) { l.pendingMu.Lock() - l.pendingEndpoints[n.TransportEndpointInfo.ID] = n + l.pendingEndpoints[n] = struct{}{} l.pending.Add(1) l.pendingMu.Unlock() } func (l *listenContext) removePendingEndpoint(n *endpoint) { l.pendingMu.Lock() - delete(l.pendingEndpoints, n.TransportEndpointInfo.ID) + delete(l.pendingEndpoints, n) l.pending.Done() l.pendingMu.Unlock() } func (l *listenContext) closeAllPendingEndpoints() { l.pendingMu.Lock() - for _, n := range l.pendingEndpoints { + for n := range l.pendingEndpoints { n.notifyProtocolGoroutine(notifyClose) } l.pendingMu.Unlock()