Don't ignore override if it is longer than layerStates
PiperOrigin-RevId: 307708653
This commit is contained in:
parent
37e01fd2ea
commit
0e013d8b00
|
@ -363,16 +363,33 @@ type Connection struct {
|
|||
// reverse is never a match. override overrides the default matchers for each
|
||||
// Layer.
|
||||
func (conn *Connection) match(override, received Layers) bool {
|
||||
if len(received) < len(conn.layerStates) {
|
||||
var layersToMatch int
|
||||
if len(override) < len(conn.layerStates) {
|
||||
layersToMatch = len(conn.layerStates)
|
||||
} else {
|
||||
layersToMatch = len(override)
|
||||
}
|
||||
if len(received) < layersToMatch {
|
||||
return false
|
||||
}
|
||||
for i, s := range conn.layerStates {
|
||||
toMatch := s.incoming(received[i])
|
||||
for i := 0; i < layersToMatch; i++ {
|
||||
var toMatch Layer
|
||||
if i < len(conn.layerStates) {
|
||||
s := conn.layerStates[i]
|
||||
toMatch = s.incoming(received[i])
|
||||
if toMatch == nil {
|
||||
return false
|
||||
}
|
||||
if i < len(override) {
|
||||
toMatch.merge(override[i])
|
||||
if err := toMatch.merge(override[i]); err != nil {
|
||||
conn.t.Fatalf("failed to merge: %s", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
toMatch = override[i]
|
||||
if toMatch == nil {
|
||||
conn.t.Fatalf("expect the overriding layers to be non-nil")
|
||||
}
|
||||
}
|
||||
if !toMatch.match(received[i]) {
|
||||
return false
|
||||
|
|
|
@ -154,3 +154,53 @@ func TestLayerStringFormat(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionMatch(t *testing.T) {
|
||||
conn := Connection{
|
||||
layerStates: []layerState{ðerState{}},
|
||||
}
|
||||
protoNum0 := tcpip.NetworkProtocolNumber(0)
|
||||
protoNum1 := tcpip.NetworkProtocolNumber(1)
|
||||
for _, tt := range []struct {
|
||||
description string
|
||||
override, received Layers
|
||||
wantMatch bool
|
||||
}{
|
||||
{
|
||||
description: "shorter override",
|
||||
override: []Layer{&Ether{}},
|
||||
received: []Layer{&Ether{}, &Payload{Bytes: []byte("hello")}},
|
||||
wantMatch: true,
|
||||
},
|
||||
{
|
||||
description: "longer override",
|
||||
override: []Layer{&Ether{}, &Payload{Bytes: []byte("hello")}},
|
||||
received: []Layer{&Ether{}},
|
||||
wantMatch: false,
|
||||
},
|
||||
{
|
||||
description: "ether layer mismatch",
|
||||
override: []Layer{&Ether{Type: &protoNum0}},
|
||||
received: []Layer{&Ether{Type: &protoNum1}},
|
||||
wantMatch: false,
|
||||
},
|
||||
{
|
||||
description: "both nil",
|
||||
override: nil,
|
||||
received: nil,
|
||||
wantMatch: false,
|
||||
},
|
||||
{
|
||||
description: "nil override",
|
||||
override: nil,
|
||||
received: []Layer{&Ether{}},
|
||||
wantMatch: true,
|
||||
},
|
||||
} {
|
||||
t.Run(tt.description, func(t *testing.T) {
|
||||
if gotMatch := conn.match(tt.override, tt.received); gotMatch != tt.wantMatch {
|
||||
t.Fatalf("conn.match(%s, %s) = %t, want %t", tt.override, tt.received, gotMatch, tt.wantMatch)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue