Fix data-race when reading/writing e.amss.
PiperOrigin-RevId: 298451319
This commit is contained in:
parent
8821a7104f
commit
3310175250
|
@ -295,6 +295,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
|
||||||
h.state = handshakeSynRcvd
|
h.state = handshakeSynRcvd
|
||||||
h.ep.mu.Lock()
|
h.ep.mu.Lock()
|
||||||
ttl := h.ep.ttl
|
ttl := h.ep.ttl
|
||||||
|
amss := h.ep.amss
|
||||||
h.ep.setEndpointState(StateSynRecv)
|
h.ep.setEndpointState(StateSynRecv)
|
||||||
h.ep.mu.Unlock()
|
h.ep.mu.Unlock()
|
||||||
synOpts := header.TCPSynOptions{
|
synOpts := header.TCPSynOptions{
|
||||||
|
@ -307,7 +308,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
|
||||||
// permits SACK. This is not explicitly defined in the RFC but
|
// permits SACK. This is not explicitly defined in the RFC but
|
||||||
// this is the behaviour implemented by Linux.
|
// this is the behaviour implemented by Linux.
|
||||||
SACKPermitted: rcvSynOpts.SACKPermitted,
|
SACKPermitted: rcvSynOpts.SACKPermitted,
|
||||||
MSS: h.ep.amss,
|
MSS: amss,
|
||||||
}
|
}
|
||||||
if ttl == 0 {
|
if ttl == 0 {
|
||||||
ttl = s.route.DefaultTTL()
|
ttl = s.route.DefaultTTL()
|
||||||
|
@ -356,6 +357,10 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
|
||||||
return tcpip.ErrInvalidEndpointState
|
return tcpip.ErrInvalidEndpointState
|
||||||
}
|
}
|
||||||
|
|
||||||
|
h.ep.mu.RLock()
|
||||||
|
amss := h.ep.amss
|
||||||
|
h.ep.mu.RUnlock()
|
||||||
|
|
||||||
h.resetState()
|
h.resetState()
|
||||||
synOpts := header.TCPSynOptions{
|
synOpts := header.TCPSynOptions{
|
||||||
WS: h.rcvWndScale,
|
WS: h.rcvWndScale,
|
||||||
|
@ -363,7 +368,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
|
||||||
TSVal: h.ep.timestamp(),
|
TSVal: h.ep.timestamp(),
|
||||||
TSEcr: h.ep.recentTimestamp(),
|
TSEcr: h.ep.recentTimestamp(),
|
||||||
SACKPermitted: h.ep.sackPermitted,
|
SACKPermitted: h.ep.sackPermitted,
|
||||||
MSS: h.ep.amss,
|
MSS: amss,
|
||||||
}
|
}
|
||||||
h.ep.sendSynTCP(&s.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
|
h.ep.sendSynTCP(&s.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
|
||||||
return nil
|
return nil
|
||||||
|
@ -530,6 +535,7 @@ func (h *handshake) execute() *tcpip.Error {
|
||||||
|
|
||||||
// Send the initial SYN segment and loop until the handshake is
|
// Send the initial SYN segment and loop until the handshake is
|
||||||
// completed.
|
// completed.
|
||||||
|
h.ep.mu.Lock()
|
||||||
h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route)
|
h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route)
|
||||||
|
|
||||||
synOpts := header.TCPSynOptions{
|
synOpts := header.TCPSynOptions{
|
||||||
|
@ -540,6 +546,7 @@ func (h *handshake) execute() *tcpip.Error {
|
||||||
SACKPermitted: bool(sackEnabled),
|
SACKPermitted: bool(sackEnabled),
|
||||||
MSS: h.ep.amss,
|
MSS: h.ep.amss,
|
||||||
}
|
}
|
||||||
|
h.ep.mu.Unlock()
|
||||||
|
|
||||||
// Execute is also called in a listen context so we want to make sure we
|
// Execute is also called in a listen context so we want to make sure we
|
||||||
// only send the TS/SACK option when we received the TS/SACK in the
|
// only send the TS/SACK option when we received the TS/SACK in the
|
||||||
|
|
|
@ -959,15 +959,18 @@ func (e *endpoint) initialReceiveWindow() int {
|
||||||
// ModerateRecvBuf adjusts the receive buffer and the advertised window
|
// ModerateRecvBuf adjusts the receive buffer and the advertised window
|
||||||
// based on the number of bytes copied to user space.
|
// based on the number of bytes copied to user space.
|
||||||
func (e *endpoint) ModerateRecvBuf(copied int) {
|
func (e *endpoint) ModerateRecvBuf(copied int) {
|
||||||
|
e.mu.RLock()
|
||||||
e.rcvListMu.Lock()
|
e.rcvListMu.Lock()
|
||||||
if e.rcvAutoParams.disabled {
|
if e.rcvAutoParams.disabled {
|
||||||
e.rcvListMu.Unlock()
|
e.rcvListMu.Unlock()
|
||||||
|
e.mu.RUnlock()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
if rtt := e.rcvAutoParams.rtt; rtt == 0 || now.Sub(e.rcvAutoParams.measureTime) < rtt {
|
if rtt := e.rcvAutoParams.rtt; rtt == 0 || now.Sub(e.rcvAutoParams.measureTime) < rtt {
|
||||||
e.rcvAutoParams.copied += copied
|
e.rcvAutoParams.copied += copied
|
||||||
e.rcvListMu.Unlock()
|
e.rcvListMu.Unlock()
|
||||||
|
e.mu.RUnlock()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
prevRTTCopied := e.rcvAutoParams.copied + copied
|
prevRTTCopied := e.rcvAutoParams.copied + copied
|
||||||
|
@ -1008,7 +1011,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
|
||||||
e.rcvBufSize = rcvWnd
|
e.rcvBufSize = rcvWnd
|
||||||
availAfter := e.receiveBufferAvailableLocked()
|
availAfter := e.receiveBufferAvailableLocked()
|
||||||
mask := uint32(notifyReceiveWindowChanged)
|
mask := uint32(notifyReceiveWindowChanged)
|
||||||
if crossed, above := e.windowCrossedACKThreshold(availAfter - availBefore); crossed && above {
|
if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
|
||||||
mask |= notifyNonZeroReceiveWindow
|
mask |= notifyNonZeroReceiveWindow
|
||||||
}
|
}
|
||||||
e.notifyProtocolGoroutine(mask)
|
e.notifyProtocolGoroutine(mask)
|
||||||
|
@ -1023,6 +1026,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
|
||||||
e.rcvAutoParams.measureTime = now
|
e.rcvAutoParams.measureTime = now
|
||||||
e.rcvAutoParams.copied = 0
|
e.rcvAutoParams.copied = 0
|
||||||
e.rcvListMu.Unlock()
|
e.rcvListMu.Unlock()
|
||||||
|
e.mu.RUnlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// IPTables implements tcpip.Endpoint.IPTables.
|
// IPTables implements tcpip.Endpoint.IPTables.
|
||||||
|
@ -1052,7 +1056,6 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
|
||||||
|
|
||||||
v, err := e.readLocked()
|
v, err := e.readLocked()
|
||||||
e.rcvListMu.Unlock()
|
e.rcvListMu.Unlock()
|
||||||
|
|
||||||
e.mu.RUnlock()
|
e.mu.RUnlock()
|
||||||
|
|
||||||
if err == tcpip.ErrClosedForReceive {
|
if err == tcpip.ErrClosedForReceive {
|
||||||
|
@ -1085,7 +1088,7 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
|
||||||
// enough buffer space, to either fit an aMSS or half a receive buffer
|
// enough buffer space, to either fit an aMSS or half a receive buffer
|
||||||
// (whichever smaller), then notify the protocol goroutine to send a
|
// (whichever smaller), then notify the protocol goroutine to send a
|
||||||
// window update.
|
// window update.
|
||||||
if crossed, above := e.windowCrossedACKThreshold(len(v)); crossed && above {
|
if crossed, above := e.windowCrossedACKThresholdLocked(len(v)); crossed && above {
|
||||||
e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
|
e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1303,9 +1306,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
|
||||||
return num, tcpip.ControlMessages{}, nil
|
return num, tcpip.ControlMessages{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// windowCrossedACKThreshold checks if the receive window to be announced now
|
// windowCrossedACKThresholdLocked checks if the receive window to be announced
|
||||||
// would be under aMSS or under half receive buffer, whichever smaller. This is
|
// now would be under aMSS or under half receive buffer, whichever smaller. This
|
||||||
// useful as a receive side silly window syndrome prevention mechanism. If
|
// is useful as a receive side silly window syndrome prevention mechanism. If
|
||||||
// window grows to reasonable value, we should send ACK to the sender to inform
|
// window grows to reasonable value, we should send ACK to the sender to inform
|
||||||
// the rx space is now large. We also want ensure a series of small read()'s
|
// the rx space is now large. We also want ensure a series of small read()'s
|
||||||
// won't trigger a flood of spurious tiny ACK's.
|
// won't trigger a flood of spurious tiny ACK's.
|
||||||
|
@ -1316,7 +1319,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
|
||||||
// crossed will be true if the window size crossed the ACK threshold.
|
// crossed will be true if the window size crossed the ACK threshold.
|
||||||
// above will be true if the new window is >= ACK threshold and false
|
// above will be true if the new window is >= ACK threshold and false
|
||||||
// otherwise.
|
// otherwise.
|
||||||
func (e *endpoint) windowCrossedACKThreshold(deltaBefore int) (crossed bool, above bool) {
|
//
|
||||||
|
// Precondition: e.mu and e.rcvListMu must be held.
|
||||||
|
func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) {
|
||||||
newAvail := e.receiveBufferAvailableLocked()
|
newAvail := e.receiveBufferAvailableLocked()
|
||||||
oldAvail := newAvail - deltaBefore
|
oldAvail := newAvail - deltaBefore
|
||||||
if oldAvail < 0 {
|
if oldAvail < 0 {
|
||||||
|
@ -1379,6 +1384,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
|
||||||
|
|
||||||
mask := uint32(notifyReceiveWindowChanged)
|
mask := uint32(notifyReceiveWindowChanged)
|
||||||
|
|
||||||
|
e.mu.RLock()
|
||||||
e.rcvListMu.Lock()
|
e.rcvListMu.Lock()
|
||||||
|
|
||||||
// Make sure the receive buffer size allows us to send a
|
// Make sure the receive buffer size allows us to send a
|
||||||
|
@ -1405,11 +1411,11 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
|
||||||
// Immediately send an ACK to uncork the sender silly window
|
// Immediately send an ACK to uncork the sender silly window
|
||||||
// syndrome prevetion, when our available space grows above aMSS
|
// syndrome prevetion, when our available space grows above aMSS
|
||||||
// or half receive buffer, whichever smaller.
|
// or half receive buffer, whichever smaller.
|
||||||
if crossed, above := e.windowCrossedACKThreshold(availAfter - availBefore); crossed && above {
|
if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
|
||||||
mask |= notifyNonZeroReceiveWindow
|
mask |= notifyNonZeroReceiveWindow
|
||||||
}
|
}
|
||||||
e.rcvListMu.Unlock()
|
e.rcvListMu.Unlock()
|
||||||
|
e.mu.RUnlock()
|
||||||
e.notifyProtocolGoroutine(mask)
|
e.notifyProtocolGoroutine(mask)
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
|
@ -2414,13 +2420,14 @@ func (e *endpoint) updateSndBufferUsage(v int) {
|
||||||
// to be read, or when the connection is closed for receiving (in which case
|
// to be read, or when the connection is closed for receiving (in which case
|
||||||
// s will be nil).
|
// s will be nil).
|
||||||
func (e *endpoint) readyToRead(s *segment) {
|
func (e *endpoint) readyToRead(s *segment) {
|
||||||
|
e.mu.RLock()
|
||||||
e.rcvListMu.Lock()
|
e.rcvListMu.Lock()
|
||||||
if s != nil {
|
if s != nil {
|
||||||
s.incRef()
|
s.incRef()
|
||||||
e.rcvBufUsed += s.data.Size()
|
e.rcvBufUsed += s.data.Size()
|
||||||
// Increase counter if the receive window falls down below MSS
|
// Increase counter if the receive window falls down below MSS
|
||||||
// or half receive buffer size, whichever smaller.
|
// or half receive buffer size, whichever smaller.
|
||||||
if crossed, above := e.windowCrossedACKThreshold(-s.data.Size()); crossed && !above {
|
if crossed, above := e.windowCrossedACKThresholdLocked(-s.data.Size()); crossed && !above {
|
||||||
e.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
|
e.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
|
||||||
}
|
}
|
||||||
e.rcvList.PushBack(s)
|
e.rcvList.PushBack(s)
|
||||||
|
@ -2428,7 +2435,7 @@ func (e *endpoint) readyToRead(s *segment) {
|
||||||
e.rcvClosed = true
|
e.rcvClosed = true
|
||||||
}
|
}
|
||||||
e.rcvListMu.Unlock()
|
e.rcvListMu.Unlock()
|
||||||
|
e.mu.RUnlock()
|
||||||
e.waiterQueue.Notify(waiter.EventIn)
|
e.waiterQueue.Notify(waiter.EventIn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1349,6 +1349,21 @@ TEST_P(SimpleTcpSocketTest, RecvOnClosedSocket) {
|
||||||
SyscallFailsWithErrno(ENOTCONN));
|
SyscallFailsWithErrno(ENOTCONN));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(SimpleTcpSocketTest, TCPConnectSoRcvBufRace) {
|
||||||
|
auto s = ASSERT_NO_ERRNO_AND_VALUE(
|
||||||
|
Socket(GetParam(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
|
||||||
|
sockaddr_storage addr =
|
||||||
|
ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
|
||||||
|
socklen_t addrlen = sizeof(addr);
|
||||||
|
|
||||||
|
RetryEINTR(connect)(s.get(), reinterpret_cast<struct sockaddr*>(&addr),
|
||||||
|
addrlen);
|
||||||
|
int buf_sz = 1 << 18;
|
||||||
|
EXPECT_THAT(
|
||||||
|
setsockopt(s.get(), SOL_SOCKET, SO_RCVBUF, &buf_sz, sizeof(buf_sz)),
|
||||||
|
SyscallSucceedsWithValue(0));
|
||||||
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest,
|
INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest,
|
||||||
::testing::Values(AF_INET, AF_INET6));
|
::testing::Values(AF_INET, AF_INET6));
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue