Remove syncRcvdCount

This is redundant with listenContext.pendingEndpoints

PiperOrigin-RevId: 399722472
This commit is contained in:
Tamir Duberstein 2021-09-29 10:47:42 -07:00 committed by gVisor bot
parent 65698b627e
commit 5aa37994c1
2 changed files with 42 additions and 69 deletions

View File

@ -20,7 +20,6 @@ import (
"fmt"
"hash"
"io"
"sync/atomic"
"time"
"gvisor.dev/gvisor/pkg/sleep"
@ -103,14 +102,14 @@ type listenContext struct {
// pendingMu protects pendingEndpoints. This should only be accessed
// by the listening endpoint's worker goroutine.
//
// Lock Ordering: listenEP.workerMu -> pendingMu
pendingMu sync.Mutex
// pending is used to wait for all pendingEndpoints to finish when
// a socket is closed.
pending sync.WaitGroup
// pendingEndpoints is a set of all endpoints for which a handshake is
// in progress.
//
// +checklocks:pendingMu
pendingEndpoints map[*endpoint]struct{}
}
@ -265,7 +264,6 @@ func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, qu
return nil, &tcpip.ErrConnectionAborted{}
}
l.addPendingEndpoint(ep)
// Propagate any inheritable options from the listening endpoint
// to the newly created endpoint.
@ -275,8 +273,6 @@ func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, qu
ep.mu.Unlock()
ep.Close()
l.removePendingEndpoint(ep)
return nil, &tcpip.ErrConnectionAborted{}
}
@ -295,10 +291,6 @@ func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, qu
ep.mu.Unlock()
ep.Close()
if l.listenEP != nil {
l.removePendingEndpoint(ep)
}
ep.drainClosingSegmentQueue()
return nil, err
@ -336,38 +328,12 @@ func (l *listenContext) performHandshake(s *segment, opts header.TCPSynOptions,
return ep, nil
}
func (l *listenContext) addPendingEndpoint(n *endpoint) {
l.pendingMu.Lock()
l.pendingEndpoints[n] = struct{}{}
l.pending.Add(1)
l.pendingMu.Unlock()
}
func (l *listenContext) removePendingEndpoint(n *endpoint) {
l.pendingMu.Lock()
delete(l.pendingEndpoints, n)
l.pending.Done()
l.pendingMu.Unlock()
}
func (l *listenContext) closeAllPendingEndpoints() {
l.pendingMu.Lock()
for n := range l.pendingEndpoints {
n.notifyProtocolGoroutine(notifyClose)
}
l.pendingMu.Unlock()
l.pending.Wait()
}
// +checklocks:h.ep.mu
func (l *listenContext) cleanupFailedHandshake(h *handshake) {
e := h.ep
e.mu.Unlock()
e.Close()
e.notifyAborted()
if l.listenEP != nil {
l.removePendingEndpoint(e)
}
e.drainClosingSegmentQueue()
e.h = nil
}
@ -378,9 +344,6 @@ func (l *listenContext) cleanupFailedHandshake(h *handshake) {
// +checklocks:h.ep.mu
func (l *listenContext) cleanupCompletedHandshake(h *handshake) {
e := h.ep
if l.listenEP != nil {
l.removePendingEndpoint(e)
}
e.isConnectNotified = true
// Update the receive window scaling. We can't do it before the
@ -444,21 +407,6 @@ func (e *endpoint) notifyAborted() {
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
}
func (e *endpoint) synRcvdBacklogFull() bool {
e.acceptMu.Lock()
acceptedCap := e.accepted.cap
e.acceptMu.Unlock()
// The capacity of the accepted queue would always be one greater than the
// listen backlog. But, the SYNRCVD connections count is always checked
// against the listen backlog value for Linux parity reason.
// https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280
//
// We maintain an equality check here as the synRcvdCount is incremented
// and compared only from a single listener context and the capacity of
// the accepted queue can only increase by a new listen call.
return int(atomic.LoadInt32(&e.synRcvdCount)) == acceptedCap-1
}
func (e *endpoint) acceptQueueIsFull() bool {
e.acceptMu.Lock()
full := e.accepted.acceptQueueIsFullLocked()
@ -500,34 +448,53 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
return nil
}
alwaysUseSynCookies := func() bool {
opts := parseSynSegmentOptions(s)
useSynCookies, err := func() (bool, tcpip.Error) {
var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies
if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil {
panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err))
}
return bool(alwaysUseSynCookies)
}()
if alwaysUseSynCookies {
return true, nil
}
e.acceptMu.Lock()
defer e.acceptMu.Unlock()
opts := parseSynSegmentOptions(s)
if !alwaysUseSynCookies && !e.synRcvdBacklogFull() {
atomic.AddInt32(&e.synRcvdCount, 1)
ctx.pendingMu.Lock()
defer ctx.pendingMu.Unlock()
// The capacity of the accepted queue would always be one greater than the
// listen backlog. But, the SYNRCVD connections count is always checked
// against the listen backlog value for Linux parity reason.
// https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280
if len(ctx.pendingEndpoints) == e.accepted.cap-1 {
return true, nil
}
h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner)
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
atomic.AddInt32(&e.synRcvdCount, -1)
return err
return false, err
}
ctx.pendingEndpoints[h.ep] = struct{}{}
ctx.pending.Add(1)
go func() {
defer func() {
ctx.pendingMu.Lock()
defer ctx.pendingMu.Unlock()
delete(ctx.pendingEndpoints, h.ep)
ctx.pending.Done()
}()
// Note that startHandshake returns a locked endpoint. The force call
// here just makes it so.
if err := h.complete(); err != nil { // +checklocksforce
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
ctx.cleanupFailedHandshake(h)
atomic.AddInt32(&e.synRcvdCount, -1)
return
}
ctx.cleanupCompletedHandshake(h)
@ -558,7 +525,6 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
}
e.accepted.endpoints.PushBack(h.ep)
atomic.AddInt32(&e.synRcvdCount, -1)
return true
}
}()
@ -570,6 +536,12 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
}
}()
return false, nil
}()
if err != nil {
return err
}
if !useSynCookies {
return nil
}
@ -780,7 +752,12 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) {
e.setEndpointState(StateClose)
// Close any endpoints in SYN-RCVD state.
ctx.closeAllPendingEndpoints()
ctx.pendingMu.Lock()
for n := range ctx.pendingEndpoints {
n.notifyProtocolGoroutine(notifyClose)
}
ctx.pendingMu.Unlock()
ctx.pending.Wait()
// Do cleanup if needed.
e.completeWorkerLocked()

View File

@ -508,10 +508,6 @@ type endpoint struct {
// and dropped when it is.
segmentQueue segmentQueue `state:"wait"`
// synRcvdCount is the number of connections for this endpoint that are
// in SYN-RCVD state; this is only accessed atomically.
synRcvdCount int32
// userMSS if non-zero is the MSS value explicitly set by the user
// for this endpoint using the TCP_MAXSEG setsockopt.
userMSS uint16