Clean up transport_demuxer.go and test
- Change receiver of endpoint lookup functions - Remove unused struct fields and functions in test - s/%v/%s/ for errors - Capitalize NIC https://github.com/golang/go/wiki/CodeReviewComments#initialisms PiperOrigin-RevId: 303119580
This commit is contained in:
parent
7aa388ce74
commit
c64796748c
|
@ -35,7 +35,7 @@ type protocolIDs struct {
|
|||
type transportEndpoints struct {
|
||||
// mu protects all fields of the transportEndpoints.
|
||||
mu sync.RWMutex
|
||||
endpoints map[TransportEndpointID]*endpointsByNic
|
||||
endpoints map[TransportEndpointID]*endpointsByNIC
|
||||
// rawEndpoints contains endpoints for raw sockets, which receive all
|
||||
// traffic of a given protocol regardless of port.
|
||||
rawEndpoints []RawTransportEndpoint
|
||||
|
@ -46,11 +46,11 @@ type transportEndpoints struct {
|
|||
func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
|
||||
eps.mu.Lock()
|
||||
defer eps.mu.Unlock()
|
||||
epsByNic, ok := eps.endpoints[id]
|
||||
epsByNIC, ok := eps.endpoints[id]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if !epsByNic.unregisterEndpoint(bindToDevice, ep) {
|
||||
if !epsByNIC.unregisterEndpoint(bindToDevice, ep) {
|
||||
return
|
||||
}
|
||||
delete(eps.endpoints, id)
|
||||
|
@ -66,18 +66,85 @@ func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint {
|
|||
return es
|
||||
}
|
||||
|
||||
type endpointsByNic struct {
|
||||
// iterEndpointsLocked yields all endpointsByNIC in eps that match id, in
|
||||
// descending order of match quality. If a call to yield returns false,
|
||||
// iterEndpointsLocked stops iteration and returns immediately.
|
||||
//
|
||||
// Preconditions: eps.mu must be locked.
|
||||
func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield func(*endpointsByNIC) bool) {
|
||||
// Try to find a match with the id as provided.
|
||||
if ep, ok := eps.endpoints[id]; ok {
|
||||
if !yield(ep) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Try to find a match with the id minus the local address.
|
||||
nid := id
|
||||
|
||||
nid.LocalAddress = ""
|
||||
if ep, ok := eps.endpoints[nid]; ok {
|
||||
if !yield(ep) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Try to find a match with the id minus the remote part.
|
||||
nid.LocalAddress = id.LocalAddress
|
||||
nid.RemoteAddress = ""
|
||||
nid.RemotePort = 0
|
||||
if ep, ok := eps.endpoints[nid]; ok {
|
||||
if !yield(ep) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Try to find a match with only the local port.
|
||||
nid.LocalAddress = ""
|
||||
if ep, ok := eps.endpoints[nid]; ok {
|
||||
if !yield(ep) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// findAllEndpointsLocked returns all endpointsByNIC in eps that match id, in
|
||||
// descending order of match quality.
|
||||
//
|
||||
// Preconditions: eps.mu must be locked.
|
||||
func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []*endpointsByNIC {
|
||||
var matchedEPs []*endpointsByNIC
|
||||
eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool {
|
||||
matchedEPs = append(matchedEPs, ep)
|
||||
return true
|
||||
})
|
||||
return matchedEPs
|
||||
}
|
||||
|
||||
// findEndpointLocked returns the endpoint that most closely matches the given id.
|
||||
//
|
||||
// Preconditions: eps.mu must be locked.
|
||||
func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpointsByNIC {
|
||||
var matchedEP *endpointsByNIC
|
||||
eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool {
|
||||
matchedEP = ep
|
||||
return false
|
||||
})
|
||||
return matchedEP
|
||||
}
|
||||
|
||||
type endpointsByNIC struct {
|
||||
mu sync.RWMutex
|
||||
endpoints map[tcpip.NICID]*multiPortEndpoint
|
||||
// seed is a random secret for a jenkins hash.
|
||||
seed uint32
|
||||
}
|
||||
|
||||
func (epsByNic *endpointsByNic) transportEndpoints() []TransportEndpoint {
|
||||
epsByNic.mu.RLock()
|
||||
defer epsByNic.mu.RUnlock()
|
||||
func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint {
|
||||
epsByNIC.mu.RLock()
|
||||
defer epsByNIC.mu.RUnlock()
|
||||
var eps []TransportEndpoint
|
||||
for _, ep := range epsByNic.endpoints {
|
||||
for _, ep := range epsByNIC.endpoints {
|
||||
eps = append(eps, ep.transportEndpoints()...)
|
||||
}
|
||||
return eps
|
||||
|
@ -85,13 +152,13 @@ func (epsByNic *endpointsByNic) transportEndpoints() []TransportEndpoint {
|
|||
|
||||
// HandlePacket is called by the stack when new packets arrive to this transport
|
||||
// endpoint.
|
||||
func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer) {
|
||||
epsByNic.mu.RLock()
|
||||
func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer) {
|
||||
epsByNIC.mu.RLock()
|
||||
|
||||
mpep, ok := epsByNic.endpoints[r.ref.nic.ID()]
|
||||
mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()]
|
||||
if !ok {
|
||||
if mpep, ok = epsByNic.endpoints[0]; !ok {
|
||||
epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
|
||||
if mpep, ok = epsByNIC.endpoints[0]; !ok {
|
||||
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -100,29 +167,29 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, p
|
|||
// endpoints bound to the right device.
|
||||
if isMulticastOrBroadcast(id.LocalAddress) {
|
||||
mpep.handlePacketAll(r, id, pkt)
|
||||
epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
|
||||
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
|
||||
return
|
||||
}
|
||||
// multiPortEndpoints are guaranteed to have at least one element.
|
||||
transEP := selectEndpoint(id, mpep, epsByNic.seed)
|
||||
transEP := selectEndpoint(id, mpep, epsByNIC.seed)
|
||||
if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue {
|
||||
queuedProtocol.QueuePacket(r, transEP, id, pkt)
|
||||
epsByNic.mu.RUnlock()
|
||||
epsByNIC.mu.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
transEP.HandlePacket(r, id, pkt)
|
||||
epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
|
||||
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
|
||||
}
|
||||
|
||||
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
|
||||
func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer) {
|
||||
epsByNic.mu.RLock()
|
||||
defer epsByNic.mu.RUnlock()
|
||||
func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer) {
|
||||
epsByNIC.mu.RLock()
|
||||
defer epsByNIC.mu.RUnlock()
|
||||
|
||||
mpep, ok := epsByNic.endpoints[n.ID()]
|
||||
mpep, ok := epsByNIC.endpoints[n.ID()]
|
||||
if !ok {
|
||||
mpep, ok = epsByNic.endpoints[0]
|
||||
mpep, ok = epsByNIC.endpoints[0]
|
||||
}
|
||||
if !ok {
|
||||
return
|
||||
|
@ -132,16 +199,16 @@ func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpoint
|
|||
// broadcast like we are doing with handlePacket above?
|
||||
|
||||
// multiPortEndpoints are guaranteed to have at least one element.
|
||||
selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, pkt)
|
||||
selectEndpoint(id, mpep, epsByNIC.seed).HandleControlPacket(id, typ, extra, pkt)
|
||||
}
|
||||
|
||||
// registerEndpoint returns true if it succeeds. It fails and returns
|
||||
// false if ep already has an element with the same key.
|
||||
func (epsByNic *endpointsByNic) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
|
||||
epsByNic.mu.Lock()
|
||||
defer epsByNic.mu.Unlock()
|
||||
func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
|
||||
epsByNIC.mu.Lock()
|
||||
defer epsByNIC.mu.Unlock()
|
||||
|
||||
multiPortEp, ok := epsByNic.endpoints[bindToDevice]
|
||||
multiPortEp, ok := epsByNIC.endpoints[bindToDevice]
|
||||
if !ok {
|
||||
multiPortEp = &multiPortEndpoint{
|
||||
demux: d,
|
||||
|
@ -149,24 +216,24 @@ func (epsByNic *endpointsByNic) registerEndpoint(d *transportDemuxer, netProto t
|
|||
transProto: transProto,
|
||||
reuse: reusePort,
|
||||
}
|
||||
epsByNic.endpoints[bindToDevice] = multiPortEp
|
||||
epsByNIC.endpoints[bindToDevice] = multiPortEp
|
||||
}
|
||||
|
||||
return multiPortEp.singleRegisterEndpoint(t, reusePort)
|
||||
}
|
||||
|
||||
// unregisterEndpoint returns true if endpointsByNic has to be unregistered.
|
||||
func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool {
|
||||
epsByNic.mu.Lock()
|
||||
defer epsByNic.mu.Unlock()
|
||||
multiPortEp, ok := epsByNic.endpoints[bindToDevice]
|
||||
// unregisterEndpoint returns true if endpointsByNIC has to be unregistered.
|
||||
func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool {
|
||||
epsByNIC.mu.Lock()
|
||||
defer epsByNIC.mu.Unlock()
|
||||
multiPortEp, ok := epsByNIC.endpoints[bindToDevice]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if multiPortEp.unregisterEndpoint(t) {
|
||||
delete(epsByNic.endpoints, bindToDevice)
|
||||
delete(epsByNIC.endpoints, bindToDevice)
|
||||
}
|
||||
return len(epsByNic.endpoints) == 0
|
||||
return len(epsByNIC.endpoints) == 0
|
||||
}
|
||||
|
||||
// transportDemuxer demultiplexes packets targeted at a transport endpoint
|
||||
|
@ -198,7 +265,7 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer {
|
|||
for proto := range stack.transportProtocols {
|
||||
protoIDs := protocolIDs{netProto, proto}
|
||||
d.protocol[protoIDs] = &transportEndpoints{
|
||||
endpoints: make(map[TransportEndpointID]*endpointsByNic),
|
||||
endpoints: make(map[TransportEndpointID]*endpointsByNIC),
|
||||
}
|
||||
qTransProto, isQueued := (stack.transportProtocols[proto].proto).(queuedTransportProtocol)
|
||||
if isQueued {
|
||||
|
@ -378,16 +445,16 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol
|
|||
eps.mu.Lock()
|
||||
defer eps.mu.Unlock()
|
||||
|
||||
epsByNic, ok := eps.endpoints[id]
|
||||
epsByNIC, ok := eps.endpoints[id]
|
||||
if !ok {
|
||||
epsByNic = &endpointsByNic{
|
||||
epsByNIC = &endpointsByNIC{
|
||||
endpoints: make(map[tcpip.NICID]*multiPortEndpoint),
|
||||
seed: rand.Uint32(),
|
||||
}
|
||||
eps.endpoints[id] = epsByNic
|
||||
eps.endpoints[id] = epsByNIC
|
||||
}
|
||||
|
||||
return epsByNic.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice)
|
||||
return epsByNIC.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice)
|
||||
}
|
||||
|
||||
// unregisterEndpoint unregisters the endpoint with the given id such that it
|
||||
|
@ -413,7 +480,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
|
|||
// transport endpoints.
|
||||
if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) {
|
||||
eps.mu.RLock()
|
||||
destEPs := d.findAllEndpointsLocked(eps, id)
|
||||
destEPs := eps.findAllEndpointsLocked(id)
|
||||
eps.mu.RUnlock()
|
||||
// Fail if we didn't find at least one matching transport endpoint.
|
||||
if len(destEPs) == 0 {
|
||||
|
@ -439,7 +506,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
|
|||
}
|
||||
|
||||
eps.mu.RLock()
|
||||
ep := d.findEndpointLocked(eps, id)
|
||||
ep := eps.findEndpointLocked(id)
|
||||
eps.mu.RUnlock()
|
||||
if ep == nil {
|
||||
if protocol == header.UDPProtocolNumber {
|
||||
|
@ -483,115 +550,47 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco
|
|||
return false
|
||||
}
|
||||
|
||||
// Try to find the endpoint.
|
||||
eps.mu.RLock()
|
||||
ep := d.findEndpointLocked(eps, id)
|
||||
ep := eps.findEndpointLocked(id)
|
||||
eps.mu.RUnlock()
|
||||
|
||||
// Fail if we didn't find one.
|
||||
if ep == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Deliver the packet.
|
||||
ep.handleControlPacket(n, id, typ, extra, pkt)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// iterEndpointsLocked yields all endpointsByNic in eps that match id, in
|
||||
// descending order of match quality. If a call to yield returns false,
|
||||
// iterEndpointsLocked stops iteration and returns immediately.
|
||||
//
|
||||
// Preconditions: eps.mu must be locked.
|
||||
func (d *transportDemuxer) iterEndpointsLocked(eps *transportEndpoints, id TransportEndpointID, yield func(*endpointsByNic) bool) {
|
||||
// Try to find a match with the id as provided.
|
||||
if ep, ok := eps.endpoints[id]; ok {
|
||||
if !yield(ep) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Try to find a match with the id minus the local address.
|
||||
nid := id
|
||||
|
||||
nid.LocalAddress = ""
|
||||
if ep, ok := eps.endpoints[nid]; ok {
|
||||
if !yield(ep) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Try to find a match with the id minus the remote part.
|
||||
nid.LocalAddress = id.LocalAddress
|
||||
nid.RemoteAddress = ""
|
||||
nid.RemotePort = 0
|
||||
if ep, ok := eps.endpoints[nid]; ok {
|
||||
if !yield(ep) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Try to find a match with only the local port.
|
||||
nid.LocalAddress = ""
|
||||
if ep, ok := eps.endpoints[nid]; ok {
|
||||
if !yield(ep) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id TransportEndpointID) []*endpointsByNic {
|
||||
var matchedEPs []*endpointsByNic
|
||||
d.iterEndpointsLocked(eps, id, func(ep *endpointsByNic) bool {
|
||||
matchedEPs = append(matchedEPs, ep)
|
||||
return true
|
||||
})
|
||||
return matchedEPs
|
||||
}
|
||||
|
||||
// findTransportEndpoint find a single endpoint that most closely matches the provided id.
|
||||
func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint {
|
||||
eps, ok := d.protocol[protocolIDs{netProto, transProto}]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
// Try to find the endpoint.
|
||||
|
||||
eps.mu.RLock()
|
||||
epsByNic := d.findEndpointLocked(eps, id)
|
||||
// Fail if we didn't find one.
|
||||
if epsByNic == nil {
|
||||
epsByNIC := eps.findEndpointLocked(id)
|
||||
if epsByNIC == nil {
|
||||
eps.mu.RUnlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
epsByNic.mu.RLock()
|
||||
epsByNIC.mu.RLock()
|
||||
eps.mu.RUnlock()
|
||||
|
||||
mpep, ok := epsByNic.endpoints[r.ref.nic.ID()]
|
||||
mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()]
|
||||
if !ok {
|
||||
if mpep, ok = epsByNic.endpoints[0]; !ok {
|
||||
epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
|
||||
if mpep, ok = epsByNIC.endpoints[0]; !ok {
|
||||
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
ep := selectEndpoint(id, mpep, epsByNic.seed)
|
||||
epsByNic.mu.RUnlock()
|
||||
ep := selectEndpoint(id, mpep, epsByNIC.seed)
|
||||
epsByNIC.mu.RUnlock()
|
||||
return ep
|
||||
}
|
||||
|
||||
// findEndpointLocked returns the endpoint that most closely matches the given
|
||||
// id.
|
||||
func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, id TransportEndpointID) *endpointsByNic {
|
||||
var matchedEP *endpointsByNic
|
||||
d.iterEndpointsLocked(eps, id, func(ep *endpointsByNic) bool {
|
||||
matchedEP = ep
|
||||
return false
|
||||
})
|
||||
return matchedEP
|
||||
}
|
||||
|
||||
// registerRawEndpoint registers the given endpoint with the dispatcher such
|
||||
// that packets of the appropriate protocol are delivered to it. A single
|
||||
// packet can be sent to one or more raw endpoints along with a non-raw
|
||||
|
|
|
@ -40,75 +40,47 @@ const (
|
|||
)
|
||||
|
||||
type testContext struct {
|
||||
t *testing.T
|
||||
linkEps map[tcpip.NICID]*channel.Endpoint
|
||||
s *stack.Stack
|
||||
|
||||
ep tcpip.Endpoint
|
||||
wq waiter.Queue
|
||||
}
|
||||
|
||||
func (c *testContext) cleanup() {
|
||||
if c.ep != nil {
|
||||
c.ep.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *testContext) createV6Endpoint(v6only bool) {
|
||||
var err *tcpip.Error
|
||||
c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
|
||||
if err != nil {
|
||||
c.t.Fatalf("NewEndpoint failed: %v", err)
|
||||
}
|
||||
|
||||
if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil {
|
||||
c.t.Fatalf("SetSockOpt failed: %v", err)
|
||||
}
|
||||
wq waiter.Queue
|
||||
}
|
||||
|
||||
// newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs.
|
||||
func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext {
|
||||
s := stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
|
||||
TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
|
||||
TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
|
||||
})
|
||||
linkEps := make(map[tcpip.NICID]*channel.Endpoint)
|
||||
for _, linkEpID := range linkEpIDs {
|
||||
channelEp := channel.New(256, mtu, "")
|
||||
if err := s.CreateNIC(linkEpID, channelEp); err != nil {
|
||||
t.Fatalf("CreateNIC failed: %v", err)
|
||||
t.Fatalf("CreateNIC failed: %s", err)
|
||||
}
|
||||
linkEps[linkEpID] = channelEp
|
||||
|
||||
if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, stackAddr); err != nil {
|
||||
t.Fatalf("AddAddress IPv4 failed: %v", err)
|
||||
t.Fatalf("AddAddress IPv4 failed: %s", err)
|
||||
}
|
||||
|
||||
if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, stackV6Addr); err != nil {
|
||||
t.Fatalf("AddAddress IPv6 failed: %v", err)
|
||||
t.Fatalf("AddAddress IPv6 failed: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
s.SetRouteTable([]tcpip.Route{
|
||||
{
|
||||
Destination: header.IPv4EmptySubnet,
|
||||
NIC: 1,
|
||||
},
|
||||
{
|
||||
Destination: header.IPv6EmptySubnet,
|
||||
NIC: 1,
|
||||
},
|
||||
{Destination: header.IPv4EmptySubnet, NIC: 1},
|
||||
{Destination: header.IPv6EmptySubnet, NIC: 1},
|
||||
})
|
||||
|
||||
return &testContext{
|
||||
t: t,
|
||||
s: s,
|
||||
linkEps: linkEps,
|
||||
}
|
||||
}
|
||||
|
||||
type headers struct {
|
||||
srcPort uint16
|
||||
dstPort uint16
|
||||
srcPort, dstPort uint16
|
||||
}
|
||||
|
||||
func newPayload() []byte {
|
||||
|
@ -179,15 +151,15 @@ func TestTransportDemuxerRegister(t *testing.T) {
|
|||
t.Fatalf("%T does not implement stack.TransportEndpoint", ep)
|
||||
}
|
||||
if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, false, 0), test.want; got != want {
|
||||
t.Fatalf("s.RegisterTransportEndpoint(...) = %v, want %v", got, want)
|
||||
t.Fatalf("s.RegisterTransportEndpoint(...) = %s, want %s", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestReuseBindToDevice injects varied packets on input devices and checks that
|
||||
// TestBindToDeviceDistribution injects varied packets on input devices and checks that
|
||||
// the distribution of packets received matches expectations.
|
||||
func TestDistribution(t *testing.T) {
|
||||
func TestBindToDeviceDistribution(t *testing.T) {
|
||||
type endpointSockopts struct {
|
||||
reuse int
|
||||
bindToDevice tcpip.NICID
|
||||
|
@ -196,19 +168,19 @@ func TestDistribution(t *testing.T) {
|
|||
name string
|
||||
// endpoints will received the inject packets.
|
||||
endpoints []endpointSockopts
|
||||
// wantedDistribution is the wanted ratio of packets received on each
|
||||
// wantDistributions is the want ratio of packets received on each
|
||||
// endpoint for each NIC on which packets are injected.
|
||||
wantedDistributions map[tcpip.NICID][]float64
|
||||
wantDistributions map[tcpip.NICID][]float64
|
||||
}{
|
||||
{
|
||||
"BindPortReuse",
|
||||
// 5 endpoints that all have reuse set.
|
||||
[]endpointSockopts{
|
||||
{1, 0},
|
||||
{1, 0},
|
||||
{1, 0},
|
||||
{1, 0},
|
||||
{1, 0},
|
||||
{reuse: 1, bindToDevice: 0},
|
||||
{reuse: 1, bindToDevice: 0},
|
||||
{reuse: 1, bindToDevice: 0},
|
||||
{reuse: 1, bindToDevice: 0},
|
||||
{reuse: 1, bindToDevice: 0},
|
||||
},
|
||||
map[tcpip.NICID][]float64{
|
||||
// Injected packets on dev0 get distributed evenly.
|
||||
|
@ -219,9 +191,9 @@ func TestDistribution(t *testing.T) {
|
|||
"BindToDevice",
|
||||
// 3 endpoints with various bindings.
|
||||
[]endpointSockopts{
|
||||
{0, 1},
|
||||
{0, 2},
|
||||
{0, 3},
|
||||
{reuse: 0, bindToDevice: 1},
|
||||
{reuse: 0, bindToDevice: 2},
|
||||
{reuse: 0, bindToDevice: 3},
|
||||
},
|
||||
map[tcpip.NICID][]float64{
|
||||
// Injected packets on dev0 go only to the endpoint bound to dev0.
|
||||
|
@ -236,12 +208,12 @@ func TestDistribution(t *testing.T) {
|
|||
"ReuseAndBindToDevice",
|
||||
// 6 endpoints with various bindings.
|
||||
[]endpointSockopts{
|
||||
{1, 1},
|
||||
{1, 1},
|
||||
{1, 2},
|
||||
{1, 2},
|
||||
{1, 2},
|
||||
{1, 0},
|
||||
{reuse: 1, bindToDevice: 1},
|
||||
{reuse: 1, bindToDevice: 1},
|
||||
{reuse: 1, bindToDevice: 2},
|
||||
{reuse: 1, bindToDevice: 2},
|
||||
{reuse: 1, bindToDevice: 2},
|
||||
{reuse: 1, bindToDevice: 0},
|
||||
},
|
||||
map[tcpip.NICID][]float64{
|
||||
// Injected packets on dev0 get distributed among endpoints bound to
|
||||
|
@ -256,16 +228,13 @@ func TestDistribution(t *testing.T) {
|
|||
},
|
||||
} {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
for device, wantedDistribution := range test.wantedDistributions {
|
||||
for device, wantDistribution := range test.wantDistributions {
|
||||
t.Run(string(device), func(t *testing.T) {
|
||||
var devices []tcpip.NICID
|
||||
for d := range test.wantedDistributions {
|
||||
for d := range test.wantDistributions {
|
||||
devices = append(devices, d)
|
||||
}
|
||||
c := newDualTestContextMultiNIC(t, defaultMTU, devices)
|
||||
defer c.cleanup()
|
||||
|
||||
c.createV6Endpoint(false)
|
||||
|
||||
eps := make(map[tcpip.Endpoint]int)
|
||||
|
||||
|
@ -281,7 +250,7 @@ func TestDistribution(t *testing.T) {
|
|||
var err *tcpip.Error
|
||||
ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq)
|
||||
if err != nil {
|
||||
c.t.Fatalf("NewEndpoint failed: %v", err)
|
||||
t.Fatalf("NewEndpoint failed: %s", err)
|
||||
}
|
||||
eps[ep] = i
|
||||
|
||||
|
@ -294,20 +263,20 @@ func TestDistribution(t *testing.T) {
|
|||
defer ep.Close()
|
||||
reusePortOption := tcpip.ReusePortOption(endpoint.reuse)
|
||||
if err := ep.SetSockOpt(reusePortOption); err != nil {
|
||||
c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", reusePortOption, i, err)
|
||||
t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %s", reusePortOption, i, err)
|
||||
}
|
||||
bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice)
|
||||
if err := ep.SetSockOpt(bindToDeviceOption); err != nil {
|
||||
c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", bindToDeviceOption, i, err)
|
||||
t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %s", bindToDeviceOption, i, err)
|
||||
}
|
||||
if err := ep.Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil {
|
||||
t.Fatalf("ep.Bind(...) on endpoint %d failed: %v", i, err)
|
||||
t.Fatalf("ep.Bind(...) on endpoint %d failed: %s", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
npackets := 100000
|
||||
nports := 10000
|
||||
if got, want := len(test.endpoints), len(wantedDistribution); got != want {
|
||||
if got, want := len(test.endpoints), len(wantDistribution); got != want {
|
||||
t.Fatalf("got len(test.endpoints) = %d, want %d", got, want)
|
||||
}
|
||||
ports := make(map[uint16]tcpip.Endpoint)
|
||||
|
@ -322,11 +291,9 @@ func TestDistribution(t *testing.T) {
|
|||
dstPort: stackPort},
|
||||
device)
|
||||
|
||||
var addr tcpip.FullAddress
|
||||
ep := <-pollChannel
|
||||
_, _, err := ep.Read(&addr)
|
||||
if err != nil {
|
||||
c.t.Fatalf("Read on endpoint %d failed: %v", eps[ep], err)
|
||||
if _, _, err := ep.Read(nil); err != nil {
|
||||
t.Fatalf("Read on endpoint %d failed: %s", eps[ep], err)
|
||||
}
|
||||
stats[ep]++
|
||||
if i < nports {
|
||||
|
@ -342,13 +309,13 @@ func TestDistribution(t *testing.T) {
|
|||
|
||||
// Check that a packet distribution is as expected.
|
||||
for ep, i := range eps {
|
||||
wantedRatio := wantedDistribution[i]
|
||||
wantedRecv := wantedRatio * float64(npackets)
|
||||
wantRatio := wantDistribution[i]
|
||||
wantRecv := wantRatio * float64(npackets)
|
||||
actualRecv := stats[ep]
|
||||
actualRatio := float64(stats[ep]) / float64(npackets)
|
||||
// The deviation is less than 10%.
|
||||
if math.Abs(actualRatio-wantedRatio) > 0.05 {
|
||||
t.Errorf("wanted about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantedRatio*100, wantedRecv, npackets, i, actualRatio*100, actualRecv, npackets)
|
||||
if math.Abs(actualRatio-wantRatio) > 0.05 {
|
||||
t.Errorf("want about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantRatio*100, wantRecv, npackets, i, actualRatio*100, actualRecv, npackets)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
|
Loading…
Reference in New Issue