diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD index c9bcf9326..23aa0ad05 100644 --- a/pkg/tcpip/buffer/BUILD +++ b/pkg/tcpip/buffer/BUILD @@ -7,6 +7,7 @@ go_library( srcs = [ "prependable.go", "view.go", + "view_unsafe.go", ], visibility = ["//visibility:public"], ) diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 91cc62cc8..b05e81526 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -239,6 +239,16 @@ func (vv *VectorisedView) Size() int { return vv.size } +// MemSize returns the estimation size of the vv in memory, including backing +// buffer data. +func (vv *VectorisedView) MemSize() int { + var size int + for _, v := range vv.views { + size += cap(v) + } + return size + cap(vv.views)*viewStructSize + vectorisedViewStructSize +} + // ToView returns a single view containing the content of the vectorised view. // // If the vectorised view contains a single view, that view will be returned diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go index e7f7cc9f1..78b2faa26 100644 --- a/pkg/tcpip/buffer/view_test.go +++ b/pkg/tcpip/buffer/view_test.go @@ -20,6 +20,7 @@ import ( "io" "reflect" "testing" + "unsafe" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -578,3 +579,15 @@ func TestAppendView(t *testing.T) { } } } + +func TestMemSize(t *testing.T) { + const perViewCap = 128 + views := make([]buffer.View, 2, 32) + views[0] = make(buffer.View, 10, perViewCap) + views[1] = make(buffer.View, 20, perViewCap) + vv := buffer.NewVectorisedView(30, views) + want := int(unsafe.Sizeof(vv)) + cap(views)*int(unsafe.Sizeof(views)) + 2*perViewCap + if got := vv.MemSize(); got != want { + t.Errorf("vv.MemSize() = %d, want %d", got, want) + } +} diff --git a/pkg/tcpip/buffer/view_unsafe.go b/pkg/tcpip/buffer/view_unsafe.go new file mode 100644 index 000000000..75ccd40f8 --- /dev/null +++ b/pkg/tcpip/buffer/view_unsafe.go @@ -0,0 +1,22 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffer + +import "unsafe" + +const ( + vectorisedViewStructSize = int(unsafe.Sizeof(VectorisedView{})) + viewStructSize = int(unsafe.Sizeof(View{})) +) diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index 1af87d713..243738951 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -84,7 +84,7 @@ type Fragmentation struct { lowLimit int reassemblers map[FragmentID]*reassembler rList reassemblerList - size int + memSize int timeout time.Duration blockSize uint16 clock tcpip.Clock @@ -156,22 +156,22 @@ func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, rea // the protocol to identify a fragment. func (f *Fragmentation) Process( id FragmentID, first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) ( - buffer.VectorisedView, uint8, bool, error) { + *stack.PacketBuffer, uint8, bool, error) { if first > last { - return buffer.VectorisedView{}, 0, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs) + return nil, 0, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs) } if first%f.blockSize != 0 { - return buffer.VectorisedView{}, 0, false, fmt.Errorf("first=%d is not a multiple of block size=%d: %w", first, f.blockSize, ErrInvalidArgs) + return nil, 0, false, fmt.Errorf("first=%d is not a multiple of block size=%d: %w", first, f.blockSize, ErrInvalidArgs) } fragmentSize := last - first + 1 if more && fragmentSize%f.blockSize != 0 { - return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs) + return nil, 0, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs) } if l := pkt.Data.Size(); l != int(fragmentSize) { - return buffer.VectorisedView{}, 0, false, fmt.Errorf("got fragment size=%d bytes not equal to the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs) + return nil, 0, false, fmt.Errorf("got fragment size=%d bytes not equal to the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs) } f.mu.Lock() @@ -190,24 +190,24 @@ func (f *Fragmentation) Process( } f.mu.Unlock() - res, firstFragmentProto, done, consumed, err := r.process(first, last, more, proto, pkt) + resPkt, firstFragmentProto, done, memConsumed, err := r.process(first, last, more, proto, pkt) if err != nil { // We probably got an invalid sequence of fragments. Just // discard the reassembler and move on. f.mu.Lock() f.release(r, false /* timedOut */) f.mu.Unlock() - return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragmentation processing error: %w", err) + return nil, 0, false, fmt.Errorf("fragmentation processing error: %w", err) } f.mu.Lock() - f.size += consumed + f.memSize += memConsumed if done { f.release(r, false /* timedOut */) } // Evict reassemblers if we are consuming more memory than highLimit until // we reach lowLimit. - if f.size > f.highLimit { - for f.size > f.lowLimit { + if f.memSize > f.highLimit { + for f.memSize > f.lowLimit { tail := f.rList.Back() if tail == nil { break @@ -216,7 +216,7 @@ func (f *Fragmentation) Process( } } f.mu.Unlock() - return res, firstFragmentProto, done, nil + return resPkt, firstFragmentProto, done, nil } func (f *Fragmentation) release(r *reassembler, timedOut bool) { @@ -228,10 +228,10 @@ func (f *Fragmentation) release(r *reassembler, timedOut bool) { delete(f.reassemblers, r.id) f.rList.Remove(r) - f.size -= r.size - if f.size < 0 { - log.Printf("memory counter < 0 (%d), this is an accounting bug that requires investigation", f.size) - f.size = 0 + f.memSize -= r.memSize + if f.memSize < 0 { + log.Printf("memory counter < 0 (%d), this is an accounting bug that requires investigation", f.memSize) + f.memSize = 0 } if h := f.timeoutHandler; timedOut && h != nil { diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go index 3a79688a8..905bbc19b 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation_test.go +++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go @@ -16,7 +16,6 @@ package fragmentation import ( "errors" - "reflect" "testing" "time" @@ -112,20 +111,20 @@ func TestFragmentationProcess(t *testing.T) { f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{}, nil) firstFragmentProto := c.in[0].proto for i, in := range c.in { - vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.pkt) + resPkt, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.pkt) if err != nil { t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %#v) failed: %s", in.id, in.first, in.last, in.more, in.proto, in.pkt, err) } - if !reflect.DeepEqual(vv, c.out[i].vv) { - t.Errorf("got Process(%+v, %d, %d, %t, %d, %#v) = (%X, _, _, _), want = (%X, _, _, _)", - in.id, in.first, in.last, in.more, in.proto, in.pkt, vv.ToView(), c.out[i].vv.ToView()) - } if done != c.out[i].done { t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, _, %t, _), want = (_, _, %t, _)", in.id, in.first, in.last, in.more, in.proto, done, c.out[i].done) } if c.out[i].done { + if diff := cmp.Diff(c.out[i].vv.ToOwnedView(), resPkt.Data.ToOwnedView()); diff != "" { + t.Errorf("got Process(%+v, %d, %d, %t, %d, %#v) result mismatch (-want, +got):\n%s", + in.id, in.first, in.last, in.more, in.proto, in.pkt, diff) + } if firstFragmentProto != proto { t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, %d, _, _), want = (_, %d, _, _)", in.id, in.first, in.last, in.more, in.proto, proto, firstFragmentProto) @@ -173,9 +172,17 @@ func TestReassemblingTimeout(t *testing.T) { // reassembly is done after the fragment is processd. expectDone bool - // sizeAfterEvent is the expected size of the fragmentation instance after - // the event. - sizeAfterEvent int + // memSizeAfterEvent is the expected memory size of the fragmentation + // instance after the event. + memSizeAfterEvent int + } + + memSizeOfFrags := func(frags ...*fragment) int { + var size int + for _, frag := range frags { + size += pkt(len(frag.data), frag.data).MemSize() + } + return size } half1 := &fragment{first: 0, last: 0, more: true, data: "0"} @@ -189,16 +196,16 @@ func TestReassemblingTimeout(t *testing.T) { name: "half1 and half2 are reassembled successfully", events: []event{ { - name: "half1", - fragment: half1, - expectDone: false, - sizeAfterEvent: 1, + name: "half1", + fragment: half1, + expectDone: false, + memSizeAfterEvent: memSizeOfFrags(half1), }, { - name: "half2", - fragment: half2, - expectDone: true, - sizeAfterEvent: 0, + name: "half2", + fragment: half2, + expectDone: true, + memSizeAfterEvent: 0, }, }, }, @@ -206,36 +213,36 @@ func TestReassemblingTimeout(t *testing.T) { name: "half1 timeout, half2 timeout", events: []event{ { - name: "half1", - fragment: half1, - expectDone: false, - sizeAfterEvent: 1, + name: "half1", + fragment: half1, + expectDone: false, + memSizeAfterEvent: memSizeOfFrags(half1), }, { - name: "half1 just before reassembly timeout", - clockAdvance: reassemblyTimeout - 1, - sizeAfterEvent: 1, + name: "half1 just before reassembly timeout", + clockAdvance: reassemblyTimeout - 1, + memSizeAfterEvent: memSizeOfFrags(half1), }, { - name: "half1 reassembly timeout", - clockAdvance: 1, - sizeAfterEvent: 0, + name: "half1 reassembly timeout", + clockAdvance: 1, + memSizeAfterEvent: 0, }, { - name: "half2", - fragment: half2, - expectDone: false, - sizeAfterEvent: 1, + name: "half2", + fragment: half2, + expectDone: false, + memSizeAfterEvent: memSizeOfFrags(half2), }, { - name: "half2 just before reassembly timeout", - clockAdvance: reassemblyTimeout - 1, - sizeAfterEvent: 1, + name: "half2 just before reassembly timeout", + clockAdvance: reassemblyTimeout - 1, + memSizeAfterEvent: memSizeOfFrags(half2), }, { - name: "half2 reassembly timeout", - clockAdvance: 1, - sizeAfterEvent: 0, + name: "half2 reassembly timeout", + clockAdvance: 1, + memSizeAfterEvent: 0, }, }, }, @@ -255,8 +262,8 @@ func TestReassemblingTimeout(t *testing.T) { t.Fatalf("%s: got done = %t, want = %t", event.name, done, event.expectDone) } } - if got, want := f.size, event.sizeAfterEvent; got != want { - t.Errorf("%s: got f.size = %d, want = %d", event.name, got, want) + if got, want := f.memSize, event.memSizeAfterEvent; got != want { + t.Errorf("%s: got f.memSize = %d, want = %d", event.name, got, want) } } }) @@ -264,7 +271,9 @@ func TestReassemblingTimeout(t *testing.T) { } func TestMemoryLimits(t *testing.T) { - f := NewFragmentation(minBlockSize, 3, 1, reassembleTimeout, &faketime.NullClock{}, nil) + lowLimit := pkt(1, "0").MemSize() + highLimit := 3 * lowLimit // Allow at most 3 such packets. + f := NewFragmentation(minBlockSize, highLimit, lowLimit, reassembleTimeout, &faketime.NullClock{}, nil) // Send first fragment with id = 0. f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0")) // Send first fragment with id = 1. @@ -288,15 +297,14 @@ func TestMemoryLimits(t *testing.T) { } func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { - f := NewFragmentation(minBlockSize, 1, 0, reassembleTimeout, &faketime.NullClock{}, nil) + memSize := pkt(1, "0").MemSize() + f := NewFragmentation(minBlockSize, memSize, 0, reassembleTimeout, &faketime.NullClock{}, nil) // Send first fragment with id = 0. f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")) // Send the same packet again. f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")) - got := f.size - want := 1 - if got != want { + if got, want := f.memSize, memSize; got != want { t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want) } } diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go index 9b20bb1d8..933d63d32 100644 --- a/pkg/tcpip/network/fragmentation/reassembler.go +++ b/pkg/tcpip/network/fragmentation/reassembler.go @@ -20,7 +20,6 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -29,13 +28,15 @@ type hole struct { last uint16 filled bool final bool - data buffer.View + // pkt is the fragment packet if hole is filled. We keep the whole pkt rather + // than the fragmented payload to prevent binding to specific buffer types. + pkt *stack.PacketBuffer } type reassembler struct { reassemblerEntry id FragmentID - size int + memSize int proto uint8 mu sync.Mutex holes []hole @@ -59,18 +60,18 @@ func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler { return r } -func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (buffer.VectorisedView, uint8, bool, int, error) { +func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (*stack.PacketBuffer, uint8, bool, int, error) { r.mu.Lock() defer r.mu.Unlock() if r.done { // A concurrent goroutine might have already reassembled // the packet and emptied the heap while this goroutine // was waiting on the mutex. We don't have to do anything in this case. - return buffer.VectorisedView{}, 0, false, 0, nil + return nil, 0, false, 0, nil } var holeFound bool - var consumed int + var memConsumed int for i := range r.holes { currentHole := &r.holes[i] @@ -90,12 +91,12 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s // https://github.com/torvalds/linux/blob/38525c6/net/ipv4/inet_fragment.c#L349 if first < currentHole.first || currentHole.last < last { // Incoming fragment only partially fits in the free hole. - return buffer.VectorisedView{}, 0, false, 0, ErrFragmentOverlap + return nil, 0, false, 0, ErrFragmentOverlap } if !more { if !currentHole.final || currentHole.filled && currentHole.last != last { // We have another final fragment, which does not perfectly overlap. - return buffer.VectorisedView{}, 0, false, 0, ErrFragmentConflict + return nil, 0, false, 0, ErrFragmentConflict } } @@ -124,16 +125,15 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s }) currentHole.final = false } - v := pkt.Data.ToOwnedView() - consumed = v.Size() - r.size += consumed + memConsumed = pkt.MemSize() + r.memSize += memConsumed // Update the current hole to precisely match the incoming fragment. r.holes[i] = hole{ first: first, last: last, filled: true, final: currentHole.final, - data: v, + pkt: pkt, } r.filled++ // For IPv6, it is possible to have different Protocol values between @@ -153,25 +153,24 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s } if !holeFound { // Incoming fragment is beyond end. - return buffer.VectorisedView{}, 0, false, 0, ErrFragmentConflict + return nil, 0, false, 0, ErrFragmentConflict } // Check if all the holes have been filled and we are ready to reassemble. if r.filled < len(r.holes) { - return buffer.VectorisedView{}, 0, false, consumed, nil + return nil, 0, false, memConsumed, nil } sort.Slice(r.holes, func(i, j int) bool { return r.holes[i].first < r.holes[j].first }) - var size int - views := make([]buffer.View, 0, len(r.holes)) - for _, hole := range r.holes { - views = append(views, hole.data) - size += hole.data.Size() + resPkt := r.holes[0].pkt + for i := 1; i < len(r.holes); i++ { + fragPkt := r.holes[i].pkt + fragPkt.Data.ReadToVV(&resPkt.Data, fragPkt.Data.Size()) } - return buffer.NewVectorisedView(size, views), r.proto, true, consumed, nil + return resPkt, r.proto, true, memConsumed, nil } func (r *reassembler) checkDoneOrMark() bool { diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go index 2ff03eeeb..214a93709 100644 --- a/pkg/tcpip/network/fragmentation/reassembler_test.go +++ b/pkg/tcpip/network/fragmentation/reassembler_test.go @@ -15,6 +15,7 @@ package fragmentation import ( + "bytes" "math" "testing" @@ -44,16 +45,21 @@ func TestReassemblerProcess(t *testing.T) { return payload } - pkt := func(size int) *stack.PacketBuffer { + pkt := func(sizes ...int) *stack.PacketBuffer { + var vv buffer.VectorisedView + for _, size := range sizes { + vv.AppendView(v(size)) + } return stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: v(size).ToVectorisedView(), + Data: vv, }) } var tests = []struct { - name string - params []processParams - want []hole + name string + params []processParams + want []hole + wantPkt *stack.PacketBuffer }{ { name: "No fragments", @@ -64,7 +70,7 @@ func TestReassemblerProcess(t *testing.T) { name: "One fragment at beginning", params: []processParams{{first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}}, want: []hole{ - {first: 0, last: 1, filled: true, final: false, data: v(2)}, + {first: 0, last: 1, filled: true, final: false, pkt: pkt(2)}, {first: 2, last: math.MaxUint16, filled: false, final: true}, }, }, @@ -72,7 +78,7 @@ func TestReassemblerProcess(t *testing.T) { name: "One fragment in the middle", params: []processParams{{first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil}}, want: []hole{ - {first: 1, last: 2, filled: true, final: false, data: v(2)}, + {first: 1, last: 2, filled: true, final: false, pkt: pkt(2)}, {first: 0, last: 0, filled: false, final: false}, {first: 3, last: math.MaxUint16, filled: false, final: true}, }, @@ -81,7 +87,7 @@ func TestReassemblerProcess(t *testing.T) { name: "One fragment at the end", params: []processParams{{first: 1, last: 2, more: false, pkt: pkt(2), wantDone: false, wantError: nil}}, want: []hole{ - {first: 1, last: 2, filled: true, final: true, data: v(2)}, + {first: 1, last: 2, filled: true, final: true, pkt: pkt(2)}, {first: 0, last: 0, filled: false}, }, }, @@ -89,8 +95,9 @@ func TestReassemblerProcess(t *testing.T) { name: "One fragment completing a packet", params: []processParams{{first: 0, last: 1, more: false, pkt: pkt(2), wantDone: true, wantError: nil}}, want: []hole{ - {first: 0, last: 1, filled: true, final: true, data: v(2)}, + {first: 0, last: 1, filled: true, final: true}, }, + wantPkt: pkt(2), }, { name: "Two fragments completing a packet", @@ -99,9 +106,10 @@ func TestReassemblerProcess(t *testing.T) { {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, }, want: []hole{ - {first: 0, last: 1, filled: true, final: false, data: v(2)}, - {first: 2, last: 3, filled: true, final: true, data: v(2)}, + {first: 0, last: 1, filled: true, final: false}, + {first: 2, last: 3, filled: true, final: true}, }, + wantPkt: pkt(2, 2), }, { name: "Two fragments completing a packet with a duplicate", @@ -111,9 +119,10 @@ func TestReassemblerProcess(t *testing.T) { {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, }, want: []hole{ - {first: 0, last: 1, filled: true, final: false, data: v(2)}, - {first: 2, last: 3, filled: true, final: true, data: v(2)}, + {first: 0, last: 1, filled: true, final: false}, + {first: 2, last: 3, filled: true, final: true}, }, + wantPkt: pkt(2, 2), }, { name: "Two fragments completing a packet with a partial duplicate", @@ -123,9 +132,10 @@ func TestReassemblerProcess(t *testing.T) { {first: 4, last: 5, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, }, want: []hole{ - {first: 0, last: 3, filled: true, final: false, data: v(4)}, - {first: 4, last: 5, filled: true, final: true, data: v(2)}, + {first: 0, last: 3, filled: true, final: false}, + {first: 4, last: 5, filled: true, final: true}, }, + wantPkt: pkt(4, 2), }, { name: "Two overlapping fragments", @@ -134,7 +144,7 @@ func TestReassemblerProcess(t *testing.T) { {first: 5, last: 15, more: false, pkt: pkt(11), wantDone: false, wantError: ErrFragmentOverlap}, }, want: []hole{ - {first: 0, last: 10, filled: true, final: false, data: v(11)}, + {first: 0, last: 10, filled: true, final: false, pkt: pkt(11)}, {first: 11, last: math.MaxUint16, filled: false, final: true}, }, }, @@ -145,7 +155,7 @@ func TestReassemblerProcess(t *testing.T) { {first: 0, last: 9, more: false, pkt: pkt(10), wantDone: false, wantError: ErrFragmentConflict}, }, want: []hole{ - {first: 10, last: 14, filled: true, final: true, data: v(5)}, + {first: 10, last: 14, filled: true, final: true, pkt: pkt(5)}, {first: 0, last: 9, filled: false, final: false}, }, }, @@ -156,7 +166,7 @@ func TestReassemblerProcess(t *testing.T) { {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil}, }, want: []hole{ - {first: 5, last: 14, filled: true, final: true, data: v(10)}, + {first: 5, last: 14, filled: true, final: true, pkt: pkt(10)}, {first: 0, last: 4, filled: false, final: false}, }, }, @@ -167,7 +177,7 @@ func TestReassemblerProcess(t *testing.T) { {first: 10, last: 13, more: false, pkt: pkt(4), wantDone: false, wantError: ErrFragmentConflict}, }, want: []hole{ - {first: 5, last: 14, filled: true, final: true, data: v(10)}, + {first: 5, last: 14, filled: true, final: true, pkt: pkt(10)}, {first: 0, last: 4, filled: false, final: false}, }, }, @@ -176,14 +186,47 @@ func TestReassemblerProcess(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { r := newReassembler(FragmentID{}, &faketime.NullClock{}) + var resPkt *stack.PacketBuffer + var isDone bool for _, param := range test.params { - _, _, done, _, err := r.process(param.first, param.last, param.more, proto, param.pkt) + pkt, _, done, _, err := r.process(param.first, param.last, param.more, proto, param.pkt) if done != param.wantDone || err != param.wantError { t.Errorf("got r.process(%d, %d, %t, %d, _) = (_, _, %t, _, %v), want = (%t, %v)", param.first, param.last, param.more, proto, done, err, param.wantDone, param.wantError) } + if done { + resPkt = pkt + isDone = true + } } - if diff := cmp.Diff(test.want, r.holes, cmp.AllowUnexported(hole{})); diff != "" { - t.Errorf("r.holes mismatch (-want +got):\n%s", diff) + + ignorePkt := func(a, b *stack.PacketBuffer) bool { return true } + cmpPktData := func(a, b *stack.PacketBuffer) bool { + if a == nil || b == nil { + return a == b + } + return bytes.Equal(a.Data.ToOwnedView(), b.Data.ToOwnedView()) + } + + if isDone { + if diff := cmp.Diff( + test.want, r.holes, + cmp.AllowUnexported(hole{}), + // Do not compare pkt in hole. Data will be altered. + cmp.Comparer(ignorePkt), + ); diff != "" { + t.Errorf("r.holes mismatch (-want +got):\n%s", diff) + } + if diff := cmp.Diff(test.wantPkt, resPkt, cmp.Comparer(cmpPktData)); diff != "" { + t.Errorf("Reassembled pkt mismatch (-want +got):\n%s", diff) + } + } else { + if diff := cmp.Diff( + test.want, r.holes, + cmp.AllowUnexported(hole{}), + cmp.Comparer(cmpPktData), + ); diff != "" { + t.Errorf("r.holes mismatch (-want +got):\n%s", diff) + } } }) } diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 330a7d170..9713c4448 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -44,6 +44,7 @@ go_test( "//pkg/tcpip/network/testutil", "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", + "//pkg/tcpip/transport/raw", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", "//pkg/waiter", diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index b0703715a..04c6a6708 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -740,7 +740,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { } proto := h.Protocol() - data, _, ready, err := e.protocol.fragmentation.Process( + resPkt, _, ready, err := e.protocol.fragmentation.Process( // As per RFC 791 section 2.3, the identification value is unique // for a source-destination pair and protocol. fragmentation.FragmentID{ @@ -763,7 +763,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { if !ready { return } - pkt.Data = data + pkt = resPkt + h = header.IPv4(pkt.NetworkHeader().View()) // The reassembler doesn't take care of fixing up the header, so we need // to do it here. diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index ed5899f0b..a296bed79 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -38,6 +38,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/raw" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" @@ -2058,7 +2059,7 @@ func TestReceiveFragments(t *testing.T) { // the fragment block size of 8 (RFC 791 section 3.1 page 14). ipv4Payload3Addr1ToAddr2 := udpGen(127, 3, addr1, addr2) udpPayload3Addr1ToAddr2 := ipv4Payload3Addr1ToAddr2[header.UDPMinimumSize:] - // Used to test the max reassembled payload length (65,535 octets). + // Used to test the max reassembled IPv4 payload length. ipv4Payload4Addr1ToAddr2 := udpGen(header.UDPMaximumSize-header.UDPMinimumSize, 4, addr1, addr2) udpPayload4Addr1ToAddr2 := ipv4Payload4Addr1ToAddr2[header.UDPMinimumSize:] @@ -2406,6 +2407,7 @@ func TestReceiveFragments(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + RawFactory: raw.EndpointFactory{}, }) e := channel.New(0, 1280, tcpip.LinkAddress("\xf0\x00")) if err := s.CreateNIC(nicID, e); err != nil { @@ -2431,6 +2433,13 @@ func TestReceiveFragments(t *testing.T) { t.Fatalf("Bind(%+v): %s", bindAddr, err) } + // Bring up a raw endpoint so we can examine network headers. + epRaw, err := s.NewRawEndpoint(udp.ProtocolNumber, header.IPv4ProtocolNumber, &wq, true /* associated */) + if err != nil { + t.Fatalf("NewRawEndpoint(%d, %d, _, true): %s", udp.ProtocolNumber, header.IPv4ProtocolNumber, err) + } + defer epRaw.Close() + // Prepare and send the fragments. for _, frag := range test.fragments { hdr := buffer.NewPrependable(header.IPv4MinimumSize) @@ -2462,10 +2471,11 @@ func TestReceiveFragments(t *testing.T) { } for i, expectedPayload := range test.expectedPayloads { + // Check UDP payload delivered by UDP endpoint. var buf bytes.Buffer result, err := ep.Read(&buf, tcpip.ReadOptions{}) if err != nil { - t.Fatalf("(i=%d) Read: %s", i, err) + t.Fatalf("(i=%d) ep.Read: %s", i, err) } if diff := cmp.Diff(tcpip.ReadResult{ Count: len(expectedPayload), @@ -2474,7 +2484,24 @@ func TestReceiveFragments(t *testing.T) { t.Errorf("(i=%d) ep.Read: unexpected result (-want +got):\n%s", i, diff) } if diff := cmp.Diff(expectedPayload, buf.Bytes()); diff != "" { - t.Errorf("(i=%d) got UDP payload mismatch (-want +got):\n%s", i, diff) + t.Errorf("(i=%d) ep.Read: UDP payload mismatch (-want +got):\n%s", i, diff) + } + + // Check IPv4 header in packet delivered by raw endpoint. + buf.Reset() + result, err = epRaw.Read(&buf, tcpip.ReadOptions{}) + if err != nil { + t.Fatalf("(i=%d) epRaw.Read: %s", i, err) + } + // Reassambly does not take care of checksum. Here we write our own + // check routine instead of using checker.IPv4. + ip := header.IPv4(buf.Bytes()) + for _, check := range []checker.NetworkChecker{ + checker.FragmentFlags(0), + checker.FragmentOffset(0), + checker.IPFullLength(uint16(header.IPv4MinimumSize + header.UDPMinimumSize + len(expectedPayload))), + } { + check(t, []header.Network{ip}) } } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 94043ed4e..caa62b3a2 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -1167,7 +1167,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // Note that pkt doesn't have its transport header set after reassembly, // and won't until DeliverNetworkPacket sets it. - data, proto, ready, err := e.protocol.fragmentation.Process( + resPkt, proto, ready, err := e.protocol.fragmentation.Process( // IPv6 ignores the Protocol field since the ID only needs to be unique // across source-destination pairs, as per RFC 8200 section 4.5. fragmentation.FragmentID{ @@ -1188,7 +1188,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { } if ready { - pkt.Data = data + pkt = resPkt // We create a new iterator with the reassembled packet because we could // have more extension headers in the reassembled payload, as per RFC diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index bb30556cf..ee23c9b98 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -72,6 +72,7 @@ go_library( "nud.go", "packet_buffer.go", "packet_buffer_list.go", + "packet_buffer_unsafe.go", "pending_packets.go", "rand.go", "registration.go", diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 9d4fc3e48..4f013b212 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -187,6 +187,12 @@ func (pk *PacketBuffer) Size() int { return pk.HeaderSize() + pk.Data.Size() } +// MemSize returns the estimation size of the pk in memory, including backing +// buffer data. +func (pk *PacketBuffer) MemSize() int { + return pk.HeaderSize() + pk.Data.MemSize() + packetBufferStructSize +} + // Views returns the underlying storage of the whole packet. func (pk *PacketBuffer) Views() []buffer.View { // Optimization for outbound packets that headers are in pk.header. diff --git a/pkg/tcpip/stack/packet_buffer_unsafe.go b/pkg/tcpip/stack/packet_buffer_unsafe.go new file mode 100644 index 000000000..ee3d47270 --- /dev/null +++ b/pkg/tcpip/stack/packet_buffer_unsafe.go @@ -0,0 +1,19 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import "unsafe" + +const packetBufferStructSize = int(unsafe.Sizeof(PacketBuffer{}))