Annotate checklocks on mutex protected fields
...to catch lock-related bugs in nogo tests. Checklocks also pointed out a locking violation which is fixed in this change. Updates #6566. PiperOrigin-RevId: 397225322
This commit is contained in:
parent
eccd46e67c
commit
85bd3dd9b1
|
@ -32,11 +32,13 @@ type protocolIDs struct {
|
|||
// transportEndpoints manages all endpoints of a given protocol. It has its own
|
||||
// mutex so as to reduce interference between protocols.
|
||||
type transportEndpoints struct {
|
||||
// mu protects all fields of the transportEndpoints.
|
||||
mu sync.RWMutex
|
||||
mu sync.RWMutex
|
||||
// +checklocks:mu
|
||||
endpoints map[TransportEndpointID]*endpointsByNIC
|
||||
// rawEndpoints contains endpoints for raw sockets, which receive all
|
||||
// traffic of a given protocol regardless of port.
|
||||
//
|
||||
// +checklocks:mu
|
||||
rawEndpoints []RawTransportEndpoint
|
||||
}
|
||||
|
||||
|
@ -69,7 +71,7 @@ func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint {
|
|||
// descending order of match quality. If a call to yield returns false,
|
||||
// iterEndpointsLocked stops iteration and returns immediately.
|
||||
//
|
||||
// Preconditions: eps.mu must be locked.
|
||||
// +checklocks:eps.mu
|
||||
func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield func(*endpointsByNIC) bool) {
|
||||
// Try to find a match with the id as provided.
|
||||
if ep, ok := eps.endpoints[id]; ok {
|
||||
|
@ -110,7 +112,7 @@ func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield
|
|||
// findAllEndpointsLocked returns all endpointsByNIC in eps that match id, in
|
||||
// descending order of match quality.
|
||||
//
|
||||
// Preconditions: eps.mu must be locked.
|
||||
// +checklocks:eps.mu
|
||||
func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []*endpointsByNIC {
|
||||
var matchedEPs []*endpointsByNIC
|
||||
eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool {
|
||||
|
@ -122,7 +124,7 @@ func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []
|
|||
|
||||
// findEndpointLocked returns the endpoint that most closely matches the given id.
|
||||
//
|
||||
// Preconditions: eps.mu must be locked.
|
||||
// +checklocks:eps.mu
|
||||
func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpointsByNIC {
|
||||
var matchedEP *endpointsByNIC
|
||||
eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool {
|
||||
|
@ -133,10 +135,12 @@ func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpo
|
|||
}
|
||||
|
||||
type endpointsByNIC struct {
|
||||
mu sync.RWMutex
|
||||
endpoints map[tcpip.NICID]*multiPortEndpoint
|
||||
// seed is a random secret for a jenkins hash.
|
||||
seed uint32
|
||||
|
||||
mu sync.RWMutex
|
||||
// +checklocks:mu
|
||||
endpoints map[tcpip.NICID]*multiPortEndpoint
|
||||
}
|
||||
|
||||
func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint {
|
||||
|
@ -171,7 +175,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *Packet
|
|||
return true
|
||||
}
|
||||
// multiPortEndpoints are guaranteed to have at least one element.
|
||||
transEP := selectEndpoint(id, mpep, epsByNIC.seed)
|
||||
transEP := mpep.selectEndpoint(id, epsByNIC.seed)
|
||||
if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue {
|
||||
queuedProtocol.QueuePacket(transEP, id, pkt)
|
||||
epsByNIC.mu.RUnlock()
|
||||
|
@ -200,7 +204,7 @@ func (epsByNIC *endpointsByNIC) handleError(n *nic, id TransportEndpointID, tran
|
|||
// broadcast like we are doing with handlePacket above?
|
||||
|
||||
// multiPortEndpoints are guaranteed to have at least one element.
|
||||
selectEndpoint(id, mpep, epsByNIC.seed).HandleError(transErr, pkt)
|
||||
mpep.selectEndpoint(id, epsByNIC.seed).HandleError(transErr, pkt)
|
||||
}
|
||||
|
||||
// registerEndpoint returns true if it succeeds. It fails and returns
|
||||
|
@ -333,15 +337,18 @@ func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber
|
|||
//
|
||||
// +stateify savable
|
||||
type multiPortEndpoint struct {
|
||||
mu sync.RWMutex `state:"nosave"`
|
||||
demux *transportDemuxer
|
||||
netProto tcpip.NetworkProtocolNumber
|
||||
transProto tcpip.TransportProtocolNumber
|
||||
|
||||
flags ports.FlagCounter
|
||||
|
||||
mu sync.RWMutex `state:"nosave"`
|
||||
// endpoints stores the transport endpoints in the order in which they
|
||||
// were bound. This is required for UDP SO_REUSEADDR.
|
||||
//
|
||||
// +checklocks:mu
|
||||
endpoints []TransportEndpoint
|
||||
flags ports.FlagCounter
|
||||
}
|
||||
|
||||
func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint {
|
||||
|
@ -362,13 +369,16 @@ func reciprocalScale(val, n uint32) uint32 {
|
|||
// selectEndpoint calculates a hash of destination and source addresses and
|
||||
// ports then uses it to select a socket. In this case, all packets from one
|
||||
// address will be sent to same endpoint.
|
||||
func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint {
|
||||
if len(mpep.endpoints) == 1 {
|
||||
return mpep.endpoints[0]
|
||||
func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID, seed uint32) TransportEndpoint {
|
||||
ep.mu.RLock()
|
||||
defer ep.mu.RUnlock()
|
||||
|
||||
if len(ep.endpoints) == 1 {
|
||||
return ep.endpoints[0]
|
||||
}
|
||||
|
||||
if mpep.flags.SharedFlags().ToFlags().Effective().MostRecent {
|
||||
return mpep.endpoints[len(mpep.endpoints)-1]
|
||||
if ep.flags.SharedFlags().ToFlags().Effective().MostRecent {
|
||||
return ep.endpoints[len(ep.endpoints)-1]
|
||||
}
|
||||
|
||||
payload := []byte{
|
||||
|
@ -384,8 +394,8 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32
|
|||
h.Write([]byte(id.RemoteAddress))
|
||||
hash := h.Sum32()
|
||||
|
||||
idx := reciprocalScale(hash, uint32(len(mpep.endpoints)))
|
||||
return mpep.endpoints[idx]
|
||||
idx := reciprocalScale(hash, uint32(len(ep.endpoints)))
|
||||
return ep.endpoints[idx]
|
||||
}
|
||||
|
||||
func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *PacketBuffer) {
|
||||
|
@ -657,7 +667,7 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN
|
|||
}
|
||||
}
|
||||
|
||||
ep := selectEndpoint(id, mpep, epsByNIC.seed)
|
||||
ep := mpep.selectEndpoint(id, epsByNIC.seed)
|
||||
epsByNIC.mu.RUnlock()
|
||||
return ep
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue