diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 234fb95ce..bde071f2a 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -11,7 +11,8 @@ go_template_instance( template = "//pkg/ilist:generic_list", types = { "Element": "*segment", - "Linker": "*segment", + "ElementMapper": "segmentMapper", + "Linker": "*segmentEntry", }, ) @@ -27,6 +28,19 @@ go_template_instance( }, ) +go_template_instance( + name = "tcp_rack_segment_list", + out = "tcp_rack_segment_list.go", + package = "tcp", + prefix = "rackSegment", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*segment", + "ElementMapper": "rackSegmentMapper", + "Linker": "*rackSegmentEntry", + }, +) + go_library( name = "tcp", srcs = [ @@ -55,6 +69,7 @@ go_library( "snd.go", "snd_state.go", "tcp_endpoint_list.go", + "tcp_rack_segment_list.go", "tcp_segment_list.go", "timer.go", ], diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 290172ac9..87980c0a1 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -924,7 +924,18 @@ func (e *endpoint) handleWrite() *tcpip.Error { first := e.sndQueue.Front() if first != nil { + lastSeg := e.snd.writeList.Back() e.snd.writeList.PushBackList(&e.sndQueue) + if lastSeg == nil { + lastSeg = e.snd.writeList.Front() + } else { + lastSeg = lastSeg.segEntry.Next() + } + // Add new segments to rcList, as rcList and writeList should + // be consistent. + for seg := lastSeg; seg != nil; seg = seg.segEntry.Next() { + e.snd.rcList.PushBack(seg) + } e.sndBufInQueue = 0 } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 1ccedebcc..21a4b6e2f 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1428,7 +1428,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro vec = append([][]byte(nil), vec...) var num int64 - for s := e.rcvList.Front(); s != nil; s = s.Next() { + for s := e.rcvList.Front(); s != nil; s = s.segEntry.Next() { views := s.data.Views() for i := s.viewToDeliver; i < len(views); i++ { @@ -2249,7 +2249,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc if !handshake { e.segmentQueue.mu.Lock() for _, l := range []segmentList{e.segmentQueue.list, e.sndQueue, e.snd.writeList} { - for s := l.Front(); s != nil; s = s.Next() { + for s := l.Front(); s != nil; s = s.segEntry.Next() { s.id = e.ID s.route = r.Clone() e.sndWaker.Assert() diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 94307d31a..a20755f78 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -30,12 +30,13 @@ import ( // // +stateify savable type segment struct { - segmentEntry - refCnt int32 - id stack.TransportEndpointID `state:"manual"` - route stack.Route `state:"manual"` - data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - hdr header.TCP + segEntry segmentEntry + rackSegEntry rackSegmentEntry + refCnt int32 + id stack.TransportEndpointID `state:"manual"` + route stack.Route `state:"manual"` + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + hdr header.TCP // views is used as buffer for data when its length is large // enough to store a VectorisedView. views [8]buffer.View `state:"nosave"` @@ -61,6 +62,16 @@ type segment struct { xmitCount uint32 } +// segmentMapper is the ElementMapper for the writeList. +type segmentMapper struct{} + +func (segmentMapper) linkerFor(seg *segment) *segmentEntry { return &seg.segEntry } + +// rackSegmentMapper is the ElementMapper for the rcList. +type rackSegmentMapper struct{} + +func (rackSegmentMapper) linkerFor(seg *segment) *rackSegmentEntry { return &seg.rackSegEntry } + func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment { s := &segment{ refCnt: 1, diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index c55589c45..31151f23d 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -154,6 +154,7 @@ type sender struct { closed bool writeNext *segment writeList segmentList + rcList rackSegmentList resendTimer timer `state:"nosave"` resendWaker sleep.Waker `state:"nosave"` @@ -367,7 +368,7 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { // Rewind writeNext to the first segment exceeding the MTU. Do nothing // if it is already before such a packet. - for seg := s.writeList.Front(); seg != nil; seg = seg.Next() { + for seg := s.writeList.Front(); seg != nil; seg = seg.segEntry.Next() { if seg == s.writeNext { // We got to writeNext before we could find a segment // exceeding the MTU. @@ -622,6 +623,7 @@ func (s *sender) splitSeg(seg *segment, size int) { nSeg.data.TrimFront(size) nSeg.sequenceNumber.UpdateForward(seqnum.Size(size)) s.writeList.InsertAfter(seg, nSeg) + s.rcList.InsertAfter(seg, nSeg) // The segment being split does not carry PUSH flag because it is // followed by the newly split segment. @@ -653,7 +655,7 @@ func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRt var s3 *segment var s4 *segment // Step 1. - for seg := nextSegHint; seg != nil; seg = seg.Next() { + for seg := nextSegHint; seg != nil; seg = seg.segEntry.Next() { // Stop iteration if we hit a segment that has never been // transmitted (i.e. either it has no assigned sequence number // or if it does have one, it's >= the next sequence number @@ -683,7 +685,7 @@ func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRt // NextSeg(): // (1.c) IsLost(S2) returns true. if s.ep.scoreboard.IsLost(segSeq) { - return seg, seg.Next(), false + return seg, seg.segEntry.Next(), false } // NextSeg(): @@ -697,7 +699,7 @@ func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRt // SHOULD be returned. if s3 == nil { s3 = seg - hint = seg.Next() + hint = seg.segEntry.Next() } } // NextSeg(): @@ -731,7 +733,7 @@ func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRt // range of one segment of up to SMSS octets of // previously unsent data starting with sequence number // HighData+1 MUST be returned." - for seg := s.writeNext; seg != nil; seg = seg.Next() { + for seg := s.writeNext; seg != nil; seg = seg.segEntry.Next() { if s.isAssignedSequenceNumber(seg) && seg.sequenceNumber.LessThan(s.sndNxt) { continue } @@ -773,15 +775,16 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se // triggering bugs in poorly written DNS // implementations. var nextTooBig bool - for seg.Next() != nil && seg.Next().data.Size() != 0 { - if seg.data.Size()+seg.Next().data.Size() > available { + for seg.segEntry.Next() != nil && seg.segEntry.Next().data.Size() != 0 { + if seg.data.Size()+seg.segEntry.Next().data.Size() > available { nextTooBig = true break } - seg.data.Append(seg.Next().data) + seg.data.Append(seg.segEntry.Next().data) // Consume the segment that we just merged in. - s.writeList.Remove(seg.Next()) + s.writeList.Remove(seg.segEntry.Next()) + s.rcList.Remove(seg.rackSegEntry.Next()) } if !nextTooBig && seg.data.Size() < available { // Segment is not full. @@ -948,7 +951,7 @@ func (s *sender) handleSACKRecovery(limit int, end seqnum.Value) (dataSent bool) } dataSent = true s.outstanding++ - s.writeNext = nextSeg.Next() + s.writeNext = nextSeg.segEntry.Next() continue } @@ -961,6 +964,7 @@ func (s *sender) handleSACKRecovery(limit int, end seqnum.Value) (dataSent bool) // transmitted in (C.1)." s.outstanding++ dataSent = true + s.sendSegment(nextSeg) segEnd := nextSeg.sequenceNumber.Add(nextSeg.logicalLen()) @@ -1035,7 +1039,7 @@ func (s *sender) sendData() { if s.fr.active && s.ep.sackPermitted { dataSent = s.handleSACKRecovery(s.maxPayloadSize, end) } else { - for seg := s.writeNext; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() { + for seg := s.writeNext; seg != nil && s.outstanding < s.sndCwnd; seg = seg.segEntry.Next() { cwndLimit := (s.sndCwnd - s.outstanding) * s.maxPayloadSize if cwndLimit < limit { limit = cwndLimit @@ -1043,7 +1047,7 @@ func (s *sender) sendData() { if s.isAssignedSequenceNumber(seg) && s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) { // Move writeNext along so that we don't try and scan data that // has already been SACKED. - s.writeNext = seg.Next() + s.writeNext = seg.segEntry.Next() continue } if sent := s.maybeSendSegment(seg, limit, end); !sent { @@ -1051,7 +1055,7 @@ func (s *sender) sendData() { } dataSent = true s.outstanding += s.pCount(seg) - s.writeNext = seg.Next() + s.writeNext = seg.segEntry.Next() } } @@ -1182,7 +1186,7 @@ func (s *sender) SetPipe() { } pipe := 0 smss := seqnum.Size(s.ep.scoreboard.SMSS()) - for s1 := s.writeList.Front(); s1 != nil && s1.data.Size() != 0 && s.isAssignedSequenceNumber(s1); s1 = s1.Next() { + for s1 := s.writeList.Front(); s1 != nil && s1.data.Size() != 0 && s.isAssignedSequenceNumber(s1); s1 = s1.segEntry.Next() { // With GSO each segment can be much larger than SMSS. So check the segment // in SMSS sized ranges. segEnd := s1.sequenceNumber.Add(seqnum.Size(s1.data.Size())) @@ -1384,7 +1388,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { } if s.writeNext == seg { - s.writeNext = seg.Next() + s.writeNext = seg.segEntry.Next() } // Update the RACK fields if SACK is enabled. @@ -1393,6 +1397,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { } s.writeList.Remove(seg) + s.rcList.Remove(seg) // if SACK is enabled then Only reduce outstanding if // the segment was not previously SACKED as these have @@ -1460,6 +1465,12 @@ func (s *sender) sendSegment(seg *segment) *tcpip.Error { if s.sndCwnd < s.sndSsthresh { s.ep.stack.Stats().TCP.SlowStartRetransmits.Increment() } + + // Move the segment which has to be retransmitted to the end of the list, as + // RACK requires the segments in the order of their transmission times. + // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-09#section-6.2 + // Step 5 + s.rcList.PushBack(seg) } seg.xmitTime = time.Now() seg.xmitCount++