Correct fragmentation reference counting.

Before this change the only reference on the packet after reassembly
processing was held by the reassembler in the holes array. This meant that
after the reassembly cleanup job, there were no references left on the
packet, leading to use after free bugs.

PiperOrigin-RevId: 424479461
This commit is contained in:
Lucas Manning 2022-01-26 17:24:05 -08:00 committed by gVisor bot
parent b5962471e1
commit 6a28dc7c59
5 changed files with 37 additions and 1 deletions

View File

@ -112,6 +112,9 @@ func TestFragmentationProcess(t *testing.T) {
for i, in := range c.in { for i, in := range c.in {
defer in.pkt.DecRef() defer in.pkt.DecRef()
resPkt, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.pkt) resPkt, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.pkt)
if resPkt != nil {
defer resPkt.DecRef()
}
if err != nil { if err != nil {
t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %#v) failed: %s", t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %#v) failed: %s",
in.id, in.first, in.last, in.more, in.proto, in.pkt, err) in.id, in.first, in.last, in.more, in.proto, in.pkt, err)
@ -258,7 +261,10 @@ func TestReassemblingTimeout(t *testing.T) {
if frag := event.fragment; frag != nil { if frag := event.fragment; frag != nil {
p := pkt(len(frag.data), frag.data) p := pkt(len(frag.data), frag.data)
defer p.DecRef() defer p.DecRef()
_, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, p) pkt, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, p)
if pkt != nil {
pkt.DecRef()
}
if err != nil { if err != nil {
t.Fatalf("%s: f.Process failed: %s", event.name, err) t.Fatalf("%s: f.Process failed: %s", event.name, err)
} }
@ -686,3 +692,18 @@ func TestTimeoutHandler(t *testing.T) {
}) })
} }
} }
func TestFragmentSurvivesReleaseJob(t *testing.T) {
handler := &testTimeoutHandler{pkt: nil}
c := faketime.NewManualClock()
f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, c, handler)
pkt := pkt(2, "01")
// Values to Process don't matter except for pkt.
resPkt, _, _, _ := f.Process(FragmentID{ID: 0}, 0, 1, false, 0, pkt)
pkt.DecRef()
// This clears out the references held by the reassembler.
c.Advance(reassembleTimeout)
// If Process doesn't give the returned packet its own reference, this will
// fail.
resPkt.DecRef()
}

View File

@ -170,6 +170,7 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s
}) })
resPkt := r.holes[0].pkt resPkt := r.holes[0].pkt
resPkt.IncRef()
for i := 1; i < len(r.holes); i++ { for i := 1; i < len(r.holes); i++ {
stack.MergeFragment(resPkt, r.holes[i].pkt) stack.MergeFragment(resPkt, r.holes[i].pkt)
} }

View File

@ -203,6 +203,9 @@ func TestReassemblerProcess(t *testing.T) {
var isDone bool var isDone bool
for _, param := range test.params { for _, param := range test.params {
pkt, _, done, _, err := r.process(param.first, param.last, param.more, proto, param.pkt) pkt, _, done, _, err := r.process(param.first, param.last, param.more, proto, param.pkt)
if pkt != nil {
defer pkt.DecRef()
}
if done != param.wantDone || err != param.wantError { if done != param.wantDone || err != param.wantError {
t.Errorf("got r.process(%d, %d, %t, %d, _) = (_, _, %t, _, %v), want = (%t, %v)", param.first, param.last, param.more, proto, done, err, param.wantDone, param.wantError) t.Errorf("got r.process(%d, %d, %t, %d, _) = (_, _, %t, _, %v), want = (%t, %v)", param.first, param.last, param.more, proto, done, err, param.wantDone, param.wantError)
} }

View File

@ -941,6 +941,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer,
if !ready { if !ready {
return return
} }
defer resPkt.DecRef()
pkt = resPkt pkt = resPkt
h = header.IPv4(pkt.NetworkHeader().View()) h = header.IPv4(pkt.NetworkHeader().View())

View File

@ -1136,7 +1136,16 @@ func (e *endpoint) processExtensionHeaders(h header.IPv6, pkt *stack.PacketBuffe
var ( var (
hasFragmentHeader bool hasFragmentHeader bool
routerAlert *header.IPv6RouterAlertOption routerAlert *header.IPv6RouterAlertOption
// Create an extra packet buffer reference to keep track of the packet to
// DecRef so that we do not incur a memory allocation for deferring a DecRef
// within the loop.
resPktToDecRef *stack.PacketBuffer
) )
defer func() {
if resPktToDecRef != nil {
resPktToDecRef.DecRef()
}
}()
for { for {
// Keep track of the start of the previous header so we can report the // Keep track of the start of the previous header so we can report the
@ -1392,6 +1401,7 @@ func (e *endpoint) processExtensionHeaders(h header.IPv6, pkt *stack.PacketBuffe
} }
if ready { if ready {
resPktToDecRef = resPkt
pkt = resPkt pkt = resPkt
// We create a new iterator with the reassembled packet because we could // We create a new iterator with the reassembled packet because we could