Add IPv4 to bind_to_device distribution test

PiperOrigin-RevId: 303156734
This commit is contained in:
Jay Zhuang 2020-03-26 11:28:05 -07:00 committed by gVisor bot
parent bc3def43c3
commit d5ef8091b4
1 changed files with 86 additions and 21 deletions

View File

@ -31,12 +31,14 @@ import (
)
const (
stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
testSrcAddrV6 = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
testDstAddrV6 = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
stackAddr = "\x0a\x00\x00\x01"
stackPort = 1234
testPort = 4096
testSrcAddrV4 = "\x0a\x00\x00\x01"
testDstAddrV4 = "\x0a\x00\x00\x02"
testDstPort = 1234
testSrcPort = 4096
)
type testContext struct {
@ -59,11 +61,11 @@ func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICI
}
linkEps[linkEpID] = channelEp
if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, stackAddr); err != nil {
if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, testDstAddrV4); err != nil {
t.Fatalf("AddAddress IPv4 failed: %s", err)
}
if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, stackV6Addr); err != nil {
if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, testDstAddrV6); err != nil {
t.Fatalf("AddAddress IPv6 failed: %s", err)
}
}
@ -91,6 +93,47 @@ func newPayload() []byte {
return b
}
func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NICID) {
buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
payloadStart := len(buf) - len(payload)
copy(buf[payloadStart:], payload)
// Initialize the IP header.
ip := header.IPv4(buf)
ip.Encode(&header.IPv4Fields{
IHL: header.IPv4MinimumSize,
TOS: 0x80,
TotalLength: uint16(len(buf)),
TTL: 65,
Protocol: uint8(udp.ProtocolNumber),
SrcAddr: testSrcAddrV4,
DstAddr: testDstAddrV4,
})
ip.SetChecksum(^ip.CalculateChecksum())
// Initialize the UDP header.
u := header.UDP(buf[header.IPv4MinimumSize:])
u.Encode(&header.UDPFields{
SrcPort: h.srcPort,
DstPort: h.dstPort,
Length: uint16(header.UDPMinimumSize + len(payload)),
})
// Calculate the UDP pseudo-header checksum.
xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV4, testDstAddrV4, uint16(len(u)))
// Calculate the UDP checksum and set it.
xsum = header.Checksum(payload, xsum)
u.SetChecksum(^u.CalculateChecksum(xsum))
// Inject packet.
c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
NetworkHeader: buffer.View(ip),
TransportHeader: buffer.View(u),
})
}
func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
@ -102,8 +145,8 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI
PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
NextHeader: uint8(udp.ProtocolNumber),
HopLimit: 65,
SrcAddr: testV6Addr,
DstAddr: stackV6Addr,
SrcAddr: testSrcAddrV6,
DstAddr: testDstAddrV6,
})
// Initialize the UDP header.
@ -115,7 +158,7 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI
})
// Calculate the UDP pseudo-header checksum.
xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u)))
xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV6, testDstAddrV6, uint16(len(u)))
// Calculate the UDP checksum and set it.
xsum = header.Checksum(payload, xsum)
@ -123,7 +166,9 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI
// Inject packet.
c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
Data: buf.ToVectorisedView(),
NetworkHeader: buffer.View(ip),
TransportHeader: buffer.View(u),
})
}
@ -227,9 +272,12 @@ func TestBindToDeviceDistribution(t *testing.T) {
},
},
} {
t.Run(test.name, func(t *testing.T) {
for protoName, netProtoNum := range map[string]tcpip.NetworkProtocolNumber{
"IPv4": ipv4.ProtocolNumber,
"IPv6": ipv6.ProtocolNumber,
} {
for device, wantDistribution := range test.wantDistributions {
t.Run(string(device), func(t *testing.T) {
t.Run(test.name+protoName+string(device), func(t *testing.T) {
var devices []tcpip.NICID
for d := range test.wantDistributions {
devices = append(devices, d)
@ -248,7 +296,7 @@ func TestBindToDeviceDistribution(t *testing.T) {
defer close(ch)
var err *tcpip.Error
ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq)
ep, err := c.s.NewEndpoint(udp.ProtocolNumber, netProtoNum, &wq)
if err != nil {
t.Fatalf("NewEndpoint failed: %s", err)
}
@ -269,7 +317,17 @@ func TestBindToDeviceDistribution(t *testing.T) {
if err := ep.SetSockOpt(bindToDeviceOption); err != nil {
t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %s", bindToDeviceOption, i, err)
}
if err := ep.Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil {
var dstAddr tcpip.Address
switch netProtoNum {
case ipv4.ProtocolNumber:
dstAddr = testDstAddrV4
case ipv6.ProtocolNumber:
dstAddr = testDstAddrV6
default:
t.Fatalf("unexpected protocol number: %d", netProtoNum)
}
if err := ep.Bind(tcpip.FullAddress{Addr: dstAddr, Port: testDstPort}); err != nil {
t.Fatalf("ep.Bind(...) on endpoint %d failed: %s", i, err)
}
}
@ -285,11 +343,18 @@ func TestBindToDeviceDistribution(t *testing.T) {
// Send a packet.
port := uint16(i % nports)
payload := newPayload()
c.sendV6Packet(payload,
&headers{
srcPort: testPort + port,
dstPort: stackPort},
device)
hdrs := &headers{
srcPort: testSrcPort + port,
dstPort: testDstPort,
}
switch netProtoNum {
case ipv4.ProtocolNumber:
c.sendV4Packet(payload, hdrs, device)
case ipv6.ProtocolNumber:
c.sendV6Packet(payload, hdrs, device)
default:
t.Fatalf("unexpected protocol number: %d", netProtoNum)
}
ep := <-pollChannel
if _, _, err := ep.Read(nil); err != nil {
@ -320,6 +385,6 @@ func TestBindToDeviceDistribution(t *testing.T) {
}
})
}
})
}
}
}