The packet forwarding should resolve the link address if necessary.

Fixes #1510

Test:
- stack_test.TestForwardingWithStaticResolver
- stack_test.TestForwardingWithFakeResolver
- stack_test.TestForwardingWithNoResolver
- stack_test.TestForwardingWithFakeResolverPartialTimeout
- stack_test.TestForwardingWithFakeResolverTwoPackets
- stack_test.TestForwardingWithFakeResolverManyPackets
- stack_test.TestForwardingWithFakeResolverManyResolutions
PiperOrigin-RevId: 300182570
This commit is contained in:
gVisor bot 2020-03-10 14:49:16 -07:00
parent 0990ef7517
commit d6440ec5a1
5 changed files with 809 additions and 17 deletions

View File

@ -19,6 +19,7 @@ go_library(
name = "stack",
srcs = [
"dhcpv6configurationfromndpra_string.go",
"forwarder.go",
"icmp_rate_limit.go",
"linkaddrcache.go",
"linkaddrentry_list.go",
@ -80,6 +81,7 @@ go_test(
name = "stack_test",
size = "small",
srcs = [
"forwarder_test.go",
"linkaddrcache_test.go",
"nic_test.go",
],

View File

@ -0,0 +1,131 @@
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"fmt"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
)
const (
// maxPendingResolutions is the maximum number of pending link-address
// resolutions.
maxPendingResolutions = 64
maxPendingPacketsPerResolution = 256
)
type pendingPacket struct {
nic *NIC
route *Route
proto tcpip.NetworkProtocolNumber
pkt tcpip.PacketBuffer
}
type forwardQueue struct {
sync.Mutex
// The packets to send once the resolver completes.
packets map[<-chan struct{}][]*pendingPacket
// FIFO of channels used to cancel the oldest goroutine waiting for
// link-address resolution.
cancelChans []chan struct{}
}
func newForwardQueue() *forwardQueue {
return &forwardQueue{packets: make(map[<-chan struct{}][]*pendingPacket)}
}
func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
shouldWait := false
f.Lock()
packets, ok := f.packets[ch]
if !ok {
shouldWait = true
}
for len(packets) == maxPendingPacketsPerResolution {
p := packets[0]
packets = packets[1:]
p.nic.stack.stats.IP.OutgoingPacketErrors.Increment()
p.route.Release()
}
if l := len(packets); l >= maxPendingPacketsPerResolution {
panic(fmt.Sprintf("max pending packets for resolution reached; got %d packets, max = %d", l, maxPendingPacketsPerResolution))
}
f.packets[ch] = append(packets, &pendingPacket{
nic: n,
route: r,
proto: protocol,
pkt: pkt,
})
f.Unlock()
if !shouldWait {
return
}
// Wait for the link-address resolution to complete.
// Start a goroutine with a forwarding-cancel channel so that we can
// limit the maximum number of goroutines running concurrently.
cancel := f.newCancelChannel()
go func() {
cancelled := false
select {
case <-ch:
case <-cancel:
cancelled = true
}
f.Lock()
packets := f.packets[ch]
delete(f.packets, ch)
f.Unlock()
for _, p := range packets {
if cancelled {
p.nic.stack.stats.IP.OutgoingPacketErrors.Increment()
} else if _, err := p.route.Resolve(nil); err != nil {
p.nic.stack.stats.IP.OutgoingPacketErrors.Increment()
} else {
p.nic.forwardPacket(p.route, p.proto, p.pkt)
}
p.route.Release()
}
}()
}
// newCancelChannel creates a channel that can cancel a pending forwarding
// activity. The oldest channel is closed if the number of open channels would
// exceed maxPendingResolutions.
func (f *forwardQueue) newCancelChannel() chan struct{} {
f.Lock()
defer f.Unlock()
if len(f.cancelChans) == maxPendingResolutions {
ch := f.cancelChans[0]
f.cancelChans = f.cancelChans[1:]
close(ch)
}
if l := len(f.cancelChans); l >= maxPendingResolutions {
panic(fmt.Sprintf("max pending resolutions reached; got %d active resolutions, max = %d", l, maxPendingResolutions))
}
ch := make(chan struct{})
f.cancelChans = append(f.cancelChans, ch)
return ch
}

View File

@ -0,0 +1,635 @@
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"encoding/binary"
"math"
"testing"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
const (
fwdTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32
fwdTestNetHeaderLen = 12
fwdTestNetDefaultPrefixLen = 8
// fwdTestNetDefaultMTU is the MTU, in bytes, used throughout the tests,
// except where another value is explicitly used. It is chosen to match
// the MTU of loopback interfaces on linux systems.
fwdTestNetDefaultMTU = 65536
)
// fwdTestNetworkEndpoint is a network-layer protocol endpoint.
// Headers of this protocol are fwdTestNetHeaderLen bytes, but we currently only
// use the first three: destination address, source address, and transport
// protocol. They're all one byte fields to simplify parsing.
type fwdTestNetworkEndpoint struct {
nicID tcpip.NICID
id NetworkEndpointID
prefixLen int
proto *fwdTestNetworkProtocol
dispatcher TransportDispatcher
ep LinkEndpoint
}
func (f *fwdTestNetworkEndpoint) MTU() uint32 {
return f.ep.MTU() - uint32(f.MaxHeaderLength())
}
func (f *fwdTestNetworkEndpoint) NICID() tcpip.NICID {
return f.nicID
}
func (f *fwdTestNetworkEndpoint) PrefixLen() int {
return f.prefixLen
}
func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 {
return 123
}
func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID {
return &f.id
}
func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt tcpip.PacketBuffer) {
// Consume the network header.
b := pkt.Data.First()
pkt.Data.TrimFront(fwdTestNetHeaderLen)
// Dispatch the packet to the transport protocol.
f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt)
}
func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 {
return f.ep.MaxHeaderLength() + fwdTestNetHeaderLen
}
func (f *fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
return 0
}
func (f *fwdTestNetworkEndpoint) Capabilities() LinkEndpointCapabilities {
return f.ep.Capabilities()
}
func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error {
// Add the protocol's header to the packet and send it to the link
// endpoint.
b := pkt.Header.Prepend(fwdTestNetHeaderLen)
b[0] = r.RemoteAddress[0]
b[1] = f.id.LocalAddress[0]
b[2] = byte(params.Protocol)
return f.ep.WritePacket(r, gso, fwdTestNetNumber, pkt)
}
// WritePackets implements LinkEndpoint.WritePackets.
func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) {
panic("not implemented")
}
func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt tcpip.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
func (*fwdTestNetworkEndpoint) Close() {}
// fwdTestNetworkProtocol is a network-layer protocol that implements Address
// resolution.
type fwdTestNetworkProtocol struct {
addrCache *linkAddrCache
addrResolveDelay time.Duration
onLinkAddressResolved func(cache *linkAddrCache, addr tcpip.Address)
onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool)
}
func (f *fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
return fwdTestNetNumber
}
func (f *fwdTestNetworkProtocol) MinimumPacketSize() int {
return fwdTestNetHeaderLen
}
func (f *fwdTestNetworkProtocol) DefaultPrefixLen() int {
return fwdTestNetDefaultPrefixLen
}
func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
return tcpip.Address(v[1:2]), tcpip.Address(v[0:1])
}
func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) {
return &fwdTestNetworkEndpoint{
nicID: nicID,
id: NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
prefixLen: addrWithPrefix.PrefixLen,
proto: f,
dispatcher: dispatcher,
ep: ep,
}, nil
}
func (f *fwdTestNetworkProtocol) SetOption(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
func (f *fwdTestNetworkProtocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
func (f *fwdTestNetworkProtocol) Close() {}
func (f *fwdTestNetworkProtocol) Wait() {}
func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error {
if f.addrCache != nil && f.onLinkAddressResolved != nil {
time.AfterFunc(f.addrResolveDelay, func() {
f.onLinkAddressResolved(f.addrCache, addr)
})
}
return nil
}
func (f *fwdTestNetworkProtocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
if f.onResolveStaticAddress != nil {
return f.onResolveStaticAddress(addr)
}
return "", false
}
func (f *fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
return fwdTestNetNumber
}
// fwdTestPacketInfo holds all the information about an outbound packet.
type fwdTestPacketInfo struct {
RemoteLinkAddress tcpip.LinkAddress
LocalLinkAddress tcpip.LinkAddress
Pkt tcpip.PacketBuffer
}
type fwdTestLinkEndpoint struct {
dispatcher NetworkDispatcher
mtu uint32
linkAddr tcpip.LinkAddress
// C is where outbound packets are queued.
C chan fwdTestPacketInfo
}
// InjectInbound injects an inbound packet.
func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
e.InjectLinkAddr(protocol, "", pkt)
}
// InjectLinkAddr injects an inbound packet with a remote link address.
func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt tcpip.PacketBuffer) {
e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, pkt)
}
// Attach saves the stack network-layer dispatcher for use later when packets
// are injected.
func (e *fwdTestLinkEndpoint) Attach(dispatcher NetworkDispatcher) {
e.dispatcher = dispatcher
}
// IsAttached implements stack.LinkEndpoint.IsAttached.
func (e *fwdTestLinkEndpoint) IsAttached() bool {
return e.dispatcher != nil
}
// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
// during construction.
func (e *fwdTestLinkEndpoint) MTU() uint32 {
return e.mtu
}
// Capabilities implements stack.LinkEndpoint.Capabilities.
func (e fwdTestLinkEndpoint) Capabilities() LinkEndpointCapabilities {
caps := LinkEndpointCapabilities(0)
return caps | CapabilityResolutionRequired
}
// GSOMaxSize returns the maximum GSO packet size.
func (*fwdTestLinkEndpoint) GSOMaxSize() uint32 {
return 1 << 15
}
// MaxHeaderLength returns the maximum size of the link layer header. Given it
// doesn't have a header, it just returns 0.
func (*fwdTestLinkEndpoint) MaxHeaderLength() uint16 {
return 0
}
// LinkAddress returns the link address of this endpoint.
func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress {
return e.linkAddr
}
func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
p := fwdTestPacketInfo{
RemoteLinkAddress: r.RemoteLinkAddress,
LocalLinkAddress: r.LocalLinkAddress,
Pkt: pkt,
}
select {
case e.C <- p:
default:
}
return nil
}
// WritePackets stores outbound packets into the channel.
func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
n := 0
for _, pkt := range pkts {
e.WritePacket(r, gso, protocol, pkt)
n++
}
return n, nil
}
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
p := fwdTestPacketInfo{
Pkt: tcpip.PacketBuffer{Data: vv},
}
select {
case e.C <- p:
default:
}
return nil
}
// Wait implements stack.LinkEndpoint.Wait.
func (*fwdTestLinkEndpoint) Wait() {}
func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) {
// Create a stack with the network protocol and two NICs.
s := New(Options{
NetworkProtocols: []NetworkProtocol{proto},
})
proto.addrCache = s.linkAddrCache
// Enable forwarding.
s.SetForwarding(true)
// NIC 1 has the link address "a", and added the network address 1.
ep1 = &fwdTestLinkEndpoint{
C: make(chan fwdTestPacketInfo, 300),
mtu: fwdTestNetDefaultMTU,
linkAddr: "a",
}
if err := s.CreateNIC(1, ep1); err != nil {
t.Fatal("CreateNIC #1 failed:", err)
}
if err := s.AddAddress(1, fwdTestNetNumber, "\x01"); err != nil {
t.Fatal("AddAddress #1 failed:", err)
}
// NIC 2 has the link address "b", and added the network address 2.
ep2 = &fwdTestLinkEndpoint{
C: make(chan fwdTestPacketInfo, 300),
mtu: fwdTestNetDefaultMTU,
linkAddr: "b",
}
if err := s.CreateNIC(2, ep2); err != nil {
t.Fatal("CreateNIC #2 failed:", err)
}
if err := s.AddAddress(2, fwdTestNetNumber, "\x02"); err != nil {
t.Fatal("AddAddress #2 failed:", err)
}
// Route all packets to NIC 2.
{
subnet, err := tcpip.NewSubnet("\x00", "\x00")
if err != nil {
t.Fatal(err)
}
s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: 2}})
}
return ep1, ep2
}
func TestForwardingWithStaticResolver(t *testing.T) {
// Create a network protocol with a static resolver.
proto := &fwdTestNetworkProtocol{
onResolveStaticAddress:
// The network address 3 is resolved to the link address "c".
func(addr tcpip.Address) (tcpip.LinkAddress, bool) {
if addr == "\x03" {
return "c", true
}
return "", false
},
}
ep1, ep2 := fwdTestNetFactory(t, proto)
// Inject an inbound packet to address 3 on NIC 1, and see if it is
// forwarded to NIC 2.
buf := buffer.NewView(30)
buf[0] = 3
ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{
Data: buf.ToVectorisedView(),
})
var p fwdTestPacketInfo
select {
case p = <-ep2.C:
default:
t.Fatal("packet not forwarded")
}
// Test that the static address resolution happened correctly.
if p.RemoteLinkAddress != "c" {
t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
}
if p.LocalLinkAddress != "b" {
t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
}
}
func TestForwardingWithFakeResolver(t *testing.T) {
// Create a network protocol with a fake resolver.
proto := &fwdTestNetworkProtocol{
addrResolveDelay: 500 * time.Millisecond,
onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
// Any address will be resolved to the link address "c".
cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
},
}
ep1, ep2 := fwdTestNetFactory(t, proto)
// Inject an inbound packet to address 3 on NIC 1, and see if it is
// forwarded to NIC 2.
buf := buffer.NewView(30)
buf[0] = 3
ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{
Data: buf.ToVectorisedView(),
})
var p fwdTestPacketInfo
select {
case p = <-ep2.C:
case <-time.After(time.Second):
t.Fatal("packet not forwarded")
}
// Test that the address resolution happened correctly.
if p.RemoteLinkAddress != "c" {
t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
}
if p.LocalLinkAddress != "b" {
t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
}
}
func TestForwardingWithNoResolver(t *testing.T) {
// Create a network protocol without a resolver.
proto := &fwdTestNetworkProtocol{}
ep1, ep2 := fwdTestNetFactory(t, proto)
// inject an inbound packet to address 3 on NIC 1, and see if it is
// forwarded to NIC 2.
buf := buffer.NewView(30)
buf[0] = 3
ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{
Data: buf.ToVectorisedView(),
})
select {
case <-ep2.C:
t.Fatal("Packet should not be forwarded")
case <-time.After(time.Second):
}
}
func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
// Create a network protocol with a fake resolver.
proto := &fwdTestNetworkProtocol{
addrResolveDelay: 500 * time.Millisecond,
onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
// Only packets to address 3 will be resolved to the
// link address "c".
if addr == "\x03" {
cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
}
},
}
ep1, ep2 := fwdTestNetFactory(t, proto)
// Inject an inbound packet to address 4 on NIC 1. This packet should
// not be forwarded.
buf := buffer.NewView(30)
buf[0] = 4
ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{
Data: buf.ToVectorisedView(),
})
// Inject an inbound packet to address 3 on NIC 1, and see if it is
// forwarded to NIC 2.
buf = buffer.NewView(30)
buf[0] = 3
ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{
Data: buf.ToVectorisedView(),
})
var p fwdTestPacketInfo
select {
case p = <-ep2.C:
case <-time.After(time.Second):
t.Fatal("packet not forwarded")
}
b := p.Pkt.Header.View()
if b[0] != 3 {
t.Fatalf("got b[0] = %d, want = 3", b[0])
}
// Test that the address resolution happened correctly.
if p.RemoteLinkAddress != "c" {
t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
}
if p.LocalLinkAddress != "b" {
t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
}
}
func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
// Create a network protocol with a fake resolver.
proto := &fwdTestNetworkProtocol{
addrResolveDelay: 500 * time.Millisecond,
onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
// Any packets will be resolved to the link address "c".
cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
},
}
ep1, ep2 := fwdTestNetFactory(t, proto)
// Inject two inbound packets to address 3 on NIC 1.
for i := 0; i < 2; i++ {
buf := buffer.NewView(30)
buf[0] = 3
ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{
Data: buf.ToVectorisedView(),
})
}
for i := 0; i < 2; i++ {
var p fwdTestPacketInfo
select {
case p = <-ep2.C:
case <-time.After(time.Second):
t.Fatal("packet not forwarded")
}
b := p.Pkt.Header.View()
if b[0] != 3 {
t.Fatalf("got b[0] = %d, want = 3", b[0])
}
// Test that the address resolution happened correctly.
if p.RemoteLinkAddress != "c" {
t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
}
if p.LocalLinkAddress != "b" {
t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
}
}
}
func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
// Create a network protocol with a fake resolver.
proto := &fwdTestNetworkProtocol{
addrResolveDelay: 500 * time.Millisecond,
onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
// Any packets will be resolved to the link address "c".
cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
},
}
ep1, ep2 := fwdTestNetFactory(t, proto)
for i := 0; i < maxPendingPacketsPerResolution+5; i++ {
// Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1.
buf := buffer.NewView(30)
buf[0] = 3
// Set the packet sequence number.
binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i))
ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{
Data: buf.ToVectorisedView(),
})
}
for i := 0; i < maxPendingPacketsPerResolution; i++ {
var p fwdTestPacketInfo
select {
case p = <-ep2.C:
case <-time.After(time.Second):
t.Fatal("packet not forwarded")
}
b := p.Pkt.Header.View()
if b[0] != 3 {
t.Fatalf("got b[0] = %d, want = 3", b[0])
}
// The first 5 packets should not be forwarded so the the
// sequemnce number should start with 5.
want := uint16(i + 5)
if n := binary.BigEndian.Uint16(b[fwdTestNetHeaderLen:]); n != want {
t.Fatalf("got the packet #%d, want = #%d", n, want)
}
// Test that the address resolution happened correctly.
if p.RemoteLinkAddress != "c" {
t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
}
if p.LocalLinkAddress != "b" {
t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
}
}
}
func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
// Create a network protocol with a fake resolver.
proto := &fwdTestNetworkProtocol{
addrResolveDelay: 500 * time.Millisecond,
onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
// Any packets will be resolved to the link address "c".
cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
},
}
ep1, ep2 := fwdTestNetFactory(t, proto)
for i := 0; i < maxPendingResolutions+5; i++ {
// Inject inbound 'maxPendingResolutions + 5' packets on NIC 1.
// Each packet has a different destination address (3 to
// maxPendingResolutions + 7).
buf := buffer.NewView(30)
buf[0] = byte(3 + i)
ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{
Data: buf.ToVectorisedView(),
})
}
for i := 0; i < maxPendingResolutions; i++ {
var p fwdTestPacketInfo
select {
case p = <-ep2.C:
case <-time.After(time.Second):
t.Fatal("packet not forwarded")
}
// The first 5 packets (address 3 to 7) should not be forwarded
// because their address resolutions are interrupted.
b := p.Pkt.Header.View()
if b[0] < 8 {
t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0])
}
// Test that the address resolution happened correctly.
if p.RemoteLinkAddress != "c" {
t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
}
if p.LocalLinkAddress != "b" {
t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
}
}
}

