From 6a28dc7c59632b4007a095377073b8b74df85bea Mon Sep 17 00:00:00 2001 From: Lucas Manning Date: Wed, 26 Jan 2022 17:24:05 -0800 Subject: [PATCH] Correct fragmentation reference counting. Before this change the only reference on the packet after reassembly processing was held by the reassembler in the holes array. This meant that after the reassembly cleanup job, there were no references left on the packet, leading to use after free bugs. PiperOrigin-RevId: 424479461 --- .../fragmentation/fragmentation_test.go | 23 ++++++++++++++++++- .../internal/fragmentation/reassembler.go | 1 + .../fragmentation/reassembler_test.go | 3 +++ pkg/tcpip/network/ipv4/ipv4.go | 1 + pkg/tcpip/network/ipv6/ipv6.go | 10 ++++++++ 5 files changed, 37 insertions(+), 1 deletion(-) diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go index 6f79abccf..18dbfe311 100644 --- a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go +++ b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go @@ -112,6 +112,9 @@ func TestFragmentationProcess(t *testing.T) { for i, in := range c.in { defer in.pkt.DecRef() resPkt, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.pkt) + if resPkt != nil { + defer resPkt.DecRef() + } if err != nil { t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %#v) failed: %s", in.id, in.first, in.last, in.more, in.proto, in.pkt, err) @@ -258,7 +261,10 @@ func TestReassemblingTimeout(t *testing.T) { if frag := event.fragment; frag != nil { p := pkt(len(frag.data), frag.data) defer p.DecRef() - _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, p) + pkt, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, p) + if pkt != nil { + pkt.DecRef() + } if err != nil { t.Fatalf("%s: f.Process failed: %s", event.name, err) } @@ -686,3 +692,18 @@ func TestTimeoutHandler(t *testing.T) { }) } } + +func TestFragmentSurvivesReleaseJob(t *testing.T) { + handler := &testTimeoutHandler{pkt: nil} + c := faketime.NewManualClock() + f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, c, handler) + pkt := pkt(2, "01") + // Values to Process don't matter except for pkt. + resPkt, _, _, _ := f.Process(FragmentID{ID: 0}, 0, 1, false, 0, pkt) + pkt.DecRef() + // This clears out the references held by the reassembler. + c.Advance(reassembleTimeout) + // If Process doesn't give the returned packet its own reference, this will + // fail. + resPkt.DecRef() +} diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler.go b/pkg/tcpip/network/internal/fragmentation/reassembler.go index 73873b003..f00c57b86 100644 --- a/pkg/tcpip/network/internal/fragmentation/reassembler.go +++ b/pkg/tcpip/network/internal/fragmentation/reassembler.go @@ -170,6 +170,7 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s }) resPkt := r.holes[0].pkt + resPkt.IncRef() for i := 1; i < len(r.holes); i++ { stack.MergeFragment(resPkt, r.holes[i].pkt) } diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler_test.go b/pkg/tcpip/network/internal/fragmentation/reassembler_test.go index 62c84485d..83a9b56c6 100644 --- a/pkg/tcpip/network/internal/fragmentation/reassembler_test.go +++ b/pkg/tcpip/network/internal/fragmentation/reassembler_test.go @@ -203,6 +203,9 @@ func TestReassemblerProcess(t *testing.T) { var isDone bool for _, param := range test.params { pkt, _, done, _, err := r.process(param.first, param.last, param.more, proto, param.pkt) + if pkt != nil { + defer pkt.DecRef() + } if done != param.wantDone || err != param.wantError { t.Errorf("got r.process(%d, %d, %t, %d, _) = (_, _, %t, _, %v), want = (%t, %v)", param.first, param.last, param.more, proto, done, err, param.wantDone, param.wantError) } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index bc01440db..c393aaa38 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -941,6 +941,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer, if !ready { return } + defer resPkt.DecRef() pkt = resPkt h = header.IPv4(pkt.NetworkHeader().View()) diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index a76bc25e1..fcecc9a82 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -1136,7 +1136,16 @@ func (e *endpoint) processExtensionHeaders(h header.IPv6, pkt *stack.PacketBuffe var ( hasFragmentHeader bool routerAlert *header.IPv6RouterAlertOption + // Create an extra packet buffer reference to keep track of the packet to + // DecRef so that we do not incur a memory allocation for deferring a DecRef + // within the loop. + resPktToDecRef *stack.PacketBuffer ) + defer func() { + if resPktToDecRef != nil { + resPktToDecRef.DecRef() + } + }() for { // Keep track of the start of the previous header so we can report the @@ -1392,6 +1401,7 @@ func (e *endpoint) processExtensionHeaders(h header.IPv6, pkt *stack.PacketBuffe } if ready { + resPktToDecRef = resPkt pkt = resPkt // We create a new iterator with the reassembled packet because we could