diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 824cf6526..542d9257c 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -32,11 +32,13 @@ type protocolIDs struct { // transportEndpoints manages all endpoints of a given protocol. It has its own // mutex so as to reduce interference between protocols. type transportEndpoints struct { - // mu protects all fields of the transportEndpoints. - mu sync.RWMutex + mu sync.RWMutex + // +checklocks:mu endpoints map[TransportEndpointID]*endpointsByNIC // rawEndpoints contains endpoints for raw sockets, which receive all // traffic of a given protocol regardless of port. + // + // +checklocks:mu rawEndpoints []RawTransportEndpoint } @@ -69,7 +71,7 @@ func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint { // descending order of match quality. If a call to yield returns false, // iterEndpointsLocked stops iteration and returns immediately. // -// Preconditions: eps.mu must be locked. +// +checklocks:eps.mu func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield func(*endpointsByNIC) bool) { // Try to find a match with the id as provided. if ep, ok := eps.endpoints[id]; ok { @@ -110,7 +112,7 @@ func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield // findAllEndpointsLocked returns all endpointsByNIC in eps that match id, in // descending order of match quality. // -// Preconditions: eps.mu must be locked. +// +checklocks:eps.mu func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []*endpointsByNIC { var matchedEPs []*endpointsByNIC eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool { @@ -122,7 +124,7 @@ func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) [] // findEndpointLocked returns the endpoint that most closely matches the given id. // -// Preconditions: eps.mu must be locked. +// +checklocks:eps.mu func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpointsByNIC { var matchedEP *endpointsByNIC eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool { @@ -133,10 +135,12 @@ func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpo } type endpointsByNIC struct { - mu sync.RWMutex - endpoints map[tcpip.NICID]*multiPortEndpoint // seed is a random secret for a jenkins hash. seed uint32 + + mu sync.RWMutex + // +checklocks:mu + endpoints map[tcpip.NICID]*multiPortEndpoint } func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint { @@ -171,7 +175,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *Packet return true } // multiPortEndpoints are guaranteed to have at least one element. - transEP := selectEndpoint(id, mpep, epsByNIC.seed) + transEP := mpep.selectEndpoint(id, epsByNIC.seed) if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue { queuedProtocol.QueuePacket(transEP, id, pkt) epsByNIC.mu.RUnlock() @@ -200,7 +204,7 @@ func (epsByNIC *endpointsByNIC) handleError(n *nic, id TransportEndpointID, tran // broadcast like we are doing with handlePacket above? // multiPortEndpoints are guaranteed to have at least one element. - selectEndpoint(id, mpep, epsByNIC.seed).HandleError(transErr, pkt) + mpep.selectEndpoint(id, epsByNIC.seed).HandleError(transErr, pkt) } // registerEndpoint returns true if it succeeds. It fails and returns @@ -333,15 +337,18 @@ func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber // // +stateify savable type multiPortEndpoint struct { - mu sync.RWMutex `state:"nosave"` demux *transportDemuxer netProto tcpip.NetworkProtocolNumber transProto tcpip.TransportProtocolNumber + flags ports.FlagCounter + + mu sync.RWMutex `state:"nosave"` // endpoints stores the transport endpoints in the order in which they // were bound. This is required for UDP SO_REUSEADDR. + // + // +checklocks:mu endpoints []TransportEndpoint - flags ports.FlagCounter } func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint { @@ -362,13 +369,16 @@ func reciprocalScale(val, n uint32) uint32 { // selectEndpoint calculates a hash of destination and source addresses and // ports then uses it to select a socket. In this case, all packets from one // address will be sent to same endpoint. -func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint { - if len(mpep.endpoints) == 1 { - return mpep.endpoints[0] +func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID, seed uint32) TransportEndpoint { + ep.mu.RLock() + defer ep.mu.RUnlock() + + if len(ep.endpoints) == 1 { + return ep.endpoints[0] } - if mpep.flags.SharedFlags().ToFlags().Effective().MostRecent { - return mpep.endpoints[len(mpep.endpoints)-1] + if ep.flags.SharedFlags().ToFlags().Effective().MostRecent { + return ep.endpoints[len(ep.endpoints)-1] } payload := []byte{ @@ -384,8 +394,8 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 h.Write([]byte(id.RemoteAddress)) hash := h.Sum32() - idx := reciprocalScale(hash, uint32(len(mpep.endpoints))) - return mpep.endpoints[idx] + idx := reciprocalScale(hash, uint32(len(ep.endpoints))) + return ep.endpoints[idx] } func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *PacketBuffer) { @@ -657,7 +667,7 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN } } - ep := selectEndpoint(id, mpep, epsByNIC.seed) + ep := mpep.selectEndpoint(id, epsByNIC.seed) epsByNIC.mu.RUnlock() return ep }