Remove View.First() and View.RemoveFirst()

These methods let users eaily break the VectorisedView abstraction, and
allowed netstack to slip into pseudo-enforcement of the "all headers are
in the first View" invariant. Removing them and replacing with PullUp(n)
breaks this reliance and will make it easier to add iptables support and
rework network buffer management.

The new View.PullUp(n) method is low cost in the common case, when when
all the headers fit in the first View.
This commit is contained in:
Kevin Krakauer 2020-04-13 17:37:21 -07:00
parent 80deebb0bf
commit a551add5d8
25 changed files with 395 additions and 147 deletions

View File

@ -121,12 +121,13 @@ func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa
tcpHeader = header.TCP(pkt.TransportHeader)
} else {
// The TCP header hasn't been parsed yet. We have to do it here.
if len(pkt.Data.First()) < header.TCPMinimumSize {
hdr, ok := pkt.Data.PullUp(header.TCPMinimumSize)
if !ok {
// There's no valid TCP header here, so we hotdrop the
// packet.
return false, true
}
tcpHeader = header.TCP(pkt.Data.First())
tcpHeader = header.TCP(hdr)
}
// Check whether the source and destination ports are within the

View File

@ -120,12 +120,13 @@ func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa
udpHeader = header.UDP(pkt.TransportHeader)
} else {
// The UDP header hasn't been parsed yet. We have to do it here.
if len(pkt.Data.First()) < header.UDPMinimumSize {
hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize)
if !ok {
// There's no valid UDP header here, so we hotdrop the
// packet.
return false, true
}
udpHeader = header.UDP(pkt.Data.First())
udpHeader = header.UDP(hdr)
}
// Check whether the source and destination ports are within the

View File

