diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go index 60d6dfb93..80f227dbe 100644 --- a/pkg/sentry/inet/inet.go +++ b/pkg/sentry/inet/inet.go @@ -55,6 +55,9 @@ type Stack interface { // RouteTable returns the network stack's route table. RouteTable() []Route + + // Resume restarts the network stack after restore. + Resume() } // Interface contains information about a network interface. diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go index 57d5510f0..b9eed7c3a 100644 --- a/pkg/sentry/inet/test_stack.go +++ b/pkg/sentry/inet/test_stack.go @@ -92,3 +92,7 @@ func (s *TestStack) Statistics(stat interface{}, arg string) error { func (s *TestStack) RouteTable() []Route { return s.RouteList } + +// Resume implements Stack.Resume. +func (s *TestStack) Resume() { +} diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 56a329f83..8c1f79ab5 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -496,7 +496,7 @@ func (ts *TaskSet) unregisterEpollWaiters() { } // LoadFrom returns a new Kernel loaded from args. -func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack) error { +func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) error { loadStart := time.Now() k.networkStack = net @@ -540,6 +540,11 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack) error { log.Infof("Overall load took [%s]", time.Since(loadStart)) + k.Timekeeper().SetClocks(clocks) + if net != nil { + net.Resume() + } + // Ensure that all pending asynchronous work is complete: // - namedpipe opening // - inode file opening @@ -549,7 +554,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack) error { tcpip.AsyncLoading.Wait() - log.Infof("Overall load took [%s]", time.Since(loadStart)) + log.Infof("Overall load took [%s] after async work", time.Since(loadStart)) // Applications may size per-cpu structures based on k.applicationCores, so // it can't change across save/restore. When we are virtualizing CPU diff --git a/pkg/sentry/socket/epsocket/stack.go b/pkg/sentry/socket/epsocket/stack.go index 0cf235b31..8f1572bf4 100644 --- a/pkg/sentry/socket/epsocket/stack.go +++ b/pkg/sentry/socket/epsocket/stack.go @@ -201,3 +201,8 @@ func (s *Stack) IPTables() (iptables.IPTables, error) { func (s *Stack) FillDefaultIPTables() error { return netfilter.FillDefaultIPTables(s.Stack) } + +// Resume implements inet.Stack.Resume. +func (s *Stack) Resume() { + s.Stack.Resume() +} diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index 99b7a1e2b..1902fe155 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -329,3 +329,6 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { func (s *Stack) RouteTable() []inet.Route { return append([]inet.Route(nil), s.routes...) } + +// Resume implements inet.Stack.Resume. +func (s *Stack) Resume() {} diff --git a/pkg/sentry/socket/rpcinet/stack.go b/pkg/sentry/socket/rpcinet/stack.go index d18305589..5dcb6b455 100644 --- a/pkg/sentry/socket/rpcinet/stack.go +++ b/pkg/sentry/socket/rpcinet/stack.go @@ -162,3 +162,6 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { func (s *Stack) RouteTable() []inet.Route { return append([]inet.Route(nil), s.routes...) } + +// Resume implements inet.Stack.Resume. +func (s *Stack) Resume() {} diff --git a/pkg/sentry/state/BUILD b/pkg/sentry/state/BUILD index f297ef3b7..88765f4d6 100644 --- a/pkg/sentry/state/BUILD +++ b/pkg/sentry/state/BUILD @@ -16,6 +16,7 @@ go_library( "//pkg/log", "//pkg/sentry/inet", "//pkg/sentry/kernel", + "//pkg/sentry/time", "//pkg/sentry/watchdog", "//pkg/state/statefile", "//pkg/syserror", diff --git a/pkg/sentry/state/state.go b/pkg/sentry/state/state.go index 026549756..9eb626b76 100644 --- a/pkg/sentry/state/state.go +++ b/pkg/sentry/state/state.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/sentry/watchdog" "gvisor.dev/gvisor/pkg/state/statefile" "gvisor.dev/gvisor/pkg/syserror" @@ -104,7 +105,7 @@ type LoadOpts struct { } // Load loads the given kernel, setting the provided platform and stack. -func (opts LoadOpts) Load(k *kernel.Kernel, n inet.Stack) error { +func (opts LoadOpts) Load(k *kernel.Kernel, n inet.Stack, clocks time.Clocks) error { // Open the file. r, m, err := statefile.NewReader(opts.Source, opts.Key) if err != nil { @@ -114,5 +115,5 @@ func (opts LoadOpts) Load(k *kernel.Kernel, n inet.Stack) error { previousMetadata = m // Restore the Kernel object graph. - return k.LoadFrom(r, n) + return k.LoadFrom(r, n, clocks) } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 78beb0dae..d45e547ee 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -334,6 +334,15 @@ type TCPEndpointState struct { Sender TCPSenderState } +// ResumableEndpoint is an endpoint that needs to be resumed after restore. +type ResumableEndpoint interface { + // Resume resumes an endpoint after restore. This can be used to restart + // background workers such as protocol goroutines. This must be called after + // all indirect dependencies of the endpoint has been restored, which + // generally implies at the end of the restore process. + Resume(*Stack) +} + // Stack is a networking stack, with all supported protocols, NICs, and route // table. type Stack struct { @@ -376,6 +385,10 @@ type Stack struct { // tables are the iptables packet filtering and manipulation rules. tables iptables.IPTables + + // resumableEndpoints is a list of endpoints that need to be resumed if the + // stack is being restored. + resumableEndpoints []ResumableEndpoint } // Options contains optional Stack configuration. @@ -1091,6 +1104,28 @@ func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip } } +// RegisterRestoredEndpoint records e as an endpoint that has been restored on +// this stack. +func (s *Stack) RegisterRestoredEndpoint(e ResumableEndpoint) { + s.mu.Lock() + s.resumableEndpoints = append(s.resumableEndpoints, e) + s.mu.Unlock() +} + +// Resume restarts the stack after a restore. This must be called after the +// entire system has been restored. +func (s *Stack) Resume() { + // ResumableEndpoint.Resume() may call other methods on s, so we can't hold + // s.mu while resuming the endpoints. + s.mu.Lock() + eps := s.resumableEndpoints + s.resumableEndpoints = nil + s.mu.Unlock() + for _, e := range eps { + e.Resume(s) + } +} + // NetworkProtocolInstance returns the protocol instance in the stack for the // specified network protocol. This method is public for protocol implementers // and tests to use. diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 8652d7814..eee3144cd 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -205,6 +205,9 @@ func (f *fakeTransportEndpoint) IPTables() (iptables.IPTables, error) { return iptables.IPTables{}, nil } +func (f *fakeTransportEndpoint) Resume(*stack.Stack) { +} + type fakeTransportGoodOption bool type fakeTransportBadOption bool diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 0df9f6d93..119712d2f 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -1082,11 +1082,13 @@ type ProtocolAddress struct { AddressWithPrefix AddressWithPrefix } -// danglingEndpointsMu protects access to danglingEndpoints. -var danglingEndpointsMu sync.Mutex +var ( + // danglingEndpointsMu protects access to danglingEndpoints. + danglingEndpointsMu sync.Mutex -// danglingEndpoints tracks all dangling endpoints no longer owned by the app. -var danglingEndpoints = make(map[Endpoint]struct{}) + // danglingEndpoints tracks all dangling endpoints no longer owned by the app. + danglingEndpoints = make(map[Endpoint]struct{}) +) // GetDanglingEndpoints returns all dangling endpoints. func GetDanglingEndpoints() []Endpoint { diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index a4527c041..9a4306011 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -136,6 +136,34 @@ func (e *endpoint) IPTables() (iptables.IPTables, error) { return e.stack.IPTables(), nil } +// Resume implements tcpip.ResumableEndpoint.Resume. +func (e *endpoint) Resume(s *stack.Stack) { + e.stack = s + + if e.state != stateBound && e.state != stateConnected { + return + } + + var err *tcpip.Error + if e.state == stateConnected { + e.route, err = e.stack.FindRoute(e.regNICID, e.bindAddr, e.id.RemoteAddress, e.netProto, false /* multicastLoop */) + if err != nil { + panic(*err) + } + + e.id.LocalAddress = e.route.LocalAddress + } else if len(e.id.LocalAddress) != 0 { // stateBound + if e.stack.CheckLocalAddress(e.regNICID, e.netProto, e.id.LocalAddress) == 0 { + panic(tcpip.ErrBadLocalAddress) + } + } + + e.id, err = e.registerWithStack(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.id) + if err != nil { + panic(*err) + } +} + // Read reads data from the endpoint. This method does not block if // there is no data pending. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go index 99b8c4093..43551d642 100644 --- a/pkg/tcpip/transport/icmp/endpoint_state.go +++ b/pkg/tcpip/transport/icmp/endpoint_state.go @@ -15,7 +15,6 @@ package icmp import ( - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -63,28 +62,5 @@ func (e *endpoint) loadRcvBufSizeMax(max int) { // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { - e.stack = stack.StackFromEnv - - if e.state != stateBound && e.state != stateConnected { - return - } - - var err *tcpip.Error - if e.state == stateConnected { - e.route, err = e.stack.FindRoute(e.regNICID, e.bindAddr, e.id.RemoteAddress, e.netProto, false /* multicastLoop */) - if err != nil { - panic(*err) - } - - e.id.LocalAddress = e.route.LocalAddress - } else if len(e.id.LocalAddress) != 0 { // stateBound - if e.stack.CheckLocalAddress(e.regNICID, e.netProto, e.id.LocalAddress) == 0 { - panic(tcpip.ErrBadLocalAddress) - } - } - - e.id, err = e.registerWithStack(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.id) - if err != nil { - panic(*err) - } + stack.StackFromEnv.RegisterRestoredEndpoint(e) } diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index b4be855c1..eab3dcbd2 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -174,6 +174,31 @@ func (ep *endpoint) IPTables() (iptables.IPTables, error) { return ep.stack.IPTables(), nil } +// Resume implements tcpip.ResumableEndpoint.Resume. +func (ep *endpoint) Resume(s *stack.Stack) { + ep.stack = s + + // If the endpoint is connected, re-connect. + if ep.connected { + var err *tcpip.Error + ep.route, err = ep.stack.FindRoute(ep.registeredNIC, ep.boundAddr, ep.route.RemoteAddress, ep.netProto, false) + if err != nil { + panic(*err) + } + } + + // If the endpoint is bound, re-bind. + if ep.bound { + if ep.stack.CheckLocalAddress(ep.registeredNIC, ep.netProto, ep.boundAddr) == 0 { + panic(tcpip.ErrBadLocalAddress) + } + } + + if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil { + panic(*err) + } +} + // Read implements tcpip.Endpoint.Read. func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { if !ep.associated { diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index cb5534d90..44abddb2b 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -15,7 +15,6 @@ package raw import ( - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -63,26 +62,5 @@ func (ep *endpoint) loadRcvBufSizeMax(max int) { // afterLoad is invoked by stateify. func (ep *endpoint) afterLoad() { - // StackFromEnv is a stack used specifically for save/restore. - ep.stack = stack.StackFromEnv - - // If the endpoint is connected, re-connect via the save/restore stack. - if ep.connected { - var err *tcpip.Error - ep.route, err = ep.stack.FindRoute(ep.registeredNIC, ep.boundAddr, ep.route.RemoteAddress, ep.netProto, false) - if err != nil { - panic(*err) - } - } - - // If the endpoint is bound, re-bind via the save/restore stack. - if ep.bound { - if ep.stack.CheckLocalAddress(ep.registeredNIC, ep.netProto, ep.boundAddr) == 0 { - panic(tcpip.ErrBadLocalAddress) - } - } - - if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil { - panic(*err) - } + stack.StackFromEnv.RegisterRestoredEndpoint(ep) } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 0e16877e7..e67169111 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -720,6 +720,107 @@ func (e *endpoint) IPTables() (iptables.IPTables, error) { return e.stack.IPTables(), nil } +// Resume implements tcpip.ResumableEndpoint.Resume. +func (e *endpoint) Resume(s *stack.Stack) { + e.stack = s + e.segmentQueue.setLimit(MaxUnprocessedSegments) + e.workMu.Init() + + state := e.state + switch state { + case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: + var ss SendBufferSizeOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { + if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max { + panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max)) + } + if e.rcvBufSize < ss.Min || e.rcvBufSize > ss.Max { + panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, ss.Min, ss.Max)) + } + } + } + + bind := func() { + e.state = StateInitial + if len(e.bindAddress) == 0 { + e.bindAddress = e.id.LocalAddress + } + if err := e.Bind(tcpip.FullAddress{Addr: e.bindAddress, Port: e.id.LocalPort}); err != nil { + panic("endpoint binding failed: " + err.String()) + } + } + + switch state { + case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: + bind() + if len(e.connectingAddress) == 0 { + e.connectingAddress = e.id.RemoteAddress + // This endpoint is accepted by netstack but not yet by + // the app. If the endpoint is IPv6 but the remote + // address is IPv4, we need to connect as IPv6 so that + // dual-stack mode can be properly activated. + if e.netProto == header.IPv6ProtocolNumber && len(e.id.RemoteAddress) != header.IPv6AddressSize { + e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.id.RemoteAddress + } + } + // Reset the scoreboard to reinitialize the sack information as + // we do not restore SACK information. + e.scoreboard.Reset() + if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted { + panic("endpoint connecting failed: " + err.String()) + } + connectedLoading.Done() + case StateListen: + tcpip.AsyncLoading.Add(1) + go func() { + connectedLoading.Wait() + bind() + backlog := cap(e.acceptedChan) + if err := e.Listen(backlog); err != nil { + panic("endpoint listening failed: " + err.String()) + } + listenLoading.Done() + tcpip.AsyncLoading.Done() + }() + case StateConnecting, StateSynSent, StateSynRecv: + tcpip.AsyncLoading.Add(1) + go func() { + connectedLoading.Wait() + listenLoading.Wait() + bind() + if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}); err != tcpip.ErrConnectStarted { + panic("endpoint connecting failed: " + err.String()) + } + connectingLoading.Done() + tcpip.AsyncLoading.Done() + }() + case StateBound: + tcpip.AsyncLoading.Add(1) + go func() { + connectedLoading.Wait() + listenLoading.Wait() + connectingLoading.Wait() + bind() + tcpip.AsyncLoading.Done() + }() + case StateClose: + if e.isPortReserved { + tcpip.AsyncLoading.Add(1) + go func() { + connectedLoading.Wait() + listenLoading.Wait() + connectingLoading.Wait() + bind() + e.state = StateClose + tcpip.AsyncLoading.Done() + }() + } + fallthrough + case StateError: + tcpip.DeleteDanglingEndpoint(e) + } +} + // Read reads data from the endpoint. func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { e.mu.RLock() diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index b3f0f6c5d..ef88dc618 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -20,7 +20,6 @@ import ( "time" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -165,104 +164,7 @@ func (e *endpoint) loadState(state EndpointState) { // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { - e.stack = stack.StackFromEnv - e.segmentQueue.setLimit(MaxUnprocessedSegments) - e.workMu.Init() - - state := e.state - switch state { - case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: - var ss SendBufferSizeOption - if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { - if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max { - panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max)) - } - if e.rcvBufSize < ss.Min || e.rcvBufSize > ss.Max { - panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, ss.Min, ss.Max)) - } - } - } - - bind := func() { - e.state = StateInitial - if len(e.bindAddress) == 0 { - e.bindAddress = e.id.LocalAddress - } - if err := e.Bind(tcpip.FullAddress{Addr: e.bindAddress, Port: e.id.LocalPort}); err != nil { - panic("endpoint binding failed: " + err.String()) - } - } - - switch state { - case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: - bind() - if len(e.connectingAddress) == 0 { - // This endpoint is accepted by netstack but not yet by - // the app. If the endpoint is IPv6 but the remote - // address is IPv4, we need to connect as IPv6 so that - // dual-stack mode can be properly activated. - if e.netProto == header.IPv6ProtocolNumber && len(e.id.RemoteAddress) != header.IPv6AddressSize { - e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.id.RemoteAddress - } else { - e.connectingAddress = e.id.RemoteAddress - } - } - // Reset the scoreboard to reinitialize the sack information as - // we do not restore SACK information. - e.scoreboard.Reset() - if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted { - panic("endpoint connecting failed: " + err.String()) - } - connectedLoading.Done() - case StateListen: - tcpip.AsyncLoading.Add(1) - go func() { - connectedLoading.Wait() - bind() - backlog := cap(e.acceptedChan) - if err := e.Listen(backlog); err != nil { - panic("endpoint listening failed: " + err.String()) - } - listenLoading.Done() - tcpip.AsyncLoading.Done() - }() - case StateConnecting, StateSynSent, StateSynRecv: - tcpip.AsyncLoading.Add(1) - go func() { - connectedLoading.Wait() - listenLoading.Wait() - bind() - if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}); err != tcpip.ErrConnectStarted { - panic("endpoint connecting failed: " + err.String()) - } - connectingLoading.Done() - tcpip.AsyncLoading.Done() - }() - case StateBound: - tcpip.AsyncLoading.Add(1) - go func() { - connectedLoading.Wait() - listenLoading.Wait() - connectingLoading.Wait() - bind() - tcpip.AsyncLoading.Done() - }() - case StateClose: - if e.isPortReserved { - tcpip.AsyncLoading.Add(1) - go func() { - connectedLoading.Wait() - listenLoading.Wait() - connectingLoading.Wait() - bind() - e.state = StateClose - tcpip.AsyncLoading.Done() - }() - } - fallthrough - case StateError: - tcpip.DeleteDanglingEndpoint(e) - } + stack.StackFromEnv.RegisterRestoredEndpoint(e) } // saveLastError is invoked by stateify. diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 7210b3a9f..7c12a6092 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -178,6 +178,53 @@ func (e *endpoint) IPTables() (iptables.IPTables, error) { return e.stack.IPTables(), nil } +// Resume implements tcpip.ResumableEndpoint.Resume. +func (e *endpoint) Resume(s *stack.Stack) { + e.stack = s + + for _, m := range e.multicastMemberships { + if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil { + panic(err) + } + } + + if e.state != stateBound && e.state != stateConnected { + return + } + + netProto := e.effectiveNetProtos[0] + // Connect() and bindLocked() both assert + // + // netProto == header.IPv6ProtocolNumber + // + // before creating a multi-entry effectiveNetProtos. + if len(e.effectiveNetProtos) > 1 { + netProto = header.IPv6ProtocolNumber + } + + var err *tcpip.Error + if e.state == stateConnected { + e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto, e.multicastLoop) + if err != nil { + panic(*err) + } + } else if len(e.id.LocalAddress) != 0 { // stateBound + if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 { + panic(tcpip.ErrBadLocalAddress) + } + } + + // Our saved state had a port, but we don't actually have a + // reservation. We need to remove the port from our state, but still + // pass it to the reservation machinery. + id := e.id + e.id.LocalPort = 0 + e.id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id) + if err != nil { + panic(*err) + } +} + // Read reads data from the endpoint. This method does not block if // there is no data pending. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 18e786397..86db36260 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -15,9 +15,7 @@ package udp import ( - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -64,47 +62,5 @@ func (e *endpoint) loadRcvBufSizeMax(max int) { // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { - e.stack = stack.StackFromEnv - - for _, m := range e.multicastMemberships { - if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil { - panic(err) - } - } - - if e.state != stateBound && e.state != stateConnected { - return - } - - netProto := e.effectiveNetProtos[0] - // Connect() and bindLocked() both assert - // - // netProto == header.IPv6ProtocolNumber - // - // before creating a multi-entry effectiveNetProtos. - if len(e.effectiveNetProtos) > 1 { - netProto = header.IPv6ProtocolNumber - } - - var err *tcpip.Error - if e.state == stateConnected { - e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto, e.multicastLoop) - if err != nil { - panic(*err) - } - } else if len(e.id.LocalAddress) != 0 { // stateBound - if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 { - panic(tcpip.ErrBadLocalAddress) - } - } - - // Our saved state had a port, but we don't actually have a - // reservation. We need to remove the port from our state, but still - // pass it to the reservation machinery. - id := e.id - e.id.LocalPort = 0 - e.id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id) - if err != nil { - panic(*err) - } + stack.StackFromEnv.RegisterRestoredEndpoint(e) } diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go index 0285f599d..72cbabd16 100644 --- a/runsc/boot/controller.go +++ b/runsc/boot/controller.go @@ -379,13 +379,10 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error { // Load the state. loadOpts := state.LoadOpts{Source: specFile} - if err := loadOpts.Load(k, networkStack); err != nil { + if err := loadOpts.Load(k, networkStack, time.NewCalibratedClocks()); err != nil { return err } - // Set timekeeper. - k.Timekeeper().SetClocks(time.NewCalibratedClocks()) - // Since we have a new kernel we also must make a new watchdog. dog := watchdog.New(k, watchdog.DefaultTimeout, cm.l.conf.WatchdogAction)