412 lines
12 KiB
Go
412 lines
12 KiB
Go
// Copyright 2018 The gVisor Authors.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package tcp
|
|
|
|
import (
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"gvisor.dev/gvisor/pkg/tcpip"
|
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
|
)
|
|
|
|
func (e *endpoint) drainSegmentLocked() {
|
|
// Drain only up to once.
|
|
if e.drainDone != nil {
|
|
return
|
|
}
|
|
|
|
e.drainDone = make(chan struct{})
|
|
e.undrain = make(chan struct{})
|
|
e.mu.Unlock()
|
|
|
|
e.notifyProtocolGoroutine(notifyDrain)
|
|
<-e.drainDone
|
|
|
|
e.mu.Lock()
|
|
}
|
|
|
|
// beforeSave is invoked by stateify.
|
|
func (e *endpoint) beforeSave() {
|
|
// Stop incoming packets.
|
|
e.segmentQueue.setLimit(0)
|
|
|
|
e.mu.Lock()
|
|
defer e.mu.Unlock()
|
|
|
|
switch e.state {
|
|
case StateInitial, StateBound:
|
|
// TODO(b/138137272): this enumeration duplicates
|
|
// EndpointState.connected. remove it.
|
|
case StateEstablished, StateSynSent, StateSynRecv, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
|
|
if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 {
|
|
if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 {
|
|
panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)})
|
|
}
|
|
e.resetConnectionLocked(tcpip.ErrConnectionAborted)
|
|
e.mu.Unlock()
|
|
e.Close()
|
|
e.mu.Lock()
|
|
}
|
|
if !e.workerRunning {
|
|
// The endpoint must be in acceptedChan or has been just
|
|
// disconnected and closed.
|
|
break
|
|
}
|
|
fallthrough
|
|
case StateListen, StateConnecting:
|
|
e.drainSegmentLocked()
|
|
if e.state != StateClose && e.state != StateError {
|
|
if !e.workerRunning {
|
|
panic("endpoint has no worker running in listen, connecting, or connected state")
|
|
}
|
|
break
|
|
}
|
|
fallthrough
|
|
case StateError, StateClose:
|
|
for (e.state == StateError || e.state == StateClose) && e.workerRunning {
|
|
e.mu.Unlock()
|
|
time.Sleep(100 * time.Millisecond)
|
|
e.mu.Lock()
|
|
}
|
|
if e.workerRunning {
|
|
panic("endpoint still has worker running in closed or error state")
|
|
}
|
|
default:
|
|
panic(fmt.Sprintf("endpoint in unknown state %v", e.state))
|
|
}
|
|
|
|
if e.waiterQueue != nil && !e.waiterQueue.IsEmpty() {
|
|
panic("endpoint still has waiters upon save")
|
|
}
|
|
|
|
if e.state != StateClose && !((e.state == StateBound || e.state == StateListen) == e.isPortReserved) {
|
|
panic("endpoints which are not in the closed state must have a reserved port IFF they are in bound or listen state")
|
|
}
|
|
}
|
|
|
|
// saveAcceptedChan is invoked by stateify.
|
|
func (e *endpoint) saveAcceptedChan() []*endpoint {
|
|
if e.acceptedChan == nil {
|
|
return nil
|
|
}
|
|
acceptedEndpoints := make([]*endpoint, len(e.acceptedChan), cap(e.acceptedChan))
|
|
for i := 0; i < len(acceptedEndpoints); i++ {
|
|
select {
|
|
case ep := <-e.acceptedChan:
|
|
acceptedEndpoints[i] = ep
|
|
default:
|
|
panic("endpoint acceptedChan buffer got consumed by background context")
|
|
}
|
|
}
|
|
for i := 0; i < len(acceptedEndpoints); i++ {
|
|
select {
|
|
case e.acceptedChan <- acceptedEndpoints[i]:
|
|
default:
|
|
panic("endpoint acceptedChan buffer got populated by background context")
|
|
}
|
|
}
|
|
return acceptedEndpoints
|
|
}
|
|
|
|
// loadAcceptedChan is invoked by stateify.
|
|
func (e *endpoint) loadAcceptedChan(acceptedEndpoints []*endpoint) {
|
|
if cap(acceptedEndpoints) > 0 {
|
|
e.acceptedChan = make(chan *endpoint, cap(acceptedEndpoints))
|
|
for _, ep := range acceptedEndpoints {
|
|
e.acceptedChan <- ep
|
|
}
|
|
}
|
|
}
|
|
|
|
// saveState is invoked by stateify.
|
|
func (e *endpoint) saveState() EndpointState {
|
|
return e.state
|
|
}
|
|
|
|
// Endpoint loading must be done in the following ordering by their state, to
|
|
// avoid dangling connecting w/o listening peer, and to avoid conflicts in port
|
|
// reservation.
|
|
var connectedLoading sync.WaitGroup
|
|
var listenLoading sync.WaitGroup
|
|
var connectingLoading sync.WaitGroup
|
|
|
|
// Bound endpoint loading happens last.
|
|
|
|
// loadState is invoked by stateify.
|
|
func (e *endpoint) loadState(state EndpointState) {
|
|
// This is to ensure that the loading wait groups include all applicable
|
|
// endpoints before any asynchronous calls to the Wait() methods.
|
|
if state.connected() {
|
|
connectedLoading.Add(1)
|
|
}
|
|
switch state {
|
|
case StateListen:
|
|
listenLoading.Add(1)
|
|
case StateConnecting, StateSynSent, StateSynRecv:
|
|
connectingLoading.Add(1)
|
|
}
|
|
e.state = state
|
|
}
|
|
|
|
// afterLoad is invoked by stateify.
|
|
func (e *endpoint) afterLoad() {
|
|
// Freeze segment queue before registering to prevent any segments
|
|
// from being delivered while it is being restored.
|
|
e.origEndpointState = e.state
|
|
// Restore the endpoint to InitialState as it will be moved to
|
|
// its origEndpointState during Resume.
|
|
e.state = StateInitial
|
|
stack.StackFromEnv.RegisterRestoredEndpoint(e)
|
|
}
|
|
|
|
// Resume implements tcpip.ResumableEndpoint.Resume.
|
|
func (e *endpoint) Resume(s *stack.Stack) {
|
|
e.stack = s
|
|
e.segmentQueue.setLimit(MaxUnprocessedSegments)
|
|
e.workMu.Init()
|
|
state := e.origEndpointState
|
|
|
|
switch state {
|
|
case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
|
|
var ss SendBufferSizeOption
|
|
if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
|
|
if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max {
|
|
panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max))
|
|
}
|
|
if e.rcvBufSize < ss.Min || e.rcvBufSize > ss.Max {
|
|
panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, ss.Min, ss.Max))
|
|
}
|
|
}
|
|
}
|
|
|
|
bind := func() {
|
|
if len(e.BindAddr) == 0 {
|
|
e.BindAddr = e.ID.LocalAddress
|
|
}
|
|
addr := e.BindAddr
|
|
port := e.ID.LocalPort
|
|
if err := e.Bind(tcpip.FullAddress{Addr: addr, Port: port}); err != nil {
|
|
panic(fmt.Sprintf("endpoint binding [%v]:%d failed: %v", addr, port, err))
|
|
}
|
|
}
|
|
|
|
switch state {
|
|
case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
|
|
bind()
|
|
if len(e.connectingAddress) == 0 {
|
|
e.connectingAddress = e.ID.RemoteAddress
|
|
// This endpoint is accepted by netstack but not yet by
|
|
// the app. If the endpoint is IPv6 but the remote
|
|
// address is IPv4, we need to connect as IPv6 so that
|
|
// dual-stack mode can be properly activated.
|
|
if e.NetProto == header.IPv6ProtocolNumber && len(e.ID.RemoteAddress) != header.IPv6AddressSize {
|
|
e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.ID.RemoteAddress
|
|
}
|
|
}
|
|
// Reset the scoreboard to reinitialize the sack information as
|
|
// we do not restore SACK information.
|
|
e.scoreboard.Reset()
|
|
if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted {
|
|
panic("endpoint connecting failed: " + err.String())
|
|
}
|
|
e.mu.Lock()
|
|
e.state = e.origEndpointState
|
|
closed := e.closed
|
|
e.mu.Unlock()
|
|
e.notifyProtocolGoroutine(notifyTickleWorker)
|
|
if state == StateFinWait2 && closed {
|
|
// If the endpoint has been closed then make sure we notify so
|
|
// that the FIN_WAIT2 timer is started after a restore.
|
|
e.notifyProtocolGoroutine(notifyClose)
|
|
}
|
|
connectedLoading.Done()
|
|
case StateListen:
|
|
tcpip.AsyncLoading.Add(1)
|
|
go func() {
|
|
connectedLoading.Wait()
|
|
bind()
|
|
backlog := cap(e.acceptedChan)
|
|
if err := e.Listen(backlog); err != nil {
|
|
panic("endpoint listening failed: " + err.String())
|
|
}
|
|
listenLoading.Done()
|
|
tcpip.AsyncLoading.Done()
|
|
}()
|
|
case StateConnecting, StateSynSent, StateSynRecv:
|
|
tcpip.AsyncLoading.Add(1)
|
|
go func() {
|
|
connectedLoading.Wait()
|
|
listenLoading.Wait()
|
|
bind()
|
|
if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}); err != tcpip.ErrConnectStarted {
|
|
panic("endpoint connecting failed: " + err.String())
|
|
}
|
|
connectingLoading.Done()
|
|
tcpip.AsyncLoading.Done()
|
|
}()
|
|
case StateBound:
|
|
tcpip.AsyncLoading.Add(1)
|
|
go func() {
|
|
connectedLoading.Wait()
|
|
listenLoading.Wait()
|
|
connectingLoading.Wait()
|
|
bind()
|
|
tcpip.AsyncLoading.Done()
|
|
}()
|
|
case StateClose:
|
|
if e.isPortReserved {
|
|
tcpip.AsyncLoading.Add(1)
|
|
go func() {
|
|
connectedLoading.Wait()
|
|
listenLoading.Wait()
|
|
connectingLoading.Wait()
|
|
bind()
|
|
e.state = StateClose
|
|
tcpip.AsyncLoading.Done()
|
|
}()
|
|
}
|
|
e.state = StateClose
|
|
e.stack.CompleteTransportEndpointCleanup(e)
|
|
tcpip.DeleteDanglingEndpoint(e)
|
|
case StateError:
|
|
e.state = StateError
|
|
e.stack.CompleteTransportEndpointCleanup(e)
|
|
tcpip.DeleteDanglingEndpoint(e)
|
|
}
|
|
}
|
|
|
|
// saveLastError is invoked by stateify.
|
|
func (e *endpoint) saveLastError() string {
|
|
if e.lastError == nil {
|
|
return ""
|
|
}
|
|
|
|
return e.lastError.String()
|
|
}
|
|
|
|
// loadLastError is invoked by stateify.
|
|
func (e *endpoint) loadLastError(s string) {
|
|
if s == "" {
|
|
return
|
|
}
|
|
|
|
e.lastError = loadError(s)
|
|
}
|
|
|
|
// saveHardError is invoked by stateify.
|
|
func (e *EndpointInfo) saveHardError() string {
|
|
if e.HardError == nil {
|
|
return ""
|
|
}
|
|
|
|
return e.HardError.String()
|
|
}
|
|
|
|
// loadHardError is invoked by stateify.
|
|
func (e *EndpointInfo) loadHardError(s string) {
|
|
if s == "" {
|
|
return
|
|
}
|
|
|
|
e.HardError = loadError(s)
|
|
}
|
|
|
|
var messageToError map[string]*tcpip.Error
|
|
|
|
var populate sync.Once
|
|
|
|
func loadError(s string) *tcpip.Error {
|
|
populate.Do(func() {
|
|
var errors = []*tcpip.Error{
|
|
tcpip.ErrUnknownProtocol,
|
|
tcpip.ErrUnknownNICID,
|
|
tcpip.ErrUnknownDevice,
|
|
tcpip.ErrUnknownProtocolOption,
|
|
tcpip.ErrDuplicateNICID,
|
|
tcpip.ErrDuplicateAddress,
|
|
tcpip.ErrNoRoute,
|
|
tcpip.ErrBadLinkEndpoint,
|
|
tcpip.ErrAlreadyBound,
|
|
tcpip.ErrInvalidEndpointState,
|
|
tcpip.ErrAlreadyConnecting,
|
|
tcpip.ErrAlreadyConnected,
|
|
tcpip.ErrNoPortAvailable,
|
|
tcpip.ErrPortInUse,
|
|
tcpip.ErrBadLocalAddress,
|
|
tcpip.ErrClosedForSend,
|
|
tcpip.ErrClosedForReceive,
|
|
tcpip.ErrWouldBlock,
|
|
tcpip.ErrConnectionRefused,
|
|
tcpip.ErrTimeout,
|
|
tcpip.ErrAborted,
|
|
tcpip.ErrConnectStarted,
|
|
tcpip.ErrDestinationRequired,
|
|
tcpip.ErrNotSupported,
|
|
tcpip.ErrQueueSizeNotSupported,
|
|
tcpip.ErrNotConnected,
|
|
tcpip.ErrConnectionReset,
|
|
tcpip.ErrConnectionAborted,
|
|
tcpip.ErrNoSuchFile,
|
|
tcpip.ErrInvalidOptionValue,
|
|
tcpip.ErrNoLinkAddress,
|
|
tcpip.ErrBadAddress,
|
|
tcpip.ErrNetworkUnreachable,
|
|
tcpip.ErrMessageTooLong,
|
|
tcpip.ErrNoBufferSpace,
|
|
tcpip.ErrBroadcastDisabled,
|
|
tcpip.ErrNotPermitted,
|
|
tcpip.ErrAddressFamilyNotSupported,
|
|
}
|
|
|
|
messageToError = make(map[string]*tcpip.Error)
|
|
for _, e := range errors {
|
|
if messageToError[e.String()] != nil {
|
|
panic("tcpip errors with duplicated message: " + e.String())
|
|
}
|
|
messageToError[e.String()] = e
|
|
}
|
|
})
|
|
|
|
e, ok := messageToError[s]
|
|
if !ok {
|
|
panic("unknown error message: " + s)
|
|
}
|
|
|
|
return e
|
|
}
|
|
|
|
// saveMeasureTime is invoked by stateify.
|
|
func (r *rcvBufAutoTuneParams) saveMeasureTime() unixTime {
|
|
return unixTime{r.measureTime.Unix(), r.measureTime.UnixNano()}
|
|
}
|
|
|
|
// loadMeasureTime is invoked by stateify.
|
|
func (r *rcvBufAutoTuneParams) loadMeasureTime(unix unixTime) {
|
|
r.measureTime = time.Unix(unix.second, unix.nano)
|
|
}
|
|
|
|
// saveRttMeasureTime is invoked by stateify.
|
|
func (r *rcvBufAutoTuneParams) saveRttMeasureTime() unixTime {
|
|
return unixTime{r.rttMeasureTime.Unix(), r.rttMeasureTime.UnixNano()}
|
|
}
|
|
|
|
// loadRttMeasureTime is invoked by stateify.
|
|
func (r *rcvBufAutoTuneParams) loadRttMeasureTime(unix unixTime) {
|
|
r.rttMeasureTime = time.Unix(unix.second, unix.nano)
|
|
}
|