From 816a9211e96b3eabe7bead7fa625831fe8d2fbf8 Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Thu, 15 Aug 2019 11:12:28 -0700 Subject: [PATCH] netstack: move resumption logic into *_state.go 13a98df rearranged some of this code in a way that broke compilation of the netstack-only export at github.com/google/netstack because *_state.go files are not included in that export. This commit moves resumption logic back into *_state.go, fixing the compilation breakage. PiperOrigin-RevId: 263601629 --- pkg/tcpip/transport/icmp/endpoint.go | 28 ------ pkg/tcpip/transport/icmp/endpoint_state.go | 29 ++++++ pkg/tcpip/transport/raw/endpoint.go | 25 ----- pkg/tcpip/transport/raw/endpoint_state.go | 26 ++++++ pkg/tcpip/transport/tcp/endpoint.go | 101 -------------------- pkg/tcpip/transport/tcp/endpoint_state.go | 102 +++++++++++++++++++++ pkg/tcpip/transport/udp/endpoint.go | 47 ---------- pkg/tcpip/transport/udp/endpoint_state.go | 49 ++++++++++ 8 files changed, 206 insertions(+), 201 deletions(-) diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 5ca208619..2e8d5d4bf 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -136,34 +136,6 @@ func (e *endpoint) IPTables() (iptables.IPTables, error) { return e.stack.IPTables(), nil } -// Resume implements tcpip.ResumableEndpoint.Resume. -func (e *endpoint) Resume(s *stack.Stack) { - e.stack = s - - if e.state != stateBound && e.state != stateConnected { - return - } - - var err *tcpip.Error - if e.state == stateConnected { - e.route, err = e.stack.FindRoute(e.regNICID, e.bindAddr, e.id.RemoteAddress, e.netProto, false /* multicastLoop */) - if err != nil { - panic(*err) - } - - e.id.LocalAddress = e.route.LocalAddress - } else if len(e.id.LocalAddress) != 0 { // stateBound - if e.stack.CheckLocalAddress(e.regNICID, e.netProto, e.id.LocalAddress) == 0 { - panic(tcpip.ErrBadLocalAddress) - } - } - - e.id, err = e.registerWithStack(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.id) - if err != nil { - panic(*err) - } -} - // Read reads data from the endpoint. This method does not block if // there is no data pending. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go index 43551d642..c5690174e 100644 --- a/pkg/tcpip/transport/icmp/endpoint_state.go +++ b/pkg/tcpip/transport/icmp/endpoint_state.go @@ -15,6 +15,7 @@ package icmp import ( + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -64,3 +65,31 @@ func (e *endpoint) loadRcvBufSizeMax(max int) { func (e *endpoint) afterLoad() { stack.StackFromEnv.RegisterRestoredEndpoint(e) } + +// Resume implements tcpip.ResumableEndpoint.Resume. +func (e *endpoint) Resume(s *stack.Stack) { + e.stack = s + + if e.state != stateBound && e.state != stateConnected { + return + } + + var err *tcpip.Error + if e.state == stateConnected { + e.route, err = e.stack.FindRoute(e.regNICID, e.bindAddr, e.id.RemoteAddress, e.netProto, false /* multicastLoop */) + if err != nil { + panic(*err) + } + + e.id.LocalAddress = e.route.LocalAddress + } else if len(e.id.LocalAddress) != 0 { // stateBound + if e.stack.CheckLocalAddress(e.regNICID, e.netProto, e.id.LocalAddress) == 0 { + panic(tcpip.ErrBadLocalAddress) + } + } + + e.id, err = e.registerWithStack(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.id) + if err != nil { + panic(*err) + } +} diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index cde655bb6..53c9515a4 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -174,31 +174,6 @@ func (ep *endpoint) IPTables() (iptables.IPTables, error) { return ep.stack.IPTables(), nil } -// Resume implements tcpip.ResumableEndpoint.Resume. -func (ep *endpoint) Resume(s *stack.Stack) { - ep.stack = s - - // If the endpoint is connected, re-connect. - if ep.connected { - var err *tcpip.Error - ep.route, err = ep.stack.FindRoute(ep.registeredNIC, ep.boundAddr, ep.route.RemoteAddress, ep.netProto, false) - if err != nil { - panic(*err) - } - } - - // If the endpoint is bound, re-bind. - if ep.bound { - if ep.stack.CheckLocalAddress(ep.registeredNIC, ep.netProto, ep.boundAddr) == 0 { - panic(tcpip.ErrBadLocalAddress) - } - } - - if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil { - panic(*err) - } -} - // Read implements tcpip.Endpoint.Read. func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { if !ep.associated { diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index 44abddb2b..a3d6f4580 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -15,6 +15,7 @@ package raw import ( + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -64,3 +65,28 @@ func (ep *endpoint) loadRcvBufSizeMax(max int) { func (ep *endpoint) afterLoad() { stack.StackFromEnv.RegisterRestoredEndpoint(ep) } + +// Resume implements tcpip.ResumableEndpoint.Resume. +func (ep *endpoint) Resume(s *stack.Stack) { + ep.stack = s + + // If the endpoint is connected, re-connect. + if ep.connected { + var err *tcpip.Error + ep.route, err = ep.stack.FindRoute(ep.registeredNIC, ep.boundAddr, ep.route.RemoteAddress, ep.netProto, false) + if err != nil { + panic(*err) + } + } + + // If the endpoint is bound, re-bind. + if ep.bound { + if ep.stack.CheckLocalAddress(ep.registeredNIC, ep.netProto, ep.boundAddr) == 0 { + panic(tcpip.ErrBadLocalAddress) + } + } + + if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil { + panic(*err) + } +} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index ed23ea0b8..e5f835c20 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -720,107 +720,6 @@ func (e *endpoint) IPTables() (iptables.IPTables, error) { return e.stack.IPTables(), nil } -// Resume implements tcpip.ResumableEndpoint.Resume. -func (e *endpoint) Resume(s *stack.Stack) { - e.stack = s - e.segmentQueue.setLimit(MaxUnprocessedSegments) - e.workMu.Init() - - state := e.state - switch state { - case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: - var ss SendBufferSizeOption - if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { - if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max { - panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max)) - } - if e.rcvBufSize < ss.Min || e.rcvBufSize > ss.Max { - panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, ss.Min, ss.Max)) - } - } - } - - bind := func() { - e.state = StateInitial - if len(e.bindAddress) == 0 { - e.bindAddress = e.id.LocalAddress - } - if err := e.Bind(tcpip.FullAddress{Addr: e.bindAddress, Port: e.id.LocalPort}); err != nil { - panic("endpoint binding failed: " + err.String()) - } - } - - switch state { - case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: - bind() - if len(e.connectingAddress) == 0 { - e.connectingAddress = e.id.RemoteAddress - // This endpoint is accepted by netstack but not yet by - // the app. If the endpoint is IPv6 but the remote - // address is IPv4, we need to connect as IPv6 so that - // dual-stack mode can be properly activated. - if e.netProto == header.IPv6ProtocolNumber && len(e.id.RemoteAddress) != header.IPv6AddressSize { - e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.id.RemoteAddress - } - } - // Reset the scoreboard to reinitialize the sack information as - // we do not restore SACK information. - e.scoreboard.Reset() - if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted { - panic("endpoint connecting failed: " + err.String()) - } - connectedLoading.Done() - case StateListen: - tcpip.AsyncLoading.Add(1) - go func() { - connectedLoading.Wait() - bind() - backlog := cap(e.acceptedChan) - if err := e.Listen(backlog); err != nil { - panic("endpoint listening failed: " + err.String()) - } - listenLoading.Done() - tcpip.AsyncLoading.Done() - }() - case StateConnecting, StateSynSent, StateSynRecv: - tcpip.AsyncLoading.Add(1) - go func() { - connectedLoading.Wait() - listenLoading.Wait() - bind() - if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}); err != tcpip.ErrConnectStarted { - panic("endpoint connecting failed: " + err.String()) - } - connectingLoading.Done() - tcpip.AsyncLoading.Done() - }() - case StateBound: - tcpip.AsyncLoading.Add(1) - go func() { - connectedLoading.Wait() - listenLoading.Wait() - connectingLoading.Wait() - bind() - tcpip.AsyncLoading.Done() - }() - case StateClose: - if e.isPortReserved { - tcpip.AsyncLoading.Add(1) - go func() { - connectedLoading.Wait() - listenLoading.Wait() - connectingLoading.Wait() - bind() - e.state = StateClose - tcpip.AsyncLoading.Done() - }() - } - fallthrough - case StateError: - tcpip.DeleteDanglingEndpoint(e) - } -} - // Read reads data from the endpoint. func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { e.mu.RLock() diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index ef88dc618..831389ec7 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -20,6 +20,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -167,6 +168,107 @@ func (e *endpoint) afterLoad() { stack.StackFromEnv.RegisterRestoredEndpoint(e) } +// Resume implements tcpip.ResumableEndpoint.Resume. +func (e *endpoint) Resume(s *stack.Stack) { + e.stack = s + e.segmentQueue.setLimit(MaxUnprocessedSegments) + e.workMu.Init() + + state := e.state + switch state { + case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: + var ss SendBufferSizeOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { + if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max { + panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max)) + } + if e.rcvBufSize < ss.Min || e.rcvBufSize > ss.Max { + panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, ss.Min, ss.Max)) + } + } + } + + bind := func() { + e.state = StateInitial + if len(e.bindAddress) == 0 { + e.bindAddress = e.id.LocalAddress + } + if err := e.Bind(tcpip.FullAddress{Addr: e.bindAddress, Port: e.id.LocalPort}); err != nil { + panic("endpoint binding failed: " + err.String()) + } + } + + switch state { + case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: + bind() + if len(e.connectingAddress) == 0 { + e.connectingAddress = e.id.RemoteAddress + // This endpoint is accepted by netstack but not yet by + // the app. If the endpoint is IPv6 but the remote + // address is IPv4, we need to connect as IPv6 so that + // dual-stack mode can be properly activated. + if e.netProto == header.IPv6ProtocolNumber && len(e.id.RemoteAddress) != header.IPv6AddressSize { + e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.id.RemoteAddress + } + } + // Reset the scoreboard to reinitialize the sack information as + // we do not restore SACK information. + e.scoreboard.Reset() + if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted { + panic("endpoint connecting failed: " + err.String()) + } + connectedLoading.Done() + case StateListen: + tcpip.AsyncLoading.Add(1) + go func() { + connectedLoading.Wait() + bind() + backlog := cap(e.acceptedChan) + if err := e.Listen(backlog); err != nil { + panic("endpoint listening failed: " + err.String()) + } + listenLoading.Done() + tcpip.AsyncLoading.Done() + }() + case StateConnecting, StateSynSent, StateSynRecv: + tcpip.AsyncLoading.Add(1) + go func() { + connectedLoading.Wait() + listenLoading.Wait() + bind() + if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}); err != tcpip.ErrConnectStarted { + panic("endpoint connecting failed: " + err.String()) + } + connectingLoading.Done() + tcpip.AsyncLoading.Done() + }() + case StateBound: + tcpip.AsyncLoading.Add(1) + go func() { + connectedLoading.Wait() + listenLoading.Wait() + connectingLoading.Wait() + bind() + tcpip.AsyncLoading.Done() + }() + case StateClose: + if e.isPortReserved { + tcpip.AsyncLoading.Add(1) + go func() { + connectedLoading.Wait() + listenLoading.Wait() + connectingLoading.Wait() + bind() + e.state = StateClose + tcpip.AsyncLoading.Done() + }() + } + fallthrough + case StateError: + tcpip.DeleteDanglingEndpoint(e) + } +} + // saveLastError is invoked by stateify. func (e *endpoint) saveLastError() string { if e.lastError == nil { diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index c50f5abb3..640bb8667 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -178,53 +178,6 @@ func (e *endpoint) IPTables() (iptables.IPTables, error) { return e.stack.IPTables(), nil } -// Resume implements tcpip.ResumableEndpoint.Resume. -func (e *endpoint) Resume(s *stack.Stack) { - e.stack = s - - for _, m := range e.multicastMemberships { - if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil { - panic(err) - } - } - - if e.state != stateBound && e.state != stateConnected { - return - } - - netProto := e.effectiveNetProtos[0] - // Connect() and bindLocked() both assert - // - // netProto == header.IPv6ProtocolNumber - // - // before creating a multi-entry effectiveNetProtos. - if len(e.effectiveNetProtos) > 1 { - netProto = header.IPv6ProtocolNumber - } - - var err *tcpip.Error - if e.state == stateConnected { - e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto, e.multicastLoop) - if err != nil { - panic(*err) - } - } else if len(e.id.LocalAddress) != 0 { // stateBound - if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 { - panic(tcpip.ErrBadLocalAddress) - } - } - - // Our saved state had a port, but we don't actually have a - // reservation. We need to remove the port from our state, but still - // pass it to the reservation machinery. - id := e.id - e.id.LocalPort = 0 - e.id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id) - if err != nil { - panic(*err) - } -} - // Read reads data from the endpoint. This method does not block if // there is no data pending. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 86db36260..bc821a96f 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -15,7 +15,9 @@ package udp import ( + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -64,3 +66,50 @@ func (e *endpoint) loadRcvBufSizeMax(max int) { func (e *endpoint) afterLoad() { stack.StackFromEnv.RegisterRestoredEndpoint(e) } + +// Resume implements tcpip.ResumableEndpoint.Resume. +func (e *endpoint) Resume(s *stack.Stack) { + e.stack = s + + for _, m := range e.multicastMemberships { + if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil { + panic(err) + } + } + + if e.state != stateBound && e.state != stateConnected { + return + } + + netProto := e.effectiveNetProtos[0] + // Connect() and bindLocked() both assert + // + // netProto == header.IPv6ProtocolNumber + // + // before creating a multi-entry effectiveNetProtos. + if len(e.effectiveNetProtos) > 1 { + netProto = header.IPv6ProtocolNumber + } + + var err *tcpip.Error + if e.state == stateConnected { + e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto, e.multicastLoop) + if err != nil { + panic(*err) + } + } else if len(e.id.LocalAddress) != 0 { // stateBound + if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 { + panic(tcpip.ErrBadLocalAddress) + } + } + + // Our saved state had a port, but we don't actually have a + // reservation. We need to remove the port from our state, but still + // pass it to the reservation machinery. + id := e.id + e.id.LocalPort = 0 + e.id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id) + if err != nil { + panic(*err) + } +}