@ -77,7 +77,8 @@ func NewVectorisedView(size int, views []View) VectorisedView {
return VectorisedView{views: views, size: size}
}
// TrimFront removes the first "count" bytes of the vectorised view.
// TrimFront removes the first "count" bytes of the vectorised view. It panics
// if count > vv.Size().
func (vv *VectorisedView) TrimFront(count int) {
for count > 0 && len(vv.views) > 0 {
if count < len(vv.views[0]) {
@ -86,7 +87,7 @@ func (vv *VectorisedView) TrimFront(count int) {
return
}
count -= len(vv.views[0])
vv.RemoveFirst()
vv.removeFirst()
}
}
@ -104,7 +105,7 @@ func (vv *VectorisedView) Read(v View) (copied int, err error) {
count -= len(vv.views[0])
copy(v[copied:], vv.views[0])
copied += len(vv.views[0])
vv.RemoveFirst()
vv.removeFirst()
}
if copied == 0 {
return 0, io.EOF
@ -126,7 +127,7 @@ func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int
count -= len(vv.views[0])
dstVV.AppendView(vv.views[0])
copied += len(vv.views[0])
vv.RemoveFirst()
vv.removeFirst()
}
return copied
}
@ -162,22 +163,37 @@ func (vv *VectorisedView) Clone(buffer []View) VectorisedView {
return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size}
}
// First returns the first view of the vectorised view.
func (vv *VectorisedView) First() View {
// PullUp returns the first "count" bytes of the vectorised view. If those
// bytes aren't already contiguous inside the vectorised view, PullUp will
// reallocate as needed to make them contiguous. PullUp fails and returns false
// when count > vv.Size().
func (vv *VectorisedView) PullUp(count int) (View, bool) {
if len(vv.views) == 0 {
return nil
return nil, count == 0
}
if count <= len(vv.views[0]) {
return vv.views[0][:count], true
}
if count > vv.size {
return nil, false
}
return vv.views[0]
}
// RemoveFirst removes the first view of the vectorised view.
func (vv *VectorisedView) RemoveFirst() {
if len(vv.views) == 0 {
return
newFirst := NewView(count)
i := 0
for offset := 0; offset < count; i++ {
copy(newFirst[offset:], vv.views[i])
if count-offset < len(vv.views[i]) {
vv.views[i].TrimFront(count - offset)
break
}
offset += len(vv.views[i])
vv.views[i] = nil
}
vv.size -= len(vv.views[0])
vv.views[0] = nil
vv.views = vv.views[1:]
// We're guaranteed that i > 0, since count is too large for the first
// view.
vv.views[i-1] = newFirst
vv.views = vv.views[i-1:]
return newFirst, true
}
// Size returns the size in bytes of the entire content stored in the vectorised view.
@ -225,3 +241,10 @@ func (vv *VectorisedView) Readers() []bytes.Reader {
}
return readers
}
// removeFirst panics when len(vv.views) < 1.
func (vv *VectorisedView) removeFirst() {
vv.size -= len(vv.views[0])
vv.views[0] = nil
vv.views = vv.views[1:]
}

View File

@ -16,6 +16,7 @@
package buffer
import (
"bytes"
"reflect"
"testing"
)
@ -370,3 +371,115 @@ func TestVVRead(t *testing.T) {
})
}
}
var pullUpTestCases = []struct {
comment string
in VectorisedView
count int
want []byte
result VectorisedView
ok bool
}{
{
comment: "simple case",
in: vv(2, "12"),
count: 1,
want: []byte("1"),
result: vv(2, "12"),
ok: true,
},
{
comment: "entire View",
in: vv(2, "1", "2"),
count: 1,
want: []byte("1"),
result: vv(2, "1", "2"),
ok: true,
},
{
comment: "spanning across two Views",
in: vv(3, "1", "23"),
count: 2,
want: []byte("12"),
result: vv(3, "12", "3"),
ok: true,
},
{
comment: "spanning across all Views",
in: vv(5, "1", "23", "45"),
count: 5,
want: []byte("12345"),
result: vv(5, "12345"),
ok: true,
},
{
comment: "count = 0",
in: vv(1, "1"),
count: 0,
want: []byte{},
result: vv(1, "1"),
ok: true,
},
{
comment: "count = size",
in: vv(1, "1"),
count: 1,
want: []byte("1"),
result: vv(1, "1"),
ok: true,
},
{
comment: "count too large",
in: vv(3, "1", "23"),
count: 4,
want: nil,
result: vv(3, "1", "23"),
ok: false,
},
{
comment: "empty vv",
in: vv(0, ""),
count: 1,
want: nil,
result: vv(0, ""),
ok: false,
},
{
comment: "empty vv, count = 0",
in: vv(0, ""),
count: 0,
want: nil,
result: vv(0, ""),
ok: true,
},
{
comment: "empty views",
in: vv(3, "", "1", "", "23"),
count: 2,
want: []byte("12"),
result: vv(3, "12", "3"),
ok: true,
},
}
func TestPullUp(t *testing.T) {
for _, c := range pullUpTestCases {
got, ok := c.in.PullUp(c.count)
// Is the return value right?
if ok != c.ok {
t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got an ok of %t. Want %t",
c.comment, c.count, c.in, ok, c.ok)
}
if bytes.Compare(got, View(c.want)) != 0 {
t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got %v. Want %v",
c.comment, c.count, c.in, got, c.want)
}
// Is the underlying structure right?
if !reflect.DeepEqual(c.in, c.result) {
t.Errorf("Test %q failed when calling PullUp(%d). Got vv with structure %v. Wanted %v",
c.comment, c.count, c.in, c.result)
}
}
}

View File