View File

@ -1201,10 +1201,6 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
return
}
defer r.Release()
r.LocalLinkAddress = n.linkEP.LinkAddress()
r.RemoteLinkAddress = remote
// Found a NIC.
n := r.ref.nic
@ -1213,24 +1209,33 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
ok = ok && ref.isValidForOutgoingRLocked() && ref.tryIncRef()
n.mu.RUnlock()
if ok {
r.LocalLinkAddress = n.linkEP.LinkAddress()
r.RemoteLinkAddress = remote
r.RemoteAddress = src
// TODO(b/123449044): Update the source NIC as well.
ref.ep.HandlePacket(&r, pkt)
ref.decRef()
} else {
// n doesn't have a destination endpoint.
// Send the packet out of n.
pkt.Header = buffer.NewPrependableFromView(pkt.Data.First())
pkt.Data.RemoveFirst()
// TODO(b/128629022): use route.WritePacket.
if err := n.linkEP.WritePacket(&r, nil /* gso */, protocol, pkt); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
} else {
n.stats.Tx.Packets.Increment()
n.stats.Tx.Bytes.IncrementBy(uint64(pkt.Header.UsedLength() + pkt.Data.Size()))
}
r.Release()
return
}
// n doesn't have a destination endpoint.
// Send the packet out of n.
// TODO(b/128629022): move this logic to route.WritePacket.
if ch, err := r.Resolve(nil); err != nil {
if err == tcpip.ErrWouldBlock {
n.stack.forwarder.enqueue(ch, n, &r, protocol, pkt)
// forwarder will release route.
return
}
n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
r.Release()
return
}
// The link-address resolution finished immediately.
n.forwardPacket(&r, protocol, pkt)
r.Release()
return
}
@ -1240,6 +1245,20 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
}
}
func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
// TODO(b/143425874) Decrease the TTL field in forwarded packets.
pkt.Header = buffer.NewPrependableFromView(pkt.Data.First())
pkt.Data.RemoveFirst()
if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
return
}
n.stats.Tx.Packets.Increment()
n.stats.Tx.Bytes.IncrementBy(uint64(pkt.Header.UsedLength() + pkt.Data.Size()))
}
// DeliverTransportPacket delivers the packets to the appropriate transport
// protocol endpoint.
func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) {

View File

@ -462,6 +462,10 @@ type Stack struct {
// opaqueIIDOpts hold the options for generating opaque interface identifiers
// (IIDs) as outlined by RFC 7217.
opaqueIIDOpts OpaqueInterfaceIdentifierOptions
// forwarder holds the packets that wait for their link-address resolutions
// to complete, and forwards them when each resolution is done.
forwarder *forwardQueue
}
// UniqueID is an abstract generator of unique identifiers.
@ -641,6 +645,7 @@ func New(opts Options) *Stack {
uniqueIDGenerator: opts.UniqueID,
ndpDisp: opts.NDPDisp,
opaqueIIDOpts: opts.OpaqueIIDOpts,
forwarder: newForwardQueue(),
}
// Add specified network protocols.