Fix deadlock in UDP handleControlPacket path.
Fixing the sendto deadlock exposed yet another deadlock where a lock inversion occurs on the handleControlPacket path where e.mu and demuxer.epsByNIC.mu are acquired in reverse order from say when RegisterTransportEndpoint is called in endpoint.Connect(). This fix sidesteps the issue by just making endpoint.state an atomic and gets rid of the need to acquire e.mu in e.HandleControlPacket. PiperOrigin-RevId: 344939895
This commit is contained in:
parent
54ad145f2e
commit
79e2364933
|
@ -16,6 +16,7 @@ package udp
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/sleep"
|
||||
"gvisor.dev/gvisor/pkg/sync"
|
||||
|
@ -95,9 +96,11 @@ type endpoint struct {
|
|||
rcvClosed bool
|
||||
|
||||
// The following fields are protected by the mu mutex.
|
||||
mu sync.RWMutex `state:"nosave"`
|
||||
sndBufSize int
|
||||
sndBufSizeMax int
|
||||
mu sync.RWMutex `state:"nosave"`
|
||||
sndBufSize int
|
||||
sndBufSizeMax int
|
||||
// state must be read/set using the EndpointState()/setEndpointState()
|
||||
// methods.
|
||||
state EndpointState
|
||||
route *stack.Route `state:"manual"`
|
||||
dstPort uint16
|
||||
|
@ -198,6 +201,20 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
|
|||
return e
|
||||
}
|
||||
|
||||
// setEndpointState updates the state of the endpoint to state atomically. This
|
||||
// method is unexported as the only place we should update the state is in this
|
||||
// package but we allow the state to be read freely without holding e.mu.
|
||||
//
|
||||
// Precondition: e.mu must be held to call this method.
|
||||
func (e *endpoint) setEndpointState(state EndpointState) {
|
||||
atomic.StoreUint32((*uint32)(&e.state), uint32(state))
|
||||
}
|
||||
|
||||
// EndpointState() returns the current state of the endpoint.
|
||||
func (e *endpoint) EndpointState() EndpointState {
|
||||
return EndpointState(atomic.LoadUint32((*uint32)(&e.state)))
|
||||
}
|
||||
|
||||
// UniqueID implements stack.TransportEndpoint.UniqueID.
|
||||
func (e *endpoint) UniqueID() uint64 {
|
||||
return e.uniqueID
|
||||
|
@ -223,7 +240,7 @@ func (e *endpoint) Close() {
|
|||
e.mu.Lock()
|
||||
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
|
||||
|
||||
switch e.state {
|
||||
switch e.EndpointState() {
|
||||
case StateBound, StateConnected:
|
||||
e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
|
||||
e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{})
|
||||
|
@ -252,7 +269,7 @@ func (e *endpoint) Close() {
|
|||
}
|
||||
|
||||
// Update the state.
|
||||
e.state = StateClosed
|
||||
e.setEndpointState(StateClosed)
|
||||
|
||||
e.mu.Unlock()
|
||||
|
||||
|
@ -316,7 +333,7 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
|
|||
//
|
||||
// Returns true for retry if preparation should be retried.
|
||||
func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) {
|
||||
switch e.state {
|
||||
switch e.EndpointState() {
|
||||
case StateInitial:
|
||||
case StateConnected:
|
||||
return false, nil
|
||||
|
@ -338,7 +355,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
|
|||
|
||||
// The state changed when we released the shared locked and re-acquired
|
||||
// it in exclusive mode. Try again.
|
||||
if e.state != StateInitial {
|
||||
if e.EndpointState() != StateInitial {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
|
@ -453,7 +470,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
|
|||
e.mu.Lock()
|
||||
|
||||
// Recheck state after lock was re-acquired.
|
||||
if e.state != StateConnected {
|
||||
if e.EndpointState() != StateConnected {
|
||||
err = tcpip.ErrInvalidEndpointState
|
||||
}
|
||||
if err == nil && route.IsResolutionRequired() {
|
||||
|
@ -464,7 +481,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
|
|||
e.mu.RLock()
|
||||
|
||||
// Recheck state after lock was re-acquired.
|
||||
if e.state != StateConnected {
|
||||
if e.EndpointState() != StateConnected {
|
||||
err = tcpip.ErrInvalidEndpointState
|
||||
}
|
||||
return ch, err
|
||||
|
@ -934,7 +951,7 @@ func (e *endpoint) Disconnect() *tcpip.Error {
|
|||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
if e.state != StateConnected {
|
||||
if e.EndpointState() != StateConnected {
|
||||
return nil
|
||||
}
|
||||
var (
|
||||
|
@ -957,7 +974,7 @@ func (e *endpoint) Disconnect() *tcpip.Error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
e.state = StateBound
|
||||
e.setEndpointState(StateBound)
|
||||
boundPortFlags = e.boundPortFlags
|
||||
} else {
|
||||
if e.ID.LocalPort != 0 {
|
||||
|
@ -965,7 +982,7 @@ func (e *endpoint) Disconnect() *tcpip.Error {
|
|||
e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{})
|
||||
e.boundPortFlags = ports.Flags{}
|
||||
}
|
||||
e.state = StateInitial
|
||||
e.setEndpointState(StateInitial)
|
||||
}
|
||||
|
||||
e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice)
|
||||
|
@ -990,7 +1007,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
|
|||
|
||||
nicID := addr.NIC
|
||||
var localPort uint16
|
||||
switch e.state {
|
||||
switch e.EndpointState() {
|
||||
case StateInitial:
|
||||
case StateBound, StateConnected:
|
||||
localPort = e.ID.LocalPort
|
||||
|
@ -1025,7 +1042,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
|
|||
RemoteAddress: r.RemoteAddress,
|
||||
}
|
||||
|
||||
if e.state == StateInitial {
|
||||
if e.EndpointState() == StateInitial {
|
||||
id.LocalAddress = r.LocalAddress
|
||||
}
|
||||
|
||||
|
@ -1059,7 +1076,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
|
|||
e.RegisterNICID = nicID
|
||||
e.effectiveNetProtos = netProtos
|
||||
|
||||
e.state = StateConnected
|
||||
e.setEndpointState(StateConnected)
|
||||
|
||||
e.rcvMu.Lock()
|
||||
e.rcvReady = true
|
||||
|
@ -1081,7 +1098,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
|
|||
|
||||
// A socket in the bound state can still receive multicast messages,
|
||||
// so we need to notify waiters on shutdown.
|
||||
if e.state != StateBound && e.state != StateConnected {
|
||||
if state := e.EndpointState(); state != StateBound && state != StateConnected {
|
||||
return tcpip.ErrNotConnected
|
||||
}
|
||||
|
||||
|
@ -1132,7 +1149,7 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ
|
|||
func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
|
||||
// Don't allow binding once endpoint is not in the initial state
|
||||
// anymore.
|
||||
if e.state != StateInitial {
|
||||
if e.EndpointState() != StateInitial {
|
||||
return tcpip.ErrInvalidEndpointState
|
||||
}
|
||||
|
||||
|
@ -1176,7 +1193,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
|
|||
e.effectiveNetProtos = netProtos
|
||||
|
||||
// Mark endpoint as bound.
|
||||
e.state = StateBound
|
||||
e.setEndpointState(StateBound)
|
||||
|
||||
e.rcvMu.Lock()
|
||||
e.rcvReady = true
|
||||
|
@ -1208,7 +1225,7 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
|
|||
defer e.mu.RUnlock()
|
||||
|
||||
addr := e.ID.LocalAddress
|
||||
if e.state == StateConnected {
|
||||
if e.EndpointState() == StateConnected {
|
||||
addr = e.route.LocalAddress
|
||||
}
|
||||
|
||||
|
@ -1224,7 +1241,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
|
|||
e.mu.RLock()
|
||||
defer e.mu.RUnlock()
|
||||
|
||||
if e.state != StateConnected {
|
||||
if e.EndpointState() != StateConnected {
|
||||
return tcpip.FullAddress{}, tcpip.ErrNotConnected
|
||||
}
|
||||
|
||||
|
@ -1356,25 +1373,20 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
|
|||
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
|
||||
func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
|
||||
if typ == stack.ControlPortUnreachable {
|
||||
e.mu.RLock()
|
||||
if e.state == StateConnected {
|
||||
if e.EndpointState() == StateConnected {
|
||||
e.lastErrorMu.Lock()
|
||||
e.lastError = tcpip.ErrConnectionRefused
|
||||
e.lastErrorMu.Unlock()
|
||||
e.mu.RUnlock()
|
||||
|
||||
e.waiterQueue.Notify(waiter.EventErr)
|
||||
return
|
||||
}
|
||||
e.mu.RUnlock()
|
||||
}
|
||||
}
|
||||
|
||||
// State implements tcpip.Endpoint.State.
|
||||
func (e *endpoint) State() uint32 {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
return uint32(e.state)
|
||||
return uint32(e.EndpointState())
|
||||
}
|
||||
|
||||
// Info returns a copy of the endpoint info.
|
||||
|
|
|
@ -98,7 +98,8 @@ func (e *endpoint) Resume(s *stack.Stack) {
|
|||
}
|
||||
}
|
||||
|
||||
if e.state != StateBound && e.state != StateConnected {
|
||||
state := e.EndpointState()
|
||||
if state != StateBound && state != StateConnected {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -113,7 +114,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
|
|||
}
|
||||
|
||||
var err *tcpip.Error
|
||||
if e.state == StateConnected {
|
||||
if state == StateConnected {
|
||||
e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.ops.GetMulticastLoop())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
|
|
@ -375,8 +375,6 @@ TEST_P(UdpSocketTest, BindInUse) {
|
|||
}
|
||||
|
||||
TEST_P(UdpSocketTest, ConnectWriteToInvalidPort) {
|
||||
ASSERT_NO_ERRNO(BindLoopback());
|
||||
|
||||
// Discover a free unused port by creating a new UDP socket, binding it
|
||||
// recording the just bound port and closing it. This is not guaranteed as it
|
||||
// can still race with other port UDP sockets trying to bind a port at the
|
||||
|
@ -410,6 +408,35 @@ TEST_P(UdpSocketTest, ConnectWriteToInvalidPort) {
|
|||
ASSERT_EQ(optlen, sizeof(err));
|
||||
}
|
||||
|
||||
TEST_P(UdpSocketTest, ConnectSimultaneousWriteToInvalidPort) {
|
||||
// Discover a free unused port by creating a new UDP socket, binding it
|
||||
// recording the just bound port and closing it. This is not guaranteed as it
|
||||
// can still race with other port UDP sockets trying to bind a port at the
|
||||
// same time.
|
||||
struct sockaddr_storage addr_storage = InetLoopbackAddr();
|
||||
socklen_t addrlen = sizeof(addr_storage);
|
||||
struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
|
||||
FileDescriptor s =
|
||||
ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP));
|
||||
ASSERT_THAT(bind(s.get(), addr, addrlen), SyscallSucceeds());
|
||||
ASSERT_THAT(getsockname(s.get(), addr, &addrlen), SyscallSucceeds());
|
||||
EXPECT_EQ(addrlen, addrlen_);
|
||||
EXPECT_NE(*Port(&addr_storage), 0);
|
||||
ASSERT_THAT(close(s.release()), SyscallSucceeds());
|
||||
|
||||
// Now connect to the port that we just released.
|
||||
ScopedThread t([&] {
|
||||
ASSERT_THAT(connect(sock_.get(), addr, addrlen_), SyscallSucceeds());
|
||||
});
|
||||
|
||||
char buf[512];
|
||||
RandomizeBuffer(buf, sizeof(buf));
|
||||
// Send from sock_ to an unbound port.
|
||||
ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, addr, addrlen_),
|
||||
SyscallSucceedsWithValue(sizeof(buf)));
|
||||
t.Join();
|
||||
}
|
||||
|
||||
TEST_P(UdpSocketTest, ReceiveAfterConnect) {
|
||||
ASSERT_NO_ERRNO(BindLoopback());
|
||||
ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
|
||||
|
|
Loading…
Reference in New Issue