Add subnet checking to NIC.findEndpoint and consolidate with NIC.getRef

This adds the same logic to NIC.findEndpoint that is already done in
NIC.getRef. Since this makes the two functions very similar they were combined
into one with the originals being wrappers.

PiperOrigin-RevId: 263864708
This commit is contained in:
Chris Kuiper 2019-08-16 15:57:46 -07:00 committed by gVisor bot
parent d60d99cb61
commit f7114e0a27
3 changed files with 99 additions and 71 deletions

View File

@ -186,41 +186,73 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN
return nil
}
func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint {
return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, n.promiscuous)
}
// findEndpoint finds the endpoint, if any, with the given address.
func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint {
return n.getRefOrCreateTemp(protocol, address, peb, n.spoofing)
}
// getRefEpOrCreateTemp returns the referenced network endpoint for the given
// protocol and address. If none exists a temporary one may be created if
// requested.
func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, allowTemp bool) *referencedNetworkEndpoint {
id := NetworkEndpointID{address}
n.mu.RLock()
ref := n.endpoints[id]
if ref != nil && !ref.tryIncRef() {
ref = nil
if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() {
n.mu.RUnlock()
return ref
}
spoofing := n.spoofing
// The address was not found, create a temporary one if requested by the
// caller or if the address is found in the NIC's subnets.
createTempEP := allowTemp
if !createTempEP {
for _, sn := range n.subnets {
if sn.Contains(address) {
createTempEP = true
break
}
}
}
n.mu.RUnlock()
if ref != nil || !spoofing {
return ref
if !createTempEP {
return nil
}
// Try again with the lock in exclusive mode. If we still can't get the
// endpoint, create a new "temporary" endpoint. It will only exist while
// there's a route through it.
n.mu.Lock()
ref = n.endpoints[id]
if ref == nil || !ref.tryIncRef() {
if netProto, ok := n.stack.networkProtocols[protocol]; ok {
ref, _ = n.addAddressLocked(tcpip.ProtocolAddress{
Protocol: protocol,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: address,
PrefixLen: netProto.DefaultPrefixLen(),
},
}, peb, true)
if ref != nil {
ref.holdsInsertRef = false
}
}
if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() {
n.mu.Unlock()
return ref
}
netProto, ok := n.stack.networkProtocols[protocol]
if !ok {
n.mu.Unlock()
return nil
}
ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{
Protocol: protocol,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: address,
PrefixLen: netProto.DefaultPrefixLen(),
},
}, peb, true)
if ref != nil {
ref.holdsInsertRef = false
}
n.mu.Unlock()
return ref
}
@ -553,57 +585,6 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
n.stack.stats.IP.InvalidAddressesReceived.Increment()
}
func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint {
id := NetworkEndpointID{dst}
n.mu.RLock()
if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() {
n.mu.RUnlock()
return ref
}
promiscuous := n.promiscuous
// Check if the packet is for a subnet this NIC cares about.
if !promiscuous {
for _, sn := range n.subnets {
if sn.Contains(dst) {
promiscuous = true
break
}
}
}
n.mu.RUnlock()
if promiscuous {
// Try again with the lock in exclusive mode. If we still can't
// get the endpoint, create a new "temporary" one. It will only
// exist while there's a route through it.
n.mu.Lock()
if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() {
n.mu.Unlock()
return ref
}
netProto, ok := n.stack.networkProtocols[protocol]
if !ok {
n.mu.Unlock()
return nil
}
ref, err := n.addAddressLocked(tcpip.ProtocolAddress{
Protocol: protocol,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: dst,
PrefixLen: netProto.DefaultPrefixLen(),
},
}, CanBePrimaryEndpoint, true)
n.mu.Unlock()
if err == nil {
ref.holdsInsertRef = false
return ref
}
}
return nil
}
// DeliverTransportPacket delivers the packets to the appropriate transport
// protocol endpoint.
func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) {

View File

@ -830,6 +830,48 @@ func TestSubnetAcceptsMatchingPacket(t *testing.T) {
}
}
// Set the subnet, then check that CheckLocalAddress returns the correct NIC.
func TestCheckLocalAddressForSubnet(t *testing.T) {
const nicID tcpip.NICID = 1
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id, _ := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(nicID, id); err != nil {
t.Fatal("CreateNIC failed:", err)
}
s.SetRouteTable([]tcpip.Route{
{"\x00", "\x00", "\x00", nicID}, // default route
})
subnet, err := tcpip.NewSubnet(tcpip.Address("\xa0"), tcpip.AddressMask("\xf0"))
if err != nil {
t.Fatal("NewSubnet failed:", err)
}
if err := s.AddSubnet(nicID, fakeNetNumber, subnet); err != nil {
t.Fatal("AddSubnet failed:", err)
}
// Loop over all subnet addresses and check them.
numOfAddresses := (1 << uint((8 - subnet.Prefix())))
if numOfAddresses < 1 || numOfAddresses > 255 {
t.Fatalf("got numOfAddresses = %d, want = [1 .. 255] (subnet=%s)", numOfAddresses, subnet)
}
addr := []byte(subnet.ID())
for i := 0; i < numOfAddresses; i++ {
if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addr)); gotNicID != nicID {
t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addr), gotNicID, nicID)
}
addr[0]++
}
// Trying the next address should fail since it is outside the subnet range.
if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addr)); gotNicID != 0 {
t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addr), gotNicID, 0)
}
}
// Set destination outside the subnet, then check it doesn't get delivered.
func TestSubnetRejectsNonmatchingPacket(t *testing.T) {
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})

View File

@ -168,6 +168,11 @@ func NewSubnet(a Address, m AddressMask) (Subnet, error) {
return Subnet{a, m}, nil
}
// String implements Stringer.
func (s Subnet) String() string {
return fmt.Sprintf("%s/%d", s.ID(), s.Prefix())
}
// Contains returns true iff the address is of the same length and matches the
// subnet address and mask.
func (s *Subnet) Contains(a Address) bool {