diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go index 21581257b..29454c4b9 100644 --- a/pkg/tcpip/header/tcp.go +++ b/pkg/tcpip/header/tcp.go @@ -609,5 +609,8 @@ func Acceptable(segSeq seqnum.Value, segLen seqnum.Size, rcvNxt, rcvAcc seqnum.V } // Page 70 of RFC 793 allows packets that can be made "acceptable" by trimming // the payload, so we'll accept any payload that overlaps the receieve window. - return rcvNxt.LessThan(segSeq.Add(segLen)) && segSeq.LessThan(rcvAcc) + // segSeq < rcvAcc is more correct according to RFC, however, Linux does it + // differently, it uses segSeq <= rcvAcc, we'd want to keep the same behavior + // as Linux. + return rcvNxt.LessThan(segSeq.Add(segLen)) && segSeq.LessThanEq(rcvAcc) } diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 6fe97fefd..dd89a292a 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -70,7 +70,16 @@ func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale // acceptable checks if the segment sequence number range is acceptable // according to the table on page 26 of RFC 793. func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { - return header.Acceptable(segSeq, segLen, r.rcvNxt, r.rcvAcc) + // r.rcvWnd could be much larger than the window size we advertised in our + // outgoing packets, we should use what we have advertised for acceptability + // test. + scaledWindowSize := r.rcvWnd >> r.rcvWndScale + if scaledWindowSize > 0xffff { + // This is what we actually put in the Window field. + scaledWindowSize = 0xffff + } + advertisedWindowSize := scaledWindowSize << r.rcvWndScale + return header.Acceptable(segSeq, segLen, r.rcvNxt, r.rcvNxt.Add(advertisedWindowSize)) } // getSendParams returns the parameters needed by the sender when building @@ -259,7 +268,14 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // If we are in one of the shutdown states then we need to do // additional checks before we try and process the segment. switch state { - case StateCloseWait, StateClosing, StateLastAck: + case StateCloseWait: + // If the ACK acks something not yet sent then we send an ACK. + if r.ep.snd.sndNxt.LessThan(s.ackNumber) { + r.ep.snd.sendAck() + return true, nil + } + fallthrough + case StateClosing, StateLastAck: if !s.sequenceNumber.LessThanEq(r.rcvNxt) { // Just drop the segment as we have // already received a FIN and this @@ -276,7 +292,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // SHUT_RD) then any data past the rcvNxt should // trigger a RST. endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) - if rcvClosed && r.rcvNxt.LessThan(endDataSeq) { + if state != StateCloseWait && rcvClosed && r.rcvNxt.LessThan(endDataSeq) { return true, tcpip.ErrConnectionAborted } if state == StateFinWait1 { @@ -329,13 +345,6 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { state := r.ep.EndpointState() closed := r.ep.closed - if state != StateEstablished { - drop, err := r.handleRcvdSegmentClosing(s, state, closed) - if drop || err != nil { - return drop, err - } - } - segLen := seqnum.Size(s.data.Size()) segSeq := s.sequenceNumber @@ -347,6 +356,13 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { return true, nil } + if state != StateEstablished { + drop, err := r.handleRcvdSegmentClosing(s, state, closed) + if drop || err != nil { + return drop, err + } + } + // Store the time of the last ack. r.lastRcvdAckTime = time.Now() diff --git a/pkg/tcpip/transport/tcp/rcv_test.go b/pkg/tcpip/transport/tcp/rcv_test.go index c9eeff935..8a026ec46 100644 --- a/pkg/tcpip/transport/tcp/rcv_test.go +++ b/pkg/tcpip/transport/tcp/rcv_test.go @@ -30,7 +30,7 @@ func TestAcceptable(t *testing.T) { }{ // The segment is smaller than the window. {105, 2, 100, 104, false}, - {105, 2, 101, 105, false}, + {105, 2, 101, 105, true}, {105, 2, 102, 106, true}, {105, 2, 103, 107, true}, {105, 2, 104, 108, true}, @@ -39,7 +39,7 @@ func TestAcceptable(t *testing.T) { {105, 2, 107, 111, false}, // The segment is larger than the window. - {105, 4, 103, 105, false}, + {105, 4, 103, 105, true}, {105, 4, 104, 106, true}, {105, 4, 105, 107, true}, {105, 4, 106, 108, true}, diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD index e4ced444b..563823d04 100644 --- a/test/packetimpact/tests/BUILD +++ b/test/packetimpact/tests/BUILD @@ -117,8 +117,6 @@ packetimpact_go_test( packetimpact_go_test( name = "tcp_close_wait_ack", srcs = ["tcp_close_wait_ack_test.go"], - # TODO(b/153574037): Fix netstack then remove the line below. - netstack = False, deps = [ "//pkg/tcpip/header", "//pkg/tcpip/seqnum",