Do not use tentative addresses for routes

Tentative addresses should not be used when finding a route. This change
fixes a bug where a tentative address may have been used.

Test: stack_test.TestDADResolve
PiperOrigin-RevId: 315997624
This commit is contained in:
Ghanan Gowripalan 2020-06-11 16:08:06 -07:00 committed by gVisor bot
parent 4f111b6384
commit 4c0a8bdaf5
2 changed files with 91 additions and 36 deletions

View File

@ -421,28 +421,52 @@ func TestDADResolve(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
} }
// We add a default route so the call to FindRoute below will succeed
// once we have an assigned address.
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv6EmptySubnet,
Gateway: addr3,
NIC: nicID,
}})
if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
} }
// Address should not be considered bound to the NIC yet (DAD ongoing). // Address should not be considered bound to the NIC yet (DAD ongoing).
addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
if err != nil { t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) } else if want := (tcpip.AddressWithPrefix{}); addr != want {
}
if want := (tcpip.AddressWithPrefix{}); addr != want {
t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
} }
// Make sure the address does not resolve before the resolution time has // Make sure the address does not resolve before the resolution time has
// passed. // passed.
time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncEventTimeout) time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncEventTimeout)
addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
if err != nil { t.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) } else if want := (tcpip.AddressWithPrefix{}); addr != want {
t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
} }
if want := (tcpip.AddressWithPrefix{}); addr != want { // Should not get a route even if we specify the local address as the
t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) // tentative address.
{
r, err := s.FindRoute(nicID, "", addr2, header.IPv6ProtocolNumber, false)
if err != tcpip.ErrNoRoute {
t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute)
}
r.Release()
}
{
r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false)
if err != tcpip.ErrNoRoute {
t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute)
}
r.Release()
}
if t.Failed() {
t.FailNow()
} }
// Wait for DAD to resolve. // Wait for DAD to resolve.
@ -454,12 +478,33 @@ func TestDADResolve(t *testing.T) {
t.Errorf("dad event mismatch (-want +got):\n%s", diff) t.Errorf("dad event mismatch (-want +got):\n%s", diff)
} }
} }
addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
if err != nil { t.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) } else if addr.Address != addr1 {
t.Errorf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1)
} }
if addr.Address != addr1 { // Should get a route using the address now that it is resolved.
t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1) {
r, err := s.FindRoute(nicID, "", addr2, header.IPv6ProtocolNumber, false)
if err != nil {
t.Errorf("got FindRoute(%d, '', %s, %d, false): %s", nicID, addr2, header.IPv6ProtocolNumber, err)
} else if r.LocalAddress != addr1 {
t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, addr1)
}
r.Release()
}
{
r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false)
if err != nil {
t.Errorf("got FindRoute(%d, %s, %s, %d, false): %s", nicID, addr1, addr2, header.IPv6ProtocolNumber, err)
} else if r.LocalAddress != addr1 {
t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, addr1)
}
r.Release()
}
if t.Failed() {
t.FailNow()
} }
// Should not have sent any more NS messages. // Should not have sent any more NS messages.

View File

@ -610,20 +610,16 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok { if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok {
// An endpoint with this id exists, check if it can be used and return it. // An endpoint with this id exists, check if it can be used and return it.
switch ref.getKind() { if !ref.isAssignedRLocked(spoofingOrPromiscuous) {
case permanentExpired:
if !spoofingOrPromiscuous {
n.mu.RUnlock() n.mu.RUnlock()
return nil return nil
} }
fallthrough
case temporary, permanent:
if ref.tryIncRef() { if ref.tryIncRef() {
n.mu.RUnlock() n.mu.RUnlock()
return ref return ref
} }
} }
}
// A usable reference was not found, create a temporary one if requested by // A usable reference was not found, create a temporary one if requested by
// the caller or if the address is found in the NIC's subnets. // the caller or if the address is found in the NIC's subnets.
@ -689,7 +685,6 @@ func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, add
PrefixLen: netProto.DefaultPrefixLen(), PrefixLen: netProto.DefaultPrefixLen(),
}, },
}, peb, temporary, static, false) }, peb, temporary, static, false)
return ref return ref
} }
@ -1660,8 +1655,8 @@ func (r *referencedNetworkEndpoint) setKind(kind networkEndpointKind) {
} }
// isValidForOutgoing returns true if the endpoint can be used to send out a // isValidForOutgoing returns true if the endpoint can be used to send out a
// packet. It requires the endpoint to not be marked expired (i.e., its address // packet. It requires the endpoint to not be marked expired (i.e., its address)
// has been removed), or the NIC to be in spoofing mode. // has been removed) unless the NIC is in spoofing mode, or temporary.
func (r *referencedNetworkEndpoint) isValidForOutgoing() bool { func (r *referencedNetworkEndpoint) isValidForOutgoing() bool {
r.nic.mu.RLock() r.nic.mu.RLock()
defer r.nic.mu.RUnlock() defer r.nic.mu.RUnlock()
@ -1669,13 +1664,28 @@ func (r *referencedNetworkEndpoint) isValidForOutgoing() bool {
return r.isValidForOutgoingRLocked() return r.isValidForOutgoingRLocked()
} }
// isValidForOutgoingRLocked returns true if the endpoint can be used to send // isValidForOutgoingRLocked is the same as isValidForOutgoing but requires
// out a packet. It requires the endpoint to not be marked expired (i.e., its // r.nic.mu to be read locked.
// address has been removed), or the NIC to be in spoofing mode.
//
// r's NIC must be read locked.
func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool { func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool {
return r.nic.mu.enabled && (r.getKind() != permanentExpired || r.nic.mu.spoofing) if !r.nic.mu.enabled {
return false
}
return r.isAssignedRLocked(r.nic.mu.spoofing)
}
// isAssignedRLocked returns true if r is considered to be assigned to the NIC.
//
// r.nic.mu must be read locked.
func (r *referencedNetworkEndpoint) isAssignedRLocked(spoofingOrPromiscuous bool) bool {
switch r.getKind() {
case permanentTentative:
return false
case permanentExpired:
return spoofingOrPromiscuous
default:
return true
}
} }
// expireLocked decrements the reference count and marks the permanent endpoint // expireLocked decrements the reference count and marks the permanent endpoint