Ensure that flipcall.Endpoint.Shutdown() shuts down inactive peers.

PiperOrigin-RevId: 267022978
This commit is contained in:
Jamie Liu 2019-09-03 15:09:34 -07:00 committed by gVisor bot
parent 648170f527
commit eb94066ef2
6 changed files with 261 additions and 76 deletions

View File

@ -121,7 +121,16 @@ func (ep *Endpoint) enterFutexWait() error {
}
func (ep *Endpoint) exitFutexWait() {
atomic.AddInt32(&ep.ctrl.state, -epsBlocked)
switch eps := atomic.AddInt32(&ep.ctrl.state, -epsBlocked); eps {
case 0:
return
case epsShutdown:
// ep.ctrlShutdown() was called while we were blocked, so we are
// repsonsible for indicating connection shutdown.
ep.shutdownConn()
default:
panic(fmt.Sprintf("invalid flipcall.Endpoint.ctrl.state after flipcall.Endpoint.exitFutexWait(): %v", eps+epsBlocked))
}
}
func (ep *Endpoint) ctrlShutdown() {
@ -142,5 +151,25 @@ func (ep *Endpoint) ctrlShutdown() {
break
}
}
} else {
// There is no blocked thread, so we are responsible for indicating
// connection shutdown.
ep.shutdownConn()
}
}
func (ep *Endpoint) shutdownConn() {
switch cs := atomic.SwapUint32(ep.connState(), csShutdown); cs {
case ep.activeState:
if err := ep.futexWakeConnState(1); err != nil {
log.Warningf("failed to FUTEX_WAKE peer Endpoint for shutdown: %v", err)
}
case ep.inactiveState:
// The peer is currently active and will detect shutdown when it tries
// to update the connection state.
case csShutdown:
// The peer also called Endpoint.Shutdown().
default:
log.Warningf("unexpected connection state before Endpoint.shutdownConn(): %v", cs)
}
}

View File

