Fix for a panic due to writing to a closed accept channel.

This can happen because endpoint.Close() closes the accept channel first and
then drains/resets any accepted but not delivered connections. But there can be
connections that are connected but not delivered to the channel as the channel
was full. But closing the channel can cause these writes to fail with a write to
a closed channel.

The correct solution is to abort any connections in SYN-RCVD state and
drain/abort all completed connections before closing the accept channel.

PiperOrigin-RevId: 261951132
This commit is contained in:
Bhasker Hariharan 2019-08-06 10:59:49 -07:00 committed by gVisor bot
parent 704f9610f3
commit dfbc0b0a4c
5 changed files with 261 additions and 28 deletions

View File

@ -96,6 +96,17 @@ type listenContext struct {
hasher hash.Hash
v6only bool
netProto tcpip.NetworkProtocolNumber
// pendingMu protects pendingEndpoints. This should only be accessed
// by the listening endpoint's worker goroutine.
//
// Lock Ordering: listenEP.workerMu -> pendingMu
pendingMu sync.Mutex
// pending is used to wait for all pendingEndpoints to finish when
// a socket is closed.
pending sync.WaitGroup
// pendingEndpoints is a map of all endpoints for which a handshake is
// in progress.
pendingEndpoints map[stack.TransportEndpointID]*endpoint
}
// timeStamp returns an 8-bit timestamp with a granularity of 64 seconds.
@ -133,14 +144,15 @@ func decSynRcvdCount() {
}
// newListenContext creates a new listen context.
func newListenContext(stack *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
l := &listenContext{
stack: stack,
rcvWnd: rcvWnd,
hasher: sha1.New(),
v6only: v6only,
netProto: netProto,
listenEP: listenEP,
stack: stk,
rcvWnd: rcvWnd,
hasher: sha1.New(),
v6only: v6only,
netProto: netProto,
listenEP: listenEP,
pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint),
}
rand.Read(l.nonce[0][:])
@ -253,6 +265,17 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
return nil, err
}
// listenEP is nil when listenContext is used by tcp.Forwarder.
if l.listenEP != nil {
l.listenEP.mu.Lock()
if l.listenEP.state != StateListen {
l.listenEP.mu.Unlock()
return nil, tcpip.ErrConnectionAborted
}
l.addPendingEndpoint(ep)
l.listenEP.mu.Unlock()
}
// Perform the 3-way handshake.
h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow()))
@ -260,6 +283,9 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
if err := h.execute(); err != nil {
ep.stack.Stats().TCP.FailedConnectionAttempts.Increment()
ep.Close()
if l.listenEP != nil {
l.removePendingEndpoint(ep)
}
return nil, err
}
ep.mu.Lock()
@ -274,15 +300,41 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
return ep, nil
}
func (l *listenContext) addPendingEndpoint(n *endpoint) {
l.pendingMu.Lock()
l.pendingEndpoints[n.id] = n
l.pending.Add(1)
l.pendingMu.Unlock()
}
func (l *listenContext) removePendingEndpoint(n *endpoint) {
l.pendingMu.Lock()
delete(l.pendingEndpoints, n.id)
l.pending.Done()
l.pendingMu.Unlock()
}
func (l *listenContext) closeAllPendingEndpoints() {
l.pendingMu.Lock()
for _, n := range l.pendingEndpoints {
n.notifyProtocolGoroutine(notifyClose)
}
l.pendingMu.Unlock()
l.pending.Wait()
}
// deliverAccepted delivers the newly-accepted endpoint to the listener. If the
// endpoint has transitioned out of the listen state, the new endpoint is closed
// instead.
func (e *endpoint) deliverAccepted(n *endpoint) {
e.mu.RLock()
e.mu.Lock()
state := e.state
e.mu.RUnlock()
e.pendingAccepted.Add(1)
defer e.pendingAccepted.Done()
acceptedChan := e.acceptedChan
e.mu.Unlock()
if state == StateListen {
e.acceptedChan <- n
acceptedChan <- n
e.waiterQueue.Notify(waiter.EventIn)
} else {
n.Close()
@ -304,7 +356,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
return
}
ctx.removePendingEndpoint(n)
e.deliverAccepted(n)
}
@ -451,6 +503,11 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// protocolListenLoop is the main loop of a listening TCP endpoint. It runs in
// its own goroutine and is responsible for handling connection requests.
func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
e.mu.Lock()
v6only := e.v6only
e.mu.Unlock()
ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.netProto)
defer func() {
// Mark endpoint as closed. This will prevent goroutines running
// handleSynSegment() from attempting to queue new connections
@ -458,6 +515,9 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
e.mu.Lock()
e.state = StateClose
// close any endpoints in SYN-RCVD state.
ctx.closeAllPendingEndpoints()
// Do cleanup if needed.
e.completeWorkerLocked()
@ -470,12 +530,6 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut)
}()
e.mu.Lock()
v6only := e.v6only
e.mu.Unlock()
ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.netProto)
s := sleep.Sleeper{}
s.AddWaker(&e.notificationWaker, wakerForNotification)
s.AddWaker(&e.newSegmentWaker, wakerForNewSegment)
@ -492,7 +546,6 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
e.handleListenSegment(ctx, s)
s.decRef()
}
synRcvdCount.pending.Wait()
close(e.drainDone)
<-e.undrain
}

