Honour the link's MaxHeaderLength when forwarding

LinkEndpoints may expect/assume that the a tcpip.PacketBuffer's Header
has enough capacity for its own headers, as per documentation for
LinkEndpoint.MaxHeaderLength.

Test: stack_test.TestNICForwarding
PiperOrigin-RevId: 300784192
This commit is contained in:
Ghanan Gowripalan 2020-03-13 10:43:09 -07:00 committed by gVisor bot
parent 8f8f16efaf
commit 28d26d2c4f
2 changed files with 87 additions and 43 deletions

View File

@ -15,6 +15,7 @@
package stack
import (
"fmt"
"log"
"reflect"
"sort"
@ -1259,9 +1260,24 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
// TODO(b/143425874) Decrease the TTL field in forwarded packets.
pkt.Header = buffer.NewPrependableFromView(pkt.Data.First())
firstData := pkt.Data.First()
pkt.Data.RemoveFirst()
if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen == 0 {
pkt.Header = buffer.NewPrependableFromView(firstData)
} else {
firstDataLen := len(firstData)
// pkt.Header should have enough capacity to hold n.linkEP's headers.
pkt.Header = buffer.NewPrependable(firstDataLen + linkHeaderLen)
// TODO(b/151227689): avoid copying the packet when forwarding
if n := copy(pkt.Header.Prepend(firstDataLen), firstData); n != firstDataLen {
panic(fmt.Sprintf("copied %d bytes, expected %d", n, firstDataLen))
}
}
if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
return

View File

@ -2240,56 +2240,84 @@ func TestNICStats(t *testing.T) {
}
func TestNICForwarding(t *testing.T) {
// Create a stack with the fake network protocol, two NICs, each with
// an address.
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
})
s.SetForwarding(true)
const nicID1 = 1
const nicID2 = 2
const dstAddr = tcpip.Address("\x03")
ep1 := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep1); err != nil {
t.Fatal("CreateNIC #1 failed:", err)
}
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
t.Fatal("AddAddress #1 failed:", err)
tests := []struct {
name string
headerLen uint16
}{
{
name: "Zero header length",
},
{
name: "Non-zero header length",
headerLen: 16,
},
}
ep2 := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(2, ep2); err != nil {
t.Fatal("CreateNIC #2 failed:", err)
}
if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
t.Fatal("AddAddress #2 failed:", err)
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
})
s.SetForwarding(true)
// Route all packets to address 3 to NIC 2.
{
subnet, err := tcpip.NewSubnet("\x03", "\xff")
if err != nil {
t.Fatal(err)
}
s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 2}})
}
ep1 := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(nicID1, ep1); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
}
if err := s.AddAddress(nicID1, fakeNetNumber, "\x01"); err != nil {
t.Fatalf("AddAddress(%d, %d, 0x01): %s", nicID1, fakeNetNumber, err)
}
// Send a packet to address 3.
buf := buffer.NewView(30)
buf[0] = 3
ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
Data: buf.ToVectorisedView(),
})
ep2 := channelLinkWithHeaderLength{
Endpoint: channel.New(10, defaultMTU, ""),
headerLength: test.headerLen,
}
if err := s.CreateNIC(nicID2, &ep2); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
}
if err := s.AddAddress(nicID2, fakeNetNumber, "\x02"); err != nil {
t.Fatalf("AddAddress(%d, %d, 0x02): %s", nicID2, fakeNetNumber, err)
}
if _, ok := ep2.Read(); !ok {
t.Fatal("Packet not forwarded")
}
// Route all packets to dstAddr to NIC 2.
{
subnet, err := tcpip.NewSubnet(dstAddr, "\xff")
if err != nil {
t.Fatal(err)
}
s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID2}})
}
// Test that forwarding increments Tx stats correctly.
if got, want := s.NICInfo()[2].Stats.Tx.Packets.Value(), uint64(1); got != want {
t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want)
}
// Send a packet to dstAddr.
buf := buffer.NewView(30)
buf[0] = dstAddr[0]
ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if got, want := s.NICInfo()[2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want {
t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
pkt, ok := ep2.Read()
if !ok {
t.Fatal("packet not forwarded")
}
// Test that the link's MaxHeaderLength is honoured.
if capacity, want := pkt.Pkt.Header.AvailableLength(), int(test.headerLen); capacity != want {
t.Errorf("got Header.AvailableLength() = %d, want = %d", capacity, want)
}
// Test that forwarding increments Tx stats correctly.
if got, want := s.NICInfo()[nicID2].Stats.Tx.Packets.Value(), uint64(1); got != want {
t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want)
}
if got, want := s.NICInfo()[nicID2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want {
t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
}
})
}
}