Disable a NIC before removing it

When a NIC is removed, attempt to disable the NIC first to cleanup
dynamic state and stop ongoing periodic tasks (e.g. IPv6 router
solicitations, DAD) so that a removed NIC does not attempt to send
packets.

Tests:
    - stack_test.TestRemoveUnknownNIC
    - stack_test.TestRemoveNIC
    - stack_test.TestDADStop
    - stack_test.TestCleanupNDPState
    - stack_test.TestRouteWithDownNIC
    - stack_test.TestStopStartSolicitingRouters
PiperOrigin-RevId: 300805857
This commit is contained in:
Ghanan Gowripalan 2020-03-13 12:29:19 -07:00 committed by gVisor bot
parent 86409c9181
commit 530a31f3c0
4 changed files with 417 additions and 233 deletions

View File

@ -639,8 +639,9 @@ func TestDADStop(t *testing.T) {
const nicID = 1
tests := []struct {
name string
stopFn func(t *testing.T, s *stack.Stack)
name string
stopFn func(t *testing.T, s *stack.Stack)
skipFinalAddrCheck bool
}{
// Tests to make sure that DAD stops when an address is removed.
{
@ -661,6 +662,19 @@ func TestDADStop(t *testing.T) {
}
},
},
// Tests to make sure that DAD stops when the NIC is removed.
{
name: "Remove NIC",
stopFn: func(t *testing.T, s *stack.Stack) {
if err := s.RemoveNIC(nicID); err != nil {
t.Fatalf("RemoveNIC(%d): %s", nicID, err)
}
},
// The NIC is removed so we can't check its addresses after calling
// stopFn.
skipFinalAddrCheck: true,
},
}
for _, test := range tests {
@ -710,12 +724,15 @@ func TestDADStop(t *testing.T) {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
}
addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
if err != nil {
t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
}
if want := (tcpip.AddressWithPrefix{}); addr != want {
t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
if !test.skipFinalAddrCheck {
addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
if err != nil {
t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
}
if want := (tcpip.AddressWithPrefix{}); addr != want {
t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
}
}
// Should not have sent more than 1 NS message.
@ -2983,11 +3000,12 @@ func TestCleanupNDPState(t *testing.T) {
cleanupFn func(t *testing.T, s *stack.Stack)
keepAutoGenLinkLocal bool
maxAutoGenAddrEvents int
skipFinalAddrCheck bool
}{
// A NIC should still keep its auto-generated link-local address when
// becoming a router.
{
name: "Forwarding Enable",
name: "Enable forwarding",
cleanupFn: func(t *testing.T, s *stack.Stack) {
t.Helper()
s.SetForwarding(true)
@ -2998,7 +3016,7 @@ func TestCleanupNDPState(t *testing.T) {
// A NIC should cleanup all NDP state when it is disabled.
{
name: "NIC Disable",
name: "Disable NIC",
cleanupFn: func(t *testing.T, s *stack.Stack) {
t.Helper()
@ -3012,6 +3030,26 @@ func TestCleanupNDPState(t *testing.T) {
keepAutoGenLinkLocal: false,
maxAutoGenAddrEvents: 6,
},
// A NIC should cleanup all NDP state when it is removed.
{
name: "Remove NIC",
cleanupFn: func(t *testing.T, s *stack.Stack) {
t.Helper()
if err := s.RemoveNIC(nicID1); err != nil {
t.Fatalf("s.RemoveNIC(%d): %s", nicID1, err)
}
if err := s.RemoveNIC(nicID2); err != nil {
t.Fatalf("s.RemoveNIC(%d): %s", nicID2, err)
}
},
keepAutoGenLinkLocal: false,
maxAutoGenAddrEvents: 6,
// The NICs are removed so we can't check their addresses after calling
// stopFn.
skipFinalAddrCheck: true,
},
}
for _, test := range tests {
@ -3230,35 +3268,37 @@ func TestCleanupNDPState(t *testing.T) {
t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff)
}
// Make sure the auto-generated addresses got removed.
nicinfo = s.NICInfo()
nic1Addrs = nicinfo[nicID1].ProtocolAddresses
nic2Addrs = nicinfo[nicID2].ProtocolAddresses
if containsV6Addr(nic1Addrs, llAddrWithPrefix1) != test.keepAutoGenLinkLocal {
if test.keepAutoGenLinkLocal {
t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
} else {
t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
if !test.skipFinalAddrCheck {
// Make sure the auto-generated addresses got removed.
nicinfo = s.NICInfo()
nic1Addrs = nicinfo[nicID1].ProtocolAddresses
nic2Addrs = nicinfo[nicID2].ProtocolAddresses
if containsV6Addr(nic1Addrs, llAddrWithPrefix1) != test.keepAutoGenLinkLocal {
if test.keepAutoGenLinkLocal {
t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
} else {
t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
}
}
}
if containsV6Addr(nic1Addrs, e1Addr1) {
t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
}
if containsV6Addr(nic1Addrs, e1Addr2) {
t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
}
if containsV6Addr(nic2Addrs, llAddrWithPrefix2) != test.keepAutoGenLinkLocal {
if test.keepAutoGenLinkLocal {
t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
} else {
t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
if containsV6Addr(nic1Addrs, e1Addr1) {
t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
}
if containsV6Addr(nic1Addrs, e1Addr2) {
t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
}
if containsV6Addr(nic2Addrs, llAddrWithPrefix2) != test.keepAutoGenLinkLocal {
if test.keepAutoGenLinkLocal {
t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
} else {
t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
}
}
if containsV6Addr(nic2Addrs, e2Addr1) {
t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
}
if containsV6Addr(nic2Addrs, e2Addr2) {
t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
}
}
if containsV6Addr(nic2Addrs, e2Addr1) {
t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
}
if containsV6Addr(nic2Addrs, e2Addr2) {
t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
}
// Should not get any more events (invalidation timers should have been
@ -3575,17 +3615,19 @@ func TestStopStartSolicitingRouters(t *testing.T) {
tests := []struct {
name string
startFn func(t *testing.T, s *stack.Stack)
stopFn func(t *testing.T, s *stack.Stack)
// first is used to tell stopFn that it is being called for the first time
// after router solicitations were last enabled.
stopFn func(t *testing.T, s *stack.Stack, first bool)
}{
// Tests that when forwarding is enabled or disabled, router solicitations
// are stopped or started, respectively.
{
name: "Forwarding enabled and disabled",
name: "Enable and disable forwarding",
startFn: func(t *testing.T, s *stack.Stack) {
t.Helper()
s.SetForwarding(false)
},
stopFn: func(t *testing.T, s *stack.Stack) {
stopFn: func(t *testing.T, s *stack.Stack, _ bool) {
t.Helper()
s.SetForwarding(true)
},
@ -3594,7 +3636,7 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Tests that when a NIC is enabled or disabled, router solicitations
// are started or stopped, respectively.
{
name: "NIC disabled and enabled",
name: "Enable and disable NIC",
startFn: func(t *testing.T, s *stack.Stack) {
t.Helper()
@ -3602,7 +3644,7 @@ func TestStopStartSolicitingRouters(t *testing.T) {
t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
}
},
stopFn: func(t *testing.T, s *stack.Stack) {
stopFn: func(t *testing.T, s *stack.Stack, _ bool) {
t.Helper()
if err := s.DisableNIC(nicID); err != nil {
@ -3610,6 +3652,25 @@ func TestStopStartSolicitingRouters(t *testing.T) {
}
},
},
// Tests that when a NIC is removed, router solicitations are stopped. We
// cannot start router solications on a removed NIC.
{
name: "Remove NIC",
stopFn: func(t *testing.T, s *stack.Stack, first bool) {
t.Helper()
// Only try to remove the NIC the first time stopFn is called since it's
// impossible to remove an already removed NIC.
if !first {
return
}
if err := s.RemoveNIC(nicID); err != nil {
t.Fatalf("s.RemoveNIC(%d): %s", nicID, err)
}
},
},
}
for _, test := range tests {
@ -3648,7 +3709,7 @@ func TestStopStartSolicitingRouters(t *testing.T) {
}
// Stop soliciting routers.
test.stopFn(t, s)
test.stopFn(t, s, true /* first */)
ctx, cancel := context.WithTimeout(context.Background(), delay+defaultTimeout)
defer cancel()
if _, ok := e.ReadContext(ctx); ok {
@ -3662,13 +3723,18 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Stopping router solicitations after it has already been stopped should
// do nothing.
test.stopFn(t, s)
test.stopFn(t, s, false /* first */)
ctx, cancel = context.WithTimeout(context.Background(), delay+defaultTimeout)
defer cancel()
if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got a packet after router solicitation has been stopepd")
}
// If test.startFn is nil, there is no way to restart router solications.
if test.startFn == nil {
return
}
// Start soliciting routers.
test.startFn(t, s)
waitForPkt(delay + defaultAsyncEventTimeout)

View File

@ -56,7 +56,7 @@ type NIC struct {
primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint
endpoints map[NetworkEndpointID]*referencedNetworkEndpoint
addressRanges []tcpip.Subnet
mcastJoins map[NetworkEndpointID]int32
mcastJoins map[NetworkEndpointID]uint32
// packetEPs is protected by mu, but the contained PacketEndpoint
// values are not.
packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint
@ -123,7 +123,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
}
nic.mu.primary = make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint)
nic.mu.endpoints = make(map[NetworkEndpointID]*referencedNetworkEndpoint)
nic.mu.mcastJoins = make(map[NetworkEndpointID]int32)
nic.mu.mcastJoins = make(map[NetworkEndpointID]uint32)
nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint)
nic.mu.ndp = ndpState{
nic: nic,
@ -167,8 +167,17 @@ func (n *NIC) disable() *tcpip.Error {
}
n.mu.Lock()
defer n.mu.Unlock()
err := n.disableLocked()
n.mu.Unlock()
return err
}
// disableLocked disables n.
//
// It undoes the work done by enable.
//
// n MUST be locked.
func (n *NIC) disableLocked() *tcpip.Error {
if !n.mu.enabled {
return nil
}
@ -191,7 +200,7 @@ func (n *NIC) disable() *tcpip.Error {
}
// The NIC may have already left the multicast group.
if err := n.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress {
if err := n.leaveGroupLocked(header.IPv6AllNodesMulticastAddress, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress {
return err
}
}
@ -307,24 +316,33 @@ func (n *NIC) remove() *tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
// Detach from link endpoint, so no packet comes in.
n.linkEP.Attach(nil)
n.disableLocked()
// TODO(b/151378115): come up with a better way to pick an error than the
// first one.
var err *tcpip.Error
// Forcefully leave multicast groups.
for nid := range n.mu.mcastJoins {
if tempErr := n.leaveGroupLocked(nid.LocalAddress, true /* force */); tempErr != nil && err == nil {
err = tempErr
}
}
// Remove permanent and permanentTentative addresses, so no packet goes out.
var errs []*tcpip.Error
for nid, ref := range n.mu.endpoints {
switch ref.getKind() {
case permanentTentative, permanent:
if err := n.removePermanentAddressLocked(nid.LocalAddress); err != nil {
errs = append(errs, err)
if tempErr := n.removePermanentAddressLocked(nid.LocalAddress); tempErr != nil && err == nil {
err = tempErr
}
}
}
if len(errs) > 0 {
return errs[0]
}
return nil
// Detach from link endpoint, so no packet comes in.
n.linkEP.Attach(nil)
return err
}
// becomeIPv6Router transitions n into an IPv6 router.
@ -971,6 +989,7 @@ func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
for i, ref := range refs {
if ref == r {
n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
refs[len(refs)-1] = nil
break
}
}
@ -1021,9 +1040,12 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
// If we are removing an IPv6 unicast address, leave the solicited-node
// multicast address.
//
// We ignore the tcpip.ErrBadLocalAddress error because the solicited-node
// multicast group may be left by user action.
if isIPv6Unicast {
snmc := header.SolicitedNodeAddr(addr)
if err := n.leaveGroupLocked(snmc); err != nil {
if err := n.leaveGroupLocked(snmc, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress {
return err
}
}
@ -1083,26 +1105,31 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
return n.leaveGroupLocked(addr)
return n.leaveGroupLocked(addr, false /* force */)
}
// leaveGroupLocked decrements the count for the given multicast address, and
// when it reaches zero removes the endpoint for this address. n MUST be locked
// before leaveGroupLocked is called.
func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
//
// If force is true, then the count for the multicast addres is ignored and the
// endpoint will be removed immediately.
func (n *NIC) leaveGroupLocked(addr tcpip.Address, force bool) *tcpip.Error {
id := NetworkEndpointID{addr}
joins := n.mu.mcastJoins[id]
switch joins {
case 0:
joins, ok := n.mu.mcastJoins[id]
if !ok {
// There are no joins with this address on this NIC.
return tcpip.ErrBadLocalAddress
case 1:
// This is the last one, clean up.
if err := n.removePermanentAddressLocked(addr); err != nil {
return err
}
}
n.mu.mcastJoins[id] = joins - 1
joins--
if force || joins == 0 {
// There are no outstanding joins or we are forced to leave, clean up.
delete(n.mu.mcastJoins, id)
return n.removePermanentAddressLocked(addr)
}
n.mu.mcastJoins[id] = joins
return nil
}

View File

@ -401,6 +401,9 @@ type LinkEndpoint interface {
// Attach attaches the data link layer endpoint to the network-layer
// dispatcher of the stack.
//
// Attach will be called with a nil dispatcher if the receiver's associated
// NIC is being removed.
Attach(dispatcher NetworkDispatcher)
// IsAttached returns whether a NetworkDispatcher is attached to the

View File

@ -255,7 +255,7 @@ type linkEPWithMockedAttach struct {
// Attach implements stack.LinkEndpoint.Attach.
func (l *linkEPWithMockedAttach) Attach(d stack.NetworkDispatcher) {
l.LinkEndpoint.Attach(d)
l.attached = true
l.attached = d != nil
}
func (l *linkEPWithMockedAttach) isAttached() bool {
@ -566,7 +566,7 @@ func TestAttachToLinkEndpointImmediately(t *testing.T) {
t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, test.nicOpts, err)
}
if !e.isAttached() {
t.Fatalf("link endpoint not attached to a network disatcher")
t.Fatal("link endpoint not attached to a network dispatcher")
}
})
}
@ -631,196 +631,240 @@ func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) {
checkNIC(false)
}
func TestRoutesWithDisabledNIC(t *testing.T) {
const unspecifiedNIC = 0
const nicID1 = 1
const nicID2 = 2
func TestRemoveUnknownNIC(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
})
ep1 := channel.New(0, defaultMTU, "")
if err := s.CreateNIC(nicID1, ep1); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
if err := s.RemoveNIC(1); err != tcpip.ErrUnknownNICID {
t.Fatalf("got s.RemoveNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID)
}
addr1 := tcpip.Address("\x01")
if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil {
t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err)
}
ep2 := channel.New(0, defaultMTU, "")
if err := s.CreateNIC(nicID2, ep2); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
}
addr2 := tcpip.Address("\x02")
if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil {
t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err)
}
// Set a route table that sends all packets with odd destination
// addresses through the first NIC, and all even destination address
// through the second one.
{
subnet0, err := tcpip.NewSubnet("\x00", "\x01")
if err != nil {
t.Fatal(err)
}
subnet1, err := tcpip.NewSubnet("\x01", "\x01")
if err != nil {
t.Fatal(err)
}
s.SetRouteTable([]tcpip.Route{
{Destination: subnet1, Gateway: "\x00", NIC: nicID1},
{Destination: subnet0, Gateway: "\x00", NIC: nicID2},
})
}
// Test routes to odd address.
testRoute(t, s, unspecifiedNIC, "", "\x05", addr1)
testRoute(t, s, unspecifiedNIC, addr1, "\x05", addr1)
testRoute(t, s, nicID1, addr1, "\x05", addr1)
// Test routes to even address.
testRoute(t, s, unspecifiedNIC, "", "\x06", addr2)
testRoute(t, s, unspecifiedNIC, addr2, "\x06", addr2)
testRoute(t, s, nicID2, addr2, "\x06", addr2)
// Disabling NIC1 should result in no routes to odd addresses. Routes to even
// addresses should continue to be available as NIC2 is still enabled.
if err := s.DisableNIC(nicID1); err != nil {
t.Fatalf("s.DisableNIC(%d): %s", nicID1, err)
}
nic1Dst := tcpip.Address("\x05")
testNoRoute(t, s, unspecifiedNIC, "", nic1Dst)
testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst)
testNoRoute(t, s, nicID1, addr1, nic1Dst)
nic2Dst := tcpip.Address("\x06")
testRoute(t, s, unspecifiedNIC, "", nic2Dst, addr2)
testRoute(t, s, unspecifiedNIC, addr2, nic2Dst, addr2)
testRoute(t, s, nicID2, addr2, nic2Dst, addr2)
// Disabling NIC2 should result in no routes to even addresses. No route
// should be available to any address as routes to odd addresses were made
// unavailable by disabling NIC1 above.
if err := s.DisableNIC(nicID2); err != nil {
t.Fatalf("s.DisableNIC(%d): %s", nicID2, err)
}
testNoRoute(t, s, unspecifiedNIC, "", nic1Dst)
testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst)
testNoRoute(t, s, nicID1, addr1, nic1Dst)
testNoRoute(t, s, unspecifiedNIC, "", nic2Dst)
testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst)
testNoRoute(t, s, nicID2, addr2, nic2Dst)
// Enabling NIC1 should make routes to odd addresses available again. Routes
// to even addresses should continue to be unavailable as NIC2 is still
// disabled.
if err := s.EnableNIC(nicID1); err != nil {
t.Fatalf("s.EnableNIC(%d): %s", nicID1, err)
}
testRoute(t, s, unspecifiedNIC, "", nic1Dst, addr1)
testRoute(t, s, unspecifiedNIC, addr1, nic1Dst, addr1)
testRoute(t, s, nicID1, addr1, nic1Dst, addr1)
testNoRoute(t, s, unspecifiedNIC, "", nic2Dst)
testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst)
testNoRoute(t, s, nicID2, addr2, nic2Dst)
}
func TestRouteWritePacketWithDisabledNIC(t *testing.T) {
const unspecifiedNIC = 0
const nicID1 = 1
const nicID2 = 2
func TestRemoveNIC(t *testing.T) {
const nicID = 1
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
})
ep1 := channel.New(1, defaultMTU, "")
if err := s.CreateNIC(nicID1, ep1); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
e := linkEPWithMockedAttach{
LinkEndpoint: loopback.New(),
}
if err := s.CreateNIC(nicID, &e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
addr1 := tcpip.Address("\x01")
if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil {
t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err)
// NIC should be present in NICInfo and attached to a NetworkDispatcher.
allNICInfo := s.NICInfo()
if _, ok := allNICInfo[nicID]; !ok {
t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo)
}
if !e.isAttached() {
t.Fatal("link endpoint not attached to a network dispatcher")
}
ep2 := channel.New(1, defaultMTU, "")
if err := s.CreateNIC(nicID2, ep2); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
// Removing a NIC should remove it from NICInfo and e should be detached from
// the NetworkDispatcher.
if err := s.RemoveNIC(nicID); err != nil {
t.Fatalf("s.RemoveNIC(%d): %s", nicID, err)
}
if nicInfo, ok := s.NICInfo()[nicID]; ok {
t.Errorf("got unexpected NICInfo entry for deleted NIC %d = %+v", nicID, nicInfo)
}
if e.isAttached() {
t.Error("link endpoint for removed NIC still attached to a network dispatcher")
}
}
func TestRouteWithDownNIC(t *testing.T) {
tests := []struct {
name string
downFn func(s *stack.Stack, nicID tcpip.NICID) *tcpip.Error
upFn func(s *stack.Stack, nicID tcpip.NICID) *tcpip.Error
}{
{
name: "Disabled NIC",
downFn: (*stack.Stack).DisableNIC,
upFn: (*stack.Stack).EnableNIC,
},
// Once a NIC is removed, it cannot be brought up.
{
name: "Removed NIC",
downFn: (*stack.Stack).RemoveNIC,
},
}
addr2 := tcpip.Address("\x02")
if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil {
t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err)
}
const unspecifiedNIC = 0
const nicID1 = 1
const nicID2 = 2
const addr1 = tcpip.Address("\x01")
const addr2 = tcpip.Address("\x02")
const nic1Dst = tcpip.Address("\x05")
const nic2Dst = tcpip.Address("\x06")
// Set a route table that sends all packets with odd destination
// addresses through the first NIC, and all even destination address
// through the second one.
{
subnet0, err := tcpip.NewSubnet("\x00", "\x01")
if err != nil {
t.Fatal(err)
}
subnet1, err := tcpip.NewSubnet("\x01", "\x01")
if err != nil {
t.Fatal(err)
}
s.SetRouteTable([]tcpip.Route{
{Destination: subnet1, Gateway: "\x00", NIC: nicID1},
{Destination: subnet0, Gateway: "\x00", NIC: nicID2},
setup := func(t *testing.T) (*stack.Stack, *channel.Endpoint, *channel.Endpoint) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
})
ep1 := channel.New(1, defaultMTU, "")
if err := s.CreateNIC(nicID1, ep1); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
}
if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil {
t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err)
}
ep2 := channel.New(1, defaultMTU, "")
if err := s.CreateNIC(nicID2, ep2); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
}
if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil {
t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err)
}
// Set a route table that sends all packets with odd destination
// addresses through the first NIC, and all even destination address
// through the second one.
{
subnet0, err := tcpip.NewSubnet("\x00", "\x01")
if err != nil {
t.Fatal(err)
}
subnet1, err := tcpip.NewSubnet("\x01", "\x01")
if err != nil {
t.Fatal(err)
}
s.SetRouteTable([]tcpip.Route{
{Destination: subnet1, Gateway: "\x00", NIC: nicID1},
{Destination: subnet0, Gateway: "\x00", NIC: nicID2},
})
}
return s, ep1, ep2
}
nic1Dst := tcpip.Address("\x05")
r1, err := s.FindRoute(nicID1, addr1, nic1Dst, fakeNetNumber, false /* multicastLoop */)
if err != nil {
t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID1, addr1, nic1Dst, fakeNetNumber, err)
}
defer r1.Release()
// Tests that routes through a down NIC are not used when looking up a route
// for a destination.
t.Run("Find", func(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s, _, _ := setup(t)
nic2Dst := tcpip.Address("\x06")
r2, err := s.FindRoute(nicID2, addr2, nic2Dst, fakeNetNumber, false /* multicastLoop */)
if err != nil {
t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID2, addr2, nic2Dst, fakeNetNumber, err)
}
defer r2.Release()
// Test routes to odd address.
testRoute(t, s, unspecifiedNIC, "", "\x05", addr1)
testRoute(t, s, unspecifiedNIC, addr1, "\x05", addr1)
testRoute(t, s, nicID1, addr1, "\x05", addr1)
// If we failed to get routes r1 or r2, we cannot proceed with the test.
if t.Failed() {
t.FailNow()
}
// Test routes to even address.
testRoute(t, s, unspecifiedNIC, "", "\x06", addr2)
testRoute(t, s, unspecifiedNIC, addr2, "\x06", addr2)
testRoute(t, s, nicID2, addr2, "\x06", addr2)
buf := buffer.View([]byte{1})
testSend(t, r1, ep1, buf)
testSend(t, r2, ep2, buf)
// Bringing NIC1 down should result in no routes to odd addresses. Routes to
// even addresses should continue to be available as NIC2 is still up.
if err := test.downFn(s, nicID1); err != nil {
t.Fatalf("test.downFn(_, %d): %s", nicID1, err)
}
testNoRoute(t, s, unspecifiedNIC, "", nic1Dst)
testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst)
testNoRoute(t, s, nicID1, addr1, nic1Dst)
testRoute(t, s, unspecifiedNIC, "", nic2Dst, addr2)
testRoute(t, s, unspecifiedNIC, addr2, nic2Dst, addr2)
testRoute(t, s, nicID2, addr2, nic2Dst, addr2)
// Writes with Routes that use the disabled NIC1 should fail.
if err := s.DisableNIC(nicID1); err != nil {
t.Fatalf("s.DisableNIC(%d): %s", nicID1, err)
}
testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState)
testSend(t, r2, ep2, buf)
// Bringing NIC2 down should result in no routes to even addresses. No
// route should be available to any address as routes to odd addresses
// were made unavailable by bringing NIC1 down above.
if err := test.downFn(s, nicID2); err != nil {
t.Fatalf("test.downFn(_, %d): %s", nicID2, err)
}
testNoRoute(t, s, unspecifiedNIC, "", nic1Dst)
testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst)
testNoRoute(t, s, nicID1, addr1, nic1Dst)
testNoRoute(t, s, unspecifiedNIC, "", nic2Dst)
testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst)
testNoRoute(t, s, nicID2, addr2, nic2Dst)
// Writes with Routes that use the disabled NIC2 should fail.
if err := s.DisableNIC(nicID2); err != nil {
t.Fatalf("s.DisableNIC(%d): %s", nicID2, err)
}
testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState)
testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState)
if upFn := test.upFn; upFn != nil {
// Bringing NIC1 up should make routes to odd addresses available
// again. Routes to even addresses should continue to be unavailable
// as NIC2 is still down.
if err := upFn(s, nicID1); err != nil {
t.Fatalf("test.upFn(_, %d): %s", nicID1, err)
}
testRoute(t, s, unspecifiedNIC, "", nic1Dst, addr1)
testRoute(t, s, unspecifiedNIC, addr1, nic1Dst, addr1)
testRoute(t, s, nicID1, addr1, nic1Dst, addr1)
testNoRoute(t, s, unspecifiedNIC, "", nic2Dst)
testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst)
testNoRoute(t, s, nicID2, addr2, nic2Dst)
}
})
}
})
// Writes with Routes that use the re-enabled NIC1 should succeed.
// TODO(b/147015577): Should we instead completely invalidate all Routes that
// were bound to a disabled NIC at some point?
if err := s.EnableNIC(nicID1); err != nil {
t.Fatalf("s.EnableNIC(%d): %s", nicID1, err)
}
testSend(t, r1, ep1, buf)
testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState)
// Tests that writing a packet using a Route through a down NIC fails.
t.Run("WritePacket", func(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s, ep1, ep2 := setup(t)
r1, err := s.FindRoute(nicID1, addr1, nic1Dst, fakeNetNumber, false /* multicastLoop */)
if err != nil {
t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID1, addr1, nic1Dst, fakeNetNumber, err)
}
defer r1.Release()
r2, err := s.FindRoute(nicID2, addr2, nic2Dst, fakeNetNumber, false /* multicastLoop */)
if err != nil {
t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID2, addr2, nic2Dst, fakeNetNumber, err)
}
defer r2.Release()
// If we failed to get routes r1 or r2, we cannot proceed with the test.
if t.Failed() {
t.FailNow()
}
buf := buffer.View([]byte{1})
testSend(t, r1, ep1, buf)
testSend(t, r2, ep2, buf)
// Writes with Routes that use NIC1 after being brought down should fail.
if err := test.downFn(s, nicID1); err != nil {
t.Fatalf("test.downFn(_, %d): %s", nicID1, err)
}
testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState)
testSend(t, r2, ep2, buf)
// Writes with Routes that use NIC2 after being brought down should fail.
if err := test.downFn(s, nicID2); err != nil {
t.Fatalf("test.downFn(_, %d): %s", nicID2, err)
}
testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState)
testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState)
if upFn := test.upFn; upFn != nil {
// Writes with Routes that use NIC1 after being brought up should
// succeed.
//
// TODO(b/147015577): Should we instead completely invalidate all
// Routes that were bound to a NIC that was brought down at some
// point?
if err := upFn(s, nicID1); err != nil {
t.Fatalf("test.upFn(_, %d): %s", nicID1, err)
}
testSend(t, r1, ep1, buf)
testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState)
}
})
}
})
}
func TestRoutes(t *testing.T) {
@ -3038,6 +3082,50 @@ func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) {
}
}
// TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval tests that removing an IPv6
// address after leaving its solicited node multicast address does not result in
// an error.
func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) {
const nicID = 1
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
})
e := channel.New(10, 1280, linkAddr1)
if err := s.CreateNIC(1, e); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
if err := s.AddAddress(nicID, ipv6.ProtocolNumber, addr1); err != nil {
t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, addr1, err)
}
// The NIC should have joined addr1's solicited node multicast address.
snmc := header.SolicitedNodeAddr(addr1)
in, err := s.IsInGroup(nicID, snmc)
if err != nil {
t.Fatalf("IsInGroup(%d, %s): %s", nicID, snmc, err)
}
if !in {
t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, snmc)
}
if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, snmc); err != nil {
t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, snmc, err)
}
in, err = s.IsInGroup(nicID, snmc)
if err != nil {
t.Fatalf("IsInGroup(%d, %s): %s", nicID, snmc, err)
}
if in {
t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, snmc)
}
if err := s.RemoveAddress(nicID, addr1); err != nil {
t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr1, err)
}
}
func TestJoinLeaveAllNodesMulticastOnNICEnableDisable(t *testing.T) {
const nicID = 1