diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go index 6f79abccf..18dbfe311 100644 --- a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go +++ b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go @@ -112,6 +112,9 @@ func TestFragmentationProcess(t *testing.T) { for i, in := range c.in { defer in.pkt.DecRef() resPkt, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.pkt) + if resPkt != nil { + defer resPkt.DecRef() + } if err != nil { t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %#v) failed: %s", in.id, in.first, in.last, in.more, in.proto, in.pkt, err) @@ -258,7 +261,10 @@ func TestReassemblingTimeout(t *testing.T) { if frag := event.fragment; frag != nil { p := pkt(len(frag.data), frag.data) defer p.DecRef() - _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, p) + pkt, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, p) + if pkt != nil { + pkt.DecRef() + } if err != nil { t.Fatalf("%s: f.Process failed: %s", event.name, err) } @@ -686,3 +692,18 @@ func TestTimeoutHandler(t *testing.T) { }) } } + +func TestFragmentSurvivesReleaseJob(t *testing.T) { + handler := &testTimeoutHandler{pkt: nil} + c := faketime.NewManualClock() + f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, c, handler) + pkt := pkt(2, "01") + // Values to Process don't matter except for pkt. + resPkt, _, _, _ := f.Process(FragmentID{ID: 0}, 0, 1, false, 0, pkt) + pkt.DecRef() + // This clears out the references held by the reassembler. + c.Advance(reassembleTimeout) + // If Process doesn't give the returned packet its own reference, this will + // fail. + resPkt.DecRef() +} diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler.go b/pkg/tcpip/network/internal/fragmentation/reassembler.go index 73873b003..f00c57b86 100644 --- a/pkg/tcpip/network/internal/fragmentation/reassembler.go +++ b/pkg/tcpip/network/internal/fragmentation/reassembler.go @@ -170,6 +170,7 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s }) resPkt := r.holes[0].pkt + resPkt.IncRef() for i := 1; i < len(r.holes); i++ { stack.MergeFragment(resPkt, r.holes[i].pkt) } diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler_test.go b/pkg/tcpip/network/internal/fragmentation/reassembler_test.go index 62c84485d..83a9b56c6 100644 --- a/pkg/tcpip/network/internal/fragmentation/reassembler_test.go +++ b/pkg/tcpip/network/internal/fragmentation/reassembler_test.go @@ -203,6 +203,9 @@ func TestReassemblerProcess(t *testing.T) { var isDone bool for _, param := range test.params { pkt, _, done, _, err := r.process(param.first, param.last, param.more, proto, param.pkt) + if pkt != nil { + defer pkt.DecRef() + } if done != param.wantDone || err != param.wantError { t.Errorf("got r.process(%d, %d, %t, %d, _) = (_, _, %t, _, %v), want = (%t, %v)", param.first, param.last, param.more, proto, done, err, param.wantDone, param.wantError) } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index bc01440db..c393aaa38 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -941,6 +941,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer, if !ready { return } + defer resPkt.DecRef() pkt = resPkt h = header.IPv4(pkt.NetworkHeader().View()) diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index a76bc25e1..fcecc9a82 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -1136,7 +1136,16 @@ func (e *endpoint) processExtensionHeaders(h header.IPv6, pkt *stack.PacketBuffe var ( hasFragmentHeader bool routerAlert *header.IPv6RouterAlertOption + // Create an extra packet buffer reference to keep track of the packet to + // DecRef so that we do not incur a memory allocation for deferring a DecRef + // within the loop. + resPktToDecRef *stack.PacketBuffer ) + defer func() { + if resPktToDecRef != nil { + resPktToDecRef.DecRef() + } + }() for { // Keep track of the start of the previous header so we can report the @@ -1392,6 +1401,7 @@ func (e *endpoint) processExtensionHeaders(h header.IPv6, pkt *stack.PacketBuffe } if ready { + resPktToDecRef = resPkt pkt = resPkt // We create a new iterator with the reassembled packet because we could