diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index c86ee1c13..66ce10357 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -429,7 +429,7 @@ type Stack struct { // If not nil, then any new endpoints will have this probe function // invoked everytime they receive a TCP segment. - tcpProbeFunc TCPProbeFunc + tcpProbeFunc atomic.Value // TCPProbeFunc // clock is used to generate user-visible times. clock tcpip.Clock @@ -1795,18 +1795,17 @@ func (s *Stack) TransportProtocolInstance(num tcpip.TransportProtocolNumber) Tra // guarantee provided on which probe will be invoked. Ideally this should only // be called once per stack. func (s *Stack) AddTCPProbe(probe TCPProbeFunc) { - s.mu.Lock() - s.tcpProbeFunc = probe - s.mu.Unlock() + s.tcpProbeFunc.Store(probe) } // GetTCPProbe returns the TCPProbeFunc if installed with AddTCPProbe, nil // otherwise. func (s *Stack) GetTCPProbe() TCPProbeFunc { - s.mu.Lock() - p := s.tcpProbeFunc - s.mu.Unlock() - return p + p := s.tcpProbeFunc.Load() + if p == nil { + return nil + } + return p.(TCPProbeFunc) } // RemoveTCPProbe removes an installed TCP probe. @@ -1815,9 +1814,8 @@ func (s *Stack) GetTCPProbe() TCPProbeFunc { // have a probe attached. Endpoints already created will continue to invoke // TCP probe. func (s *Stack) RemoveTCPProbe() { - s.mu.Lock() - s.tcpProbeFunc = nil - s.mu.Unlock() + // This must be TCPProbeFunc(nil) because atomic.Value.Store(nil) panics. + s.tcpProbeFunc.Store(TCPProbeFunc(nil)) } // JoinGroup joins the given multicast group on the given NIC.