View File

@ -570,3 +570,89 @@ func TestV4AcceptOnV4(t *testing.T) {
// Test acceptance.
testV4Accept(t, c)
}
func testV4ListenClose(t *testing.T, c *context.Context) {
// Set the SynRcvd threshold to zero to force a syn cookie based accept
// to happen.
saved := tcp.SynRcvdCountThreshold
defer func() {
tcp.SynRcvdCountThreshold = saved
}()
tcp.SynRcvdCountThreshold = 0
const n = uint16(32)
// Start listening.
if err := c.EP.Listen(int(tcp.SynRcvdCountThreshold + 1)); err != nil {
t.Fatalf("Listen failed: %v", err)
}
irs := seqnum.Value(789)
for i := uint16(0); i < n; i++ {
// Send a SYN request.
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort + i,
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
SeqNum: irs,
RcvWnd: 30000,
})
}
// Each of these ACK's will cause a syn-cookie based connection to be
// accepted and delivered to the listening endpoint.
for i := uint16(0); i < n; i++ {
b := c.GetPacket()
tcp := header.TCP(header.IPv4(b).Payload())
iss := seqnum.Value(tcp.SequenceNumber())
// Send ACK.
c.SendPacket(nil, &context.Headers{
SrcPort: tcp.DestinationPort(),
DstPort: context.StackPort,
Flags: header.TCPFlagAck,
SeqNum: irs + 1,
AckNum: iss + 1,
RcvWnd: 30000,
})
}
// Try to accept the connection.
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
nep, _, err := c.EP.Accept()
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
nep, _, err = c.EP.Accept()
if err != nil {
t.Fatalf("Accept failed: %v", err)
}
case <-time.After(10 * time.Second):
t.Fatalf("Timed out waiting for accept")
}
}
nep.Close()
c.EP.Close()
}
func TestV4ListenCloseOnV4(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
// Create TCP endpoint.
var err *tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
}
// Bind to wildcard.
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %v", err)
}
// Test acceptance.
testV4ListenClose(t, c)
}

View File

