gvisor/pkg/tcpip/checker/checker.go

1299 lines
37 KiB
Go

// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package checker provides helper functions to check networking packets for
// validity.
package checker
import (
"encoding/binary"
"reflect"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
)
// NetworkChecker is a function to check a property of a network packet.
type NetworkChecker func(*testing.T, []header.Network)
// TransportChecker is a function to check a property of a transport packet.
type TransportChecker func(*testing.T, header.Transport)
// ControlMessagesChecker is a function to check a property of ancillary data.
type ControlMessagesChecker func(*testing.T, tcpip.ControlMessages)
// IPv4 checks the validity and properties of the given IPv4 packet. It is
// expected to be used in conjunction with other network checkers for specific
// properties. For example, to check the source and destination address, one
// would call:
//
// checker.IPv4(t, b, checker.SrcAddr(x), checker.DstAddr(y))
func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) {
t.Helper()
ipv4 := header.IPv4(b)
if !ipv4.IsValid(len(b)) {
t.Error("Not a valid IPv4 packet")
}
xsum := ipv4.CalculateChecksum()
if xsum != 0 && xsum != 0xffff {
t.Errorf("Bad checksum: 0x%x, checksum in packet: 0x%x", xsum, ipv4.Checksum())
}
for _, f := range checkers {
f(t, []header.Network{ipv4})
}
if t.Failed() {
t.FailNow()
}
}
// IPv6 checks the validity and properties of the given IPv6 packet. The usage
// is similar to IPv4.
func IPv6(t *testing.T, b []byte, checkers ...NetworkChecker) {
t.Helper()
ipv6 := header.IPv6(b)
if !ipv6.IsValid(len(b)) {
t.Error("Not a valid IPv6 packet")
}
for _, f := range checkers {
f(t, []header.Network{ipv6})
}
if t.Failed() {
t.FailNow()
}
}
// SrcAddr creates a checker that checks the source address.
func SrcAddr(addr tcpip.Address) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
if a := h[0].SourceAddress(); a != addr {
t.Errorf("Bad source address, got %v, want %v", a, addr)
}
}
}
// DstAddr creates a checker that checks the destination address.
func DstAddr(addr tcpip.Address) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
if a := h[0].DestinationAddress(); a != addr {
t.Errorf("Bad destination address, got %v, want %v", a, addr)
}
}
}
// TTL creates a checker that checks the TTL (ipv4) or HopLimit (ipv6).
func TTL(ttl uint8) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
var v uint8
switch ip := h[0].(type) {
case header.IPv4:
v = ip.TTL()
case header.IPv6:
v = ip.HopLimit()
}
if v != ttl {
t.Fatalf("Bad TTL, got = %d, want = %d", v, ttl)
}
}
}
// IPFullLength creates a checker for the full IP packet length. The
// expected size is checked against both the Total Length in the
// header and the number of bytes received.
func IPFullLength(packetLength uint16) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
var v uint16
var l uint16
switch ip := h[0].(type) {
case header.IPv4:
v = ip.TotalLength()
l = uint16(len(ip))
case header.IPv6:
v = ip.PayloadLength() + header.IPv6FixedHeaderSize
l = uint16(len(ip))
default:
t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4 or header.IPv6", ip)
}
if l != packetLength {
t.Errorf("bad packet length, got = %d, want = %d", l, packetLength)
}
if v != packetLength {
t.Errorf("unexpected packet length in header, got = %d, want = %d", v, packetLength)
}
}
}
// IPv4HeaderLength creates a checker that checks the IPv4 Header length.
func IPv4HeaderLength(headerLength int) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
switch ip := h[0].(type) {
case header.IPv4:
if hl := ip.HeaderLength(); hl != uint8(headerLength) {
t.Errorf("Bad header length, got = %d, want = %d", hl, headerLength)
}
default:
t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", ip)
}
}
}
// PayloadLen creates a checker that checks the payload length.
func PayloadLen(payloadLength int) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
if l := len(h[0].Payload()); l != payloadLength {
t.Errorf("Bad payload length, got = %d, want = %d", l, payloadLength)
}
}
}
// IPPayload creates a checker that checks the payload.
func IPPayload(payload []byte) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
got := h[0].Payload()
// cmp.Diff does not consider nil slices equal to empty slices, but we do.
if len(got) == 0 && len(payload) == 0 {
return
}
if diff := cmp.Diff(payload, got); diff != "" {
t.Errorf("payload mismatch (-want +got):\n%s", diff)
}
}
}
// IPv4Options returns a checker that checks the options in an IPv4 packet.
func IPv4Options(want header.IPv4Options) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
ip, ok := h[0].(header.IPv4)
if !ok {
t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0])
}
options := ip.Options()
// cmp.Diff does not consider nil slices equal to empty slices, but we do.
if len(want) == 0 && len(options) == 0 {
return
}
if diff := cmp.Diff(want, options); diff != "" {
t.Errorf("options mismatch (-want +got):\n%s", diff)
}
}
}
// FragmentOffset creates a checker that checks the FragmentOffset field.
func FragmentOffset(offset uint16) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
// We only do this for IPv4 for now.
switch ip := h[0].(type) {
case header.IPv4:
if v := ip.FragmentOffset(); v != offset {
t.Errorf("Bad fragment offset, got = %d, want = %d", v, offset)
}
}
}
}
// FragmentFlags creates a checker that checks the fragment flags field.
func FragmentFlags(flags uint8) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
// We only do this for IPv4 for now.
switch ip := h[0].(type) {
case header.IPv4:
if v := ip.Flags(); v != flags {
t.Errorf("Bad fragment offset, got = %d, want = %d", v, flags)
}
}
}
}
// ReceiveTClass creates a checker that checks the TCLASS field in
// ControlMessages.
func ReceiveTClass(want uint32) ControlMessagesChecker {
return func(t *testing.T, cm tcpip.ControlMessages) {
t.Helper()
if !cm.HasTClass {
t.Errorf("got cm.HasTClass = %t, want = true", cm.HasTClass)
} else if got := cm.TClass; got != want {
t.Errorf("got cm.TClass = %d, want %d", got, want)
}
}
}
// ReceiveTOS creates a checker that checks the TOS field in ControlMessages.
func ReceiveTOS(want uint8) ControlMessagesChecker {
return func(t *testing.T, cm tcpip.ControlMessages) {
t.Helper()
if !cm.HasTOS {
t.Errorf("got cm.HasTOS = %t, want = true", cm.HasTOS)
} else if got := cm.TOS; got != want {
t.Errorf("got cm.TOS = %d, want %d", got, want)
}
}
}
// ReceiveIPPacketInfo creates a checker that checks the PacketInfo field in
// ControlMessages.
func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker {
return func(t *testing.T, cm tcpip.ControlMessages) {
t.Helper()
if !cm.HasIPPacketInfo {
t.Errorf("got cm.HasIPPacketInfo = %t, want = true", cm.HasIPPacketInfo)
} else if diff := cmp.Diff(want, cm.PacketInfo); diff != "" {
t.Errorf("IPPacketInfo mismatch (-want +got):\n%s", diff)
}
}
}
// TOS creates a checker that checks the TOS field.
func TOS(tos uint8, label uint32) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
if v, l := h[0].TOS(); v != tos || l != label {
t.Errorf("Bad TOS, got = (%d, %d), want = (%d,%d)", v, l, tos, label)
}
}
}
// Raw creates a checker that checks the bytes of payload.
// The checker always checks the payload of the last network header.
// For instance, in case of IPv6 fragments, the payload that will be checked
// is the one containing the actual data that the packet is carrying, without
// the bytes added by the IPv6 fragmentation.
func Raw(want []byte) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
if got := h[len(h)-1].Payload(); !reflect.DeepEqual(got, want) {
t.Errorf("Wrong payload, got %v, want %v", got, want)
}
}
}
// IPv6Fragment creates a checker that validates an IPv6 fragment.
func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader {
t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber)
}
ipv6Frag := header.IPv6Fragment(h[0].Payload())
if !ipv6Frag.IsValid() {
t.Error("Not a valid IPv6 fragment")
}
for _, f := range checkers {
f(t, []header.Network{h[0], ipv6Frag})
}
if t.Failed() {
t.FailNow()
}
}
}
// TCP creates a checker that checks that the transport protocol is TCP and
// potentially additional transport header fields.
func TCP(checkers ...TransportChecker) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
first := h[0]
last := h[len(h)-1]
if p := last.TransportProtocol(); p != header.TCPProtocolNumber {
t.Errorf("Bad protocol, got = %d, want = %d", p, header.TCPProtocolNumber)
}
// Verify the checksum.
tcp := header.TCP(last.Payload())
l := uint16(len(tcp))
xsum := header.Checksum([]byte(first.SourceAddress()), 0)
xsum = header.Checksum([]byte(first.DestinationAddress()), xsum)
xsum = header.Checksum([]byte{0, byte(last.TransportProtocol())}, xsum)
xsum = header.Checksum([]byte{byte(l >> 8), byte(l)}, xsum)
xsum = header.Checksum(tcp, xsum)
if xsum != 0 && xsum != 0xffff {
t.Errorf("Bad checksum: 0x%x, checksum in segment: 0x%x", xsum, tcp.Checksum())
}
// Run the transport checkers.
for _, f := range checkers {
f(t, tcp)
}
if t.Failed() {
t.FailNow()
}
}
}
// UDP creates a checker that checks that the transport protocol is UDP and
// potentially additional transport header fields.
func UDP(checkers ...TransportChecker) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
last := h[len(h)-1]
if p := last.TransportProtocol(); p != header.UDPProtocolNumber {
t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber)
}
udp := header.UDP(last.Payload())
for _, f := range checkers {
f(t, udp)
}
if t.Failed() {
t.FailNow()
}
}
}
// SrcPort creates a checker that checks the source port.
func SrcPort(port uint16) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
if p := h.SourcePort(); p != port {
t.Errorf("Bad source port, got = %d, want = %d", p, port)
}
}
}
// DstPort creates a checker that checks the destination port.
func DstPort(port uint16) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
if p := h.DestinationPort(); p != port {
t.Errorf("Bad destination port, got = %d, want = %d", p, port)
}
}
}
// NoChecksum creates a checker that checks if the checksum is zero.
func NoChecksum(noChecksum bool) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
udp, ok := h.(header.UDP)
if !ok {
t.Fatalf("UDP header not found in h: %T", h)
}
if b := udp.Checksum() == 0; b != noChecksum {
t.Errorf("bad checksum state, got %t, want %t", b, noChecksum)
}
}
}
// TCPSeqNum creates a checker that checks the sequence number.
func TCPSeqNum(seq uint32) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
tcp, ok := h.(header.TCP)
if !ok {
t.Fatalf("TCP header not found in h: %T", h)
}
if s := tcp.SequenceNumber(); s != seq {
t.Errorf("Bad sequence number, got = %d, want = %d", s, seq)
}
}
}
// TCPAckNum creates a checker that checks the ack number.
func TCPAckNum(seq uint32) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
tcp, ok := h.(header.TCP)
if !ok {
t.Fatalf("TCP header not found in h: %T", h)
}
if s := tcp.AckNumber(); s != seq {
t.Errorf("Bad ack number, got = %d, want = %d", s, seq)
}
}
}
// TCPWindow creates a checker that checks the tcp window.
func TCPWindow(window uint16) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
tcp, ok := h.(header.TCP)
if !ok {
t.Fatalf("TCP header not found in hdr : %T", h)
}
if w := tcp.WindowSize(); w != window {
t.Errorf("Bad window, got %d, want %d", w, window)
}
}
}
// TCPWindowGreaterThanEq creates a checker that checks that the TCP window
// is greater than or equal to the provided value.
func TCPWindowGreaterThanEq(window uint16) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
tcp, ok := h.(header.TCP)
if !ok {
t.Fatalf("TCP header not found in h: %T", h)
}
if w := tcp.WindowSize(); w < window {
t.Errorf("Bad window, got %d, want > %d", w, window)
}
}
}
// TCPWindowLessThanEq creates a checker that checks that the tcp window
// is less than or equal to the provided value.
func TCPWindowLessThanEq(window uint16) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
tcp, ok := h.(header.TCP)
if !ok {
t.Fatalf("TCP header not found in h: %T", h)
}
if w := tcp.WindowSize(); w > window {
t.Errorf("Bad window, got %d, want < %d", w, window)
}
}
}
// TCPFlags creates a checker that checks the tcp flags.
func TCPFlags(flags uint8) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
tcp, ok := h.(header.TCP)
if !ok {
t.Fatalf("TCP header not found in h: %T", h)
}
if f := tcp.Flags(); f != flags {
t.Errorf("Bad flags, got 0x%x, want 0x%x", f, flags)
}
}
}
// TCPFlagsMatch creates a checker that checks that the tcp flags, masked by the
// given mask, match the supplied flags.
func TCPFlagsMatch(flags, mask uint8) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
tcp, ok := h.(header.TCP)
if !ok {
t.Fatalf("TCP header not found in h: %T", h)
}
if f := tcp.Flags(); (f & mask) != (flags & mask) {
t.Errorf("Bad masked flags, got 0x%x, want 0x%x, mask 0x%x", f, flags, mask)
}
}
}
// TCPSynOptions creates a checker that checks the presence of TCP options in
// SYN segments.
//
// If wndscale is negative, the window scale option must not be present.
func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
tcp, ok := h.(header.TCP)
if !ok {
return
}
opts := tcp.Options()
limit := len(opts)
foundMSS := false
foundWS := false
foundTS := false
foundSACKPermitted := false
tsVal := uint32(0)
tsEcr := uint32(0)
for i := 0; i < limit; {
switch opts[i] {
case header.TCPOptionEOL:
i = limit
case header.TCPOptionNOP:
i++
case header.TCPOptionMSS:
v := uint16(opts[i+2])<<8 | uint16(opts[i+3])
if wantOpts.MSS != v {
t.Errorf("Bad MSS, got = %d, want = %d", v, wantOpts.MSS)
}
foundMSS = true
i += 4
case header.TCPOptionWS:
if wantOpts.WS < 0 {
t.Error("WS present when it shouldn't be")
}
v := int(opts[i+2])
if v != wantOpts.WS {
t.Errorf("Bad WS, got = %d, want = %d", v, wantOpts.WS)
}
foundWS = true
i += 3
case header.TCPOptionTS:
if i+9 >= limit {
t.Errorf("TS Option truncated , option is only: %d bytes, want 10", limit-i)
}
if opts[i+1] != 10 {
t.Errorf("Bad length %d for TS option, limit: %d", opts[i+1], limit)
}
tsVal = binary.BigEndian.Uint32(opts[i+2:])
tsEcr = uint32(0)
if tcp.Flags()&header.TCPFlagAck != 0 {
// If the syn is an SYN-ACK then read
// the tsEcr value as well.
tsEcr = binary.BigEndian.Uint32(opts[i+6:])
}
foundTS = true
i += 10
case header.TCPOptionSACKPermitted:
if i+1 >= limit {
t.Errorf("SACKPermitted option truncated, option is only : %d bytes, want 2", limit-i)
}
if opts[i+1] != 2 {
t.Errorf("Bad length %d for SACKPermitted option, limit: %d", opts[i+1], limit)
}
foundSACKPermitted = true
i += 2
default:
i += int(opts[i+1])
}
}
if !foundMSS {
t.Errorf("MSS option not found. Options: %x", opts)
}
if !foundWS && wantOpts.WS >= 0 {
t.Errorf("WS option not found. Options: %x", opts)
}
if wantOpts.TS && !foundTS {
t.Errorf("TS option not found. Options: %x", opts)
}
if foundTS && tsVal == 0 {
t.Error("TS option specified but the timestamp value is zero")
}
if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 {
t.Errorf("TS option specified but TSEcr is incorrect, got = %d, want = %d", tsEcr, wantOpts.TSEcr)
}
if wantOpts.SACKPermitted && !foundSACKPermitted {
t.Errorf("SACKPermitted option not found. Options: %x", opts)
}
}
}
// TCPTimestampChecker creates a checker that validates that a TCP segment has a
// TCP Timestamp option if wantTS is true, it also compares the wantTSVal and
// wantTSEcr values with those in the TCP segment (if present).
//
// If wantTSVal or wantTSEcr is zero then the corresponding comparison is
// skipped.
func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
tcp, ok := h.(header.TCP)
if !ok {
return
}
opts := []byte(tcp.Options())
limit := len(opts)
foundTS := false
tsVal := uint32(0)
tsEcr := uint32(0)
for i := 0; i < limit; {
switch opts[i] {
case header.TCPOptionEOL:
i = limit
case header.TCPOptionNOP:
i++
case header.TCPOptionTS:
if i+9 >= limit {
t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i)
}
if opts[i+1] != 10 {
t.Errorf("TS option found, but bad length specified: got = %d, want = 10", opts[i+1])
}
tsVal = binary.BigEndian.Uint32(opts[i+2:])
tsEcr = binary.BigEndian.Uint32(opts[i+6:])
foundTS = true
i += 10
default:
// We don't recognize this option, just skip over it.
if i+2 > limit {
return
}
l := int(opts[i+1])
if i < 2 || i+l > limit {
return
}
i += l
}
}
if wantTS != foundTS {
t.Errorf("TS Option mismatch, got TS= %t, want TS= %t", foundTS, wantTS)
}
if wantTS && wantTSVal != 0 && wantTSVal != tsVal {
t.Errorf("Timestamp value is incorrect, got = %d, want = %d", tsVal, wantTSVal)
}
if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr {
t.Errorf("Timestamp Echo Reply is incorrect, got = %d, want = %d", tsEcr, wantTSEcr)
}
}
}
// TCPNoSACKBlockChecker creates a checker that verifies that the segment does
// not contain any SACK blocks in the TCP options.
func TCPNoSACKBlockChecker() TransportChecker {
return TCPSACKBlockChecker(nil)
}
// TCPSACKBlockChecker creates a checker that verifies that the segment does
// contain the specified SACK blocks in the TCP options.
func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
tcp, ok := h.(header.TCP)
if !ok {
return
}
var gotSACKBlocks []header.SACKBlock
opts := []byte(tcp.Options())
limit := len(opts)
for i := 0; i < limit; {
switch opts[i] {
case header.TCPOptionEOL:
i = limit
case header.TCPOptionNOP:
i++
case header.TCPOptionSACK:
if i+2 > limit {
// Malformed SACK block.
t.Errorf("malformed SACK option in options: %v", opts)
}
sackOptionLen := int(opts[i+1])
if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 {
// Malformed SACK block.
t.Errorf("malformed SACK option length in options: %v", opts)
}
numBlocks := sackOptionLen / 8
for j := 0; j < numBlocks; j++ {
start := binary.BigEndian.Uint32(opts[i+2+j*8:])
end := binary.BigEndian.Uint32(opts[i+2+j*8+4:])
gotSACKBlocks = append(gotSACKBlocks, header.SACKBlock{
Start: seqnum.Value(start),
End: seqnum.Value(end),
})
}
i += sackOptionLen
default:
// We don't recognize this option, just skip over it.
if i+2 > limit {
break
}
l := int(opts[i+1])
if l < 2 || i+l > limit {
break
}
i += l
}
}
if !reflect.DeepEqual(gotSACKBlocks, sackBlocks) {
t.Errorf("SACKBlocks are not equal, got = %v, want = %v", gotSACKBlocks, sackBlocks)
}
}
}
// Payload creates a checker that checks the payload.
func Payload(want []byte) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
if got := h.Payload(); !reflect.DeepEqual(got, want) {
t.Errorf("Wrong payload, got %v, want %v", got, want)
}
}
}
// ICMPv4 creates a checker that checks that the transport protocol is ICMPv4
// and potentially additional ICMPv4 header fields.
func ICMPv4(checkers ...TransportChecker) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
last := h[len(h)-1]
if p := last.TransportProtocol(); p != header.ICMPv4ProtocolNumber {
t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv4ProtocolNumber)
}
icmp := header.ICMPv4(last.Payload())
for _, f := range checkers {
f(t, icmp)
}
if t.Failed() {
t.FailNow()
}
}
}
// ICMPv4Type creates a checker that checks the ICMPv4 Type field.
func ICMPv4Type(want header.ICMPv4Type) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
icmpv4, ok := h.(header.ICMPv4)
if !ok {
t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
}
if got := icmpv4.Type(); got != want {
t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want)
}
}
}
// ICMPv4Code creates a checker that checks the ICMPv4 Code field.
func ICMPv4Code(want header.ICMPv4Code) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
icmpv4, ok := h.(header.ICMPv4)
if !ok {
t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
}
if got := icmpv4.Code(); got != want {
t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want)
}
}
}
// ICMPv4Ident creates a checker that checks the ICMPv4 echo Ident.
func ICMPv4Ident(want uint16) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
icmpv4, ok := h.(header.ICMPv4)
if !ok {
t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
}
if got := icmpv4.Ident(); got != want {
t.Fatalf("unexpected ICMP ident, got = %d, want = %d", got, want)
}
}
}
// ICMPv4Seq creates a checker that checks the ICMPv4 echo Sequence.
func ICMPv4Seq(want uint16) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
icmpv4, ok := h.(header.ICMPv4)
if !ok {
t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
}
if got := icmpv4.Sequence(); got != want {
t.Fatalf("unexpected ICMP sequence, got = %d, want = %d", got, want)
}
}
}
// ICMPv4Pointer creates a checker that checks the ICMPv4 Param Problem pointer.
func ICMPv4Pointer(want uint8) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
icmpv4, ok := h.(header.ICMPv4)
if !ok {
t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
}
if got := icmpv4.Pointer(); got != want {
t.Fatalf("unexpected ICMP Param Problem pointer, got = %d, want = %d", got, want)
}
}
}
// ICMPv4Checksum creates a checker that checks the ICMPv4 Checksum.
// This assumes that the payload exactly makes up the rest of the slice.
func ICMPv4Checksum() TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
icmpv4, ok := h.(header.ICMPv4)
if !ok {
t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
}
heldChecksum := icmpv4.Checksum()
icmpv4.SetChecksum(0)
newChecksum := ^header.Checksum(icmpv4, 0)
icmpv4.SetChecksum(heldChecksum)
if heldChecksum != newChecksum {
t.Errorf("unexpected ICMP checksum, got = %d, want = %d", heldChecksum, newChecksum)
}
}
}
// ICMPv4Payload creates a checker that checks the payload in an ICMPv4 packet.
func ICMPv4Payload(want []byte) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
icmpv4, ok := h.(header.ICMPv4)
if !ok {
t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
}
payload := icmpv4.Payload()
// cmp.Diff does not consider nil slices equal to empty slices, but we do.
if len(want) == 0 && len(payload) == 0 {
return
}
if diff := cmp.Diff(want, payload); diff != "" {
t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff)
}
}
}
// ICMPv6 creates a checker that checks that the transport protocol is ICMPv6 and
// potentially additional ICMPv6 header fields.
//
// ICMPv6 will validate the checksum field before calling checkers.
func ICMPv6(checkers ...TransportChecker) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
last := h[len(h)-1]
if p := last.TransportProtocol(); p != header.ICMPv6ProtocolNumber {
t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv6ProtocolNumber)
}
icmp := header.ICMPv6(last.Payload())
if got, want := icmp.Checksum(), header.ICMPv6Checksum(icmp, last.SourceAddress(), last.DestinationAddress(), buffer.VectorisedView{}); got != want {
t.Fatalf("Bad ICMPv6 checksum; got %d, want %d", got, want)
}
for _, f := range checkers {
f(t, icmp)
}
if t.Failed() {
t.FailNow()
}
}
}
// ICMPv6Type creates a checker that checks the ICMPv6 Type field.
func ICMPv6Type(want header.ICMPv6Type) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
icmpv6, ok := h.(header.ICMPv6)
if !ok {
t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
}
if got := icmpv6.Type(); got != want {
t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want)
}
}
}
// ICMPv6Code creates a checker that checks the ICMPv6 Code field.
func ICMPv6Code(want header.ICMPv6Code) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
icmpv6, ok := h.(header.ICMPv6)
if !ok {
t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
}
if got := icmpv6.Code(); got != want {
t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want)
}
}
}
// ICMPv6TypeSpecific creates a checker that checks the ICMPv6 TypeSpecific
// field.
func ICMPv6TypeSpecific(want uint32) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
icmpv6, ok := h.(header.ICMPv6)
if !ok {
t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
}
if got := icmpv6.TypeSpecific(); got != want {
t.Fatalf("unexpected ICMP TypeSpecific, got = %d, want = %d", got, want)
}
}
}
// ICMPv6Payload creates a checker that checks the payload in an ICMPv6 packet.
func ICMPv6Payload(want []byte) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
icmpv6, ok := h.(header.ICMPv6)
if !ok {
t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
}
payload := icmpv6.Payload()
// cmp.Diff does not consider nil slices equal to empty slices, but we do.
if len(want) == 0 && len(payload) == 0 {
return
}
if diff := cmp.Diff(want, payload); diff != "" {
t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff)
}
}
}
// NDP creates a checker that checks that the packet contains a valid NDP
// message for type of ty, with potentially additional checks specified by
// checkers.
//
// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
// NDP message as far as the size of the message (minSize) is concerned. The
// values within the message are up to checkers to validate.
func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
// Check normal ICMPv6 first.
ICMPv6(
ICMPv6Type(msgType),
ICMPv6Code(0))(t, h)
last := h[len(h)-1]
icmp := header.ICMPv6(last.Payload())
if got := len(icmp.NDPPayload()); got < minSize {
t.Fatalf("ICMPv6 NDP (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize)
}
for _, f := range checkers {
f(t, icmp)
}
if t.Failed() {
t.FailNow()
}
}
}
// NDPNS creates a checker that checks that the packet contains a valid NDP
// Neighbor Solicitation message (as per the raw wire format), with potentially
// additional checks specified by checkers.
//
// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
// NDPNS message as far as the size of the message is concerned. The values
// within the message are up to checkers to validate.
func NDPNS(checkers ...TransportChecker) NetworkChecker {
return NDP(header.ICMPv6NeighborSolicit, header.NDPNSMinimumSize, checkers...)
}
// NDPNSTargetAddress creates a checker that checks the Target Address field of
// a header.NDPNeighborSolicit.
//
// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
// containing a valid NDPNS message as far as the size is concerned.
func NDPNSTargetAddress(want tcpip.Address) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
icmp := h.(header.ICMPv6)
ns := header.NDPNeighborSolicit(icmp.NDPPayload())
if got := ns.TargetAddress(); got != want {
t.Errorf("got %T.TargetAddress() = %s, want = %s", ns, got, want)
}
}
}
// NDPNA creates a checker that checks that the packet contains a valid NDP
// Neighbor Advertisement message (as per the raw wire format), with potentially
// additional checks specified by checkers.
//
// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
// NDPNA message as far as the size of the message is concerned. The values
// within the message are up to checkers to validate.
func NDPNA(checkers ...TransportChecker) NetworkChecker {
return NDP(header.ICMPv6NeighborAdvert, header.NDPNAMinimumSize, checkers...)
}
// NDPNATargetAddress creates a checker that checks the Target Address field of
// a header.NDPNeighborAdvert.
//
// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
// containing a valid NDPNA message as far as the size is concerned.
func NDPNATargetAddress(want tcpip.Address) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
icmp := h.(header.ICMPv6)
na := header.NDPNeighborAdvert(icmp.NDPPayload())
if got := na.TargetAddress(); got != want {
t.Errorf("got %T.TargetAddress() = %s, want = %s", na, got, want)
}
}
}
// NDPNASolicitedFlag creates a checker that checks the Solicited field of
// a header.NDPNeighborAdvert.
//
// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
// containing a valid NDPNA message as far as the size is concerned.
func NDPNASolicitedFlag(want bool) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
icmp := h.(header.ICMPv6)
na := header.NDPNeighborAdvert(icmp.NDPPayload())
if got := na.SolicitedFlag(); got != want {
t.Errorf("got %T.SolicitedFlag = %t, want = %t", na, got, want)
}
}
}
// ndpOptions checks that optsBuf only contains opts.
func ndpOptions(t *testing.T, optsBuf header.NDPOptions, opts []header.NDPOption) {
t.Helper()
it, err := optsBuf.Iter(true)
if err != nil {
t.Errorf("optsBuf.Iter(true): %s", err)
return
}
i := 0
for {
opt, done, err := it.Next()
if err != nil {
// This should never happen as Iter(true) above did not return an error.
t.Fatalf("unexpected error when iterating over NDP options: %s", err)
}
if done {
break
}
if i >= len(opts) {
t.Errorf("got unexpected option: %s", opt)
continue
}
switch wantOpt := opts[i].(type) {
case header.NDPSourceLinkLayerAddressOption:
gotOpt, ok := opt.(header.NDPSourceLinkLayerAddressOption)
if !ok {
t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt)
} else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want {
t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want)
}
case header.NDPTargetLinkLayerAddressOption:
gotOpt, ok := opt.(header.NDPTargetLinkLayerAddressOption)
if !ok {
t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt)
} else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want {
t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want)
}
default:
t.Fatalf("checker not implemented for expected NDP option: %T", wantOpt)
}
i++
}
if missing := opts[i:]; len(missing) > 0 {
t.Errorf("missing options: %s", missing)
}
}
// NDPNAOptions creates a checker that checks that the packet contains the
// provided NDP options within an NDP Neighbor Solicitation message.
//
// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
// containing a valid NDPNA message as far as the size is concerned.
func NDPNAOptions(opts []header.NDPOption) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
icmp := h.(header.ICMPv6)
na := header.NDPNeighborAdvert(icmp.NDPPayload())
ndpOptions(t, na.Options(), opts)
}
}
// NDPNSOptions creates a checker that checks that the packet contains the
// provided NDP options within an NDP Neighbor Solicitation message.
//
// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
// containing a valid NDPNS message as far as the size is concerned.
func NDPNSOptions(opts []header.NDPOption) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
icmp := h.(header.ICMPv6)
ns := header.NDPNeighborSolicit(icmp.NDPPayload())
ndpOptions(t, ns.Options(), opts)
}
}
// NDPRS creates a checker that checks that the packet contains a valid NDP
// Router Solicitation message (as per the raw wire format).
//
// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
// NDPRS as far as the size of the message is concerned. The values within the
// message are up to checkers to validate.
func NDPRS(checkers ...TransportChecker) NetworkChecker {
return NDP(header.ICMPv6RouterSolicit, header.NDPRSMinimumSize, checkers...)
}
// NDPRSOptions creates a checker that checks that the packet contains the
// provided NDP options within an NDP Router Solicitation message.
//
// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
// containing a valid NDPRS message as far as the size is concerned.
func NDPRSOptions(opts []header.NDPOption) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
icmp := h.(header.ICMPv6)
rs := header.NDPRouterSolicit(icmp.NDPPayload())
ndpOptions(t, rs.Options(), opts)
}
}
// IGMP checks the validity and properties of the given IGMP packet. It is
// expected to be used in conjunction with other IGMP transport checkers for
// specific properties.
func IGMP(checkers ...TransportChecker) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
last := h[len(h)-1]
if p := last.TransportProtocol(); p != header.IGMPProtocolNumber {
t.Fatalf("Bad protocol, got %d, want %d", p, header.IGMPProtocolNumber)
}
igmp := header.IGMP(last.Payload())
for _, f := range checkers {
f(t, igmp)
}
if t.Failed() {
t.FailNow()
}
}
}
// IGMPType creates a checker that checks the IGMP Type field.
func IGMPType(want header.IGMPType) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
igmp, ok := h.(header.IGMP)
if !ok {
t.Fatalf("got transport header = %T, want = header.IGMP", h)
}
if got := igmp.Type(); got != want {
t.Errorf("got igmp.Type() = %d, want = %d", got, want)
}
}
}
// IGMPMaxRespTime creates a checker that checks the IGMP Max Resp Time field.
func IGMPMaxRespTime(want time.Duration) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
igmp, ok := h.(header.IGMP)
if !ok {
t.Fatalf("got transport header = %T, want = header.IGMP", h)
}
if got := igmp.MaxRespTime(); got != want {
t.Errorf("got igmp.MaxRespTime() = %s, want = %s", got, want)
}
}
}
// IGMPGroupAddress creates a checker that checks the IGMP Group Address field.
func IGMPGroupAddress(want tcpip.Address) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
igmp, ok := h.(header.IGMP)
if !ok {
t.Fatalf("got transport header = %T, want = header.IGMP", h)
}
if got := igmp.GroupAddress(); got != want {
t.Errorf("got igmp.GroupAddress() = %s, want = %s", got, want)
}
}
}