Fix data-race when reading/writing e.amss.

PiperOrigin-RevId: 298451319
This commit is contained in:
Bhasker Hariharan 2020-03-02 14:43:52 -08:00 committed by gVisor bot
parent 8821a7104f
commit 3310175250
3 changed files with 42 additions and 13 deletions

View File

@ -295,6 +295,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
h.state = handshakeSynRcvd
h.ep.mu.Lock()
ttl := h.ep.ttl
amss := h.ep.amss
h.ep.setEndpointState(StateSynRecv)
h.ep.mu.Unlock()
synOpts := header.TCPSynOptions{
@ -307,7 +308,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
// permits SACK. This is not explicitly defined in the RFC but
// this is the behaviour implemented by Linux.
SACKPermitted: rcvSynOpts.SACKPermitted,
MSS: h.ep.amss,
MSS: amss,
}
if ttl == 0 {
ttl = s.route.DefaultTTL()
@ -356,6 +357,10 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
h.ep.mu.RLock()
amss := h.ep.amss
h.ep.mu.RUnlock()
h.resetState()
synOpts := header.TCPSynOptions{
WS: h.rcvWndScale,
@ -363,7 +368,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
TSVal: h.ep.timestamp(),
TSEcr: h.ep.recentTimestamp(),
SACKPermitted: h.ep.sackPermitted,
MSS: h.ep.amss,
MSS: amss,
}
h.ep.sendSynTCP(&s.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
return nil
@ -530,6 +535,7 @@ func (h *handshake) execute() *tcpip.Error {
// Send the initial SYN segment and loop until the handshake is
// completed.
h.ep.mu.Lock()
h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route)
synOpts := header.TCPSynOptions{
@ -540,6 +546,7 @@ func (h *handshake) execute() *tcpip.Error {
SACKPermitted: bool(sackEnabled),
MSS: h.ep.amss,
}
h.ep.mu.Unlock()
// Execute is also called in a listen context so we want to make sure we
// only send the TS/SACK option when we received the TS/SACK in the

View File

@ -959,15 +959,18 @@ func (e *endpoint) initialReceiveWindow() int {
// ModerateRecvBuf adjusts the receive buffer and the advertised window
// based on the number of bytes copied to user space.
func (e *endpoint) ModerateRecvBuf(copied int) {
e.mu.RLock()
e.rcvListMu.Lock()
if e.rcvAutoParams.disabled {
e.rcvListMu.Unlock()
e.mu.RUnlock()
return
}
now := time.Now()
if rtt := e.rcvAutoParams.rtt; rtt == 0 || now.Sub(e.rcvAutoParams.measureTime) < rtt {
e.rcvAutoParams.copied += copied
e.rcvListMu.Unlock()
e.mu.RUnlock()
return
}
prevRTTCopied := e.rcvAutoParams.copied + copied
@ -1008,7 +1011,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
e.rcvBufSize = rcvWnd
availAfter := e.receiveBufferAvailableLocked()
mask := uint32(notifyReceiveWindowChanged)
if crossed, above := e.windowCrossedACKThreshold(availAfter - availBefore); crossed && above {
if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
mask |= notifyNonZeroReceiveWindow
}
e.notifyProtocolGoroutine(mask)
@ -1023,6 +1026,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
e.rcvAutoParams.measureTime = now
e.rcvAutoParams.copied = 0
e.rcvListMu.Unlock()
e.mu.RUnlock()
}
// IPTables implements tcpip.Endpoint.IPTables.
@ -1052,7 +1056,6 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
v, err := e.readLocked()
e.rcvListMu.Unlock()
e.mu.RUnlock()
if err == tcpip.ErrClosedForReceive {
@ -1085,7 +1088,7 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
// enough buffer space, to either fit an aMSS or half a receive buffer
// (whichever smaller), then notify the protocol goroutine to send a
// window update.
if crossed, above := e.windowCrossedACKThreshold(len(v)); crossed && above {
if crossed, above := e.windowCrossedACKThresholdLocked(len(v)); crossed && above {
e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
}
@ -1303,9 +1306,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
return num, tcpip.ControlMessages{}, nil
}
// windowCrossedACKThreshold checks if the receive window to be announced now
// would be under aMSS or under half receive buffer, whichever smaller. This is
// useful as a receive side silly window syndrome prevention mechanism. If
// windowCrossedACKThresholdLocked checks if the receive window to be announced
// now would be under aMSS or under half receive buffer, whichever smaller. This
// is useful as a receive side silly window syndrome prevention mechanism. If
// window grows to reasonable value, we should send ACK to the sender to inform
// the rx space is now large. We also want ensure a series of small read()'s
// won't trigger a flood of spurious tiny ACK's.
@ -1316,7 +1319,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
// crossed will be true if the window size crossed the ACK threshold.
// above will be true if the new window is >= ACK threshold and false
// otherwise.
func (e *endpoint) windowCrossedACKThreshold(deltaBefore int) (crossed bool, above bool) {
//
// Precondition: e.mu and e.rcvListMu must be held.
func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) {
newAvail := e.receiveBufferAvailableLocked()
oldAvail := newAvail - deltaBefore
if oldAvail < 0 {
@ -1379,6 +1384,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
mask := uint32(notifyReceiveWindowChanged)
e.mu.RLock()
e.rcvListMu.Lock()
// Make sure the receive buffer size allows us to send a
@ -1405,11 +1411,11 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
// Immediately send an ACK to uncork the sender silly window
// syndrome prevetion, when our available space grows above aMSS
// or half receive buffer, whichever smaller.
if crossed, above := e.windowCrossedACKThreshold(availAfter - availBefore); crossed && above {
if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
mask |= notifyNonZeroReceiveWindow
}
e.rcvListMu.Unlock()
e.mu.RUnlock()
e.notifyProtocolGoroutine(mask)
return nil
@ -2414,13 +2420,14 @@ func (e *endpoint) updateSndBufferUsage(v int) {
// to be read, or when the connection is closed for receiving (in which case
// s will be nil).
func (e *endpoint) readyToRead(s *segment) {
e.mu.RLock()
e.rcvListMu.Lock()
if s != nil {
s.incRef()
e.rcvBufUsed += s.data.Size()
// Increase counter if the receive window falls down below MSS
// or half receive buffer size, whichever smaller.
if crossed, above := e.windowCrossedACKThreshold(-s.data.Size()); crossed && !above {
if crossed, above := e.windowCrossedACKThresholdLocked(-s.data.Size()); crossed && !above {
e.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
}
e.rcvList.PushBack(s)
@ -2428,7 +2435,7 @@ func (e *endpoint) readyToRead(s *segment) {
e.rcvClosed = true
}
e.rcvListMu.Unlock()
e.mu.RUnlock()
e.waiterQueue.Notify(waiter.EventIn)
}

View File

@ -1349,6 +1349,21 @@ TEST_P(SimpleTcpSocketTest, RecvOnClosedSocket) {
SyscallFailsWithErrno(ENOTCONN));
}
TEST_P(SimpleTcpSocketTest, TCPConnectSoRcvBufRace) {
auto s = ASSERT_NO_ERRNO_AND_VALUE(
Socket(GetParam(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
sockaddr_storage addr =
ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
socklen_t addrlen = sizeof(addr);
RetryEINTR(connect)(s.get(), reinterpret_cast<struct sockaddr*>(&addr),
addrlen);
int buf_sz = 1 << 18;
EXPECT_THAT(
setsockopt(s.get(), SOL_SOCKET, SO_RCVBUF, &buf_sz, sizeof(buf_sz)),
SyscallSucceedsWithValue(0));
}
INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest,
::testing::Values(AF_INET, AF_INET6));