From 13a98df49ea1b36cd21c528293b626a6a3639f0b Mon Sep 17 00:00:00 2001 From: Rahat Mahmood Date: Thu, 8 Aug 2019 12:32:00 -0700 Subject: [PATCH] netstack: Don't start endpoint goroutines too soon on restore. Endpoint protocol goroutines were previously started as part of loading the endpoint. This is potentially too soon, as resources used by these goroutine may not have been loaded. Protocol goroutines may perform meaningful work as soon as they're started (ex: incoming connect) which can cause them to indirectly access resources that haven't been loaded yet. This CL defers resuming all protocol goroutines until the end of restore. PiperOrigin-RevId: 262409429 --- pkg/sentry/inet/inet.go | 3 + pkg/sentry/inet/test_stack.go | 4 + pkg/sentry/kernel/kernel.go | 9 +- pkg/sentry/socket/epsocket/stack.go | 5 + pkg/sentry/socket/hostinet/stack.go | 3 + pkg/sentry/socket/rpcinet/stack.go | 3 + pkg/sentry/state/BUILD | 1 + pkg/sentry/state/state.go | 5 +- pkg/tcpip/stack/stack.go | 35 +++++++ pkg/tcpip/stack/transport_test.go | 3 + pkg/tcpip/tcpip.go | 10 +- pkg/tcpip/transport/icmp/endpoint.go | 28 ++++++ pkg/tcpip/transport/icmp/endpoint_state.go | 26 +----- pkg/tcpip/transport/raw/endpoint.go | 25 +++++ pkg/tcpip/transport/raw/endpoint_state.go | 24 +---- pkg/tcpip/transport/tcp/endpoint.go | 101 +++++++++++++++++++++ pkg/tcpip/transport/tcp/endpoint_state.go | 100 +------------------- pkg/tcpip/transport/udp/endpoint.go | 47 ++++++++++ pkg/tcpip/transport/udp/endpoint_state.go | 46 +--------- runsc/boot/controller.go | 5 +- 20 files changed, 279 insertions(+), 204 deletions(-) diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go index 60d6dfb93..80f227dbe 100644 --- a/pkg/sentry/inet/inet.go +++ b/pkg/sentry/inet/inet.go @@ -55,6 +55,9 @@ type Stack interface { // RouteTable returns the network stack's route table. RouteTable() []Route + + // Resume restarts the network stack after restore. + Resume() } // Interface contains information about a network interface. diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go index 57d5510f0..b9eed7c3a 100644 --- a/pkg/sentry/inet/test_stack.go +++ b/pkg/sentry/inet/test_stack.go @@ -92,3 +92,7 @@ func (s *TestStack) Statistics(stat interface{}, arg string) error { func (s *TestStack) RouteTable() []Route { return s.RouteList } + +// Resume implements Stack.Resume. +func (s *TestStack) Resume() { +} diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 56a329f83..8c1f79ab5 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -496,7 +496,7 @@ func (ts *TaskSet) unregisterEpollWaiters() { } // LoadFrom returns a new Kernel loaded from args. -func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack) error { +func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) error { loadStart := time.Now() k.networkStack = net @@ -540,6 +540,11 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack) error { log.Infof("Overall load took [%s]", time.Since(loadStart)) + k.Timekeeper().SetClocks(clocks) + if net != nil { + net.Resume() + } + // Ensure that all pending asynchronous work is complete: // - namedpipe opening // - inode file opening @@ -549,7 +554,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack) error { tcpip.AsyncLoading.Wait() - log.Infof("Overall load took [%s]", time.Since(loadStart)) + log.Infof("Overall load took [%s] after async work", time.Since(loadStart)) // Applications may size per-cpu structures based on k.applicationCores, so // it can't change across save/restore. When we are virtualizing CPU diff --git a/pkg/sentry/socket/epsocket/stack.go b/pkg/sentry/socket/epsocket/stack.go index 0cf235b31..8f1572bf4 100644 --- a/pkg/sentry/socket/epsocket/stack.go +++ b/pkg/sentry/socket/epsocket/stack.go @@ -201,3 +201,8 @@ func (s *Stack) IPTables() (iptables.IPTables, error) { func (s *Stack) FillDefaultIPTables() error { return netfilter.FillDefaultIPTables(s.Stack) } + +// Resume implements inet.Stack.Resume. +func (s *Stack) Resume() { + s.Stack.Resume() +} diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index 99b7a1e2b..1902fe155 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -329,3 +329,6 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { func (s *Stack) RouteTable() []inet.Route { return append([]inet.Route(nil), s.routes...) } + +// Resume implements inet.Stack.Resume. +func (s *Stack) Resume() {} diff --git a/pkg/sentry/socket/rpcinet/stack.go b/pkg/sentry/socket/rpcinet/stack.go index d18305589..5dcb6b455 100644 --- a/pkg/sentry/socket/rpcinet/stack.go +++ b/pkg/sentry/socket/rpcinet/stack.go @@ -162,3 +162,6 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { func (s *Stack) RouteTable() []inet.Route { return append([]inet.Route(nil), s.routes...) } + +// Resume implements inet.Stack.Resume. +func (s *Stack) Resume() {} diff --git a/pkg/sentry/state/BUILD b/pkg/sentry/state/BUILD index f297ef3b7..88765f4d6 100644 --- a/pkg/sentry/state/BUILD +++ b/pkg/sentry/state/BUILD @@ -16,6 +16,7 @@ go_library( "//pkg/log", "//pkg/sentry/inet", "//pkg/sentry/kernel", + "//pkg/sentry/time", "//pkg/sentry/watchdog", "//pkg/state/statefile", "//pkg/syserror", diff --git a/pkg/sentry/state/state.go b/pkg/sentry/state/state.go index 026549756..9eb626b76 100644 --- a/pkg/sentry/state/state.go +++ b/pkg/sentry/state/state.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/sentry/watchdog" "gvisor.dev/gvisor/pkg/state/statefile" "gvisor.dev/gvisor/pkg/syserror" @@ -104,7 +105,7 @@ type LoadOpts struct { } // Load loads the given kernel, setting the provided platform and stack. -func (opts LoadOpts) Load(k *kernel.Kernel, n inet.Stack) error { +func (opts LoadOpts) Load(k *kernel.Kernel, n inet.Stack, clocks time.Clocks) error { // Open the file. r, m, err := statefile.NewReader(opts.Source, opts.Key) if err != nil { @@ -114,5 +115,5 @@ func (opts LoadOpts) Load(k *kernel.Kernel, n inet.Stack) error { previousMetadata = m // Restore the Kernel object graph. - return k.LoadFrom(r, n) + return k.LoadFrom(r, n, clocks) } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 78beb0dae..d45e547ee 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -334,6 +334,15 @@ type TCPEndpointState struct { Sender TCPSenderState } +// ResumableEndpoint is an endpoint that needs to be resumed after restore. +type ResumableEndpoint interface { + // Resume resumes an endpoint after restore. This can be used to restart + // background workers such as protocol goroutines. This must be called after + // all indirect dependencies of the endpoint has been restored, which + // generally implies at the end of the restore process. + Resume(*Stack) +} + // Stack is a networking stack, with all supported protocols, NICs, and route // table. type Stack struct { @@ -376,6 +385,10 @@ type Stack struct { // tables are the iptables packet filtering and manipulation rules. tables iptables.IPTables + + // resumableEndpoints is a list of endpoints that need to be resumed if the + // stack is being restored. + resumableEndpoints []ResumableEndpoint } // Options contains optional Stack configuration. @@ -1091,6 +1104,28 @@ func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip } } +// RegisterRestoredEndpoint records e as an endpoint that has been restored on +// this stack. +func (s *Stack) RegisterRestoredEndpoint(e ResumableEndpoint) { + s.mu.Lock() + s.resumableEndpoints = append(s.resumableEndpoints, e) + s.mu.Unlock() +} + +// Resume restarts the stack after a restore. This must be called after the +// entire system has been restored. +func (s *Stack) Resume() { + // ResumableEndpoint.Resume() may call other methods on s, so we can't hold + // s.mu while resuming the endpoints. + s.mu.Lock() + eps := s.resumableEndpoints + s.resumableEndpoints = nil + s.mu.Unlock() + for _, e := range eps { + e.Resume(s) + } +} + // NetworkProtocolInstance returns the protocol instance in the stack for the // specified network protocol. This method is public for protocol implementers // and tests to use. diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 8652d7814..eee3144cd 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -205,6 +205,9 @@ func (f *fakeTransportEndpoint) IPTables() (iptables.IPTables, error) { return iptables.IPTables{}, nil } +func (f *fakeTransportEndpoint) Resume(*stack.Stack) { +} + type fakeTransportGoodOption bool type fakeTransportBadOption bool diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 0df9f6d93..119712d2f 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -1082,11 +1082,13 @@ type ProtocolAddress struct { AddressWithPrefix AddressWithPrefix } -// danglingEndpointsMu protects access to danglingEndpoints. -var danglingEndpointsMu sync.Mutex +var ( + // danglingEndpointsMu protects access to danglingEndpoints. + danglingEndpointsMu sync.Mutex -// danglingEndpoints tracks all dangling endpoints no longer owned by the app. -var danglingEndpoints = make(map[Endpoint]struct{}) + // danglingEndpoints tracks all dangling endpoints no longer owned by the app. + danglingEndpoints = make(map[Endpoint]struct{}) +) // GetDanglingEndpoints returns all dangling endpoints. func GetDanglingEndpoints() []Endpoint { diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index a4527c041..9a4306011 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -136,6 +136,34 @@ func (e *endpoint) IPTables() (iptables.IPTables, error) { return e.stack.IPTables(), nil } +// Resume implements tcpip.ResumableEndpoint.Resume. +func (e *endpoint) Resume(s *stack.Stack) { + e.stack = s + + if e.state != stateBound && e.state != stateConnected { + return + } + + var err *tcpip.Error + if e.state == stateConnected { + e.route, err = e.stack.FindRoute(e.regNICID, e.bindAddr, e.id.RemoteAddress, e.netProto, false /* multicastLoop */) + if err != nil { + panic(*err) + } + + e.id.LocalAddress = e.route.LocalAddress + } else if len(e.id.LocalAddress) != 0 { // stateBound + if e.stack.CheckLocalAddress(e.regNICID, e.netProto, e.id.LocalAddress) == 0 { + panic(tcpip.ErrBadLocalAddress) + } + } + + e.id, err = e.registerWithStack(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.id) + if err != nil { + panic(*err) + } +} + // Read reads data from the endpoint. This method does not block if // there is no data pending. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go index 99b8c4093..43551d642 100644 --- a/pkg/tcpip/transport/icmp/endpoint_state.go +++ b/pkg/tcpip/transport/icmp/endpoint_state.go @@ -15,7 +15,6 @@ package icmp import ( - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -63,28 +62,5 @@ func (e *endpoint) loadRcvBufSizeMax(max int) { // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { - e.stack = stack.StackFromEnv - - if e.state != stateBound && e.state != stateConnected { - return - } - - var err *tcpip.Error - if e.state == stateConnected { - e.route, err = e.stack.FindRoute(e.regNICID, e.bindAddr, e.id.RemoteAddress, e.netProto, false /* multicastLoop */) - if err != nil { - panic(*err) - } - - e.id.LocalAddress = e.route.LocalAddress - } else if len(e.id.LocalAddress) != 0 { // stateBound - if e.stack.CheckLocalAddress(e.regNICID, e.netProto, e.id.LocalAddress) == 0 { - panic(tcpip.ErrBadLocalAddress) - } - } - - e.id, err = e.registerWithStack(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.id) - if err != nil { - panic(*err) - } + stack.StackFromEnv.RegisterRestoredEndpoint(e) } diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index b4be855c1..eab3dcbd2 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -174,6 +174,31 @@ func (ep *endpoint) IPTables() (iptables.IPTables, error) { return ep.stack.IPTables(), nil } +// Resume implements tcpip.ResumableEndpoint.Resume. +func (ep *endpoint) Resume(s *stack.Stack) { + ep.stack = s + + // If the endpoint is connected, re-connect. + if ep.connected { + var err *tcpip.Error + ep.route, err = ep.stack.FindRoute(ep.registeredNIC, ep.boundAddr, ep.route.RemoteAddress, ep.netProto, false) + if err != nil { + panic(*err) + } + } + + // If the endpoint is bound, re-bind. + if ep.bound { + if ep.stack.CheckLocalAddress(ep.registeredNIC, ep.netProto, ep.boundAddr) == 0 { + panic(tcpip.ErrBadLocalAddress) + } + } + + if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil { + panic(*err) + } +} + // Read implements tcpip.Endpoint.Read. func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { if !ep.associated { diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index cb5534d90..44abddb2b 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -15,7 +15,6 @@ package raw import ( - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -63,26 +62,5 @@ func (ep *endpoint) loadRcvBufSizeMax(max int) { // afterLoad is invoked by stateify. func (ep *endpoint) afterLoad() { - // StackFromEnv is a stack used specifically for save/restore. - ep.stack = stack.StackFromEnv - - // If the endpoint is connected, re-connect via the save/restore stack. - if ep.connected { - var err *tcpip.Error - ep.route, err = ep.stack.FindRoute(ep.registeredNIC, ep.boundAddr, ep.route.RemoteAddress, ep.netProto, false) - if err != nil { - panic(*err) - } - } - - // If the endpoint is bound, re-bind via the save/restore stack. - if ep.bound { - if ep.stack.CheckLocalAddress(ep.registeredNIC, ep.netProto, ep.boundAddr) == 0 { - panic(tcpip.ErrBadLocalAddress) - } - } - - if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil { - panic(*err) - } + stack.StackFromEnv.RegisterRestoredEndpoint(ep) } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 0e16877e7..e67169111 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -720,6 +720,107 @@ func (e *endpoint) IPTables() (iptables.IPTables, error) { return e.stack.IPTables(), nil } +// Resume implements tcpip.ResumableEndpoint.Resume. +func (e *endpoint) Resume(s *stack.Stack) { + e.stack = s + e.segmentQueue.setLimit(MaxUnprocessedSegments) + e.workMu.Init() + + state := e.state + switch state { + case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: + var ss SendBufferSizeOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { + if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max { + panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max)) + } + if e.rcvBufSize < ss.Min || e.rcvBufSize > ss.Max { + panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, ss.Min, ss.Max)) + } + } + } + + bind := func() { + e.state = StateInitial + if len(e.bindAddress) == 0 { + e.bindAddress = e.id.LocalAddress + } + if err := e.Bind(tcpip.FullAddress{Addr: e.bindAddress, Port: e.id.LocalPort}); err != nil { + panic("endpoint binding failed: " + err.String()) + } + } + + switch state { + case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: + bind() + if len(e.connectingAddress) == 0 { + e.connectingAddress = e.id.RemoteAddress + // This endpoint is accepted by netstack but not yet by + // the app. If the endpoint is IPv6 but the remote + // address is IPv4, we need to connect as IPv6 so that + // dual-stack mode can be properly activated. + if e.netProto == header.IPv6ProtocolNumber && len(e.id.RemoteAddress) != header.IPv6AddressSize { + e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.id.RemoteAddress + } + } + // Reset the scoreboard to reinitialize the sack information as + // we do not restore SACK information. + e.scoreboard.Reset() + if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted { + panic("endpoint connecting failed: " + err.String()) + } + connectedLoading.Done() + case StateListen: + tcpip.AsyncLoading.Add(1) + go func() { + connectedLoading.Wait() + bind() + backlog := cap(e.acceptedChan) + if err := e.Listen(backlog); err != nil { + panic("endpoint listening failed: " + err.String()) + } + listenLoading.Done() + tcpip.AsyncLoading.Done() + }() + case StateConnecting, StateSynSent, StateSynRecv: + tcpip.AsyncLoading.Add(1) + go func() { + connectedLoading.Wait() + listenLoading.Wait() + bind() + if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}); err != tcpip.ErrConnectStarted { + panic("endpoint connecting failed: " + err.String()) + } + connectingLoading.Done() + tcpip.AsyncLoading.Done() + }() + case StateBound: + tcpip.AsyncLoading.Add(1) + go func() { + connectedLoading.Wait() + listenLoading.Wait() + connectingLoading.Wait() + bind() + tcpip.AsyncLoading.Done() + }() + case StateClose: + if e.isPortReserved { + tcpip.AsyncLoading.Add(1) + go func() { + connectedLoading.Wait() + listenLoading.Wait() + connectingLoading.Wait() + bind() + e.state = StateClose + tcpip.AsyncLoading.Done() + }() + } + fallthrough + case StateError: + tcpip.DeleteDanglingEndpoint(e) + } +} + // Read reads data from the endpoint. func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { e.mu.RLock() diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index b3f0f6c5d..ef88dc618 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -20,7 +20,6 @@ import ( "time" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -165,104 +164,7 @@ func (e *endpoint) loadState(state EndpointState) { // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { - e.stack = stack.StackFromEnv - e.segmentQueue.setLimit(MaxUnprocessedSegments) - e.workMu.Init() - - state := e.state - switch state { - case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: - var ss SendBufferSizeOption - if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { - if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max { - panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max)) - } - if e.rcvBufSize < ss.Min || e.rcvBufSize > ss.Max { - panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, ss.Min, ss.Max)) - } - } - } - - bind := func() { - e.state = StateInitial - if len(e.bindAddress) == 0 { - e.bindAddress = e.id.LocalAddress - } - if err := e.Bind(tcpip.FullAddress{Addr: e.bindAddress, Port: e.id.LocalPort}); err != nil { - panic("endpoint binding failed: " + err.String()) - } - } - - switch state { - case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: - bind() - if len(e.connectingAddress) == 0 { - // This endpoint is accepted by netstack but not yet by - // the app. If the endpoint is IPv6 but the remote - // address is IPv4, we need to connect as IPv6 so that - // dual-stack mode can be properly activated. - if e.netProto == header.IPv6ProtocolNumber && len(e.id.RemoteAddress) != header.IPv6AddressSize { - e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.id.RemoteAddress - } else { - e.connectingAddress = e.id.RemoteAddress - } - } - // Reset the scoreboard to reinitialize the sack information as - // we do not restore SACK information. - e.scoreboard.Reset() - if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted { - panic("endpoint connecting failed: " + err.String()) - } - connectedLoading.Done() - case StateListen: - tcpip.AsyncLoading.Add(1) - go func() { - connectedLoading.Wait() - bind() - backlog := cap(e.acceptedChan) - if err := e.Listen(backlog); err != nil { - panic("endpoint listening failed: " + err.String()) - } - listenLoading.Done() - tcpip.AsyncLoading.Done() - }() - case StateConnecting, StateSynSent, StateSynRecv: - tcpip.AsyncLoading.Add(1) - go func() { - connectedLoading.Wait() - listenLoading.Wait() - bind() - if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}); err != tcpip.ErrConnectStarted { - panic("endpoint connecting failed: " + err.String()) - } - connectingLoading.Done() - tcpip.AsyncLoading.Done() - }() - case StateBound: - tcpip.AsyncLoading.Add(1) - go func() { - connectedLoading.Wait() - listenLoading.Wait() - connectingLoading.Wait() - bind() - tcpip.AsyncLoading.Done() - }() - case StateClose: - if e.isPortReserved { - tcpip.AsyncLoading.Add(1) - go func() { - connectedLoading.Wait() - listenLoading.Wait() - connectingLoading.Wait() - bind() - e.state = StateClose - tcpip.AsyncLoading.Done() - }() - } - fallthrough - case StateError: - tcpip.DeleteDanglingEndpoint(e) - } + stack.StackFromEnv.RegisterRestoredEndpoint(e) } // saveLastError is invoked by stateify. diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 7210b3a9f..7c12a6092 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -178,6 +178,53 @@ func (e *endpoint) IPTables() (iptables.IPTables, error) { return e.stack.IPTables(), nil } +// Resume implements tcpip.ResumableEndpoint.Resume. +func (e *endpoint) Resume(s *stack.Stack) { + e.stack = s + + for _, m := range e.multicastMemberships { + if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil { + panic(err) + } + } + + if e.state != stateBound && e.state != stateConnected { + return + } + + netProto := e.effectiveNetProtos[0] + // Connect() and bindLocked() both assert + // + // netProto == header.IPv6ProtocolNumber + // + // before creating a multi-entry effectiveNetProtos. + if len(e.effectiveNetProtos) > 1 { + netProto = header.IPv6ProtocolNumber + } + + var err *tcpip.Error + if e.state == stateConnected { + e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto, e.multicastLoop) + if err != nil { + panic(*err) + } + } else if len(e.id.LocalAddress) != 0 { // stateBound + if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 { + panic(tcpip.ErrBadLocalAddress) + } + } + + // Our saved state had a port, but we don't actually have a + // reservation. We need to remove the port from our state, but still + // pass it to the reservation machinery. + id := e.id + e.id.LocalPort = 0 + e.id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id) + if err != nil { + panic(*err) + } +} + // Read reads data from the endpoint. This method does not block if // there is no data pending. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 18e786397..86db36260 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -15,9 +15,7 @@ package udp import ( - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -64,47 +62,5 @@ func (e *endpoint) loadRcvBufSizeMax(max int) { // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { - e.stack = stack.StackFromEnv - - for _, m := range e.multicastMemberships { - if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil { - panic(err) - } - } - - if e.state != stateBound && e.state != stateConnected { - return - } - - netProto := e.effectiveNetProtos[0] - // Connect() and bindLocked() both assert - // - // netProto == header.IPv6ProtocolNumber - // - // before creating a multi-entry effectiveNetProtos. - if len(e.effectiveNetProtos) > 1 { - netProto = header.IPv6ProtocolNumber - } - - var err *tcpip.Error - if e.state == stateConnected { - e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto, e.multicastLoop) - if err != nil { - panic(*err) - } - } else if len(e.id.LocalAddress) != 0 { // stateBound - if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 { - panic(tcpip.ErrBadLocalAddress) - } - } - - // Our saved state had a port, but we don't actually have a - // reservation. We need to remove the port from our state, but still - // pass it to the reservation machinery. - id := e.id - e.id.LocalPort = 0 - e.id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id) - if err != nil { - panic(*err) - } + stack.StackFromEnv.RegisterRestoredEndpoint(e) } diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go index 0285f599d..72cbabd16 100644 --- a/runsc/boot/controller.go +++ b/runsc/boot/controller.go @@ -379,13 +379,10 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error { // Load the state. loadOpts := state.LoadOpts{Source: specFile} - if err := loadOpts.Load(k, networkStack); err != nil { + if err := loadOpts.Load(k, networkStack, time.NewCalibratedClocks()); err != nil { return err } - // Set timekeeper. - k.Timekeeper().SetClocks(time.NewCalibratedClocks()) - // Since we have a new kernel we also must make a new watchdog. dog := watchdog.New(k, watchdog.DefaultTimeout, cm.l.conf.WatchdogAction)