diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index f2b1b68da..405a6dce7 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -172,14 +172,12 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { // If we started off with a window larger than what can he held in // the 16bit window field, we ceil the value to the max value. - // While ceiling, we still do not want to grow the right edge when - // not applicable. if scaledWnd > math.MaxUint16 { - if toGrow { - scaledWnd = seqnum.Size(math.MaxUint16) - } else { - scaledWnd = seqnum.Size(uint16(scaledWnd)) - } + scaledWnd = seqnum.Size(math.MaxUint16) + + // Ensure that the stashed receive window always reflects what + // is being advertised. + r.rcvWnd = scaledWnd << r.rcvWndScale } return r.rcvNxt, scaledWnd } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 7581bdc97..351a5e4f5 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -1932,6 +1932,84 @@ func TestFullWindowReceive(t *testing.T) { ) } +// Test the stack receive window advertisement on receiving segments smaller than +// segment overhead. It tests for the right edge of the window to not grow when +// the endpoint is not being read from. +func TestSmallSegReceiveWindowAdvertisement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + opt := tcpip.TCPReceiveBufferSizeRangeOption{ + Min: 1, + Default: tcp.DefaultReceiveBufferSize, + Max: tcp.DefaultReceiveBufferSize << tcp.FindWndScale(seqnum.Size(tcp.DefaultReceiveBufferSize)), + } + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } + + c.AcceptWithOptions(tcp.FindWndScale(seqnum.Size(opt.Default)), header.TCPSynOptions{MSS: defaultIPv4MSS}) + + // Bump up the receive buffer size such that, when the receive window grows, + // the scaled window exceeds maxUint16. + if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, opt.Max); err != nil { + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed: %s", opt.Max, err) + } + + // Keep the payload size < segment overhead and such that it is a multiple + // of the window scaled value. This enables the test to perform equality + // checks on the incoming receive window. + payload := generateRandomPayload(t, (tcp.SegSize-1)&(1< 65535 (0xffff). + if !readOnce { + if got := dut.Recv(t, acceptFd, int32(payloadLen), 0); len(got) != payloadLen { + t.Fatalf("got dut.Recv(t, %d, %d, 0) = %d, want %d", acceptFd, payloadLen, len(got), payloadLen) + } + readOnce = true + } windowSize := *gotTCP.WindowSize t.Logf("got window size = %d", windowSize) if windowSize == 0 { @@ -94,10 +105,9 @@ func TestNonZeroReceiveWindow(t *testing.T) { if err != nil { t.Fatalf("expected packet was not received: %s", err) } - if ret, _, err := dut.RecvWithErrno(context.Background(), t, acceptFd, int32(payloadLen), 0); ret == -1 { - t.Fatalf("dut.RecvWithErrno(ctx, t, %d, %d, 0) = %d,_, %s", acceptFd, payloadLen, ret, err) + if got := dut.Recv(t, acceptFd, int32(payloadLen), 0); len(got) != payloadLen { + t.Fatalf("got dut.Recv(t, %d, %d, 0) = %d, want %d", acceptFd, payloadLen, len(got), payloadLen) } - if *gotTCP.WindowSize == 0 { t.Fatalf("expected non-zero receive window.") }