@ -98,13 +98,13 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
// Reject the packet if it's shorter than an ethernet header.
if vv.Size() < header.EthernetMinimumSize {
// There should be an ethernet header at the beginning of vv.
hdr, ok := vv.PullUp(header.EthernetMinimumSize)
if !ok {
// Reject the packet if it's shorter than an ethernet header.
return tcpip.ErrBadAddress
}
// There should be an ethernet header at the beginning of vv.
linkHeader := header.Ethernet(vv.First()[:header.EthernetMinimumSize])
linkHeader := header.Ethernet(hdr)
vv.TrimFront(len(linkHeader))
e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), stack.PacketBuffer{
Data: vv,

View File

@ -674,7 +674,7 @@ func TestSimpleReceive(t *testing.T) {
// Wait for packet to be received, then check it.
c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet")
c.mu.Lock()
rcvd := []byte(c.packets[0].vv.First())
rcvd := []byte(c.packets[0].vv.ToView())
c.packets = c.packets[:0]
c.mu.Unlock()

View File

@ -171,11 +171,7 @@ func (e *endpoint) GSOMaxSize() uint32 {
func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
writer := e.writer
if writer == nil && atomic.LoadUint32(&LogPackets) == 1 {
first := pkt.Header.View()
if len(first) == 0 {
first = pkt.Data.First()
}
logPacket(prefix, protocol, first, gso)
logPacket(prefix, protocol, pkt, gso)
}
if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 {
totalLength := pkt.Header.UsedLength() + pkt.Data.Size()
@ -238,7 +234,7 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
// Wait implements stack.LinkEndpoint.Wait.
func (e *endpoint) Wait() { e.lower.Wait() }
func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) {
func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) {
// Figure out the network layer info.
var transProto uint8
src := tcpip.Address("unknown")
@ -247,28 +243,49 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie
size := uint16(0)
var fragmentOffset uint16
var moreFragments bool
// Create a clone of pkt, including any headers if present. Avoid allocating
// backing memory for the clone.
views := [8]buffer.View{}
vv := buffer.NewVectorisedView(0, views[:0])
vv.AppendView(pkt.Header.View())
vv.Append(pkt.Data)
switch protocol {
case header.IPv4ProtocolNumber:
ipv4 := header.IPv4(b)
hdr, ok := vv.PullUp(header.IPv4MinimumSize)
if !ok {
return
}
ipv4 := header.IPv4(hdr)
fragmentOffset = ipv4.FragmentOffset()
moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments
src = ipv4.SourceAddress()
dst = ipv4.DestinationAddress()
transProto = ipv4.Protocol()
size = ipv4.TotalLength() - uint16(ipv4.HeaderLength())
b = b[ipv4.HeaderLength():]
vv.TrimFront(int(ipv4.HeaderLength()))
id = int(ipv4.ID())
case header.IPv6ProtocolNumber:
ipv6 := header.IPv6(b)
hdr, ok := vv.PullUp(header.IPv6MinimumSize)
if !ok {
return
}
ipv6 := header.IPv6(hdr)
src = ipv6.SourceAddress()
dst = ipv6.DestinationAddress()
transProto = ipv6.NextHeader()
size = ipv6.PayloadLength()
b = b[header.IPv6MinimumSize:]
vv.TrimFront(header.IPv6MinimumSize)
case header.ARPProtocolNumber:
arp := header.ARP(b)
hdr, ok := vv.PullUp(header.ARPSize)
if !ok {
return
}
vv.TrimFront(header.ARPSize)
arp := header.ARP(hdr)
log.Infof(
"%s arp %v (%v) -> %v (%v) valid:%v",
prefix,
@ -284,7 +301,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie
// We aren't guaranteed to have a transport header - it's possible for
// writes via raw endpoints to contain only network headers.
if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && len(b) < minSize {
if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && vv.Size() < minSize {
log.Infof("%s %v -> %v transport protocol: %d, but no transport header found (possible raw packet)", prefix, src, dst, transProto)
return
}
@ -297,7 +314,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie
switch tcpip.TransportProtocolNumber(transProto) {
case header.ICMPv4ProtocolNumber:
transName = "icmp"
icmp := header.ICMPv4(b)
hdr, ok := vv.PullUp(header.ICMPv4MinimumSize)
if !ok {
break
}
icmp := header.ICMPv4(hdr)
icmpType := "unknown"
if fragmentOffset == 0 {
switch icmp.Type() {
@ -330,7 +351,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie
case header.ICMPv6ProtocolNumber:
transName = "icmp"
icmp := header.ICMPv6(b)
hdr, ok := vv.PullUp(header.ICMPv6MinimumSize)
if !ok {
break
}
icmp := header.ICMPv6(hdr)
icmpType := "unknown"
switch icmp.Type() {
case header.ICMPv6DstUnreachable:
@ -361,7 +386,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie
case header.UDPProtocolNumber:
transName = "udp"
udp := header.UDP(b)
hdr, ok := vv.PullUp(header.UDPMinimumSize)
if !ok {
break
}
udp := header.UDP(hdr)
if fragmentOffset == 0 && len(udp) >= header.UDPMinimumSize {
srcPort = udp.SourcePort()
dstPort = udp.DestinationPort()
@ -371,7 +400,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie
case header.TCPProtocolNumber:
transName = "tcp"
tcp := header.TCP(b)
hdr, ok := vv.PullUp(header.TCPMinimumSize)
if !ok {
break
}
tcp := header.TCP(hdr)
if fragmentOffset == 0 && len(tcp) >= header.TCPMinimumSize {
offset := int(tcp.DataOffset())
if offset < header.TCPMinimumSize {

View File

@ -93,7 +93,10 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf
}
func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
v := pkt.Data.First()
v, ok := pkt.Data.PullUp(header.ARPSize)
if !ok {
return
}
h := header.ARP(v)
if !h.IsValid() {
return

View File

@ -25,7 +25,11 @@ import (
// used to find out which transport endpoint must be notified about the ICMP
// packet.
func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
h := header.IPv4(pkt.Data.First())
h, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
if !ok {
return
}
hdr := header.IPv4(h)
// We don't use IsValid() here because ICMP only requires that the IP
// header plus 8 bytes of the transport header be included. So it's
@ -34,12 +38,12 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.
//
// Drop packet if it doesn't have the basic IPv4 header or if the
// original source address doesn't match the endpoint's address.
if len(h) < header.IPv4MinimumSize || h.SourceAddress() != e.id.LocalAddress {
if hdr.SourceAddress() != e.id.LocalAddress {
return
}
hlen := int(h.HeaderLength())
if pkt.Data.Size() < hlen || h.FragmentOffset() != 0 {
hlen := int(hdr.HeaderLength())
if pkt.Data.Size() < hlen || hdr.FragmentOffset() != 0 {
// We won't be able to handle this if it doesn't contain the
// full IPv4 header, or if it's a fragment not at offset 0
// (because it won't have the transport header).
@ -48,15 +52,15 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.
// Skip the ip header, then deliver control message.
pkt.Data.TrimFront(hlen)
p := h.TransportProtocol()
e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
p := hdr.TransportProtocol()
e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
}
func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) {
stats := r.Stats()
received := stats.ICMP.V4PacketsReceived
v := pkt.Data.First()
if len(v) < header.ICMPv4MinimumSize {
v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize)
if !ok {
received.Invalid.Increment()
return
}

View File

@ -328,7 +328,11 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error {
// The packet already has an IP header, but there are a few required
// checks.
ip := header.IPv4(pkt.Data.First())
h, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
if !ok {
return tcpip.ErrInvalidOptionValue
}
ip := header.IPv4(h)
if !ip.IsValid(pkt.Data.Size()) {
return tcpip.ErrInvalidOptionValue
}
@ -378,7 +382,11 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf
// HandlePacket is called by the link layer when new ipv4 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
headerView := pkt.Data.First()
headerView, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
if !ok {
r.Stats().IP.MalformedPacketsReceived.Increment()
return
}
h := header.IPv4(headerView)
if !h.IsValid(pkt.Data.Size()) {
r.Stats().IP.MalformedPacketsReceived.Increment()

View File

@ -28,7 +28,11 @@ import (
// used to find out which transport endpoint must be notified about the ICMP
// packet.
func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
h := header.IPv6(pkt.Data.First())
h, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
if !ok {
return
}
hdr := header.IPv6(h)
// We don't use IsValid() here because ICMP only requires that up to
// 1280 bytes of the original packet be included. So it's likely that it
@ -36,17 +40,21 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.
//
// Drop packet if it doesn't have the basic IPv6 header or if the
// original source address doesn't match the endpoint's address.
if len(h) < header.IPv6MinimumSize || h.SourceAddress() != e.id.LocalAddress {
if hdr.SourceAddress() != e.id.LocalAddress {
return
}
// Skip the IP header, then handle the fragmentation header if there
// is one.
pkt.Data.TrimFront(header.IPv6MinimumSize)
p := h.TransportProtocol()
p := hdr.TransportProtocol()
if p == header.IPv6FragmentHeader {
f := header.IPv6Fragment(pkt.Data.First())
if !f.IsValid() || f.FragmentOffset() != 0 {
f, ok := pkt.Data.PullUp(header.IPv6FragmentHeaderSize)
if !ok {
return
}
fragHdr := header.IPv6Fragment(f)
if !fragHdr.IsValid() || fragHdr.FragmentOffset() != 0 {
// We can't handle fragments that aren't at offset 0
// because they don't have the transport headers.
return
@ -55,19 +63,19 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.
// Skip fragmentation header and find out the actual protocol
// number.
pkt.Data.TrimFront(header.IPv6FragmentHeaderSize)
p = f.TransportProtocol()
p = fragHdr.TransportProtocol()
}
// Deliver the control packet to the transport endpoint.
e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
}
func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer, hasFragmentHeader bool) {
stats := r.Stats().ICMP
sent := stats.V6PacketsSent
received := stats.V6PacketsReceived
v := pkt.Data.First()
if len(v) < header.ICMPv6MinimumSize {
v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize)
if !ok {
received.Invalid.Increment()
return
}
@ -76,11 +84,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
// Validate ICMPv6 checksum before processing the packet.
//
// Only the first view in vv is accounted for by h. To account for the
// rest of vv, a shallow copy is made and the first view is removed.
// This copy is used as extra payload during the checksum calculation.
payload := pkt.Data.Clone(nil)
payload.RemoveFirst()
payload.TrimFront(len(h))
if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want {
received.Invalid.Increment()
return
@ -101,34 +107,40 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
switch h.Type() {
case header.ICMPv6PacketTooBig:
received.PacketTooBig.Increment()
if len(v) < header.ICMPv6PacketTooBigMinimumSize {
hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize)
if !ok {
received.Invalid.Increment()
return
}
pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize)
mtu := h.MTU()
mtu := header.ICMPv6(hdr).MTU()
e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt)
case header.ICMPv6DstUnreachable:
received.DstUnreachable.Increment()
if len(v) < header.ICMPv6DstUnreachableMinimumSize {
hdr, ok := pkt.Data.PullUp(header.ICMPv6DstUnreachableMinimumSize)
if !ok {
received.Invalid.Increment()
return
}
pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize)
switch h.Code() {
switch header.ICMPv6(hdr).Code() {
case header.ICMPv6PortUnreachable:
e.handleControl(stack.ControlPortUnreachable, 0, pkt)
}
case header.ICMPv6NeighborSolicit:
received.NeighborSolicit.Increment()
if len(v) < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() {
if pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() {
received.Invalid.Increment()
return
}
ns := header.NDPNeighborSolicit(h.NDPPayload())
// The remainder of payload must be only the neighbor solicitation, so
// payload.ToView() always returns the solicitation. Per RFC 6980 section 5,
// NDP messages cannot be fragmented. Also note that in the common case NDP
// datagrams are very small and ToView() will not incur allocations.
ns := header.NDPNeighborSolicit(payload.ToView())
it, err := ns.Options().Iter(true)
if err != nil {
// If we have a malformed NDP NS option, drop the packet.
@ -286,12 +298,16 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
case header.ICMPv6NeighborAdvert:
received.NeighborAdvert.Increment()
if len(v) < header.ICMPv6NeighborAdvertSize || !isNDPValid() {
if pkt.Data.Size() < header.ICMPv6NeighborAdvertSize || !isNDPValid() {
received.Invalid.Increment()
return
}
na := header.NDPNeighborAdvert(h.NDPPayload())
// The remainder of payload must be only the neighbor advertisement, so
// payload.ToView() always returns the advertisement. Per RFC 6980 section
// 5, NDP messages cannot be fragmented. Also note that in the common case
// NDP datagrams are very small and ToView() will not incur allocations.
na := header.NDPNeighborAdvert(payload.ToView())
it, err := na.Options().Iter(true)
if err != nil {
// If we have a malformed NDP NA option, drop the packet.
@ -363,14 +379,15 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
case header.ICMPv6EchoRequest:
received.EchoRequest.Increment()
if len(v) < header.ICMPv6EchoMinimumSize {
icmpHdr, ok := pkt.Data.PullUp(header.ICMPv6EchoMinimumSize)
if !ok {
received.Invalid.Increment()
return
}
pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize)
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize)
packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
copy(packet, h)
copy(packet, icmpHdr)
packet.SetType(header.ICMPv6EchoReply)
packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data))
if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{
@ -384,7 +401,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
case header.ICMPv6EchoReply:
received.EchoReply.Increment()
if len(v) < header.ICMPv6EchoMinimumSize {
if pkt.Data.Size() < header.ICMPv6EchoMinimumSize {
received.Invalid.Increment()
return
}
@ -406,8 +423,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
case header.ICMPv6RouterAdvert:
received.RouterAdvert.Increment()
p := h.NDPPayload()
if len(p) < header.NDPRAMinimumSize || !isNDPValid() {
// Is the NDP payload of sufficient size to hold a Router
// Advertisement?
if pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize || !isNDPValid() {
received.Invalid.Increment()
return
}
@ -425,7 +443,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
return
}
ra := header.NDPRouterAdvert(p)
// The remainder of payload must be only the router advertisement, so
// payload.ToView() always returns the advertisement. Per RFC 6980 section
// 5, NDP messages cannot be fragmented. Also note that in the common case
// NDP datagrams are very small and ToView() will not incur allocations.
ra := header.NDPRouterAdvert(payload.ToView())
opts := ra.Options()
// Are options valid as per the wire format?

View File

@ -166,7 +166,8 @@ func TestICMPCounts(t *testing.T) {
},
{
typ: header.ICMPv6NeighborSolicit,
size: header.ICMPv6NeighborSolicitMinimumSize},
size: header.ICMPv6NeighborSolicitMinimumSize,
},
{
typ: header.ICMPv6NeighborAdvert,
size: header.ICMPv6NeighborAdvertMinimumSize,

View File

@ -171,7 +171,11 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffe
// HandlePacket is called by the link layer when new ipv6 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
headerView := pkt.Data.First()
headerView, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
if !ok {
r.Stats().IP.MalformedPacketsReceived.Increment()
return
}
h := header.IPv6(headerView)
if !h.IsValid(pkt.Data.Size()) {
r.Stats().IP.MalformedPacketsReceived.Increment()

View File

@ -70,7 +70,10 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID {
func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt PacketBuffer) {
// Consume the network header.
b := pkt.Data.First()
b, ok := pkt.Data.PullUp(fwdTestNetHeaderLen)
if !ok {
return
}
pkt.Data.TrimFront(fwdTestNetHeaderLen)
// Dispatch the packet to the transport protocol.
@ -473,7 +476,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
t.Fatal("packet not forwarded")
}
b := p.Pkt.Header.View()
b := p.Pkt.Data.ToView()
if b[0] != 3 {
t.Fatalf("got b[0] = %d, want = 3", b[0])
}
@ -517,7 +520,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
t.Fatal("packet not forwarded")
}
b := p.Pkt.Header.View()
b := p.Pkt.Data.ToView()
if b[0] != 3 {
t.Fatalf("got b[0] = %d, want = 3", b[0])
}
@ -564,7 +567,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
t.Fatal("packet not forwarded")
}
b := p.Pkt.Header.View()
b := p.Pkt.Data.ToView()
if b[0] != 3 {
t.Fatalf("got b[0] = %d, want = 3", b[0])
}
@ -619,7 +622,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
// The first 5 packets (address 3 to 7) should not be forwarded
// because their address resolutions are interrupted.
b := p.Pkt.Header.View()
b := p.Pkt.Data.ToView()
if b[0] < 8 {
t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0])
}

View File

@ -212,6 +212,11 @@ func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool {
// CheckPackets runs pkts through the rules for hook and returns a map of packets that
// should not go forward.
//
// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
//
// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a
// precondition.
//
// NOTE: unlike the Check API the returned map contains packets that should be
// dropped.
func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*PacketBuffer]struct{}) {
@ -226,7 +231,9 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*Pa
return drop
}
// Precondition: pkt.NetworkHeader is set.
// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a
// precondition.
func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) chainVerdict {
// Start from ruleIdx and walk the list of rules until a rule gives us
// a verdict.
@ -271,14 +278,21 @@ func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx
return chainDrop
}
// Precondition: pk.NetworkHeader is set.
// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a
// precondition.
func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) {
rule := table.Rules[ruleIdx]
// If pkt.NetworkHeader hasn't been set yet, it will be contained in
// pkt.Data.First().
// pkt.Data.
if pkt.NetworkHeader == nil {
pkt.NetworkHeader = pkt.Data.First()
var ok bool
pkt.NetworkHeader, ok = pkt.Data.PullUp(header.IPv4MinimumSize)
if !ok {
// Precondition has been violated.
panic(fmt.Sprintf("iptables checks require IPv4 headers of at least %d bytes", header.IPv4MinimumSize))
}
}
// Check whether the packet matches the IP header filter.

View File

@ -96,9 +96,12 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) {
newPkt := pkt.Clone()
// Set network header.
headerView := newPkt.Data.First()
headerView, ok := newPkt.Data.PullUp(header.IPv4MinimumSize)
if !ok {
return RuleDrop, 0
}
netHeader := header.IPv4(headerView)
newPkt.NetworkHeader = headerView[:header.IPv4MinimumSize]
newPkt.NetworkHeader = headerView
hlen := int(netHeader.HeaderLength())
tlen := int(netHeader.TotalLength())
@ -117,10 +120,14 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) {
if newPkt.TransportHeader != nil {
udpHeader = header.UDP(newPkt.TransportHeader)
} else {
if len(pkt.Data.First()) < header.UDPMinimumSize {
if pkt.Data.Size() < header.UDPMinimumSize {
return RuleDrop, 0
}
udpHeader = header.UDP(newPkt.Data.First())
hdr, ok := newPkt.Data.PullUp(header.UDPMinimumSize)
if !ok {
return RuleDrop, 0
}
udpHeader = header.UDP(hdr)
}
udpHeader.SetDestinationPort(rt.MinPort)
case header.TCPProtocolNumber:
@ -128,10 +135,14 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) {
if newPkt.TransportHeader != nil {
tcpHeader = header.TCP(newPkt.TransportHeader)
} else {
if len(pkt.Data.First()) < header.TCPMinimumSize {
if pkt.Data.Size() < header.TCPMinimumSize {
return RuleDrop, 0
}
tcpHeader = header.TCP(newPkt.TransportHeader)
hdr, ok := newPkt.Data.PullUp(header.TCPMinimumSize)
if !ok {
return RuleDrop, 0
}
tcpHeader = header.TCP(hdr)
}
// TODO(gvisor.dev/issue/170): Need to recompute checksum
// and implement nat connection tracking to support TCP.

View File

@ -1203,12 +1203,12 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
n.stack.stats.IP.PacketsReceived.Increment()
}
if len(pkt.Data.First()) < netProto.MinimumPacketSize() {
netHeader, ok := pkt.Data.PullUp(netProto.MinimumPacketSize())
if !ok {
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
src, dst := netProto.ParseAddresses(pkt.Data.First())
src, dst := netProto.ParseAddresses(netHeader)
if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil {
// The source address is one of our own, so we never should have gotten a
@ -1289,22 +1289,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) {
// TODO(b/143425874) Decrease the TTL field in forwarded packets.
firstData := pkt.Data.First()
pkt.Data.RemoveFirst()
if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen == 0 {
pkt.Header = buffer.NewPrependableFromView(firstData)
} else {
firstDataLen := len(firstData)
// pkt.Header should have enough capacity to hold n.linkEP's headers.
pkt.Header = buffer.NewPrependable(firstDataLen + linkHeaderLen)
// TODO(b/151227689): avoid copying the packet when forwarding
if n := copy(pkt.Header.Prepend(firstDataLen), firstData); n != firstDataLen {
panic(fmt.Sprintf("copied %d bytes, expected %d", n, firstDataLen))
}
if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen != 0 {
pkt.Header = buffer.NewPrependable(linkHeaderLen)
}
if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil {
@ -1332,12 +1318,13 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// validly formed.
n.stack.demux.deliverRawPacket(r, protocol, pkt)
if len(pkt.Data.First()) < transProto.MinimumPacketSize() {
transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize())
if !ok {
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First())
srcPort, dstPort, err := transProto.ParsePorts(transHeader)
if err != nil {
n.stack.stats.MalformedRcvdPackets.Increment()
return
@ -1375,11 +1362,12 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp
// ICMPv4 only guarantees that 8 bytes of the transport protocol will
// be present in the payload. We know that the ports are within the
// first 8 bytes for all known transport protocols.
if len(pkt.Data.First()) < 8 {
transHeader, ok := pkt.Data.PullUp(8)
if !ok {
return
}
srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First())
srcPort, dstPort, err := transProto.ParsePorts(transHeader)
if err != nil {
return
}

View File

@ -37,7 +37,9 @@ type PacketBuffer struct {
Data buffer.VectorisedView
// Header holds the headers of outbound packets. As a packet is passed
// down the stack, each layer adds to Header.
// down the stack, each layer adds to Header. Note that forwarded
// packets don't populate Headers on their way out -- their headers and
// payload are never parsed out and remain in Data.
Header buffer.Prependable
// These fields are used by both inbound and outbound packets. They

View File

@ -95,16 +95,18 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffe
f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++
// Consume the network header.
b := pkt.Data.First()
b, ok := pkt.Data.PullUp(fakeNetHeaderLen)
if !ok {
return
}
pkt.Data.TrimFront(fakeNetHeaderLen)
// Handle control packets.
if b[2] == uint8(fakeControlProtocol) {
nb := pkt.Data.First()
if len(nb) < fakeNetHeaderLen {
nb, ok := pkt.Data.PullUp(fakeNetHeaderLen)
if !ok {
return
}
pkt.Data.TrimFront(fakeNetHeaderLen)
f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt)
return

View File

@ -642,10 +642,11 @@ func TestTransportForwarding(t *testing.T) {
t.Fatal("Response packet not forwarded")
}
if dst := p.Pkt.Header.View()[0]; dst != 3 {
hdrs := p.Pkt.Data.ToView()
if dst := hdrs[0]; dst != 3 {
t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst)
}
if src := p.Pkt.Header.View()[1]; src != 1 {
if src := hdrs[1]; src != 1 {
t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src)
}
}

View File

@ -747,15 +747,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
// Only accept echo replies.
switch e.NetProto {
case header.IPv4ProtocolNumber:
h := header.ICMPv4(pkt.Data.First())
if h.Type() != header.ICMPv4EchoReply {
h, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize)
if !ok || header.ICMPv4(h).Type() != header.ICMPv4EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
}
case header.IPv6ProtocolNumber:
h := header.ICMPv6(pkt.Data.First())
if h.Type() != header.ICMPv6EchoReply {
h, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize)
if !ok || header.ICMPv6(h).Type() != header.ICMPv6EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return

View File

@ -144,7 +144,11 @@ func (s *segment) logicalLen() seqnum.Size {
// TCP checksum and stores the checksum and result of checksum verification in
// the csum and csumValid fields of the segment.
func (s *segment) parse() bool {
h := header.TCP(s.data.First())
h, ok := s.data.PullUp(header.TCPMinimumSize)
if !ok {
return false
}
hdr := header.TCP(h)
// h is the header followed by the payload. We check that the offset to
// the data respects the following constraints:
@ -156,12 +160,16 @@ func (s *segment) parse() bool {
// N.B. The segment has already been validated as having at least the
// minimum TCP size before reaching here, so it's safe to read the
// fields.
offset := int(h.DataOffset())
if offset < header.TCPMinimumSize || offset > len(h) {
offset := int(hdr.DataOffset())
if offset < header.TCPMinimumSize {
return false
}
hdrWithOpts, ok := s.data.PullUp(offset)
if !ok {
return false
}
s.options = []byte(h[header.TCPMinimumSize:offset])
s.options = []byte(hdrWithOpts[header.TCPMinimumSize:])
s.parsedOptions = header.ParseTCPOptions(s.options)
// Query the link capabilities to decide if checksum validation is
@ -173,18 +181,19 @@ func (s *segment) parse() bool {
s.data.TrimFront(offset)
}
if verifyChecksum {
s.csum = h.Checksum()
hdr = header.TCP(hdrWithOpts)
s.csum = hdr.Checksum()
xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size()))
xsum = h.CalculateChecksum(xsum)
xsum = hdr.CalculateChecksum(xsum)
s.data.TrimFront(offset)
xsum = header.ChecksumVV(s.data, xsum)
s.csumValid = xsum == 0xffff
}
s.sequenceNumber = seqnum.Value(h.SequenceNumber())
s.ackNumber = seqnum.Value(h.AckNumber())
s.flags = h.Flags()
s.window = seqnum.Size(h.WindowSize())
s.sequenceNumber = seqnum.Value(hdr.SequenceNumber())
s.ackNumber = seqnum.Value(hdr.AckNumber())
s.flags = hdr.Flags()
s.window = seqnum.Size(hdr.WindowSize())
return true
}

View File

@ -3548,7 +3548,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
tcpbuf := vv.First()[header.IPv4MinimumSize:]
tcpbuf := vv.ToView()[header.IPv4MinimumSize:]
tcpbuf[header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4
c.SendSegment(vv)
@ -3575,7 +3575,7 @@ func TestReceivedIncorrectChecksumIncrement(t *testing.T) {
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
tcpbuf := vv.First()[header.IPv4MinimumSize:]
tcpbuf := vv.ToView()[header.IPv4MinimumSize:]
// Overwrite a byte in the payload which should cause checksum
// verification to fail.
tcpbuf[(tcpbuf[header.TCPDataOffset]>>4)*4] = 0x4

View File

@ -1250,8 +1250,8 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// endpoint.
func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) {
// Get the header then trim it from the view.
hdr := header.UDP(pkt.Data.First())
if int(hdr.Length()) > pkt.Data.Size() {
hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize)
if !ok || int(header.UDP(hdr).Length()) > pkt.Data.Size() {
// Malformed packet.
e.stack.Stats().UDP.MalformedPacketsReceived.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
@ -1286,7 +1286,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
senderAddress: tcpip.FullAddress{
NIC: r.NICID(),
Addr: id.RemoteAddress,
Port: hdr.SourcePort(),
Port: header.UDP(hdr).SourcePort(),
},
}
packet.data = pkt.Data

View File

@ -68,8 +68,13 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// that don't match any existing endpoint.
func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool {
// Get the header then trim it from the view.
hdr := header.UDP(pkt.Data.First())
if int(hdr.Length()) > pkt.Data.Size() {
h, ok := pkt.Data.PullUp(header.UDPMinimumSize)
if !ok {
// Malformed packet.
r.Stack().Stats().UDP.MalformedPacketsReceived.Increment()
return true
}
if int(header.UDP(h).Length()) > pkt.Data.Size() {
// Malformed packet.
r.Stack().Stats().UDP.MalformedPacketsReceived.Increment()
return true