Fix data-race when reading/writing e.amss.
PiperOrigin-RevId: 298451319
This commit is contained in:
parent
8821a7104f
commit
3310175250
|
@ -295,6 +295,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
|
|||
h.state = handshakeSynRcvd
|
||||
h.ep.mu.Lock()
|
||||
ttl := h.ep.ttl
|
||||
amss := h.ep.amss
|
||||
h.ep.setEndpointState(StateSynRecv)
|
||||
h.ep.mu.Unlock()
|
||||
synOpts := header.TCPSynOptions{
|
||||
|
@ -307,7 +308,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
|
|||
// permits SACK. This is not explicitly defined in the RFC but
|
||||
// this is the behaviour implemented by Linux.
|
||||
SACKPermitted: rcvSynOpts.SACKPermitted,
|
||||
MSS: h.ep.amss,
|
||||
MSS: amss,
|
||||
}
|
||||
if ttl == 0 {
|
||||
ttl = s.route.DefaultTTL()
|
||||
|
@ -356,6 +357,10 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
|
|||
return tcpip.ErrInvalidEndpointState
|
||||
}
|
||||
|
||||
h.ep.mu.RLock()
|
||||
amss := h.ep.amss
|
||||
h.ep.mu.RUnlock()
|
||||
|
||||
h.resetState()
|
||||
synOpts := header.TCPSynOptions{
|
||||
WS: h.rcvWndScale,
|
||||
|
@ -363,7 +368,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
|
|||
TSVal: h.ep.timestamp(),
|
||||
TSEcr: h.ep.recentTimestamp(),
|
||||
SACKPermitted: h.ep.sackPermitted,
|
||||
MSS: h.ep.amss,
|
||||
MSS: amss,
|
||||
}
|
||||
h.ep.sendSynTCP(&s.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
|
||||
return nil
|
||||
|
@ -530,6 +535,7 @@ func (h *handshake) execute() *tcpip.Error {
|
|||
|
||||
// Send the initial SYN segment and loop until the handshake is
|
||||
// completed.
|
||||
h.ep.mu.Lock()
|
||||
h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route)
|
||||
|
||||
synOpts := header.TCPSynOptions{
|
||||
|
@ -540,6 +546,7 @@ func (h *handshake) execute() *tcpip.Error {
|
|||
SACKPermitted: bool(sackEnabled),
|
||||
MSS: h.ep.amss,
|
||||
}
|
||||
h.ep.mu.Unlock()
|
||||
|
||||
// Execute is also called in a listen context so we want to make sure we
|
||||
// only send the TS/SACK option when we received the TS/SACK in the
|
||||
|
|
|
@ -959,15 +959,18 @@ func (e *endpoint) initialReceiveWindow() int {
|
|||
// ModerateRecvBuf adjusts the receive buffer and the advertised window
|
||||
// based on the number of bytes copied to user space.
|
||||
func (e *endpoint) ModerateRecvBuf(copied int) {
|
||||
e.mu.RLock()
|
||||
e.rcvListMu.Lock()
|
||||
if e.rcvAutoParams.disabled {
|
||||
e.rcvListMu.Unlock()
|
||||
e.mu.RUnlock()
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
if rtt := e.rcvAutoParams.rtt; rtt == 0 || now.Sub(e.rcvAutoParams.measureTime) < rtt {
|
||||
e.rcvAutoParams.copied += copied
|
||||
e.rcvListMu.Unlock()
|
||||
e.mu.RUnlock()
|
||||
return
|
||||
}
|
||||
prevRTTCopied := e.rcvAutoParams.copied + copied
|
||||
|
@ -1008,7 +1011,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
|
|||
e.rcvBufSize = rcvWnd
|
||||
availAfter := e.receiveBufferAvailableLocked()
|
||||
mask := uint32(notifyReceiveWindowChanged)
|
||||
if crossed, above := e.windowCrossedACKThreshold(availAfter - availBefore); crossed && above {
|
||||
if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
|
||||
mask |= notifyNonZeroReceiveWindow
|
||||
}
|
||||
e.notifyProtocolGoroutine(mask)
|
||||
|
@ -1023,6 +1026,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
|
|||
e.rcvAutoParams.measureTime = now
|
||||
e.rcvAutoParams.copied = 0
|
||||
e.rcvListMu.Unlock()
|
||||
e.mu.RUnlock()
|
||||
}
|
||||
|
||||
// IPTables implements tcpip.Endpoint.IPTables.
|
||||
|
@ -1052,7 +1056,6 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
|
|||
|
||||
v, err := e.readLocked()
|
||||
e.rcvListMu.Unlock()
|
||||
|
||||
e.mu.RUnlock()
|
||||
|
||||
if err == tcpip.ErrClosedForReceive {
|
||||
|
@ -1085,7 +1088,7 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
|
|||
// enough buffer space, to either fit an aMSS or half a receive buffer
|
||||
// (whichever smaller), then notify the protocol goroutine to send a
|
||||
// window update.
|
||||
if crossed, above := e.windowCrossedACKThreshold(len(v)); crossed && above {
|
||||
if crossed, above := e.windowCrossedACKThresholdLocked(len(v)); crossed && above {
|
||||
e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
|
||||
}
|
||||
|
||||
|
@ -1303,9 +1306,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
|
|||
return num, tcpip.ControlMessages{}, nil
|
||||
}
|
||||
|
||||
// windowCrossedACKThreshold checks if the receive window to be announced now
|
||||
// would be under aMSS or under half receive buffer, whichever smaller. This is
|
||||
// useful as a receive side silly window syndrome prevention mechanism. If
|
||||
// windowCrossedACKThresholdLocked checks if the receive window to be announced
|
||||
// now would be under aMSS or under half receive buffer, whichever smaller. This
|
||||
// is useful as a receive side silly window syndrome prevention mechanism. If
|
||||
// window grows to reasonable value, we should send ACK to the sender to inform
|
||||
// the rx space is now large. We also want ensure a series of small read()'s
|
||||
// won't trigger a flood of spurious tiny ACK's.
|
||||
|
@ -1316,7 +1319,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
|
|||
// crossed will be true if the window size crossed the ACK threshold.
|
||||
// above will be true if the new window is >= ACK threshold and false
|
||||
// otherwise.
|
||||
func (e *endpoint) windowCrossedACKThreshold(deltaBefore int) (crossed bool, above bool) {
|
||||
//
|
||||
// Precondition: e.mu and e.rcvListMu must be held.
|
||||
func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) {
|
||||
newAvail := e.receiveBufferAvailableLocked()
|
||||
oldAvail := newAvail - deltaBefore
|
||||
if oldAvail < 0 {
|
||||
|
@ -1379,6 +1384,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
|
|||
|
||||
mask := uint32(notifyReceiveWindowChanged)
|
||||
|
||||
e.mu.RLock()
|
||||
e.rcvListMu.Lock()
|
||||
|
||||
// Make sure the receive buffer size allows us to send a
|
||||
|
@ -1405,11 +1411,11 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
|
|||
// Immediately send an ACK to uncork the sender silly window
|
||||
// syndrome prevetion, when our available space grows above aMSS
|
||||
// or half receive buffer, whichever smaller.
|
||||
if crossed, above := e.windowCrossedACKThreshold(availAfter - availBefore); crossed && above {
|
||||
if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
|
||||
mask |= notifyNonZeroReceiveWindow
|
||||
}
|
||||
e.rcvListMu.Unlock()
|
||||
|
||||
e.mu.RUnlock()
|
||||
e.notifyProtocolGoroutine(mask)
|
||||
return nil
|
||||
|
||||
|
@ -2414,13 +2420,14 @@ func (e *endpoint) updateSndBufferUsage(v int) {
|
|||
// to be read, or when the connection is closed for receiving (in which case
|
||||
// s will be nil).
|
||||
func (e *endpoint) readyToRead(s *segment) {
|
||||
e.mu.RLock()
|
||||
e.rcvListMu.Lock()
|
||||
if s != nil {
|
||||
s.incRef()
|
||||
e.rcvBufUsed += s.data.Size()
|
||||
// Increase counter if the receive window falls down below MSS
|
||||
// or half receive buffer size, whichever smaller.
|
||||
if crossed, above := e.windowCrossedACKThreshold(-s.data.Size()); crossed && !above {
|
||||
if crossed, above := e.windowCrossedACKThresholdLocked(-s.data.Size()); crossed && !above {
|
||||
e.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
|
||||
}
|
||||
e.rcvList.PushBack(s)
|
||||
|
@ -2428,7 +2435,7 @@ func (e *endpoint) readyToRead(s *segment) {
|
|||
e.rcvClosed = true
|
||||
}
|
||||
e.rcvListMu.Unlock()
|
||||
|
||||
e.mu.RUnlock()
|
||||
e.waiterQueue.Notify(waiter.EventIn)
|
||||
}
|
||||
|
||||
|
|
|
@ -1349,6 +1349,21 @@ TEST_P(SimpleTcpSocketTest, RecvOnClosedSocket) {
|
|||
SyscallFailsWithErrno(ENOTCONN));
|
||||
}
|
||||
|
||||
TEST_P(SimpleTcpSocketTest, TCPConnectSoRcvBufRace) {
|
||||
auto s = ASSERT_NO_ERRNO_AND_VALUE(
|
||||
Socket(GetParam(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
|
||||
sockaddr_storage addr =
|
||||
ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
|
||||
socklen_t addrlen = sizeof(addr);
|
||||
|
||||
RetryEINTR(connect)(s.get(), reinterpret_cast<struct sockaddr*>(&addr),
|
||||
addrlen);
|
||||
int buf_sz = 1 << 18;
|
||||
EXPECT_THAT(
|
||||
setsockopt(s.get(), SOL_SOCKET, SO_RCVBUF, &buf_sz, sizeof(buf_sz)),
|
||||
SyscallSucceedsWithValue(0));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest,
|
||||
::testing::Values(AF_INET, AF_INET6));
|
||||
|
||||
|
|
Loading…
Reference in New Issue