Fix for panic in endpoint.Close().

When sending a RST on shutdown we need to double check the
state after acquiring the work mutex as the endpoint could
have transitioned out of a connected state from the time
we checked it and we acquired the workMutex.

I added two tests but sadly neither reproduce the panic. I am
going to leave the tests in as they are good to have anyway.

PiperOrigin-RevId: 292393800
This commit is contained in:
Bhasker Hariharan 2020-01-30 11:48:36 -08:00 committed by gVisor bot
parent 757b2b87fe
commit 4ee64a248e
5 changed files with 98 additions and 2 deletions

View File

@ -91,6 +91,7 @@ go_test(
tags = ["flaky"],
deps = [
":tcp",
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",

View File

@ -2047,8 +2047,14 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
// work mutex is available.
if e.workMu.TryLock() {
e.mu.Lock()
e.resetConnectionLocked(tcpip.ErrConnectionAborted)
e.notifyProtocolGoroutine(notifyTickleWorker)
// We need to double check here to make
// sure worker has not transitioned the
// endpoint out of a connected state
// before trying to send a reset.
if e.EndpointState().connected() {
e.resetConnectionLocked(tcpip.ErrConnectionAborted)
e.notifyProtocolGoroutine(notifyTickleWorker)
}
e.mu.Unlock()
e.workMu.Unlock()
} else {

View File

@ -21,6 +21,7 @@ import (
"testing"
"time"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@ -6913,3 +6914,57 @@ func TestTCPDeferAcceptTimeout(t *testing.T) {
checker.SeqNum(uint32(iss+1)),
checker.AckNum(uint32(irs+5))))
}
func TestResetDuringClose(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
iss := seqnum.Value(789)
c.CreateConnected(iss, 30000, -1 /* epRecvBuf */)
// Send some data to make sure there is some unread
// data to trigger a reset on c.Close.
irs := c.IRS
c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
SeqNum: iss.Add(1),
AckNum: irs.Add(1),
RcvWnd: 30000,
})
// Receive ACK for the data we sent.
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
checker.SeqNum(uint32(irs.Add(1))),
checker.AckNum(uint32(iss.Add(5)))))
// Close in a separate goroutine so that we can trigger
// a race with the RST we send below. This should not
// panic due to the route being released depeding on
// whether Close() sends an active RST or the RST sent
// below is processed by the worker first.
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
SeqNum: iss.Add(5),
AckNum: c.IRS.Add(5),
RcvWnd: 30000,
Flags: header.TCPFlagRst,
})
}()
wg.Add(1)
go func() {
defer wg.Done()
c.EP.Close()
}()
wg.Wait()
}

View File

@ -2173,6 +2173,7 @@ cc_library(
":socket_test_util",
"//test/util:test_util",
"//test/util:thread_util",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/time",
"@com_google_googletest//:gtest",
],

View File

@ -24,6 +24,7 @@
#include <sys/un.h>
#include "gtest/gtest.h"
#include "absl/memory/memory.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "test/syscalls/linux/socket_test_util.h"
@ -875,5 +876,37 @@ TEST_P(TCPSocketPairTest, SetTCPUserTimeoutAboveZero) {
EXPECT_EQ(get, kAbove);
}
TEST_P(TCPSocketPairTest, TCPResetDuringClose_NoRandomSave) {
DisableSave ds; // Too many syscalls.
constexpr int kThreadCount = 1000;
std::unique_ptr<ScopedThread> instances[kThreadCount];
for (int i = 0; i < kThreadCount; i++) {
instances[i] = absl::make_unique<ScopedThread>([&]() {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
ScopedThread t([&]() {
// Close one end to trigger sending of a FIN.
struct pollfd poll_fd = {sockets->second_fd(), POLLIN | POLLHUP, 0};
// Wait up to 20 seconds for the data.
constexpr int kPollTimeoutMs = 20000;
ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs),
SyscallSucceedsWithValue(1));
ASSERT_THAT(close(sockets->release_second_fd()), SyscallSucceeds());
});
// Send some data then close.
constexpr char kStr[] = "abc";
ASSERT_THAT(write(sockets->first_fd(), kStr, 3),
SyscallSucceedsWithValue(3));
absl::SleepFor(absl::Milliseconds(10));
ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds());
t.Join();
});
}
for (int i = 0; i < kThreadCount; i++) {
instances[i]->Join();
}
}
} // namespace testing
} // namespace gvisor