From 0e013d8b00dbc3ad96e98bc0405ec2e21887308e Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Tue, 21 Apr 2020 16:54:08 -0700 Subject: [PATCH] Don't ignore override if it is longer than layerStates PiperOrigin-RevId: 307708653 --- test/packetimpact/testbench/connections.go | 33 ++++++++++---- test/packetimpact/testbench/layers_test.go | 50 ++++++++++++++++++++++ 2 files changed, 75 insertions(+), 8 deletions(-) diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go index f84fd8ba7..00a366894 100644 --- a/test/packetimpact/testbench/connections.go +++ b/test/packetimpact/testbench/connections.go @@ -363,16 +363,33 @@ type Connection struct { // reverse is never a match. override overrides the default matchers for each // Layer. func (conn *Connection) match(override, received Layers) bool { - if len(received) < len(conn.layerStates) { + var layersToMatch int + if len(override) < len(conn.layerStates) { + layersToMatch = len(conn.layerStates) + } else { + layersToMatch = len(override) + } + if len(received) < layersToMatch { return false } - for i, s := range conn.layerStates { - toMatch := s.incoming(received[i]) - if toMatch == nil { - return false - } - if i < len(override) { - toMatch.merge(override[i]) + for i := 0; i < layersToMatch; i++ { + var toMatch Layer + if i < len(conn.layerStates) { + s := conn.layerStates[i] + toMatch = s.incoming(received[i]) + if toMatch == nil { + return false + } + if i < len(override) { + if err := toMatch.merge(override[i]); err != nil { + conn.t.Fatalf("failed to merge: %s", err) + } + } + } else { + toMatch = override[i] + if toMatch == nil { + conn.t.Fatalf("expect the overriding layers to be non-nil") + } } if !toMatch.match(received[i]) { return false diff --git a/test/packetimpact/testbench/layers_test.go b/test/packetimpact/testbench/layers_test.go index b32efda93..c99cf6312 100644 --- a/test/packetimpact/testbench/layers_test.go +++ b/test/packetimpact/testbench/layers_test.go @@ -154,3 +154,53 @@ func TestLayerStringFormat(t *testing.T) { }) } } + +func TestConnectionMatch(t *testing.T) { + conn := Connection{ + layerStates: []layerState{ðerState{}}, + } + protoNum0 := tcpip.NetworkProtocolNumber(0) + protoNum1 := tcpip.NetworkProtocolNumber(1) + for _, tt := range []struct { + description string + override, received Layers + wantMatch bool + }{ + { + description: "shorter override", + override: []Layer{&Ether{}}, + received: []Layer{&Ether{}, &Payload{Bytes: []byte("hello")}}, + wantMatch: true, + }, + { + description: "longer override", + override: []Layer{&Ether{}, &Payload{Bytes: []byte("hello")}}, + received: []Layer{&Ether{}}, + wantMatch: false, + }, + { + description: "ether layer mismatch", + override: []Layer{&Ether{Type: &protoNum0}}, + received: []Layer{&Ether{Type: &protoNum1}}, + wantMatch: false, + }, + { + description: "both nil", + override: nil, + received: nil, + wantMatch: false, + }, + { + description: "nil override", + override: nil, + received: []Layer{&Ether{}}, + wantMatch: true, + }, + } { + t.Run(tt.description, func(t *testing.T) { + if gotMatch := conn.match(tt.override, tt.received); gotMatch != tt.wantMatch { + t.Fatalf("conn.match(%s, %s) = %t, want %t", tt.override, tt.received, gotMatch, tt.wantMatch) + } + }) + } +}