netstack/udp: connect with the AF_UNSPEC address family means disconnect

PiperOrigin-RevId: 256433283
This commit is contained in:
Andrei Vagin 2019-07-03 13:57:24 -07:00 committed by gVisor bot
parent f10862696c
commit 116cac053e
12 changed files with 299 additions and 91 deletions

View File

@ -285,14 +285,14 @@ func bytesToIPAddress(addr []byte) tcpip.Address {
// GetAddress reads an sockaddr struct from the given address and converts it
// to the FullAddress format. It supports AF_UNIX, AF_INET and AF_INET6
// addresses.
func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) {
func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syserr.Error) {
// Make sure we have at least 2 bytes for the address family.
if len(addr) < 2 {
return tcpip.FullAddress{}, syserr.ErrInvalidArgument
}
family := usermem.ByteOrder.Uint16(addr)
if family != uint16(sfamily) {
if family != uint16(sfamily) && (!strict && family != linux.AF_UNSPEC) {
return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported
}
@ -317,7 +317,7 @@ func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) {
case linux.AF_INET:
var a linux.SockAddrInet
if len(addr) < sockAddrInetSize {
return tcpip.FullAddress{}, syserr.ErrBadAddress
return tcpip.FullAddress{}, syserr.ErrInvalidArgument
}
binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a)
@ -330,7 +330,7 @@ func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) {
case linux.AF_INET6:
var a linux.SockAddrInet6
if len(addr) < sockAddrInet6Size {
return tcpip.FullAddress{}, syserr.ErrBadAddress
return tcpip.FullAddress{}, syserr.ErrInvalidArgument
}
binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a)
@ -343,6 +343,9 @@ func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) {
}
return out, nil
case linux.AF_UNSPEC:
return tcpip.FullAddress{}, nil
default:
return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported
}
@ -465,7 +468,7 @@ func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
// Connect implements the linux syscall connect(2) for sockets backed by
// tpcip.Endpoint.
func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
addr, err := GetAddress(s.family, sockaddr)
addr, err := GetAddress(s.family, sockaddr, false /* strict */)
if err != nil {
return err
}
@ -498,7 +501,7 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo
// Bind implements the linux syscall bind(2) for sockets backed by
// tcpip.Endpoint.
func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
addr, err := GetAddress(s.family, sockaddr)
addr, err := GetAddress(s.family, sockaddr, true /* strict */)
if err != nil {
return err
}
@ -1922,7 +1925,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
var addr *tcpip.FullAddress
if len(to) > 0 {
addrBuf, err := GetAddress(s.family, to)
addrBuf, err := GetAddress(s.family, to, true /* strict */)
if err != nil {
return 0, err
}

View File

@ -110,7 +110,7 @@ func (s *SocketOperations) Endpoint() transport.Endpoint {
// extractPath extracts and validates the address.
func extractPath(sockaddr []byte) (string, *syserr.Error) {
addr, err := epsocket.GetAddress(linux.AF_UNIX, sockaddr)
addr, err := epsocket.GetAddress(linux.AF_UNIX, sockaddr, true /* strict */)
if err != nil {
return "", err
}

View File

@ -332,7 +332,7 @@ func sockAddr(t *kernel.Task, addr usermem.Addr, length uint32) string {
switch family {
case linux.AF_INET, linux.AF_INET6, linux.AF_UNIX:
fa, err := epsocket.GetAddress(int(family), b)
fa, err := epsocket.GetAddress(int(family), b, true /* strict */)
if err != nil {
return fmt.Sprintf("%#x {Family: %s, error extracting address: %v}", addr, familyStr, err)
}

View File

@ -49,43 +49,44 @@ var (
)
var netstackErrorTranslations = map[*tcpip.Error]*Error{
tcpip.ErrUnknownProtocol: ErrUnknownProtocol,
tcpip.ErrUnknownNICID: ErrUnknownNICID,
tcpip.ErrUnknownDevice: ErrUnknownDevice,
tcpip.ErrUnknownProtocolOption: ErrUnknownProtocolOption,
tcpip.ErrDuplicateNICID: ErrDuplicateNICID,
tcpip.ErrDuplicateAddress: ErrDuplicateAddress,
tcpip.ErrNoRoute: ErrNoRoute,
tcpip.ErrBadLinkEndpoint: ErrBadLinkEndpoint,
tcpip.ErrAlreadyBound: ErrAlreadyBound,
tcpip.ErrInvalidEndpointState: ErrInvalidEndpointState,
tcpip.ErrAlreadyConnecting: ErrAlreadyConnecting,
tcpip.ErrAlreadyConnected: ErrAlreadyConnected,
tcpip.ErrNoPortAvailable: ErrNoPortAvailable,
tcpip.ErrPortInUse: ErrPortInUse,
tcpip.ErrBadLocalAddress: ErrBadLocalAddress,
tcpip.ErrClosedForSend: ErrClosedForSend,
tcpip.ErrClosedForReceive: ErrClosedForReceive,
tcpip.ErrWouldBlock: ErrWouldBlock,
tcpip.ErrConnectionRefused: ErrConnectionRefused,
tcpip.ErrTimeout: ErrTimeout,
tcpip.ErrAborted: ErrAborted,
tcpip.ErrConnectStarted: ErrConnectStarted,
tcpip.ErrDestinationRequired: ErrDestinationRequired,
tcpip.ErrNotSupported: ErrNotSupported,
tcpip.ErrQueueSizeNotSupported: ErrQueueSizeNotSupported,
tcpip.ErrNotConnected: ErrNotConnected,
tcpip.ErrConnectionReset: ErrConnectionReset,
tcpip.ErrConnectionAborted: ErrConnectionAborted,
tcpip.ErrNoSuchFile: ErrNoSuchFile,
tcpip.ErrInvalidOptionValue: ErrInvalidOptionValue,
tcpip.ErrNoLinkAddress: ErrHostDown,
tcpip.ErrBadAddress: ErrBadAddress,
tcpip.ErrNetworkUnreachable: ErrNetworkUnreachable,
tcpip.ErrMessageTooLong: ErrMessageTooLong,
tcpip.ErrNoBufferSpace: ErrNoBufferSpace,
tcpip.ErrBroadcastDisabled: ErrBroadcastDisabled,
tcpip.ErrNotPermitted: ErrNotPermittedNet,
tcpip.ErrUnknownProtocol: ErrUnknownProtocol,
tcpip.ErrUnknownNICID: ErrUnknownNICID,
tcpip.ErrUnknownDevice: ErrUnknownDevice,
tcpip.ErrUnknownProtocolOption: ErrUnknownProtocolOption,
tcpip.ErrDuplicateNICID: ErrDuplicateNICID,
tcpip.ErrDuplicateAddress: ErrDuplicateAddress,
tcpip.ErrNoRoute: ErrNoRoute,
tcpip.ErrBadLinkEndpoint: ErrBadLinkEndpoint,
tcpip.ErrAlreadyBound: ErrAlreadyBound,
tcpip.ErrInvalidEndpointState: ErrInvalidEndpointState,
tcpip.ErrAlreadyConnecting: ErrAlreadyConnecting,
tcpip.ErrAlreadyConnected: ErrAlreadyConnected,
tcpip.ErrNoPortAvailable: ErrNoPortAvailable,
tcpip.ErrPortInUse: ErrPortInUse,
tcpip.ErrBadLocalAddress: ErrBadLocalAddress,
tcpip.ErrClosedForSend: ErrClosedForSend,
tcpip.ErrClosedForReceive: ErrClosedForReceive,
tcpip.ErrWouldBlock: ErrWouldBlock,
tcpip.ErrConnectionRefused: ErrConnectionRefused,
tcpip.ErrTimeout: ErrTimeout,
tcpip.ErrAborted: ErrAborted,
tcpip.ErrConnectStarted: ErrConnectStarted,
tcpip.ErrDestinationRequired: ErrDestinationRequired,
tcpip.ErrNotSupported: ErrNotSupported,
tcpip.ErrQueueSizeNotSupported: ErrQueueSizeNotSupported,
tcpip.ErrNotConnected: ErrNotConnected,
tcpip.ErrConnectionReset: ErrConnectionReset,
tcpip.ErrConnectionAborted: ErrConnectionAborted,
tcpip.ErrNoSuchFile: ErrNoSuchFile,
tcpip.ErrInvalidOptionValue: ErrInvalidOptionValue,
tcpip.ErrNoLinkAddress: ErrHostDown,
tcpip.ErrBadAddress: ErrBadAddress,
tcpip.ErrNetworkUnreachable: ErrNetworkUnreachable,
tcpip.ErrMessageTooLong: ErrMessageTooLong,
tcpip.ErrNoBufferSpace: ErrNoBufferSpace,
tcpip.ErrBroadcastDisabled: ErrBroadcastDisabled,
tcpip.ErrNotPermitted: ErrNotPermittedNet,
tcpip.ErrAddressFamilyNotSupported: ErrAddressFamilyNotSupported,
}
// TranslateNetstackError converts an error from the tcpip package to a sentry

View File

@ -66,43 +66,44 @@ func (e *Error) IgnoreStats() bool {
// Errors that can be returned by the network stack.
var (
ErrUnknownProtocol = &Error{msg: "unknown protocol"}
ErrUnknownNICID = &Error{msg: "unknown nic id"}
ErrUnknownDevice = &Error{msg: "unknown device"}
ErrUnknownProtocolOption = &Error{msg: "unknown option for protocol"}
ErrDuplicateNICID = &Error{msg: "duplicate nic id"}
ErrDuplicateAddress = &Error{msg: "duplicate address"}
ErrNoRoute = &Error{msg: "no route"}
ErrBadLinkEndpoint = &Error{msg: "bad link layer endpoint"}
ErrAlreadyBound = &Error{msg: "endpoint already bound", ignoreStats: true}
ErrInvalidEndpointState = &Error{msg: "endpoint is in invalid state"}
ErrAlreadyConnecting = &Error{msg: "endpoint is already connecting", ignoreStats: true}
ErrAlreadyConnected = &Error{msg: "endpoint is already connected", ignoreStats: true}
ErrNoPortAvailable = &Error{msg: "no ports are available"}
ErrPortInUse = &Error{msg: "port is in use"}
ErrBadLocalAddress = &Error{msg: "bad local address"}
ErrClosedForSend = &Error{msg: "endpoint is closed for send"}
ErrClosedForReceive = &Error{msg: "endpoint is closed for receive"}
ErrWouldBlock = &Error{msg: "operation would block", ignoreStats: true}
ErrConnectionRefused = &Error{msg: "connection was refused"}
ErrTimeout = &Error{msg: "operation timed out"}
ErrAborted = &Error{msg: "operation aborted"}
ErrConnectStarted = &Error{msg: "connection attempt started", ignoreStats: true}
ErrDestinationRequired = &Error{msg: "destination address is required"}
ErrNotSupported = &Error{msg: "operation not supported"}
ErrQueueSizeNotSupported = &Error{msg: "queue size querying not supported"}
ErrNotConnected = &Error{msg: "endpoint not connected"}
ErrConnectionReset = &Error{msg: "connection reset by peer"}
ErrConnectionAborted = &Error{msg: "connection aborted"}
ErrNoSuchFile = &Error{msg: "no such file"}
ErrInvalidOptionValue = &Error{msg: "invalid option value specified"}
ErrNoLinkAddress = &Error{msg: "no remote link address"}
ErrBadAddress = &Error{msg: "bad address"}
ErrNetworkUnreachable = &Error{msg: "network is unreachable"}
ErrMessageTooLong = &Error{msg: "message too long"}
ErrNoBufferSpace = &Error{msg: "no buffer space available"}
ErrBroadcastDisabled = &Error{msg: "broadcast socket option disabled"}
ErrNotPermitted = &Error{msg: "operation not permitted"}
ErrUnknownProtocol = &Error{msg: "unknown protocol"}
ErrUnknownNICID = &Error{msg: "unknown nic id"}
ErrUnknownDevice = &Error{msg: "unknown device"}
ErrUnknownProtocolOption = &Error{msg: "unknown option for protocol"}
ErrDuplicateNICID = &Error{msg: "duplicate nic id"}
ErrDuplicateAddress = &Error{msg: "duplicate address"}
ErrNoRoute = &Error{msg: "no route"}
ErrBadLinkEndpoint = &Error{msg: "bad link layer endpoint"}
ErrAlreadyBound = &Error{msg: "endpoint already bound", ignoreStats: true}
ErrInvalidEndpointState = &Error{msg: "endpoint is in invalid state"}
ErrAlreadyConnecting = &Error{msg: "endpoint is already connecting", ignoreStats: true}
ErrAlreadyConnected = &Error{msg: "endpoint is already connected", ignoreStats: true}
ErrNoPortAvailable = &Error{msg: "no ports are available"}
ErrPortInUse = &Error{msg: "port is in use"}
ErrBadLocalAddress = &Error{msg: "bad local address"}
ErrClosedForSend = &Error{msg: "endpoint is closed for send"}
ErrClosedForReceive = &Error{msg: "endpoint is closed for receive"}
ErrWouldBlock = &Error{msg: "operation would block", ignoreStats: true}
ErrConnectionRefused = &Error{msg: "connection was refused"}
ErrTimeout = &Error{msg: "operation timed out"}
ErrAborted = &Error{msg: "operation aborted"}
ErrConnectStarted = &Error{msg: "connection attempt started", ignoreStats: true}
ErrDestinationRequired = &Error{msg: "destination address is required"}
ErrNotSupported = &Error{msg: "operation not supported"}
ErrQueueSizeNotSupported = &Error{msg: "queue size querying not supported"}
ErrNotConnected = &Error{msg: "endpoint not connected"}
ErrConnectionReset = &Error{msg: "connection reset by peer"}
ErrConnectionAborted = &Error{msg: "connection aborted"}
ErrNoSuchFile = &Error{msg: "no such file"}
ErrInvalidOptionValue = &Error{msg: "invalid option value specified"}
ErrNoLinkAddress = &Error{msg: "no remote link address"}
ErrBadAddress = &Error{msg: "bad address"}
ErrNetworkUnreachable = &Error{msg: "network is unreachable"}
ErrMessageTooLong = &Error{msg: "message too long"}
ErrNoBufferSpace = &Error{msg: "no buffer space available"}
ErrBroadcastDisabled = &Error{msg: "broadcast socket option disabled"}
ErrNotPermitted = &Error{msg: "operation not permitted"}
ErrAddressFamilyNotSupported = &Error{msg: "address family not supported by protocol"}
)
// Errors related to Subnet
@ -339,6 +340,10 @@ type Endpoint interface {
// get the actual result. The first call to Connect after the socket has
// connected returns nil. Calling connect again results in ErrAlreadyConnected.
// Anything else -- the attempt to connect failed.
//
// If address.Addr is empty, this means that Enpoint has to be
// disconnected if this is supported, otherwise
// ErrAddressFamilyNotSupported must be returned.
Connect(address FullAddress) *Error
// Shutdown closes the read and/or write end of the endpoint connection

View File

@ -422,6 +422,11 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
if addr.Addr == "" {
// AF_UNSPEC isn't supported.
return tcpip.ErrAddressFamilyNotSupported
}
nicid := addr.NIC
localPort := uint16(0)
switch e.state {

View File

@ -298,6 +298,11 @@ func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
if addr.Addr == "" {
// AF_UNSPEC isn't supported.
return tcpip.ErrAddressFamilyNotSupported
}
if ep.closed {
return tcpip.ErrInvalidEndpointState
}

View File

@ -1271,6 +1271,11 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol
// Connect connects the endpoint to its peer.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
if addr.Addr == "" && addr.Port == 0 {
// AF_UNSPEC isn't supported.
return tcpip.ErrAddressFamilyNotSupported
}
return e.connect(addr, true, true)
}

View File

@ -342,6 +342,7 @@ func loadError(s string) *tcpip.Error {
tcpip.ErrNoBufferSpace,
tcpip.ErrBroadcastDisabled,
tcpip.ErrNotPermitted,
tcpip.ErrAddressFamilyNotSupported,
}
messageToError = make(map[string]*tcpip.Error)

View File

@ -698,8 +698,44 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t
return netProto, nil
}
func (e *endpoint) disconnect() *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
if e.state != stateConnected {
return nil
}
id := stack.TransportEndpointID{}
// Exclude ephemerally bound endpoints.
if e.bindNICID != 0 || e.id.LocalAddress == "" {
var err *tcpip.Error
id = stack.TransportEndpointID{
LocalPort: e.id.LocalPort,
LocalAddress: e.id.LocalAddress,
}
id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id)
if err != nil {
return err
}
e.state = stateBound
} else {
e.state = stateInitial
}
e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
e.id = id
e.route.Release()
e.route = stack.Route{}
e.dstPort = 0
return nil
}
// Connect connects the endpoint to its peer. Specifying a NIC is optional.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
if addr.Addr == "" {
return e.disconnect()
}
if addr.Port == 0 {
// We don't support connecting to port zero.
return tcpip.ErrInvalidEndpointState
@ -734,12 +770,16 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
defer r.Release()
id := stack.TransportEndpointID{
LocalAddress: r.LocalAddress,
LocalAddress: e.id.LocalAddress,
LocalPort: localPort,
RemotePort: addr.Port,
RemoteAddress: r.RemoteAddress,
}
if e.state == stateInitial {
id.LocalAddress = r.LocalAddress
}
// Even if we're connected, this endpoint can still be used to send
// packets on a different network protocol, so we register both even if
// v6only is set to false and this is an ipv6 endpoint.

View File

@ -92,8 +92,6 @@ func (e *endpoint) afterLoad() {
if err != nil {
panic(*err)
}
e.id.LocalAddress = e.route.LocalAddress
} else if len(e.id.LocalAddress) != 0 { // stateBound
if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 {
panic(tcpip.ErrBadLocalAddress)

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <arpa/inet.h>
#include <fcntl.h>
#include <linux/errqueue.h>
#include <netinet/in.h>
@ -304,12 +305,50 @@ TEST_P(UdpSocketTest, ReceiveAfterConnect) {
SyscallSucceedsWithValue(sizeof(buf)));
// Receive the data.
char received[512];
char received[sizeof(buf)];
EXPECT_THAT(recv(s_, received, sizeof(received), 0),
SyscallSucceedsWithValue(sizeof(received)));
EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0);
}
TEST_P(UdpSocketTest, ReceiveAfterDisconnect) {
// Connect s_ to loopback:TestPort, and bind t_ to loopback:TestPort.
ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds());
ASSERT_THAT(bind(t_, addr_[0], addrlen_), SyscallSucceeds());
ASSERT_THAT(connect(t_, addr_[1], addrlen_), SyscallSucceeds());
// Get the address s_ was bound to during connect.
struct sockaddr_storage addr;
socklen_t addrlen = sizeof(addr);
EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
SyscallSucceeds());
EXPECT_EQ(addrlen, addrlen_);
for (int i = 0; i < 2; i++) {
// Send from t_ to s_.
char buf[512];
RandomizeBuffer(buf, sizeof(buf));
EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
SyscallSucceeds());
ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0,
reinterpret_cast<sockaddr*>(&addr), addrlen),
SyscallSucceedsWithValue(sizeof(buf)));
// Receive the data.
char received[sizeof(buf)];
EXPECT_THAT(recv(s_, received, sizeof(received), 0),
SyscallSucceedsWithValue(sizeof(received)));
EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0);
// Disconnect s_.
struct sockaddr addr = {};
addr.sa_family = AF_UNSPEC;
ASSERT_THAT(connect(s_, &addr, sizeof(addr.sa_family)), SyscallSucceeds());
// Connect s_ loopback:TestPort.
ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds());
}
}
TEST_P(UdpSocketTest, Connect) {
ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds());
@ -335,6 +374,112 @@ TEST_P(UdpSocketTest, Connect) {
EXPECT_EQ(memcmp(&peer, addr_[2], addrlen_), 0);
}
TEST_P(UdpSocketTest, DisconnectAfterBind) {
ASSERT_THAT(bind(s_, addr_[1], addrlen_), SyscallSucceeds());
// Connect the socket.
ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds());
struct sockaddr_storage addr = {};
addr.ss_family = AF_UNSPEC;
EXPECT_THAT(
connect(s_, reinterpret_cast<sockaddr*>(&addr), sizeof(addr.ss_family)),
SyscallSucceeds());
// Check that we're still bound.
socklen_t addrlen = sizeof(addr);
EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
SyscallSucceeds());
EXPECT_EQ(addrlen, addrlen_);
EXPECT_EQ(memcmp(&addr, addr_[1], addrlen_), 0);
addrlen = sizeof(addr);
EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
SyscallFailsWithErrno(ENOTCONN));
}
TEST_P(UdpSocketTest, DisconnectAfterBindToAny) {
struct sockaddr_storage baddr = {};
socklen_t addrlen;
auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1]));
if (addr_[0]->sa_family == AF_INET) {
auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr);
addr_in->sin_family = AF_INET;
addr_in->sin_port = port;
inet_pton(AF_INET, "0.0.0.0",
reinterpret_cast<void*>(&addr_in->sin_addr.s_addr));
} else {
auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr);
addr_in->sin6_family = AF_INET6;
addr_in->sin6_port = port;
inet_pton(AF_INET6,
"::", reinterpret_cast<void*>(&addr_in->sin6_addr.s6_addr));
addr_in->sin6_scope_id = 0;
}
ASSERT_THAT(bind(s_, reinterpret_cast<sockaddr*>(&baddr), addrlen_),
SyscallSucceeds());
// Connect the socket.
ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds());
struct sockaddr_storage addr = {};
addr.ss_family = AF_UNSPEC;
EXPECT_THAT(
connect(s_, reinterpret_cast<sockaddr*>(&addr), sizeof(addr.ss_family)),
SyscallSucceeds());
// Check that we're still bound.
addrlen = sizeof(addr);
EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
SyscallSucceeds());
EXPECT_EQ(addrlen, addrlen_);
EXPECT_EQ(memcmp(&addr, &baddr, addrlen), 0);
addrlen = sizeof(addr);
EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
SyscallFailsWithErrno(ENOTCONN));
}
TEST_P(UdpSocketTest, Disconnect) {
for (int i = 0; i < 2; i++) {
// Try to connect again.
EXPECT_THAT(connect(s_, addr_[2], addrlen_), SyscallSucceeds());
// Check that we're connected to the right peer.
struct sockaddr_storage peer;
socklen_t peerlen = sizeof(peer);
EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&peer), &peerlen),
SyscallSucceeds());
EXPECT_EQ(peerlen, addrlen_);
EXPECT_EQ(memcmp(&peer, addr_[2], addrlen_), 0);
// Try to disconnect.
struct sockaddr_storage addr = {};
addr.ss_family = AF_UNSPEC;
EXPECT_THAT(
connect(s_, reinterpret_cast<sockaddr*>(&addr), sizeof(addr.ss_family)),
SyscallSucceeds());
peerlen = sizeof(peer);
EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&peer), &peerlen),
SyscallFailsWithErrno(ENOTCONN));
// Check that we're still bound.
socklen_t addrlen = sizeof(addr);
EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
SyscallSucceeds());
EXPECT_EQ(addrlen, addrlen_);
EXPECT_EQ(*Port(&addr), 0);
}
}
TEST_P(UdpSocketTest, ConnectBadAddress) {
struct sockaddr addr = {};
addr.sa_family = addr_[0]->sa_family;
ASSERT_THAT(connect(s_, &addr, sizeof(addr.sa_family)),
SyscallFailsWithErrno(EINVAL));
}
TEST_P(UdpSocketTest, SendToAddressOtherThanConnected) {
ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds());
@ -397,7 +542,7 @@ TEST_P(UdpSocketTest, SendAndReceiveNotConnected) {
SyscallSucceedsWithValue(sizeof(buf)));
// Receive the data.
char received[512];
char received[sizeof(buf)];
EXPECT_THAT(recv(s_, received, sizeof(received), 0),
SyscallSucceedsWithValue(sizeof(received)));
EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0);
@ -419,7 +564,7 @@ TEST_P(UdpSocketTest, SendAndReceiveConnected) {
SyscallSucceedsWithValue(sizeof(buf)));
// Receive the data.
char received[512];
char received[sizeof(buf)];
EXPECT_THAT(recv(s_, received, sizeof(received), 0),
SyscallSucceedsWithValue(sizeof(received)));
EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0);
@ -462,7 +607,7 @@ TEST_P(UdpSocketTest, ReceiveBeforeConnect) {
ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds());
// Receive the data. It works because it was sent before the connect.
char received[512];
char received[sizeof(buf)];
EXPECT_THAT(recv(s_, received, sizeof(received), 0),
SyscallSucceedsWithValue(sizeof(received)));
EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0);
@ -491,7 +636,7 @@ TEST_P(UdpSocketTest, ReceiveFrom) {
SyscallSucceedsWithValue(sizeof(buf)));
// Receive the data and sender address.
char received[512];
char received[sizeof(buf)];
struct sockaddr_storage addr;
socklen_t addrlen = sizeof(addr);
EXPECT_THAT(recvfrom(s_, received, sizeof(received), 0,