Enlarge port range and fix integer overflow

Also count failed TCP port allocations

PiperOrigin-RevId: 368939619
This commit is contained in:
Kevin Krakauer 2021-04-16 16:26:31 -07:00 committed by gVisor bot
parent 6241f89f49
commit 32c18f443f
6 changed files with 51 additions and 14 deletions

View File

@ -242,6 +242,7 @@ var Metrics = tcpip.Stats{
FastRetransmit: mustCreateMetric("/netstack/tcp/fast_retransmit", "Number of TCP segments which were fast retransmitted."), FastRetransmit: mustCreateMetric("/netstack/tcp/fast_retransmit", "Number of TCP segments which were fast retransmitted."),
Timeouts: mustCreateMetric("/netstack/tcp/timeouts", "Number of times RTO expired."), Timeouts: mustCreateMetric("/netstack/tcp/timeouts", "Number of times RTO expired."),
ChecksumErrors: mustCreateMetric("/netstack/tcp/checksum_errors", "Number of segments dropped due to bad checksums."), ChecksumErrors: mustCreateMetric("/netstack/tcp/checksum_errors", "Number of segments dropped due to bad checksums."),
FailedPortReservations: mustCreateMetric("/netstack/tcp/failed_port_reservations", "Number of time TCP failed to reserve a port."),
}, },
UDP: tcpip.UDPStats{ UDP: tcpip.UDPStats{
PacketsReceived: mustCreateMetric("/netstack/udp/packets_received", "Number of UDP datagrams received via HandlePacket."), PacketsReceived: mustCreateMetric("/netstack/udp/packets_received", "Number of UDP datagrams received via HandlePacket."),

View File

@ -17,6 +17,7 @@
package ports package ports
import ( import (
"math"
"math/rand" "math/rand"
"sync/atomic" "sync/atomic"
@ -24,7 +25,10 @@ import (
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
) )
const anyIPAddress tcpip.Address = "" const (
firstEphemeral = 16000
anyIPAddress tcpip.Address = ""
)
// Reservation describes a port reservation. // Reservation describes a port reservation.
type Reservation struct { type Reservation struct {
@ -220,10 +224,8 @@ type PortManager struct {
func NewPortManager() *PortManager { func NewPortManager() *PortManager {
return &PortManager{ return &PortManager{
allocatedPorts: make(map[portDescriptor]addrToDevice), allocatedPorts: make(map[portDescriptor]addrToDevice),
// Match Linux's default ephemeral range. See: firstEphemeral: firstEphemeral,
// https://github.com/torvalds/linux/blob/e54937963fa249595824439dc839c948188dea83/net/ipv4/af_inet.c#L1842 numEphemeral: math.MaxUint16 - firstEphemeral + 1,
firstEphemeral: 32768,
numEphemeral: 28232,
} }
} }
@ -242,13 +244,13 @@ func (pm *PortManager) PickEphemeralPort(testPort PortTester) (port uint16, err
numEphemeral := pm.numEphemeral numEphemeral := pm.numEphemeral
pm.ephemeralMu.RUnlock() pm.ephemeralMu.RUnlock()
offset := uint16(rand.Int31n(int32(numEphemeral))) offset := uint32(rand.Int31n(int32(numEphemeral)))
return pickEphemeralPort(offset, firstEphemeral, numEphemeral, testPort) return pickEphemeralPort(offset, firstEphemeral, numEphemeral, testPort)
} }
// portHint atomically reads and returns the pm.hint value. // portHint atomically reads and returns the pm.hint value.
func (pm *PortManager) portHint() uint16 { func (pm *PortManager) portHint() uint32 {
return uint16(atomic.LoadUint32(&pm.hint)) return atomic.LoadUint32(&pm.hint)
} }
// incPortHint atomically increments pm.hint by 1. // incPortHint atomically increments pm.hint by 1.
@ -260,7 +262,7 @@ func (pm *PortManager) incPortHint() {
// iterates over all ephemeral ports, allowing the caller to decide whether a // iterates over all ephemeral ports, allowing the caller to decide whether a
// given port is suitable for its needs and stopping when a port is found or an // given port is suitable for its needs and stopping when a port is found or an
// error occurs. // error occurs.
func (pm *PortManager) PickEphemeralPortStable(offset uint16, testPort PortTester) (port uint16, err tcpip.Error) { func (pm *PortManager) PickEphemeralPortStable(offset uint32, testPort PortTester) (port uint16, err tcpip.Error) {
pm.ephemeralMu.RLock() pm.ephemeralMu.RLock()
firstEphemeral := pm.firstEphemeral firstEphemeral := pm.firstEphemeral
numEphemeral := pm.numEphemeral numEphemeral := pm.numEphemeral
@ -277,9 +279,9 @@ func (pm *PortManager) PickEphemeralPortStable(offset uint16, testPort PortTeste
// and iterates over the number of ports specified by count and allows the // and iterates over the number of ports specified by count and allows the
// caller to decide whether a given port is suitable for its needs, and stopping // caller to decide whether a given port is suitable for its needs, and stopping
// when a port is found or an error occurs. // when a port is found or an error occurs.
func pickEphemeralPort(offset, first, count uint16, testPort PortTester) (port uint16, err tcpip.Error) { func pickEphemeralPort(offset uint32, first, count uint16, testPort PortTester) (port uint16, err tcpip.Error) {
for i := uint16(0); i < count; i++ { for i := uint32(0); i < uint32(count); i++ {
port = first + (offset+i)%count port := uint16(uint32(first) + (offset+i)%uint32(count))
ok, err := testPort(port) ok, err := testPort(port)
if err != nil { if err != nil {
return 0, err return 0, err

View File

@ -15,6 +15,7 @@
package ports package ports
import ( import (
"math"
"math/rand" "math/rand"
"testing" "testing"
@ -482,7 +483,7 @@ func TestPickEphemeralPortStable(t *testing.T) {
if err := pm.SetPortRange(firstEphemeral, firstEphemeral+numEphemeralPorts); err != nil { if err := pm.SetPortRange(firstEphemeral, firstEphemeral+numEphemeralPorts); err != nil {
t.Fatalf("failed to set ephemeral port range: %s", err) t.Fatalf("failed to set ephemeral port range: %s", err)
} }
portOffset := uint16(rand.Int31n(int32(numEphemeralPorts))) portOffset := uint32(rand.Int31n(int32(numEphemeralPorts)))
port, err := pm.PickEphemeralPortStable(portOffset, test.f) port, err := pm.PickEphemeralPortStable(portOffset, test.f)
if diff := cmp.Diff(test.wantErr, err); diff != "" { if diff := cmp.Diff(test.wantErr, err); diff != "" {
t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff) t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff)
@ -493,3 +494,29 @@ func TestPickEphemeralPortStable(t *testing.T) {
}) })
} }
} }
// TestOverflow addresses b/183593432, wherein an overflowing uint16 causes a
// port allocation failure.
func TestOverflow(t *testing.T) {
// Use a small range and start at offsets that will cause an overflow.
count := uint16(50)
for offset := uint32(math.MaxUint16 - count); offset < math.MaxUint16; offset++ {
reservedPorts := make(map[uint16]struct{})
// Ensure we can reserve everything in the allowed range.
for i := uint16(0); i < count; i++ {
port, err := pickEphemeralPort(offset, firstEphemeral, count, func(port uint16) (bool, tcpip.Error) {
if _, ok := reservedPorts[port]; !ok {
reservedPorts[port] = struct{}{}
return true, nil
}
return false, nil
})
if err != nil {
t.Fatalf("port picking failed at iteration %d, for offset %d, len(reserved): %+v", i, offset, len(reservedPorts))
}
if port < firstEphemeral || port > firstEphemeral+count {
t.Fatalf("reserved port %d, which is not in range [%d, %d]", port, firstEphemeral, firstEphemeral+count-1)
}
}
}
}

View File

@ -1732,6 +1732,10 @@ type TCPStats struct {
// ChecksumErrors is the number of segments dropped due to bad checksums. // ChecksumErrors is the number of segments dropped due to bad checksums.
ChecksumErrors *StatCounter ChecksumErrors *StatCounter
// FailedPortReservations is the number of times TCP failed to reserve
// a port.
FailedPortReservations *StatCounter
} }
// UDPStats collects UDP-specific stats. // UDPStats collects UDP-specific stats.

View File

@ -455,6 +455,7 @@ func (e *endpoint) reserveTupleLocked() bool {
Dest: dest, Dest: dest,
} }
if !e.stack.ReserveTuple(portRes) { if !e.stack.ReserveTuple(portRes) {
e.stack.Stats().TCP.FailedPortReservations.Increment()
return false return false
} }

View File

@ -2251,7 +2251,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
panic(err) panic(err)
} }
} }
portOffset := uint16(h.Sum32()) portOffset := h.Sum32()
var twReuse tcpip.TCPTimeWaitReuseOption var twReuse tcpip.TCPTimeWaitReuseOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &twReuse); err != nil { if err := e.stack.TransportProtocolOption(ProtocolNumber, &twReuse); err != nil {
@ -2362,6 +2362,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
e.boundDest = addr e.boundDest = addr
return true, nil return true, nil
}); err != nil { }); err != nil {
e.stack.Stats().TCP.FailedPortReservations.Increment()
return err return err
} }
} }
@ -2685,6 +2686,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) {
return true, nil return true, nil
}) })
if err != nil { if err != nil {
e.stack.Stats().TCP.FailedPortReservations.Increment()
return err return err
} }