Do not perform IGMP/MLD on loopback interfaces

The loopback interface will never have any neighbouring nodes so
advertising its interest in multicast groups is unnecessary.

Bug #4682, #4861

Startblock:
  has LGTM from asfez
  and then
  add reviewer tamird
PiperOrigin-RevId: 346587604
This commit is contained in:
Ghanan Gowripalan 2020-12-09 10:50:42 -08:00 committed by Shentubot
parent a855a814d6
commit 50189b0d6f
4 changed files with 89 additions and 9 deletions

View File

@ -161,7 +161,8 @@ type GenericMulticastProtocolState struct {
// Init initializes the Generic Multicast Protocol state.
//
// Must only be called once for the lifetime of g.
// Must only be called once for the lifetime of g; Init will panic if it is
// called twice.
//
// The GenericMulticastProtocolState will only grab the lock when timers/jobs
// fire.
@ -170,9 +171,11 @@ func (g *GenericMulticastProtocolState) Init(protocolMU *sync.RWMutex, opts Gene
panic("attempted to initialize generic membership protocol state twice")
}
g.opts = opts
g.memberships = make(map[tcpip.Address]multicastGroupState)
g.protocolMU = protocolMU
*g = GenericMulticastProtocolState{
opts: opts,
memberships: make(map[tcpip.Address]multicastGroupState),
protocolMU: protocolMU,
}
}
// MakeAllNonMemberLocked transitions all groups to the non-member state.

View File

@ -57,6 +57,9 @@ type IGMPOptions struct {
// When enabled, IGMP may transmit IGMP report and leave messages when
// joining and leaving multicast groups respectively, and handle incoming
// IGMP packets.
//
// This field is ignored and is always assumed to be false for interfaces
// without neighbouring nodes (e.g. loopback).
Enabled bool
}
@ -69,6 +72,8 @@ type igmpState struct {
// The IPv4 endpoint this igmpState is for.
ep *endpoint
enabled bool
genericMulticastProtocol ip.GenericMulticastProtocolState
// igmpV1Present is for maintaining compatibility with IGMPv1 Routers, from
@ -117,8 +122,11 @@ func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
// Must only be called once for the lifetime of igmp.
func (igmp *igmpState) init(ep *endpoint) {
igmp.ep = ep
// No need to perform IGMP on loopback interfaces since they don't have
// neighbouring nodes.
igmp.enabled = ep.protocol.options.IGMP.Enabled && !igmp.ep.nic.IsLoopback()
igmp.genericMulticastProtocol.Init(&ep.mu.RWMutex, ip.GenericMulticastProtocolOptions{
Enabled: ep.protocol.options.IGMP.Enabled,
Enabled: igmp.enabled,
Rand: ep.protocol.stack.Rand(),
Clock: ep.protocol.stack.Clock(),
Protocol: igmp,
@ -210,7 +218,7 @@ func (igmp *igmpState) handleMembershipQuery(groupAddress tcpip.Address, maxResp
// As per RFC 2236 Section 6, Page 10: If the maximum response time is zero
// then change the state to note that an IGMPv1 router is present and
// schedule the query received Job.
if maxRespTime == 0 && igmp.ep.protocol.options.IGMP.Enabled {
if igmp.enabled && maxRespTime == 0 {
igmp.igmpV1Job.Cancel()
igmp.igmpV1Job.Schedule(v1RouterPresentTimeout)
igmp.setV1Present(true)

View File

@ -40,6 +40,9 @@ type MLDOptions struct {
// When enabled, MLD may transmit MLD report and done messages when
// joining and leaving multicast groups respectively, and handle incoming
// MLD packets.
//
// This field is ignored and is always assumed to be false for interfaces
// without neighbouring nodes (e.g. loopback).
Enabled bool
}
@ -72,7 +75,9 @@ func (mld *mldState) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
func (mld *mldState) init(ep *endpoint) {
mld.ep = ep
mld.genericMulticastProtocol.Init(&ep.mu.RWMutex, ip.GenericMulticastProtocolOptions{
Enabled: ep.protocol.options.MLD.Enabled,
// No need to perform MLD on loopback interfaces since they don't have
// neighbouring nodes.
Enabled: ep.protocol.options.MLD.Enabled && !mld.ep.nic.IsLoopback(),
Rand: ep.protocol.stack.Rand(),
Clock: ep.protocol.stack.Clock(),
Protocol: mld,

View File

@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@ -104,7 +105,14 @@ func createStack(t *testing.T, mgpEnabled bool) (*channel.Endpoint, *stack.Stack
// Create an endpoint of queue size 2, since no more than 2 packets are ever
// queued in the tests in this file.
e := channel.New(2, 1280, linkAddr)
e := channel.New(2, header.IPv6MinimumMTU, linkAddr)
s, clock := createStackWithLinkEndpoint(t, mgpEnabled, e)
return e, s, clock
}
func createStackWithLinkEndpoint(t *testing.T, mgpEnabled bool, e stack.LinkEndpoint) (*stack.Stack, *faketime.ManualClock) {
t.Helper()
clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
@ -125,7 +133,7 @@ func createStack(t *testing.T, mgpEnabled bool) (*channel.Endpoint, *stack.Stack
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
return e, s, clock
return s, clock
}
// createAndInjectIGMPPacket creates and injects an IGMP packet with the
@ -1067,3 +1075,59 @@ func TestMGPWithNICLifecycle(t *testing.T) {
})
}
}
// TestMGPDisabledOnLoopback tests that the multicast group protocol is not
// performed on loopback interfaces since they have no neighbours.
func TestMGPDisabledOnLoopback(t *testing.T) {
tests := []struct {
name string
protoNum tcpip.NetworkProtocolNumber
multicastAddr tcpip.Address
sentReportStat func(*stack.Stack) *tcpip.StatCounter
}{
{
name: "IGMP",
protoNum: ipv4.ProtocolNumber,
multicastAddr: ipv4MulticastAddr1,
sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
return s.Stats().IGMP.PacketsSent.V2MembershipReport
},
},
{
name: "MLD",
protoNum: ipv6.ProtocolNumber,
multicastAddr: ipv6MulticastAddr1,
sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s, clock := createStackWithLinkEndpoint(t, true /* mgpEnabled */, loopback.New())
sentReportStat := test.sentReportStat(s)
if got := sentReportStat.Value(); got != 0 {
t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
}
clock.Advance(time.Hour)
if got := sentReportStat.Value(); got != 0 {
t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
}
// Test joining a specific group explicitly and verify that no reports are
// sent.
if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
}
if got := sentReportStat.Value(); got != 0 {
t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
}
clock.Advance(time.Hour)
if got := sentReportStat.Value(); got != 0 {
t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
}
})
}
}