netstack/udp: connect with the AF_UNSPEC address family means disconnect
PiperOrigin-RevId: 256433283
This commit is contained in:
parent
f10862696c
commit
116cac053e
|
@ -285,14 +285,14 @@ func bytesToIPAddress(addr []byte) tcpip.Address {
|
|||
// GetAddress reads an sockaddr struct from the given address and converts it
|
||||
// to the FullAddress format. It supports AF_UNIX, AF_INET and AF_INET6
|
||||
// addresses.
|
||||
func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) {
|
||||
func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syserr.Error) {
|
||||
// Make sure we have at least 2 bytes for the address family.
|
||||
if len(addr) < 2 {
|
||||
return tcpip.FullAddress{}, syserr.ErrInvalidArgument
|
||||
}
|
||||
|
||||
family := usermem.ByteOrder.Uint16(addr)
|
||||
if family != uint16(sfamily) {
|
||||
if family != uint16(sfamily) && (!strict && family != linux.AF_UNSPEC) {
|
||||
return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported
|
||||
}
|
||||
|
||||
|
@ -317,7 +317,7 @@ func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) {
|
|||
case linux.AF_INET:
|
||||
var a linux.SockAddrInet
|
||||
if len(addr) < sockAddrInetSize {
|
||||
return tcpip.FullAddress{}, syserr.ErrBadAddress
|
||||
return tcpip.FullAddress{}, syserr.ErrInvalidArgument
|
||||
}
|
||||
binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a)
|
||||
|
||||
|
@ -330,7 +330,7 @@ func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) {
|
|||
case linux.AF_INET6:
|
||||
var a linux.SockAddrInet6
|
||||
if len(addr) < sockAddrInet6Size {
|
||||
return tcpip.FullAddress{}, syserr.ErrBadAddress
|
||||
return tcpip.FullAddress{}, syserr.ErrInvalidArgument
|
||||
}
|
||||
binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a)
|
||||
|
||||
|
@ -343,6 +343,9 @@ func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) {
|
|||
}
|
||||
return out, nil
|
||||
|
||||
case linux.AF_UNSPEC:
|
||||
return tcpip.FullAddress{}, nil
|
||||
|
||||
default:
|
||||
return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported
|
||||
}
|
||||
|
@ -465,7 +468,7 @@ func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
|
|||
// Connect implements the linux syscall connect(2) for sockets backed by
|
||||
// tpcip.Endpoint.
|
||||
func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
|
||||
addr, err := GetAddress(s.family, sockaddr)
|
||||
addr, err := GetAddress(s.family, sockaddr, false /* strict */)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -498,7 +501,7 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo
|
|||
// Bind implements the linux syscall bind(2) for sockets backed by
|
||||
// tcpip.Endpoint.
|
||||
func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
|
||||
addr, err := GetAddress(s.family, sockaddr)
|
||||
addr, err := GetAddress(s.family, sockaddr, true /* strict */)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -1922,7 +1925,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
|
|||
|
||||
var addr *tcpip.FullAddress
|
||||
if len(to) > 0 {
|
||||
addrBuf, err := GetAddress(s.family, to)
|
||||
addrBuf, err := GetAddress(s.family, to, true /* strict */)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
|
|
@ -110,7 +110,7 @@ func (s *SocketOperations) Endpoint() transport.Endpoint {
|
|||
|
||||
// extractPath extracts and validates the address.
|
||||
func extractPath(sockaddr []byte) (string, *syserr.Error) {
|
||||
addr, err := epsocket.GetAddress(linux.AF_UNIX, sockaddr)
|
||||
addr, err := epsocket.GetAddress(linux.AF_UNIX, sockaddr, true /* strict */)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
|
|
@ -332,7 +332,7 @@ func sockAddr(t *kernel.Task, addr usermem.Addr, length uint32) string {
|
|||
|
||||
switch family {
|
||||
case linux.AF_INET, linux.AF_INET6, linux.AF_UNIX:
|
||||
fa, err := epsocket.GetAddress(int(family), b)
|
||||
fa, err := epsocket.GetAddress(int(family), b, true /* strict */)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("%#x {Family: %s, error extracting address: %v}", addr, familyStr, err)
|
||||
}
|
||||
|
|
|
@ -49,43 +49,44 @@ var (
|
|||
)
|
||||
|
||||
var netstackErrorTranslations = map[*tcpip.Error]*Error{
|
||||
tcpip.ErrUnknownProtocol: ErrUnknownProtocol,
|
||||
tcpip.ErrUnknownNICID: ErrUnknownNICID,
|
||||
tcpip.ErrUnknownDevice: ErrUnknownDevice,
|
||||
tcpip.ErrUnknownProtocolOption: ErrUnknownProtocolOption,
|
||||
tcpip.ErrDuplicateNICID: ErrDuplicateNICID,
|
||||
tcpip.ErrDuplicateAddress: ErrDuplicateAddress,
|
||||
tcpip.ErrNoRoute: ErrNoRoute,
|
||||
tcpip.ErrBadLinkEndpoint: ErrBadLinkEndpoint,
|
||||
tcpip.ErrAlreadyBound: ErrAlreadyBound,
|
||||
tcpip.ErrInvalidEndpointState: ErrInvalidEndpointState,
|
||||
tcpip.ErrAlreadyConnecting: ErrAlreadyConnecting,
|
||||
tcpip.ErrAlreadyConnected: ErrAlreadyConnected,
|
||||
tcpip.ErrNoPortAvailable: ErrNoPortAvailable,
|
||||
tcpip.ErrPortInUse: ErrPortInUse,
|
||||
tcpip.ErrBadLocalAddress: ErrBadLocalAddress,
|
||||
tcpip.ErrClosedForSend: ErrClosedForSend,
|
||||
tcpip.ErrClosedForReceive: ErrClosedForReceive,
|
||||
tcpip.ErrWouldBlock: ErrWouldBlock,
|
||||
tcpip.ErrConnectionRefused: ErrConnectionRefused,
|
||||
tcpip.ErrTimeout: ErrTimeout,
|
||||
tcpip.ErrAborted: ErrAborted,
|
||||
tcpip.ErrConnectStarted: ErrConnectStarted,
|
||||
tcpip.ErrDestinationRequired: ErrDestinationRequired,
|
||||
tcpip.ErrNotSupported: ErrNotSupported,
|
||||
tcpip.ErrQueueSizeNotSupported: ErrQueueSizeNotSupported,
|
||||
tcpip.ErrNotConnected: ErrNotConnected,
|
||||
tcpip.ErrConnectionReset: ErrConnectionReset,
|
||||
tcpip.ErrConnectionAborted: ErrConnectionAborted,
|
||||
tcpip.ErrNoSuchFile: ErrNoSuchFile,
|
||||
tcpip.ErrInvalidOptionValue: ErrInvalidOptionValue,
|
||||
tcpip.ErrNoLinkAddress: ErrHostDown,
|
||||
tcpip.ErrBadAddress: ErrBadAddress,
|
||||
tcpip.ErrNetworkUnreachable: ErrNetworkUnreachable,
|
||||
tcpip.ErrMessageTooLong: ErrMessageTooLong,
|
||||
tcpip.ErrNoBufferSpace: ErrNoBufferSpace,
|
||||
tcpip.ErrBroadcastDisabled: ErrBroadcastDisabled,
|
||||
tcpip.ErrNotPermitted: ErrNotPermittedNet,
|
||||
tcpip.ErrUnknownProtocol: ErrUnknownProtocol,
|
||||
tcpip.ErrUnknownNICID: ErrUnknownNICID,
|
||||
tcpip.ErrUnknownDevice: ErrUnknownDevice,
|
||||
tcpip.ErrUnknownProtocolOption: ErrUnknownProtocolOption,
|
||||
tcpip.ErrDuplicateNICID: ErrDuplicateNICID,
|
||||
tcpip.ErrDuplicateAddress: ErrDuplicateAddress,
|
||||
tcpip.ErrNoRoute: ErrNoRoute,
|
||||
tcpip.ErrBadLinkEndpoint: ErrBadLinkEndpoint,
|
||||
tcpip.ErrAlreadyBound: ErrAlreadyBound,
|
||||
tcpip.ErrInvalidEndpointState: ErrInvalidEndpointState,
|
||||
tcpip.ErrAlreadyConnecting: ErrAlreadyConnecting,
|
||||
tcpip.ErrAlreadyConnected: ErrAlreadyConnected,
|
||||
tcpip.ErrNoPortAvailable: ErrNoPortAvailable,
|
||||
tcpip.ErrPortInUse: ErrPortInUse,
|
||||
tcpip.ErrBadLocalAddress: ErrBadLocalAddress,
|
||||
tcpip.ErrClosedForSend: ErrClosedForSend,
|
||||
tcpip.ErrClosedForReceive: ErrClosedForReceive,
|
||||
tcpip.ErrWouldBlock: ErrWouldBlock,
|
||||
tcpip.ErrConnectionRefused: ErrConnectionRefused,
|
||||
tcpip.ErrTimeout: ErrTimeout,
|
||||
tcpip.ErrAborted: ErrAborted,
|
||||
tcpip.ErrConnectStarted: ErrConnectStarted,
|
||||
tcpip.ErrDestinationRequired: ErrDestinationRequired,
|
||||
tcpip.ErrNotSupported: ErrNotSupported,
|
||||
tcpip.ErrQueueSizeNotSupported: ErrQueueSizeNotSupported,
|
||||
tcpip.ErrNotConnected: ErrNotConnected,
|
||||
tcpip.ErrConnectionReset: ErrConnectionReset,
|
||||
tcpip.ErrConnectionAborted: ErrConnectionAborted,
|
||||
tcpip.ErrNoSuchFile: ErrNoSuchFile,
|
||||
tcpip.ErrInvalidOptionValue: ErrInvalidOptionValue,
|
||||
tcpip.ErrNoLinkAddress: ErrHostDown,
|
||||
tcpip.ErrBadAddress: ErrBadAddress,
|
||||
tcpip.ErrNetworkUnreachable: ErrNetworkUnreachable,
|
||||
tcpip.ErrMessageTooLong: ErrMessageTooLong,
|
||||
tcpip.ErrNoBufferSpace: ErrNoBufferSpace,
|
||||
tcpip.ErrBroadcastDisabled: ErrBroadcastDisabled,
|
||||
tcpip.ErrNotPermitted: ErrNotPermittedNet,
|
||||
tcpip.ErrAddressFamilyNotSupported: ErrAddressFamilyNotSupported,
|
||||
}
|
||||
|
||||
// TranslateNetstackError converts an error from the tcpip package to a sentry
|
||||
|
|
|
@ -66,43 +66,44 @@ func (e *Error) IgnoreStats() bool {
|
|||
|
||||
// Errors that can be returned by the network stack.
|
||||
var (
|
||||
ErrUnknownProtocol = &Error{msg: "unknown protocol"}
|
||||
ErrUnknownNICID = &Error{msg: "unknown nic id"}
|
||||
ErrUnknownDevice = &Error{msg: "unknown device"}
|
||||
ErrUnknownProtocolOption = &Error{msg: "unknown option for protocol"}
|
||||
ErrDuplicateNICID = &Error{msg: "duplicate nic id"}
|
||||
ErrDuplicateAddress = &Error{msg: "duplicate address"}
|
||||
ErrNoRoute = &Error{msg: "no route"}
|
||||
ErrBadLinkEndpoint = &Error{msg: "bad link layer endpoint"}
|
||||
ErrAlreadyBound = &Error{msg: "endpoint already bound", ignoreStats: true}
|
||||
ErrInvalidEndpointState = &Error{msg: "endpoint is in invalid state"}
|
||||
ErrAlreadyConnecting = &Error{msg: "endpoint is already connecting", ignoreStats: true}
|
||||
ErrAlreadyConnected = &Error{msg: "endpoint is already connected", ignoreStats: true}
|
||||
ErrNoPortAvailable = &Error{msg: "no ports are available"}
|
||||
ErrPortInUse = &Error{msg: "port is in use"}
|
||||
ErrBadLocalAddress = &Error{msg: "bad local address"}
|
||||
ErrClosedForSend = &Error{msg: "endpoint is closed for send"}
|
||||
ErrClosedForReceive = &Error{msg: "endpoint is closed for receive"}
|
||||
ErrWouldBlock = &Error{msg: "operation would block", ignoreStats: true}
|
||||
ErrConnectionRefused = &Error{msg: "connection was refused"}
|
||||
ErrTimeout = &Error{msg: "operation timed out"}
|
||||
ErrAborted = &Error{msg: "operation aborted"}
|
||||
ErrConnectStarted = &Error{msg: "connection attempt started", ignoreStats: true}
|
||||
ErrDestinationRequired = &Error{msg: "destination address is required"}
|
||||
ErrNotSupported = &Error{msg: "operation not supported"}
|
||||
ErrQueueSizeNotSupported = &Error{msg: "queue size querying not supported"}
|
||||
ErrNotConnected = &Error{msg: "endpoint not connected"}
|
||||
ErrConnectionReset = &Error{msg: "connection reset by peer"}
|
||||
ErrConnectionAborted = &Error{msg: "connection aborted"}
|
||||
ErrNoSuchFile = &Error{msg: "no such file"}
|
||||
ErrInvalidOptionValue = &Error{msg: "invalid option value specified"}
|
||||
ErrNoLinkAddress = &Error{msg: "no remote link address"}
|
||||
ErrBadAddress = &Error{msg: "bad address"}
|
||||
ErrNetworkUnreachable = &Error{msg: "network is unreachable"}
|
||||
ErrMessageTooLong = &Error{msg: "message too long"}
|
||||
ErrNoBufferSpace = &Error{msg: "no buffer space available"}
|
||||
ErrBroadcastDisabled = &Error{msg: "broadcast socket option disabled"}
|
||||
ErrNotPermitted = &Error{msg: "operation not permitted"}
|
||||
ErrUnknownProtocol = &Error{msg: "unknown protocol"}
|
||||
ErrUnknownNICID = &Error{msg: "unknown nic id"}
|
||||
ErrUnknownDevice = &Error{msg: "unknown device"}
|
||||
ErrUnknownProtocolOption = &Error{msg: "unknown option for protocol"}
|
||||
ErrDuplicateNICID = &Error{msg: "duplicate nic id"}
|
||||
ErrDuplicateAddress = &Error{msg: "duplicate address"}
|
||||
ErrNoRoute = &Error{msg: "no route"}
|
||||
ErrBadLinkEndpoint = &Error{msg: "bad link layer endpoint"}
|
||||
ErrAlreadyBound = &Error{msg: "endpoint already bound", ignoreStats: true}
|
||||
ErrInvalidEndpointState = &Error{msg: "endpoint is in invalid state"}
|
||||
ErrAlreadyConnecting = &Error{msg: "endpoint is already connecting", ignoreStats: true}
|
||||
ErrAlreadyConnected = &Error{msg: "endpoint is already connected", ignoreStats: true}
|
||||
ErrNoPortAvailable = &Error{msg: "no ports are available"}
|
||||
ErrPortInUse = &Error{msg: "port is in use"}
|
||||
ErrBadLocalAddress = &Error{msg: "bad local address"}
|
||||
ErrClosedForSend = &Error{msg: "endpoint is closed for send"}
|
||||
ErrClosedForReceive = &Error{msg: "endpoint is closed for receive"}
|
||||
ErrWouldBlock = &Error{msg: "operation would block", ignoreStats: true}
|
||||
ErrConnectionRefused = &Error{msg: "connection was refused"}
|
||||
ErrTimeout = &Error{msg: "operation timed out"}
|
||||
ErrAborted = &Error{msg: "operation aborted"}
|
||||
ErrConnectStarted = &Error{msg: "connection attempt started", ignoreStats: true}
|
||||
ErrDestinationRequired = &Error{msg: "destination address is required"}
|
||||
ErrNotSupported = &Error{msg: "operation not supported"}
|
||||
ErrQueueSizeNotSupported = &Error{msg: "queue size querying not supported"}
|
||||
ErrNotConnected = &Error{msg: "endpoint not connected"}
|
||||
ErrConnectionReset = &Error{msg: "connection reset by peer"}
|
||||
ErrConnectionAborted = &Error{msg: "connection aborted"}
|
||||
ErrNoSuchFile = &Error{msg: "no such file"}
|
||||
ErrInvalidOptionValue = &Error{msg: "invalid option value specified"}
|
||||
ErrNoLinkAddress = &Error{msg: "no remote link address"}
|
||||
ErrBadAddress = &Error{msg: "bad address"}
|
||||
ErrNetworkUnreachable = &Error{msg: "network is unreachable"}
|
||||
ErrMessageTooLong = &Error{msg: "message too long"}
|
||||
ErrNoBufferSpace = &Error{msg: "no buffer space available"}
|
||||
ErrBroadcastDisabled = &Error{msg: "broadcast socket option disabled"}
|
||||
ErrNotPermitted = &Error{msg: "operation not permitted"}
|
||||
ErrAddressFamilyNotSupported = &Error{msg: "address family not supported by protocol"}
|
||||
)
|
||||
|
||||
// Errors related to Subnet
|
||||
|
@ -339,6 +340,10 @@ type Endpoint interface {
|
|||
// get the actual result. The first call to Connect after the socket has
|
||||
// connected returns nil. Calling connect again results in ErrAlreadyConnected.
|
||||
// Anything else -- the attempt to connect failed.
|
||||
//
|
||||
// If address.Addr is empty, this means that Enpoint has to be
|
||||
// disconnected if this is supported, otherwise
|
||||
// ErrAddressFamilyNotSupported must be returned.
|
||||
Connect(address FullAddress) *Error
|
||||
|
||||
// Shutdown closes the read and/or write end of the endpoint connection
|
||||
|
|
|
@ -422,6 +422,11 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
|
|||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
if addr.Addr == "" {
|
||||
// AF_UNSPEC isn't supported.
|
||||
return tcpip.ErrAddressFamilyNotSupported
|
||||
}
|
||||
|
||||
nicid := addr.NIC
|
||||
localPort := uint16(0)
|
||||
switch e.state {
|
||||
|
|
|
@ -298,6 +298,11 @@ func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
|
|||
ep.mu.Lock()
|
||||
defer ep.mu.Unlock()
|
||||
|
||||
if addr.Addr == "" {
|
||||
// AF_UNSPEC isn't supported.
|
||||
return tcpip.ErrAddressFamilyNotSupported
|
||||
}
|
||||
|
||||
if ep.closed {
|
||||
return tcpip.ErrInvalidEndpointState
|
||||
}
|
||||
|
|
|
@ -1271,6 +1271,11 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol
|
|||
|
||||
// Connect connects the endpoint to its peer.
|
||||
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
|
||||
if addr.Addr == "" && addr.Port == 0 {
|
||||
// AF_UNSPEC isn't supported.
|
||||
return tcpip.ErrAddressFamilyNotSupported
|
||||
}
|
||||
|
||||
return e.connect(addr, true, true)
|
||||
}
|
||||
|
||||
|
|
|
@ -342,6 +342,7 @@ func loadError(s string) *tcpip.Error {
|
|||
tcpip.ErrNoBufferSpace,
|
||||
tcpip.ErrBroadcastDisabled,
|
||||
tcpip.ErrNotPermitted,
|
||||
tcpip.ErrAddressFamilyNotSupported,
|
||||
}
|
||||
|
||||
messageToError = make(map[string]*tcpip.Error)
|
||||
|
|
|
@ -698,8 +698,44 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t
|
|||
return netProto, nil
|
||||
}
|
||||
|
||||
func (e *endpoint) disconnect() *tcpip.Error {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
if e.state != stateConnected {
|
||||
return nil
|
||||
}
|
||||
id := stack.TransportEndpointID{}
|
||||
// Exclude ephemerally bound endpoints.
|
||||
if e.bindNICID != 0 || e.id.LocalAddress == "" {
|
||||
var err *tcpip.Error
|
||||
id = stack.TransportEndpointID{
|
||||
LocalPort: e.id.LocalPort,
|
||||
LocalAddress: e.id.LocalAddress,
|
||||
}
|
||||
id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
e.state = stateBound
|
||||
} else {
|
||||
e.state = stateInitial
|
||||
}
|
||||
|
||||
e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
|
||||
e.id = id
|
||||
e.route.Release()
|
||||
e.route = stack.Route{}
|
||||
e.dstPort = 0
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Connect connects the endpoint to its peer. Specifying a NIC is optional.
|
||||
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
|
||||
if addr.Addr == "" {
|
||||
return e.disconnect()
|
||||
}
|
||||
if addr.Port == 0 {
|
||||
// We don't support connecting to port zero.
|
||||
return tcpip.ErrInvalidEndpointState
|
||||
|
@ -734,12 +770,16 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
|
|||
defer r.Release()
|
||||
|
||||
id := stack.TransportEndpointID{
|
||||
LocalAddress: r.LocalAddress,
|
||||
LocalAddress: e.id.LocalAddress,
|
||||
LocalPort: localPort,
|
||||
RemotePort: addr.Port,
|
||||
RemoteAddress: r.RemoteAddress,
|
||||
}
|
||||
|
||||
if e.state == stateInitial {
|
||||
id.LocalAddress = r.LocalAddress
|
||||
}
|
||||
|
||||
// Even if we're connected, this endpoint can still be used to send
|
||||
// packets on a different network protocol, so we register both even if
|
||||
// v6only is set to false and this is an ipv6 endpoint.
|
||||
|
|
|
@ -92,8 +92,6 @@ func (e *endpoint) afterLoad() {
|
|||
if err != nil {
|
||||
panic(*err)
|
||||
}
|
||||
|
||||
e.id.LocalAddress = e.route.LocalAddress
|
||||
} else if len(e.id.LocalAddress) != 0 { // stateBound
|
||||
if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 {
|
||||
panic(tcpip.ErrBadLocalAddress)
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <arpa/inet.h>
|
||||
#include <fcntl.h>
|
||||
#include <linux/errqueue.h>
|
||||
#include <netinet/in.h>
|
||||
|
@ -304,12 +305,50 @@ TEST_P(UdpSocketTest, ReceiveAfterConnect) {
|
|||
SyscallSucceedsWithValue(sizeof(buf)));
|
||||
|
||||
// Receive the data.
|
||||
char received[512];
|
||||
char received[sizeof(buf)];
|
||||
EXPECT_THAT(recv(s_, received, sizeof(received), 0),
|
||||
SyscallSucceedsWithValue(sizeof(received)));
|
||||
EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0);
|
||||
}
|
||||
|
||||
TEST_P(UdpSocketTest, ReceiveAfterDisconnect) {
|
||||
// Connect s_ to loopback:TestPort, and bind t_ to loopback:TestPort.
|
||||
ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds());
|
||||
ASSERT_THAT(bind(t_, addr_[0], addrlen_), SyscallSucceeds());
|
||||
ASSERT_THAT(connect(t_, addr_[1], addrlen_), SyscallSucceeds());
|
||||
|
||||
// Get the address s_ was bound to during connect.
|
||||
struct sockaddr_storage addr;
|
||||
socklen_t addrlen = sizeof(addr);
|
||||
EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
|
||||
SyscallSucceeds());
|
||||
EXPECT_EQ(addrlen, addrlen_);
|
||||
|
||||
for (int i = 0; i < 2; i++) {
|
||||
// Send from t_ to s_.
|
||||
char buf[512];
|
||||
RandomizeBuffer(buf, sizeof(buf));
|
||||
EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
|
||||
SyscallSucceeds());
|
||||
ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0,
|
||||
reinterpret_cast<sockaddr*>(&addr), addrlen),
|
||||
SyscallSucceedsWithValue(sizeof(buf)));
|
||||
|
||||
// Receive the data.
|
||||
char received[sizeof(buf)];
|
||||
EXPECT_THAT(recv(s_, received, sizeof(received), 0),
|
||||
SyscallSucceedsWithValue(sizeof(received)));
|
||||
EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0);
|
||||
|
||||
// Disconnect s_.
|
||||
struct sockaddr addr = {};
|
||||
addr.sa_family = AF_UNSPEC;
|
||||
ASSERT_THAT(connect(s_, &addr, sizeof(addr.sa_family)), SyscallSucceeds());
|
||||
// Connect s_ loopback:TestPort.
|
||||
ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(UdpSocketTest, Connect) {
|
||||
ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds());
|
||||
|
||||
|
@ -335,6 +374,112 @@ TEST_P(UdpSocketTest, Connect) {
|
|||
EXPECT_EQ(memcmp(&peer, addr_[2], addrlen_), 0);
|
||||
}
|
||||
|
||||
TEST_P(UdpSocketTest, DisconnectAfterBind) {
|
||||
ASSERT_THAT(bind(s_, addr_[1], addrlen_), SyscallSucceeds());
|
||||
// Connect the socket.
|
||||
ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds());
|
||||
|
||||
struct sockaddr_storage addr = {};
|
||||
addr.ss_family = AF_UNSPEC;
|
||||
EXPECT_THAT(
|
||||
connect(s_, reinterpret_cast<sockaddr*>(&addr), sizeof(addr.ss_family)),
|
||||
SyscallSucceeds());
|
||||
|
||||
// Check that we're still bound.
|
||||
socklen_t addrlen = sizeof(addr);
|
||||
EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
|
||||
SyscallSucceeds());
|
||||
|
||||
EXPECT_EQ(addrlen, addrlen_);
|
||||
EXPECT_EQ(memcmp(&addr, addr_[1], addrlen_), 0);
|
||||
|
||||
addrlen = sizeof(addr);
|
||||
EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
|
||||
SyscallFailsWithErrno(ENOTCONN));
|
||||
}
|
||||
|
||||
TEST_P(UdpSocketTest, DisconnectAfterBindToAny) {
|
||||
struct sockaddr_storage baddr = {};
|
||||
socklen_t addrlen;
|
||||
auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1]));
|
||||
if (addr_[0]->sa_family == AF_INET) {
|
||||
auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr);
|
||||
addr_in->sin_family = AF_INET;
|
||||
addr_in->sin_port = port;
|
||||
inet_pton(AF_INET, "0.0.0.0",
|
||||
reinterpret_cast<void*>(&addr_in->sin_addr.s_addr));
|
||||
} else {
|
||||
auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr);
|
||||
addr_in->sin6_family = AF_INET6;
|
||||
addr_in->sin6_port = port;
|
||||
inet_pton(AF_INET6,
|
||||
"::", reinterpret_cast<void*>(&addr_in->sin6_addr.s6_addr));
|
||||
addr_in->sin6_scope_id = 0;
|
||||
}
|
||||
ASSERT_THAT(bind(s_, reinterpret_cast<sockaddr*>(&baddr), addrlen_),
|
||||
SyscallSucceeds());
|
||||
// Connect the socket.
|
||||
ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds());
|
||||
|
||||
struct sockaddr_storage addr = {};
|
||||
addr.ss_family = AF_UNSPEC;
|
||||
EXPECT_THAT(
|
||||
connect(s_, reinterpret_cast<sockaddr*>(&addr), sizeof(addr.ss_family)),
|
||||
SyscallSucceeds());
|
||||
|
||||
// Check that we're still bound.
|
||||
addrlen = sizeof(addr);
|
||||
EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
|
||||
SyscallSucceeds());
|
||||
|
||||
EXPECT_EQ(addrlen, addrlen_);
|
||||
EXPECT_EQ(memcmp(&addr, &baddr, addrlen), 0);
|
||||
|
||||
addrlen = sizeof(addr);
|
||||
EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
|
||||
SyscallFailsWithErrno(ENOTCONN));
|
||||
}
|
||||
|
||||
TEST_P(UdpSocketTest, Disconnect) {
|
||||
for (int i = 0; i < 2; i++) {
|
||||
// Try to connect again.
|
||||
EXPECT_THAT(connect(s_, addr_[2], addrlen_), SyscallSucceeds());
|
||||
|
||||
// Check that we're connected to the right peer.
|
||||
struct sockaddr_storage peer;
|
||||
socklen_t peerlen = sizeof(peer);
|
||||
EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&peer), &peerlen),
|
||||
SyscallSucceeds());
|
||||
EXPECT_EQ(peerlen, addrlen_);
|
||||
EXPECT_EQ(memcmp(&peer, addr_[2], addrlen_), 0);
|
||||
|
||||
// Try to disconnect.
|
||||
struct sockaddr_storage addr = {};
|
||||
addr.ss_family = AF_UNSPEC;
|
||||
EXPECT_THAT(
|
||||
connect(s_, reinterpret_cast<sockaddr*>(&addr), sizeof(addr.ss_family)),
|
||||
SyscallSucceeds());
|
||||
|
||||
peerlen = sizeof(peer);
|
||||
EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&peer), &peerlen),
|
||||
SyscallFailsWithErrno(ENOTCONN));
|
||||
|
||||
// Check that we're still bound.
|
||||
socklen_t addrlen = sizeof(addr);
|
||||
EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
|
||||
SyscallSucceeds());
|
||||
EXPECT_EQ(addrlen, addrlen_);
|
||||
EXPECT_EQ(*Port(&addr), 0);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(UdpSocketTest, ConnectBadAddress) {
|
||||
struct sockaddr addr = {};
|
||||
addr.sa_family = addr_[0]->sa_family;
|
||||
ASSERT_THAT(connect(s_, &addr, sizeof(addr.sa_family)),
|
||||
SyscallFailsWithErrno(EINVAL));
|
||||
}
|
||||
|
||||
TEST_P(UdpSocketTest, SendToAddressOtherThanConnected) {
|
||||
ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds());
|
||||
|
||||
|
@ -397,7 +542,7 @@ TEST_P(UdpSocketTest, SendAndReceiveNotConnected) {
|
|||
SyscallSucceedsWithValue(sizeof(buf)));
|
||||
|
||||
// Receive the data.
|
||||
char received[512];
|
||||
char received[sizeof(buf)];
|
||||
EXPECT_THAT(recv(s_, received, sizeof(received), 0),
|
||||
SyscallSucceedsWithValue(sizeof(received)));
|
||||
EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0);
|
||||
|
@ -419,7 +564,7 @@ TEST_P(UdpSocketTest, SendAndReceiveConnected) {
|
|||
SyscallSucceedsWithValue(sizeof(buf)));
|
||||
|
||||
// Receive the data.
|
||||
char received[512];
|
||||
char received[sizeof(buf)];
|
||||
EXPECT_THAT(recv(s_, received, sizeof(received), 0),
|
||||
SyscallSucceedsWithValue(sizeof(received)));
|
||||
EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0);
|
||||
|
@ -462,7 +607,7 @@ TEST_P(UdpSocketTest, ReceiveBeforeConnect) {
|
|||
ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds());
|
||||
|
||||
// Receive the data. It works because it was sent before the connect.
|
||||
char received[512];
|
||||
char received[sizeof(buf)];
|
||||
EXPECT_THAT(recv(s_, received, sizeof(received), 0),
|
||||
SyscallSucceedsWithValue(sizeof(received)));
|
||||
EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0);
|
||||
|
@ -491,7 +636,7 @@ TEST_P(UdpSocketTest, ReceiveFrom) {
|
|||
SyscallSucceedsWithValue(sizeof(buf)));
|
||||
|
||||
// Receive the data and sender address.
|
||||
char received[512];
|
||||
char received[sizeof(buf)];
|
||||
struct sockaddr_storage addr;
|
||||
socklen_t addrlen = sizeof(addr);
|
||||
EXPECT_THAT(recvfrom(s_, received, sizeof(received), 0,
|
||||
|
|
Loading…
Reference in New Issue