diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 53cb6694f..13b42d34b 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -71,22 +71,24 @@ type Route struct { // ownership of the provided local address. // // Returns an empty route if validation fails. -func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) Route { - addrWithPrefix := addressEndpoint.AddressWithPrefix() +func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, localAddr, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) Route { + if len(localAddr) == 0 { + localAddr = addressEndpoint.AddressWithPrefix().Address + } - if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(addrWithPrefix.Address) { + if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(localAddr) { addressEndpoint.DecRef() return Route{} } // If no remote address is provided, use the local address. if len(remoteAddr) == 0 { - remoteAddr = addrWithPrefix.Address + remoteAddr = localAddr } r := makeRoute( netProto, - addrWithPrefix.Address, + localAddr, remoteAddr, outgoingNIC, localAddressNIC, @@ -99,7 +101,7 @@ func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndp // broadcast it. if len(gateway) > 0 { r.NextHop = gateway - } else if subnet := addrWithPrefix.Subnet(); subnet.IsBroadcast(remoteAddr) { + } else if subnet := addressEndpoint.Subnet(); subnet.IsBroadcast(remoteAddr) { r.RemoteLinkAddress = header.EthernetBroadcastAddress } @@ -113,6 +115,10 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip panic(fmt.Sprintf("cannot create a route with NICs from different stacks")) } + if len(localAddr) == 0 { + localAddr = localAddressEndpoint.AddressWithPrefix().Address + } + loop := PacketOut // TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index f4504e633..e788344d9 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1240,7 +1240,7 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re r := makeLocalRoute( netProto, - localAddressEndpoint.AddressWithPrefix().Address, + localAddr, remoteAddr, outgoingNIC, localAddressNIC, @@ -1317,7 +1317,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { return makeRoute( netProto, - addressEndpoint.AddressWithPrefix().Address, + localAddr, remoteAddr, nic, /* outboundNIC */ nic, /* localAddressNIC*/ @@ -1354,7 +1354,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n if needRoute { gateway = route.Gateway } - r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop) + r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop) if r == (Route{}) { panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr)) } @@ -1391,7 +1391,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n if id != 0 { if aNIC, ok := s.nics[id]; ok { if addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto); addressEndpoint != nil { - if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) { + if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) { return r, nil } } @@ -1409,7 +1409,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n continue } - if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) { + if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) { return r, nil } } @@ -2130,3 +2130,43 @@ func (s *Stack) networkProtocolNumbers() []tcpip.NetworkProtocolNumber { } return protos } + +func isSubnetBroadcastOnNIC(nic *NIC, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { + addressEndpoint := nic.getAddressOrCreateTempInner(protocol, addr, false /* createTemp */, NeverPrimaryEndpoint) + if addressEndpoint == nil { + return false + } + + subnet := addressEndpoint.Subnet() + addressEndpoint.DecRef() + return subnet.IsBroadcast(addr) +} + +// IsSubnetBroadcast returns true if the provided address is a subnet-local +// broadcast address on the specified NIC and protocol. +// +// Returns false if the NIC is unknown or if the protocol is unknown or does +// not support addressing. +// +// If the NIC is not specified, the stack will check all NICs. +func (s *Stack) IsSubnetBroadcast(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { + s.mu.RLock() + defer s.mu.RUnlock() + + if nicID != 0 { + nic, ok := s.nics[nicID] + if !ok { + return false + } + + return isSubnetBroadcastOnNIC(nic, protocol, addr) + } + + for _, nic := range s.nics { + if isSubnetBroadcastOnNIC(nic, protocol, addr) { + return true + } + } + + return false +} diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 81601f559..648587137 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -369,7 +369,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // specified address is a multicast address. func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) { localAddr := e.ID.LocalAddress - if isBroadcastOrMulticast(localAddr) { + if e.isBroadcastOrMulticast(nicID, netProto, localAddr) { // A packet can only originate from a unicast address (i.e., an interface). localAddr = "" } @@ -1290,7 +1290,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { } nicID := addr.NIC - if len(addr.Addr) != 0 && !isBroadcastOrMulticast(addr.Addr) { + if len(addr.Addr) != 0 && !e.isBroadcastOrMulticast(addr.NIC, netProto, addr.Addr) { // A local unicast address was specified, verify that it's valid. nicID = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) if nicID == 0 { @@ -1531,8 +1531,8 @@ func (e *endpoint) Stats() tcpip.EndpointStats { // Wait implements tcpip.Endpoint.Wait. func (*endpoint) Wait() {} -func isBroadcastOrMulticast(a tcpip.Address) bool { - return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a) +func (e *endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { + return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) || e.stack.IsSubnetBroadcast(nicID, netProto, addr) } // SetOwner implements tcpip.Endpoint.SetOwner. diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 858c99a45..99f3fc37f 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -118,7 +118,7 @@ func (e *endpoint) Resume(s *stack.Stack) { if err != nil { panic(err) } - } else if len(e.ID.LocalAddress) != 0 && !isBroadcastOrMulticast(e.ID.LocalAddress) { // stateBound + } else if len(e.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.RegisterNICID, netProto, e.ID.LocalAddress) { // stateBound // A local unicast address is specified, verify that it's valid. if e.stack.CheckLocalAddress(e.RegisterNICID, netProto, e.ID.LocalAddress) == 0 { panic(tcpip.ErrBadLocalAddress)