Remove syncRcvdCount
This is redundant with listenContext.pendingEndpoints PiperOrigin-RevId: 399722472
This commit is contained in:
parent
65698b627e
commit
5aa37994c1
|
@ -20,7 +20,6 @@ import (
|
|||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/sleep"
|
||||
|
@ -103,14 +102,14 @@ type listenContext struct {
|
|||
|
||||
// pendingMu protects pendingEndpoints. This should only be accessed
|
||||
// by the listening endpoint's worker goroutine.
|
||||
//
|
||||
// Lock Ordering: listenEP.workerMu -> pendingMu
|
||||
pendingMu sync.Mutex
|
||||
// pending is used to wait for all pendingEndpoints to finish when
|
||||
// a socket is closed.
|
||||
pending sync.WaitGroup
|
||||
// pendingEndpoints is a set of all endpoints for which a handshake is
|
||||
// in progress.
|
||||
//
|
||||
// +checklocks:pendingMu
|
||||
pendingEndpoints map[*endpoint]struct{}
|
||||
}
|
||||
|
||||
|
@ -265,7 +264,6 @@ func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, qu
|
|||
|
||||
return nil, &tcpip.ErrConnectionAborted{}
|
||||
}
|
||||
l.addPendingEndpoint(ep)
|
||||
|
||||
// Propagate any inheritable options from the listening endpoint
|
||||
// to the newly created endpoint.
|
||||
|
@ -275,8 +273,6 @@ func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, qu
|
|||
ep.mu.Unlock()
|
||||
ep.Close()
|
||||
|
||||
l.removePendingEndpoint(ep)
|
||||
|
||||
return nil, &tcpip.ErrConnectionAborted{}
|
||||
}
|
||||
|
||||
|
@ -295,10 +291,6 @@ func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, qu
|
|||
ep.mu.Unlock()
|
||||
ep.Close()
|
||||
|
||||
if l.listenEP != nil {
|
||||
l.removePendingEndpoint(ep)
|
||||
}
|
||||
|
||||
ep.drainClosingSegmentQueue()
|
||||
|
||||
return nil, err
|
||||
|
@ -336,38 +328,12 @@ func (l *listenContext) performHandshake(s *segment, opts header.TCPSynOptions,
|
|||
return ep, nil
|
||||
}
|
||||
|
||||
func (l *listenContext) addPendingEndpoint(n *endpoint) {
|
||||
l.pendingMu.Lock()
|
||||
l.pendingEndpoints[n] = struct{}{}
|
||||
l.pending.Add(1)
|
||||
l.pendingMu.Unlock()
|
||||
}
|
||||
|
||||
func (l *listenContext) removePendingEndpoint(n *endpoint) {
|
||||
l.pendingMu.Lock()
|
||||
delete(l.pendingEndpoints, n)
|
||||
l.pending.Done()
|
||||
l.pendingMu.Unlock()
|
||||
}
|
||||
|
||||
func (l *listenContext) closeAllPendingEndpoints() {
|
||||
l.pendingMu.Lock()
|
||||
for n := range l.pendingEndpoints {
|
||||
n.notifyProtocolGoroutine(notifyClose)
|
||||
}
|
||||
l.pendingMu.Unlock()
|
||||
l.pending.Wait()
|
||||
}
|
||||
|
||||
// +checklocks:h.ep.mu
|
||||
func (l *listenContext) cleanupFailedHandshake(h *handshake) {
|
||||
e := h.ep
|
||||
e.mu.Unlock()
|
||||
e.Close()
|
||||
e.notifyAborted()
|
||||
if l.listenEP != nil {
|
||||
l.removePendingEndpoint(e)
|
||||
}
|
||||
e.drainClosingSegmentQueue()
|
||||
e.h = nil
|
||||
}
|
||||
|
@ -378,9 +344,6 @@ func (l *listenContext) cleanupFailedHandshake(h *handshake) {
|
|||
// +checklocks:h.ep.mu
|
||||
func (l *listenContext) cleanupCompletedHandshake(h *handshake) {
|
||||
e := h.ep
|
||||
if l.listenEP != nil {
|
||||
l.removePendingEndpoint(e)
|
||||
}
|
||||
e.isConnectNotified = true
|
||||
|
||||
// Update the receive window scaling. We can't do it before the
|
||||
|
@ -444,21 +407,6 @@ func (e *endpoint) notifyAborted() {
|
|||
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
|
||||
}
|
||||
|
||||
func (e *endpoint) synRcvdBacklogFull() bool {
|
||||
e.acceptMu.Lock()
|
||||
acceptedCap := e.accepted.cap
|
||||
e.acceptMu.Unlock()
|
||||
// The capacity of the accepted queue would always be one greater than the
|
||||
// listen backlog. But, the SYNRCVD connections count is always checked
|
||||
// against the listen backlog value for Linux parity reason.
|
||||
// https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280
|
||||
//
|
||||
// We maintain an equality check here as the synRcvdCount is incremented
|
||||
// and compared only from a single listener context and the capacity of
|
||||
// the accepted queue can only increase by a new listen call.
|
||||
return int(atomic.LoadInt32(&e.synRcvdCount)) == acceptedCap-1
|
||||
}
|
||||
|
||||
func (e *endpoint) acceptQueueIsFull() bool {
|
||||
e.acceptMu.Lock()
|
||||
full := e.accepted.acceptQueueIsFullLocked()
|
||||
|
@ -500,34 +448,53 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
|
|||
return nil
|
||||
}
|
||||
|
||||
alwaysUseSynCookies := func() bool {
|
||||
opts := parseSynSegmentOptions(s)
|
||||
|
||||
useSynCookies, err := func() (bool, tcpip.Error) {
|
||||
var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies
|
||||
if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil {
|
||||
panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err))
|
||||
}
|
||||
return bool(alwaysUseSynCookies)
|
||||
}()
|
||||
if alwaysUseSynCookies {
|
||||
return true, nil
|
||||
}
|
||||
e.acceptMu.Lock()
|
||||
defer e.acceptMu.Unlock()
|
||||
|
||||
opts := parseSynSegmentOptions(s)
|
||||
if !alwaysUseSynCookies && !e.synRcvdBacklogFull() {
|
||||
atomic.AddInt32(&e.synRcvdCount, 1)
|
||||
ctx.pendingMu.Lock()
|
||||
defer ctx.pendingMu.Unlock()
|
||||
// The capacity of the accepted queue would always be one greater than the
|
||||
// listen backlog. But, the SYNRCVD connections count is always checked
|
||||
// against the listen backlog value for Linux parity reason.
|
||||
// https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280
|
||||
if len(ctx.pendingEndpoints) == e.accepted.cap-1 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner)
|
||||
if err != nil {
|
||||
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
|
||||
e.stats.FailedConnectionAttempts.Increment()
|
||||
atomic.AddInt32(&e.synRcvdCount, -1)
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
|
||||
ctx.pendingEndpoints[h.ep] = struct{}{}
|
||||
ctx.pending.Add(1)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
ctx.pendingMu.Lock()
|
||||
defer ctx.pendingMu.Unlock()
|
||||
delete(ctx.pendingEndpoints, h.ep)
|
||||
ctx.pending.Done()
|
||||
}()
|
||||
|
||||
// Note that startHandshake returns a locked endpoint. The force call
|
||||
// here just makes it so.
|
||||
if err := h.complete(); err != nil { // +checklocksforce
|
||||
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
|
||||
e.stats.FailedConnectionAttempts.Increment()
|
||||
ctx.cleanupFailedHandshake(h)
|
||||
atomic.AddInt32(&e.synRcvdCount, -1)
|
||||
return
|
||||
}
|
||||
ctx.cleanupCompletedHandshake(h)
|
||||
|
@ -558,7 +525,6 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
|
|||
}
|
||||
|
||||
e.accepted.endpoints.PushBack(h.ep)
|
||||
atomic.AddInt32(&e.synRcvdCount, -1)
|
||||
return true
|
||||
}
|
||||
}()
|
||||
|
@ -570,6 +536,12 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
|
|||
}
|
||||
}()
|
||||
|
||||
return false, nil
|
||||
}()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !useSynCookies {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -780,7 +752,12 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) {
|
|||
e.setEndpointState(StateClose)
|
||||
|
||||
// Close any endpoints in SYN-RCVD state.
|
||||
ctx.closeAllPendingEndpoints()
|
||||
ctx.pendingMu.Lock()
|
||||
for n := range ctx.pendingEndpoints {
|
||||
n.notifyProtocolGoroutine(notifyClose)
|
||||
}
|
||||
ctx.pendingMu.Unlock()
|
||||
ctx.pending.Wait()
|
||||
|
||||
// Do cleanup if needed.
|
||||
e.completeWorkerLocked()
|
||||
|
|
|
@ -508,10 +508,6 @@ type endpoint struct {
|
|||
// and dropped when it is.
|
||||
segmentQueue segmentQueue `state:"wait"`
|
||||
|
||||
// synRcvdCount is the number of connections for this endpoint that are
|
||||
// in SYN-RCVD state; this is only accessed atomically.
|
||||
synRcvdCount int32
|
||||
|
||||
// userMSS if non-zero is the MSS value explicitly set by the user
|
||||
// for this endpoint using the TCP_MAXSEG setsockopt.
|
||||
userMSS uint16
|
||||
|
|
Loading…
Reference in New Issue