Fix panic due to early transition to Closed.
The code in rcv.consumeSegment incorrectly transitions to CLOSED state from LAST-ACK before the final ACK for the FIN. Further if receiving a segment changes a socket to a closed state then we should not invoke the sender as the socket is now closed and sending any segments is incorrect. PiperOrigin-RevId: 283625300
This commit is contained in:
parent
43643752f0
commit
27e2c4ddca
|
@ -953,20 +953,6 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
|
||||||
func (e *endpoint) handleSegments() *tcpip.Error {
|
func (e *endpoint) handleSegments() *tcpip.Error {
|
||||||
checkRequeue := true
|
checkRequeue := true
|
||||||
for i := 0; i < maxSegmentsPerWake; i++ {
|
for i := 0; i < maxSegmentsPerWake; i++ {
|
||||||
e.mu.RLock()
|
|
||||||
state := e.state
|
|
||||||
e.mu.RUnlock()
|
|
||||||
if state == StateClose {
|
|
||||||
// When we get into StateClose while processing from the queue,
|
|
||||||
// return immediately and let the protocolMainloop handle it.
|
|
||||||
//
|
|
||||||
// We can reach StateClose only while processing a previous segment
|
|
||||||
// or a notification from the protocolMainLoop (caller goroutine).
|
|
||||||
// This means that with this return, the segment dequeue below can
|
|
||||||
// never occur on a closed endpoint.
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
s := e.segmentQueue.dequeue()
|
s := e.segmentQueue.dequeue()
|
||||||
if s == nil {
|
if s == nil {
|
||||||
checkRequeue = false
|
checkRequeue = false
|
||||||
|
@ -1024,6 +1010,24 @@ func (e *endpoint) handleSegments() *tcpip.Error {
|
||||||
s.decRef()
|
s.decRef()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Now check if the received segment has caused us to transition
|
||||||
|
// to a CLOSED state, if yes then terminate processing and do
|
||||||
|
// not invoke the sender.
|
||||||
|
e.mu.RLock()
|
||||||
|
state := e.state
|
||||||
|
e.mu.RUnlock()
|
||||||
|
if state == StateClose {
|
||||||
|
// When we get into StateClose while processing from the queue,
|
||||||
|
// return immediately and let the protocolMainloop handle it.
|
||||||
|
//
|
||||||
|
// We can reach StateClose only while processing a previous segment
|
||||||
|
// or a notification from the protocolMainLoop (caller goroutine).
|
||||||
|
// This means that with this return, the segment dequeue below can
|
||||||
|
// never occur on a closed endpoint.
|
||||||
|
s.decRef()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
e.snd.handleRcvdSegment(s)
|
e.snd.handleRcvdSegment(s)
|
||||||
}
|
}
|
||||||
s.decRef()
|
s.decRef()
|
||||||
|
|
|
@ -205,7 +205,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
|
||||||
|
|
||||||
// Handle ACK (not FIN-ACK, which we handled above) during one of the
|
// Handle ACK (not FIN-ACK, which we handled above) during one of the
|
||||||
// shutdown states.
|
// shutdown states.
|
||||||
if s.flagIsSet(header.TCPFlagAck) {
|
if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt {
|
||||||
r.ep.mu.Lock()
|
r.ep.mu.Lock()
|
||||||
switch r.ep.state {
|
switch r.ep.state {
|
||||||
case StateFinWait1:
|
case StateFinWait1:
|
||||||
|
|
|
@ -5632,6 +5632,7 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) {
|
||||||
DstPort: context.StackPort,
|
DstPort: context.StackPort,
|
||||||
Flags: header.TCPFlagSyn,
|
Flags: header.TCPFlagSyn,
|
||||||
SeqNum: iss,
|
SeqNum: iss,
|
||||||
|
RcvWnd: 30000,
|
||||||
})
|
})
|
||||||
|
|
||||||
// Receive the SYN-ACK reply.
|
// Receive the SYN-ACK reply.
|
||||||
|
@ -5750,6 +5751,7 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) {
|
||||||
DstPort: context.StackPort,
|
DstPort: context.StackPort,
|
||||||
Flags: header.TCPFlagSyn,
|
Flags: header.TCPFlagSyn,
|
||||||
SeqNum: iss,
|
SeqNum: iss,
|
||||||
|
RcvWnd: 30000,
|
||||||
})
|
})
|
||||||
|
|
||||||
// Receive the SYN-ACK reply.
|
// Receive the SYN-ACK reply.
|
||||||
|
@ -5856,6 +5858,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
|
||||||
DstPort: context.StackPort,
|
DstPort: context.StackPort,
|
||||||
Flags: header.TCPFlagSyn,
|
Flags: header.TCPFlagSyn,
|
||||||
SeqNum: iss,
|
SeqNum: iss,
|
||||||
|
RcvWnd: 30000,
|
||||||
})
|
})
|
||||||
|
|
||||||
// Receive the SYN-ACK reply.
|
// Receive the SYN-ACK reply.
|
||||||
|
@ -5929,6 +5932,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
|
||||||
DstPort: context.StackPort,
|
DstPort: context.StackPort,
|
||||||
Flags: header.TCPFlagSyn,
|
Flags: header.TCPFlagSyn,
|
||||||
SeqNum: iss,
|
SeqNum: iss,
|
||||||
|
RcvWnd: 30000,
|
||||||
})
|
})
|
||||||
|
|
||||||
c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second)
|
c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second)
|
||||||
|
@ -5941,6 +5945,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
|
||||||
DstPort: context.StackPort,
|
DstPort: context.StackPort,
|
||||||
Flags: header.TCPFlagSyn,
|
Flags: header.TCPFlagSyn,
|
||||||
SeqNum: iss,
|
SeqNum: iss,
|
||||||
|
RcvWnd: 30000,
|
||||||
})
|
})
|
||||||
|
|
||||||
// Receive the SYN-ACK reply.
|
// Receive the SYN-ACK reply.
|
||||||
|
@ -6007,6 +6012,7 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
|
||||||
DstPort: context.StackPort,
|
DstPort: context.StackPort,
|
||||||
Flags: header.TCPFlagSyn,
|
Flags: header.TCPFlagSyn,
|
||||||
SeqNum: iss,
|
SeqNum: iss,
|
||||||
|
RcvWnd: 30000,
|
||||||
})
|
})
|
||||||
|
|
||||||
// Receive the SYN-ACK reply.
|
// Receive the SYN-ACK reply.
|
||||||
|
@ -6115,3 +6121,176 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
|
||||||
checker.AckNum(uint32(ackHeaders.SeqNum)),
|
checker.AckNum(uint32(ackHeaders.SeqNum)),
|
||||||
checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck)))
|
checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTCPCloseWithData(t *testing.T) {
|
||||||
|
c := context.New(t, defaultMTU)
|
||||||
|
defer c.Cleanup()
|
||||||
|
|
||||||
|
// Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed
|
||||||
|
// after 5 seconds in TIME_WAIT state.
|
||||||
|
tcpTimeWaitTimeout := 5 * time.Second
|
||||||
|
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
|
||||||
|
t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
wq := &waiter.Queue{}
|
||||||
|
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewEndpoint failed: %s", err)
|
||||||
|
}
|
||||||
|
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
|
||||||
|
t.Fatalf("Bind failed: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ep.Listen(10); err != nil {
|
||||||
|
t.Fatalf("Listen failed: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send a SYN request.
|
||||||
|
iss := seqnum.Value(789)
|
||||||
|
c.SendPacket(nil, &context.Headers{
|
||||||
|
SrcPort: context.TestPort,
|
||||||
|
DstPort: context.StackPort,
|
||||||
|
Flags: header.TCPFlagSyn,
|
||||||
|
SeqNum: iss,
|
||||||
|
RcvWnd: 30000,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Receive the SYN-ACK reply.
|
||||||
|
b := c.GetPacket()
|
||||||
|
tcpHdr := header.TCP(header.IPv4(b).Payload())
|
||||||
|
c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
|
||||||
|
|
||||||
|
ackHeaders := &context.Headers{
|
||||||
|
SrcPort: context.TestPort,
|
||||||
|
DstPort: context.StackPort,
|
||||||
|
Flags: header.TCPFlagAck,
|
||||||
|
SeqNum: iss + 1,
|
||||||
|
AckNum: c.IRS + 1,
|
||||||
|
RcvWnd: 30000,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send ACK.
|
||||||
|
c.SendPacket(nil, ackHeaders)
|
||||||
|
|
||||||
|
// Try to accept the connection.
|
||||||
|
we, ch := waiter.NewChannelEntry(nil)
|
||||||
|
wq.EventRegister(&we, waiter.EventIn)
|
||||||
|
defer wq.EventUnregister(&we)
|
||||||
|
|
||||||
|
c.EP, _, err = ep.Accept()
|
||||||
|
if err == tcpip.ErrWouldBlock {
|
||||||
|
// Wait for connection to be established.
|
||||||
|
select {
|
||||||
|
case <-ch:
|
||||||
|
c.EP, _, err = ep.Accept()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Accept failed: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
case <-time.After(1 * time.Second):
|
||||||
|
t.Fatalf("Timed out waiting for accept")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now trigger a passive close by sending a FIN.
|
||||||
|
finHeaders := &context.Headers{
|
||||||
|
SrcPort: context.TestPort,
|
||||||
|
DstPort: context.StackPort,
|
||||||
|
Flags: header.TCPFlagAck | header.TCPFlagFin,
|
||||||
|
SeqNum: iss + 1,
|
||||||
|
AckNum: c.IRS + 2,
|
||||||
|
RcvWnd: 30000,
|
||||||
|
}
|
||||||
|
|
||||||
|
c.SendPacket(nil, finHeaders)
|
||||||
|
|
||||||
|
// Get the ACK to the FIN we just sent.
|
||||||
|
checker.IPv4(t, c.GetPacket(), checker.TCP(
|
||||||
|
checker.SrcPort(context.StackPort),
|
||||||
|
checker.DstPort(context.TestPort),
|
||||||
|
checker.SeqNum(uint32(c.IRS+1)),
|
||||||
|
checker.AckNum(uint32(iss)+2),
|
||||||
|
checker.TCPFlags(header.TCPFlagAck)))
|
||||||
|
|
||||||
|
// Now write a few bytes and then close the endpoint.
|
||||||
|
data := []byte{1, 2, 3}
|
||||||
|
view := buffer.NewView(len(data))
|
||||||
|
copy(view, data)
|
||||||
|
|
||||||
|
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
|
||||||
|
t.Fatalf("Write failed: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that data is received.
|
||||||
|
b = c.GetPacket()
|
||||||
|
checker.IPv4(t, b,
|
||||||
|
checker.PayloadLen(len(data)+header.TCPMinimumSize),
|
||||||
|
checker.TCP(
|
||||||
|
checker.DstPort(context.TestPort),
|
||||||
|
checker.SeqNum(uint32(c.IRS)+1),
|
||||||
|
checker.AckNum(uint32(iss)+2), // Acknum is initial sequence number + 1
|
||||||
|
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
|
||||||
|
t.Errorf("got data = %x, want = %x", p, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.EP.Close()
|
||||||
|
// Check the FIN.
|
||||||
|
checker.IPv4(t, c.GetPacket(), checker.TCP(
|
||||||
|
checker.SrcPort(context.StackPort),
|
||||||
|
checker.DstPort(context.TestPort),
|
||||||
|
checker.SeqNum(uint32(c.IRS+1)+uint32(len(data))),
|
||||||
|
checker.AckNum(uint32(iss+2)),
|
||||||
|
checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
|
||||||
|
|
||||||
|
// First send a partial ACK.
|
||||||
|
ackHeaders = &context.Headers{
|
||||||
|
SrcPort: context.TestPort,
|
||||||
|
DstPort: context.StackPort,
|
||||||
|
Flags: header.TCPFlagAck,
|
||||||
|
SeqNum: iss + 2,
|
||||||
|
AckNum: c.IRS + 1 + seqnum.Value(len(data)-1),
|
||||||
|
RcvWnd: 30000,
|
||||||
|
}
|
||||||
|
c.SendPacket(nil, ackHeaders)
|
||||||
|
|
||||||
|
// Now send a full ACK.
|
||||||
|
ackHeaders = &context.Headers{
|
||||||
|
SrcPort: context.TestPort,
|
||||||
|
DstPort: context.StackPort,
|
||||||
|
Flags: header.TCPFlagAck,
|
||||||
|
SeqNum: iss + 2,
|
||||||
|
AckNum: c.IRS + 1 + seqnum.Value(len(data)),
|
||||||
|
RcvWnd: 30000,
|
||||||
|
}
|
||||||
|
c.SendPacket(nil, ackHeaders)
|
||||||
|
|
||||||
|
// Now ACK the FIN.
|
||||||
|
ackHeaders.AckNum++
|
||||||
|
c.SendPacket(nil, ackHeaders)
|
||||||
|
|
||||||
|
// Now send an ACK and we should get a RST back as the endpoint should
|
||||||
|
// be in CLOSED state.
|
||||||
|
ackHeaders = &context.Headers{
|
||||||
|
SrcPort: context.TestPort,
|
||||||
|
DstPort: context.StackPort,
|
||||||
|
Flags: header.TCPFlagAck,
|
||||||
|
SeqNum: iss + 2,
|
||||||
|
AckNum: c.IRS + 1 + seqnum.Value(len(data)),
|
||||||
|
RcvWnd: 30000,
|
||||||
|
}
|
||||||
|
c.SendPacket(nil, ackHeaders)
|
||||||
|
|
||||||
|
// Check the RST.
|
||||||
|
checker.IPv4(t, c.GetPacket(), checker.TCP(
|
||||||
|
checker.SrcPort(context.StackPort),
|
||||||
|
checker.DstPort(context.TestPort),
|
||||||
|
checker.SeqNum(uint32(ackHeaders.AckNum)),
|
||||||
|
checker.AckNum(uint32(ackHeaders.SeqNum)),
|
||||||
|
checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck)))
|
||||||
|
|
||||||
|
}
|
||||||
|
|
|
@ -2142,6 +2142,7 @@ cc_library(
|
||||||
":socket_test_util",
|
":socket_test_util",
|
||||||
"//test/util:test_util",
|
"//test/util:test_util",
|
||||||
"//test/util:thread_util",
|
"//test/util:thread_util",
|
||||||
|
"@com_google_absl//absl/time",
|
||||||
"@com_google_googletest//:gtest",
|
"@com_google_googletest//:gtest",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
|
|
|
@ -24,6 +24,8 @@
|
||||||
#include <sys/un.h>
|
#include <sys/un.h>
|
||||||
|
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
|
#include "absl/time/clock.h"
|
||||||
|
#include "absl/time/time.h"
|
||||||
#include "test/syscalls/linux/socket_test_util.h"
|
#include "test/syscalls/linux/socket_test_util.h"
|
||||||
#include "test/util/test_util.h"
|
#include "test/util/test_util.h"
|
||||||
#include "test/util/thread_util.h"
|
#include "test/util/thread_util.h"
|
||||||
|
@ -789,5 +791,26 @@ TEST_P(TCPSocketPairTest, SetTCPLingerTimeout) {
|
||||||
EXPECT_EQ(get, kTCPLingerTimeout);
|
EXPECT_EQ(get, kTCPLingerTimeout);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(TCPSocketPairTest, TestTCPCloseWithData) {
|
||||||
|
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
|
||||||
|
|
||||||
|
ScopedThread t([&]() {
|
||||||
|
// Close one end to trigger sending of a FIN.
|
||||||
|
ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_WR), SyscallSucceeds());
|
||||||
|
char buf[3];
|
||||||
|
ASSERT_THAT(read(sockets->second_fd(), buf, 3),
|
||||||
|
SyscallSucceedsWithValue(3));
|
||||||
|
absl::SleepFor(absl::Milliseconds(50));
|
||||||
|
ASSERT_THAT(close(sockets->release_second_fd()), SyscallSucceeds());
|
||||||
|
});
|
||||||
|
|
||||||
|
absl::SleepFor(absl::Milliseconds(50));
|
||||||
|
// Send some data then close.
|
||||||
|
constexpr char kStr[] = "abc";
|
||||||
|
ASSERT_THAT(write(sockets->first_fd(), kStr, 3), SyscallSucceedsWithValue(3));
|
||||||
|
t.Join();
|
||||||
|
ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace testing
|
} // namespace testing
|
||||||
} // namespace gvisor
|
} // namespace gvisor
|
||||||
|
|
Loading…
Reference in New Issue