Don't hold AddressEndpoints for multicast addresses

Group addressable endpoints can simply check if it has joined the
multicast group without maintaining address endpoints. This also
helps remove the dependency on AddressableEndpoint from
GroupAddressableEndpoint.

Now that group addresses are not tracked with address endpoints, we can
avoid accidentally obtaining a route with a multicast local address.

PiperOrigin-RevId: 343336912
This commit is contained in:
Ghanan Gowripalan 2020-11-19 11:46:09 -08:00 committed by gVisor bot
parent 332671c339
commit 27ee4fe76a
6 changed files with 48 additions and 51 deletions

View File

@ -566,21 +566,6 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
stats.IP.MalformedPacketsReceived.Increment() stats.IP.MalformedPacketsReceived.Increment()
return return
} }
srcAddr := h.SourceAddress()
dstAddr := h.DestinationAddress()
addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint)
if addressEndpoint == nil {
if !e.protocol.Forwarding() {
stats.IP.InvalidDestinationAddressesReceived.Increment()
return
}
_ = e.forwardPacket(pkt)
return
}
subnet := addressEndpoint.AddressWithPrefix().Subnet()
addressEndpoint.DecRef()
// There has been some confusion regarding verifying checksums. We need // There has been some confusion regarding verifying checksums. We need
// just look for negative 0 (0xffff) as the checksum, as it's not possible to // just look for negative 0 (0xffff) as the checksum, as it's not possible to
@ -608,16 +593,42 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
return return
} }
srcAddr := h.SourceAddress()
dstAddr := h.DestinationAddress()
// As per RFC 1122 section 3.2.1.3: // As per RFC 1122 section 3.2.1.3:
// When a host sends any datagram, the IP source address MUST // When a host sends any datagram, the IP source address MUST
// be one of its own IP addresses (but not a broadcast or // be one of its own IP addresses (but not a broadcast or
// multicast address). // multicast address).
if directedBroadcast := subnet.IsBroadcast(srcAddr); directedBroadcast || srcAddr == header.IPv4Broadcast || header.IsV4MulticastAddress(srcAddr) { if srcAddr == header.IPv4Broadcast || header.IsV4MulticastAddress(srcAddr) {
stats.IP.InvalidSourceAddressesReceived.Increment() stats.IP.InvalidSourceAddressesReceived.Increment()
return return
} }
// Make sure the source address is not a subnet-local broadcast address.
if addressEndpoint := e.AcquireAssignedAddress(srcAddr, false /* createTemp */, stack.NeverPrimaryEndpoint); addressEndpoint != nil {
subnet := addressEndpoint.Subnet()
addressEndpoint.DecRef()
if subnet.IsBroadcast(srcAddr) {
stats.IP.InvalidSourceAddressesReceived.Increment()
return
}
}
pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast // The destination address should be an address we own or a group we joined
// for us to receive the packet. Otherwise, attempt to forward the packet.
if addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint); addressEndpoint != nil {
subnet := addressEndpoint.AddressWithPrefix().Subnet()
addressEndpoint.DecRef()
pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast
} else if !e.IsInGroup(dstAddr) {
if !e.protocol.Forwarding() {
stats.IP.InvalidDestinationAddressesReceived.Increment()
return
}
_ = e.forwardPacket(pkt)
return
}
// iptables filtering. All packets that reach here are intended for // iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded. // this machine and will not be forwarded.

View File

@ -796,7 +796,8 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
allowResponseToMulticast = reason.respondToMulticast allowResponseToMulticast = reason.respondToMulticast
} }
if (!allowResponseToMulticast && header.IsV6MulticastAddress(origIPHdrDst)) || origIPHdrSrc == header.IPv6Any { isOrigDstMulticast := header.IsV6MulticastAddress(origIPHdrDst)
if (!allowResponseToMulticast && isOrigDstMulticast) || origIPHdrSrc == header.IPv6Any {
return nil return nil
} }
@ -812,8 +813,13 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
// If we are operating as a router, do not use the packet's destination // If we are operating as a router, do not use the packet's destination
// address as the response's source address as we should not own the // address as the response's source address as we should not own the
// destination address of a packet we are forwarding. // destination address of a packet we are forwarding.
//
// If the packet was originally destined to a multicast address, then do not
// use the packet's destination address as the source for the response ICMP
// packet as "multicast addresses must not be used as source addresses in IPv6
// packets", as per RFC 4291 section 2.7.
localAddr := origIPHdrDst localAddr := origIPHdrDst
if _, ok := reason.(*icmpReasonHopLimitExceeded); ok { if _, ok := reason.(*icmpReasonHopLimitExceeded); ok || isOrigDstMulticast {
localAddr = "" localAddr = ""
} }
// Even if we were able to receive a packet from some remote, we may not have // Even if we were able to receive a packet from some remote, we may not have

