Internal change.

PiperOrigin-RevId: 308940886
This commit is contained in:
gVisor bot 2020-04-28 18:49:19 -07:00
parent f93f2fda74
commit 24abccbc1c
8 changed files with 749 additions and 87 deletions

View File

@ -71,6 +71,7 @@ const (
// Values for ICMP code as defined in RFC 792.
const (
ICMPv4TTLExceeded = 0
ICMPv4PortUnreachable = 3
ICMPv4FragmentationNeeded = 4
)

View File

@ -60,6 +60,45 @@
return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Unknown Sockaddr");
}
::grpc::Status proto_to_sockaddr(const posix_server::Sockaddr &sockaddr_proto,
sockaddr_storage *addr) {
switch (sockaddr_proto.sockaddr_case()) {
case posix_server::Sockaddr::SockaddrCase::kIn: {
auto proto_in = sockaddr_proto.in();
if (proto_in.addr().size() != 4) {
return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
"IPv4 address must be 4 bytes");
}
auto addr_in = reinterpret_cast<sockaddr_in *>(addr);
addr_in->sin_family = proto_in.family();
addr_in->sin_port = htons(proto_in.port());
proto_in.addr().copy(reinterpret_cast<char *>(&addr_in->sin_addr.s_addr),
4);
break;
}
case posix_server::Sockaddr::SockaddrCase::kIn6: {
auto proto_in6 = sockaddr_proto.in6();
if (proto_in6.addr().size() != 16) {
return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
"IPv6 address must be 16 bytes");
}
auto addr_in6 = reinterpret_cast<sockaddr_in6 *>(addr);
addr_in6->sin6_family = proto_in6.family();
addr_in6->sin6_port = htons(proto_in6.port());
addr_in6->sin6_flowinfo = htonl(proto_in6.flowinfo());
proto_in6.addr().copy(
reinterpret_cast<char *>(&addr_in6->sin6_addr.s6_addr), 16);
addr_in6->sin6_scope_id = htonl(proto_in6.scope_id());
break;
}
case posix_server::Sockaddr::SockaddrCase::SOCKADDR_NOT_SET:
default:
return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
"Unknown Sockaddr");
}
return ::grpc::Status::OK;
}
class PosixImpl final : public posix_server::Posix::Service {
::grpc::Status Accept(grpc_impl::ServerContext *context,
const ::posix_server::AcceptRequest *request,
@ -79,42 +118,13 @@ class PosixImpl final : public posix_server::Posix::Service {
return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
"Missing address");
}
sockaddr_storage addr;
switch (request->addr().sockaddr_case()) {
case posix_server::Sockaddr::SockaddrCase::kIn: {
auto request_in = request->addr().in();
if (request_in.addr().size() != 4) {
return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
"IPv4 address must be 4 bytes");
}
auto addr_in = reinterpret_cast<sockaddr_in *>(&addr);
addr_in->sin_family = request_in.family();
addr_in->sin_port = htons(request_in.port());
request_in.addr().copy(
reinterpret_cast<char *>(&addr_in->sin_addr.s_addr), 4);
break;
}
case posix_server::Sockaddr::SockaddrCase::kIn6: {
auto request_in6 = request->addr().in6();
if (request_in6.addr().size() != 16) {
return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
"IPv6 address must be 16 bytes");
}
auto addr_in6 = reinterpret_cast<sockaddr_in6 *>(&addr);
addr_in6->sin6_family = request_in6.family();
addr_in6->sin6_port = htons(request_in6.port());
addr_in6->sin6_flowinfo = htonl(request_in6.flowinfo());
request_in6.addr().copy(
reinterpret_cast<char *>(&addr_in6->sin6_addr.s6_addr), 16);
addr_in6->sin6_scope_id = htonl(request_in6.scope_id());
break;
}
case posix_server::Sockaddr::SockaddrCase::SOCKADDR_NOT_SET:
default:
return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
"Unknown Sockaddr");
sockaddr_storage addr;
auto err = proto_to_sockaddr(request->addr(), &addr);
if (!err.ok()) {
return err;
}
response->set_ret(bind(request->sockfd(),
reinterpret_cast<sockaddr *>(&addr), sizeof(addr)));
response->set_errno_(errno);
@ -129,6 +139,25 @@ class PosixImpl final : public posix_server::Posix::Service {
return ::grpc::Status::OK;
}
::grpc::Status Connect(grpc_impl::ServerContext *context,
const ::posix_server::ConnectRequest *request,
::posix_server::ConnectResponse *response) override {
if (!request->has_addr()) {
return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
"Missing address");
}
sockaddr_storage addr;
auto err = proto_to_sockaddr(request->addr(), &addr);
if (!err.ok()) {
return err;
}
response->set_ret(connect(
request->sockfd(), reinterpret_cast<sockaddr *>(&addr), sizeof(addr)));
response->set_errno_(errno);
return ::grpc::Status::OK;
}
::grpc::Status GetSockName(
grpc_impl::ServerContext *context,
const ::posix_server::GetSockNameRequest *request,
@ -141,6 +170,48 @@ class PosixImpl final : public posix_server::Posix::Service {
return sockaddr_to_proto(addr, addrlen, response->mutable_addr());
}
::grpc::Status GetSockOpt(
grpc_impl::ServerContext *context,
const ::posix_server::GetSockOptRequest *request,
::posix_server::GetSockOptResponse *response) override {
socklen_t optlen = request->optlen();
std::vector<char> buf(optlen);
response->set_ret(::getsockopt(request->sockfd(), request->level(),
request->optname(), buf.data(), &optlen));
response->set_errno_(errno);
if (optlen >= 0) {
response->set_optval(buf.data(), optlen);
}
return ::grpc::Status::OK;
}
::grpc::Status GetSockOptInt(
::grpc::ServerContext *context,
const ::posix_server::GetSockOptIntRequest *request,
::posix_server::GetSockOptIntResponse *response) override {
int opt = 0;
socklen_t optlen = sizeof(opt);
response->set_ret(::getsockopt(request->sockfd(), request->level(),
request->optname(), &opt, &optlen));
response->set_errno_(errno);
response->set_intval(opt);
return ::grpc::Status::OK;
}
::grpc::Status GetSockOptTimeval(
::grpc::ServerContext *context,
const ::posix_server::GetSockOptTimevalRequest *request,
::posix_server::GetSockOptTimevalResponse *response) override {
timeval tv;
socklen_t optlen = sizeof(tv);
response->set_ret(::getsockopt(request->sockfd(), request->level(),
request->optname(), &tv, &optlen));
response->set_errno_(errno);
response->mutable_timeval()->set_seconds(tv.tv_sec);
response->mutable_timeval()->set_microseconds(tv.tv_usec);
return ::grpc::Status::OK;
}
::grpc::Status Listen(grpc_impl::ServerContext *context,
const ::posix_server::ListenRequest *request,
::posix_server::ListenResponse *response) override {
@ -158,6 +229,26 @@ class PosixImpl final : public posix_server::Posix::Service {
return ::grpc::Status::OK;
}
::grpc::Status SendTo(::grpc::ServerContext *context,
const ::posix_server::SendToRequest *request,
::posix_server::SendToResponse *response) override {
if (!request->has_dest_addr()) {
return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
"Missing address");
}
sockaddr_storage addr;
auto err = proto_to_sockaddr(request->dest_addr(), &addr);
if (!err.ok()) {
return err;
}
response->set_ret(::sendto(
request->sockfd(), request->buf().data(), request->buf().size(),
request->flags(), reinterpret_cast<sockaddr *>(&addr), sizeof(addr)));
response->set_errno_(errno);
return ::grpc::Status::OK;
}
::grpc::Status SetSockOpt(
grpc_impl::ServerContext *context,
const ::posix_server::SetSockOptRequest *request,
@ -208,8 +299,10 @@ class PosixImpl final : public posix_server::Posix::Service {
std::vector<char> buf(request->len());
response->set_ret(
recv(request->sockfd(), buf.data(), buf.size(), request->flags()));
response->set_errno_(errno);
if (response->ret() >= 0) {
response->set_buf(buf.data(), response->ret());
}
response->set_errno_(errno);
return ::grpc::Status::OK;
}
};

View File

@ -73,6 +73,16 @@ message CloseResponse {
int32 errno_ = 2; // "errno" may fail to compile in c++.
}
message ConnectRequest {
int32 sockfd = 1;
Sockaddr addr = 2;
}
message ConnectResponse {
int32 ret = 1;
int32 errno_ = 2; // "errno" may fail to compile in c++.
}
message GetSockNameRequest {
int32 sockfd = 1;
}
@ -83,6 +93,43 @@ message GetSockNameResponse {
Sockaddr addr = 3;
}
message GetSockOptRequest {
int32 sockfd = 1;
int32 level = 2;
int32 optname = 3;
int32 optlen = 4;
}
message GetSockOptResponse {
int32 ret = 1;
int32 errno_ = 2; // "errno" may fail to compile in c++.
bytes optval = 3;
}
message GetSockOptIntRequest {
int32 sockfd = 1;
int32 level = 2;
int32 optname = 3;
}
message GetSockOptIntResponse {
int32 ret = 1;
int32 errno_ = 2; // "errno" may fail to compile in c++.
int32 intval = 3;
}
message GetSockOptTimevalRequest {
int32 sockfd = 1;
int32 level = 2;
int32 optname = 3;
}
message GetSockOptTimevalResponse {
int32 ret = 1;
int32 errno_ = 2; // "errno" may fail to compile in c++.
Timeval timeval = 3;
}
message ListenRequest {
int32 sockfd = 1;
int32 backlog = 2;
@ -104,6 +151,18 @@ message SendResponse {
int32 errno_ = 2;
}
message SendToRequest {
int32 sockfd = 1;
bytes buf = 2;
int32 flags = 3;
Sockaddr dest_addr = 4;
}
message SendToResponse {
int32 ret = 1;
int32 errno_ = 2; // "errno" may fail to compile in c++.
}
message SetSockOptRequest {
int32 sockfd = 1;
int32 level = 2;
@ -170,12 +229,26 @@ service Posix {
rpc Bind(BindRequest) returns (BindResponse);
// Call close() on the DUT.
rpc Close(CloseRequest) returns (CloseResponse);
// Call connect() on the DUT.
rpc Connect(ConnectRequest) returns (ConnectResponse);
// Call getsockname() on the DUT.
rpc GetSockName(GetSockNameRequest) returns (GetSockNameResponse);
// Call getsockopt() on the DUT. You should prefer one of the other
// GetSockOpt* functions with a more structured optval or else you may get the
// encoding wrong, such as making a bad assumption about the server's word
// sizes or endianness.
rpc GetSockOpt(GetSockOptRequest) returns (GetSockOptResponse);
// Call getsockopt() on the DUT with an int optval.
rpc GetSockOptInt(GetSockOptIntRequest) returns (GetSockOptIntResponse);
// Call getsockopt() on the DUT with a Timeval optval.
rpc GetSockOptTimeval(GetSockOptTimevalRequest)
returns (GetSockOptTimevalResponse);
// Call listen() on the DUT.
rpc Listen(ListenRequest) returns (ListenResponse);
// Call send() on the DUT.
rpc Send(SendRequest) returns (SendResponse);
// Call sendto() on the DUT.
rpc SendTo(SendToRequest) returns (SendToResponse);
// Call setsockopt() on the DUT. You should prefer one of the other
// SetSockOpt* functions with a more structured optval or else you may get the
// encoding wrong, such as making a bad assumption about the server's word

View File

@ -39,20 +39,28 @@ var remoteIPv6 = flag.String("remote_ipv6", "", "remote IPv6 address for test pa
var localMAC = flag.String("local_mac", "", "local mac address for test packets")
var remoteMAC = flag.String("remote_mac", "", "remote mac address for test packets")
// pickPort makes a new socket and returns the socket FD and port. The domain
// should be AF_INET or AF_INET6. The caller must close the FD when done with
func portFromSockaddr(sa unix.Sockaddr) (uint16, error) {
switch sa := sa.(type) {
case *unix.SockaddrInet4:
return uint16(sa.Port), nil
case *unix.SockaddrInet6:
return uint16(sa.Port), nil
}
return 0, fmt.Errorf("sockaddr type %T does not contain port", sa)
}
// pickPort makes a new socket and returns the socket FD and port. The domain should be AF_INET or AF_INET6. The caller must close the FD when done with
// the port if there is no error.
func pickPort(domain, typ int) (fd int, port uint16, err error) {
func pickPort(domain, typ int) (fd int, sa unix.Sockaddr, err error) {
fd, err = unix.Socket(domain, typ, 0)
if err != nil {
return -1, 0, err
return -1, nil, err
}
defer func() {
if err != nil {
err = multierr.Append(err, unix.Close(fd))
}
}()
var sa unix.Sockaddr
switch domain {
case unix.AF_INET:
var sa4 unix.SockaddrInet4
@ -63,31 +71,16 @@ func pickPort(domain, typ int) (fd int, port uint16, err error) {
copy(sa6.Addr[:], net.ParseIP(*localIPv6).To16())
sa = &sa6
default:
return -1, 0, fmt.Errorf("invalid domain %d, it should be one of unix.AF_INET or unix.AF_INET6", domain)
return -1, nil, fmt.Errorf("invalid domain %d, it should be one of unix.AF_INET or unix.AF_INET6", domain)
}
if err = unix.Bind(fd, sa); err != nil {
return -1, 0, err
return -1, nil, err
}
newSockAddr, err := unix.Getsockname(fd)
sa, err = unix.Getsockname(fd)
if err != nil {
return -1, 0, err
}
switch domain {
case unix.AF_INET:
newSockAddrInet4, ok := newSockAddr.(*unix.SockaddrInet4)
if !ok {
return -1, 0, fmt.Errorf("can't cast Getsockname result %T to SockaddrInet4", newSockAddr)
}
return fd, uint16(newSockAddrInet4.Port), nil
case unix.AF_INET6:
newSockAddrInet6, ok := newSockAddr.(*unix.SockaddrInet6)
if !ok {
return -1, 0, fmt.Errorf("can't cast Getsockname result %T to SockaddrInet6", newSockAddr)
}
return fd, uint16(newSockAddrInet6.Port), nil
default:
return -1, 0, fmt.Errorf("invalid domain %d, it should be one of unix.AF_INET or unix.AF_INET6", domain)
return -1, nil, err
}
return fd, sa, nil
}
// layerState stores the state of a layer of a connection.
@ -282,7 +275,11 @@ func SeqNumValue(v seqnum.Value) *seqnum.Value {
// newTCPState creates a new TCPState.
func newTCPState(domain int, out, in TCP) (*tcpState, error) {
portPickerFD, localPort, err := pickPort(domain, unix.SOCK_STREAM)
portPickerFD, localAddr, err := pickPort(domain, unix.SOCK_STREAM)
if err != nil {
return nil, err
}
localPort, err := portFromSockaddr(localAddr)
if err != nil {
return nil, err
}
@ -385,10 +382,14 @@ type udpState struct {
var _ layerState = (*udpState)(nil)
// newUDPState creates a new udpState.
func newUDPState(domain int, out, in UDP) (*udpState, error) {
portPickerFD, localPort, err := pickPort(domain, unix.SOCK_DGRAM)
func newUDPState(domain int, out, in UDP) (*udpState, unix.Sockaddr, error) {
portPickerFD, localAddr, err := pickPort(domain, unix.SOCK_DGRAM)
if err != nil {
return nil, err
return nil, nil, err
}
localPort, err := portFromSockaddr(localAddr)
if err != nil {
return nil, nil, err
}
s := udpState{
out: UDP{SrcPort: &localPort},
@ -396,12 +397,12 @@ func newUDPState(domain int, out, in UDP) (*udpState, error) {
portPickerFD: portPickerFD,
}
if err := s.out.merge(&out); err != nil {
return nil, err
return nil, nil, err
}
if err := s.in.merge(&in); err != nil {
return nil, err
return nil, nil, err
}
return &s, nil
return &s, localAddr, nil
}
func (s *udpState) outgoing() Layer {
@ -436,6 +437,7 @@ type Connection struct {
layerStates []layerState
injector Injector
sniffer Sniffer
localAddr unix.Sockaddr
t *testing.T
}
@ -499,7 +501,7 @@ func (conn *Connection) CreateFrame(layer Layer, additionalLayers ...Layer) Laye
func (conn *Connection) SendFrame(frame Layers) {
outBytes, err := frame.ToBytes()
if err != nil {
conn.t.Fatalf("can't build outgoing TCP packet: %s", err)
conn.t.Fatalf("can't build outgoing packet: %s", err)
}
conn.injector.Send(outBytes)
@ -545,8 +547,9 @@ func (e *layersError) Error() string {
return e.got.diff(e.want)
}
// Expect a frame with the final layerStates layer matching the provided Layer
// within the timeout specified. If it doesn't arrive in time, it returns nil.
// Expect expects a frame with the final layerStates layer matching the
// provided Layer within the timeout specified. If it doesn't arrive in time,
// an error is returned.
func (conn *Connection) Expect(layer Layer, timeout time.Duration) (Layer, error) {
// Make a frame that will ignore all but the final layer.
layers := make([]Layer, len(conn.layerStates))
@ -671,8 +674,8 @@ func (conn *TCPIPv4) Close() {
(*Connection)(conn).Close()
}
// Expect a frame with the TCP layer matching the provided TCP within the
// timeout specified. If it doesn't arrive in time, it returns nil.
// Expect expects a frame with the TCP layer matching the provided TCP within
// the timeout specified. If it doesn't arrive in time, an error is returned.
func (conn *TCPIPv4) Expect(tcp TCP, timeout time.Duration) (*TCP, error) {
layer, err := (*Connection)(conn).Expect(&tcp, timeout)
if layer == nil {
@ -756,7 +759,7 @@ func (conn *IPv6Conn) Close() {
}
// ExpectFrame expects a frame that matches the provided Layers within the
// timeout specified. If it doesn't arrive in time, it returns nil.
// timeout specified. If it doesn't arrive in time, an error is returned.
func (conn *IPv6Conn) ExpectFrame(frame Layers, timeout time.Duration) (Layers, error) {
return (*Connection)(conn).ExpectFrame(frame, timeout)
}
@ -780,7 +783,7 @@ func NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 {
if err != nil {
t.Fatalf("can't make ipv4State: %s", err)
}
tcpState, err := newUDPState(unix.AF_INET, outgoingUDP, incomingUDP)
udpState, localAddr, err := newUDPState(unix.AF_INET, outgoingUDP, incomingUDP)
if err != nil {
t.Fatalf("can't make udpState: %s", err)
}
@ -794,24 +797,61 @@ func NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 {
}
return UDPIPv4{
layerStates: []layerState{etherState, ipv4State, tcpState},
layerStates: []layerState{etherState, ipv4State, udpState},
injector: injector,
sniffer: sniffer,
localAddr: localAddr,
t: t,
}
}
// LocalAddr gets the local socket address of this connection.
func (conn *UDPIPv4) LocalAddr() unix.Sockaddr {
return conn.localAddr
}
// CreateFrame builds a frame for the connection with layer overriding defaults
// of the innermost layer and additionalLayers added after it.
func (conn *UDPIPv4) CreateFrame(layer Layer, additionalLayers ...Layer) Layers {
return (*Connection)(conn).CreateFrame(layer, additionalLayers...)
}
// Send a packet with reasonable defaults. Potentially override the UDP layer in
// the connection with the provided layer and add additionLayers.
func (conn *UDPIPv4) Send(udp UDP, additionalLayers ...Layer) {
(*Connection)(conn).Send(&udp, additionalLayers...)
}
// SendFrame sends a frame on the wire and updates the state of all layers.
func (conn *UDPIPv4) SendFrame(frame Layers) {
(*Connection)(conn).SendFrame(frame)
}
// SendIP sends a packet with additionalLayers following the IP layer in the
// connection.
func (conn *UDPIPv4) SendIP(additionalLayers ...Layer) {
var layersToSend Layers
for _, s := range conn.layerStates[:len(conn.layerStates)-1] {
layersToSend = append(layersToSend, s.outgoing())
}
layersToSend = append(layersToSend, additionalLayers...)
conn.SendFrame(layersToSend)
}
// Expect expects a frame with the UDP layer matching the provided UDP within
// the timeout specified. If it doesn't arrive in time, an error is returned.
func (conn *UDPIPv4) Expect(udp UDP, timeout time.Duration) (*UDP, error) {
layer, err := (*Connection)(conn).Expect(&udp, timeout)
if layer == nil {
return nil, err
}
gotUDP, ok := layer.(*UDP)
if !ok {
conn.t.Fatalf("expected %s to be UDP", layer)
}
return gotUDP, err
}
// Close frees associated resources held by the UDPIPv4 connection.
func (conn *UDPIPv4) Close() {
(*Connection)(conn).Close()

View File

@ -237,6 +237,33 @@ func (dut *DUT) CloseWithErrno(ctx context.Context, fd int32) (int32, error) {
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
// Connect calls connect on the DUT and causes a fatal test failure if it
// doesn't succeed. If more control over the timeout or error handling is
// needed, use ConnectWithErrno.
func (dut *DUT) Connect(fd int32, sa unix.Sockaddr) {
dut.t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
defer cancel()
ret, err := dut.ConnectWithErrno(ctx, fd, sa)
if ret != 0 {
dut.t.Fatalf("failed to connect socket: %s", err)
}
}
// ConnectWithErrno calls bind on the DUT.
func (dut *DUT) ConnectWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr) (int32, error) {
dut.t.Helper()
req := pb.ConnectRequest{
Sockfd: fd,
Addr: dut.sockaddrToProto(sa),
}
resp, err := dut.posixServer.Connect(ctx, &req)
if err != nil {
dut.t.Fatalf("failed to call Connect: %s", err)
}
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
// GetSockName calls getsockname on the DUT and causes a fatal test failure if
// it doesn't succeed. If more control over the timeout or error handling is
// needed, use GetSockNameWithErrno.
@ -264,6 +291,102 @@ func (dut *DUT) GetSockNameWithErrno(ctx context.Context, sockfd int32) (int32,
return resp.GetRet(), dut.protoToSockaddr(resp.GetAddr()), syscall.Errno(resp.GetErrno_())
}
// GetSockOpt calls getsockopt on the DUT and causes a fatal test failure if it
// doesn't succeed. If more control over the timeout or error handling is
// needed, use GetSockOptWithErrno. Because endianess and the width of values
// might differ between the testbench and DUT architectures, prefer to use a
// more specific GetSockOptXxx function.
func (dut *DUT) GetSockOpt(sockfd, level, optname, optlen int32) []byte {
dut.t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
defer cancel()
ret, optval, err := dut.GetSockOptWithErrno(ctx, sockfd, level, optname, optlen)
if ret != 0 {
dut.t.Fatalf("failed to GetSockOpt: %s", err)
}
return optval
}
// GetSockOptWithErrno calls getsockopt on the DUT. Because endianess and the
// width of values might differ between the testbench and DUT architectures,
// prefer to use a more specific GetSockOptXxxWithErrno function.
func (dut *DUT) GetSockOptWithErrno(ctx context.Context, sockfd, level, optname, optlen int32) (int32, []byte, error) {
dut.t.Helper()
req := pb.GetSockOptRequest{
Sockfd: sockfd,
Level: level,
Optname: optname,
Optlen: optlen,
}
resp, err := dut.posixServer.GetSockOpt(ctx, &req)
if err != nil {
dut.t.Fatalf("failed to call GetSockOpt: %s", err)
}
return resp.GetRet(), resp.GetOptval(), syscall.Errno(resp.GetErrno_())
}
// GetSockOptInt calls getsockopt on the DUT and causes a fatal test failure
// if it doesn't succeed. If more control over the int optval or error handling
// is needed, use GetSockOptIntWithErrno.
func (dut *DUT) GetSockOptInt(sockfd, level, optname int32) int32 {
dut.t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
defer cancel()
ret, intval, err := dut.GetSockOptIntWithErrno(ctx, sockfd, level, optname)
if ret != 0 {
dut.t.Fatalf("failed to GetSockOptInt: %s", err)
}
return intval
}
// GetSockOptIntWithErrno calls getsockopt with an integer optval.
func (dut *DUT) GetSockOptIntWithErrno(ctx context.Context, sockfd, level, optname int32) (int32, int32, error) {
dut.t.Helper()
req := pb.GetSockOptIntRequest{
Sockfd: sockfd,
Level: level,
Optname: optname,
}
resp, err := dut.posixServer.GetSockOptInt(ctx, &req)
if err != nil {
dut.t.Fatalf("failed to call GetSockOptInt: %s", err)
}
return resp.GetRet(), resp.GetIntval(), syscall.Errno(resp.GetErrno_())
}
// GetSockOptTimeval calls getsockopt on the DUT and causes a fatal test failure
// if it doesn't succeed. If more control over the timeout or error handling is
// needed, use GetSockOptTimevalWithErrno.
func (dut *DUT) GetSockOptTimeval(sockfd, level, optname int32) unix.Timeval {
dut.t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
defer cancel()
ret, timeval, err := dut.GetSockOptTimevalWithErrno(ctx, sockfd, level, optname)
if ret != 0 {
dut.t.Fatalf("failed to GetSockOptTimeval: %s", err)
}
return timeval
}
// GetSockOptTimevalWithErrno calls getsockopt and returns a timeval.
func (dut *DUT) GetSockOptTimevalWithErrno(ctx context.Context, sockfd, level, optname int32) (int32, unix.Timeval, error) {
dut.t.Helper()
req := pb.GetSockOptTimevalRequest{
Sockfd: sockfd,
Level: level,
Optname: optname,
}
resp, err := dut.posixServer.GetSockOptTimeval(ctx, &req)
if err != nil {
dut.t.Fatalf("failed to call GetSockOptTimeval: %s", err)
}
timeval := unix.Timeval{
Sec: resp.GetTimeval().Seconds,
Usec: resp.GetTimeval().Microseconds,
}
return resp.GetRet(), timeval, syscall.Errno(resp.GetErrno_())
}
// Listen calls listen on the DUT and causes a fatal test failure if it doesn't
// succeed. If more control over the timeout or error handling is needed, use
// ListenWithErrno.
@ -320,6 +443,36 @@ func (dut *DUT) SendWithErrno(ctx context.Context, sockfd int32, buf []byte, fla
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
// SendTo calls sendto on the DUT and causes a fatal test failure if it doesn't
// succeed. If more control over the timeout or error handling is needed, use
// SendToWithErrno.
func (dut *DUT) SendTo(sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) int32 {
dut.t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
defer cancel()
ret, err := dut.SendToWithErrno(ctx, sockfd, buf, flags, destAddr)
if ret == -1 {
dut.t.Fatalf("failed to sendto: %s", err)
}
return ret
}
// SendToWithErrno calls sendto on the DUT.
func (dut *DUT) SendToWithErrno(ctx context.Context, sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) (int32, error) {
dut.t.Helper()
req := pb.SendToRequest{
Sockfd: sockfd,
Buf: buf,
Flags: flags,
DestAddr: dut.sockaddrToProto(destAddr),
}
resp, err := dut.posixServer.SendTo(ctx, &req)
if err != nil {
dut.t.Fatalf("faled to call SendTo: %s", err)
}
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
// SetSockOpt calls setsockopt on the DUT and causes a fatal test failure if it
// doesn't succeed. If more control over the timeout or error handling is
// needed, use SetSockOptWithErrno. Because endianess and the width of values

View File

@ -58,7 +58,7 @@ type Layer interface {
next() Layer
// prev gets a pointer to the Layer encapsulating this one.
prev() Layer
Prev() Layer
// setNext sets the pointer to the encapsulated Layer.
setNext(Layer)
@ -80,7 +80,8 @@ func (lb *LayerBase) next() Layer {
return lb.nextLayer
}
func (lb *LayerBase) prev() Layer {
// Prev returns the previous layer.
func (lb *LayerBase) Prev() Layer {
return lb.prevLayer
}
@ -340,6 +341,8 @@ func (l *IPv4) ToBytes() ([]byte, error) {
fields.Protocol = uint8(header.TCPProtocolNumber)
case *UDP:
fields.Protocol = uint8(header.UDPProtocolNumber)
case *ICMPv4:
fields.Protocol = uint8(header.ICMPv4ProtocolNumber)
default:
// TODO(b/150301488): Support more protocols as needed.
return nil, fmt.Errorf("ipv4 header's next layer is unrecognized: %#v", n)
@ -403,6 +406,8 @@ func parseIPv4(b []byte) (Layer, layerParser) {
nextParser = parseTCP
case header.UDPProtocolNumber:
nextParser = parseUDP
case header.ICMPv4ProtocolNumber:
nextParser = parseICMPv4
default:
// Assume that the rest is a payload.
nextParser = parsePayload
@ -562,7 +567,7 @@ func (l *ICMPv6) ToBytes() ([]byte, error) {
if l.Checksum != nil {
h.SetChecksum(*l.Checksum)
} else {
ipv6 := l.prev().(*IPv6)
ipv6 := l.Prev().(*IPv6)
h.SetChecksum(header.ICMPv6Checksum(h, *ipv6.SrcAddr, *ipv6.DstAddr, buffer.VectorisedView{}))
}
return h, nil
@ -606,6 +611,72 @@ func (l *ICMPv6) merge(other Layer) error {
return mergeLayer(l, other)
}
// ICMPv4Type is a helper routine that allocates a new header.ICMPv4Type value
// to store t and returns a pointer to it.
func ICMPv4Type(t header.ICMPv4Type) *header.ICMPv4Type {
return &t
}
// ICMPv4 can construct and match an ICMPv4 encapsulation.
type ICMPv4 struct {
LayerBase
Type *header.ICMPv4Type
Code *uint8
Checksum *uint16
}
func (l *ICMPv4) String() string {
return stringLayer(l)
}
// ToBytes implements Layer.ToBytes.
func (l *ICMPv4) ToBytes() ([]byte, error) {
b := make([]byte, header.ICMPv4MinimumSize)
h := header.ICMPv4(b)
if l.Type != nil {
h.SetType(*l.Type)
}
if l.Code != nil {
h.SetCode(byte(*l.Code))
}
if l.Checksum != nil {
h.SetChecksum(*l.Checksum)
return h, nil
}
payload, err := payload(l)
if err != nil {
return nil, err
}
h.SetChecksum(header.ICMPv4Checksum(h, payload))
return h, nil
}
// parseICMPv4 parses the bytes as an ICMPv4 header, returning a Layer and a
// parser for the encapsulated payload.
func parseICMPv4(b []byte) (Layer, layerParser) {
h := header.ICMPv4(b)
icmpv4 := ICMPv4{
Type: ICMPv4Type(h.Type()),
Code: Uint8(h.Code()),
Checksum: Uint16(h.Checksum()),
}
return &icmpv4, parsePayload
}
func (l *ICMPv4) match(other Layer) bool {
return equalLayer(l, other)
}
func (l *ICMPv4) length() int {
return header.ICMPv4MinimumSize
}
// merge overrides the values in l with the values from other but only in fields
// where the value is not nil.
func (l *ICMPv4) merge(other Layer) error {
return mergeLayer(l, other)
}
// TCP can construct and match a TCP encapsulation.
type TCP struct {
LayerBase
@ -676,25 +747,34 @@ func totalLength(l Layer) int {
return totalLength
}
// payload returns a buffer.VectorisedView of l's payload.
func payload(l Layer) (buffer.VectorisedView, error) {
var payloadBytes buffer.VectorisedView
for current := l.next(); current != nil; current = current.next() {
payload, err := current.ToBytes()
if err != nil {
return buffer.VectorisedView{}, fmt.Errorf("can't get bytes for next header: %s", payload)
}
payloadBytes.AppendView(payload)
}
return payloadBytes, nil
}
// layerChecksum calculates the checksum of the Layer header, including the
// peusdeochecksum of the layer before it and all the bytes after it..
// peusdeochecksum of the layer before it and all the bytes after it.
func layerChecksum(l Layer, protoNumber tcpip.TransportProtocolNumber) (uint16, error) {
totalLength := uint16(totalLength(l))
var xsum uint16
switch s := l.prev().(type) {
switch s := l.Prev().(type) {
case *IPv4:
xsum = header.PseudoHeaderChecksum(protoNumber, *s.SrcAddr, *s.DstAddr, totalLength)
default:
// TODO(b/150301488): Support more protocols, like IPv6.
return 0, fmt.Errorf("can't get src and dst addr from previous layer: %#v", s)
}
var payloadBytes buffer.VectorisedView
for current := l.next(); current != nil; current = current.next() {
payload, err := current.ToBytes()
payloadBytes, err := payload(l)
if err != nil {
return 0, fmt.Errorf("can't get bytes for next header: %s", payload)
}
payloadBytes.AppendView(payload)
return 0, err
}
xsum = header.ChecksumVV(payloadBytes, xsum)
return xsum, nil

View File

@ -28,6 +28,19 @@ packetimpact_go_test(
],
)
packetimpact_go_test(
name = "udp_icmp_error_propagation",
srcs = ["udp_icmp_error_propagation_test.go"],
# TODO(b/153926291): Fix netstack then remove the line below.
netstack = False,
deps = [
"//pkg/tcpip",
"//pkg/tcpip/header",
"//test/packetimpact/testbench",
"@org_golang_x_sys//unix:go_default_library",
],
)
packetimpact_go_test(
name = "tcp_window_shrink",
srcs = ["tcp_window_shrink_test.go"],

View File

@ -0,0 +1,209 @@
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package udp_icmp_error_propagation_test
import (
"context"
"fmt"
"net"
"syscall"
"testing"
"time"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip/header"
tb "gvisor.dev/gvisor/test/packetimpact/testbench"
)
type connected bool
func (c connected) String() string {
if c {
return "Connected"
}
return "Connectionless"
}
type icmpError int
const (
portUnreachable icmpError = iota
timeToLiveExceeded
)
func (e icmpError) String() string {
switch e {
case portUnreachable:
return "PortUnreachable"
case timeToLiveExceeded:
return "TimeToLiveExpired"
}
return "Unknown ICMP error"
}
func (e icmpError) ToICMPv4() *tb.ICMPv4 {
switch e {
case portUnreachable:
return &tb.ICMPv4{Type: tb.ICMPv4Type(header.ICMPv4DstUnreachable), Code: tb.Uint8(header.ICMPv4PortUnreachable)}
case timeToLiveExceeded:
return &tb.ICMPv4{Type: tb.ICMPv4Type(header.ICMPv4TimeExceeded), Code: tb.Uint8(header.ICMPv4TTLExceeded)}
}
return nil
}
type errorDetectionFunc func(context.Context, *tb.DUT, *tb.UDPIPv4, int32, syscall.Errno) error
// testRecv tests observing the ICMP error through the recv syscall.
// A packet is sent to the DUT, and if wantErrno is non-zero, then the first
// recv should fail and the second should succeed. Otherwise if wantErrno is
// zero then the first recv should succeed immediately.
func testRecv(ctx context.Context, dut *tb.DUT, conn *tb.UDPIPv4, remoteFD int32, wantErrno syscall.Errno) error {
conn.Send(tb.UDP{})
if wantErrno != syscall.Errno(0) {
ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
ret, _, err := dut.RecvWithErrno(ctx, remoteFD, 100, 0)
if ret != -1 {
return fmt.Errorf("recv after ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", wantErrno)
}
if err != wantErrno {
return fmt.Errorf("recv after ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, wantErrno)
}
}
dut.Recv(remoteFD, 100, 0)
return nil
}
// testSendTo tests observing the ICMP error through the send syscall.
// If wantErrno is non-zero, the first send should fail and a subsequent send
// should suceed; while if wantErrno is zero then the first send should just
// succeed.
func testSendTo(ctx context.Context, dut *tb.DUT, conn *tb.UDPIPv4, remoteFD int32, wantErrno syscall.Errno) error {
if wantErrno != syscall.Errno(0) {
ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
ret, err := dut.SendToWithErrno(ctx, remoteFD, nil, 0, conn.LocalAddr())
if ret != -1 {
return fmt.Errorf("sendto after ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", wantErrno)
}
if err != wantErrno {
return fmt.Errorf("sendto after ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, wantErrno)
}
}
dut.SendTo(remoteFD, nil, 0, conn.LocalAddr())
if _, err := conn.Expect(tb.UDP{}, time.Second); err != nil {
return fmt.Errorf("did not receive UDP packet as expected: %s", err)
}
return nil
}
func testSockOpt(_ context.Context, dut *tb.DUT, conn *tb.UDPIPv4, remoteFD int32, wantErrno syscall.Errno) error {
errno := syscall.Errno(dut.GetSockOptInt(remoteFD, unix.SOL_SOCKET, unix.SO_ERROR))
if errno != wantErrno {
return fmt.Errorf("SO_ERROR sockopt after ICMP error is (%[1]d) %[1]v, expected (%[2]d) %[2]v", errno, wantErrno)
}
// Check that after clearing socket error, sending doesn't fail.
dut.SendTo(remoteFD, nil, 0, conn.LocalAddr())
if _, err := conn.Expect(tb.UDP{}, time.Second); err != nil {
return fmt.Errorf("did not receive UDP packet as expected: %s", err)
}
return nil
}
type testParameters struct {
connected connected
icmpErr icmpError
wantErrno syscall.Errno
f errorDetectionFunc
fName string
}
// TestUDPICMPErrorPropagation tests that ICMP PortUnreachable error messages
// destined for a "connected" UDP socket are observable on said socket by:
// 1. causing the next send to fail with ECONNREFUSED,
// 2. causing the next recv to fail with ECONNREFUSED, or
// 3. returning ECONNREFUSED through the SO_ERROR socket option.
func TestUDPICMPErrorPropagation(t *testing.T) {
var testCases []testParameters
for _, c := range []connected{true, false} {
for _, i := range []icmpError{portUnreachable, timeToLiveExceeded} {
e := syscall.Errno(0)
if c && i == portUnreachable {
e = unix.ECONNREFUSED
}
for _, f := range []struct {
name string
f errorDetectionFunc
}{
{"SendTo", testSendTo},
{"Recv", testRecv},
{"SockOpt", testSockOpt},
} {
testCases = append(testCases, testParameters{c, i, e, f.f, f.name})
}
}
}
for _, tt := range testCases {
t.Run(fmt.Sprintf("%s/%s/%s", tt.connected, tt.icmpErr, tt.fName), func(t *testing.T) {
dut := tb.NewDUT(t)
defer dut.TearDown()
remoteFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0"))
defer dut.Close(remoteFD)
conn := tb.NewUDPIPv4(t, tb.UDP{DstPort: &remotePort}, tb.UDP{SrcPort: &remotePort})
defer conn.Close()
if tt.connected {
dut.Connect(remoteFD, conn.LocalAddr())
}
dut.SendTo(remoteFD, nil, 0, conn.LocalAddr())
udp, err := conn.Expect(tb.UDP{}, time.Second)
if err != nil {
t.Fatalf("did not receive message from DUT: %s", err)
}
if tt.icmpErr == timeToLiveExceeded {
ip, ok := udp.Prev().(*tb.IPv4)
if !ok {
t.Fatalf("expected %s to be IPv4", udp.Prev())
}
*ip.TTL = 1
// Let serialization recalculate the checksum since we set the
// TTL to 1.
ip.Checksum = nil
// Note that the ICMP payload is valid in this case because the UDP
// payload is empty. If the UDP payload were not empty, the packet
// length during serialization may not be calculated correctly,
// resulting in a mal-formed packet.
conn.SendIP(tt.icmpErr.ToICMPv4(), ip, udp)
} else {
conn.SendIP(tt.icmpErr.ToICMPv4(), udp.Prev(), udp)
}
if err := tt.f(context.Background(), &dut, &conn, remoteFD, tt.wantErrno); err != nil {
t.Fatal(err)
}
})
}
}