diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index c27c0dd89..707fda4d2 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -84,6 +84,17 @@ var ( errSubnetAddressMasked = errors.New("subnet address has bits set outside the mask") ) +// ErrSaveRejection indicates a failed save due to unsupported networking state. +// This type of errors is only used for save logic. +type ErrSaveRejection struct { + Err error +} + +// Error returns a sensible description of the save rejection error. +func (e ErrSaveRejection) Error() string { + return "save rejected due to unsupported networking state: " + e.Err.Error() +} + // A Clock provides the current time. // // Times returned by a Clock should always be used for application-visible diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index a71cb444f..ac213e310 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -349,13 +349,17 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { // to the endpoint. e.mu.Lock() e.state = stateClosed - e.mu.Unlock() // Notify waiters that the endpoint is shutdown. e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut) // Do cleanup if needed. - e.completeWorker() + e.completeWorkerLocked() + + if e.drainDone != nil { + close(e.drainDone) + } + e.mu.Unlock() }() e.mu.Lock() diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 698e2b440..4cc0d733d 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -690,7 +690,7 @@ func (e *endpoint) sendRaw(data buffer.View, flags byte, seq, ack seqnum.Value, return err } -func (e *endpoint) handleWrite() bool { +func (e *endpoint) handleWrite() *tcpip.Error { // Move packets from send queue to send list. The queue is accessible // from other goroutines and protected by the send mutex, while the send // list is only accessible from the handler goroutine, so it needs no @@ -714,47 +714,42 @@ func (e *endpoint) handleWrite() bool { // Push out any new packets. e.snd.sendData() - return true + return nil } -func (e *endpoint) handleClose() bool { +func (e *endpoint) handleClose() *tcpip.Error { // Drain the send queue. e.handleWrite() // Mark send side as closed. e.snd.closed = true - return true + return nil } -// resetConnection sends a RST segment and puts the endpoint in an error state -// with the given error code. -// This method must only be called from the protocol goroutine. -func (e *endpoint) resetConnection(err *tcpip.Error) { +// resetConnectionLocked sends a RST segment and puts the endpoint in an error +// state with the given error code. This method must only be called from the +// protocol goroutine. +func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { e.sendRaw(nil, flagAck|flagRst, e.snd.sndUna, e.rcv.rcvNxt, 0) - e.mu.Lock() e.state = stateError e.hardError = err - e.mu.Unlock() } -// completeWorker is called by the worker goroutine when it's about to exit. It -// marks the worker as completed and performs cleanup work if requested by -// Close(). -func (e *endpoint) completeWorker() { - e.mu.Lock() - defer e.mu.Unlock() - +// completeWorkerLocked is called by the worker goroutine when it's about to +// exit. It marks the worker as completed and performs cleanup work if requested +// by Close(). +func (e *endpoint) completeWorkerLocked() { e.workerRunning = false if e.workerCleanup { - e.cleanup() + e.cleanupLocked() } } // handleSegments pulls segments from the queue and processes them. It returns -// true if the protocol loop should continue, false otherwise. -func (e *endpoint) handleSegments() bool { +// no error if the protocol loop should continue, an error otherwise. +func (e *endpoint) handleSegments() *tcpip.Error { checkRequeue := true for i := 0; i < maxSegmentsPerWake; i++ { s := e.segmentQueue.dequeue() @@ -775,11 +770,7 @@ func (e *endpoint) handleSegments() bool { // validated by checking their SEQ-fields." So // we only process it if it's acceptable. s.decRef() - e.mu.Lock() - e.state = stateError - e.hardError = tcpip.ErrConnectionReset - e.mu.Unlock() - return false + return tcpip.ErrConnectionReset } } else if s.flagIsSet(flagAck) { // Patch the window size in the segment according to the @@ -816,7 +807,7 @@ func (e *endpoint) handleSegments() bool { e.snd.sendAck() } - return true + return nil } // protocolMainLoop is the main loop of the TCP protocol. It runs in its own @@ -827,8 +818,9 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { var closeWaker sleep.Waker defer func() { + // e.mu is expected to be hold upon entering this section. e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut) - e.completeWorker() + e.completeWorkerLocked() if e.snd != nil { e.snd.resendTimer.cleanup() @@ -837,6 +829,12 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { if closeTimer != nil { closeTimer.Stop() } + + if e.drainDone != nil { + close(e.drainDone) + } + + e.mu.Unlock() }() if !passive { @@ -855,12 +853,7 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { e.mu.Lock() e.state = stateError e.hardError = err - drained := e.drainDone != nil - e.mu.Unlock() - if drained { - close(e.drainDone) - <-e.undrain - } + // Lock released in deferred statement. return err } @@ -894,7 +887,7 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { // wakes up. funcs := []struct { w *sleep.Waker - f func() bool + f func() *tcpip.Error }{ { w: &e.sndWaker, @@ -910,24 +903,22 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { }, { w: &closeWaker, - f: func() bool { - e.resetConnection(tcpip.ErrConnectionAborted) - return false + f: func() *tcpip.Error { + return tcpip.ErrConnectionAborted }, }, { w: &e.snd.resendWaker, - f: func() bool { + f: func() *tcpip.Error { if !e.snd.retransmitTimerExpired() { - e.resetConnection(tcpip.ErrTimeout) - return false + return tcpip.ErrTimeout } - return true + return nil }, }, { w: &e.notificationWaker, - f: func() bool { + f: func() *tcpip.Error { n := e.fetchNotifications() if n¬ifyNonZeroReceiveWindow != 0 { e.rcv.nonZeroWindow() @@ -954,7 +945,7 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { closeWaker.Assert() }) } - return true + return nil }, }, } @@ -971,7 +962,10 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { e.workMu.Unlock() v, _ := s.Fetch(true) e.workMu.Lock() - if !funcs[v].f() { + if err := funcs[v].f(); err != nil { + e.mu.Lock() + e.resetConnectionLocked(err) + // Lock released in deferred statement. return nil } } @@ -979,7 +973,7 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { // Mark endpoint as closed. e.mu.Lock() e.state = stateClosed - e.mu.Unlock() + // Lock released in deferred statement. return nil } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index f26b28632..3f87c4cac 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -94,7 +94,7 @@ type endpoint struct { state endpointState isPortReserved bool `state:"manual"` isRegistered bool - boundNICID tcpip.NICID + boundNICID tcpip.NICID `state:"manual"` route stack.Route `state:"manual"` v6only bool isConnectNotified bool @@ -118,7 +118,7 @@ type endpoint struct { // workerCleanup specifies if the worker goroutine must perform cleanup // before exitting. This can only be set to true when workerRunning is // also true, and they're both protected by the mutex. - workerCleanup bool + workerCleanup bool `state:"zerovalue"` // sendTSOk is used to indicate when the TS Option has been negotiated. // When sendTSOk is true every non-RST segment should carry a TS as per @@ -326,13 +326,7 @@ func (e *endpoint) Close() { // if we're connected, or stop accepting if we're listening. e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) - // While we hold the lock, determine if the cleanup should happen - // inline or if we should tell the worker (if any) to do the cleanup. e.mu.Lock() - worker := e.workerRunning - if worker { - e.workerCleanup = true - } // We always release ports inline so that they are immediately available // for reuse after Close() is called. If also registered, it means this @@ -348,29 +342,32 @@ func (e *endpoint) Close() { } } - e.mu.Unlock() - - // Now that we don't hold the lock anymore, either perform the local - // cleanup or kick the worker to make sure it knows it needs to cleanup. - if !worker { - e.cleanup() + // Either perform the local cleanup or kick the worker to make sure it + // knows it needs to cleanup. + if !e.workerRunning { + e.cleanupLocked() } else { + e.workerCleanup = true e.notifyProtocolGoroutine(notifyClose) } + + e.mu.Unlock() } -// cleanup frees all resources associated with the endpoint. It is called after -// Close() is called and the worker goroutine (if any) is done with its work. -func (e *endpoint) cleanup() { +// cleanupLocked frees all resources associated with the endpoint. It is called +// after Close() is called and the worker goroutine (if any) is done with its +// work. +func (e *endpoint) cleanupLocked() { // Close all endpoints that might have been accepted by TCP but not by // the client. if e.acceptedChan != nil { close(e.acceptedChan) for n := range e.acceptedChan { - n.resetConnection(tcpip.ErrConnectionAborted) + n.resetConnectionLocked(tcpip.ErrConnectionAborted) n.Close() } } + e.workerCleanup = false if e.isRegistered { e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id) diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index ebab7006d..212d2513a 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -12,17 +12,6 @@ import ( "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" ) -// ErrSaveRejection indicates a failed save due to unsupported tcp endpoint -// state. -type ErrSaveRejection struct { - Err error -} - -// Error returns a sensible description of the save rejection error. -func (e ErrSaveRejection) Error() string { - return "save rejected due to unsupported endpoint state: " + e.Err.Error() -} - func (e *endpoint) drainSegmentLocked() { // Drain only up to once. if e.drainDone != nil { @@ -48,8 +37,7 @@ func (e *endpoint) beforeSave() { defer e.mu.Unlock() switch e.state { - case stateInitial: - case stateBound: + case stateInitial, stateBound: case stateListen: if !e.segmentQueue.empty() { e.drainSegmentLocked() @@ -62,9 +50,11 @@ func (e *endpoint) beforeSave() { fallthrough case stateConnected: // FIXME - panic(ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%v, remote %v:%v", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)}) - case stateClosed: - case stateError: + panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%v, remote %v:%v", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)}) + case stateClosed, stateError: + if e.workerRunning { + panic(fmt.Sprintf("endpoint still has worker running in closed or error state")) + } default: panic(fmt.Sprintf("endpoint in unknown state %v", e.state)) } diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index c742fc394..07e4bfd73 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -30,7 +30,7 @@ type segment struct { segmentEntry refCnt int32 id stack.TransportEndpointID - route stack.Route + route stack.Route `state:"manual"` data buffer.VectorisedView // views is used as buffer for data when its length is large // enough to store a VectorisedView.