@ -42,11 +42,6 @@ type Endpoint struct {
// dataCap is immutable.
dataCap uint32
// shutdown is non-zero if Endpoint.Shutdown() has been called, or if the
// Endpoint has acknowledged shutdown initiated by the peer. shutdown is
// accessed using atomic memory operations.
shutdown uint32
// activeState is csClientActive if this is a client Endpoint and
// csServerActive if this is a server Endpoint.
activeState uint32
@ -55,9 +50,27 @@ type Endpoint struct {
// csClientActive if this is a server Endpoint.
inactiveState uint32
// shutdown is non-zero if Endpoint.Shutdown() has been called, or if the
// Endpoint has acknowledged shutdown initiated by the peer. shutdown is
// accessed using atomic memory operations.
shutdown uint32
ctrl endpointControlImpl
}
// EndpointSide indicates which side of a connection an Endpoint belongs to.
type EndpointSide int
const (
// ClientSide indicates that an Endpoint is a client (initially-active;
// first method call should be Connect).
ClientSide EndpointSide = iota
// ServerSide indicates that an Endpoint is a server (initially-inactive;
// first method call should be RecvFirst.)
ServerSide
)
// Init must be called on zero-value Endpoints before first use. If it
// succeeds, ep.Destroy() must be called once the Endpoint is no longer in use.
//
@ -65,7 +78,17 @@ type Endpoint struct {
// Endpoint. FD may differ between Endpoints if they are in different
// processes, but must represent the same file. The packet window must
// initially be filled with zero bytes.
func (ep *Endpoint) Init(pwd PacketWindowDescriptor, opts ...EndpointOption) error {
func (ep *Endpoint) Init(side EndpointSide, pwd PacketWindowDescriptor, opts ...EndpointOption) error {
switch side {
case ClientSide:
ep.activeState = csClientActive
ep.inactiveState = csServerActive
case ServerSide:
ep.activeState = csServerActive
ep.inactiveState = csClientActive
default:
return fmt.Errorf("invalid EndpointSide: %v", side)
}
if pwd.Length < pageSize {
return fmt.Errorf("packet window size (%d) less than minimum (%d)", pwd.Length, pageSize)
}
@ -78,9 +101,6 @@ func (ep *Endpoint) Init(pwd PacketWindowDescriptor, opts ...EndpointOption) err
}
ep.packet = m
ep.dataCap = uint32(pwd.Length) - uint32(PacketHeaderBytes)
// These will be overwritten by ep.Connect() for client Endpoints.
ep.activeState = csServerActive
ep.inactiveState = csClientActive
if err := ep.ctrlInit(opts...); err != nil {
ep.unmapPacket()
return err
@ -90,9 +110,9 @@ func (ep *Endpoint) Init(pwd PacketWindowDescriptor, opts ...EndpointOption) err
// NewEndpoint is a convenience function that returns an initialized Endpoint
// allocated on the heap.
func NewEndpoint(pwd PacketWindowDescriptor, opts ...EndpointOption) (*Endpoint, error) {
func NewEndpoint(side EndpointSide, pwd PacketWindowDescriptor, opts ...EndpointOption) (*Endpoint, error) {
var ep Endpoint
if err := ep.Init(pwd, opts...); err != nil {
if err := ep.Init(side, pwd, opts...); err != nil {
return nil, err
}
return &ep, nil
@ -115,9 +135,9 @@ func (ep *Endpoint) unmapPacket() {
}
// Shutdown causes concurrent and future calls to ep.Connect(), ep.SendRecv(),
// ep.RecvFirst(), and ep.SendLast() to unblock and return errors. It does not
// wait for concurrent calls to return. The effect of Shutdown on the peer
// Endpoint is unspecified. Successive calls to Shutdown have no effect.
// ep.RecvFirst(), and ep.SendLast(), as well as the same calls in the peer
// Endpoint, to unblock and return errors. It does not wait for concurrent
// calls to return. Successive calls to Shutdown have no effect.
//
// Shutdown is the only Endpoint method that may be called concurrently with
// other methods on the same Endpoint.
@ -152,24 +172,22 @@ const (
// The client is, by definition, initially active, so this must be 0.
csClientActive = 0
csServerActive = 1
csShutdown = 2
)
// Connect designates ep as a client Endpoint and blocks until the peer
// Endpoint has called Endpoint.RecvFirst().
// Connect blocks until the peer Endpoint has called Endpoint.RecvFirst().
//
// Preconditions: ep.Connect(), ep.RecvFirst(), ep.SendRecv(), and
// ep.SendLast() have never been called.
// Preconditions: ep is a client Endpoint. ep.Connect(), ep.RecvFirst(),
// ep.SendRecv(), and ep.SendLast() have never been called.
func (ep *Endpoint) Connect() error {
ep.activeState = csClientActive
ep.inactiveState = csServerActive
return ep.ctrlConnect()
}
// RecvFirst blocks until the peer Endpoint calls Endpoint.SendRecv(), then
// returns the datagram length specified by that call.
//
// Preconditions: ep.SendRecv(), ep.RecvFirst(), and ep.SendLast() have never
// been called.
// Preconditions: ep is a server Endpoint. ep.SendRecv(), ep.RecvFirst(), and
// ep.SendLast() have never been called.
func (ep *Endpoint) RecvFirst() (uint32, error) {
if err := ep.ctrlWaitFirst(); err != nil {
return 0, err

View File

@ -38,12 +38,12 @@ func Example() {
panic(err)
}
var clientEP Endpoint
if err := clientEP.Init(pwd); err != nil {
if err := clientEP.Init(ClientSide, pwd); err != nil {
panic(err)
}
defer clientEP.Destroy()
var serverEP Endpoint
if err := serverEP.Init(pwd); err != nil {
if err := serverEP.Init(ServerSide, pwd); err != nil {
panic(err)
}
defer serverEP.Destroy()

View File

@ -39,11 +39,11 @@ func newTestConnectionWithOptions(tb testing.TB, clientOpts, serverOpts []Endpoi
c.pwa.Destroy()
tb.Fatalf("PacketWindowAllocator.Allocate() failed: %v", err)
}
if err := c.clientEP.Init(pwd, clientOpts...); err != nil {
if err := c.clientEP.Init(ClientSide, pwd, clientOpts...); err != nil {
c.pwa.Destroy()
tb.Fatalf("failed to create client Endpoint: %v", err)
}
if err := c.serverEP.Init(pwd, serverOpts...); err != nil {
if err := c.serverEP.Init(ServerSide, pwd, serverOpts...); err != nil {
c.pwa.Destroy()
c.clientEP.Destroy()
tb.Fatalf("failed to create server Endpoint: %v", err)
@ -68,11 +68,13 @@ func testSendRecv(t *testing.T, c *testConnection) {
defer serverRun.Done()
t.Logf("server Endpoint waiting for packet 1")
if _, err := c.serverEP.RecvFirst(); err != nil {
t.Fatalf("server Endpoint.RecvFirst() failed: %v", err)
t.Errorf("server Endpoint.RecvFirst() failed: %v", err)
return
}
t.Logf("server Endpoint got packet 1, sending packet 2 and waiting for packet 3")
if _, err := c.serverEP.SendRecv(0); err != nil {
t.Fatalf("server Endpoint.SendRecv() failed: %v", err)
t.Errorf("server Endpoint.SendRecv() failed: %v", err)
return
}
t.Logf("server Endpoint got packet 3")
}()
@ -105,7 +107,30 @@ func TestSendRecv(t *testing.T) {
testSendRecv(t, c)
}
func testShutdownConnect(t *testing.T, c *testConnection) {
func testShutdownBeforeConnect(t *testing.T, c *testConnection, remoteShutdown bool) {
if remoteShutdown {
c.serverEP.Shutdown()
} else {
c.clientEP.Shutdown()
}
if err := c.clientEP.Connect(); err == nil {
t.Errorf("client Endpoint.Connect() succeeded unexpectedly")
}
}
func TestShutdownBeforeConnectLocal(t *testing.T) {
c := newTestConnection(t)
defer c.destroy()
testShutdownBeforeConnect(t, c, false)
}
func TestShutdownBeforeConnectRemote(t *testing.T) {
c := newTestConnection(t)
defer c.destroy()
testShutdownBeforeConnect(t, c, true)
}
func testShutdownDuringConnect(t *testing.T, c *testConnection, remoteShutdown bool) {
var clientRun sync.WaitGroup
clientRun.Add(1)
go func() {
@ -115,44 +140,86 @@ func testShutdownConnect(t *testing.T, c *testConnection) {
}
}()
time.Sleep(time.Second) // to allow c.clientEP.Connect() to block
c.clientEP.Shutdown()
if remoteShutdown {
c.serverEP.Shutdown()
} else {
c.clientEP.Shutdown()
}
clientRun.Wait()
}
func TestShutdownConnect(t *testing.T) {
func TestShutdownDuringConnectLocal(t *testing.T) {
c := newTestConnection(t)
defer c.destroy()
testShutdownConnect(t, c)
testShutdownDuringConnect(t, c, false)
}
func testShutdownRecvFirstBeforeConnect(t *testing.T, c *testConnection) {
var serverRun sync.WaitGroup
serverRun.Add(1)
go func() {
defer serverRun.Done()
_, err := c.serverEP.RecvFirst()
if err == nil {
t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly")
}
}()
time.Sleep(time.Second) // to allow c.serverEP.RecvFirst() to block
c.serverEP.Shutdown()
serverRun.Wait()
}
func TestShutdownRecvFirstBeforeConnect(t *testing.T) {
func TestShutdownDuringConnectRemote(t *testing.T) {
c := newTestConnection(t)
defer c.destroy()
testShutdownRecvFirstBeforeConnect(t, c)
testShutdownDuringConnect(t, c, true)
}
func testShutdownRecvFirstAfterConnect(t *testing.T, c *testConnection) {
func testShutdownBeforeRecvFirst(t *testing.T, c *testConnection, remoteShutdown bool) {
if remoteShutdown {
c.clientEP.Shutdown()
} else {
c.serverEP.Shutdown()
}
if _, err := c.serverEP.RecvFirst(); err == nil {
t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly")
}
}
func TestShutdownBeforeRecvFirstLocal(t *testing.T) {
c := newTestConnection(t)
defer c.destroy()
testShutdownBeforeRecvFirst(t, c, false)
}
func TestShutdownBeforeRecvFirstRemote(t *testing.T) {
c := newTestConnection(t)
defer c.destroy()
testShutdownBeforeRecvFirst(t, c, true)
}
func testShutdownDuringRecvFirstBeforeConnect(t *testing.T, c *testConnection, remoteShutdown bool) {
var serverRun sync.WaitGroup
serverRun.Add(1)
go func() {
defer serverRun.Done()
if _, err := c.serverEP.RecvFirst(); err == nil {
t.Fatalf("server Endpoint.RecvFirst() succeeded unexpectedly")
t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly")
}
}()
time.Sleep(time.Second) // to allow c.serverEP.RecvFirst() to block
if remoteShutdown {
c.clientEP.Shutdown()
} else {
c.serverEP.Shutdown()
}
serverRun.Wait()
}
func TestShutdownDuringRecvFirstBeforeConnectLocal(t *testing.T) {
c := newTestConnection(t)
defer c.destroy()
testShutdownDuringRecvFirstBeforeConnect(t, c, false)
}
func TestShutdownDuringRecvFirstBeforeConnectRemote(t *testing.T) {
c := newTestConnection(t)
defer c.destroy()
testShutdownDuringRecvFirstBeforeConnect(t, c, true)
}
func testShutdownDuringRecvFirstAfterConnect(t *testing.T, c *testConnection, remoteShutdown bool) {
var serverRun sync.WaitGroup
serverRun.Add(1)
go func() {
defer serverRun.Done()
if _, err := c.serverEP.RecvFirst(); err == nil {
t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly")
}
}()
defer func() {
@ -164,23 +231,75 @@ func testShutdownRecvFirstAfterConnect(t *testing.T, c *testConnection) {
if err := c.clientEP.Connect(); err != nil {
t.Fatalf("client Endpoint.Connect() failed: %v", err)
}
c.serverEP.Shutdown()
if remoteShutdown {
c.clientEP.Shutdown()
} else {
c.serverEP.Shutdown()
}
serverRun.Wait()
}
func TestShutdownRecvFirstAfterConnect(t *testing.T) {
func TestShutdownDuringRecvFirstAfterConnectLocal(t *testing.T) {
c := newTestConnection(t)
defer c.destroy()
testShutdownRecvFirstAfterConnect(t, c)
testShutdownDuringRecvFirstAfterConnect(t, c, false)
}
func testShutdownSendRecv(t *testing.T, c *testConnection) {
func TestShutdownDuringRecvFirstAfterConnectRemote(t *testing.T) {
c := newTestConnection(t)
defer c.destroy()
testShutdownDuringRecvFirstAfterConnect(t, c, true)
}
func testShutdownDuringClientSendRecv(t *testing.T, c *testConnection, remoteShutdown bool) {
var serverRun sync.WaitGroup
serverRun.Add(1)
go func() {
defer serverRun.Done()
if _, err := c.serverEP.RecvFirst(); err != nil {
t.Fatalf("server Endpoint.RecvFirst() failed: %v", err)
t.Errorf("server Endpoint.RecvFirst() failed: %v", err)
}
// At this point, the client must be blocked in c.clientEP.SendRecv().
if remoteShutdown {
c.serverEP.Shutdown()
} else {
c.clientEP.Shutdown()
}
}()
defer func() {
// Ensure that the server goroutine is cleaned up before
// c.serverEP.Destroy(), even if the test fails.
c.serverEP.Shutdown()
serverRun.Wait()
}()
if err := c.clientEP.Connect(); err != nil {
t.Fatalf("client Endpoint.Connect() failed: %v", err)
}
if _, err := c.clientEP.SendRecv(0); err == nil {
t.Errorf("client Endpoint.SendRecv() succeeded unexpectedly")
}
}
func TestShutdownDuringClientSendRecvLocal(t *testing.T) {
c := newTestConnection(t)
defer c.destroy()
testShutdownDuringClientSendRecv(t, c, false)
}
func TestShutdownDuringClientSendRecvRemote(t *testing.T) {
c := newTestConnection(t)
defer c.destroy()
testShutdownDuringClientSendRecv(t, c, true)
}
func testShutdownDuringServerSendRecv(t *testing.T, c *testConnection, remoteShutdown bool) {
var serverRun sync.WaitGroup
serverRun.Add(1)
go func() {
defer serverRun.Done()
if _, err := c.serverEP.RecvFirst(); err != nil {
t.Errorf("server Endpoint.RecvFirst() failed: %v", err)
return
}
if _, err := c.serverEP.SendRecv(0); err == nil {
t.Errorf("server Endpoint.SendRecv() succeeded unexpectedly")
@ -199,14 +318,24 @@ func testShutdownSendRecv(t *testing.T, c *testConnection) {
t.Fatalf("client Endpoint.SendRecv() failed: %v", err)
}
time.Sleep(time.Second) // to allow serverEP.SendRecv() to block
c.serverEP.Shutdown()
if remoteShutdown {
c.clientEP.Shutdown()
} else {
c.serverEP.Shutdown()
}
serverRun.Wait()
}
func TestShutdownSendRecv(t *testing.T) {
func TestShutdownDuringServerSendRecvLocal(t *testing.T) {
c := newTestConnection(t)
defer c.destroy()
testShutdownSendRecv(t, c)
testShutdownDuringServerSendRecv(t, c, false)
}
func TestShutdownDuringServerSendRecvRemote(t *testing.T) {
c := newTestConnection(t)
defer c.destroy()
testShutdownDuringServerSendRecv(t, c, true)
}
func benchmarkSendRecv(b *testing.B, c *testConnection) {
@ -218,15 +347,17 @@ func benchmarkSendRecv(b *testing.B, c *testConnection) {
return
}
if _, err := c.serverEP.RecvFirst(); err != nil {
b.Fatalf("server Endpoint.RecvFirst() failed: %v", err)
b.Errorf("server Endpoint.RecvFirst() failed: %v", err)
return
}
for i := 1; i < b.N; i++ {
if _, err := c.serverEP.SendRecv(0); err != nil {
b.Fatalf("server Endpoint.SendRecv() failed: %v", err)
b.Errorf("server Endpoint.SendRecv() failed: %v", err)
return
}
}
if err := c.serverEP.SendLast(0); err != nil {
b.Fatalf("server Endpoint.SendLast() failed: %v", err)
b.Errorf("server Endpoint.SendLast() failed: %v", err)
}
}()
defer func() {

View File

@ -19,15 +19,15 @@ import (
"unsafe"
)
// Packets consist of an 8-byte header followed by an arbitrarily-sized
// Packets consist of a 16-byte header followed by an arbitrarily-sized
// datagram. The header consists of:
//
// - A 4-byte native-endian connection state.
//
// - A 4-byte native-endian datagram length in bytes.
//
// - 8 reserved bytes.
const (
sizeofUint32 = unsafe.Sizeof(uint32(0))
// PacketHeaderBytes is the size of a flipcall packet header in bytes. The
// maximum datagram size supported by a flipcall connection is equal to the
// length of the packet window minus PacketHeaderBytes.
@ -35,7 +35,7 @@ const (
// PacketHeaderBytes is exported to support its use in constant
// expressions. Non-constant expressions may prefer to use
// PacketWindowLengthForDataCap().
PacketHeaderBytes = 2 * sizeofUint32
PacketHeaderBytes = 16
)
func (ep *Endpoint) connState() *uint32 {
@ -43,7 +43,7 @@ func (ep *Endpoint) connState() *uint32 {
}
func (ep *Endpoint) dataLen() *uint32 {
return (*uint32)((unsafe.Pointer)(ep.packet + sizeofUint32))
return (*uint32)((unsafe.Pointer)(ep.packet + 4))
}
// Data returns the datagram part of ep's packet window as a byte slice.

View File

@ -59,7 +59,12 @@ func (ep *Endpoint) futexConnect(req *ctrlHandshakeRequest) (ctrlHandshakeRespon
func (ep *Endpoint) futexSwitchToPeer() error {
// Update connection state to indicate that the peer should be active.
if !atomic.CompareAndSwapUint32(ep.connState(), ep.activeState, ep.inactiveState) {
return fmt.Errorf("unexpected connection state before FUTEX_WAKE: %v", atomic.LoadUint32(ep.connState()))
switch cs := atomic.LoadUint32(ep.connState()); cs {
case csShutdown:
return shutdownError{}
default:
return fmt.Errorf("unexpected connection state before FUTEX_WAKE: %v", cs)
}
}
// Wake the peer's Endpoint.futexSwitchFromPeer().
@ -75,16 +80,18 @@ func (ep *Endpoint) futexSwitchFromPeer() error {
case ep.activeState:
return nil
case ep.inactiveState:
// Continue to FUTEX_WAIT.
if ep.isShutdownLocally() {
return shutdownError{}
}
if err := ep.futexWaitConnState(ep.inactiveState); err != nil {
return fmt.Errorf("failed to FUTEX_WAIT for peer Endpoint: %v", err)
}
continue
case csShutdown:
return shutdownError{}
default:
return fmt.Errorf("unexpected connection state before FUTEX_WAIT: %v", cs)
}
if ep.isShutdownLocally() {
return shutdownError{}
}
if err := ep.futexWaitConnState(ep.inactiveState); err != nil {
return fmt.Errorf("failed to FUTEX_WAIT for peer Endpoint: %v", err)
}
}
}