Don't ignore override if it is longer than layerStates

PiperOrigin-RevId: 307708653
This commit is contained in:
gVisor bot 2020-04-21 16:54:08 -07:00
parent 37e01fd2ea
commit 0e013d8b00
2 changed files with 75 additions and 8 deletions

View File

@ -363,16 +363,33 @@ type Connection struct {
// reverse is never a match. override overrides the default matchers for each
// Layer.
func (conn *Connection) match(override, received Layers) bool {
if len(received) < len(conn.layerStates) {
var layersToMatch int
if len(override) < len(conn.layerStates) {
layersToMatch = len(conn.layerStates)
} else {
layersToMatch = len(override)
}
if len(received) < layersToMatch {
return false
}
for i, s := range conn.layerStates {
toMatch := s.incoming(received[i])
if toMatch == nil {
return false
}
if i < len(override) {
toMatch.merge(override[i])
for i := 0; i < layersToMatch; i++ {
var toMatch Layer
if i < len(conn.layerStates) {
s := conn.layerStates[i]
toMatch = s.incoming(received[i])
if toMatch == nil {
return false
}
if i < len(override) {
if err := toMatch.merge(override[i]); err != nil {
conn.t.Fatalf("failed to merge: %s", err)
}
}
} else {
toMatch = override[i]
if toMatch == nil {
conn.t.Fatalf("expect the overriding layers to be non-nil")
}
}
if !toMatch.match(received[i]) {
return false

View File

@ -154,3 +154,53 @@ func TestLayerStringFormat(t *testing.T) {
})
}
}
func TestConnectionMatch(t *testing.T) {
conn := Connection{
layerStates: []layerState{&etherState{}},
}
protoNum0 := tcpip.NetworkProtocolNumber(0)
protoNum1 := tcpip.NetworkProtocolNumber(1)
for _, tt := range []struct {
description string
override, received Layers
wantMatch bool
}{
{
description: "shorter override",
override: []Layer{&Ether{}},
received: []Layer{&Ether{}, &Payload{Bytes: []byte("hello")}},
wantMatch: true,
},
{
description: "longer override",
override: []Layer{&Ether{}, &Payload{Bytes: []byte("hello")}},
received: []Layer{&Ether{}},
wantMatch: false,
},
{
description: "ether layer mismatch",
override: []Layer{&Ether{Type: &protoNum0}},
received: []Layer{&Ether{Type: &protoNum1}},
wantMatch: false,
},
{
description: "both nil",
override: nil,
received: nil,
wantMatch: false,
},
{
description: "nil override",
override: nil,
received: []Layer{&Ether{}},
wantMatch: true,
},
} {
t.Run(tt.description, func(t *testing.T) {
if gotMatch := conn.match(tt.override, tt.received); gotMatch != tt.wantMatch {
t.Fatalf("conn.match(%s, %s) = %t, want %t", tt.override, tt.received, gotMatch, tt.wantMatch)
}
})
}
}