netstack: move resumption logic into *_state.go
13a98df
rearranged some of this code in a way that broke compilation of
the netstack-only export at github.com/google/netstack because
*_state.go files are not included in that export.
This commit moves resumption logic back into *_state.go, fixing the
compilation breakage.
PiperOrigin-RevId: 263601629
This commit is contained in:
parent
d81d94ac4c
commit
816a9211e9
|
@ -136,34 +136,6 @@ func (e *endpoint) IPTables() (iptables.IPTables, error) {
|
||||||
return e.stack.IPTables(), nil
|
return e.stack.IPTables(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resume implements tcpip.ResumableEndpoint.Resume.
|
|
||||||
func (e *endpoint) Resume(s *stack.Stack) {
|
|
||||||
e.stack = s
|
|
||||||
|
|
||||||
if e.state != stateBound && e.state != stateConnected {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var err *tcpip.Error
|
|
||||||
if e.state == stateConnected {
|
|
||||||
e.route, err = e.stack.FindRoute(e.regNICID, e.bindAddr, e.id.RemoteAddress, e.netProto, false /* multicastLoop */)
|
|
||||||
if err != nil {
|
|
||||||
panic(*err)
|
|
||||||
}
|
|
||||||
|
|
||||||
e.id.LocalAddress = e.route.LocalAddress
|
|
||||||
} else if len(e.id.LocalAddress) != 0 { // stateBound
|
|
||||||
if e.stack.CheckLocalAddress(e.regNICID, e.netProto, e.id.LocalAddress) == 0 {
|
|
||||||
panic(tcpip.ErrBadLocalAddress)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
e.id, err = e.registerWithStack(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.id)
|
|
||||||
if err != nil {
|
|
||||||
panic(*err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read reads data from the endpoint. This method does not block if
|
// Read reads data from the endpoint. This method does not block if
|
||||||
// there is no data pending.
|
// there is no data pending.
|
||||||
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
|
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package icmp
|
package icmp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/buffer"
|
"gvisor.dev/gvisor/pkg/tcpip/buffer"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
)
|
)
|
||||||
|
@ -64,3 +65,31 @@ func (e *endpoint) loadRcvBufSizeMax(max int) {
|
||||||
func (e *endpoint) afterLoad() {
|
func (e *endpoint) afterLoad() {
|
||||||
stack.StackFromEnv.RegisterRestoredEndpoint(e)
|
stack.StackFromEnv.RegisterRestoredEndpoint(e)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Resume implements tcpip.ResumableEndpoint.Resume.
|
||||||
|
func (e *endpoint) Resume(s *stack.Stack) {
|
||||||
|
e.stack = s
|
||||||
|
|
||||||
|
if e.state != stateBound && e.state != stateConnected {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var err *tcpip.Error
|
||||||
|
if e.state == stateConnected {
|
||||||
|
e.route, err = e.stack.FindRoute(e.regNICID, e.bindAddr, e.id.RemoteAddress, e.netProto, false /* multicastLoop */)
|
||||||
|
if err != nil {
|
||||||
|
panic(*err)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.id.LocalAddress = e.route.LocalAddress
|
||||||
|
} else if len(e.id.LocalAddress) != 0 { // stateBound
|
||||||
|
if e.stack.CheckLocalAddress(e.regNICID, e.netProto, e.id.LocalAddress) == 0 {
|
||||||
|
panic(tcpip.ErrBadLocalAddress)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
e.id, err = e.registerWithStack(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.id)
|
||||||
|
if err != nil {
|
||||||
|
panic(*err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -174,31 +174,6 @@ func (ep *endpoint) IPTables() (iptables.IPTables, error) {
|
||||||
return ep.stack.IPTables(), nil
|
return ep.stack.IPTables(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resume implements tcpip.ResumableEndpoint.Resume.
|
|
||||||
func (ep *endpoint) Resume(s *stack.Stack) {
|
|
||||||
ep.stack = s
|
|
||||||
|
|
||||||
// If the endpoint is connected, re-connect.
|
|
||||||
if ep.connected {
|
|
||||||
var err *tcpip.Error
|
|
||||||
ep.route, err = ep.stack.FindRoute(ep.registeredNIC, ep.boundAddr, ep.route.RemoteAddress, ep.netProto, false)
|
|
||||||
if err != nil {
|
|
||||||
panic(*err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the endpoint is bound, re-bind.
|
|
||||||
if ep.bound {
|
|
||||||
if ep.stack.CheckLocalAddress(ep.registeredNIC, ep.netProto, ep.boundAddr) == 0 {
|
|
||||||
panic(tcpip.ErrBadLocalAddress)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil {
|
|
||||||
panic(*err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read implements tcpip.Endpoint.Read.
|
// Read implements tcpip.Endpoint.Read.
|
||||||
func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
|
func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
|
||||||
if !ep.associated {
|
if !ep.associated {
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package raw
|
package raw
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/buffer"
|
"gvisor.dev/gvisor/pkg/tcpip/buffer"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
)
|
)
|
||||||
|
@ -64,3 +65,28 @@ func (ep *endpoint) loadRcvBufSizeMax(max int) {
|
||||||
func (ep *endpoint) afterLoad() {
|
func (ep *endpoint) afterLoad() {
|
||||||
stack.StackFromEnv.RegisterRestoredEndpoint(ep)
|
stack.StackFromEnv.RegisterRestoredEndpoint(ep)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Resume implements tcpip.ResumableEndpoint.Resume.
|
||||||
|
func (ep *endpoint) Resume(s *stack.Stack) {
|
||||||
|
ep.stack = s
|
||||||
|
|
||||||
|
// If the endpoint is connected, re-connect.
|
||||||
|
if ep.connected {
|
||||||
|
var err *tcpip.Error
|
||||||
|
ep.route, err = ep.stack.FindRoute(ep.registeredNIC, ep.boundAddr, ep.route.RemoteAddress, ep.netProto, false)
|
||||||
|
if err != nil {
|
||||||
|
panic(*err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the endpoint is bound, re-bind.
|
||||||
|
if ep.bound {
|
||||||
|
if ep.stack.CheckLocalAddress(ep.registeredNIC, ep.netProto, ep.boundAddr) == 0 {
|
||||||
|
panic(tcpip.ErrBadLocalAddress)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil {
|
||||||
|
panic(*err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -720,107 +720,6 @@ func (e *endpoint) IPTables() (iptables.IPTables, error) {
|
||||||
return e.stack.IPTables(), nil
|
return e.stack.IPTables(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resume implements tcpip.ResumableEndpoint.Resume.
|
|
||||||
func (e *endpoint) Resume(s *stack.Stack) {
|
|
||||||
e.stack = s
|
|
||||||
e.segmentQueue.setLimit(MaxUnprocessedSegments)
|
|
||||||
e.workMu.Init()
|
|
||||||
|
|
||||||
state := e.state
|
|
||||||
switch state {
|
|
||||||
case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
|
|
||||||
var ss SendBufferSizeOption
|
|
||||||
if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
|
|
||||||
if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max {
|
|
||||||
panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max))
|
|
||||||
}
|
|
||||||
if e.rcvBufSize < ss.Min || e.rcvBufSize > ss.Max {
|
|
||||||
panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, ss.Min, ss.Max))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bind := func() {
|
|
||||||
e.state = StateInitial
|
|
||||||
if len(e.bindAddress) == 0 {
|
|
||||||
e.bindAddress = e.id.LocalAddress
|
|
||||||
}
|
|
||||||
if err := e.Bind(tcpip.FullAddress{Addr: e.bindAddress, Port: e.id.LocalPort}); err != nil {
|
|
||||||
panic("endpoint binding failed: " + err.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
switch state {
|
|
||||||
case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
|
|
||||||
bind()
|
|
||||||
if len(e.connectingAddress) == 0 {
|
|
||||||
e.connectingAddress = e.id.RemoteAddress
|
|
||||||
// This endpoint is accepted by netstack but not yet by
|
|
||||||
// the app. If the endpoint is IPv6 but the remote
|
|
||||||
// address is IPv4, we need to connect as IPv6 so that
|
|
||||||
// dual-stack mode can be properly activated.
|
|
||||||
if e.netProto == header.IPv6ProtocolNumber && len(e.id.RemoteAddress) != header.IPv6AddressSize {
|
|
||||||
e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.id.RemoteAddress
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Reset the scoreboard to reinitialize the sack information as
|
|
||||||
// we do not restore SACK information.
|
|
||||||
e.scoreboard.Reset()
|
|
||||||
if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted {
|
|
||||||
panic("endpoint connecting failed: " + err.String())
|
|
||||||
}
|
|
||||||
connectedLoading.Done()
|
|
||||||
case StateListen:
|
|
||||||
tcpip.AsyncLoading.Add(1)
|
|
||||||
go func() {
|
|
||||||
connectedLoading.Wait()
|
|
||||||
bind()
|
|
||||||
backlog := cap(e.acceptedChan)
|
|
||||||
if err := e.Listen(backlog); err != nil {
|
|
||||||
panic("endpoint listening failed: " + err.String())
|
|
||||||
}
|
|
||||||
listenLoading.Done()
|
|
||||||
tcpip.AsyncLoading.Done()
|
|
||||||
}()
|
|
||||||
case StateConnecting, StateSynSent, StateSynRecv:
|
|
||||||
tcpip.AsyncLoading.Add(1)
|
|
||||||
go func() {
|
|
||||||
connectedLoading.Wait()
|
|
||||||
listenLoading.Wait()
|
|
||||||
bind()
|
|
||||||
if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}); err != tcpip.ErrConnectStarted {
|
|
||||||
panic("endpoint connecting failed: " + err.String())
|
|
||||||
}
|
|
||||||
connectingLoading.Done()
|
|
||||||
tcpip.AsyncLoading.Done()
|
|
||||||
}()
|
|
||||||
case StateBound:
|
|
||||||
tcpip.AsyncLoading.Add(1)
|
|
||||||
go func() {
|
|
||||||
connectedLoading.Wait()
|
|
||||||
listenLoading.Wait()
|
|
||||||
connectingLoading.Wait()
|
|
||||||
bind()
|
|
||||||
tcpip.AsyncLoading.Done()
|
|
||||||
}()
|
|
||||||
case StateClose:
|
|
||||||
if e.isPortReserved {
|
|
||||||
tcpip.AsyncLoading.Add(1)
|
|
||||||
go func() {
|
|
||||||
connectedLoading.Wait()
|
|
||||||
listenLoading.Wait()
|
|
||||||
connectingLoading.Wait()
|
|
||||||
bind()
|
|
||||||
e.state = StateClose
|
|
||||||
tcpip.AsyncLoading.Done()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
fallthrough
|
|
||||||
case StateError:
|
|
||||||
tcpip.DeleteDanglingEndpoint(e)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read reads data from the endpoint.
|
// Read reads data from the endpoint.
|
||||||
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
|
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
|
||||||
e.mu.RLock()
|
e.mu.RLock()
|
||||||
|
|
|
@ -20,6 +20,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gvisor.dev/gvisor/pkg/tcpip"
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -167,6 +168,107 @@ func (e *endpoint) afterLoad() {
|
||||||
stack.StackFromEnv.RegisterRestoredEndpoint(e)
|
stack.StackFromEnv.RegisterRestoredEndpoint(e)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Resume implements tcpip.ResumableEndpoint.Resume.
|
||||||
|
func (e *endpoint) Resume(s *stack.Stack) {
|
||||||
|
e.stack = s
|
||||||
|
e.segmentQueue.setLimit(MaxUnprocessedSegments)
|
||||||
|
e.workMu.Init()
|
||||||
|
|
||||||
|
state := e.state
|
||||||
|
switch state {
|
||||||
|
case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
|
||||||
|
var ss SendBufferSizeOption
|
||||||
|
if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
|
||||||
|
if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max {
|
||||||
|
panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max))
|
||||||
|
}
|
||||||
|
if e.rcvBufSize < ss.Min || e.rcvBufSize > ss.Max {
|
||||||
|
panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, ss.Min, ss.Max))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bind := func() {
|
||||||
|
e.state = StateInitial
|
||||||
|
if len(e.bindAddress) == 0 {
|
||||||
|
e.bindAddress = e.id.LocalAddress
|
||||||
|
}
|
||||||
|
if err := e.Bind(tcpip.FullAddress{Addr: e.bindAddress, Port: e.id.LocalPort}); err != nil {
|
||||||
|
panic("endpoint binding failed: " + err.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch state {
|
||||||
|
case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
|
||||||
|
bind()
|
||||||
|
if len(e.connectingAddress) == 0 {
|
||||||
|
e.connectingAddress = e.id.RemoteAddress
|
||||||
|
// This endpoint is accepted by netstack but not yet by
|
||||||
|
// the app. If the endpoint is IPv6 but the remote
|
||||||
|
// address is IPv4, we need to connect as IPv6 so that
|
||||||
|
// dual-stack mode can be properly activated.
|
||||||
|
if e.netProto == header.IPv6ProtocolNumber && len(e.id.RemoteAddress) != header.IPv6AddressSize {
|
||||||
|
e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.id.RemoteAddress
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Reset the scoreboard to reinitialize the sack information as
|
||||||
|
// we do not restore SACK information.
|
||||||
|
e.scoreboard.Reset()
|
||||||
|
if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted {
|
||||||
|
panic("endpoint connecting failed: " + err.String())
|
||||||
|
}
|
||||||
|
connectedLoading.Done()
|
||||||
|
case StateListen:
|
||||||
|
tcpip.AsyncLoading.Add(1)
|
||||||
|
go func() {
|
||||||
|
connectedLoading.Wait()
|
||||||
|
bind()
|
||||||
|
backlog := cap(e.acceptedChan)
|
||||||
|
if err := e.Listen(backlog); err != nil {
|
||||||
|
panic("endpoint listening failed: " + err.String())
|
||||||
|
}
|
||||||
|
listenLoading.Done()
|
||||||
|
tcpip.AsyncLoading.Done()
|
||||||
|
}()
|
||||||
|
case StateConnecting, StateSynSent, StateSynRecv:
|
||||||
|
tcpip.AsyncLoading.Add(1)
|
||||||
|
go func() {
|
||||||
|
connectedLoading.Wait()
|
||||||
|
listenLoading.Wait()
|
||||||
|
bind()
|
||||||
|
if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}); err != tcpip.ErrConnectStarted {
|
||||||
|
panic("endpoint connecting failed: " + err.String())
|
||||||
|
}
|
||||||
|
connectingLoading.Done()
|
||||||
|
tcpip.AsyncLoading.Done()
|
||||||
|
}()
|
||||||
|
case StateBound:
|
||||||
|
tcpip.AsyncLoading.Add(1)
|
||||||
|
go func() {
|
||||||
|
connectedLoading.Wait()
|
||||||
|
listenLoading.Wait()
|
||||||
|
connectingLoading.Wait()
|
||||||
|
bind()
|
||||||
|
tcpip.AsyncLoading.Done()
|
||||||
|
}()
|
||||||
|
case StateClose:
|
||||||
|
if e.isPortReserved {
|
||||||
|
tcpip.AsyncLoading.Add(1)
|
||||||
|
go func() {
|
||||||
|
connectedLoading.Wait()
|
||||||
|
listenLoading.Wait()
|
||||||
|
connectingLoading.Wait()
|
||||||
|
bind()
|
||||||
|
e.state = StateClose
|
||||||
|
tcpip.AsyncLoading.Done()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
fallthrough
|
||||||
|
case StateError:
|
||||||
|
tcpip.DeleteDanglingEndpoint(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// saveLastError is invoked by stateify.
|
// saveLastError is invoked by stateify.
|
||||||
func (e *endpoint) saveLastError() string {
|
func (e *endpoint) saveLastError() string {
|
||||||
if e.lastError == nil {
|
if e.lastError == nil {
|
||||||
|
|
|
@ -178,53 +178,6 @@ func (e *endpoint) IPTables() (iptables.IPTables, error) {
|
||||||
return e.stack.IPTables(), nil
|
return e.stack.IPTables(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resume implements tcpip.ResumableEndpoint.Resume.
|
|
||||||
func (e *endpoint) Resume(s *stack.Stack) {
|
|
||||||
e.stack = s
|
|
||||||
|
|
||||||
for _, m := range e.multicastMemberships {
|
|
||||||
if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if e.state != stateBound && e.state != stateConnected {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
netProto := e.effectiveNetProtos[0]
|
|
||||||
// Connect() and bindLocked() both assert
|
|
||||||
//
|
|
||||||
// netProto == header.IPv6ProtocolNumber
|
|
||||||
//
|
|
||||||
// before creating a multi-entry effectiveNetProtos.
|
|
||||||
if len(e.effectiveNetProtos) > 1 {
|
|
||||||
netProto = header.IPv6ProtocolNumber
|
|
||||||
}
|
|
||||||
|
|
||||||
var err *tcpip.Error
|
|
||||||
if e.state == stateConnected {
|
|
||||||
e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto, e.multicastLoop)
|
|
||||||
if err != nil {
|
|
||||||
panic(*err)
|
|
||||||
}
|
|
||||||
} else if len(e.id.LocalAddress) != 0 { // stateBound
|
|
||||||
if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 {
|
|
||||||
panic(tcpip.ErrBadLocalAddress)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Our saved state had a port, but we don't actually have a
|
|
||||||
// reservation. We need to remove the port from our state, but still
|
|
||||||
// pass it to the reservation machinery.
|
|
||||||
id := e.id
|
|
||||||
e.id.LocalPort = 0
|
|
||||||
e.id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id)
|
|
||||||
if err != nil {
|
|
||||||
panic(*err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read reads data from the endpoint. This method does not block if
|
// Read reads data from the endpoint. This method does not block if
|
||||||
// there is no data pending.
|
// there is no data pending.
|
||||||
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
|
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
|
||||||
|
|
|
@ -15,7 +15,9 @@
|
||||||
package udp
|
package udp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/buffer"
|
"gvisor.dev/gvisor/pkg/tcpip/buffer"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -64,3 +66,50 @@ func (e *endpoint) loadRcvBufSizeMax(max int) {
|
||||||
func (e *endpoint) afterLoad() {
|
func (e *endpoint) afterLoad() {
|
||||||
stack.StackFromEnv.RegisterRestoredEndpoint(e)
|
stack.StackFromEnv.RegisterRestoredEndpoint(e)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Resume implements tcpip.ResumableEndpoint.Resume.
|
||||||
|
func (e *endpoint) Resume(s *stack.Stack) {
|
||||||
|
e.stack = s
|
||||||
|
|
||||||
|
for _, m := range e.multicastMemberships {
|
||||||
|
if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.state != stateBound && e.state != stateConnected {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
netProto := e.effectiveNetProtos[0]
|
||||||
|
// Connect() and bindLocked() both assert
|
||||||
|
//
|
||||||
|
// netProto == header.IPv6ProtocolNumber
|
||||||
|
//
|
||||||
|
// before creating a multi-entry effectiveNetProtos.
|
||||||
|
if len(e.effectiveNetProtos) > 1 {
|
||||||
|
netProto = header.IPv6ProtocolNumber
|
||||||
|
}
|
||||||
|
|
||||||
|
var err *tcpip.Error
|
||||||
|
if e.state == stateConnected {
|
||||||
|
e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto, e.multicastLoop)
|
||||||
|
if err != nil {
|
||||||
|
panic(*err)
|
||||||
|
}
|
||||||
|
} else if len(e.id.LocalAddress) != 0 { // stateBound
|
||||||
|
if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 {
|
||||||
|
panic(tcpip.ErrBadLocalAddress)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Our saved state had a port, but we don't actually have a
|
||||||
|
// reservation. We need to remove the port from our state, but still
|
||||||
|
// pass it to the reservation machinery.
|
||||||
|
id := e.id
|
||||||
|
e.id.LocalPort = 0
|
||||||
|
e.id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id)
|
||||||
|
if err != nil {
|
||||||
|
panic(*err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue