Correct fragmentation reference counting.
Before this change the only reference on the packet after reassembly processing was held by the reassembler in the holes array. This meant that after the reassembly cleanup job, there were no references left on the packet, leading to use after free bugs. PiperOrigin-RevId: 424479461
This commit is contained in:
parent
b5962471e1
commit
6a28dc7c59
|
@ -112,6 +112,9 @@ func TestFragmentationProcess(t *testing.T) {
|
||||||
for i, in := range c.in {
|
for i, in := range c.in {
|
||||||
defer in.pkt.DecRef()
|
defer in.pkt.DecRef()
|
||||||
resPkt, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.pkt)
|
resPkt, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.pkt)
|
||||||
|
if resPkt != nil {
|
||||||
|
defer resPkt.DecRef()
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %#v) failed: %s",
|
t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %#v) failed: %s",
|
||||||
in.id, in.first, in.last, in.more, in.proto, in.pkt, err)
|
in.id, in.first, in.last, in.more, in.proto, in.pkt, err)
|
||||||
|
@ -258,7 +261,10 @@ func TestReassemblingTimeout(t *testing.T) {
|
||||||
if frag := event.fragment; frag != nil {
|
if frag := event.fragment; frag != nil {
|
||||||
p := pkt(len(frag.data), frag.data)
|
p := pkt(len(frag.data), frag.data)
|
||||||
defer p.DecRef()
|
defer p.DecRef()
|
||||||
_, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, p)
|
pkt, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, p)
|
||||||
|
if pkt != nil {
|
||||||
|
pkt.DecRef()
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("%s: f.Process failed: %s", event.name, err)
|
t.Fatalf("%s: f.Process failed: %s", event.name, err)
|
||||||
}
|
}
|
||||||
|
@ -686,3 +692,18 @@ func TestTimeoutHandler(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFragmentSurvivesReleaseJob(t *testing.T) {
|
||||||
|
handler := &testTimeoutHandler{pkt: nil}
|
||||||
|
c := faketime.NewManualClock()
|
||||||
|
f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, c, handler)
|
||||||
|
pkt := pkt(2, "01")
|
||||||
|
// Values to Process don't matter except for pkt.
|
||||||
|
resPkt, _, _, _ := f.Process(FragmentID{ID: 0}, 0, 1, false, 0, pkt)
|
||||||
|
pkt.DecRef()
|
||||||
|
// This clears out the references held by the reassembler.
|
||||||
|
c.Advance(reassembleTimeout)
|
||||||
|
// If Process doesn't give the returned packet its own reference, this will
|
||||||
|
// fail.
|
||||||
|
resPkt.DecRef()
|
||||||
|
}
|
||||||
|
|
|
@ -170,6 +170,7 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s
|
||||||
})
|
})
|
||||||
|
|
||||||
resPkt := r.holes[0].pkt
|
resPkt := r.holes[0].pkt
|
||||||
|
resPkt.IncRef()
|
||||||
for i := 1; i < len(r.holes); i++ {
|
for i := 1; i < len(r.holes); i++ {
|
||||||
stack.MergeFragment(resPkt, r.holes[i].pkt)
|
stack.MergeFragment(resPkt, r.holes[i].pkt)
|
||||||
}
|
}
|
||||||
|
|
|
@ -203,6 +203,9 @@ func TestReassemblerProcess(t *testing.T) {
|
||||||
var isDone bool
|
var isDone bool
|
||||||
for _, param := range test.params {
|
for _, param := range test.params {
|
||||||
pkt, _, done, _, err := r.process(param.first, param.last, param.more, proto, param.pkt)
|
pkt, _, done, _, err := r.process(param.first, param.last, param.more, proto, param.pkt)
|
||||||
|
if pkt != nil {
|
||||||
|
defer pkt.DecRef()
|
||||||
|
}
|
||||||
if done != param.wantDone || err != param.wantError {
|
if done != param.wantDone || err != param.wantError {
|
||||||
t.Errorf("got r.process(%d, %d, %t, %d, _) = (_, _, %t, _, %v), want = (%t, %v)", param.first, param.last, param.more, proto, done, err, param.wantDone, param.wantError)
|
t.Errorf("got r.process(%d, %d, %t, %d, _) = (_, _, %t, _, %v), want = (%t, %v)", param.first, param.last, param.more, proto, done, err, param.wantDone, param.wantError)
|
||||||
}
|
}
|
||||||
|
|
|
@ -941,6 +941,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer,
|
||||||
if !ready {
|
if !ready {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
defer resPkt.DecRef()
|
||||||
pkt = resPkt
|
pkt = resPkt
|
||||||
h = header.IPv4(pkt.NetworkHeader().View())
|
h = header.IPv4(pkt.NetworkHeader().View())
|
||||||
|
|
||||||
|
|
|
@ -1136,7 +1136,16 @@ func (e *endpoint) processExtensionHeaders(h header.IPv6, pkt *stack.PacketBuffe
|
||||||
var (
|
var (
|
||||||
hasFragmentHeader bool
|
hasFragmentHeader bool
|
||||||
routerAlert *header.IPv6RouterAlertOption
|
routerAlert *header.IPv6RouterAlertOption
|
||||||
|
// Create an extra packet buffer reference to keep track of the packet to
|
||||||
|
// DecRef so that we do not incur a memory allocation for deferring a DecRef
|
||||||
|
// within the loop.
|
||||||
|
resPktToDecRef *stack.PacketBuffer
|
||||||
)
|
)
|
||||||
|
defer func() {
|
||||||
|
if resPktToDecRef != nil {
|
||||||
|
resPktToDecRef.DecRef()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
// Keep track of the start of the previous header so we can report the
|
// Keep track of the start of the previous header so we can report the
|
||||||
|
@ -1392,6 +1401,7 @@ func (e *endpoint) processExtensionHeaders(h header.IPv6, pkt *stack.PacketBuffe
|
||||||
}
|
}
|
||||||
|
|
||||||
if ready {
|
if ready {
|
||||||
|
resPktToDecRef = resPkt
|
||||||
pkt = resPkt
|
pkt = resPkt
|
||||||
|
|
||||||
// We create a new iterator with the reassembled packet because we could
|
// We create a new iterator with the reassembled packet because we could
|
||||||
|
|
Loading…
Reference in New Issue