@ -362,6 +362,12 @@ type endpoint struct {
// without hearing a response, the connection is closed.
keepalive keepalive
// pendingAccepted is a synchronization primitive used to track number
// of connections that are queued up to be delivered to the accepted
// channel. We use this to ensure that all goroutines blocked on writing
// to the acceptedChan below terminate before we close acceptedChan.
pendingAccepted sync.WaitGroup `state:"nosave"`
// acceptedChan is used by a listening endpoint protocol goroutine to
// send newly accepted connections to the endpoint so that they can be
// read by Accept() calls.
@ -375,7 +381,11 @@ type endpoint struct {
// The goroutine drain completion notification channel.
drainDone chan struct{} `state:"nosave"`
// The goroutine undrain notification channel.
// The goroutine undrain notification channel. This is currently used as
// a way to block the worker goroutines. Today nothing closes/writes
// this channel and this causes any goroutines waiting on this to just
// block. This is used during save/restore to prevent worker goroutines
// from mutating state as it's being saved.
undrain chan struct{} `state:"nosave"`
// probe if not nil is invoked on every received segment. It is passed
@ -575,6 +585,34 @@ func (e *endpoint) Close() {
e.mu.Unlock()
}
// closePendingAcceptableConnections closes all connections that have completed
// handshake but not yet been delivered to the application.
func (e *endpoint) closePendingAcceptableConnectionsLocked() {
done := make(chan struct{})
// Spin a goroutine up as ranging on e.acceptedChan will just block when
// there are no more connections in the channel. Using a non-blocking
// select does not work as it can potentially select the default case
// even when there are pending writes but that are not yet written to
// the channel.
go func() {
defer close(done)
for n := range e.acceptedChan {
n.mu.Lock()
n.resetConnectionLocked(tcpip.ErrConnectionAborted)
n.mu.Unlock()
n.Close()
}
}()
// pendingAccepted(see endpoint.deliverAccepted) tracks the number of
// endpoints which have completed handshake but are not yet written to
// the e.acceptedChan. We wait here till the goroutine above can drain
// all such connections from e.acceptedChan.
e.pendingAccepted.Wait()
close(e.acceptedChan)
<-done
e.acceptedChan = nil
}
// cleanupLocked frees all resources associated with the endpoint. It is called
// after Close() is called and the worker goroutine (if any) is done with its
// work.
@ -582,14 +620,7 @@ func (e *endpoint) cleanupLocked() {
// Close all endpoints that might have been accepted by TCP but not by
// the client.
if e.acceptedChan != nil {
close(e.acceptedChan)
for n := range e.acceptedChan {
n.mu.Lock()
n.resetConnectionLocked(tcpip.ErrConnectionAborted)
n.mu.Unlock()
n.Close()
}
e.acceptedChan = nil
e.closePendingAcceptableConnectionsLocked()
}
e.workerCleanup = false

View File

@ -440,7 +440,9 @@ syscall_test(
)
syscall_test(
size = "medium",
size = "large",
parallel = False,
shard_count = 10,
test = "//test/syscalls/linux:socket_inet_loopback_test",
)

View File

@ -145,6 +145,67 @@ TEST_P(SocketInetLoopbackTest, TCP) {
ASSERT_THAT(shutdown(conn_fd.get(), SHUT_RDWR), SyscallSucceeds());
}
TEST_P(SocketInetLoopbackTest, TCPListenClose) {
auto const& param = GetParam();
TestAddress const& listener = param.listener;
TestAddress const& connector = param.connector;
// Create the listening socket.
FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
sockaddr_storage listen_addr = listener.addr;
ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
listener.addr_len),
SyscallSucceeds());
ASSERT_THAT(listen(listen_fd.get(), 1001), SyscallSucceeds());
// Get the port bound by the listening socket.
socklen_t addrlen = listener.addr_len;
ASSERT_THAT(getsockname(listen_fd.get(),
reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
SyscallSucceeds());
uint16_t const port =
ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
DisableSave ds; // Too many system calls.
sockaddr_storage conn_addr = connector.addr;
ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
constexpr int kFDs = 2048;
constexpr int kThreadCount = 4;
constexpr int kFDsPerThread = kFDs / kThreadCount;
FileDescriptor clients[kFDs];
std::unique_ptr<ScopedThread> threads[kThreadCount];
for (int i = 0; i < kFDs; i++) {
clients[i] = ASSERT_NO_ERRNO_AND_VALUE(
Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
}
for (int i = 0; i < kThreadCount; i++) {
threads[i] = absl::make_unique<ScopedThread>([&connector, &conn_addr,
&clients, i]() {
for (int j = 0; j < kFDsPerThread; j++) {
int k = i * kFDsPerThread + j;
int ret =
connect(clients[k].get(), reinterpret_cast<sockaddr*>(&conn_addr),
connector.addr_len);
if (ret != 0) {
EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
}
}
});
}
for (int i = 0; i < kThreadCount; i++) {
threads[i]->Join();
}
for (int i = 0; i < 32; i++) {
auto accepted =
ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
}
// TODO(b/138400178): Fix cooperative S/R failure when ds.reset() is invoked
// before function end.
// ds.reset()
}
TEST_P(SocketInetLoopbackTest, TCPbacklog) {
auto const& param = GetParam();