Fix panic due to early transition to Closed.

The code in rcv.consumeSegment incorrectly transitions to
CLOSED state from LAST-ACK before the final ACK for the FIN.

Further if receiving a segment changes a socket to a closed state
then we should not invoke the sender as the socket is now closed
and sending any segments is incorrect.

PiperOrigin-RevId: 283625300
This commit is contained in:
Bhasker Hariharan 2019-12-03 14:40:22 -08:00 committed by gVisor bot
parent 43643752f0
commit 27e2c4ddca
5 changed files with 222 additions and 15 deletions

View File

@ -953,20 +953,6 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
func (e *endpoint) handleSegments() *tcpip.Error {
checkRequeue := true
for i := 0; i < maxSegmentsPerWake; i++ {
e.mu.RLock()
state := e.state
e.mu.RUnlock()
if state == StateClose {
// When we get into StateClose while processing from the queue,
// return immediately and let the protocolMainloop handle it.
//
// We can reach StateClose only while processing a previous segment
// or a notification from the protocolMainLoop (caller goroutine).
// This means that with this return, the segment dequeue below can
// never occur on a closed endpoint.
return nil
}
s := e.segmentQueue.dequeue()
if s == nil {
checkRequeue = false
@ -1024,6 +1010,24 @@ func (e *endpoint) handleSegments() *tcpip.Error {
s.decRef()
continue
}
// Now check if the received segment has caused us to transition
// to a CLOSED state, if yes then terminate processing and do
// not invoke the sender.
e.mu.RLock()
state := e.state
e.mu.RUnlock()
if state == StateClose {
// When we get into StateClose while processing from the queue,
// return immediately and let the protocolMainloop handle it.
//
// We can reach StateClose only while processing a previous segment
// or a notification from the protocolMainLoop (caller goroutine).
// This means that with this return, the segment dequeue below can
// never occur on a closed endpoint.
s.decRef()
return nil
}
e.snd.handleRcvdSegment(s)
}
s.decRef()

View File

@ -205,7 +205,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
// Handle ACK (not FIN-ACK, which we handled above) during one of the
// shutdown states.
if s.flagIsSet(header.TCPFlagAck) {
if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt {
r.ep.mu.Lock()
switch r.ep.state {
case StateFinWait1:

View File

@ -5632,6 +5632,7 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) {
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
SeqNum: iss,
RcvWnd: 30000,
})
// Receive the SYN-ACK reply.
@ -5750,6 +5751,7 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) {
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
SeqNum: iss,
RcvWnd: 30000,
})
// Receive the SYN-ACK reply.
@ -5856,6 +5858,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
SeqNum: iss,
RcvWnd: 30000,
})
// Receive the SYN-ACK reply.
@ -5929,6 +5932,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
SeqNum: iss,
RcvWnd: 30000,
})
c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second)
@ -5941,6 +5945,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
SeqNum: iss,
RcvWnd: 30000,
})
// Receive the SYN-ACK reply.
@ -6007,6 +6012,7 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
SeqNum: iss,
RcvWnd: 30000,
})
// Receive the SYN-ACK reply.
@ -6115,3 +6121,176 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
checker.AckNum(uint32(ackHeaders.SeqNum)),
checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck)))
}
func TestTCPCloseWithData(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
// Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed
// after 5 seconds in TIME_WAIT state.
tcpTimeWaitTimeout := 5 * time.Second
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err)
}
wq := &waiter.Queue{}
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
t.Fatalf("NewEndpoint failed: %s", err)
}
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
t.Fatalf("Listen failed: %s", err)
}
// Send a SYN request.
iss := seqnum.Value(789)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
SeqNum: iss,
RcvWnd: 30000,
})
// Receive the SYN-ACK reply.
b := c.GetPacket()
tcpHdr := header.TCP(header.IPv4(b).Payload())
c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
ackHeaders := &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
Flags: header.TCPFlagAck,
SeqNum: iss + 1,
AckNum: c.IRS + 1,
RcvWnd: 30000,
}
// Send ACK.
c.SendPacket(nil, ackHeaders)
// Try to accept the connection.
we, ch := waiter.NewChannelEntry(nil)
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept()
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
c.EP, _, err = ep.Accept()
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
t.Fatalf("Timed out waiting for accept")
}
}
// Now trigger a passive close by sending a FIN.
finHeaders := &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
Flags: header.TCPFlagAck | header.TCPFlagFin,
SeqNum: iss + 1,
AckNum: c.IRS + 2,
RcvWnd: 30000,
}
c.SendPacket(nil, finHeaders)
// Get the ACK to the FIN we just sent.
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.SeqNum(uint32(c.IRS+1)),
checker.AckNum(uint32(iss)+2),
checker.TCPFlags(header.TCPFlagAck)))
// Now write a few bytes and then close the endpoint.
data := []byte{1, 2, 3}
view := buffer.NewView(len(data))
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
// Check that data is received.
b = c.GetPacket()
checker.IPv4(t, b,
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
checker.SeqNum(uint32(c.IRS)+1),
checker.AckNum(uint32(iss)+2), // Acknum is initial sequence number + 1
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
t.Errorf("got data = %x, want = %x", p, data)
}
c.EP.Close()
// Check the FIN.
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.SeqNum(uint32(c.IRS+1)+uint32(len(data))),
checker.AckNum(uint32(iss+2)),
checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
// First send a partial ACK.
ackHeaders = &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
Flags: header.TCPFlagAck,
SeqNum: iss + 2,
AckNum: c.IRS + 1 + seqnum.Value(len(data)-1),
RcvWnd: 30000,
}
c.SendPacket(nil, ackHeaders)
// Now send a full ACK.
ackHeaders = &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
Flags: header.TCPFlagAck,
SeqNum: iss + 2,
AckNum: c.IRS + 1 + seqnum.Value(len(data)),
RcvWnd: 30000,
}
c.SendPacket(nil, ackHeaders)
// Now ACK the FIN.
ackHeaders.AckNum++
c.SendPacket(nil, ackHeaders)
// Now send an ACK and we should get a RST back as the endpoint should
// be in CLOSED state.
ackHeaders = &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
Flags: header.TCPFlagAck,
SeqNum: iss + 2,
AckNum: c.IRS + 1 + seqnum.Value(len(data)),
RcvWnd: 30000,
}
c.SendPacket(nil, ackHeaders)
// Check the RST.
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.SeqNum(uint32(ackHeaders.AckNum)),
checker.AckNum(uint32(ackHeaders.SeqNum)),
checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck)))
}

