Fix panic due to early transition to Closed.
The code in rcv.consumeSegment incorrectly transitions to CLOSED state from LAST-ACK before the final ACK for the FIN. Further if receiving a segment changes a socket to a closed state then we should not invoke the sender as the socket is now closed and sending any segments is incorrect. PiperOrigin-RevId: 283625300
This commit is contained in:
parent
43643752f0
commit
27e2c4ddca
|
@ -953,20 +953,6 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
|
|||
func (e *endpoint) handleSegments() *tcpip.Error {
|
||||
checkRequeue := true
|
||||
for i := 0; i < maxSegmentsPerWake; i++ {
|
||||
e.mu.RLock()
|
||||
state := e.state
|
||||
e.mu.RUnlock()
|
||||
if state == StateClose {
|
||||
// When we get into StateClose while processing from the queue,
|
||||
// return immediately and let the protocolMainloop handle it.
|
||||
//
|
||||
// We can reach StateClose only while processing a previous segment
|
||||
// or a notification from the protocolMainLoop (caller goroutine).
|
||||
// This means that with this return, the segment dequeue below can
|
||||
// never occur on a closed endpoint.
|
||||
return nil
|
||||
}
|
||||
|
||||
s := e.segmentQueue.dequeue()
|
||||
if s == nil {
|
||||
checkRequeue = false
|
||||
|
@ -1024,6 +1010,24 @@ func (e *endpoint) handleSegments() *tcpip.Error {
|
|||
s.decRef()
|
||||
continue
|
||||
}
|
||||
|
||||
// Now check if the received segment has caused us to transition
|
||||
// to a CLOSED state, if yes then terminate processing and do
|
||||
// not invoke the sender.
|
||||
e.mu.RLock()
|
||||
state := e.state
|
||||
e.mu.RUnlock()
|
||||
if state == StateClose {
|
||||
// When we get into StateClose while processing from the queue,
|
||||
// return immediately and let the protocolMainloop handle it.
|
||||
//
|
||||
// We can reach StateClose only while processing a previous segment
|
||||
// or a notification from the protocolMainLoop (caller goroutine).
|
||||
// This means that with this return, the segment dequeue below can
|
||||
// never occur on a closed endpoint.
|
||||
s.decRef()
|
||||
return nil
|
||||
}
|
||||
e.snd.handleRcvdSegment(s)
|
||||
}
|
||||
s.decRef()
|
||||
|
|
|
@ -205,7 +205,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
|
|||
|
||||
// Handle ACK (not FIN-ACK, which we handled above) during one of the
|
||||
// shutdown states.
|
||||
if s.flagIsSet(header.TCPFlagAck) {
|
||||
if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt {
|
||||
r.ep.mu.Lock()
|
||||
switch r.ep.state {
|
||||
case StateFinWait1:
|
||||
|
|
|
@ -5632,6 +5632,7 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) {
|
|||
DstPort: context.StackPort,
|
||||
Flags: header.TCPFlagSyn,
|
||||
SeqNum: iss,
|
||||
RcvWnd: 30000,
|
||||
})
|
||||
|
||||
// Receive the SYN-ACK reply.
|
||||
|
@ -5750,6 +5751,7 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) {
|
|||
DstPort: context.StackPort,
|
||||
Flags: header.TCPFlagSyn,
|
||||
SeqNum: iss,
|
||||
RcvWnd: 30000,
|
||||
})
|
||||
|
||||
// Receive the SYN-ACK reply.
|
||||
|
@ -5856,6 +5858,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
|
|||
DstPort: context.StackPort,
|
||||
Flags: header.TCPFlagSyn,
|
||||
SeqNum: iss,
|
||||
RcvWnd: 30000,
|
||||
})
|
||||
|
||||
// Receive the SYN-ACK reply.
|
||||
|
@ -5929,6 +5932,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
|
|||
DstPort: context.StackPort,
|
||||
Flags: header.TCPFlagSyn,
|
||||
SeqNum: iss,
|
||||
RcvWnd: 30000,
|
||||
})
|
||||
|
||||
c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second)
|
||||
|
@ -5941,6 +5945,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
|
|||
DstPort: context.StackPort,
|
||||
Flags: header.TCPFlagSyn,
|
||||
SeqNum: iss,
|
||||
RcvWnd: 30000,
|
||||
})
|
||||
|
||||
// Receive the SYN-ACK reply.
|
||||
|
@ -6007,6 +6012,7 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
|
|||
DstPort: context.StackPort,
|
||||
Flags: header.TCPFlagSyn,
|
||||
SeqNum: iss,
|
||||
RcvWnd: 30000,
|
||||
})
|
||||
|
||||
// Receive the SYN-ACK reply.
|
||||
|
@ -6115,3 +6121,176 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
|
|||
checker.AckNum(uint32(ackHeaders.SeqNum)),
|
||||
checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck)))
|
||||
}
|
||||
|
||||
func TestTCPCloseWithData(t *testing.T) {
|
||||
c := context.New(t, defaultMTU)
|
||||
defer c.Cleanup()
|
||||
|
||||
// Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed
|
||||
// after 5 seconds in TIME_WAIT state.
|
||||
tcpTimeWaitTimeout := 5 * time.Second
|
||||
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
|
||||
t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err)
|
||||
}
|
||||
|
||||
wq := &waiter.Queue{}
|
||||
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
|
||||
if err != nil {
|
||||
t.Fatalf("NewEndpoint failed: %s", err)
|
||||
}
|
||||
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
|
||||
t.Fatalf("Bind failed: %s", err)
|
||||
}
|
||||
|
||||
if err := ep.Listen(10); err != nil {
|
||||
t.Fatalf("Listen failed: %s", err)
|
||||
}
|
||||
|
||||
// Send a SYN request.
|
||||
iss := seqnum.Value(789)
|
||||
c.SendPacket(nil, &context.Headers{
|
||||
SrcPort: context.TestPort,
|
||||
DstPort: context.StackPort,
|
||||
Flags: header.TCPFlagSyn,
|
||||
SeqNum: iss,
|
||||
RcvWnd: 30000,
|
||||
})
|
||||
|
||||
// Receive the SYN-ACK reply.
|
||||
b := c.GetPacket()
|
||||
tcpHdr := header.TCP(header.IPv4(b).Payload())
|
||||
c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
|
||||
|
||||
ackHeaders := &context.Headers{
|
||||
SrcPort: context.TestPort,
|
||||
DstPort: context.StackPort,
|
||||
Flags: header.TCPFlagAck,
|
||||
SeqNum: iss + 1,
|
||||
AckNum: c.IRS + 1,
|
||||
RcvWnd: 30000,
|
||||
}
|
||||
|
||||
// Send ACK.
|
||||
c.SendPacket(nil, ackHeaders)
|
||||
|
||||
// Try to accept the connection.
|
||||
we, ch := waiter.NewChannelEntry(nil)
|
||||
wq.EventRegister(&we, waiter.EventIn)
|
||||
defer wq.EventUnregister(&we)
|
||||
|
||||
c.EP, _, err = ep.Accept()
|
||||
if err == tcpip.ErrWouldBlock {
|
||||
// Wait for connection to be established.
|
||||
select {
|
||||
case <-ch:
|
||||
c.EP, _, err = ep.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("Accept failed: %s", err)
|
||||
}
|
||||
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatalf("Timed out waiting for accept")
|
||||
}
|
||||
}
|
||||
|
||||
// Now trigger a passive close by sending a FIN.
|
||||
finHeaders := &context.Headers{
|
||||
SrcPort: context.TestPort,
|
||||
DstPort: context.StackPort,
|
||||
Flags: header.TCPFlagAck | header.TCPFlagFin,
|
||||
SeqNum: iss + 1,
|
||||
AckNum: c.IRS + 2,
|
||||
RcvWnd: 30000,
|
||||
}
|
||||
|
||||
c.SendPacket(nil, finHeaders)
|
||||
|
||||
// Get the ACK to the FIN we just sent.
|
||||
checker.IPv4(t, c.GetPacket(), checker.TCP(
|
||||
checker.SrcPort(context.StackPort),
|
||||
checker.DstPort(context.TestPort),
|
||||
checker.SeqNum(uint32(c.IRS+1)),
|
||||
checker.AckNum(uint32(iss)+2),
|
||||
checker.TCPFlags(header.TCPFlagAck)))
|
||||
|
||||
// Now write a few bytes and then close the endpoint.
|
||||
data := []byte{1, 2, 3}
|
||||
view := buffer.NewView(len(data))
|
||||
copy(view, data)
|
||||
|
||||
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
|
||||
t.Fatalf("Write failed: %s", err)
|
||||
}
|
||||
|
||||
// Check that data is received.
|
||||
b = c.GetPacket()
|
||||
checker.IPv4(t, b,
|
||||
checker.PayloadLen(len(data)+header.TCPMinimumSize),
|
||||
checker.TCP(
|
||||
checker.DstPort(context.TestPort),
|
||||
checker.SeqNum(uint32(c.IRS)+1),
|
||||
checker.AckNum(uint32(iss)+2), // Acknum is initial sequence number + 1
|
||||
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
|
||||
),
|
||||
)
|
||||
|
||||
if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
|
||||
t.Errorf("got data = %x, want = %x", p, data)
|
||||
}
|
||||
|
||||
c.EP.Close()
|
||||
// Check the FIN.
|
||||
checker.IPv4(t, c.GetPacket(), checker.TCP(
|
||||
checker.SrcPort(context.StackPort),
|
||||
checker.DstPort(context.TestPort),
|
||||
checker.SeqNum(uint32(c.IRS+1)+uint32(len(data))),
|
||||
checker.AckNum(uint32(iss+2)),
|
||||
checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
|
||||
|
||||
// First send a partial ACK.
|
||||
ackHeaders = &context.Headers{
|
||||
SrcPort: context.TestPort,
|
||||
DstPort: context.StackPort,
|
||||
Flags: header.TCPFlagAck,
|
||||
SeqNum: iss + 2,
|
||||
AckNum: c.IRS + 1 + seqnum.Value(len(data)-1),
|
||||
RcvWnd: 30000,
|
||||
}
|
||||
c.SendPacket(nil, ackHeaders)
|
||||
|
||||
// Now send a full ACK.
|
||||
ackHeaders = &context.Headers{
|
||||
SrcPort: context.TestPort,
|
||||
DstPort: context.StackPort,
|
||||
Flags: header.TCPFlagAck,
|
||||
SeqNum: iss + 2,
|
||||
AckNum: c.IRS + 1 + seqnum.Value(len(data)),
|
||||
RcvWnd: 30000,
|
||||
}
|
||||
c.SendPacket(nil, ackHeaders)
|
||||
|
||||
// Now ACK the FIN.
|
||||
ackHeaders.AckNum++
|
||||
c.SendPacket(nil, ackHeaders)
|
||||
|
||||
// Now send an ACK and we should get a RST back as the endpoint should
|
||||
// be in CLOSED state.
|
||||
ackHeaders = &context.Headers{
|
||||
SrcPort: context.TestPort,
|
||||
DstPort: context.StackPort,
|
||||
Flags: header.TCPFlagAck,
|
||||
SeqNum: iss + 2,
|
||||
AckNum: c.IRS + 1 + seqnum.Value(len(data)),
|
||||
RcvWnd: 30000,
|
||||
}
|
||||
c.SendPacket(nil, ackHeaders)
|
||||
|
||||
// Check the RST.
|
||||
checker.IPv4(t, c.GetPacket(), checker.TCP(
|
||||
checker.SrcPort(context.StackPort),
|
||||
checker.DstPort(context.TestPort),
|
||||
checker.SeqNum(uint32(ackHeaders.AckNum)),
|
||||
checker.AckNum(uint32(ackHeaders.SeqNum)),
|
||||
checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck)))
|
||||
|
||||
}
|
||||
|
|
|
@ -2142,6 +2142,7 @@ cc_library(
|
|||
":socket_test_util",
|
||||
"//test/util:test_util",
|
||||
"//test/util:thread_util",
|
||||
"@com_google_absl//absl/time",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
|
|
@ -24,6 +24,8 @@
|
|||
#include <sys/un.h>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "absl/time/clock.h"
|
||||
#include "absl/time/time.h"
|
||||
#include "test/syscalls/linux/socket_test_util.h"
|
||||
#include "test/util/test_util.h"
|
||||
#include "test/util/thread_util.h"
|
||||
|
@ -789,5 +791,26 @@ TEST_P(TCPSocketPairTest, SetTCPLingerTimeout) {
|
|||
EXPECT_EQ(get, kTCPLingerTimeout);
|
||||
}
|
||||
|
||||
TEST_P(TCPSocketPairTest, TestTCPCloseWithData) {
|
||||
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
|
||||
|
||||
ScopedThread t([&]() {
|
||||
// Close one end to trigger sending of a FIN.
|
||||
ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_WR), SyscallSucceeds());
|
||||
char buf[3];
|
||||
ASSERT_THAT(read(sockets->second_fd(), buf, 3),
|
||||
SyscallSucceedsWithValue(3));
|
||||
absl::SleepFor(absl::Milliseconds(50));
|
||||
ASSERT_THAT(close(sockets->release_second_fd()), SyscallSucceeds());
|
||||
});
|
||||
|
||||
absl::SleepFor(absl::Milliseconds(50));
|
||||
// Send some data then close.
|
||||
constexpr char kStr[] = "abc";
|
||||
ASSERT_THAT(write(sockets->first_fd(), kStr, 3), SyscallSucceedsWithValue(3));
|
||||
t.Join();
|
||||
ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds());
|
||||
}
|
||||
|
||||
} // namespace testing
|
||||
} // namespace gvisor
|
||||
|
|
Loading…
Reference in New Issue