gvisor/pkg/tcpip/transport/tcp/endpoint_state.go

235 lines
5.5 KiB
Go

// Copyright 2017 The Netstack Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tcp
import (
"fmt"
"sync"
"gvisor.googlesource.com/gvisor/pkg/tcpip"
"gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
)
func (e *endpoint) drainSegmentLocked() {
// Drain only up to once.
if e.drainDone != nil {
return
}
e.drainDone = make(chan struct{})
e.undrain = make(chan struct{})
e.mu.Unlock()
e.notificationWaker.Assert()
<-e.drainDone
e.mu.Lock()
}
// beforeSave is invoked by stateify.
func (e *endpoint) beforeSave() {
// Stop incoming packets.
e.segmentQueue.setLimit(0)
e.mu.Lock()
defer e.mu.Unlock()
switch e.state {
case stateInitial, stateBound:
case stateListen:
if !e.segmentQueue.empty() {
e.drainSegmentLocked()
}
case stateConnecting:
e.drainSegmentLocked()
if e.state != stateConnected {
break
}
fallthrough
case stateConnected:
// FIXME
panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%v, remote %v:%v", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)})
case stateClosed, stateError:
if e.workerRunning {
panic(fmt.Sprintf("endpoint still has worker running in closed or error state"))
}
default:
panic(fmt.Sprintf("endpoint in unknown state %v", e.state))
}
}
// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad() {
e.stack = stack.StackFromEnv
e.segmentQueue.setLimit(2 * e.rcvBufSize)
e.workMu.Init()
state := e.state
switch state {
case stateInitial, stateBound, stateListen, stateConnecting:
var ss SendBufferSizeOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max {
panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max))
}
if e.rcvBufSize < ss.Min || e.rcvBufSize > ss.Max {
panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, ss.Min, ss.Max))
}
}
}
switch state {
case stateBound, stateListen, stateConnecting:
e.state = stateInitial
if err := e.Bind(tcpip.FullAddress{Addr: e.id.LocalAddress, Port: e.id.LocalPort}, nil); err != nil {
panic("endpoint binding failed: " + err.String())
}
}
switch state {
case stateListen:
backlog := cap(e.acceptedChan)
e.acceptedChan = nil
if err := e.Listen(backlog); err != nil {
panic("endpoint listening failed: " + err.String())
}
}
switch state {
case stateConnecting:
if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}); err != tcpip.ErrConnectStarted {
panic("endpoint connecting failed: " + err.String())
}
}
}
// saveAcceptedChan is invoked by stateify.
func (e *endpoint) saveAcceptedChan() endpointChan {
if e.acceptedChan == nil {
return endpointChan{}
}
close(e.acceptedChan)
buffer := make([]*endpoint, 0, len(e.acceptedChan))
for ep := range e.acceptedChan {
buffer = append(buffer, ep)
}
if len(buffer) != cap(buffer) {
panic("endpoint.acceptedChan buffer got consumed by background context")
}
c := cap(e.acceptedChan)
e.acceptedChan = nil
return endpointChan{buffer: buffer, cap: c}
}
// loadAcceptedChan is invoked by stateify.
func (e *endpoint) loadAcceptedChan(c endpointChan) {
if c.cap == 0 {
return
}
e.acceptedChan = make(chan *endpoint, c.cap)
for _, ep := range c.buffer {
e.acceptedChan <- ep
}
}
type endpointChan struct {
buffer []*endpoint
cap int
}
// saveLastError is invoked by stateify.
func (e *endpoint) saveLastError() string {
if e.lastError == nil {
return ""
}
return e.lastError.String()
}
// loadLastError is invoked by stateify.
func (e *endpoint) loadLastError(s string) {
if s == "" {
return
}
e.lastError = loadError(s)
}
// saveHardError is invoked by stateify.
func (e *endpoint) saveHardError() string {
if e.hardError == nil {
return ""
}
return e.hardError.String()
}
// loadHardError is invoked by stateify.
func (e *endpoint) loadHardError(s string) {
if s == "" {
return
}
e.hardError = loadError(s)
}
var messageToError map[string]*tcpip.Error
var populate sync.Once
func loadError(s string) *tcpip.Error {
populate.Do(func() {
var errors = []*tcpip.Error{
tcpip.ErrUnknownProtocol,
tcpip.ErrUnknownNICID,
tcpip.ErrUnknownProtocolOption,
tcpip.ErrDuplicateNICID,
tcpip.ErrDuplicateAddress,
tcpip.ErrNoRoute,
tcpip.ErrBadLinkEndpoint,
tcpip.ErrAlreadyBound,
tcpip.ErrInvalidEndpointState,
tcpip.ErrAlreadyConnecting,
tcpip.ErrAlreadyConnected,
tcpip.ErrNoPortAvailable,
tcpip.ErrPortInUse,
tcpip.ErrBadLocalAddress,
tcpip.ErrClosedForSend,
tcpip.ErrClosedForReceive,
tcpip.ErrWouldBlock,
tcpip.ErrConnectionRefused,
tcpip.ErrTimeout,
tcpip.ErrAborted,
tcpip.ErrConnectStarted,
tcpip.ErrDestinationRequired,
tcpip.ErrNotSupported,
tcpip.ErrQueueSizeNotSupported,
tcpip.ErrNotConnected,
tcpip.ErrConnectionReset,
tcpip.ErrConnectionAborted,
tcpip.ErrNoSuchFile,
tcpip.ErrInvalidOptionValue,
tcpip.ErrNoLinkAddress,
tcpip.ErrBadAddress,
tcpip.ErrNetworkUnreachable,
}
messageToError = make(map[string]*tcpip.Error)
for _, e := range errors {
if messageToError[e.String()] != nil {
panic("tcpip errors with duplicated message: " + e.String())
}
messageToError[e.String()] = e
}
})
e, ok := messageToError[s]
if !ok {
panic("unknown error message: " + s)
}
return e
}