View File

@ -2142,6 +2142,7 @@ cc_library(
":socket_test_util",
"//test/util:test_util",
"//test/util:thread_util",
"@com_google_absl//absl/time",
"@com_google_googletest//:gtest",
],
alwayslink = 1,

View File

@ -24,6 +24,8 @@
#include <sys/un.h>
#include "gtest/gtest.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
@ -789,5 +791,26 @@ TEST_P(TCPSocketPairTest, SetTCPLingerTimeout) {
EXPECT_EQ(get, kTCPLingerTimeout);
}
TEST_P(TCPSocketPairTest, TestTCPCloseWithData) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
ScopedThread t([&]() {
// Close one end to trigger sending of a FIN.
ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_WR), SyscallSucceeds());
char buf[3];
ASSERT_THAT(read(sockets->second_fd(), buf, 3),
SyscallSucceedsWithValue(3));
absl::SleepFor(absl::Milliseconds(50));
ASSERT_THAT(close(sockets->release_second_fd()), SyscallSucceeds());
});
absl::SleepFor(absl::Milliseconds(50));
// Send some data then close.
constexpr char kStr[] = "abc";
ASSERT_THAT(write(sockets->first_fd(), kStr, 3), SyscallSucceedsWithValue(3));
t.Join();
ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds());
}
} // namespace testing
} // namespace gvisor