Queue packets in WritePackets when resolving link address
Test: integration_test.TestWritePacketsLinkResolution Fixes #4458. PiperOrigin-RevId: 353108826
This commit is contained in:
parent
0ca4cf7698
commit
89df5a681c
|
@ -45,12 +45,7 @@ type Endpoint struct {
|
|||
linkAddr tcpip.LinkAddress
|
||||
}
|
||||
|
||||
// WritePacket implements stack.LinkEndpoint.
|
||||
func (e *Endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
|
||||
if !e.linked.IsAttached() {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Endpoint) deliverPackets(r stack.RouteInfo, proto tcpip.NetworkProtocolNumber, pkts stack.PacketBufferList) {
|
||||
// Note that the local address from the perspective of this endpoint is the
|
||||
// remote address from the perspective of the other end of the pipe
|
||||
// (e.linked). Similarly, the remote address from the perspective of this
|
||||
|
@ -70,16 +65,33 @@ func (e *Endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, proto tcpip.Netw
|
|||
//
|
||||
// TODO(gvisor.dev/issue/5289): don't use a new goroutine once we support send
|
||||
// and receive queues.
|
||||
go e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
|
||||
}))
|
||||
go func() {
|
||||
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
|
||||
e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
|
||||
}))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// WritePacket implements stack.LinkEndpoint.
|
||||
func (e *Endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
|
||||
if e.linked.IsAttached() {
|
||||
var pkts stack.PacketBufferList
|
||||
pkts.PushBack(pkt)
|
||||
e.deliverPackets(r, proto, pkts)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// WritePackets implements stack.LinkEndpoint.
|
||||
func (*Endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
|
||||
panic("not implemented")
|
||||
func (e *Endpoint) WritePackets(r stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
|
||||
if e.linked.IsAttached() {
|
||||
e.deliverPackets(r, proto, pkts)
|
||||
}
|
||||
|
||||
return pkts.Len(), nil
|
||||
}
|
||||
|
||||
// Attach implements stack.LinkEndpoint.
|
||||
|
|
|
@ -358,16 +358,43 @@ func (n *NIC) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolN
|
|||
|
||||
// WritePackets implements NetworkLinkEndpoint.
|
||||
func (n *NIC) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
|
||||
// TODO(gvisor.dev/issue/4458): Queue packets whie link address resolution
|
||||
// is being peformed like WritePacket.
|
||||
routeInfo := r.Fields()
|
||||
// As per relevant RFCs, we should queue packets while we wait for link
|
||||
// resolution to complete.
|
||||
//
|
||||
// RFC 1122 section 2.3.2.2 (for IPv4):
|
||||
// The link layer SHOULD save (rather than discard) at least
|
||||
// one (the latest) packet of each set of packets destined to
|
||||
// the same unresolved IP address, and transmit the saved
|
||||
// packet when the address has been resolved.
|
||||
//
|
||||
// RFC 4861 section 7.2.2 (for IPv6):
|
||||
// While waiting for address resolution to complete, the sender MUST, for
|
||||
// each neighbor, retain a small queue of packets waiting for address
|
||||
// resolution to complete. The queue MUST hold at least one packet, and MAY
|
||||
// contain more. However, the number of queued packets per neighbor SHOULD
|
||||
// be limited to some small value. When a queue overflows, the new arrival
|
||||
// SHOULD replace the oldest entry. Once address resolution completes, the
|
||||
// node transmits any queued packets.
|
||||
if ch, err := r.Resolve(nil); err != nil {
|
||||
if err == tcpip.ErrWouldBlock {
|
||||
r.Acquire()
|
||||
n.linkResQueue.enqueue(ch, r, protocol, &pkts)
|
||||
return pkts.Len(), nil
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return n.writePackets(r.Fields(), gso, protocol, pkts)
|
||||
}
|
||||
|
||||
func (n *NIC) writePackets(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkts PacketBufferList) (int, *tcpip.Error) {
|
||||
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
|
||||
pkt.EgressRoute = routeInfo
|
||||
pkt.EgressRoute = r
|
||||
pkt.GSOOptions = gso
|
||||
pkt.NetworkProtocolNumber = protocol
|
||||
}
|
||||
|
||||
writtenPackets, err := n.LinkEndpoint.WritePackets(routeInfo, gso, pkts, protocol)
|
||||
writtenPackets, err := n.LinkEndpoint.WritePackets(r, gso, pkts, protocol)
|
||||
n.stats.Tx.Packets.IncrementBy(uint64(writtenPackets))
|
||||
writtenBytes := 0
|
||||
for i, pb := 0, pkts.Front(); i < writtenPackets && pb != nil; i, pb = i+1, pb.Next() {
|
||||
|
|
|
@ -28,10 +28,26 @@ const (
|
|||
maxPendingPacketsPerResolution = 256
|
||||
)
|
||||
|
||||
// pendingPacketBuffer is a pending packet buffer.
|
||||
//
|
||||
// TODO(gvisor.dev/issue/5331): Drop this when we drop WritePacket and only use
|
||||
// WritePackets so we can use a PacketBufferList everywhere.
|
||||
type pendingPacketBuffer interface {
|
||||
len() int
|
||||
}
|
||||
|
||||
func (*PacketBuffer) len() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func (p *PacketBufferList) len() int {
|
||||
return p.Len()
|
||||
}
|
||||
|
||||
type pendingPacket struct {
|
||||
route *Route
|
||||
proto tcpip.NetworkProtocolNumber
|
||||
pkt *PacketBuffer
|
||||
pkt pendingPacketBuffer
|
||||
}
|
||||
|
||||
// packetsPendingLinkResolution is a queue of packets pending link resolution.
|
||||
|
@ -54,16 +70,17 @@ func (f *packetsPendingLinkResolution) init() {
|
|||
f.packets = make(map[<-chan struct{}][]pendingPacket)
|
||||
}
|
||||
|
||||
func incrementOutgoingPacketErrors(r *Route, proto tcpip.NetworkProtocolNumber) {
|
||||
r.Stats().IP.OutgoingPacketErrors.Increment()
|
||||
func incrementOutgoingPacketErrors(r *Route, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) {
|
||||
n := uint64(pkt.len())
|
||||
r.Stats().IP.OutgoingPacketErrors.IncrementBy(n)
|
||||
|
||||
// ok may be false if the endpoint's stats do not collect IP-related data.
|
||||
if ipEndpointStats, ok := r.outgoingNIC.getNetworkEndpoint(proto).Stats().(IPNetworkEndpointStats); ok {
|
||||
ipEndpointStats.IPStats().OutgoingPacketErrors.Increment()
|
||||
ipEndpointStats.IPStats().OutgoingPacketErrors.IncrementBy(n)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, proto tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
|
||||
func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) {
|
||||
f.Lock()
|
||||
defer f.Unlock()
|
||||
|
||||
|
@ -73,7 +90,7 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro
|
|||
packets[0] = pendingPacket{}
|
||||
packets = packets[1:]
|
||||
|
||||
incrementOutgoingPacketErrors(r, proto)
|
||||
incrementOutgoingPacketErrors(r, proto, p.pkt)
|
||||
|
||||
p.route.Release()
|
||||
}
|
||||
|
@ -113,13 +130,29 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro
|
|||
|
||||
for _, p := range packets {
|
||||
if cancelled || p.route.IsResolutionRequired() {
|
||||
incrementOutgoingPacketErrors(r, proto)
|
||||
incrementOutgoingPacketErrors(r, proto, p.pkt)
|
||||
|
||||
if linkResolvableEP, ok := p.route.outgoingNIC.getNetworkEndpoint(p.route.NetProto).(LinkResolvableNetworkEndpoint); ok {
|
||||
linkResolvableEP.HandleLinkResolutionFailure(pkt)
|
||||
switch pkt := p.pkt.(type) {
|
||||
case *PacketBuffer:
|
||||
linkResolvableEP.HandleLinkResolutionFailure(pkt)
|
||||
case *PacketBufferList:
|
||||
for pb := pkt.Front(); pb != nil; pb = pb.Next() {
|
||||
linkResolvableEP.HandleLinkResolutionFailure(pb)
|
||||
}
|
||||
default:
|
||||
panic(fmt.Sprintf("unrecognized pending packet buffer type = %T", p.pkt))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
p.route.outgoingNIC.writePacket(p.route.Fields(), nil /* gso */, p.proto, p.pkt)
|
||||
switch pkt := p.pkt.(type) {
|
||||
case *PacketBuffer:
|
||||
p.route.outgoingNIC.writePacket(p.route.Fields(), nil /* gso */, p.proto, pkt)
|
||||
case *PacketBufferList:
|
||||
p.route.outgoingNIC.writePackets(p.route.Fields(), nil /* gso */, p.proto, *pkt)
|
||||
default:
|
||||
panic(fmt.Sprintf("unrecognized pending packet buffer type = %T", p.pkt))
|
||||
}
|
||||
}
|
||||
p.route.Release()
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@ import (
|
|||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/checker"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/link/pipe"
|
||||
|
@ -32,6 +33,7 @@ import (
|
|||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
)
|
||||
|
||||
|
@ -456,3 +458,126 @@ func TestGetLinkAddress(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWritePacketsLinkResolution(t *testing.T) {
|
||||
const (
|
||||
host1NICID = 1
|
||||
host2NICID = 4
|
||||
)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
netProto tcpip.NetworkProtocolNumber
|
||||
remoteAddr tcpip.Address
|
||||
expectedWriteErr *tcpip.Error
|
||||
}{
|
||||
{
|
||||
name: "IPv4",
|
||||
netProto: ipv4.ProtocolNumber,
|
||||
remoteAddr: ipv4Addr2.AddressWithPrefix.Address,
|
||||
expectedWriteErr: nil,
|
||||
},
|
||||
{
|
||||
name: "IPv6",
|
||||
netProto: ipv6.ProtocolNumber,
|
||||
remoteAddr: ipv6Addr2.AddressWithPrefix.Address,
|
||||
expectedWriteErr: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
stackOpts := stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
|
||||
}
|
||||
|
||||
host1Stack, host2Stack := setupStack(t, stackOpts, host1NICID, host2NICID)
|
||||
|
||||
var serverWQ waiter.Queue
|
||||
serverWE, serverCH := waiter.NewChannelEntry(nil)
|
||||
serverWQ.EventRegister(&serverWE, waiter.EventIn)
|
||||
serverEP, err := host2Stack.NewEndpoint(udp.ProtocolNumber, test.netProto, &serverWQ)
|
||||
if err != nil {
|
||||
t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.netProto, err)
|
||||
}
|
||||
defer serverEP.Close()
|
||||
|
||||
serverAddr := tcpip.FullAddress{Port: 1234}
|
||||
if err := serverEP.Bind(serverAddr); err != nil {
|
||||
t.Fatalf("serverEP.Bind(%#v): %s", serverAddr, err)
|
||||
}
|
||||
|
||||
r, err := host1Stack.FindRoute(host1NICID, "", test.remoteAddr, test.netProto, false /* multicastLoop */)
|
||||
if err != nil {
|
||||
t.Fatalf("host1Stack.FindRoute(%d, '', %s, %d, false): %s", host1NICID, test.remoteAddr, test.netProto, err)
|
||||
}
|
||||
defer r.Release()
|
||||
|
||||
data := []byte{1, 2}
|
||||
var pkts stack.PacketBufferList
|
||||
for _, d := range data {
|
||||
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()),
|
||||
Data: buffer.View([]byte{d}).ToVectorisedView(),
|
||||
})
|
||||
pkt.TransportProtocolNumber = udp.ProtocolNumber
|
||||
length := uint16(pkt.Size())
|
||||
udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
|
||||
udpHdr.Encode(&header.UDPFields{
|
||||
SrcPort: 5555,
|
||||
DstPort: serverAddr.Port,
|
||||
Length: length,
|
||||
})
|
||||
xsum := r.PseudoHeaderChecksum(udp.ProtocolNumber, length)
|
||||
for _, v := range pkt.Data.Views() {
|
||||
xsum = header.Checksum(v, xsum)
|
||||
}
|
||||
udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum))
|
||||
|
||||
pkts.PushBack(pkt)
|
||||
}
|
||||
|
||||
params := stack.NetworkHeaderParams{
|
||||
Protocol: udp.ProtocolNumber,
|
||||
TTL: 64,
|
||||
TOS: stack.DefaultTOS,
|
||||
}
|
||||
|
||||
if n, err := r.WritePackets(nil /* gso */, pkts, params); err != nil {
|
||||
t.Fatalf("r.WritePackets(nil, %#v, _): %s", params, err)
|
||||
} else if want := pkts.Len(); want != n {
|
||||
t.Fatalf("got r.WritePackets(nil, %#v, _) = %d, want = %d", n, params, want)
|
||||
}
|
||||
|
||||
var writer bytes.Buffer
|
||||
count := 0
|
||||
for {
|
||||
var rOpts tcpip.ReadOptions
|
||||
res, err := serverEP.Read(&writer, rOpts)
|
||||
if err != nil {
|
||||
if err == tcpip.ErrWouldBlock {
|
||||
// Should not have anymore bytes to read after we read the sent
|
||||
// number of bytes.
|
||||
if count == len(data) {
|
||||
break
|
||||
}
|
||||
|
||||
<-serverCH
|
||||
continue
|
||||
}
|
||||
|
||||
t.Fatalf("serverEP.Read(_, %#v): %s", rOpts, err)
|
||||
}
|
||||
count += res.Count
|
||||
}
|
||||
|
||||
if got, want := host2Stack.Stats().UDP.PacketsReceived.Value(), uint64(len(data)); got != want {
|
||||
t.Errorf("got host2Stack.Stats().UDP.PacketsReceived.Value() = %d, want = %d", got, want)
|
||||
}
|
||||
if diff := cmp.Diff(data, writer.Bytes()); diff != "" {
|
||||
t.Errorf("read bytes mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue