From 5b4c20e1b83fad0691469a6b8a11a86f3453838a Mon Sep 17 00:00:00 2001 From: Zhaozhong Ni Date: Wed, 16 May 2018 14:14:28 -0700 Subject: [PATCH] netstack: make TCP endpoint closed and error state cleanup work synchronous. So that when saving TCP endpoint in these states, there is no pending or background activities. Also lift tcp network save rejection error to tcpip package. PiperOrigin-RevId: 196886839 Change-Id: I0fe73750f2743ec7e62d139eb2cec758c5dd6698 --- pkg/tcpip/tcpip.go | 11 +++ pkg/tcpip/transport/tcp/accept.go | 8 ++- pkg/tcpip/transport/tcp/connect.go | 84 +++++++++++------------ pkg/tcpip/transport/tcp/endpoint.go | 33 ++++----- pkg/tcpip/transport/tcp/endpoint_state.go | 22 ++---- pkg/tcpip/transport/tcp/segment.go | 2 +- 6 files changed, 78 insertions(+), 82 deletions(-) diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index c27c0dd89..707fda4d2 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -84,6 +84,17 @@ var ( errSubnetAddressMasked = errors.New("subnet address has bits set outside the mask") ) +// ErrSaveRejection indicates a failed save due to unsupported networking state. +// This type of errors is only used for save logic. +type ErrSaveRejection struct { + Err error +} + +// Error returns a sensible description of the save rejection error. +func (e ErrSaveRejection) Error() string { + return "save rejected due to unsupported networking state: " + e.Err.Error() +} + // A Clock provides the current time. // // Times returned by a Clock should always be used for application-visible diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index a71cb444f..ac213e310 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -349,13 +349,17 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { // to the endpoint. e.mu.Lock() e.state = stateClosed - e.mu.Unlock() // Notify waiters that the endpoint is shutdown. e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut) // Do cleanup if needed. - e.completeWorker() + e.completeWorkerLocked() + + if e.drainDone != nil { + close(e.drainDone) + } + e.mu.Unlock() }() e.mu.Lock() diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 698e2b440..4cc0d733d 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -690,7 +690,7 @@ func (e *endpoint) sendRaw(data buffer.View, flags byte, seq, ack seqnum.Value, return err } -func (e *endpoint) handleWrite() bool { +func (e *endpoint) handleWrite() *tcpip.Error { // Move packets from send queue to send list. The queue is accessible // from other goroutines and protected by the send mutex, while the send // list is only accessible from the handler goroutine, so it needs no @@ -714,47 +714,42 @@ func (e *endpoint) handleWrite() bool { // Push out any new packets. e.snd.sendData() - return true + return nil } -func (e *endpoint) handleClose() bool { +func (e *endpoint) handleClose() *tcpip.Error { // Drain the send queue. e.handleWrite() // Mark send side as closed. e.snd.closed = true - return true + return nil } -// resetConnection sends a RST segment and puts the endpoint in an error state -// with the given error code. -// This method must only be called from the protocol goroutine. -func (e *endpoint) resetConnection(err *tcpip.Error) { +// resetConnectionLocked sends a RST segment and puts the endpoint in an error +// state with the given error code. This method must only be called from the +// protocol goroutine. +func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { e.sendRaw(nil, flagAck|flagRst, e.snd.sndUna, e.rcv.rcvNxt, 0) - e.mu.Lock() e.state = stateError e.hardError = err - e.mu.Unlock() } -// completeWorker is called by the worker goroutine when it's about to exit. It -// marks the worker as completed and performs cleanup work if requested by -// Close(). -func (e *endpoint) completeWorker() { - e.mu.Lock() - defer e.mu.Unlock() - +// completeWorkerLocked is called by the worker goroutine when it's about to +// exit. It marks the worker as completed and performs cleanup work if requested +// by Close(). +func (e *endpoint) completeWorkerLocked() { e.workerRunning = false if e.workerCleanup { - e.cleanup() + e.cleanupLocked() } } // handleSegments pulls segments from the queue and processes them. It returns -// true if the protocol loop should continue, false otherwise. -func (e *endpoint) handleSegments() bool { +// no error if the protocol loop should continue, an error otherwise. +func (e *endpoint) handleSegments() *tcpip.Error { checkRequeue := true for i := 0; i < maxSegmentsPerWake; i++ { s := e.segmentQueue.dequeue() @@ -775,11 +770,7 @@ func (e *endpoint) handleSegments() bool { // validated by checking their SEQ-fields." So // we only process it if it's acceptable. s.decRef() - e.mu.Lock() - e.state = stateError - e.hardError = tcpip.ErrConnectionReset - e.mu.Unlock() - return false + return tcpip.ErrConnectionReset } } else if s.flagIsSet(flagAck) { // Patch the window size in the segment according to the @@ -816,7 +807,7 @@ func (e *endpoint) handleSegments() bool { e.snd.sendAck() } - return true + return nil } // protocolMainLoop is the main loop of the TCP protocol. It runs in its own @@ -827,8 +818,9 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { var closeWaker sleep.Waker defer func() { + // e.mu is expected to be hold upon entering this section. e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut) - e.completeWorker() + e.completeWorkerLocked() if e.snd != nil { e.snd.resendTimer.cleanup() @@ -837,6 +829,12 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { if closeTimer != nil { closeTimer.Stop() } + + if e.drainDone != nil { + close(e.drainDone) + } + + e.mu.Unlock() }() if !passive { @@ -855,12 +853,7 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { e.mu.Lock() e.state = stateError e.hardError = err - drained := e.drainDone != nil - e.mu.Unlock() - if drained { - close(e.drainDone) - <-e.undrain - } + // Lock released in deferred statement. return err } @@ -894,7 +887,7 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { // wakes up. funcs := []struct { w *sleep.Waker - f func() bool + f func() *tcpip.Error }{ { w: &e.sndWaker, @@ -910,24 +903,22 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { }, { w: &closeWaker, - f: func() bool { - e.resetConnection(tcpip.ErrConnectionAborted) - return false + f: func() *tcpip.Error { + return tcpip.ErrConnectionAborted }, }, { w: &e.snd.resendWaker, - f: func() bool { + f: func() *tcpip.Error { if !e.snd.retransmitTimerExpired() { - e.resetConnection(tcpip.ErrTimeout) - return false + return tcpip.ErrTimeout } - return true + return nil }, }, { w: &e.notificationWaker, - f: func() bool { + f: func() *tcpip.Error { n := e.fetchNotifications() if n¬ifyNonZeroReceiveWindow != 0 { e.rcv.nonZeroWindow() @@ -954,7 +945,7 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { closeWaker.Assert() }) } - return true + return nil }, }, } @@ -971,7 +962,10 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { e.workMu.Unlock() v, _ := s.Fetch(true) e.workMu.Lock() - if !funcs[v].f() { + if err := funcs[v].f(); err != nil { + e.mu.Lock() + e.resetConnectionLocked(err) + // Lock released in deferred statement. return nil } } @@ -979,7 +973,7 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { // Mark endpoint as closed. e.mu.Lock() e.state = stateClosed - e.mu.Unlock() + // Lock released in deferred statement. return nil } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index f26b28632..3f87c4cac 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -94,7 +94,7 @@ type endpoint struct { state endpointState isPortReserved bool `state:"manual"` isRegistered bool - boundNICID tcpip.NICID + boundNICID tcpip.NICID `state:"manual"` route stack.Route `state:"manual"` v6only bool isConnectNotified bool @@ -118,7 +118,7 @@ type endpoint struct { // workerCleanup specifies if the worker goroutine must perform cleanup // before exitting. This can only be set to true when workerRunning is // also true, and they're both protected by the mutex. - workerCleanup bool + workerCleanup bool `state:"zerovalue"` // sendTSOk is used to indicate when the TS Option has been negotiated. // When sendTSOk is true every non-RST segment should carry a TS as per @@ -326,13 +326,7 @@ func (e *endpoint) Close() { // if we're connected, or stop accepting if we're listening. e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) - // While we hold the lock, determine if the cleanup should happen - // inline or if we should tell the worker (if any) to do the cleanup. e.mu.Lock() - worker := e.workerRunning - if worker { - e.workerCleanup = true - } // We always release ports inline so that they are immediately available // for reuse after Close() is called. If also registered, it means this @@ -348,29 +342,32 @@ func (e *endpoint) Close() { } } - e.mu.Unlock() - - // Now that we don't hold the lock anymore, either perform the local - // cleanup or kick the worker to make sure it knows it needs to cleanup. - if !worker { - e.cleanup() + // Either perform the local cleanup or kick the worker to make sure it + // knows it needs to cleanup. + if !e.workerRunning { + e.cleanupLocked() } else { + e.workerCleanup = true e.notifyProtocolGoroutine(notifyClose) } + + e.mu.Unlock() } -// cleanup frees all resources associated with the endpoint. It is called after -// Close() is called and the worker goroutine (if any) is done with its work. -func (e *endpoint) cleanup() { +// cleanupLocked frees all resources associated with the endpoint. It is called +// after Close() is called and the worker goroutine (if any) is done with its +// work. +func (e *endpoint) cleanupLocked() { // Close all endpoints that might have been accepted by TCP but not by // the client. if e.acceptedChan != nil { close(e.acceptedChan) for n := range e.acceptedChan { - n.resetConnection(tcpip.ErrConnectionAborted) + n.resetConnectionLocked(tcpip.ErrConnectionAborted) n.Close() } } + e.workerCleanup = false if e.isRegistered { e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id) diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index ebab7006d..212d2513a 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -12,17 +12,6 @@ import ( "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" ) -// ErrSaveRejection indicates a failed save due to unsupported tcp endpoint -// state. -type ErrSaveRejection struct { - Err error -} - -// Error returns a sensible description of the save rejection error. -func (e ErrSaveRejection) Error() string { - return "save rejected due to unsupported endpoint state: " + e.Err.Error() -} - func (e *endpoint) drainSegmentLocked() { // Drain only up to once. if e.drainDone != nil { @@ -48,8 +37,7 @@ func (e *endpoint) beforeSave() { defer e.mu.Unlock() switch e.state { - case stateInitial: - case stateBound: + case stateInitial, stateBound: case stateListen: if !e.segmentQueue.empty() { e.drainSegmentLocked() @@ -62,9 +50,11 @@ func (e *endpoint) beforeSave() { fallthrough case stateConnected: // FIXME - panic(ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%v, remote %v:%v", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)}) - case stateClosed: - case stateError: + panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%v, remote %v:%v", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)}) + case stateClosed, stateError: + if e.workerRunning { + panic(fmt.Sprintf("endpoint still has worker running in closed or error state")) + } default: panic(fmt.Sprintf("endpoint in unknown state %v", e.state)) } diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index c742fc394..07e4bfd73 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -30,7 +30,7 @@ type segment struct { segmentEntry refCnt int32 id stack.TransportEndpointID - route stack.Route + route stack.Route `state:"manual"` data buffer.VectorisedView // views is used as buffer for data when its length is large // enough to store a VectorisedView.