Remove AssignableAddressEndpoint.NetworkEndpoint

We can get the network endpoint directly from the NIC.

This is a preparatory CL for when a Route needs to hold a dedicated NIC
as its output interface. This is because when forwarding is enabled,
packets may be sent from a NIC different from the NIC a route's local
address is associated with.

PiperOrigin-RevId: 335484500
This commit is contained in:
Ghanan Gowripalan 2020-10-05 13:15:06 -07:00 committed by gVisor bot
parent 5aa75653ab
commit 91e2d15a62
6 changed files with 22 additions and 33 deletions

View File

@ -679,11 +679,6 @@ type addressState struct {
}
}
// NetworkEndpoint implements AddressEndpoint.
func (a *addressState) NetworkEndpoint() NetworkEndpoint {
return a.addressableEndpointState.networkEndpoint
}
// AddressWithPrefix implements AddressEndpoint.
func (a *addressState) AddressWithPrefix() tcpip.AddressWithPrefix {
return a.addr

View File

@ -38,8 +38,11 @@ type NIC struct {
linkEP LinkEndpoint
context NICContext
stats NICStats
neigh *neighborCache
stats NICStats
neigh *neighborCache
// The network endpoints themselves may be modified by calling the interface's
// methods, but the map reference and entries must be constant.
networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint
// enabled is set to 1 when the NIC is enabled and 0 when it is disabled.
@ -132,6 +135,10 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
return nic
}
func (n *NIC) getNetworkEndpoint(proto tcpip.NetworkProtocolNumber) NetworkEndpoint {
return n.networkEndpoints[proto]
}
// Enabled implements NetworkInterface.
func (n *NIC) Enabled() bool {
return atomic.LoadUint32(&n.enabled) == 1
@ -211,7 +218,6 @@ func (n *NIC) remove() *tcpip.Error {
for _, ep := range n.networkEndpoints {
ep.Close()
}
n.networkEndpoints = nil
// Detach from link endpoint, so no packet comes in.
n.linkEP.Attach(nil)
@ -483,9 +489,9 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool {
func (n *NIC) handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, remotelinkAddr tcpip.LinkAddress, addressEndpoint AssignableAddressEndpoint, pkt *PacketBuffer) {
r := makeRoute(protocol, dst, src, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */)
defer r.Release()
r.RemoteLinkAddress = remotelinkAddr
addressEndpoint.NetworkEndpoint().HandlePacket(&r, pkt)
addressEndpoint.DecRef()
n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt)
}
// DeliverNetworkPacket finds the appropriate network protocol endpoint and
@ -603,7 +609,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
r.RemoteLinkAddress = remote
r.RemoteAddress = src
// TODO(b/123449044): Update the source NIC as well.
addressEndpoint.NetworkEndpoint().HandlePacket(&r, pkt)
n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt)
addressEndpoint.DecRef()
r.Release()
return

View File

@ -326,10 +326,6 @@ const (
// AssignableAddressEndpoint is a reference counted address endpoint that may be
// assigned to a NetworkEndpoint.
type AssignableAddressEndpoint interface {
// NetworkEndpoint returns the NetworkEndpoint the receiver is associated
// with.
NetworkEndpoint() NetworkEndpoint
// AddressWithPrefix returns the endpoint's address.
AddressWithPrefix() tcpip.AddressWithPrefix

View File

@ -100,7 +100,7 @@ func (r *Route) NICID() tcpip.NICID {
// MaxHeaderLength forwards the call to the network endpoint's implementation.
func (r *Route) MaxHeaderLength() uint16 {
return r.addressEndpoint.NetworkEndpoint().MaxHeaderLength()
return r.nic.getNetworkEndpoint(r.NetProto).MaxHeaderLength()
}
// Stats returns a mutable copy of current stats.
@ -121,7 +121,7 @@ func (r *Route) Capabilities() LinkEndpointCapabilities {
// GSOMaxSize returns the maximum GSO packet size.
func (r *Route) GSOMaxSize() uint32 {
if gso, ok := r.addressEndpoint.NetworkEndpoint().(GSOEndpoint); ok {
if gso, ok := r.nic.getNetworkEndpoint(r.NetProto).(GSOEndpoint); ok {
return gso.GSOMaxSize()
}
return 0
@ -211,7 +211,7 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuf
// WritePacket takes ownership of pkt, calculate numBytes first.
numBytes := pkt.Size()
if err := r.addressEndpoint.NetworkEndpoint().WritePacket(r, gso, params, pkt); err != nil {
if err := r.nic.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt); err != nil {
return err
}
@ -227,7 +227,7 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead
return 0, tcpip.ErrInvalidEndpointState
}
n, err := r.addressEndpoint.NetworkEndpoint().WritePackets(r, gso, pkts, params)
n, err := r.nic.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params)
r.nic.stats.Tx.Packets.IncrementBy(uint64(n))
writtenBytes := 0
for i, pb := 0, pkts.Front(); i < n && pb != nil; i, pb = i+1, pb.Next() {
@ -248,7 +248,7 @@ func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error {
// WriteHeaderIncludedPacket takes ownership of pkt, calculate numBytes first.
numBytes := pkt.Data.Size()
if err := r.addressEndpoint.NetworkEndpoint().WriteHeaderIncludedPacket(r, pkt); err != nil {
if err := r.nic.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt); err != nil {
return err
}
r.nic.stats.Tx.Packets.Increment()
@ -258,18 +258,12 @@ func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error {
// DefaultTTL returns the default TTL of the underlying network endpoint.
func (r *Route) DefaultTTL() uint8 {
return r.addressEndpoint.NetworkEndpoint().DefaultTTL()
return r.nic.getNetworkEndpoint(r.NetProto).DefaultTTL()
}
// MTU returns the MTU of the underlying network endpoint.
func (r *Route) MTU() uint32 {
return r.addressEndpoint.NetworkEndpoint().MTU()
}
// NetworkProtocolNumber returns the NetworkProtocolNumber of the underlying
// network endpoint.
func (r *Route) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return r.addressEndpoint.NetworkEndpoint().NetworkProtocolNumber()
return r.nic.getNetworkEndpoint(r.NetProto).MTU()
}
// Release frees all resources associated with the route.

View File

@ -1796,7 +1796,7 @@ func (s *Stack) GetNetworkEndpoint(nicID tcpip.NICID, proto tcpip.NetworkProtoco
return nil, tcpip.ErrUnknownNICID
}
return nic.networkEndpoints[proto], nil
return nic.getNetworkEndpoint(proto), nil
}
// NUDConfigurations gets the per-interface NUD configurations.
@ -1873,10 +1873,8 @@ func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, addres
if addressEndpoint == nil {
continue
}
ep := addressEndpoint.NetworkEndpoint()
addressEndpoint.DecRef()
return ep, nil
return nic.getNetworkEndpoint(netProto), nil
}
return nil, tcpip.ErrBadAddress
}

View File

@ -804,7 +804,7 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso
pkt.Owner = owner
pkt.EgressRoute = r
pkt.GSOOptions = gso
pkt.NetworkProtocolNumber = r.NetworkProtocolNumber()
pkt.NetworkProtocolNumber = r.NetProto
data.ReadToVV(&pkt.Data, packetSize)
buildTCPHdr(r, tf, pkt, gso)
tf.seq = tf.seq.Add(seqnum.Size(packetSize))