Store SO_BINDTODEVICE state at bind.

This allows us to ensure that the correct port reservation is released.

Fixes #1217

PiperOrigin-RevId: 282048155
This commit is contained in:
Ian Gudger 2019-11-22 14:56:54 -08:00 committed by gVisor bot
parent 9db08c4e58
commit 8eb68912e4
4 changed files with 43 additions and 22 deletions

View File

@ -243,7 +243,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
n.initGSO()
// Register new endpoint so that packets are routed to it.
if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.bindToDevice); err != nil {
if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.boundBindToDevice); err != nil {
n.Close()
return nil, err
}

View File

@ -340,6 +340,9 @@ type endpoint struct {
// TCP should never broadcast but Linux nevertheless supports enabling/
// disabling SO_BROADCAST, albeit as a NOOP.
broadcast bool
// Values used to reserve a port or register a transport endpoint
// (which ever happens first).
boundBindToDevice tcpip.NICID
// effectiveNetProtos contains the network protocols actually in use. In
// most cases it will only contain "netProto", but in cases like IPv6
@ -730,12 +733,13 @@ func (e *endpoint) Close() {
// in Listen() when trying to register.
if e.state == StateListen && e.isPortReserved {
if e.isRegistered {
e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
e.isRegistered = false
}
e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice)
e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundBindToDevice)
e.isPortReserved = false
e.boundBindToDevice = 0
}
// Mark endpoint as closed.
@ -791,14 +795,15 @@ func (e *endpoint) cleanupLocked() {
e.workerCleanup = false
if e.isRegistered {
e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
e.isRegistered = false
}
if e.isPortReserved {
e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice)
e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundBindToDevice)
e.isPortReserved = false
}
e.boundBindToDevice = 0
e.route.Release()
e.stack.CompleteTransportEndpointCleanup(e)
@ -1741,7 +1746,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
if e.ID.LocalPort != 0 {
// The endpoint is bound to a port, attempt to register it.
err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.reusePort, e.bindToDevice)
err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.reusePort, e.boundBindToDevice)
if err != nil {
return err
}
@ -1778,7 +1783,10 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
id.LocalPort = p
switch e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) {
case nil:
// Port picking successful. Save the details of
// the selected port.
e.ID = id
e.boundBindToDevice = e.bindToDevice
return true, nil
case tcpip.ErrPortInUse:
return false, nil
@ -1794,7 +1802,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
// before Connect: in such a case we don't want to hold on to
// reservations anymore.
if e.isPortReserved {
e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.bindToDevice)
e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.boundBindToDevice)
e.isPortReserved = false
}
@ -1950,7 +1958,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
}
// Register the endpoint.
if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.reusePort, e.bindToDevice); err != nil {
if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.reusePort, e.boundBindToDevice); err != nil {
return err
}
@ -2031,6 +2039,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
return err
}
e.boundBindToDevice = e.bindToDevice
e.isPortReserved = true
e.effectiveNetProtos = netProtos
e.ID.LocalPort = port
@ -2044,8 +2053,9 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
e.ID.LocalPort = 0
e.ID.LocalAddress = ""
e.boundNICID = 0
e.boundBindToDevice = 0
}
}(e.bindToDevice)
}(e.boundBindToDevice)
// If an address is specified, we must ensure that it's one of our
// local addresses.

View File

@ -104,6 +104,10 @@ type endpoint struct {
bindToDevice tcpip.NICID
broadcast bool
// Values used to reserve a port or register a transport endpoint.
// (which ever happens first).
boundBindToDevice tcpip.NICID
// sendTOS represents IPv4 TOS or IPv6 TrafficClass,
// applied while sending packets. Defaults to 0 as on Linux.
sendTOS uint8
@ -175,8 +179,9 @@ func (e *endpoint) Close() {
switch e.state {
case StateBound, StateConnected:
e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice)
e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundBindToDevice)
e.boundBindToDevice = 0
}
for _, mem := range e.multicastMemberships {
@ -870,7 +875,10 @@ func (e *endpoint) Disconnect() *tcpip.Error {
if e.state != StateConnected {
return nil
}
id := stack.TransportEndpointID{}
var (
id stack.TransportEndpointID
btd tcpip.NICID
)
// Exclude ephemerally bound endpoints.
if e.BindNICID != 0 || e.ID.LocalAddress == "" {
var err *tcpip.Error
@ -878,7 +886,7 @@ func (e *endpoint) Disconnect() *tcpip.Error {
LocalPort: e.ID.LocalPort,
LocalAddress: e.ID.LocalAddress,
}
id, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id)
id, btd, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id)
if err != nil {
return err
}
@ -886,13 +894,14 @@ func (e *endpoint) Disconnect() *tcpip.Error {
} else {
if e.ID.LocalPort != 0 {
// Release the ephemeral port.
e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice)
e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundBindToDevice)
}
e.state = StateInitial
}
e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
e.ID = id
e.boundBindToDevice = btd
e.route.Release()
e.route = stack.Route{}
e.dstPort = 0
@ -961,17 +970,18 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
}
id, err = e.registerWithStack(nicID, netProtos, id)
id, btd, err := e.registerWithStack(nicID, netProtos, id)
if err != nil {
return err
}
// Remove the old registration.
if e.ID.LocalPort != 0 {
e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
}
e.ID = id
e.boundBindToDevice = btd
e.route = r.Clone()
e.dstPort = addr.Port
e.RegisterNICID = nicID
@ -1029,11 +1039,11 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
return nil, nil, tcpip.ErrNotSupported
}
func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) {
if e.ID.LocalPort == 0 {
port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort, e.bindToDevice)
if err != nil {
return id, err
return id, e.bindToDevice, err
}
id.LocalPort = port
}
@ -1042,7 +1052,7 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ
if err != nil {
e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.bindToDevice)
}
return id, err
return id, e.bindToDevice, err
}
func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
@ -1081,12 +1091,13 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
LocalPort: addr.Port,
LocalAddress: addr.Addr,
}
id, err = e.registerWithStack(nicID, netProtos, id)
id, btd, err := e.registerWithStack(nicID, netProtos, id)
if err != nil {
return err
}
e.ID = id
e.boundBindToDevice = btd
e.RegisterNICID = nicID
e.effectiveNetProtos = netProtos

View File

@ -109,7 +109,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
// pass it to the reservation machinery.
id := e.ID
e.ID.LocalPort = 0
e.ID, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id)
e.ID, e.boundBindToDevice, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id)
if err != nil {
panic(err)
}