Make fragmentation return a reassembled PacketBuffer
This allows later decoupling of the backing network buffer implementation. PiperOrigin-RevId: 354643297
This commit is contained in:
parent
45fe9fe9c6
commit
825c185dc5
|
@ -7,6 +7,7 @@ go_library(
|
|||
srcs = [
|
||||
"prependable.go",
|
||||
"view.go",
|
||||
"view_unsafe.go",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
|
|
@ -239,6 +239,16 @@ func (vv *VectorisedView) Size() int {
|
|||
return vv.size
|
||||
}
|
||||
|
||||
// MemSize returns the estimation size of the vv in memory, including backing
|
||||
// buffer data.
|
||||
func (vv *VectorisedView) MemSize() int {
|
||||
var size int
|
||||
for _, v := range vv.views {
|
||||
size += cap(v)
|
||||
}
|
||||
return size + cap(vv.views)*viewStructSize + vectorisedViewStructSize
|
||||
}
|
||||
|
||||
// ToView returns a single view containing the content of the vectorised view.
|
||||
//
|
||||
// If the vectorised view contains a single view, that view will be returned
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"io"
|
||||
"reflect"
|
||||
"testing"
|
||||
"unsafe"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/buffer"
|
||||
|
@ -578,3 +579,15 @@ func TestAppendView(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemSize(t *testing.T) {
|
||||
const perViewCap = 128
|
||||
views := make([]buffer.View, 2, 32)
|
||||
views[0] = make(buffer.View, 10, perViewCap)
|
||||
views[1] = make(buffer.View, 20, perViewCap)
|
||||
vv := buffer.NewVectorisedView(30, views)
|
||||
want := int(unsafe.Sizeof(vv)) + cap(views)*int(unsafe.Sizeof(views)) + 2*perViewCap
|
||||
if got := vv.MemSize(); got != want {
|
||||
t.Errorf("vv.MemSize() = %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
// Copyright 2021 The gVisor Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package buffer
|
||||
|
||||
import "unsafe"
|
||||
|
||||
const (
|
||||
vectorisedViewStructSize = int(unsafe.Sizeof(VectorisedView{}))
|
||||
viewStructSize = int(unsafe.Sizeof(View{}))
|
||||
)
|
|
@ -84,7 +84,7 @@ type Fragmentation struct {
|
|||
lowLimit int
|
||||
reassemblers map[FragmentID]*reassembler
|
||||
rList reassemblerList
|
||||
size int
|
||||
memSize int
|
||||
timeout time.Duration
|
||||
blockSize uint16
|
||||
clock tcpip.Clock
|
||||
|
@ -156,22 +156,22 @@ func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, rea
|
|||
// the protocol to identify a fragment.
|
||||
func (f *Fragmentation) Process(
|
||||
id FragmentID, first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (
|
||||
buffer.VectorisedView, uint8, bool, error) {
|
||||
*stack.PacketBuffer, uint8, bool, error) {
|
||||
if first > last {
|
||||
return buffer.VectorisedView{}, 0, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs)
|
||||
return nil, 0, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs)
|
||||
}
|
||||
|
||||
if first%f.blockSize != 0 {
|
||||
return buffer.VectorisedView{}, 0, false, fmt.Errorf("first=%d is not a multiple of block size=%d: %w", first, f.blockSize, ErrInvalidArgs)
|
||||
return nil, 0, false, fmt.Errorf("first=%d is not a multiple of block size=%d: %w", first, f.blockSize, ErrInvalidArgs)
|
||||
}
|
||||
|
||||
fragmentSize := last - first + 1
|
||||
if more && fragmentSize%f.blockSize != 0 {
|
||||
return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs)
|
||||
return nil, 0, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs)
|
||||
}
|
||||
|
||||
if l := pkt.Data.Size(); l != int(fragmentSize) {
|
||||
return buffer.VectorisedView{}, 0, false, fmt.Errorf("got fragment size=%d bytes not equal to the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs)
|
||||
return nil, 0, false, fmt.Errorf("got fragment size=%d bytes not equal to the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs)
|
||||
}
|
||||
|
||||
f.mu.Lock()
|
||||
|
@ -190,24 +190,24 @@ func (f *Fragmentation) Process(
|
|||
}
|
||||
f.mu.Unlock()
|
||||
|
||||
res, firstFragmentProto, done, consumed, err := r.process(first, last, more, proto, pkt)
|
||||
resPkt, firstFragmentProto, done, memConsumed, err := r.process(first, last, more, proto, pkt)
|
||||
if err != nil {
|
||||
// We probably got an invalid sequence of fragments. Just
|
||||
// discard the reassembler and move on.
|
||||
f.mu.Lock()
|
||||
f.release(r, false /* timedOut */)
|
||||
f.mu.Unlock()
|
||||
return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragmentation processing error: %w", err)
|
||||
return nil, 0, false, fmt.Errorf("fragmentation processing error: %w", err)
|
||||
}
|
||||
f.mu.Lock()
|
||||
f.size += consumed
|
||||
f.memSize += memConsumed
|
||||
if done {
|
||||
f.release(r, false /* timedOut */)
|
||||
}
|
||||
// Evict reassemblers if we are consuming more memory than highLimit until
|
||||
// we reach lowLimit.
|
||||
if f.size > f.highLimit {
|
||||
for f.size > f.lowLimit {
|
||||
if f.memSize > f.highLimit {
|
||||
for f.memSize > f.lowLimit {
|
||||
tail := f.rList.Back()
|
||||
if tail == nil {
|
||||
break
|
||||
|
@ -216,7 +216,7 @@ func (f *Fragmentation) Process(
|
|||
}
|
||||
}
|
||||
f.mu.Unlock()
|
||||
return res, firstFragmentProto, done, nil
|
||||
return resPkt, firstFragmentProto, done, nil
|
||||
}
|
||||
|
||||
func (f *Fragmentation) release(r *reassembler, timedOut bool) {
|
||||
|
@ -228,10 +228,10 @@ func (f *Fragmentation) release(r *reassembler, timedOut bool) {
|
|||
|
||||
delete(f.reassemblers, r.id)
|
||||
f.rList.Remove(r)
|
||||
f.size -= r.size
|
||||
if f.size < 0 {
|
||||
log.Printf("memory counter < 0 (%d), this is an accounting bug that requires investigation", f.size)
|
||||
f.size = 0
|
||||
f.memSize -= r.memSize
|
||||
if f.memSize < 0 {
|
||||
log.Printf("memory counter < 0 (%d), this is an accounting bug that requires investigation", f.memSize)
|
||||
f.memSize = 0
|
||||
}
|
||||
|
||||
if h := f.timeoutHandler; timedOut && h != nil {
|
||||
|
|
|
@ -16,7 +16,6 @@ package fragmentation
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -112,20 +111,20 @@ func TestFragmentationProcess(t *testing.T) {
|
|||
f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{}, nil)
|
||||
firstFragmentProto := c.in[0].proto
|
||||
for i, in := range c.in {
|
||||
vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.pkt)
|
||||
resPkt, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.pkt)
|
||||
if err != nil {
|
||||
t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %#v) failed: %s",
|
||||
in.id, in.first, in.last, in.more, in.proto, in.pkt, err)
|
||||
}
|
||||
if !reflect.DeepEqual(vv, c.out[i].vv) {
|
||||
t.Errorf("got Process(%+v, %d, %d, %t, %d, %#v) = (%X, _, _, _), want = (%X, _, _, _)",
|
||||
in.id, in.first, in.last, in.more, in.proto, in.pkt, vv.ToView(), c.out[i].vv.ToView())
|
||||
}
|
||||
if done != c.out[i].done {
|
||||
t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, _, %t, _), want = (_, _, %t, _)",
|
||||
in.id, in.first, in.last, in.more, in.proto, done, c.out[i].done)
|
||||
}
|
||||
if c.out[i].done {
|
||||
if diff := cmp.Diff(c.out[i].vv.ToOwnedView(), resPkt.Data.ToOwnedView()); diff != "" {
|
||||
t.Errorf("got Process(%+v, %d, %d, %t, %d, %#v) result mismatch (-want, +got):\n%s",
|
||||
in.id, in.first, in.last, in.more, in.proto, in.pkt, diff)
|
||||
}
|
||||
if firstFragmentProto != proto {
|
||||
t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, %d, _, _), want = (_, %d, _, _)",
|
||||
in.id, in.first, in.last, in.more, in.proto, proto, firstFragmentProto)
|
||||
|
@ -173,9 +172,17 @@ func TestReassemblingTimeout(t *testing.T) {
|
|||
// reassembly is done after the fragment is processd.
|
||||
expectDone bool
|
||||
|
||||
// sizeAfterEvent is the expected size of the fragmentation instance after
|
||||
// the event.
|
||||
sizeAfterEvent int
|
||||
// memSizeAfterEvent is the expected memory size of the fragmentation
|
||||
// instance after the event.
|
||||
memSizeAfterEvent int
|
||||
}
|
||||
|
||||
memSizeOfFrags := func(frags ...*fragment) int {
|
||||
var size int
|
||||
for _, frag := range frags {
|
||||
size += pkt(len(frag.data), frag.data).MemSize()
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
half1 := &fragment{first: 0, last: 0, more: true, data: "0"}
|
||||
|
@ -189,16 +196,16 @@ func TestReassemblingTimeout(t *testing.T) {
|
|||
name: "half1 and half2 are reassembled successfully",
|
||||
events: []event{
|
||||
{
|
||||
name: "half1",
|
||||
fragment: half1,
|
||||
expectDone: false,
|
||||
sizeAfterEvent: 1,
|
||||
name: "half1",
|
||||
fragment: half1,
|
||||
expectDone: false,
|
||||
memSizeAfterEvent: memSizeOfFrags(half1),
|
||||
},
|
||||
{
|
||||
name: "half2",
|
||||
fragment: half2,
|
||||
expectDone: true,
|
||||
sizeAfterEvent: 0,
|
||||
name: "half2",
|
||||
fragment: half2,
|
||||
expectDone: true,
|
||||
memSizeAfterEvent: 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -206,36 +213,36 @@ func TestReassemblingTimeout(t *testing.T) {
|
|||
name: "half1 timeout, half2 timeout",
|
||||
events: []event{
|
||||
{
|
||||
name: "half1",
|
||||
fragment: half1,
|
||||
expectDone: false,
|
||||
sizeAfterEvent: 1,
|
||||
name: "half1",
|
||||
fragment: half1,
|
||||
expectDone: false,
|
||||
memSizeAfterEvent: memSizeOfFrags(half1),
|
||||
},
|
||||
{
|
||||
name: "half1 just before reassembly timeout",
|
||||
clockAdvance: reassemblyTimeout - 1,
|
||||
sizeAfterEvent: 1,
|
||||
name: "half1 just before reassembly timeout",
|
||||
clockAdvance: reassemblyTimeout - 1,
|
||||
memSizeAfterEvent: memSizeOfFrags(half1),
|
||||
},
|
||||
{
|
||||
name: "half1 reassembly timeout",
|
||||
clockAdvance: 1,
|
||||
sizeAfterEvent: 0,
|
||||
name: "half1 reassembly timeout",
|
||||
clockAdvance: 1,
|
||||
memSizeAfterEvent: 0,
|
||||
},
|
||||
{
|
||||
name: "half2",
|
||||
fragment: half2,
|
||||
expectDone: false,
|
||||
sizeAfterEvent: 1,
|
||||
name: "half2",
|
||||
fragment: half2,
|
||||
expectDone: false,
|
||||
memSizeAfterEvent: memSizeOfFrags(half2),
|
||||
},
|
||||
{
|
||||
name: "half2 just before reassembly timeout",
|
||||
clockAdvance: reassemblyTimeout - 1,
|
||||
sizeAfterEvent: 1,
|
||||
name: "half2 just before reassembly timeout",
|
||||
clockAdvance: reassemblyTimeout - 1,
|
||||
memSizeAfterEvent: memSizeOfFrags(half2),
|
||||
},
|
||||
{
|
||||
name: "half2 reassembly timeout",
|
||||
clockAdvance: 1,
|
||||
sizeAfterEvent: 0,
|
||||
name: "half2 reassembly timeout",
|
||||
clockAdvance: 1,
|
||||
memSizeAfterEvent: 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -255,8 +262,8 @@ func TestReassemblingTimeout(t *testing.T) {
|
|||
t.Fatalf("%s: got done = %t, want = %t", event.name, done, event.expectDone)
|
||||
}
|
||||
}
|
||||
if got, want := f.size, event.sizeAfterEvent; got != want {
|
||||
t.Errorf("%s: got f.size = %d, want = %d", event.name, got, want)
|
||||
if got, want := f.memSize, event.memSizeAfterEvent; got != want {
|
||||
t.Errorf("%s: got f.memSize = %d, want = %d", event.name, got, want)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
@ -264,7 +271,9 @@ func TestReassemblingTimeout(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMemoryLimits(t *testing.T) {
|
||||
f := NewFragmentation(minBlockSize, 3, 1, reassembleTimeout, &faketime.NullClock{}, nil)
|
||||
lowLimit := pkt(1, "0").MemSize()
|
||||
highLimit := 3 * lowLimit // Allow at most 3 such packets.
|
||||
f := NewFragmentation(minBlockSize, highLimit, lowLimit, reassembleTimeout, &faketime.NullClock{}, nil)
|
||||
// Send first fragment with id = 0.
|
||||
f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0"))
|
||||
// Send first fragment with id = 1.
|
||||
|
@ -288,15 +297,14 @@ func TestMemoryLimits(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMemoryLimitsIgnoresDuplicates(t *testing.T) {
|
||||
f := NewFragmentation(minBlockSize, 1, 0, reassembleTimeout, &faketime.NullClock{}, nil)
|
||||
memSize := pkt(1, "0").MemSize()
|
||||
f := NewFragmentation(minBlockSize, memSize, 0, reassembleTimeout, &faketime.NullClock{}, nil)
|
||||
// Send first fragment with id = 0.
|
||||
f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0"))
|
||||
// Send the same packet again.
|
||||
f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0"))
|
||||
|
||||
got := f.size
|
||||
want := 1
|
||||
if got != want {
|
||||
if got, want := f.memSize, memSize; got != want {
|
||||
t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,7 +20,6 @@ import (
|
|||
|
||||
"gvisor.dev/gvisor/pkg/sync"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
)
|
||||
|
||||
|
@ -29,13 +28,15 @@ type hole struct {
|
|||
last uint16
|
||||
filled bool
|
||||
final bool
|
||||
data buffer.View
|
||||
// pkt is the fragment packet if hole is filled. We keep the whole pkt rather
|
||||
// than the fragmented payload to prevent binding to specific buffer types.
|
||||
pkt *stack.PacketBuffer
|
||||
}
|
||||
|
||||
type reassembler struct {
|
||||
reassemblerEntry
|
||||
id FragmentID
|
||||
size int
|
||||
memSize int
|
||||
proto uint8
|
||||
mu sync.Mutex
|
||||
holes []hole
|
||||
|
@ -59,18 +60,18 @@ func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler {
|
|||
return r
|
||||
}
|
||||
|
||||
func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (buffer.VectorisedView, uint8, bool, int, error) {
|
||||
func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (*stack.PacketBuffer, uint8, bool, int, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if r.done {
|
||||
// A concurrent goroutine might have already reassembled
|
||||
// the packet and emptied the heap while this goroutine
|
||||
// was waiting on the mutex. We don't have to do anything in this case.
|
||||
return buffer.VectorisedView{}, 0, false, 0, nil
|
||||
return nil, 0, false, 0, nil
|
||||
}
|
||||
|
||||
var holeFound bool
|
||||
var consumed int
|
||||
var memConsumed int
|
||||
for i := range r.holes {
|
||||
currentHole := &r.holes[i]
|
||||
|
||||
|
@ -90,12 +91,12 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s
|
|||
// https://github.com/torvalds/linux/blob/38525c6/net/ipv4/inet_fragment.c#L349
|
||||
if first < currentHole.first || currentHole.last < last {
|
||||
// Incoming fragment only partially fits in the free hole.
|
||||
return buffer.VectorisedView{}, 0, false, 0, ErrFragmentOverlap
|
||||
return nil, 0, false, 0, ErrFragmentOverlap
|
||||
}
|
||||
if !more {
|
||||
if !currentHole.final || currentHole.filled && currentHole.last != last {
|
||||
// We have another final fragment, which does not perfectly overlap.
|
||||
return buffer.VectorisedView{}, 0, false, 0, ErrFragmentConflict
|
||||
return nil, 0, false, 0, ErrFragmentConflict
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -124,16 +125,15 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s
|
|||
})
|
||||
currentHole.final = false
|
||||
}
|
||||
v := pkt.Data.ToOwnedView()
|
||||
consumed = v.Size()
|
||||
r.size += consumed
|
||||
memConsumed = pkt.MemSize()
|
||||
r.memSize += memConsumed
|
||||
// Update the current hole to precisely match the incoming fragment.
|
||||
r.holes[i] = hole{
|
||||
first: first,
|
||||
last: last,
|
||||
filled: true,
|
||||
final: currentHole.final,
|
||||
data: v,
|
||||
pkt: pkt,
|
||||
}
|
||||
r.filled++
|
||||
// For IPv6, it is possible to have different Protocol values between
|
||||
|
@ -153,25 +153,24 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s
|
|||
}
|
||||
if !holeFound {
|
||||
// Incoming fragment is beyond end.
|
||||
return buffer.VectorisedView{}, 0, false, 0, ErrFragmentConflict
|
||||
return nil, 0, false, 0, ErrFragmentConflict
|
||||
}
|
||||
|
||||
// Check if all the holes have been filled and we are ready to reassemble.
|
||||
if r.filled < len(r.holes) {
|
||||
return buffer.VectorisedView{}, 0, false, consumed, nil
|
||||
return nil, 0, false, memConsumed, nil
|
||||
}
|
||||
|
||||
sort.Slice(r.holes, func(i, j int) bool {
|
||||
return r.holes[i].first < r.holes[j].first
|
||||
})
|
||||
|
||||
var size int
|
||||
views := make([]buffer.View, 0, len(r.holes))
|
||||
for _, hole := range r.holes {
|
||||
views = append(views, hole.data)
|
||||
size += hole.data.Size()
|
||||
resPkt := r.holes[0].pkt
|
||||
for i := 1; i < len(r.holes); i++ {
|
||||
fragPkt := r.holes[i].pkt
|
||||
fragPkt.Data.ReadToVV(&resPkt.Data, fragPkt.Data.Size())
|
||||
}
|
||||
return buffer.NewVectorisedView(size, views), r.proto, true, consumed, nil
|
||||
return resPkt, r.proto, true, memConsumed, nil
|
||||
}
|
||||
|
||||
func (r *reassembler) checkDoneOrMark() bool {
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
package fragmentation
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
|
@ -44,16 +45,21 @@ func TestReassemblerProcess(t *testing.T) {
|
|||
return payload
|
||||
}
|
||||
|
||||
pkt := func(size int) *stack.PacketBuffer {
|
||||
pkt := func(sizes ...int) *stack.PacketBuffer {
|
||||
var vv buffer.VectorisedView
|
||||
for _, size := range sizes {
|
||||
vv.AppendView(v(size))
|
||||
}
|
||||
return stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Data: v(size).ToVectorisedView(),
|
||||
Data: vv,
|
||||
})
|
||||
}
|
||||
|
||||
var tests = []struct {
|
||||
name string
|
||||
params []processParams
|
||||
want []hole
|
||||
name string
|
||||
params []processParams
|
||||
want []hole
|
||||
wantPkt *stack.PacketBuffer
|
||||
}{
|
||||
{
|
||||
name: "No fragments",
|
||||
|
@ -64,7 +70,7 @@ func TestReassemblerProcess(t *testing.T) {
|
|||
name: "One fragment at beginning",
|
||||
params: []processParams{{first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}},
|
||||
want: []hole{
|
||||
{first: 0, last: 1, filled: true, final: false, data: v(2)},
|
||||
{first: 0, last: 1, filled: true, final: false, pkt: pkt(2)},
|
||||
{first: 2, last: math.MaxUint16, filled: false, final: true},
|
||||
},
|
||||
},
|
||||
|
@ -72,7 +78,7 @@ func TestReassemblerProcess(t *testing.T) {
|
|||
name: "One fragment in the middle",
|
||||
params: []processParams{{first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil}},
|
||||
want: []hole{
|
||||
{first: 1, last: 2, filled: true, final: false, data: v(2)},
|
||||
{first: 1, last: 2, filled: true, final: false, pkt: pkt(2)},
|
||||
{first: 0, last: 0, filled: false, final: false},
|
||||
{first: 3, last: math.MaxUint16, filled: false, final: true},
|
||||
},
|
||||
|
@ -81,7 +87,7 @@ func TestReassemblerProcess(t *testing.T) {
|
|||
name: "One fragment at the end",
|
||||
params: []processParams{{first: 1, last: 2, more: false, pkt: pkt(2), wantDone: false, wantError: nil}},
|
||||
want: []hole{
|
||||
{first: 1, last: 2, filled: true, final: true, data: v(2)},
|
||||
{first: 1, last: 2, filled: true, final: true, pkt: pkt(2)},
|
||||
{first: 0, last: 0, filled: false},
|
||||
},
|
||||
},
|
||||
|
@ -89,8 +95,9 @@ func TestReassemblerProcess(t *testing.T) {
|
|||
name: "One fragment completing a packet",
|
||||
params: []processParams{{first: 0, last: 1, more: false, pkt: pkt(2), wantDone: true, wantError: nil}},
|
||||
want: []hole{
|
||||
{first: 0, last: 1, filled: true, final: true, data: v(2)},
|
||||
{first: 0, last: 1, filled: true, final: true},
|
||||
},
|
||||
wantPkt: pkt(2),
|
||||
},
|
||||
{
|
||||
name: "Two fragments completing a packet",
|
||||
|
@ -99,9 +106,10 @@ func TestReassemblerProcess(t *testing.T) {
|
|||
{first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil},
|
||||
},
|
||||
want: []hole{
|
||||
{first: 0, last: 1, filled: true, final: false, data: v(2)},
|
||||
{first: 2, last: 3, filled: true, final: true, data: v(2)},
|
||||
{first: 0, last: 1, filled: true, final: false},
|
||||
{first: 2, last: 3, filled: true, final: true},
|
||||
},
|
||||
wantPkt: pkt(2, 2),
|
||||
},
|
||||
{
|
||||
name: "Two fragments completing a packet with a duplicate",
|
||||
|
@ -111,9 +119,10 @@ func TestReassemblerProcess(t *testing.T) {
|
|||
{first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil},
|
||||
},
|
||||
want: []hole{
|
||||
{first: 0, last: 1, filled: true, final: false, data: v(2)},
|
||||
{first: 2, last: 3, filled: true, final: true, data: v(2)},
|
||||
{first: 0, last: 1, filled: true, final: false},
|
||||
{first: 2, last: 3, filled: true, final: true},
|
||||
},
|
||||
wantPkt: pkt(2, 2),
|
||||
},
|
||||
{
|
||||
name: "Two fragments completing a packet with a partial duplicate",
|
||||
|
@ -123,9 +132,10 @@ func TestReassemblerProcess(t *testing.T) {
|
|||
{first: 4, last: 5, more: false, pkt: pkt(2), wantDone: true, wantError: nil},
|
||||
},
|
||||
want: []hole{
|
||||
{first: 0, last: 3, filled: true, final: false, data: v(4)},
|
||||
{first: 4, last: 5, filled: true, final: true, data: v(2)},
|
||||
{first: 0, last: 3, filled: true, final: false},
|
||||
{first: 4, last: 5, filled: true, final: true},
|
||||
},
|
||||
wantPkt: pkt(4, 2),
|
||||
},
|
||||
{
|
||||
name: "Two overlapping fragments",
|
||||
|
@ -134,7 +144,7 @@ func TestReassemblerProcess(t *testing.T) {
|
|||
{first: 5, last: 15, more: false, pkt: pkt(11), wantDone: false, wantError: ErrFragmentOverlap},
|
||||
},
|
||||
want: []hole{
|
||||
{first: 0, last: 10, filled: true, final: false, data: v(11)},
|
||||
{first: 0, last: 10, filled: true, final: false, pkt: pkt(11)},
|
||||
{first: 11, last: math.MaxUint16, filled: false, final: true},
|
||||
},
|
||||
},
|
||||
|
@ -145,7 +155,7 @@ func TestReassemblerProcess(t *testing.T) {
|
|||
{first: 0, last: 9, more: false, pkt: pkt(10), wantDone: false, wantError: ErrFragmentConflict},
|
||||
},
|
||||
want: []hole{
|
||||
{first: 10, last: 14, filled: true, final: true, data: v(5)},
|
||||
{first: 10, last: 14, filled: true, final: true, pkt: pkt(5)},
|
||||
{first: 0, last: 9, filled: false, final: false},
|
||||
},
|
||||
},
|
||||
|
@ -156,7 +166,7 @@ func TestReassemblerProcess(t *testing.T) {
|
|||
{first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil},
|
||||
},
|
||||
want: []hole{
|
||||
{first: 5, last: 14, filled: true, final: true, data: v(10)},
|
||||
{first: 5, last: 14, filled: true, final: true, pkt: pkt(10)},
|
||||
{first: 0, last: 4, filled: false, final: false},
|
||||
},
|
||||
},
|
||||
|
@ -167,7 +177,7 @@ func TestReassemblerProcess(t *testing.T) {
|
|||
{first: 10, last: 13, more: false, pkt: pkt(4), wantDone: false, wantError: ErrFragmentConflict},
|
||||
},
|
||||
want: []hole{
|
||||
{first: 5, last: 14, filled: true, final: true, data: v(10)},
|
||||
{first: 5, last: 14, filled: true, final: true, pkt: pkt(10)},
|
||||
{first: 0, last: 4, filled: false, final: false},
|
||||
},
|
||||
},
|
||||
|
@ -176,14 +186,47 @@ func TestReassemblerProcess(t *testing.T) {
|
|||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
r := newReassembler(FragmentID{}, &faketime.NullClock{})
|
||||
var resPkt *stack.PacketBuffer
|
||||
var isDone bool
|
||||
for _, param := range test.params {
|
||||
_, _, done, _, err := r.process(param.first, param.last, param.more, proto, param.pkt)
|
||||
pkt, _, done, _, err := r.process(param.first, param.last, param.more, proto, param.pkt)
|
||||
if done != param.wantDone || err != param.wantError {
|
||||
t.Errorf("got r.process(%d, %d, %t, %d, _) = (_, _, %t, _, %v), want = (%t, %v)", param.first, param.last, param.more, proto, done, err, param.wantDone, param.wantError)
|
||||
}
|
||||
if done {
|
||||
resPkt = pkt
|
||||
isDone = true
|
||||
}
|
||||
}
|
||||
if diff := cmp.Diff(test.want, r.holes, cmp.AllowUnexported(hole{})); diff != "" {
|
||||
t.Errorf("r.holes mismatch (-want +got):\n%s", diff)
|
||||
|
||||
ignorePkt := func(a, b *stack.PacketBuffer) bool { return true }
|
||||
cmpPktData := func(a, b *stack.PacketBuffer) bool {
|
||||
if a == nil || b == nil {
|
||||
return a == b
|
||||
}
|
||||
return bytes.Equal(a.Data.ToOwnedView(), b.Data.ToOwnedView())
|
||||
}
|
||||
|
||||
if isDone {
|
||||
if diff := cmp.Diff(
|
||||
test.want, r.holes,
|
||||
cmp.AllowUnexported(hole{}),
|
||||
// Do not compare pkt in hole. Data will be altered.
|
||||
cmp.Comparer(ignorePkt),
|
||||
); diff != "" {
|
||||
t.Errorf("r.holes mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(test.wantPkt, resPkt, cmp.Comparer(cmpPktData)); diff != "" {
|
||||
t.Errorf("Reassembled pkt mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
} else {
|
||||
if diff := cmp.Diff(
|
||||
test.want, r.holes,
|
||||
cmp.AllowUnexported(hole{}),
|
||||
cmp.Comparer(cmpPktData),
|
||||
); diff != "" {
|
||||
t.Errorf("r.holes mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -44,6 +44,7 @@ go_test(
|
|||
"//pkg/tcpip/network/testutil",
|
||||
"//pkg/tcpip/stack",
|
||||
"//pkg/tcpip/transport/icmp",
|
||||
"//pkg/tcpip/transport/raw",
|
||||
"//pkg/tcpip/transport/tcp",
|
||||
"//pkg/tcpip/transport/udp",
|
||||
"//pkg/waiter",
|
||||
|
|
|
@ -740,7 +740,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
|
|||
}
|
||||
|
||||
proto := h.Protocol()
|
||||
data, _, ready, err := e.protocol.fragmentation.Process(
|
||||
resPkt, _, ready, err := e.protocol.fragmentation.Process(
|
||||
// As per RFC 791 section 2.3, the identification value is unique
|
||||
// for a source-destination pair and protocol.
|
||||
fragmentation.FragmentID{
|
||||
|
@ -763,7 +763,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
|
|||
if !ready {
|
||||
return
|
||||
}
|
||||
pkt.Data = data
|
||||
pkt = resPkt
|
||||
h = header.IPv4(pkt.NetworkHeader().View())
|
||||
|
||||
// The reassembler doesn't take care of fixing up the header, so we need
|
||||
// to do it here.
|
||||
|
|
|
@ -38,6 +38,7 @@ import (
|
|||
"gvisor.dev/gvisor/pkg/tcpip/network/testutil"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/raw"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
|
@ -2058,7 +2059,7 @@ func TestReceiveFragments(t *testing.T) {
|
|||
// the fragment block size of 8 (RFC 791 section 3.1 page 14).
|
||||
ipv4Payload3Addr1ToAddr2 := udpGen(127, 3, addr1, addr2)
|
||||
udpPayload3Addr1ToAddr2 := ipv4Payload3Addr1ToAddr2[header.UDPMinimumSize:]
|
||||
// Used to test the max reassembled payload length (65,535 octets).
|
||||
// Used to test the max reassembled IPv4 payload length.
|
||||
ipv4Payload4Addr1ToAddr2 := udpGen(header.UDPMaximumSize-header.UDPMinimumSize, 4, addr1, addr2)
|
||||
udpPayload4Addr1ToAddr2 := ipv4Payload4Addr1ToAddr2[header.UDPMinimumSize:]
|
||||
|
||||
|
@ -2406,6 +2407,7 @@ func TestReceiveFragments(t *testing.T) {
|
|||
s := stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
|
||||
RawFactory: raw.EndpointFactory{},
|
||||
})
|
||||
e := channel.New(0, 1280, tcpip.LinkAddress("\xf0\x00"))
|
||||
if err := s.CreateNIC(nicID, e); err != nil {
|
||||
|
@ -2431,6 +2433,13 @@ func TestReceiveFragments(t *testing.T) {
|
|||
t.Fatalf("Bind(%+v): %s", bindAddr, err)
|
||||
}
|
||||
|
||||
// Bring up a raw endpoint so we can examine network headers.
|
||||
epRaw, err := s.NewRawEndpoint(udp.ProtocolNumber, header.IPv4ProtocolNumber, &wq, true /* associated */)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRawEndpoint(%d, %d, _, true): %s", udp.ProtocolNumber, header.IPv4ProtocolNumber, err)
|
||||
}
|
||||
defer epRaw.Close()
|
||||
|
||||
// Prepare and send the fragments.
|
||||
for _, frag := range test.fragments {
|
||||
hdr := buffer.NewPrependable(header.IPv4MinimumSize)
|
||||
|
@ -2462,10 +2471,11 @@ func TestReceiveFragments(t *testing.T) {
|
|||
}
|
||||
|
||||
for i, expectedPayload := range test.expectedPayloads {
|
||||
// Check UDP payload delivered by UDP endpoint.
|
||||
var buf bytes.Buffer
|
||||
result, err := ep.Read(&buf, tcpip.ReadOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("(i=%d) Read: %s", i, err)
|
||||
t.Fatalf("(i=%d) ep.Read: %s", i, err)
|
||||
}
|
||||
if diff := cmp.Diff(tcpip.ReadResult{
|
||||
Count: len(expectedPayload),
|
||||
|
@ -2474,7 +2484,24 @@ func TestReceiveFragments(t *testing.T) {
|
|||
t.Errorf("(i=%d) ep.Read: unexpected result (-want +got):\n%s", i, diff)
|
||||
}
|
||||
if diff := cmp.Diff(expectedPayload, buf.Bytes()); diff != "" {
|
||||
t.Errorf("(i=%d) got UDP payload mismatch (-want +got):\n%s", i, diff)
|
||||
t.Errorf("(i=%d) ep.Read: UDP payload mismatch (-want +got):\n%s", i, diff)
|
||||
}
|
||||
|
||||
// Check IPv4 header in packet delivered by raw endpoint.
|
||||
buf.Reset()
|
||||
result, err = epRaw.Read(&buf, tcpip.ReadOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("(i=%d) epRaw.Read: %s", i, err)
|
||||
}
|
||||
// Reassambly does not take care of checksum. Here we write our own
|
||||
// check routine instead of using checker.IPv4.
|
||||
ip := header.IPv4(buf.Bytes())
|
||||
for _, check := range []checker.NetworkChecker{
|
||||
checker.FragmentFlags(0),
|
||||
checker.FragmentOffset(0),
|
||||
checker.IPFullLength(uint16(header.IPv4MinimumSize + header.UDPMinimumSize + len(expectedPayload))),
|
||||
} {
|
||||
check(t, []header.Network{ip})
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1167,7 +1167,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
|
|||
|
||||
// Note that pkt doesn't have its transport header set after reassembly,
|
||||
// and won't until DeliverNetworkPacket sets it.
|
||||
data, proto, ready, err := e.protocol.fragmentation.Process(
|
||||
resPkt, proto, ready, err := e.protocol.fragmentation.Process(
|
||||
// IPv6 ignores the Protocol field since the ID only needs to be unique
|
||||
// across source-destination pairs, as per RFC 8200 section 4.5.
|
||||
fragmentation.FragmentID{
|
||||
|
@ -1188,7 +1188,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
|
|||
}
|
||||
|
||||
if ready {
|
||||
pkt.Data = data
|
||||
pkt = resPkt
|
||||
|
||||
// We create a new iterator with the reassembled packet because we could
|
||||
// have more extension headers in the reassembled payload, as per RFC
|
||||
|
|
|
@ -72,6 +72,7 @@ go_library(
|
|||
"nud.go",
|
||||
"packet_buffer.go",
|
||||
"packet_buffer_list.go",
|
||||
"packet_buffer_unsafe.go",
|
||||
"pending_packets.go",
|
||||
"rand.go",
|
||||
"registration.go",
|
||||
|
|
|
@ -187,6 +187,12 @@ func (pk *PacketBuffer) Size() int {
|
|||
return pk.HeaderSize() + pk.Data.Size()
|
||||
}
|
||||
|
||||
// MemSize returns the estimation size of the pk in memory, including backing
|
||||
// buffer data.
|
||||
func (pk *PacketBuffer) MemSize() int {
|
||||
return pk.HeaderSize() + pk.Data.MemSize() + packetBufferStructSize
|
||||
}
|
||||
|
||||
// Views returns the underlying storage of the whole packet.
|
||||
func (pk *PacketBuffer) Views() []buffer.View {
|
||||
// Optimization for outbound packets that headers are in pk.header.
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
// Copyright 2021 The gVisor Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package stack
|
||||
|
||||
import "unsafe"
|
||||
|
||||
const packetBufferStructSize = int(unsafe.Sizeof(PacketBuffer{}))
|
Loading…
Reference in New Issue