diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 91f89a781..52a5df691 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -241,12 +241,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // connectRoute establishes a route to the specified interface or the // configured multicast interface if no interface is specified and the // specified address is a multicast address. -func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress) (stack.Route, tcpip.NICID, tcpip.NetworkProtocolNumber, *tcpip.Error) { - netProto, err := e.checkV4Mapped(&addr, false) - if err != nil { - return stack.Route{}, 0, 0, err - } - +func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) { localAddr := e.id.LocalAddress if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) { if nicid == 0 { @@ -260,9 +255,9 @@ func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress) (stac // Find a route to the desired destination. r, err := e.stack.FindRoute(nicid, localAddr, addr.Addr, netProto, e.multicastLoop) if err != nil { - return stack.Route{}, 0, 0, err + return stack.Route{}, 0, err } - return r, nicid, netProto, nil + return r, nicid, nil } // Write writes data to the endpoint's peer. This method does not block @@ -336,7 +331,12 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c return 0, nil, tcpip.ErrBroadcastDisabled } - r, _, _, err := e.connectRoute(nicid, *to) + netProto, err := e.checkV4Mapped(to, false) + if err != nil { + return 0, nil, err + } + + r, _, err := e.connectRoute(nicid, *to, netProto) if err != nil { return 0, nil, err } @@ -740,6 +740,10 @@ func (e *endpoint) disconnect() *tcpip.Error { // Connect connects the endpoint to its peer. Specifying a NIC is optional. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + netProto, err := e.checkV4Mapped(&addr, false) + if err != nil { + return err + } if addr.Addr == "" { return e.disconnect() } @@ -770,7 +774,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - r, nicid, netProto, err := e.connectRoute(nicid, addr) + r, nicid, err := e.connectRoute(nicid, addr, netProto) if err != nil { return err } diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc index 1bb0307c4..6ffb65168 100644 --- a/test/syscalls/linux/udp_socket.cc +++ b/test/syscalls/linux/udp_socket.cc @@ -39,7 +39,7 @@ constexpr int TestPort = 40000; // Fixture for tests parameterized by the address family to use (AF_INET and // AF_INET6) when creating sockets. -class UdpSocketTest : public ::testing::TestWithParam { +class UdpSocketTest : public ::testing::TestWithParam { protected: // Creates two sockets that will be used by test cases. void SetUp() override; @@ -97,31 +97,32 @@ uint16_t* Port(struct sockaddr_storage* addr) { } void UdpSocketTest::SetUp() { - ASSERT_THAT(s_ = socket(GetParam(), SOCK_DGRAM, IPPROTO_UDP), - SyscallSucceeds()); + int type; + if (GetParam() == AddressFamily::kIpv4) { + type = AF_INET; + auto sin = reinterpret_cast(&anyaddr_storage_); + addrlen_ = sizeof(*sin); + sin->sin_addr.s_addr = htonl(INADDR_ANY); + } else { + type = AF_INET6; + auto sin6 = reinterpret_cast(&anyaddr_storage_); + addrlen_ = sizeof(*sin6); + if (GetParam() == AddressFamily::kIpv6) { + sin6->sin6_addr = IN6ADDR_ANY_INIT; + } else { + TestAddress const& v4_mapped_any = V4MappedAny(); + sin6->sin6_addr = + reinterpret_cast(&v4_mapped_any.addr) + ->sin6_addr; + } + } + ASSERT_THAT(s_ = socket(type, SOCK_DGRAM, IPPROTO_UDP), SyscallSucceeds()); - ASSERT_THAT(t_ = socket(GetParam(), SOCK_DGRAM, IPPROTO_UDP), - SyscallSucceeds()); + ASSERT_THAT(t_ = socket(type, SOCK_DGRAM, IPPROTO_UDP), SyscallSucceeds()); memset(&anyaddr_storage_, 0, sizeof(anyaddr_storage_)); anyaddr_ = reinterpret_cast(&anyaddr_storage_); - anyaddr_->sa_family = GetParam(); - - // Initialize address-family-specific values. - switch (GetParam()) { - case AF_INET: { - auto sin = reinterpret_cast(&anyaddr_storage_); - addrlen_ = sizeof(*sin); - sin->sin_addr.s_addr = htonl(INADDR_ANY); - break; - } - case AF_INET6: { - auto sin6 = reinterpret_cast(&anyaddr_storage_); - addrlen_ = sizeof(*sin6); - sin6->sin6_addr = in6addr_any; - break; - } - } + anyaddr_->sa_family = type; if (gvisor::testing::IsRunningOnGvisor()) { for (size_t i = 0; i < ABSL_ARRAYSIZE(ports_); ++i) { @@ -154,9 +155,9 @@ void UdpSocketTest::SetUp() { memset(&addr_storage_[i], 0, sizeof(addr_storage_[i])); addr_[i] = reinterpret_cast(&addr_storage_[i]); - addr_[i]->sa_family = GetParam(); + addr_[i]->sa_family = type; - switch (GetParam()) { + switch (type) { case AF_INET: { auto sin = reinterpret_cast(addr_[i]); sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK); @@ -174,17 +175,20 @@ void UdpSocketTest::SetUp() { } TEST_P(UdpSocketTest, Creation) { + int type = AF_INET6; + if (GetParam() == AddressFamily::kIpv4) { + type = AF_INET; + } + int s_; - ASSERT_THAT(s_ = socket(GetParam(), SOCK_DGRAM, IPPROTO_UDP), - SyscallSucceeds()); + ASSERT_THAT(s_ = socket(type, SOCK_DGRAM, IPPROTO_UDP), SyscallSucceeds()); EXPECT_THAT(close(s_), SyscallSucceeds()); - ASSERT_THAT(s_ = socket(GetParam(), SOCK_DGRAM, 0), SyscallSucceeds()); + ASSERT_THAT(s_ = socket(type, SOCK_DGRAM, 0), SyscallSucceeds()); EXPECT_THAT(close(s_), SyscallSucceeds()); - ASSERT_THAT(s_ = socket(GetParam(), SOCK_STREAM, IPPROTO_UDP), - SyscallFails()); + ASSERT_THAT(s_ = socket(type, SOCK_STREAM, IPPROTO_UDP), SyscallFails()); } TEST_P(UdpSocketTest, Getsockname) { @@ -374,6 +378,92 @@ TEST_P(UdpSocketTest, Connect) { EXPECT_EQ(memcmp(&peer, addr_[2], addrlen_), 0); } +TEST_P(UdpSocketTest, ConnectAny) { + struct sockaddr_storage addr = {}; + + // Precondition check. + { + socklen_t addrlen = sizeof(addr); + EXPECT_THAT(getsockname(s_, reinterpret_cast(&addr), &addrlen), + SyscallSucceeds()); + + if (GetParam() == AddressFamily::kIpv4) { + auto addr_out = reinterpret_cast(&addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_ANY)); + } else { + auto addr_out = reinterpret_cast(&addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + struct in6_addr any = IN6ADDR_ANY_INIT; + EXPECT_EQ(memcmp(&addr_out->sin6_addr, &any, sizeof(in6_addr)), 0); + } + + { + socklen_t addrlen = sizeof(addr); + EXPECT_THAT(getpeername(s_, reinterpret_cast(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); + } + + struct sockaddr_storage baddr = {}; + if (GetParam() == AddressFamily::kIpv4) { + auto addr_in = reinterpret_cast(&baddr); + addrlen = sizeof(*addr_in); + addr_in->sin_family = AF_INET; + addr_in->sin_addr.s_addr = htonl(INADDR_ANY); + } else { + auto addr_in = reinterpret_cast(&baddr); + addrlen = sizeof(*addr_in); + addr_in->sin6_family = AF_INET6; + if (GetParam() == AddressFamily::kIpv6) { + addr_in->sin6_addr = IN6ADDR_ANY_INIT; + } else { + TestAddress const& v4_mapped_any = V4MappedAny(); + addr_in->sin6_addr = + reinterpret_cast(&v4_mapped_any.addr) + ->sin6_addr; + } + } + + ASSERT_THAT(connect(s_, reinterpret_cast(&baddr), addrlen), + SyscallSucceeds()); + } + + // TODO(b/138658473): gVisor doesn't return the correct local address after + // connecting to the any address. + SKIP_IF(IsRunningOnGvisor()); + + // Postcondition check. + { + socklen_t addrlen = sizeof(addr); + EXPECT_THAT(getsockname(s_, reinterpret_cast(&addr), &addrlen), + SyscallSucceeds()); + + if (GetParam() == AddressFamily::kIpv4) { + auto addr_out = reinterpret_cast(&addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_LOOPBACK)); + } else { + auto addr_out = reinterpret_cast(&addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + struct in6_addr loopback; + if (GetParam() == AddressFamily::kIpv6) { + loopback = IN6ADDR_LOOPBACK_INIT; + } else { + TestAddress const& v4_mapped_loopback = V4MappedLoopback(); + loopback = reinterpret_cast( + &v4_mapped_loopback.addr) + ->sin6_addr; + } + + EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0); + } + + addrlen = sizeof(addr); + EXPECT_THAT(getpeername(s_, reinterpret_cast(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); + } +} + TEST_P(UdpSocketTest, DisconnectAfterBind) { ASSERT_THAT(bind(s_, addr_[1], addrlen_), SyscallSucceeds()); // Connect the socket. @@ -402,19 +492,17 @@ TEST_P(UdpSocketTest, DisconnectAfterBindToAny) { struct sockaddr_storage baddr = {}; socklen_t addrlen; auto port = *Port(reinterpret_cast(addr_[1])); - if (addr_[0]->sa_family == AF_INET) { + if (GetParam() == AddressFamily::kIpv4) { auto addr_in = reinterpret_cast(&baddr); addr_in->sin_family = AF_INET; addr_in->sin_port = port; - inet_pton(AF_INET, "0.0.0.0", - reinterpret_cast(&addr_in->sin_addr.s_addr)); + addr_in->sin_addr.s_addr = htonl(INADDR_ANY); } else { auto addr_in = reinterpret_cast(&baddr); addr_in->sin6_family = AF_INET6; addr_in->sin6_port = port; - inet_pton(AF_INET6, - "::", reinterpret_cast(&addr_in->sin6_addr.s6_addr)); addr_in->sin6_scope_id = 0; + addr_in->sin6_addr = IN6ADDR_ANY_INIT; } ASSERT_THAT(bind(s_, reinterpret_cast(&baddr), addrlen_), SyscallSucceeds()); @@ -1165,7 +1253,9 @@ TEST_P(UdpSocketTest, TimestampIoctlPersistence) { } INSTANTIATE_TEST_SUITE_P(AllInetTests, UdpSocketTest, - ::testing::Values(AF_INET, AF_INET6)); + ::testing::Values(AddressFamily::kIpv4, + AddressFamily::kIpv6, + AddressFamily::kDualStack)); } // namespace