Avoid ignoring incoming packet by demuxer on endpoint lookup failure

This fixes a race that occurs while the endpoint is being unregistered
and the transport demuxer attempts to match the incoming packet to any
endpoint. The race specifically occurs when the unregistration (and
deletion of the endpoint) occurs, after a successful endpointsByNIC
lookup and before the endpoints map is further looked up with ingress
NICID of the packet.

The fix is to notify the caller of lookup-with-NICID failure, so that
the logic falls through to handling unknown destination packets.
For TCP this can mean replying back with RST.

The syscall test in this CL catches this race as the ACK completing the
handshake could get silently dropped on a listener close, causing no
RST sent to the peer and timing out the poll waiting for POLLHUP.

Fixes #5850

PiperOrigin-RevId: 369023779
This commit is contained in:
Mithun Iyer 2021-04-17 11:30:36 -07:00 committed by gVisor bot
parent 3b685753b4
commit 9b4cc3d43b
2 changed files with 49 additions and 45 deletions

View File

@ -150,16 +150,17 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint {
return eps
}
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *PacketBuffer) {
// handlePacket is called by the stack when new packets arrive to this transport
// endpoint. It returns false if the packet could not be matched to any
// transport endpoint, true otherwise.
func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *PacketBuffer) bool {
epsByNIC.mu.RLock()
mpep, ok := epsByNIC.endpoints[pkt.NICID]
if !ok {
if mpep, ok = epsByNIC.endpoints[0]; !ok {
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
return
return false
}
}
@ -168,18 +169,19 @@ func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *Packet
if isInboundMulticastOrBroadcast(pkt, id.LocalAddress) {
mpep.handlePacketAll(id, pkt)
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
return
return true
}
// multiPortEndpoints are guaranteed to have at least one element.
transEP := selectEndpoint(id, mpep, epsByNIC.seed)
if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue {
queuedProtocol.QueuePacket(transEP, id, pkt)
epsByNIC.mu.RUnlock()
return
return true
}
transEP.HandlePacket(id, pkt)
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
return true
}
// handleError delivers an error to the transport endpoint identified by id.
@ -567,8 +569,7 @@ func (d *transportDemuxer) deliverPacket(protocol tcpip.TransportProtocolNumber,
}
return false
}
ep.handlePacket(id, pkt)
return true
return ep.handlePacket(id, pkt)
}
// deliverRawPacket attempts to deliver the given packet and returns whether it

View File

@ -477,47 +477,50 @@ void TestHangupDuringConnect(const TestParam& param,
TestAddress const& listener = param.listener;
TestAddress const& connector = param.connector;
// Create the listening socket.
FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
sockaddr_storage listen_addr = listener.addr;
ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
listener.addr_len),
SyscallSucceeds());
ASSERT_THAT(listen(listen_fd.get(), 1), SyscallSucceeds());
for (int i = 0; i < 100; i++) {
// Create the listening socket.
FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
sockaddr_storage listen_addr = listener.addr;
ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
listener.addr_len),
SyscallSucceeds());
ASSERT_THAT(listen(listen_fd.get(), 0), SyscallSucceeds());
// Get the port bound by the listening socket.
socklen_t addrlen = listener.addr_len;
ASSERT_THAT(getsockname(listen_fd.get(),
reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
SyscallSucceeds());
uint16_t const port =
ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
// Get the port bound by the listening socket.
socklen_t addrlen = listener.addr_len;
ASSERT_THAT(
getsockname(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
&addrlen),
SyscallSucceeds());
uint16_t const port =
ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
sockaddr_storage conn_addr = connector.addr;
ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
sockaddr_storage conn_addr = connector.addr;
ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
// Connect asynchronously and immediately hang up the listener.
FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE(
Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr),
connector.addr_len);
if (ret != 0) {
EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
// Connect asynchronously and immediately hang up the listener.
FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE(
Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr),
connector.addr_len);
if (ret != 0) {
EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
}
hangup(listen_fd);
// Wait for the connection to close.
struct pollfd pfd = {
.fd = client.get(),
};
constexpr int kTimeout = 10000;
int n = poll(&pfd, 1, kTimeout);
ASSERT_GE(n, 0) << strerror(errno);
ASSERT_EQ(n, 1);
ASSERT_EQ(pfd.revents, POLLHUP | POLLERR);
ASSERT_EQ(close(client.release()), 0) << strerror(errno);
}
hangup(listen_fd);
// Wait for the connection to close.
struct pollfd pfd = {
.fd = client.get(),
};
constexpr int kTimeout = 10000;
int n = poll(&pfd, 1, kTimeout);
ASSERT_GE(n, 0) << strerror(errno);
ASSERT_EQ(n, 1);
ASSERT_EQ(pfd.revents, POLLHUP | POLLERR);
ASSERT_EQ(close(client.release()), 0) << strerror(errno);
}
TEST_P(SocketInetLoopbackTest, TCPListenCloseDuringConnect) {