diff --git a/pkg/flipcall/ctrl_futex.go b/pkg/flipcall/ctrl_futex.go index 865b6f640..d59159912 100644 --- a/pkg/flipcall/ctrl_futex.go +++ b/pkg/flipcall/ctrl_futex.go @@ -121,7 +121,16 @@ func (ep *Endpoint) enterFutexWait() error { } func (ep *Endpoint) exitFutexWait() { - atomic.AddInt32(&ep.ctrl.state, -epsBlocked) + switch eps := atomic.AddInt32(&ep.ctrl.state, -epsBlocked); eps { + case 0: + return + case epsShutdown: + // ep.ctrlShutdown() was called while we were blocked, so we are + // repsonsible for indicating connection shutdown. + ep.shutdownConn() + default: + panic(fmt.Sprintf("invalid flipcall.Endpoint.ctrl.state after flipcall.Endpoint.exitFutexWait(): %v", eps+epsBlocked)) + } } func (ep *Endpoint) ctrlShutdown() { @@ -142,5 +151,25 @@ func (ep *Endpoint) ctrlShutdown() { break } } + } else { + // There is no blocked thread, so we are responsible for indicating + // connection shutdown. + ep.shutdownConn() + } +} + +func (ep *Endpoint) shutdownConn() { + switch cs := atomic.SwapUint32(ep.connState(), csShutdown); cs { + case ep.activeState: + if err := ep.futexWakeConnState(1); err != nil { + log.Warningf("failed to FUTEX_WAKE peer Endpoint for shutdown: %v", err) + } + case ep.inactiveState: + // The peer is currently active and will detect shutdown when it tries + // to update the connection state. + case csShutdown: + // The peer also called Endpoint.Shutdown(). + default: + log.Warningf("unexpected connection state before Endpoint.shutdownConn(): %v", cs) } } diff --git a/pkg/flipcall/flipcall.go b/pkg/flipcall/flipcall.go index 5c9212c33..991018684 100644 --- a/pkg/flipcall/flipcall.go +++ b/pkg/flipcall/flipcall.go @@ -42,11 +42,6 @@ type Endpoint struct { // dataCap is immutable. dataCap uint32 - // shutdown is non-zero if Endpoint.Shutdown() has been called, or if the - // Endpoint has acknowledged shutdown initiated by the peer. shutdown is - // accessed using atomic memory operations. - shutdown uint32 - // activeState is csClientActive if this is a client Endpoint and // csServerActive if this is a server Endpoint. activeState uint32 @@ -55,9 +50,27 @@ type Endpoint struct { // csClientActive if this is a server Endpoint. inactiveState uint32 + // shutdown is non-zero if Endpoint.Shutdown() has been called, or if the + // Endpoint has acknowledged shutdown initiated by the peer. shutdown is + // accessed using atomic memory operations. + shutdown uint32 + ctrl endpointControlImpl } +// EndpointSide indicates which side of a connection an Endpoint belongs to. +type EndpointSide int + +const ( + // ClientSide indicates that an Endpoint is a client (initially-active; + // first method call should be Connect). + ClientSide EndpointSide = iota + + // ServerSide indicates that an Endpoint is a server (initially-inactive; + // first method call should be RecvFirst.) + ServerSide +) + // Init must be called on zero-value Endpoints before first use. If it // succeeds, ep.Destroy() must be called once the Endpoint is no longer in use. // @@ -65,7 +78,17 @@ type Endpoint struct { // Endpoint. FD may differ between Endpoints if they are in different // processes, but must represent the same file. The packet window must // initially be filled with zero bytes. -func (ep *Endpoint) Init(pwd PacketWindowDescriptor, opts ...EndpointOption) error { +func (ep *Endpoint) Init(side EndpointSide, pwd PacketWindowDescriptor, opts ...EndpointOption) error { + switch side { + case ClientSide: + ep.activeState = csClientActive + ep.inactiveState = csServerActive + case ServerSide: + ep.activeState = csServerActive + ep.inactiveState = csClientActive + default: + return fmt.Errorf("invalid EndpointSide: %v", side) + } if pwd.Length < pageSize { return fmt.Errorf("packet window size (%d) less than minimum (%d)", pwd.Length, pageSize) } @@ -78,9 +101,6 @@ func (ep *Endpoint) Init(pwd PacketWindowDescriptor, opts ...EndpointOption) err } ep.packet = m ep.dataCap = uint32(pwd.Length) - uint32(PacketHeaderBytes) - // These will be overwritten by ep.Connect() for client Endpoints. - ep.activeState = csServerActive - ep.inactiveState = csClientActive if err := ep.ctrlInit(opts...); err != nil { ep.unmapPacket() return err @@ -90,9 +110,9 @@ func (ep *Endpoint) Init(pwd PacketWindowDescriptor, opts ...EndpointOption) err // NewEndpoint is a convenience function that returns an initialized Endpoint // allocated on the heap. -func NewEndpoint(pwd PacketWindowDescriptor, opts ...EndpointOption) (*Endpoint, error) { +func NewEndpoint(side EndpointSide, pwd PacketWindowDescriptor, opts ...EndpointOption) (*Endpoint, error) { var ep Endpoint - if err := ep.Init(pwd, opts...); err != nil { + if err := ep.Init(side, pwd, opts...); err != nil { return nil, err } return &ep, nil @@ -115,9 +135,9 @@ func (ep *Endpoint) unmapPacket() { } // Shutdown causes concurrent and future calls to ep.Connect(), ep.SendRecv(), -// ep.RecvFirst(), and ep.SendLast() to unblock and return errors. It does not -// wait for concurrent calls to return. The effect of Shutdown on the peer -// Endpoint is unspecified. Successive calls to Shutdown have no effect. +// ep.RecvFirst(), and ep.SendLast(), as well as the same calls in the peer +// Endpoint, to unblock and return errors. It does not wait for concurrent +// calls to return. Successive calls to Shutdown have no effect. // // Shutdown is the only Endpoint method that may be called concurrently with // other methods on the same Endpoint. @@ -152,24 +172,22 @@ const ( // The client is, by definition, initially active, so this must be 0. csClientActive = 0 csServerActive = 1 + csShutdown = 2 ) -// Connect designates ep as a client Endpoint and blocks until the peer -// Endpoint has called Endpoint.RecvFirst(). +// Connect blocks until the peer Endpoint has called Endpoint.RecvFirst(). // -// Preconditions: ep.Connect(), ep.RecvFirst(), ep.SendRecv(), and -// ep.SendLast() have never been called. +// Preconditions: ep is a client Endpoint. ep.Connect(), ep.RecvFirst(), +// ep.SendRecv(), and ep.SendLast() have never been called. func (ep *Endpoint) Connect() error { - ep.activeState = csClientActive - ep.inactiveState = csServerActive return ep.ctrlConnect() } // RecvFirst blocks until the peer Endpoint calls Endpoint.SendRecv(), then // returns the datagram length specified by that call. // -// Preconditions: ep.SendRecv(), ep.RecvFirst(), and ep.SendLast() have never -// been called. +// Preconditions: ep is a server Endpoint. ep.SendRecv(), ep.RecvFirst(), and +// ep.SendLast() have never been called. func (ep *Endpoint) RecvFirst() (uint32, error) { if err := ep.ctrlWaitFirst(); err != nil { return 0, err diff --git a/pkg/flipcall/flipcall_example_test.go b/pkg/flipcall/flipcall_example_test.go index edb6a8bef..8d88b845d 100644 --- a/pkg/flipcall/flipcall_example_test.go +++ b/pkg/flipcall/flipcall_example_test.go @@ -38,12 +38,12 @@ func Example() { panic(err) } var clientEP Endpoint - if err := clientEP.Init(pwd); err != nil { + if err := clientEP.Init(ClientSide, pwd); err != nil { panic(err) } defer clientEP.Destroy() var serverEP Endpoint - if err := serverEP.Init(pwd); err != nil { + if err := serverEP.Init(ServerSide, pwd); err != nil { panic(err) } defer serverEP.Destroy() diff --git a/pkg/flipcall/flipcall_test.go b/pkg/flipcall/flipcall_test.go index da9d736ab..435e4eeae 100644 --- a/pkg/flipcall/flipcall_test.go +++ b/pkg/flipcall/flipcall_test.go @@ -39,11 +39,11 @@ func newTestConnectionWithOptions(tb testing.TB, clientOpts, serverOpts []Endpoi c.pwa.Destroy() tb.Fatalf("PacketWindowAllocator.Allocate() failed: %v", err) } - if err := c.clientEP.Init(pwd, clientOpts...); err != nil { + if err := c.clientEP.Init(ClientSide, pwd, clientOpts...); err != nil { c.pwa.Destroy() tb.Fatalf("failed to create client Endpoint: %v", err) } - if err := c.serverEP.Init(pwd, serverOpts...); err != nil { + if err := c.serverEP.Init(ServerSide, pwd, serverOpts...); err != nil { c.pwa.Destroy() c.clientEP.Destroy() tb.Fatalf("failed to create server Endpoint: %v", err) @@ -68,11 +68,13 @@ func testSendRecv(t *testing.T, c *testConnection) { defer serverRun.Done() t.Logf("server Endpoint waiting for packet 1") if _, err := c.serverEP.RecvFirst(); err != nil { - t.Fatalf("server Endpoint.RecvFirst() failed: %v", err) + t.Errorf("server Endpoint.RecvFirst() failed: %v", err) + return } t.Logf("server Endpoint got packet 1, sending packet 2 and waiting for packet 3") if _, err := c.serverEP.SendRecv(0); err != nil { - t.Fatalf("server Endpoint.SendRecv() failed: %v", err) + t.Errorf("server Endpoint.SendRecv() failed: %v", err) + return } t.Logf("server Endpoint got packet 3") }() @@ -105,7 +107,30 @@ func TestSendRecv(t *testing.T) { testSendRecv(t, c) } -func testShutdownConnect(t *testing.T, c *testConnection) { +func testShutdownBeforeConnect(t *testing.T, c *testConnection, remoteShutdown bool) { + if remoteShutdown { + c.serverEP.Shutdown() + } else { + c.clientEP.Shutdown() + } + if err := c.clientEP.Connect(); err == nil { + t.Errorf("client Endpoint.Connect() succeeded unexpectedly") + } +} + +func TestShutdownBeforeConnectLocal(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownBeforeConnect(t, c, false) +} + +func TestShutdownBeforeConnectRemote(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownBeforeConnect(t, c, true) +} + +func testShutdownDuringConnect(t *testing.T, c *testConnection, remoteShutdown bool) { var clientRun sync.WaitGroup clientRun.Add(1) go func() { @@ -115,44 +140,86 @@ func testShutdownConnect(t *testing.T, c *testConnection) { } }() time.Sleep(time.Second) // to allow c.clientEP.Connect() to block - c.clientEP.Shutdown() + if remoteShutdown { + c.serverEP.Shutdown() + } else { + c.clientEP.Shutdown() + } clientRun.Wait() } -func TestShutdownConnect(t *testing.T) { +func TestShutdownDuringConnectLocal(t *testing.T) { c := newTestConnection(t) defer c.destroy() - testShutdownConnect(t, c) + testShutdownDuringConnect(t, c, false) } -func testShutdownRecvFirstBeforeConnect(t *testing.T, c *testConnection) { - var serverRun sync.WaitGroup - serverRun.Add(1) - go func() { - defer serverRun.Done() - _, err := c.serverEP.RecvFirst() - if err == nil { - t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly") - } - }() - time.Sleep(time.Second) // to allow c.serverEP.RecvFirst() to block - c.serverEP.Shutdown() - serverRun.Wait() -} - -func TestShutdownRecvFirstBeforeConnect(t *testing.T) { +func TestShutdownDuringConnectRemote(t *testing.T) { c := newTestConnection(t) defer c.destroy() - testShutdownRecvFirstBeforeConnect(t, c) + testShutdownDuringConnect(t, c, true) } -func testShutdownRecvFirstAfterConnect(t *testing.T, c *testConnection) { +func testShutdownBeforeRecvFirst(t *testing.T, c *testConnection, remoteShutdown bool) { + if remoteShutdown { + c.clientEP.Shutdown() + } else { + c.serverEP.Shutdown() + } + if _, err := c.serverEP.RecvFirst(); err == nil { + t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly") + } +} + +func TestShutdownBeforeRecvFirstLocal(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownBeforeRecvFirst(t, c, false) +} + +func TestShutdownBeforeRecvFirstRemote(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownBeforeRecvFirst(t, c, true) +} + +func testShutdownDuringRecvFirstBeforeConnect(t *testing.T, c *testConnection, remoteShutdown bool) { var serverRun sync.WaitGroup serverRun.Add(1) go func() { defer serverRun.Done() if _, err := c.serverEP.RecvFirst(); err == nil { - t.Fatalf("server Endpoint.RecvFirst() succeeded unexpectedly") + t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly") + } + }() + time.Sleep(time.Second) // to allow c.serverEP.RecvFirst() to block + if remoteShutdown { + c.clientEP.Shutdown() + } else { + c.serverEP.Shutdown() + } + serverRun.Wait() +} + +func TestShutdownDuringRecvFirstBeforeConnectLocal(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownDuringRecvFirstBeforeConnect(t, c, false) +} + +func TestShutdownDuringRecvFirstBeforeConnectRemote(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownDuringRecvFirstBeforeConnect(t, c, true) +} + +func testShutdownDuringRecvFirstAfterConnect(t *testing.T, c *testConnection, remoteShutdown bool) { + var serverRun sync.WaitGroup + serverRun.Add(1) + go func() { + defer serverRun.Done() + if _, err := c.serverEP.RecvFirst(); err == nil { + t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly") } }() defer func() { @@ -164,23 +231,75 @@ func testShutdownRecvFirstAfterConnect(t *testing.T, c *testConnection) { if err := c.clientEP.Connect(); err != nil { t.Fatalf("client Endpoint.Connect() failed: %v", err) } - c.serverEP.Shutdown() + if remoteShutdown { + c.clientEP.Shutdown() + } else { + c.serverEP.Shutdown() + } serverRun.Wait() } -func TestShutdownRecvFirstAfterConnect(t *testing.T) { +func TestShutdownDuringRecvFirstAfterConnectLocal(t *testing.T) { c := newTestConnection(t) defer c.destroy() - testShutdownRecvFirstAfterConnect(t, c) + testShutdownDuringRecvFirstAfterConnect(t, c, false) } -func testShutdownSendRecv(t *testing.T, c *testConnection) { +func TestShutdownDuringRecvFirstAfterConnectRemote(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownDuringRecvFirstAfterConnect(t, c, true) +} + +func testShutdownDuringClientSendRecv(t *testing.T, c *testConnection, remoteShutdown bool) { var serverRun sync.WaitGroup serverRun.Add(1) go func() { defer serverRun.Done() if _, err := c.serverEP.RecvFirst(); err != nil { - t.Fatalf("server Endpoint.RecvFirst() failed: %v", err) + t.Errorf("server Endpoint.RecvFirst() failed: %v", err) + } + // At this point, the client must be blocked in c.clientEP.SendRecv(). + if remoteShutdown { + c.serverEP.Shutdown() + } else { + c.clientEP.Shutdown() + } + }() + defer func() { + // Ensure that the server goroutine is cleaned up before + // c.serverEP.Destroy(), even if the test fails. + c.serverEP.Shutdown() + serverRun.Wait() + }() + if err := c.clientEP.Connect(); err != nil { + t.Fatalf("client Endpoint.Connect() failed: %v", err) + } + if _, err := c.clientEP.SendRecv(0); err == nil { + t.Errorf("client Endpoint.SendRecv() succeeded unexpectedly") + } +} + +func TestShutdownDuringClientSendRecvLocal(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownDuringClientSendRecv(t, c, false) +} + +func TestShutdownDuringClientSendRecvRemote(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownDuringClientSendRecv(t, c, true) +} + +func testShutdownDuringServerSendRecv(t *testing.T, c *testConnection, remoteShutdown bool) { + var serverRun sync.WaitGroup + serverRun.Add(1) + go func() { + defer serverRun.Done() + if _, err := c.serverEP.RecvFirst(); err != nil { + t.Errorf("server Endpoint.RecvFirst() failed: %v", err) + return } if _, err := c.serverEP.SendRecv(0); err == nil { t.Errorf("server Endpoint.SendRecv() succeeded unexpectedly") @@ -199,14 +318,24 @@ func testShutdownSendRecv(t *testing.T, c *testConnection) { t.Fatalf("client Endpoint.SendRecv() failed: %v", err) } time.Sleep(time.Second) // to allow serverEP.SendRecv() to block - c.serverEP.Shutdown() + if remoteShutdown { + c.clientEP.Shutdown() + } else { + c.serverEP.Shutdown() + } serverRun.Wait() } -func TestShutdownSendRecv(t *testing.T) { +func TestShutdownDuringServerSendRecvLocal(t *testing.T) { c := newTestConnection(t) defer c.destroy() - testShutdownSendRecv(t, c) + testShutdownDuringServerSendRecv(t, c, false) +} + +func TestShutdownDuringServerSendRecvRemote(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownDuringServerSendRecv(t, c, true) } func benchmarkSendRecv(b *testing.B, c *testConnection) { @@ -218,15 +347,17 @@ func benchmarkSendRecv(b *testing.B, c *testConnection) { return } if _, err := c.serverEP.RecvFirst(); err != nil { - b.Fatalf("server Endpoint.RecvFirst() failed: %v", err) + b.Errorf("server Endpoint.RecvFirst() failed: %v", err) + return } for i := 1; i < b.N; i++ { if _, err := c.serverEP.SendRecv(0); err != nil { - b.Fatalf("server Endpoint.SendRecv() failed: %v", err) + b.Errorf("server Endpoint.SendRecv() failed: %v", err) + return } } if err := c.serverEP.SendLast(0); err != nil { - b.Fatalf("server Endpoint.SendLast() failed: %v", err) + b.Errorf("server Endpoint.SendLast() failed: %v", err) } }() defer func() { diff --git a/pkg/flipcall/flipcall_unsafe.go b/pkg/flipcall/flipcall_unsafe.go index 7c8977893..73e6eef29 100644 --- a/pkg/flipcall/flipcall_unsafe.go +++ b/pkg/flipcall/flipcall_unsafe.go @@ -19,15 +19,15 @@ import ( "unsafe" ) -// Packets consist of an 8-byte header followed by an arbitrarily-sized +// Packets consist of a 16-byte header followed by an arbitrarily-sized // datagram. The header consists of: // // - A 4-byte native-endian connection state. // // - A 4-byte native-endian datagram length in bytes. +// +// - 8 reserved bytes. const ( - sizeofUint32 = unsafe.Sizeof(uint32(0)) - // PacketHeaderBytes is the size of a flipcall packet header in bytes. The // maximum datagram size supported by a flipcall connection is equal to the // length of the packet window minus PacketHeaderBytes. @@ -35,7 +35,7 @@ const ( // PacketHeaderBytes is exported to support its use in constant // expressions. Non-constant expressions may prefer to use // PacketWindowLengthForDataCap(). - PacketHeaderBytes = 2 * sizeofUint32 + PacketHeaderBytes = 16 ) func (ep *Endpoint) connState() *uint32 { @@ -43,7 +43,7 @@ func (ep *Endpoint) connState() *uint32 { } func (ep *Endpoint) dataLen() *uint32 { - return (*uint32)((unsafe.Pointer)(ep.packet + sizeofUint32)) + return (*uint32)((unsafe.Pointer)(ep.packet + 4)) } // Data returns the datagram part of ep's packet window as a byte slice. diff --git a/pkg/flipcall/futex_linux.go b/pkg/flipcall/futex_linux.go index e7dd812b3..b127a2bbb 100644 --- a/pkg/flipcall/futex_linux.go +++ b/pkg/flipcall/futex_linux.go @@ -59,7 +59,12 @@ func (ep *Endpoint) futexConnect(req *ctrlHandshakeRequest) (ctrlHandshakeRespon func (ep *Endpoint) futexSwitchToPeer() error { // Update connection state to indicate that the peer should be active. if !atomic.CompareAndSwapUint32(ep.connState(), ep.activeState, ep.inactiveState) { - return fmt.Errorf("unexpected connection state before FUTEX_WAKE: %v", atomic.LoadUint32(ep.connState())) + switch cs := atomic.LoadUint32(ep.connState()); cs { + case csShutdown: + return shutdownError{} + default: + return fmt.Errorf("unexpected connection state before FUTEX_WAKE: %v", cs) + } } // Wake the peer's Endpoint.futexSwitchFromPeer(). @@ -75,16 +80,18 @@ func (ep *Endpoint) futexSwitchFromPeer() error { case ep.activeState: return nil case ep.inactiveState: - // Continue to FUTEX_WAIT. + if ep.isShutdownLocally() { + return shutdownError{} + } + if err := ep.futexWaitConnState(ep.inactiveState); err != nil { + return fmt.Errorf("failed to FUTEX_WAIT for peer Endpoint: %v", err) + } + continue + case csShutdown: + return shutdownError{} default: return fmt.Errorf("unexpected connection state before FUTEX_WAIT: %v", cs) } - if ep.isShutdownLocally() { - return shutdownError{} - } - if err := ep.futexWaitConnState(ep.inactiveState); err != nil { - return fmt.Errorf("failed to FUTEX_WAIT for peer Endpoint: %v", err) - } } }