gvisor/pkg/tcpip/stack/transport_demuxer_test.go

386 lines
12 KiB
Go
Raw Normal View History

// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack_test
import (
"io/ioutil"
"math"
"math/rand"
"testing"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
const (
testSrcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
testDstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
testSrcAddrV4 = "\x0a\x00\x00\x01"
testDstAddrV4 = "\x0a\x00\x00\x02"
testDstPort = 1234
testSrcPort = 4096
)
type testContext struct {
linkEps map[tcpip.NICID]*channel.Endpoint
s *stack.Stack
wq waiter.Queue
}
// newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs.
func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
linkEps := make(map[tcpip.NICID]*channel.Endpoint)
for _, linkEpID := range linkEpIDs {
channelEp := channel.New(256, mtu, "")
if err := s.CreateNIC(linkEpID, channelEp); err != nil {
t.Fatalf("CreateNIC failed: %s", err)
}
linkEps[linkEpID] = channelEp
if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, testDstAddrV4); err != nil {
t.Fatalf("AddAddress IPv4 failed: %s", err)
}
if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, testDstAddrV6); err != nil {
t.Fatalf("AddAddress IPv6 failed: %s", err)
}
}
s.SetRouteTable([]tcpip.Route{
{Destination: header.IPv4EmptySubnet, NIC: 1},
{Destination: header.IPv6EmptySubnet, NIC: 1},
})
return &testContext{
s: s,
linkEps: linkEps,
}
}
type headers struct {
srcPort, dstPort uint16
}
func newPayload() []byte {
b := make([]byte, 30+rand.Intn(100))
for i := range b {
b[i] = byte(rand.Intn(256))
}
return b
}
func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NICID) {
buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
payloadStart := len(buf) - len(payload)
copy(buf[payloadStart:], payload)
// Initialize the IP header.
ip := header.IPv4(buf)
ip.Encode(&header.IPv4Fields{
TOS: 0x80,
TotalLength: uint16(len(buf)),
TTL: 65,
Protocol: uint8(udp.ProtocolNumber),
SrcAddr: testSrcAddrV4,
DstAddr: testDstAddrV4,
})
ip.SetChecksum(^ip.CalculateChecksum())
// Initialize the UDP header.
u := header.UDP(buf[header.IPv4MinimumSize:])
u.Encode(&header.UDPFields{
SrcPort: h.srcPort,
DstPort: h.dstPort,
Length: uint16(header.UDPMinimumSize + len(payload)),
})
// Calculate the UDP pseudo-header checksum.
xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV4, testDstAddrV4, uint16(len(u)))
// Calculate the UDP checksum and set it.
xsum = header.Checksum(payload, xsum)
u.SetChecksum(^u.CalculateChecksum(xsum))
// Inject packet.
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
})
c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, pkt)
}
func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
copy(buf[len(buf)-len(payload):], payload)
// Initialize the IP header.
ip := header.IPv6(buf)
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
TransportProtocol: udp.ProtocolNumber,
HopLimit: 65,
SrcAddr: testSrcAddrV6,
DstAddr: testDstAddrV6,
})
// Initialize the UDP header.
u := header.UDP(buf[header.IPv6MinimumSize:])
u.Encode(&header.UDPFields{
SrcPort: h.srcPort,
DstPort: h.dstPort,
Length: uint16(header.UDPMinimumSize + len(payload)),
})
// Calculate the UDP pseudo-header checksum.
xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV6, testDstAddrV6, uint16(len(u)))
// Calculate the UDP checksum and set it.
xsum = header.Checksum(payload, xsum)
u.SetChecksum(^u.CalculateChecksum(xsum))
// Inject packet.
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
Use PacketBuffers, rather than VectorisedViews, in netstack. PacketBuffers are analogous to Linux's sk_buff. They hold all information about a packet, headers, and payload. This is important for: * iptables to access various headers of packets * Preventing the clutter of passing different net and link headers along with VectorisedViews to packet handling functions. This change only affects the incoming packet path, and a future change will change the outgoing path. Benchmark Regular PacketBufferPtr PacketBufferConcrete -------------------------------------------------------------------------------- BM_Recvmsg 400.715MB/s 373.676MB/s 396.276MB/s BM_Sendmsg 361.832MB/s 333.003MB/s 335.571MB/s BM_Recvfrom 453.336MB/s 393.321MB/s 381.650MB/s BM_Sendto 378.052MB/s 372.134MB/s 341.342MB/s BM_SendmsgTCP/0/1k 353.711MB/s 316.216MB/s 322.747MB/s BM_SendmsgTCP/0/2k 600.681MB/s 588.776MB/s 565.050MB/s BM_SendmsgTCP/0/4k 995.301MB/s 888.808MB/s 941.888MB/s BM_SendmsgTCP/0/8k 1.517GB/s 1.274GB/s 1.345GB/s BM_SendmsgTCP/0/16k 1.872GB/s 1.586GB/s 1.698GB/s BM_SendmsgTCP/0/32k 1.017GB/s 1.020GB/s 1.133GB/s BM_SendmsgTCP/0/64k 475.626MB/s 584.587MB/s 627.027MB/s BM_SendmsgTCP/0/128k 416.371MB/s 503.434MB/s 409.850MB/s BM_SendmsgTCP/0/256k 323.449MB/s 449.599MB/s 388.852MB/s BM_SendmsgTCP/0/512k 243.992MB/s 267.676MB/s 314.474MB/s BM_SendmsgTCP/0/1M 95.138MB/s 95.874MB/s 95.417MB/s BM_SendmsgTCP/0/2M 96.261MB/s 94.977MB/s 96.005MB/s BM_SendmsgTCP/0/4M 96.512MB/s 95.978MB/s 95.370MB/s BM_SendmsgTCP/0/8M 95.603MB/s 95.541MB/s 94.935MB/s BM_SendmsgTCP/0/16M 94.598MB/s 94.696MB/s 94.521MB/s BM_SendmsgTCP/0/32M 94.006MB/s 94.671MB/s 94.768MB/s BM_SendmsgTCP/0/64M 94.133MB/s 94.333MB/s 94.746MB/s BM_SendmsgTCP/0/128M 93.615MB/s 93.497MB/s 93.573MB/s BM_SendmsgTCP/0/256M 93.241MB/s 95.100MB/s 93.272MB/s BM_SendmsgTCP/1/1k 303.644MB/s 316.074MB/s 308.430MB/s BM_SendmsgTCP/1/2k 537.093MB/s 584.962MB/s 529.020MB/s BM_SendmsgTCP/1/4k 882.362MB/s 939.087MB/s 892.285MB/s BM_SendmsgTCP/1/8k 1.272GB/s 1.394GB/s 1.296GB/s BM_SendmsgTCP/1/16k 1.802GB/s 2.019GB/s 1.830GB/s BM_SendmsgTCP/1/32k 2.084GB/s 2.173GB/s 2.156GB/s BM_SendmsgTCP/1/64k 2.515GB/s 2.463GB/s 2.473GB/s BM_SendmsgTCP/1/128k 2.811GB/s 3.004GB/s 2.946GB/s BM_SendmsgTCP/1/256k 3.008GB/s 3.159GB/s 3.171GB/s BM_SendmsgTCP/1/512k 2.980GB/s 3.150GB/s 3.126GB/s BM_SendmsgTCP/1/1M 2.165GB/s 2.233GB/s 2.163GB/s BM_SendmsgTCP/1/2M 2.370GB/s 2.219GB/s 2.453GB/s BM_SendmsgTCP/1/4M 2.005GB/s 2.091GB/s 2.214GB/s BM_SendmsgTCP/1/8M 2.111GB/s 2.013GB/s 2.109GB/s BM_SendmsgTCP/1/16M 1.902GB/s 1.868GB/s 1.897GB/s BM_SendmsgTCP/1/32M 1.655GB/s 1.665GB/s 1.635GB/s BM_SendmsgTCP/1/64M 1.575GB/s 1.547GB/s 1.575GB/s BM_SendmsgTCP/1/128M 1.524GB/s 1.584GB/s 1.580GB/s BM_SendmsgTCP/1/256M 1.579GB/s 1.607GB/s 1.593GB/s PiperOrigin-RevId: 278940079
2019-11-06 22:24:38 +00:00
})
c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, pkt)
}
func TestTransportDemuxerRegister(t *testing.T) {
for _, test := range []struct {
name string
proto tcpip.NetworkProtocolNumber
want tcpip.Error
}{
{"failure", ipv6.ProtocolNumber, &tcpip.ErrUnknownProtocol{}},
{"success", ipv4.ProtocolNumber, nil},
} {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
var wq waiter.Queue
ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
t.Fatal(err)
}
tEP, ok := ep.(stack.TransportEndpoint)
if !ok {
t.Fatalf("%T does not implement stack.TransportEndpoint", ep)
}
if got, want := s.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, ports.Flags{}, 0), test.want; got != want {
t.Fatalf("s.RegisterTransportEndpoint(...) = %s, want %s", got, want)
}
})
}
}
// TestBindToDeviceDistribution injects varied packets on input devices and checks that
// the distribution of packets received matches expectations.
func TestBindToDeviceDistribution(t *testing.T) {
type endpointSockopts struct {
reuse bool
bindToDevice tcpip.NICID
}
for _, test := range []struct {
name string
// endpoints will received the inject packets.
endpoints []endpointSockopts
// wantDistributions is the want ratio of packets received on each
// endpoint for each NIC on which packets are injected.
wantDistributions map[tcpip.NICID][]float64
}{
{
"BindPortReuse",
// 5 endpoints that all have reuse set.
[]endpointSockopts{
{reuse: true, bindToDevice: 0},
{reuse: true, bindToDevice: 0},
{reuse: true, bindToDevice: 0},
{reuse: true, bindToDevice: 0},
{reuse: true, bindToDevice: 0},
},
map[tcpip.NICID][]float64{
// Injected packets on dev0 get distributed evenly.
1: {0.2, 0.2, 0.2, 0.2, 0.2},
},
},
{
"BindToDevice",
// 3 endpoints with various bindings.
[]endpointSockopts{
{reuse: false, bindToDevice: 1},
{reuse: false, bindToDevice: 2},
{reuse: false, bindToDevice: 3},
},
map[tcpip.NICID][]float64{
// Injected packets on dev0 go only to the endpoint bound to dev0.
1: {1, 0, 0},
// Injected packets on dev1 go only to the endpoint bound to dev1.
2: {0, 1, 0},
// Injected packets on dev2 go only to the endpoint bound to dev2.
3: {0, 0, 1},
},
},
{
"ReuseAndBindToDevice",
// 6 endpoints with various bindings.
[]endpointSockopts{
{reuse: true, bindToDevice: 1},
{reuse: true, bindToDevice: 1},
{reuse: true, bindToDevice: 2},
{reuse: true, bindToDevice: 2},
{reuse: true, bindToDevice: 2},
{reuse: true, bindToDevice: 0},
},
map[tcpip.NICID][]float64{
// Injected packets on dev0 get distributed among endpoints bound to
// dev0.
1: {0.5, 0.5, 0, 0, 0, 0},
// Injected packets on dev1 get distributed among endpoints bound to
// dev1 or unbound.
2: {0, 0, 1. / 3, 1. / 3, 1. / 3, 0},
// Injected packets on dev999 go only to the unbound.
1000: {0, 0, 0, 0, 0, 1},
},
},
} {
for protoName, netProtoNum := range map[string]tcpip.NetworkProtocolNumber{
"IPv4": ipv4.ProtocolNumber,
"IPv6": ipv6.ProtocolNumber,
} {
for device, wantDistribution := range test.wantDistributions {
t.Run(test.name+protoName+string(device), func(t *testing.T) {
var devices []tcpip.NICID
for d := range test.wantDistributions {
devices = append(devices, d)
}
c := newDualTestContextMultiNIC(t, defaultMTU, devices)
eps := make(map[tcpip.Endpoint]int)
pollChannel := make(chan tcpip.Endpoint)
for i, endpoint := range test.endpoints {
// Try to receive the data.
wq := waiter.Queue{}
we, ch := waiter.NewChannelEntry(nil)
wq.EventRegister(&we, waiter.ReadableEvents)
defer wq.EventUnregister(&we)
defer close(ch)
var err tcpip.Error
ep, err := c.s.NewEndpoint(udp.ProtocolNumber, netProtoNum, &wq)
if err != nil {
t.Fatalf("NewEndpoint failed: %s", err)
}
eps[ep] = i
go func(ep tcpip.Endpoint) {
for range ch {
pollChannel <- ep
}
}(ep)
defer ep.Close()
ep.SocketOptions().SetReusePort(endpoint.reuse)
if err := ep.SocketOptions().SetBindToDevice(int32(endpoint.bindToDevice)); err != nil {
t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", endpoint.bindToDevice, endpoint.bindToDevice, i, err)
}
var dstAddr tcpip.Address
switch netProtoNum {
case ipv4.ProtocolNumber:
dstAddr = testDstAddrV4
case ipv6.ProtocolNumber:
dstAddr = testDstAddrV6
default:
t.Fatalf("unexpected protocol number: %d", netProtoNum)
}
if err := ep.Bind(tcpip.FullAddress{Addr: dstAddr, Port: testDstPort}); err != nil {
t.Fatalf("ep.Bind(...) on endpoint %d failed: %s", i, err)
}
}
npackets := 100000
nports := 10000
if got, want := len(test.endpoints), len(wantDistribution); got != want {
t.Fatalf("got len(test.endpoints) = %d, want %d", got, want)
}
ports := make(map[uint16]tcpip.Endpoint)
stats := make(map[tcpip.Endpoint]int)
for i := 0; i < npackets; i++ {
// Send a packet.
port := uint16(i % nports)
payload := newPayload()
hdrs := &headers{
srcPort: testSrcPort + port,
dstPort: testDstPort,
}
switch netProtoNum {
case ipv4.ProtocolNumber:
c.sendV4Packet(payload, hdrs, device)
case ipv6.ProtocolNumber:
c.sendV6Packet(payload, hdrs, device)
default:
t.Fatalf("unexpected protocol number: %d", netProtoNum)
}
ep := <-pollChannel
if _, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != nil {
t.Fatalf("Read on endpoint %d failed: %s", eps[ep], err)
}
stats[ep]++
if i < nports {
ports[uint16(i)] = ep
} else {
// Check that all packets from one client are handled by the same
// socket.
if want, got := ports[port], ep; want != got {
t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got])
}
}
}
// Check that a packet distribution is as expected.
for ep, i := range eps {
wantRatio := wantDistribution[i]
wantRecv := wantRatio * float64(npackets)
actualRecv := stats[ep]
actualRatio := float64(stats[ep]) / float64(npackets)
// The deviation is less than 10%.
if math.Abs(actualRatio-wantRatio) > 0.05 {
t.Errorf("want about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantRatio*100, wantRecv, npackets, i, actualRatio*100, actualRecv, npackets)
}
}
})
}
}
}
}