View File

@ -737,8 +737,11 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
return return
} }
addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint) // The destination address should be an address we own or a group we joined
if addressEndpoint == nil { // for us to receive the packet. Otherwise, attempt to forward the packet.
if addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint); addressEndpoint != nil {
addressEndpoint.DecRef()
} else if !e.IsInGroup(dstAddr) {
if !e.protocol.Forwarding() { if !e.protocol.Forwarding() {
stats.IP.InvalidDestinationAddressesReceived.Increment() stats.IP.InvalidDestinationAddressesReceived.Increment()
return return
@ -747,7 +750,6 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
_ = e.forwardPacket(pkt) _ = e.forwardPacket(pkt)
return return
} }
addressEndpoint.DecRef()
// vv consists of: // vv consists of:
// - Any IPv6 header bytes after the first 40 (i.e. extensions). // - Any IPv6 header bytes after the first 40 (i.e. extensions).

View File

@ -873,7 +873,13 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
Length: uint16(udpLength), Length: uint16(udpLength),
}) })
copy(u.Payload(), udpPayload) copy(u.Payload(), udpPayload)
sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, addr2, uint16(udpLength))
dstAddr := tcpip.Address(addr2)
if test.multicast {
dstAddr = header.IPv6AllNodesMulticastAddress
}
sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, dstAddr, uint16(udpLength))
sum = header.Checksum(udpPayload, sum) sum = header.Checksum(udpPayload, sum)
u.SetChecksum(^u.CalculateChecksum(sum)) u.SetChecksum(^u.CalculateChecksum(sum))
@ -884,10 +890,6 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Serialize IPv6 fixed header. // Serialize IPv6 fixed header.
payloadLength := hdr.UsedLength() payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
dstAddr := tcpip.Address(addr2)
if test.multicast {
dstAddr = header.IPv6AllNodesMulticastAddress
}
ip.Encode(&header.IPv6Fields{ ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(payloadLength), PayloadLength: uint16(payloadLength),
NextHeader: ipv6NextHdr, NextHeader: ipv6NextHdr,

View File

@ -594,15 +594,6 @@ func (a *AddressableEndpointState) JoinGroup(group tcpip.Address) (bool, *tcpip.
defer a.mu.Unlock() defer a.mu.Unlock()
joins, ok := a.mu.groups[group] joins, ok := a.mu.groups[group]
if !ok {
ep, err := a.addAndAcquireAddressLocked(group.WithPrefix(), NeverPrimaryEndpoint, AddressConfigStatic, false /* deprecated */, true /* permanent */)
if err != nil {
return false, err
}
// We have no need for the address endpoint.
a.decAddressRefLocked(ep)
}
a.mu.groups[group] = joins + 1 a.mu.groups[group] = joins + 1
return !ok, nil return !ok, nil
} }
@ -618,7 +609,6 @@ func (a *AddressableEndpointState) LeaveGroup(group tcpip.Address) (bool, *tcpip
} }
if joins == 1 { if joins == 1 {
a.removeGroupAddressLocked(group)
delete(a.mu.groups, group) delete(a.mu.groups, group)
return true, nil return true, nil
} }
@ -635,23 +625,11 @@ func (a *AddressableEndpointState) IsInGroup(group tcpip.Address) bool {
return ok return ok
} }
func (a *AddressableEndpointState) removeGroupAddressLocked(group tcpip.Address) {
if err := a.removePermanentAddressLocked(group); err != nil {
// removePermanentEndpointLocked would only return an error if group is
// not bound to the addressable endpoint, but we know it MUST be assigned
// since we have group in our map of groups.
panic(fmt.Sprintf("error removing group address = %s: %s", group, err))
}
}
// Cleanup forcefully leaves all groups and removes all permanent addresses. // Cleanup forcefully leaves all groups and removes all permanent addresses.
func (a *AddressableEndpointState) Cleanup() { func (a *AddressableEndpointState) Cleanup() {
a.mu.Lock() a.mu.Lock()
defer a.mu.Unlock() defer a.mu.Unlock()
for group := range a.mu.groups {
a.removeGroupAddressLocked(group)
}
a.mu.groups = make(map[tcpip.Address]uint32) a.mu.groups = make(map[tcpip.Address]uint32)
for _, ep := range a.mu.endpoints { for _, ep := range a.mu.endpoints {

View File

@ -248,8 +248,6 @@ ALL_TESTS = [
), ),
PacketimpactTestInfo( PacketimpactTestInfo(
name = "ipv6_unknown_options_action", name = "ipv6_unknown_options_action",
# TODO(b/159928940): Fix netstack then remove the line below.
expect_netstack_failure = True,
), ),
PacketimpactTestInfo( PacketimpactTestInfo(
name = "ipv4_fragment_reassembly", name = "ipv4_fragment_reassembly",