diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index ea0a0409a..3c552988a 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -127,6 +127,10 @@ func TestCloseReader(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} @@ -175,6 +179,10 @@ func TestCloseReaderWithForwarder(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) @@ -225,30 +233,21 @@ func TestCloseRead(t *testing.T) { if terr != nil { t.Fatalf("newLoopbackStack() = %v", terr) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { var wq waiter.Queue - ep, err := r.CreateEndpoint(&wq) + _, err := r.CreateEndpoint(&wq) if err != nil { t.Fatalf("r.CreateEndpoint() = %v", err) } - defer ep.Close() - r.Complete(false) - - c := NewTCPConn(&wq, ep) - - buf := make([]byte, 256) - n, e := c.Read(buf) - if e != nil || string(buf[:n]) != "abc123" { - t.Fatalf("c.Read() = (%d, %v), want (6, nil)", n, e) - } - - if n, e = c.Write([]byte("abc123")); e != nil { - t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, e) - } + // Endpoint will be closed in deferred s.Close (above). }) s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) @@ -278,6 +277,10 @@ func TestCloseWrite(t *testing.T) { if terr != nil { t.Fatalf("newLoopbackStack() = %v", terr) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) @@ -334,6 +337,10 @@ func TestUDPForwarder(t *testing.T) { if terr != nil { t.Fatalf("newLoopbackStack() = %v", terr) } + defer func() { + s.Close() + s.Wait() + }() ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr1 := tcpip.FullAddress{NICID, ip1, 11211} @@ -391,6 +398,10 @@ func TestDeadlineChange(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} @@ -440,6 +451,10 @@ func TestPacketConnTransfer(t *testing.T) { if e != nil { t.Fatalf("newLoopbackStack() = %v", e) } + defer func() { + s.Close() + s.Wait() + }() ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr1 := tcpip.FullAddress{NICID, ip1, 11211} @@ -492,6 +507,10 @@ func TestConnectedPacketConnTransfer(t *testing.T) { if e != nil { t.Fatalf("newLoopbackStack() = %v", e) } + defer func() { + s.Close() + s.Wait() + }() ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr := tcpip.FullAddress{NICID, ip, 11211} @@ -562,6 +581,8 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) { stop = func() { c1.Close() c2.Close() + s.Close() + s.Wait() } if err := l.Close(); err != nil { @@ -624,6 +645,10 @@ func TestTCPDialError(t *testing.T) { if e != nil { t.Fatalf("newLoopbackStack() = %v", e) } + defer func() { + s.Close() + s.Wait() + }() ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr := tcpip.FullAddress{NICID, ip, 11211} @@ -641,6 +666,10 @@ func TestDialContextTCPCanceled(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) @@ -659,6 +688,10 @@ func TestDialContextTCPTimeout(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 4da13c5df..e9fcc89a8 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -148,12 +148,12 @@ func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWi }, nil } -// LinkAddressProtocol implements stack.LinkAddressResolver. +// LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol. func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return header.IPv4ProtocolNumber } -// LinkAddressRequest implements stack.LinkAddressResolver. +// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error { r := &stack.Route{ RemoteLinkAddress: broadcastMAC, @@ -172,7 +172,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. }) } -// ResolveStaticAddress implements stack.LinkAddressResolver. +// ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress. func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { if addr == header.IPv4Broadcast { return broadcastMAC, true @@ -183,16 +183,22 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo return tcpip.LinkAddress([]byte(nil)), false } -// SetOption implements NetworkProtocol. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +// SetOption implements stack.NetworkProtocol.SetOption. +func (*protocol) SetOption(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -// Option implements NetworkProtocol. -func (p *protocol) Option(option interface{}) *tcpip.Error { +// Option implements stack.NetworkProtocol.Option. +func (*protocol) Option(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) // NewProtocol returns an ARP network protocol. diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 6597e6781..4f1742938 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -473,6 +473,12 @@ func (p *protocol) DefaultTTL() uint8 { return uint8(atomic.LoadUint32(&p.defaultTTL)) } +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + // calculateMTU calculates the network-layer payload MTU based on the link-layer // payload mtu. func calculateMTU(mtu uint32) uint32 { diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 180a480fd..9aef5234b 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -265,6 +265,12 @@ func (p *protocol) DefaultTTL() uint8 { return uint8(atomic.LoadUint32(&p.defaultTTL)) } +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + // calculateMTU calculates the network-layer payload MTU based on the link-layer // payload mtu. func calculateMTU(mtu uint32) uint32 { diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index d83adf0ec..f9fd8f18f 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -74,10 +74,11 @@ type TransportEndpoint interface { // HandleControlPacket takes ownership of pkt. HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) - // Close puts the endpoint in a closed state and frees all resources - // associated with it. This cleanup may happen asynchronously. Wait can - // be used to block on this asynchronous cleanup. - Close() + // Abort initiates an expedited endpoint teardown. It puts the endpoint + // in a closed state and frees all resources associated with it. This + // cleanup may happen asynchronously. Wait can be used to block on this + // asynchronous cleanup. + Abort() // Wait waits for any worker goroutines owned by the endpoint to stop. // @@ -160,6 +161,13 @@ type TransportProtocol interface { // Option returns an error if the option is not supported or the // provided option value is invalid. Option(option interface{}) *tcpip.Error + + // Close requests that any worker goroutines owned by the protocol + // stop. + Close() + + // Wait waits for any worker goroutines owned by the protocol to stop. + Wait() } // TransportDispatcher contains the methods used by the network stack to deliver @@ -293,6 +301,13 @@ type NetworkProtocol interface { // Option returns an error if the option is not supported or the // provided option value is invalid. Option(option interface{}) *tcpip.Error + + // Close requests that any worker goroutines owned by the protocol + // stop. + Close() + + // Wait waits for any worker goroutines owned by the protocol to stop. + Wait() } // NetworkDispatcher contains the methods used by the network stack to deliver diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 900dd46c5..ebb6c5e3b 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1446,7 +1446,13 @@ func (s *Stack) RestoreCleanupEndpoints(es []TransportEndpoint) { // Endpoints created or modified during this call may not get closed. func (s *Stack) Close() { for _, e := range s.RegisteredEndpoints() { - e.Close() + e.Abort() + } + for _, p := range s.transportProtocols { + p.proto.Close() + } + for _, p := range s.networkProtocols { + p.Close() } } @@ -1464,6 +1470,12 @@ func (s *Stack) Wait() { for _, e := range s.CleanupEndpoints() { e.Wait() } + for _, p := range s.transportProtocols { + p.proto.Wait() + } + for _, p := range s.networkProtocols { + p.Wait() + } s.mu.RLock() defer s.mu.RUnlock() diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 18016e7db..edf6bec52 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -235,6 +235,12 @@ func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error { } } +// Close implements TransportProtocol.Close. +func (*fakeNetworkProtocol) Close() {} + +// Wait implements TransportProtocol.Wait. +func (*fakeNetworkProtocol) Wait() {} + func fakeNetFactory() stack.NetworkProtocol { return &fakeNetworkProtocol{} } diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index d686e6eb8..778c0a4d6 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -306,26 +306,6 @@ func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, p ep.mu.RUnlock() // Don't use defer for performance reasons. } -// Close implements stack.TransportEndpoint.Close. -func (ep *multiPortEndpoint) Close() { - ep.mu.RLock() - eps := append([]TransportEndpoint(nil), ep.endpointsArr...) - ep.mu.RUnlock() - for _, e := range eps { - e.Close() - } -} - -// Wait implements stack.TransportEndpoint.Wait. -func (ep *multiPortEndpoint) Wait() { - ep.mu.RLock() - eps := append([]TransportEndpoint(nil), ep.endpointsArr...) - ep.mu.RUnlock() - for _, e := range eps { - e.Wait() - } -} - // singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint // list. The list might be empty already. func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 869c69a6d..5d1da2f8b 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -61,6 +61,10 @@ func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netP return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} } +func (f *fakeTransportEndpoint) Abort() { + f.Close() +} + func (f *fakeTransportEndpoint) Close() { f.route.Release() } @@ -272,7 +276,7 @@ func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.N return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil } -func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (*fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { return nil, tcpip.ErrUnknownProtocol } @@ -310,6 +314,15 @@ func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error { } } +// Abort implements TransportProtocol.Abort. +func (*fakeTransportProtocol) Abort() {} + +// Close implements tcpip.Endpoint.Close. +func (*fakeTransportProtocol) Close() {} + +// Wait implements TransportProtocol.Wait. +func (*fakeTransportProtocol) Wait() {} + func fakeTransFactory() stack.TransportProtocol { return &fakeTransportProtocol{} } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index ce5527391..3dc5d87d6 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -341,9 +341,15 @@ type ControlMessages struct { // networking stack. type Endpoint interface { // Close puts the endpoint in a closed state and frees all resources - // associated with it. + // associated with it. Close initiates the teardown process, the + // Endpoint may not be fully closed when Close returns. Close() + // Abort initiates an expedited endpoint teardown. As compared to + // Close, Abort prioritizes closing the Endpoint quickly over cleanly. + // Abort is best effort; implementing Abort with Close is acceptable. + Abort() + // Read reads data from the endpoint and optionally returns the sender. // // This method does not block if there is no data pending. It will also diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 42afb3f5b..426da1ee6 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -96,6 +96,11 @@ func (e *endpoint) UniqueID() uint64 { return e.uniqueID } +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + e.Close() +} + // Close puts the endpoint in a closed state and frees all resources // associated with it. func (e *endpoint) Close() { diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 9ce500e80..113d92901 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -104,20 +104,26 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool { +func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool { return true } -// SetOption implements TransportProtocol.SetOption. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +// SetOption implements stack.TransportProtocol.SetOption. +func (*protocol) SetOption(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -// Option implements TransportProtocol.Option. -func (p *protocol) Option(option interface{}) *tcpip.Error { +// Option implements stack.TransportProtocol.Option. +func (*protocol) Option(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + // NewProtocol4 returns an ICMPv4 transport protocol. func NewProtocol4() stack.TransportProtocol { return &protocol{ProtocolNumber4} diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index fc5bc69fa..5722815e9 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -98,6 +98,11 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb return ep, nil } +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + e.Close() +} + // Close implements tcpip.Endpoint.Close. func (ep *endpoint) Close() { ep.mu.Lock() diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index ee9c4c58b..2ef5fac76 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -121,6 +121,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt return e, nil } +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + e.Close() +} + // Close implements tcpip.Endpoint.Close. func (e *endpoint) Close() { e.mu.Lock() diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 08afb7c17..13e383ffc 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -299,6 +299,13 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept) if err := h.execute(); err != nil { ep.Close() + // Wake up any waiters. This is strictly not required normally + // as a socket that was never accepted can't really have any + // registered waiters except when stack.Wait() is called which + // waits for all registered endpoints to stop and expects an + // EventHUp. + ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + if l.listenEP != nil { l.removePendingEndpoint(ep) } @@ -607,7 +614,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { e.mu.Unlock() // Notify waiters that the endpoint is shutdown. - e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut) + e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr) }() s := sleep.Sleeper{} diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 5c5397823..7730e6445 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -1372,7 +1372,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ e.snd.updateMaxPayloadSize(mtu, count) } - if n¬ifyReset != 0 { + if n¬ifyReset != 0 || n¬ifyAbort != 0 { return tcpip.ErrConnectionAborted } @@ -1655,7 +1655,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) { } case notification: n := e.fetchNotifications() - if n¬ifyClose != 0 { + if n¬ifyClose != 0 || n¬ifyAbort != 0 { return nil } if n¬ifyDrain != 0 { diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index e18012ac0..d792b07d6 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -68,17 +68,28 @@ func (q *epQueue) empty() bool { type processor struct { epQ epQueue newEndpointWaker sleep.Waker + closeWaker sleep.Waker id int + wg sync.WaitGroup } func newProcessor(id int) *processor { p := &processor{ id: id, } + p.wg.Add(1) go p.handleSegments() return p } +func (p *processor) close() { + p.closeWaker.Assert() +} + +func (p *processor) wait() { + p.wg.Wait() +} + func (p *processor) queueEndpoint(ep *endpoint) { // Queue an endpoint for processing by the processor goroutine. p.epQ.enqueue(ep) @@ -87,11 +98,17 @@ func (p *processor) queueEndpoint(ep *endpoint) { func (p *processor) handleSegments() { const newEndpointWaker = 1 + const closeWaker = 2 s := sleep.Sleeper{} s.AddWaker(&p.newEndpointWaker, newEndpointWaker) + s.AddWaker(&p.closeWaker, closeWaker) defer s.Done() for { - s.Fetch(true) + id, ok := s.Fetch(true) + if ok && id == closeWaker { + p.wg.Done() + return + } for ep := p.epQ.dequeue(); ep != nil; ep = p.epQ.dequeue() { if ep.segmentQueue.empty() { continue @@ -160,6 +177,18 @@ func newDispatcher(nProcessors int) *dispatcher { } } +func (d *dispatcher) close() { + for _, p := range d.processors { + p.close() + } +} + +func (d *dispatcher) wait() { + for _, p := range d.processors { + p.wait() + } +} + func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { ep := stackEP.(*endpoint) s := newSegment(r, id, pkt) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index f2be0e651..f1ad19dac 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -121,6 +121,8 @@ const ( notifyDrain notifyReset notifyResetByPeer + // notifyAbort is a request for an expedited teardown. + notifyAbort notifyKeepaliveChanged notifyMSSChanged // notifyTickleWorker is used to tickle the protocol main loop during a @@ -785,6 +787,24 @@ func (e *endpoint) notifyProtocolGoroutine(n uint32) { } } +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + // The abort notification is not processed synchronously, so no + // synchronization is needed. + // + // If the endpoint becomes connected after this check, we still close + // the endpoint. This worst case results in a slower abort. + // + // If the endpoint disconnected after the check, nothing needs to be + // done, so sending a notification which will potentially be ignored is + // fine. + if e.EndpointState().connected() { + e.notifyProtocolGoroutine(notifyAbort) + return + } + e.Close() +} + // Close puts the endpoint in a closed state and frees all resources associated // with it. It must be called only once and with no other concurrent calls to // the endpoint. @@ -829,9 +849,18 @@ func (e *endpoint) closeNoShutdown() { // Either perform the local cleanup or kick the worker to make sure it // knows it needs to cleanup. tcpip.AddDanglingEndpoint(e) - if !e.workerRunning { + switch e.EndpointState() { + // Sockets in StateSynRecv state(passive connections) are closed when + // the handshake fails or if the listening socket is closed while + // handshake was in progress. In such cases the handshake goroutine + // is already gone by the time Close is called and we need to cleanup + // here. + case StateInitial, StateBound, StateSynRecv: e.cleanupLocked() - } else { + e.setEndpointState(StateClose) + case StateError, StateClose: + // do nothing. + default: e.workerCleanup = true e.notifyProtocolGoroutine(notifyClose) } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 958c06fa7..73098d904 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -194,7 +194,7 @@ func replyWithReset(s *segment) { sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, flags, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */) } -// SetOption implements TransportProtocol.SetOption. +// SetOption implements stack.TransportProtocol.SetOption. func (p *protocol) SetOption(option interface{}) *tcpip.Error { switch v := option.(type) { case SACKEnabled: @@ -269,7 +269,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { } } -// Option implements TransportProtocol.Option. +// Option implements stack.TransportProtocol.Option. func (p *protocol) Option(option interface{}) *tcpip.Error { switch v := option.(type) { case *SACKEnabled: @@ -331,6 +331,16 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { } } +// Close implements stack.TransportProtocol.Close. +func (p *protocol) Close() { + p.dispatcher.close() +} + +// Wait implements stack.TransportProtocol.Wait. +func (p *protocol) Wait() { + p.dispatcher.wait() +} + // NewProtocol returns a TCP transport protocol. func NewProtocol() stack.TransportProtocol { return &protocol{ diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index eff7f3600..1c6a600b8 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -186,6 +186,11 @@ func (e *endpoint) UniqueID() uint64 { return e.uniqueID } +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + e.Close() +} + // Close puts the endpoint in a closed state and frees all resources // associated with it. func (e *endpoint) Close() { diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 259c3072a..8df089d22 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -180,16 +180,22 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans return true } -// SetOption implements TransportProtocol.SetOption. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +// SetOption implements stack.TransportProtocol.SetOption. +func (*protocol) SetOption(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -// Option implements TransportProtocol.Option. -func (p *protocol) Option(option interface{}) *tcpip.Error { +// Option implements stack.TransportProtocol.Option. +func (*protocol) Option(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + // NewProtocol returns a UDP transport protocol. func NewProtocol() stack.TransportProtocol { return &protocol{}