gvisor/test/packetimpact/testbench/connections.go

445 lines
13 KiB
Go
Raw Normal View History

// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package testbench has utilities to send and receive packets and also command
// the DUT to run POSIX functions.
package testbench
import (
"flag"
"fmt"
"math/rand"
"net"
"strings"
"testing"
"time"
"github.com/mohae/deepcopy"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
)
var localIPv4 = flag.String("local_ipv4", "", "local IPv4 address for test packets")
var remoteIPv4 = flag.String("remote_ipv4", "", "remote IPv4 address for test packets")
var localMAC = flag.String("local_mac", "", "local mac address for test packets")
var remoteMAC = flag.String("remote_mac", "", "remote mac address for test packets")
// pickPort makes a new socket and returns the socket FD and port. The caller
// must close the FD when done with the port if there is no error.
func pickPort() (int, uint16, error) {
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_STREAM, 0)
if err != nil {
return -1, 0, err
}
var sa unix.SockaddrInet4
copy(sa.Addr[0:4], net.ParseIP(*localIPv4).To4())
if err := unix.Bind(fd, &sa); err != nil {
unix.Close(fd)
return -1, 0, err
}
newSockAddr, err := unix.Getsockname(fd)
if err != nil {
unix.Close(fd)
return -1, 0, err
}
newSockAddrInet4, ok := newSockAddr.(*unix.SockaddrInet4)
if !ok {
unix.Close(fd)
return -1, 0, fmt.Errorf("can't cast Getsockname result to SockaddrInet4")
}
return fd, uint16(newSockAddrInet4.Port), nil
}
// TCPIPv4 maintains state about a TCP/IPv4 connection.
type TCPIPv4 struct {
outgoing Layers
incoming Layers
LocalSeqNum seqnum.Value
RemoteSeqNum seqnum.Value
SynAck *TCP
sniffer Sniffer
injector Injector
portPickerFD int
t *testing.T
}
// tcpLayerIndex is the position of the TCP layer in the TCPIPv4 connection. It
// is the third, after Ethernet and IPv4.
const tcpLayerIndex int = 2
// NewTCPIPv4 creates a new TCPIPv4 connection with reasonable defaults.
func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 {
lMAC, err := tcpip.ParseMACAddress(*localMAC)
if err != nil {
t.Fatalf("can't parse localMAC %q: %s", *localMAC, err)
}
rMAC, err := tcpip.ParseMACAddress(*remoteMAC)
if err != nil {
t.Fatalf("can't parse remoteMAC %q: %s", *remoteMAC, err)
}
portPickerFD, localPort, err := pickPort()
if err != nil {
t.Fatalf("can't pick a port: %s", err)
}
lIP := tcpip.Address(net.ParseIP(*localIPv4).To4())
rIP := tcpip.Address(net.ParseIP(*remoteIPv4).To4())
sniffer, err := NewSniffer(t)
if err != nil {
t.Fatalf("can't make new sniffer: %s", err)
}
injector, err := NewInjector(t)
if err != nil {
t.Fatalf("can't make new injector: %s", err)
}
newOutgoingTCP := &TCP{
SrcPort: &localPort,
}
if err := newOutgoingTCP.merge(outgoingTCP); err != nil {
t.Fatalf("can't merge %+v into %+v: %s", outgoingTCP, newOutgoingTCP, err)
}
newIncomingTCP := &TCP{
DstPort: &localPort,
}
if err := newIncomingTCP.merge(incomingTCP); err != nil {
t.Fatalf("can't merge %+v into %+v: %s", incomingTCP, newIncomingTCP, err)
}
return TCPIPv4{
outgoing: Layers{
&Ether{SrcAddr: &lMAC, DstAddr: &rMAC},
&IPv4{SrcAddr: &lIP, DstAddr: &rIP},
newOutgoingTCP},
incoming: Layers{
&Ether{SrcAddr: &rMAC, DstAddr: &lMAC},
&IPv4{SrcAddr: &rIP, DstAddr: &lIP},
newIncomingTCP},
sniffer: sniffer,
injector: injector,
portPickerFD: portPickerFD,
t: t,
LocalSeqNum: seqnum.Value(rand.Uint32()),
}
}
// Close the injector and sniffer associated with this connection.
func (conn *TCPIPv4) Close() {
conn.sniffer.Close()
conn.injector.Close()
if err := unix.Close(conn.portPickerFD); err != nil {
conn.t.Fatalf("can't close portPickerFD: %s", err)
}
conn.portPickerFD = -1
}
// CreateFrame builds a frame for the connection with tcp overriding defaults
// and additionalLayers added after the TCP header.
func (conn *TCPIPv4) CreateFrame(tcp TCP, additionalLayers ...Layer) Layers {
if tcp.SeqNum == nil {
tcp.SeqNum = Uint32(uint32(conn.LocalSeqNum))
}
if tcp.AckNum == nil {
tcp.AckNum = Uint32(uint32(conn.RemoteSeqNum))
}
layersToSend := deepcopy.Copy(conn.outgoing).(Layers)
if err := layersToSend[tcpLayerIndex].(*TCP).merge(tcp); err != nil {
conn.t.Fatalf("can't merge %+v into %+v: %s", tcp, layersToSend[tcpLayerIndex], err)
}
layersToSend = append(layersToSend, additionalLayers...)
return layersToSend
}
// SendFrame sends a frame with reasonable defaults.
func (conn *TCPIPv4) SendFrame(frame Layers) {
outBytes, err := frame.toBytes()
if err != nil {
conn.t.Fatalf("can't build outgoing TCP packet: %s", err)
}
conn.injector.Send(outBytes)
// Compute the next TCP sequence number.
for i := tcpLayerIndex + 1; i < len(frame); i++ {
conn.LocalSeqNum.UpdateForward(seqnum.Size(frame[i].length()))
}
tcp := frame[tcpLayerIndex].(*TCP)
if tcp.Flags != nil && *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 {
conn.LocalSeqNum.UpdateForward(1)
}
}
// Send a packet with reasonable defaults and override some fields by tcp.
func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) {
conn.SendFrame(conn.CreateFrame(tcp, additionalLayers...))
}
// Recv gets a packet from the sniffer within the timeout provided.
// If no packet arrives before the timeout, it returns nil.
func (conn *TCPIPv4) Recv(timeout time.Duration) *TCP {
layers := conn.RecvFrame(timeout)
if tcpLayerIndex < len(layers) {
return layers[tcpLayerIndex].(*TCP)
}
return nil
}
// RecvFrame gets a frame (of type Layers) within the timeout provided.
// If no frame arrives before the timeout, it returns nil.
func (conn *TCPIPv4) RecvFrame(timeout time.Duration) Layers {
deadline := time.Now().Add(timeout)
for {
timeout = time.Until(deadline)
if timeout <= 0 {
break
}
b := conn.sniffer.Recv(timeout)
if b == nil {
break
}
layers := Parse(ParseEther, b)
if !conn.incoming.match(layers) {
continue // Ignore packets that don't match the expected incoming.
}
tcpHeader := (layers[tcpLayerIndex]).(*TCP)
conn.RemoteSeqNum = seqnum.Value(*tcpHeader.SeqNum)
if *tcpHeader.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 {
conn.RemoteSeqNum.UpdateForward(1)
}
for i := tcpLayerIndex + 1; i < len(layers); i++ {
conn.RemoteSeqNum.UpdateForward(seqnum.Size(layers[i].length()))
}
return layers
}
return nil
}
// Expect a packet that matches the provided tcp within the timeout specified.
// If it doesn't arrive in time, it returns nil.
func (conn *TCPIPv4) Expect(tcp TCP, timeout time.Duration) (*TCP, error) {
// We cannot implement this directly using ExpectFrame as we cannot specify
// the Payload part.
deadline := time.Now().Add(timeout)
var allTCP []string
for {
var gotTCP *TCP
if timeout = time.Until(deadline); timeout > 0 {
gotTCP = conn.Recv(timeout)
}
if gotTCP == nil {
return nil, fmt.Errorf("got %d packets:\n%s", len(allTCP), strings.Join(allTCP, "\n"))
}
if tcp.match(gotTCP) {
return gotTCP, nil
}
allTCP = append(allTCP, gotTCP.String())
}
}
// ExpectFrame expects a frame that matches the specified layers within the
// timeout specified. If it doesn't arrive in time, it returns nil.
func (conn *TCPIPv4) ExpectFrame(layers Layers, timeout time.Duration) Layers {
deadline := time.Now().Add(timeout)
for {
timeout = time.Until(deadline)
if timeout <= 0 {
return nil
}
gotLayers := conn.RecvFrame(timeout)
if layers.match(gotLayers) {
return gotLayers
}
}
}
// ExpectData is a convenient method that expects a TCP packet along with
// the payload to arrive within the timeout specified. If it doesn't arrive
// in time, it causes a fatal test failure.
func (conn *TCPIPv4) ExpectData(tcp TCP, data []byte, timeout time.Duration) {
expected := []Layer{&Ether{}, &IPv4{}, &tcp}
if len(data) > 0 {
expected = append(expected, &Payload{Bytes: data})
}
if conn.ExpectFrame(expected, timeout) == nil {
conn.t.Fatalf("expected to get a TCP frame %s with payload %x", &tcp, data)
}
}
// Handshake performs a TCP 3-way handshake.
func (conn *TCPIPv4) Handshake() {
// Send the SYN.
conn.Send(TCP{Flags: Uint8(header.TCPFlagSyn)})
// Wait for the SYN-ACK.
synAck, err := conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
if synAck == nil {
conn.t.Fatalf("didn't get synack during handshake: %s", err)
}
conn.SynAck = synAck
// Send an ACK.
conn.Send(TCP{Flags: Uint8(header.TCPFlagAck)})
}
// UDPIPv4 maintains state about a UDP/IPv4 connection.
type UDPIPv4 struct {
outgoing Layers
incoming Layers
sniffer Sniffer
injector Injector
portPickerFD int
t *testing.T
}
// udpLayerIndex is the position of the UDP layer in the UDPIPv4 connection. It
// is the third, after Ethernet and IPv4.
const udpLayerIndex int = 2
// NewUDPIPv4 creates a new UDPIPv4 connection with reasonable defaults.
func NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 {
lMAC, err := tcpip.ParseMACAddress(*localMAC)
if err != nil {
t.Fatalf("can't parse localMAC %q: %s", *localMAC, err)
}
rMAC, err := tcpip.ParseMACAddress(*remoteMAC)
if err != nil {
t.Fatalf("can't parse remoteMAC %q: %s", *remoteMAC, err)
}
portPickerFD, localPort, err := pickPort()
if err != nil {
t.Fatalf("can't pick a port: %s", err)
}
lIP := tcpip.Address(net.ParseIP(*localIPv4).To4())
rIP := tcpip.Address(net.ParseIP(*remoteIPv4).To4())
sniffer, err := NewSniffer(t)
if err != nil {
t.Fatalf("can't make new sniffer: %s", err)
}
injector, err := NewInjector(t)
if err != nil {
t.Fatalf("can't make new injector: %s", err)
}
newOutgoingUDP := &UDP{
SrcPort: &localPort,
}
if err := newOutgoingUDP.merge(outgoingUDP); err != nil {
t.Fatalf("can't merge %+v into %+v: %s", outgoingUDP, newOutgoingUDP, err)
}
newIncomingUDP := &UDP{
DstPort: &localPort,
}
if err := newIncomingUDP.merge(incomingUDP); err != nil {
t.Fatalf("can't merge %+v into %+v: %s", incomingUDP, newIncomingUDP, err)
}
return UDPIPv4{
outgoing: Layers{
&Ether{SrcAddr: &lMAC, DstAddr: &rMAC},
&IPv4{SrcAddr: &lIP, DstAddr: &rIP},
newOutgoingUDP},
incoming: Layers{
&Ether{SrcAddr: &rMAC, DstAddr: &lMAC},
&IPv4{SrcAddr: &rIP, DstAddr: &lIP},
newIncomingUDP},
sniffer: sniffer,
injector: injector,
portPickerFD: portPickerFD,
t: t,
}
}
// Close the injector and sniffer associated with this connection.
func (conn *UDPIPv4) Close() {
conn.sniffer.Close()
conn.injector.Close()
if err := unix.Close(conn.portPickerFD); err != nil {
conn.t.Fatalf("can't close portPickerFD: %s", err)
}
conn.portPickerFD = -1
}
// CreateFrame builds a frame for the connection with the provided udp
// overriding defaults and the additionalLayers added after the UDP header.
func (conn *UDPIPv4) CreateFrame(udp UDP, additionalLayers ...Layer) Layers {
layersToSend := deepcopy.Copy(conn.outgoing).(Layers)
if err := layersToSend[udpLayerIndex].(*UDP).merge(udp); err != nil {
conn.t.Fatalf("can't merge %+v into %+v: %s", udp, layersToSend[udpLayerIndex], err)
}
layersToSend = append(layersToSend, additionalLayers...)
return layersToSend
}
// SendFrame sends a frame with reasonable defaults.
func (conn *UDPIPv4) SendFrame(frame Layers) {
outBytes, err := frame.toBytes()
if err != nil {
conn.t.Fatalf("can't build outgoing UDP packet: %s", err)
}
conn.injector.Send(outBytes)
}
// Send a packet with reasonable defaults and override some fields by udp.
func (conn *UDPIPv4) Send(udp UDP, additionalLayers ...Layer) {
conn.SendFrame(conn.CreateFrame(udp, additionalLayers...))
}
// Recv gets a packet from the sniffer within the timeout provided. If no packet
// arrives before the timeout, it returns nil.
func (conn *UDPIPv4) Recv(timeout time.Duration) *UDP {
deadline := time.Now().Add(timeout)
for {
timeout = time.Until(deadline)
if timeout <= 0 {
break
}
b := conn.sniffer.Recv(timeout)
if b == nil {
break
}
layers := Parse(ParseEther, b)
if !conn.incoming.match(layers) {
continue // Ignore packets that don't match the expected incoming.
}
return (layers[udpLayerIndex]).(*UDP)
}
return nil
}
// Expect a packet that matches the provided udp within the timeout specified.
// If it doesn't arrive in time, the test fails.
func (conn *UDPIPv4) Expect(udp UDP, timeout time.Duration) (*UDP, error) {
deadline := time.Now().Add(timeout)
var allUDP []string
for {
var gotUDP *UDP
if timeout = time.Until(deadline); timeout > 0 {
gotUDP = conn.Recv(timeout)
}
if gotUDP == nil {
return nil, fmt.Errorf("got %d packets:\n%s", len(allUDP), strings.Join(allUDP, "\n"))
}
if udp.match(gotUDP) {
return gotUDP, nil
}
allUDP = append(allUDP, gotUDP.String())
}
}