Test connecting UDP sockets to the ANY address

This doesn't currently pass on gVisor.

While I'm here, fix a bug where connecting to the v6-mapped v4 address doesn't
work in gVisor.

PiperOrigin-RevId: 260923961
This commit is contained in:
Tamir Duberstein 2019-07-31 07:39:52 -07:00 committed by gVisor bot
parent a7d5e0d254
commit c6e6d92cb1
2 changed files with 139 additions and 45 deletions

View File

@ -241,12 +241,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// connectRoute establishes a route to the specified interface or the
// configured multicast interface if no interface is specified and the
// specified address is a multicast address.
func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress) (stack.Route, tcpip.NICID, tcpip.NetworkProtocolNumber, *tcpip.Error) {
netProto, err := e.checkV4Mapped(&addr, false)
if err != nil {
return stack.Route{}, 0, 0, err
}
func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) {
localAddr := e.id.LocalAddress
if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) {
if nicid == 0 {
@ -260,9 +255,9 @@ func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress) (stac
// Find a route to the desired destination.
r, err := e.stack.FindRoute(nicid, localAddr, addr.Addr, netProto, e.multicastLoop)
if err != nil {
return stack.Route{}, 0, 0, err
return stack.Route{}, 0, err
}
return r, nicid, netProto, nil
return r, nicid, nil
}
// Write writes data to the endpoint's peer. This method does not block
@ -336,7 +331,12 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c
return 0, nil, tcpip.ErrBroadcastDisabled
}
r, _, _, err := e.connectRoute(nicid, *to)
netProto, err := e.checkV4Mapped(to, false)
if err != nil {
return 0, nil, err
}
r, _, err := e.connectRoute(nicid, *to, netProto)
if err != nil {
return 0, nil, err
}
@ -740,6 +740,10 @@ func (e *endpoint) disconnect() *tcpip.Error {
// Connect connects the endpoint to its peer. Specifying a NIC is optional.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
netProto, err := e.checkV4Mapped(&addr, false)
if err != nil {
return err
}
if addr.Addr == "" {
return e.disconnect()
}
@ -770,7 +774,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
r, nicid, netProto, err := e.connectRoute(nicid, addr)
r, nicid, err := e.connectRoute(nicid, addr, netProto)
if err != nil {
return err
}

View File

@ -39,7 +39,7 @@ constexpr int TestPort = 40000;
// Fixture for tests parameterized by the address family to use (AF_INET and
// AF_INET6) when creating sockets.
class UdpSocketTest : public ::testing::TestWithParam<int> {
class UdpSocketTest : public ::testing::TestWithParam<AddressFamily> {
protected:
// Creates two sockets that will be used by test cases.
void SetUp() override;
@ -97,31 +97,32 @@ uint16_t* Port(struct sockaddr_storage* addr) {
}
void UdpSocketTest::SetUp() {
ASSERT_THAT(s_ = socket(GetParam(), SOCK_DGRAM, IPPROTO_UDP),
SyscallSucceeds());
int type;
if (GetParam() == AddressFamily::kIpv4) {
type = AF_INET;
auto sin = reinterpret_cast<struct sockaddr_in*>(&anyaddr_storage_);
addrlen_ = sizeof(*sin);
sin->sin_addr.s_addr = htonl(INADDR_ANY);
} else {
type = AF_INET6;
auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&anyaddr_storage_);
addrlen_ = sizeof(*sin6);
if (GetParam() == AddressFamily::kIpv6) {
sin6->sin6_addr = IN6ADDR_ANY_INIT;
} else {
TestAddress const& v4_mapped_any = V4MappedAny();
sin6->sin6_addr =
reinterpret_cast<const struct sockaddr_in6*>(&v4_mapped_any.addr)
->sin6_addr;
}
}
ASSERT_THAT(s_ = socket(type, SOCK_DGRAM, IPPROTO_UDP), SyscallSucceeds());
ASSERT_THAT(t_ = socket(GetParam(), SOCK_DGRAM, IPPROTO_UDP),
SyscallSucceeds());
ASSERT_THAT(t_ = socket(type, SOCK_DGRAM, IPPROTO_UDP), SyscallSucceeds());
memset(&anyaddr_storage_, 0, sizeof(anyaddr_storage_));
anyaddr_ = reinterpret_cast<struct sockaddr*>(&anyaddr_storage_);
anyaddr_->sa_family = GetParam();
// Initialize address-family-specific values.
switch (GetParam()) {
case AF_INET: {
auto sin = reinterpret_cast<struct sockaddr_in*>(&anyaddr_storage_);
addrlen_ = sizeof(*sin);
sin->sin_addr.s_addr = htonl(INADDR_ANY);
break;
}
case AF_INET6: {
auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&anyaddr_storage_);
addrlen_ = sizeof(*sin6);
sin6->sin6_addr = in6addr_any;
break;
}
}
anyaddr_->sa_family = type;
if (gvisor::testing::IsRunningOnGvisor()) {
for (size_t i = 0; i < ABSL_ARRAYSIZE(ports_); ++i) {
@ -154,9 +155,9 @@ void UdpSocketTest::SetUp() {
memset(&addr_storage_[i], 0, sizeof(addr_storage_[i]));
addr_[i] = reinterpret_cast<struct sockaddr*>(&addr_storage_[i]);
addr_[i]->sa_family = GetParam();
addr_[i]->sa_family = type;
switch (GetParam()) {
switch (type) {
case AF_INET: {
auto sin = reinterpret_cast<struct sockaddr_in*>(addr_[i]);
sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
@ -174,17 +175,20 @@ void UdpSocketTest::SetUp() {
}
TEST_P(UdpSocketTest, Creation) {
int type = AF_INET6;
if (GetParam() == AddressFamily::kIpv4) {
type = AF_INET;
}
int s_;
ASSERT_THAT(s_ = socket(GetParam(), SOCK_DGRAM, IPPROTO_UDP),
SyscallSucceeds());
ASSERT_THAT(s_ = socket(type, SOCK_DGRAM, IPPROTO_UDP), SyscallSucceeds());
EXPECT_THAT(close(s_), SyscallSucceeds());
ASSERT_THAT(s_ = socket(GetParam(), SOCK_DGRAM, 0), SyscallSucceeds());
ASSERT_THAT(s_ = socket(type, SOCK_DGRAM, 0), SyscallSucceeds());
EXPECT_THAT(close(s_), SyscallSucceeds());
ASSERT_THAT(s_ = socket(GetParam(), SOCK_STREAM, IPPROTO_UDP),
SyscallFails());
ASSERT_THAT(s_ = socket(type, SOCK_STREAM, IPPROTO_UDP), SyscallFails());
}
TEST_P(UdpSocketTest, Getsockname) {
@ -374,6 +378,92 @@ TEST_P(UdpSocketTest, Connect) {
EXPECT_EQ(memcmp(&peer, addr_[2], addrlen_), 0);
}
TEST_P(UdpSocketTest, ConnectAny) {
struct sockaddr_storage addr = {};
// Precondition check.
{
socklen_t addrlen = sizeof(addr);
EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
SyscallSucceeds());
if (GetParam() == AddressFamily::kIpv4) {
auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr);
EXPECT_EQ(addrlen, sizeof(*addr_out));
EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_ANY));
} else {
auto addr_out = reinterpret_cast<struct sockaddr_in6*>(&addr);
EXPECT_EQ(addrlen, sizeof(*addr_out));
struct in6_addr any = IN6ADDR_ANY_INIT;
EXPECT_EQ(memcmp(&addr_out->sin6_addr, &any, sizeof(in6_addr)), 0);
}
{
socklen_t addrlen = sizeof(addr);
EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
SyscallFailsWithErrno(ENOTCONN));
}
struct sockaddr_storage baddr = {};
if (GetParam() == AddressFamily::kIpv4) {
auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr);
addrlen = sizeof(*addr_in);
addr_in->sin_family = AF_INET;
addr_in->sin_addr.s_addr = htonl(INADDR_ANY);
} else {
auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr);
addrlen = sizeof(*addr_in);
addr_in->sin6_family = AF_INET6;
if (GetParam() == AddressFamily::kIpv6) {
addr_in->sin6_addr = IN6ADDR_ANY_INIT;
} else {
TestAddress const& v4_mapped_any = V4MappedAny();
addr_in->sin6_addr =
reinterpret_cast<const struct sockaddr_in6*>(&v4_mapped_any.addr)
->sin6_addr;
}
}
ASSERT_THAT(connect(s_, reinterpret_cast<sockaddr*>(&baddr), addrlen),
SyscallSucceeds());
}
// TODO(b/138658473): gVisor doesn't return the correct local address after
// connecting to the any address.
SKIP_IF(IsRunningOnGvisor());
// Postcondition check.
{
socklen_t addrlen = sizeof(addr);
EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
SyscallSucceeds());
if (GetParam() == AddressFamily::kIpv4) {
auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr);
EXPECT_EQ(addrlen, sizeof(*addr_out));
EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_LOOPBACK));
} else {
auto addr_out = reinterpret_cast<struct sockaddr_in6*>(&addr);
EXPECT_EQ(addrlen, sizeof(*addr_out));
struct in6_addr loopback;
if (GetParam() == AddressFamily::kIpv6) {
loopback = IN6ADDR_LOOPBACK_INIT;
} else {
TestAddress const& v4_mapped_loopback = V4MappedLoopback();
loopback = reinterpret_cast<const struct sockaddr_in6*>(
&v4_mapped_loopback.addr)
->sin6_addr;
}
EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0);
}
addrlen = sizeof(addr);
EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
SyscallFailsWithErrno(ENOTCONN));
}
}
TEST_P(UdpSocketTest, DisconnectAfterBind) {
ASSERT_THAT(bind(s_, addr_[1], addrlen_), SyscallSucceeds());
// Connect the socket.
@ -402,19 +492,17 @@ TEST_P(UdpSocketTest, DisconnectAfterBindToAny) {
struct sockaddr_storage baddr = {};
socklen_t addrlen;
auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1]));
if (addr_[0]->sa_family == AF_INET) {
if (GetParam() == AddressFamily::kIpv4) {
auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr);
addr_in->sin_family = AF_INET;
addr_in->sin_port = port;
inet_pton(AF_INET, "0.0.0.0",
reinterpret_cast<void*>(&addr_in->sin_addr.s_addr));
addr_in->sin_addr.s_addr = htonl(INADDR_ANY);
} else {
auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr);
addr_in->sin6_family = AF_INET6;
addr_in->sin6_port = port;
inet_pton(AF_INET6,
"::", reinterpret_cast<void*>(&addr_in->sin6_addr.s6_addr));
addr_in->sin6_scope_id = 0;
addr_in->sin6_addr = IN6ADDR_ANY_INIT;
}
ASSERT_THAT(bind(s_, reinterpret_cast<sockaddr*>(&baddr), addrlen_),
SyscallSucceeds());
@ -1165,7 +1253,9 @@ TEST_P(UdpSocketTest, TimestampIoctlPersistence) {
}
INSTANTIATE_TEST_SUITE_P(AllInetTests, UdpSocketTest,
::testing::Values(AF_INET, AF_INET6));
::testing::Values(AddressFamily::kIpv4,
AddressFamily::kIpv6,
AddressFamily::kDualStack));
} // namespace