diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go index 74412c894..9339d637f 100644 --- a/pkg/tcpip/header/udp.go +++ b/pkg/tcpip/header/udp.go @@ -99,6 +99,11 @@ func (b UDP) SetChecksum(checksum uint16) { binary.BigEndian.PutUint16(b[udpChecksum:], checksum) } +// SetLength sets the "length" field of the udp header. +func (b UDP) SetLength(length uint16) { + binary.BigEndian.PutUint16(b[udpLength:], length) +} + // CalculateChecksum calculates the checksum of the udp packet, given the // checksum of the network-layer pseudo-header and the checksum of the payload. func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 { diff --git a/test/packetimpact/dut/posix_server.cc b/test/packetimpact/dut/posix_server.cc index 2f10dda40..4a71c54c6 100644 --- a/test/packetimpact/dut/posix_server.cc +++ b/test/packetimpact/dut/posix_server.cc @@ -181,6 +181,17 @@ class PosixImpl final : public posix_server::Posix::Service { response->set_errno_(errno); return ::grpc::Status::OK; } + + ::grpc::Status Recv(::grpc::ServerContext *context, + const ::posix_server::RecvRequest *request, + ::posix_server::RecvResponse *response) override { + std::vector buf(request->len()); + response->set_ret( + recv(request->sockfd(), buf.data(), buf.size(), request->flags())); + response->set_errno_(errno); + response->set_buf(buf.data(), response->ret()); + return ::grpc::Status::OK; + } }; // Parse command line options. Returns a pointer to the first argument beyond diff --git a/test/packetimpact/proto/posix_server.proto b/test/packetimpact/proto/posix_server.proto index 026876fc2..53ec49410 100644 --- a/test/packetimpact/proto/posix_server.proto +++ b/test/packetimpact/proto/posix_server.proto @@ -24,7 +24,7 @@ message SocketRequest { message SocketResponse { int32 fd = 1; - int32 errno_ = 2; + int32 errno_ = 2; // "errno" may fail to compile in c++. } message SockaddrIn { @@ -55,7 +55,7 @@ message BindRequest { message BindResponse { int32 ret = 1; - int32 errno_ = 2; + int32 errno_ = 2; // "errno" may fail to compile in c++. } message GetSockNameRequest { @@ -64,7 +64,7 @@ message GetSockNameRequest { message GetSockNameResponse { int32 ret = 1; - int32 errno_ = 2; + int32 errno_ = 2; // "errno" may fail to compile in c++. Sockaddr addr = 3; } @@ -75,7 +75,7 @@ message ListenRequest { message ListenResponse { int32 ret = 1; - int32 errno_ = 2; + int32 errno_ = 2; // "errno" may fail to compile in c++. } message AcceptRequest { @@ -84,7 +84,7 @@ message AcceptRequest { message AcceptResponse { int32 fd = 1; - int32 errno_ = 2; + int32 errno_ = 2; // "errno" may fail to compile in c++. Sockaddr addr = 3; } @@ -97,7 +97,7 @@ message SetSockOptRequest { message SetSockOptResponse { int32 ret = 1; - int32 errno_ = 2; + int32 errno_ = 2; // "errno" may fail to compile in c++. } message Timeval { @@ -114,7 +114,7 @@ message SetSockOptTimevalRequest { message SetSockOptTimevalResponse { int32 ret = 1; - int32 errno_ = 2; + int32 errno_ = 2; // "errno" may fail to compile in c++. } message CloseRequest { @@ -123,7 +123,19 @@ message CloseRequest { message CloseResponse { int32 ret = 1; - int32 errno_ = 2; + int32 errno_ = 2; // "errno" may fail to compile in c++. +} + +message RecvRequest { + int32 sockfd = 1; + int32 len = 2; + int32 flags = 3; +} + +message RecvResponse { + int32 ret = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. + bytes buf = 3; } service Posix { @@ -147,4 +159,6 @@ service Posix { returns (SetSockOptTimevalResponse); // Call close() on the DUT. rpc Close(CloseRequest) returns (CloseResponse); + // Call recv() on the DUT. + rpc Recv(RecvRequest) returns (RecvResponse); } diff --git a/test/packetimpact/testbench/BUILD b/test/packetimpact/testbench/BUILD index a34c81fcc..4a9d8efa6 100644 --- a/test/packetimpact/testbench/BUILD +++ b/test/packetimpact/testbench/BUILD @@ -16,6 +16,7 @@ go_library( ], deps = [ "//pkg/tcpip", + "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/seqnum", "//pkg/usermem", diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go index b7aa63934..8d1f562ee 100644 --- a/test/packetimpact/testbench/connections.go +++ b/test/packetimpact/testbench/connections.go @@ -36,19 +36,6 @@ var remoteIPv4 = flag.String("remote_ipv4", "", "remote IPv4 address for test pa var localMAC = flag.String("local_mac", "", "local mac address for test packets") var remoteMAC = flag.String("remote_mac", "", "remote mac address for test packets") -// TCPIPv4 maintains state about a TCP/IPv4 connection. -type TCPIPv4 struct { - outgoing Layers - incoming Layers - LocalSeqNum seqnum.Value - RemoteSeqNum seqnum.Value - SynAck *TCP - sniffer Sniffer - injector Injector - portPickerFD int - t *testing.T -} - // pickPort makes a new socket and returns the socket FD and port. The caller // must close the FD when done with the port if there is no error. func pickPort() (int, uint16, error) { @@ -75,12 +62,25 @@ func pickPort() (int, uint16, error) { return fd, uint16(newSockAddrInet4.Port), nil } +// TCPIPv4 maintains state about a TCP/IPv4 connection. +type TCPIPv4 struct { + outgoing Layers + incoming Layers + LocalSeqNum seqnum.Value + RemoteSeqNum seqnum.Value + SynAck *TCP + sniffer Sniffer + injector Injector + portPickerFD int + t *testing.T +} + // tcpLayerIndex is the position of the TCP layer in the TCPIPv4 connection. It // is the third, after Ethernet and IPv4. const tcpLayerIndex int = 2 // NewTCPIPv4 creates a new TCPIPv4 connection with reasonable defaults. -func NewTCPIPv4(t *testing.T, dut DUT, outgoingTCP, incomingTCP TCP) TCPIPv4 { +func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 { lMAC, err := tcpip.ParseMACAddress(*localMAC) if err != nil { t.Fatalf("can't parse localMAC %q: %s", *localMAC, err) @@ -109,18 +109,16 @@ func NewTCPIPv4(t *testing.T, dut DUT, outgoingTCP, incomingTCP TCP) TCPIPv4 { } newOutgoingTCP := &TCP{ - DataOffset: Uint8(header.TCPMinimumSize), - WindowSize: Uint16(32768), - SrcPort: &localPort, + SrcPort: &localPort, } if err := newOutgoingTCP.merge(outgoingTCP); err != nil { - t.Fatalf("can't merge %v into %v: %s", outgoingTCP, newOutgoingTCP, err) + t.Fatalf("can't merge %+v into %+v: %s", outgoingTCP, newOutgoingTCP, err) } newIncomingTCP := &TCP{ DstPort: &localPort, } if err := newIncomingTCP.merge(incomingTCP); err != nil { - t.Fatalf("can't merge %v into %v: %s", incomingTCP, newIncomingTCP, err) + t.Fatalf("can't merge %+v into %+v: %s", incomingTCP, newIncomingTCP, err) } return TCPIPv4{ outgoing: Layers{ @@ -149,8 +147,9 @@ func (conn *TCPIPv4) Close() { conn.portPickerFD = -1 } -// Send a packet with reasonable defaults and override some fields by tcp. -func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) { +// CreateFrame builds a frame for the connection with tcp overriding defaults +// and additionalLayers added after the TCP header. +func (conn *TCPIPv4) CreateFrame(tcp TCP, additionalLayers ...Layer) Layers { if tcp.SeqNum == nil { tcp.SeqNum = Uint32(uint32(conn.LocalSeqNum)) } @@ -159,30 +158,41 @@ func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) { } layersToSend := deepcopy.Copy(conn.outgoing).(Layers) if err := layersToSend[tcpLayerIndex].(*TCP).merge(tcp); err != nil { - conn.t.Fatalf("can't merge %v into %v: %s", tcp, layersToSend[tcpLayerIndex], err) + conn.t.Fatalf("can't merge %+v into %+v: %s", tcp, layersToSend[tcpLayerIndex], err) } layersToSend = append(layersToSend, additionalLayers...) - outBytes, err := layersToSend.toBytes() + return layersToSend +} + +// SendFrame sends a frame with reasonable defaults. +func (conn *TCPIPv4) SendFrame(frame Layers) { + outBytes, err := frame.toBytes() if err != nil { conn.t.Fatalf("can't build outgoing TCP packet: %s", err) } conn.injector.Send(outBytes) // Compute the next TCP sequence number. - for i := tcpLayerIndex + 1; i < len(layersToSend); i++ { - conn.LocalSeqNum.UpdateForward(seqnum.Size(layersToSend[i].length())) + for i := tcpLayerIndex + 1; i < len(frame); i++ { + conn.LocalSeqNum.UpdateForward(seqnum.Size(frame[i].length())) } + tcp := frame[tcpLayerIndex].(*TCP) if tcp.Flags != nil && *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 { conn.LocalSeqNum.UpdateForward(1) } } +// Send a packet with reasonable defaults and override some fields by tcp. +func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) { + conn.SendFrame(conn.CreateFrame(tcp, additionalLayers...)) +} + // Recv gets a packet from the sniffer within the timeout provided. If no packet // arrives before the timeout, it returns nil. func (conn *TCPIPv4) Recv(timeout time.Duration) *TCP { deadline := time.Now().Add(timeout) for { - timeout = deadline.Sub(time.Now()) + timeout = time.Until(deadline) if timeout <= 0 { break } @@ -192,6 +202,7 @@ func (conn *TCPIPv4) Recv(timeout time.Duration) *TCP { } layers, err := ParseEther(b) if err != nil { + conn.t.Logf("can't parse frame: %s", err) continue // Ignore packets that can't be parsed. } if !conn.incoming.match(layers) { @@ -215,7 +226,7 @@ func (conn *TCPIPv4) Recv(timeout time.Duration) *TCP { func (conn *TCPIPv4) Expect(tcp TCP, timeout time.Duration) *TCP { deadline := time.Now().Add(timeout) for { - timeout = deadline.Sub(time.Now()) + timeout = time.Until(deadline) if timeout <= 0 { return nil } @@ -243,3 +254,154 @@ func (conn *TCPIPv4) Handshake() { // Send an ACK. conn.Send(TCP{Flags: Uint8(header.TCPFlagAck)}) } + +// UDPIPv4 maintains state about a UDP/IPv4 connection. +type UDPIPv4 struct { + outgoing Layers + incoming Layers + sniffer Sniffer + injector Injector + portPickerFD int + t *testing.T +} + +// udpLayerIndex is the position of the UDP layer in the UDPIPv4 connection. It +// is the third, after Ethernet and IPv4. +const udpLayerIndex int = 2 + +// NewUDPIPv4 creates a new UDPIPv4 connection with reasonable defaults. +func NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 { + lMAC, err := tcpip.ParseMACAddress(*localMAC) + if err != nil { + t.Fatalf("can't parse localMAC %q: %s", *localMAC, err) + } + + rMAC, err := tcpip.ParseMACAddress(*remoteMAC) + if err != nil { + t.Fatalf("can't parse remoteMAC %q: %s", *remoteMAC, err) + } + + portPickerFD, localPort, err := pickPort() + if err != nil { + t.Fatalf("can't pick a port: %s", err) + } + lIP := tcpip.Address(net.ParseIP(*localIPv4).To4()) + rIP := tcpip.Address(net.ParseIP(*remoteIPv4).To4()) + + sniffer, err := NewSniffer(t) + if err != nil { + t.Fatalf("can't make new sniffer: %s", err) + } + + injector, err := NewInjector(t) + if err != nil { + t.Fatalf("can't make new injector: %s", err) + } + + newOutgoingUDP := &UDP{ + SrcPort: &localPort, + } + if err := newOutgoingUDP.merge(outgoingUDP); err != nil { + t.Fatalf("can't merge %+v into %+v: %s", outgoingUDP, newOutgoingUDP, err) + } + newIncomingUDP := &UDP{ + DstPort: &localPort, + } + if err := newIncomingUDP.merge(incomingUDP); err != nil { + t.Fatalf("can't merge %+v into %+v: %s", incomingUDP, newIncomingUDP, err) + } + return UDPIPv4{ + outgoing: Layers{ + &Ether{SrcAddr: &lMAC, DstAddr: &rMAC}, + &IPv4{SrcAddr: &lIP, DstAddr: &rIP}, + newOutgoingUDP}, + incoming: Layers{ + &Ether{SrcAddr: &rMAC, DstAddr: &lMAC}, + &IPv4{SrcAddr: &rIP, DstAddr: &lIP}, + newIncomingUDP}, + sniffer: sniffer, + injector: injector, + portPickerFD: portPickerFD, + t: t, + } +} + +// Close the injector and sniffer associated with this connection. +func (conn *UDPIPv4) Close() { + conn.sniffer.Close() + conn.injector.Close() + if err := unix.Close(conn.portPickerFD); err != nil { + conn.t.Fatalf("can't close portPickerFD: %s", err) + } + conn.portPickerFD = -1 +} + +// CreateFrame builds a frame for the connection with the provided udp +// overriding defaults and the additionalLayers added after the UDP header. +func (conn *UDPIPv4) CreateFrame(udp UDP, additionalLayers ...Layer) Layers { + layersToSend := deepcopy.Copy(conn.outgoing).(Layers) + if err := layersToSend[udpLayerIndex].(*UDP).merge(udp); err != nil { + conn.t.Fatalf("can't merge %+v into %+v: %s", udp, layersToSend[udpLayerIndex], err) + } + layersToSend = append(layersToSend, additionalLayers...) + return layersToSend +} + +// SendFrame sends a frame with reasonable defaults. +func (conn *UDPIPv4) SendFrame(frame Layers) { + outBytes, err := frame.toBytes() + if err != nil { + conn.t.Fatalf("can't build outgoing UDP packet: %s", err) + } + conn.injector.Send(outBytes) +} + +// Send a packet with reasonable defaults and override some fields by udp. +func (conn *UDPIPv4) Send(udp UDP, additionalLayers ...Layer) { + conn.SendFrame(conn.CreateFrame(udp, additionalLayers...)) +} + +// Recv gets a packet from the sniffer within the timeout provided. If no packet +// arrives before the timeout, it returns nil. +func (conn *UDPIPv4) Recv(timeout time.Duration) *UDP { + deadline := time.Now().Add(timeout) + for { + timeout = time.Until(deadline) + if timeout <= 0 { + break + } + b := conn.sniffer.Recv(timeout) + if b == nil { + break + } + layers, err := ParseEther(b) + if err != nil { + conn.t.Logf("can't parse frame: %s", err) + continue // Ignore packets that can't be parsed. + } + if !conn.incoming.match(layers) { + continue // Ignore packets that don't match the expected incoming. + } + return (layers[udpLayerIndex]).(*UDP) + } + return nil +} + +// Expect a packet that matches the provided udp within the timeout specified. +// If it doesn't arrive in time, the test fails. +func (conn *UDPIPv4) Expect(udp UDP, timeout time.Duration) *UDP { + deadline := time.Now().Add(timeout) + for { + timeout = time.Until(deadline) + if timeout <= 0 { + return nil + } + gotUDP := conn.Recv(timeout) + if gotUDP == nil { + return nil + } + if udp.match(gotUDP) { + return gotUDP + } + } +} diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go index 8ea1706d3..f80dbb35f 100644 --- a/test/packetimpact/testbench/dut.go +++ b/test/packetimpact/testbench/dut.go @@ -305,6 +305,35 @@ func (dut *DUT) SetSockOptTimeval(sockfd, level, optname int32, tv *unix.Timeval } } +// RecvWithErrno calls recv on the DUT. +func (dut *DUT) RecvWithErrno(ctx context.Context, sockfd, len, flags int32) (int32, []byte, error) { + dut.t.Helper() + req := pb.RecvRequest{ + Sockfd: sockfd, + Len: len, + Flags: flags, + } + resp, err := dut.posixServer.Recv(ctx, &req) + if err != nil { + dut.t.Fatalf("failed to call Recv: %s", err) + } + return resp.GetRet(), resp.GetBuf(), syscall.Errno(resp.GetErrno_()) +} + +// Recv calls recv on the DUT and causes a fatal test failure if it doesn't +// succeed. If more control over the timeout or error handling is needed, use +// RecvWithErrno. +func (dut *DUT) Recv(sockfd, len, flags int32) []byte { + dut.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) + defer cancel() + ret, buf, err := dut.RecvWithErrno(ctx, sockfd, len, flags) + if ret == -1 { + dut.t.Fatalf("failed to recv: %s", err) + } + return buf +} + // CloseWithErrno calls close on the DUT. func (dut *DUT) CloseWithErrno(fd int32) (int32, error) { dut.t.Helper() @@ -330,10 +359,11 @@ func (dut *DUT) Close(fd int32) { } } -// CreateListener makes a new TCP connection. If it fails, the test ends. -func (dut *DUT) CreateListener(typ, proto, backlog int32) (int32, uint16) { +// CreateBoundSocket makes a new socket on the DUT, with type typ and protocol +// proto, and bound to the IP address addr. Returns the new file descriptor and +// the port that was selected on the DUT. +func (dut *DUT) CreateBoundSocket(typ, proto int32, addr net.IP) (int32, uint16) { dut.t.Helper() - addr := net.ParseIP(*remoteIPv4) var fd int32 if addr.To4() != nil { fd = dut.Socket(unix.AF_INET, typ, proto) @@ -358,6 +388,12 @@ func (dut *DUT) CreateListener(typ, proto, backlog int32) (int32, uint16) { default: dut.t.Fatalf("unknown sockaddr type from getsockname: %t", sa) } - dut.Listen(fd, backlog) return fd, uint16(port) } + +// CreateListener makes a new TCP connection. If it fails, the test ends. +func (dut *DUT) CreateListener(typ, proto, backlog int32) (int32, uint16) { + fd, remotePort := dut.CreateBoundSocket(typ, proto, net.ParseIP(*remoteIPv4)) + dut.Listen(fd, backlog) + return fd, remotePort +} diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go index 35fa4dcb6..d7434c3d2 100644 --- a/test/packetimpact/testbench/layers.go +++ b/test/packetimpact/testbench/layers.go @@ -22,6 +22,7 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" "github.com/imdario/mergo" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -97,7 +98,7 @@ func equalLayer(x, y Layer) bool { return cmp.Equal(x, y, opt, cmpopts.IgnoreUnexported(LayerBase{})) } -// Ether can construct and match the ethernet encapsulation. +// Ether can construct and match an ethernet encapsulation. type Ether struct { LayerBase SrcAddr *tcpip.LinkAddress @@ -161,7 +162,7 @@ func ParseEther(b []byte) (Layers, error) { return append(layers, moreLayers...), nil default: // TODO(b/150301488): Support more protocols, like IPv6. - return nil, fmt.Errorf("can't deduce the ethernet header's next protocol: %v", b) + return nil, fmt.Errorf("can't deduce the ethernet header's next protocol: %#v", b) } } @@ -173,7 +174,7 @@ func (l *Ether) length() int { return header.EthernetMinimumSize } -// IPv4 can construct and match the ethernet excapulation. +// IPv4 can construct and match an IPv4 encapsulation. type IPv4 struct { LayerBase IHL *uint8 @@ -236,9 +237,11 @@ func (l *IPv4) toBytes() ([]byte, error) { switch n := l.next().(type) { case *TCP: fields.Protocol = uint8(header.TCPProtocolNumber) + case *UDP: + fields.Protocol = uint8(header.UDPProtocolNumber) default: - // TODO(b/150301488): Support more protocols, like UDP. - return nil, fmt.Errorf("can't deduce the ip header's next protocol: %+v", n) + // TODO(b/150301488): Support more protocols as needed. + return nil, fmt.Errorf("can't deduce the ip header's next protocol: %#v", n) } } if l.SrcAddr != nil { @@ -294,13 +297,19 @@ func ParseIPv4(b []byte) (Layers, error) { DstAddr: Address(h.DestinationAddress()), } layers := Layers{&ipv4} - switch h.Protocol() { - case uint8(header.TCPProtocolNumber): + switch h.TransportProtocol() { + case header.TCPProtocolNumber: moreLayers, err := ParseTCP(b[ipv4.length():]) if err != nil { return nil, err } return append(layers, moreLayers...), nil + case header.UDPProtocolNumber: + moreLayers, err := ParseUDP(b[ipv4.length():]) + if err != nil { + return nil, err + } + return append(layers, moreLayers...), nil } return nil, fmt.Errorf("can't deduce the ethernet header's next protocol: %d", h.Protocol()) } @@ -316,7 +325,7 @@ func (l *IPv4) length() int { return int(*l.IHL) } -// TCP can construct and match the TCP excapulation. +// TCP can construct and match a TCP encapsulation. type TCP struct { LayerBase SrcPort *uint16 @@ -347,12 +356,16 @@ func (l *TCP) toBytes() ([]byte, error) { } if l.DataOffset != nil { h.SetDataOffset(*l.DataOffset) + } else { + h.SetDataOffset(uint8(l.length())) } if l.Flags != nil { h.SetFlags(*l.Flags) } if l.WindowSize != nil { h.SetWindowSize(*l.WindowSize) + } else { + h.SetWindowSize(32768) } if l.UrgentPointer != nil { h.SetUrgentPoiner(*l.UrgentPointer) @@ -361,38 +374,52 @@ func (l *TCP) toBytes() ([]byte, error) { h.SetChecksum(*l.Checksum) return h, nil } - if err := setChecksum(&h, l); err != nil { + if err := setTCPChecksum(&h, l); err != nil { return nil, err } return h, nil } -// setChecksum calculates the checksum of the TCP header and sets it in h. -func setChecksum(h *header.TCP, tcp *TCP) error { - h.SetChecksum(0) - tcpLength := uint16(tcp.length()) - current := tcp.next() - for current != nil { - tcpLength += uint16(current.length()) - current = current.next() +// totalLength returns the length of the provided layer and all following +// layers. +func totalLength(l Layer) int { + var totalLength int + for ; l != nil; l = l.next() { + totalLength += l.length() } + return totalLength +} +// layerChecksum calculates the checksum of the Layer header, including the +// peusdeochecksum of the layer before it and all the bytes after it.. +func layerChecksum(l Layer, protoNumber tcpip.TransportProtocolNumber) (uint16, error) { + totalLength := uint16(totalLength(l)) var xsum uint16 - switch s := tcp.prev().(type) { + switch s := l.prev().(type) { case *IPv4: - xsum = header.PseudoHeaderChecksum(header.TCPProtocolNumber, *s.SrcAddr, *s.DstAddr, tcpLength) + xsum = header.PseudoHeaderChecksum(protoNumber, *s.SrcAddr, *s.DstAddr, totalLength) default: // TODO(b/150301488): Support more protocols, like IPv6. - return fmt.Errorf("can't get src and dst addr from previous layer") + return 0, fmt.Errorf("can't get src and dst addr from previous layer: %#v", s) } - current = tcp.next() - for current != nil { + var payloadBytes buffer.VectorisedView + for current := l.next(); current != nil; current = current.next() { payload, err := current.toBytes() if err != nil { - return fmt.Errorf("can't get bytes for next header: %s", payload) + return 0, fmt.Errorf("can't get bytes for next header: %s", payload) } - xsum = header.Checksum(payload, xsum) - current = current.next() + payloadBytes.AppendView(payload) + } + xsum = header.ChecksumVV(payloadBytes, xsum) + return xsum, nil +} + +// setTCPChecksum calculates the checksum of the TCP header and sets it in h. +func setTCPChecksum(h *header.TCP, tcp *TCP) error { + h.SetChecksum(0) + xsum, err := layerChecksum(tcp, header.TCPProtocolNumber) + if err != nil { + return err } h.SetChecksum(^h.CalculateChecksum(xsum)) return nil @@ -444,6 +471,85 @@ func (l *TCP) merge(other TCP) error { return mergo.Merge(l, other, mergo.WithOverride) } +// UDP can construct and match a UDP encapsulation. +type UDP struct { + LayerBase + SrcPort *uint16 + DstPort *uint16 + Length *uint16 + Checksum *uint16 +} + +func (l *UDP) toBytes() ([]byte, error) { + b := make([]byte, header.UDPMinimumSize) + h := header.UDP(b) + if l.SrcPort != nil { + h.SetSourcePort(*l.SrcPort) + } + if l.DstPort != nil { + h.SetDestinationPort(*l.DstPort) + } + if l.Length != nil { + h.SetLength(*l.Length) + } else { + h.SetLength(uint16(totalLength(l))) + } + if l.Checksum != nil { + h.SetChecksum(*l.Checksum) + return h, nil + } + if err := setUDPChecksum(&h, l); err != nil { + return nil, err + } + return h, nil +} + +// setUDPChecksum calculates the checksum of the UDP header and sets it in h. +func setUDPChecksum(h *header.UDP, udp *UDP) error { + h.SetChecksum(0) + xsum, err := layerChecksum(udp, header.UDPProtocolNumber) + if err != nil { + return err + } + h.SetChecksum(^h.CalculateChecksum(xsum)) + return nil +} + +// ParseUDP parses the bytes assuming that they start with a udp header and +// continues parsing further encapsulations. +func ParseUDP(b []byte) (Layers, error) { + h := header.UDP(b) + udp := UDP{ + SrcPort: Uint16(h.SourcePort()), + DstPort: Uint16(h.DestinationPort()), + Length: Uint16(h.Length()), + Checksum: Uint16(h.Checksum()), + } + layers := Layers{&udp} + moreLayers, err := ParsePayload(b[udp.length():]) + if err != nil { + return nil, err + } + return append(layers, moreLayers...), nil +} + +func (l *UDP) match(other Layer) bool { + return equalLayer(l, other) +} + +func (l *UDP) length() int { + if l.Length == nil { + return header.UDPMinimumSize + } + return int(*l.Length) +} + +// merge overrides the values in l with the values from other but only in fields +// where the value is not nil. +func (l *UDP) merge(other UDP) error { + return mergo.Merge(l, other, mergo.WithOverride) +} + // Payload has bytes beyond OSI layer 4. type Payload struct { LayerBase diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD index 1dff2a4d5..9a4d66ea9 100644 --- a/test/packetimpact/tests/BUILD +++ b/test/packetimpact/tests/BUILD @@ -15,6 +15,19 @@ packetimpact_go_test( ], ) +packetimpact_go_test( + name = "udp_recv_multicast", + srcs = ["udp_recv_multicast_test.go"], + # TODO(b/152813495): Fix netstack then remove the line below. + netstack = False, + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + sh_binary( name = "test_runner", srcs = ["test_runner.sh"], diff --git a/test/packetimpact/tests/Dockerfile b/test/packetimpact/tests/Dockerfile index 507030cc7..9075bc555 100644 --- a/test/packetimpact/tests/Dockerfile +++ b/test/packetimpact/tests/Dockerfile @@ -1,5 +1,17 @@ FROM ubuntu:bionic -RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y iptables netcat tcpdump iproute2 tshark +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \ + # iptables to disable OS native packet processing. + iptables \ + # nc to check that the posix_server is running. + netcat \ + # tcpdump to log brief packet sniffing. + tcpdump \ + # ip link show to display MAC addresses. + iproute2 \ + # tshark to log verbose packet sniffing. + tshark \ + # killall for cleanup. + psmisc RUN hash -r CMD /bin/bash diff --git a/test/packetimpact/tests/defs.bzl b/test/packetimpact/tests/defs.bzl index 1b4213d9b..8c0d058b2 100644 --- a/test/packetimpact/tests/defs.bzl +++ b/test/packetimpact/tests/defs.bzl @@ -93,7 +93,17 @@ def packetimpact_netstack_test(name, testbench_binary, **kwargs): **kwargs ) -def packetimpact_go_test(name, size = "small", pure = True, **kwargs): +def packetimpact_go_test(name, size = "small", pure = True, linux = True, netstack = True, **kwargs): + """Add packetimpact tests written in go. + + Args: + name: name of the test + size: size of the test + pure: make a static go binary + linux: generate a linux test + netstack: generate a netstack test + **kwargs: all the other args, forwarded to go_test + """ testbench_binary = name + "_test" go_test( name = testbench_binary, @@ -102,5 +112,7 @@ def packetimpact_go_test(name, size = "small", pure = True, **kwargs): tags = PACKETIMPACT_TAGS, **kwargs ) - packetimpact_linux_test(name = name, testbench_binary = testbench_binary) - packetimpact_netstack_test(name = name, testbench_binary = testbench_binary) + if linux: + packetimpact_linux_test(name = name, testbench_binary = testbench_binary) + if netstack: + packetimpact_netstack_test(name = name, testbench_binary = testbench_binary) diff --git a/test/packetimpact/tests/fin_wait2_timeout_test.go b/test/packetimpact/tests/fin_wait2_timeout_test.go index 5f54e67ed..2b3f39045 100644 --- a/test/packetimpact/tests/fin_wait2_timeout_test.go +++ b/test/packetimpact/tests/fin_wait2_timeout_test.go @@ -36,7 +36,7 @@ func TestFinWait2Timeout(t *testing.T) { defer dut.TearDown() listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) defer dut.Close(listenFd) - conn := tb.NewTCPIPv4(t, dut, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort}) + conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort}) defer conn.Close() conn.Handshake() diff --git a/test/packetimpact/tests/test_runner.sh b/test/packetimpact/tests/test_runner.sh index 5281cb53d..e99fc7d09 100755 --- a/test/packetimpact/tests/test_runner.sh +++ b/test/packetimpact/tests/test_runner.sh @@ -29,13 +29,15 @@ function failure() { } trap 'failure ${LINENO} "$BASH_COMMAND"' ERR -declare -r LONGOPTS="dut_platform:,posix_server_binary:,testbench_binary:,runtime:,tshark" +declare -r LONGOPTS="dut_platform:,posix_server_binary:,testbench_binary:,runtime:,tshark,extra_test_arg:" # Don't use declare below so that the error from getopt will end the script. PARSED=$(getopt --options "" --longoptions=$LONGOPTS --name "$0" -- "$@") eval set -- "$PARSED" +declare -a EXTRA_TEST_ARGS + while true; do case "$1" in --dut_platform) @@ -62,6 +64,10 @@ while true; do declare -r TSHARK="1" shift 1 ;; + --extra_test_arg) + EXTRA_TEST_ARGS+="$2" + shift 2 + ;; --) shift break @@ -125,6 +131,19 @@ docker --version function finish { local cleanup_success=1 + + if [[ -z "${TSHARK-}" ]]; then + # Kill tcpdump so that it will flush output. + docker exec -t "${TESTBENCH}" \ + killall tcpdump || \ + cleanup_success=0 + else + # Kill tshark so that it will flush output. + docker exec -t "${TESTBENCH}" \ + killall tshark || \ + cleanup_success=0 + fi + for net in "${CTRL_NET}" "${TEST_NET}"; do # Kill all processes attached to ${net}. for docker_command in "kill" "rm"; do @@ -224,6 +243,8 @@ else # interface with the test packets. docker exec -t "${TESTBENCH}" \ tshark -V -l -n -i "${TEST_DEVICE}" \ + -o tcp.check_checksum:TRUE \ + -o udp.check_checksum:TRUE \ host "${TEST_NET_PREFIX}${TESTBENCH_NET_SUFFIX}" & fi @@ -235,6 +256,7 @@ sleep 3 # be executed on the DUT. docker exec -t "${TESTBENCH}" \ /bin/bash -c "${DOCKER_TESTBENCH_BINARY} \ + ${EXTRA_TEST_ARGS[@]-} \ --posix_server_ip=${CTRL_NET_PREFIX}${DUT_NET_SUFFIX} \ --posix_server_port=${CTRL_PORT} \ --remote_ipv4=${TEST_NET_PREFIX}${DUT_NET_SUFFIX} \ diff --git a/test/packetimpact/tests/udp_recv_multicast_test.go b/test/packetimpact/tests/udp_recv_multicast_test.go new file mode 100644 index 000000000..bc1b0be49 --- /dev/null +++ b/test/packetimpact/tests/udp_recv_multicast_test.go @@ -0,0 +1,37 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package udp_recv_multicast_test + +import ( + "net" + "testing" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip" + tb "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func TestUDPRecvMulticast(t *testing.T) { + dut := tb.NewDUT(t) + defer dut.TearDown() + boundFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0")) + defer dut.Close(boundFD) + conn := tb.NewUDPIPv4(t, tb.UDP{DstPort: &remotePort}, tb.UDP{SrcPort: &remotePort}) + defer conn.Close() + frame := conn.CreateFrame(tb.UDP{}, &tb.Payload{Bytes: []byte("hello world")}) + frame[1].(*tb.IPv4).DstAddr = tb.Address(tcpip.Address(net.ParseIP("224.0.0.1").To4())) + conn.SendFrame(frame) + dut.Recv(boundFD, 100, 0) +}