Annotate checklocks on mutex protected fields

...to catch lock-related bugs in nogo tests.

Checklocks also pointed out a locking violation which is fixed
in this change.

Updates #6566.

PiperOrigin-RevId: 397225322
This commit is contained in:
Ghanan Gowripalan 2021-09-16 19:50:00 -07:00 committed by gVisor bot
parent eccd46e67c
commit 85bd3dd9b1
1 changed files with 29 additions and 19 deletions

View File

@ -32,11 +32,13 @@ type protocolIDs struct {
// transportEndpoints manages all endpoints of a given protocol. It has its own
// mutex so as to reduce interference between protocols.
type transportEndpoints struct {
// mu protects all fields of the transportEndpoints.
mu sync.RWMutex
mu sync.RWMutex
// +checklocks:mu
endpoints map[TransportEndpointID]*endpointsByNIC
// rawEndpoints contains endpoints for raw sockets, which receive all
// traffic of a given protocol regardless of port.
//
// +checklocks:mu
rawEndpoints []RawTransportEndpoint
}
@ -69,7 +71,7 @@ func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint {
// descending order of match quality. If a call to yield returns false,
// iterEndpointsLocked stops iteration and returns immediately.
//
// Preconditions: eps.mu must be locked.
// +checklocks:eps.mu
func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield func(*endpointsByNIC) bool) {
// Try to find a match with the id as provided.
if ep, ok := eps.endpoints[id]; ok {
@ -110,7 +112,7 @@ func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield
// findAllEndpointsLocked returns all endpointsByNIC in eps that match id, in
// descending order of match quality.
//
// Preconditions: eps.mu must be locked.
// +checklocks:eps.mu
func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []*endpointsByNIC {
var matchedEPs []*endpointsByNIC
eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool {
@ -122,7 +124,7 @@ func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []
// findEndpointLocked returns the endpoint that most closely matches the given id.
//
// Preconditions: eps.mu must be locked.
// +checklocks:eps.mu
func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpointsByNIC {
var matchedEP *endpointsByNIC
eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool {
@ -133,10 +135,12 @@ func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpo
}
type endpointsByNIC struct {
mu sync.RWMutex
endpoints map[tcpip.NICID]*multiPortEndpoint
// seed is a random secret for a jenkins hash.
seed uint32
mu sync.RWMutex
// +checklocks:mu
endpoints map[tcpip.NICID]*multiPortEndpoint
}
func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint {
@ -171,7 +175,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *Packet
return true
}
// multiPortEndpoints are guaranteed to have at least one element.
transEP := selectEndpoint(id, mpep, epsByNIC.seed)
transEP := mpep.selectEndpoint(id, epsByNIC.seed)
if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue {
queuedProtocol.QueuePacket(transEP, id, pkt)
epsByNIC.mu.RUnlock()
@ -200,7 +204,7 @@ func (epsByNIC *endpointsByNIC) handleError(n *nic, id TransportEndpointID, tran
// broadcast like we are doing with handlePacket above?
// multiPortEndpoints are guaranteed to have at least one element.
selectEndpoint(id, mpep, epsByNIC.seed).HandleError(transErr, pkt)
mpep.selectEndpoint(id, epsByNIC.seed).HandleError(transErr, pkt)
}
// registerEndpoint returns true if it succeeds. It fails and returns
@ -333,15 +337,18 @@ func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber
//
// +stateify savable
type multiPortEndpoint struct {
mu sync.RWMutex `state:"nosave"`
demux *transportDemuxer
netProto tcpip.NetworkProtocolNumber
transProto tcpip.TransportProtocolNumber
flags ports.FlagCounter
mu sync.RWMutex `state:"nosave"`
// endpoints stores the transport endpoints in the order in which they
// were bound. This is required for UDP SO_REUSEADDR.
//
// +checklocks:mu
endpoints []TransportEndpoint
flags ports.FlagCounter
}
func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint {
@ -362,13 +369,16 @@ func reciprocalScale(val, n uint32) uint32 {
// selectEndpoint calculates a hash of destination and source addresses and
// ports then uses it to select a socket. In this case, all packets from one
// address will be sent to same endpoint.
func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint {
if len(mpep.endpoints) == 1 {
return mpep.endpoints[0]
func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID, seed uint32) TransportEndpoint {
ep.mu.RLock()
defer ep.mu.RUnlock()
if len(ep.endpoints) == 1 {
return ep.endpoints[0]
}
if mpep.flags.SharedFlags().ToFlags().Effective().MostRecent {
return mpep.endpoints[len(mpep.endpoints)-1]
if ep.flags.SharedFlags().ToFlags().Effective().MostRecent {
return ep.endpoints[len(ep.endpoints)-1]
}
payload := []byte{
@ -384,8 +394,8 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32
h.Write([]byte(id.RemoteAddress))
hash := h.Sum32()
idx := reciprocalScale(hash, uint32(len(mpep.endpoints)))
return mpep.endpoints[idx]
idx := reciprocalScale(hash, uint32(len(ep.endpoints)))
return ep.endpoints[idx]
}
func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *PacketBuffer) {
@ -657,7 +667,7 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN
}
}
ep := selectEndpoint(id, mpep, epsByNIC.seed)
ep := mpep.selectEndpoint(id, epsByNIC.seed)
epsByNIC.mu.RUnlock()
return ep
}