Add support for tearing down protocol dispatchers and TIME_WAIT endpoints.

Protocol dispatchers were previously leaked. Bypassing TIME_WAIT is required to
test this change.

Also fix a race when a socket in SYN-RCVD is closed. This is also required to
test this change.

PiperOrigin-RevId: 296922548
This commit is contained in:
Ian Gudger 2020-02-24 10:31:01 -08:00 committed by gVisor bot
parent b8f56c79be
commit c37b196455
21 changed files with 256 additions and 66 deletions

View File

@ -127,6 +127,10 @@ func TestCloseReader(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("newLoopbackStack() = %v", err) t.Fatalf("newLoopbackStack() = %v", err)
} }
defer func() {
s.Close()
s.Wait()
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
@ -175,6 +179,10 @@ func TestCloseReaderWithForwarder(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("newLoopbackStack() = %v", err) t.Fatalf("newLoopbackStack() = %v", err)
} }
defer func() {
s.Close()
s.Wait()
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
@ -225,30 +233,21 @@ func TestCloseRead(t *testing.T) {
if terr != nil { if terr != nil {
t.Fatalf("newLoopbackStack() = %v", terr) t.Fatalf("newLoopbackStack() = %v", terr)
} }
defer func() {
s.Close()
s.Wait()
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
var wq waiter.Queue var wq waiter.Queue
ep, err := r.CreateEndpoint(&wq) _, err := r.CreateEndpoint(&wq)
if err != nil { if err != nil {
t.Fatalf("r.CreateEndpoint() = %v", err) t.Fatalf("r.CreateEndpoint() = %v", err)
} }
defer ep.Close() // Endpoint will be closed in deferred s.Close (above).
r.Complete(false)
c := NewTCPConn(&wq, ep)
buf := make([]byte, 256)
n, e := c.Read(buf)
if e != nil || string(buf[:n]) != "abc123" {
t.Fatalf("c.Read() = (%d, %v), want (6, nil)", n, e)
}
if n, e = c.Write([]byte("abc123")); e != nil {
t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, e)
}
}) })
s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket)
@ -278,6 +277,10 @@ func TestCloseWrite(t *testing.T) {
if terr != nil { if terr != nil {
t.Fatalf("newLoopbackStack() = %v", terr) t.Fatalf("newLoopbackStack() = %v", terr)
} }
defer func() {
s.Close()
s.Wait()
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
@ -334,6 +337,10 @@ func TestUDPForwarder(t *testing.T) {
if terr != nil { if terr != nil {
t.Fatalf("newLoopbackStack() = %v", terr) t.Fatalf("newLoopbackStack() = %v", terr)
} }
defer func() {
s.Close()
s.Wait()
}()
ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr1 := tcpip.FullAddress{NICID, ip1, 11211} addr1 := tcpip.FullAddress{NICID, ip1, 11211}
@ -391,6 +398,10 @@ func TestDeadlineChange(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("newLoopbackStack() = %v", err) t.Fatalf("newLoopbackStack() = %v", err)
} }
defer func() {
s.Close()
s.Wait()
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
@ -440,6 +451,10 @@ func TestPacketConnTransfer(t *testing.T) {
if e != nil { if e != nil {
t.Fatalf("newLoopbackStack() = %v", e) t.Fatalf("newLoopbackStack() = %v", e)
} }
defer func() {
s.Close()
s.Wait()
}()
ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr1 := tcpip.FullAddress{NICID, ip1, 11211} addr1 := tcpip.FullAddress{NICID, ip1, 11211}
@ -492,6 +507,10 @@ func TestConnectedPacketConnTransfer(t *testing.T) {
if e != nil { if e != nil {
t.Fatalf("newLoopbackStack() = %v", e) t.Fatalf("newLoopbackStack() = %v", e)
} }
defer func() {
s.Close()
s.Wait()
}()
ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr := tcpip.FullAddress{NICID, ip, 11211} addr := tcpip.FullAddress{NICID, ip, 11211}
@ -562,6 +581,8 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) {
stop = func() { stop = func() {
c1.Close() c1.Close()
c2.Close() c2.Close()
s.Close()
s.Wait()
} }
if err := l.Close(); err != nil { if err := l.Close(); err != nil {
@ -624,6 +645,10 @@ func TestTCPDialError(t *testing.T) {
if e != nil { if e != nil {
t.Fatalf("newLoopbackStack() = %v", e) t.Fatalf("newLoopbackStack() = %v", e)
} }
defer func() {
s.Close()
s.Wait()
}()
ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr := tcpip.FullAddress{NICID, ip, 11211} addr := tcpip.FullAddress{NICID, ip, 11211}
@ -641,6 +666,10 @@ func TestDialContextTCPCanceled(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("newLoopbackStack() = %v", err) t.Fatalf("newLoopbackStack() = %v", err)
} }
defer func() {
s.Close()
s.Wait()
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
@ -659,6 +688,10 @@ func TestDialContextTCPTimeout(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("newLoopbackStack() = %v", err) t.Fatalf("newLoopbackStack() = %v", err)
} }
defer func() {
s.Close()
s.Wait()
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)

View File

@ -148,12 +148,12 @@ func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWi
}, nil }, nil
} }
// LinkAddressProtocol implements stack.LinkAddressResolver. // LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol.
func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
return header.IPv4ProtocolNumber return header.IPv4ProtocolNumber
} }
// LinkAddressRequest implements stack.LinkAddressResolver. // LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest.
func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error { func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
r := &stack.Route{ r := &stack.Route{
RemoteLinkAddress: broadcastMAC, RemoteLinkAddress: broadcastMAC,
@ -172,7 +172,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.
}) })
} }
// ResolveStaticAddress implements stack.LinkAddressResolver. // ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress.
func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
if addr == header.IPv4Broadcast { if addr == header.IPv4Broadcast {
return broadcastMAC, true return broadcastMAC, true
@ -183,16 +183,22 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo
return tcpip.LinkAddress([]byte(nil)), false return tcpip.LinkAddress([]byte(nil)), false
} }
// SetOption implements NetworkProtocol. // SetOption implements stack.NetworkProtocol.SetOption.
func (p *protocol) SetOption(option interface{}) *tcpip.Error { func (*protocol) SetOption(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption return tcpip.ErrUnknownProtocolOption
} }
// Option implements NetworkProtocol. // Option implements stack.NetworkProtocol.Option.
func (p *protocol) Option(option interface{}) *tcpip.Error { func (*protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption return tcpip.ErrUnknownProtocolOption
} }
// Close implements stack.TransportProtocol.Close.
func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
// NewProtocol returns an ARP network protocol. // NewProtocol returns an ARP network protocol.

View File

@ -473,6 +473,12 @@ func (p *protocol) DefaultTTL() uint8 {
return uint8(atomic.LoadUint32(&p.defaultTTL)) return uint8(atomic.LoadUint32(&p.defaultTTL))
} }
// Close implements stack.TransportProtocol.Close.
func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
// calculateMTU calculates the network-layer payload MTU based on the link-layer // calculateMTU calculates the network-layer payload MTU based on the link-layer
// payload mtu. // payload mtu.
func calculateMTU(mtu uint32) uint32 { func calculateMTU(mtu uint32) uint32 {

View File

@ -265,6 +265,12 @@ func (p *protocol) DefaultTTL() uint8 {
return uint8(atomic.LoadUint32(&p.defaultTTL)) return uint8(atomic.LoadUint32(&p.defaultTTL))
} }
// Close implements stack.TransportProtocol.Close.
func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
// calculateMTU calculates the network-layer payload MTU based on the link-layer // calculateMTU calculates the network-layer payload MTU based on the link-layer
// payload mtu. // payload mtu.
func calculateMTU(mtu uint32) uint32 { func calculateMTU(mtu uint32) uint32 {

View File

@ -74,10 +74,11 @@ type TransportEndpoint interface {
// HandleControlPacket takes ownership of pkt. // HandleControlPacket takes ownership of pkt.
HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer)
// Close puts the endpoint in a closed state and frees all resources // Abort initiates an expedited endpoint teardown. It puts the endpoint
// associated with it. This cleanup may happen asynchronously. Wait can // in a closed state and frees all resources associated with it. This
// be used to block on this asynchronous cleanup. // cleanup may happen asynchronously. Wait can be used to block on this
Close() // asynchronous cleanup.
Abort()
// Wait waits for any worker goroutines owned by the endpoint to stop. // Wait waits for any worker goroutines owned by the endpoint to stop.
// //
@ -160,6 +161,13 @@ type TransportProtocol interface {
// Option returns an error if the option is not supported or the // Option returns an error if the option is not supported or the
// provided option value is invalid. // provided option value is invalid.
Option(option interface{}) *tcpip.Error Option(option interface{}) *tcpip.Error
// Close requests that any worker goroutines owned by the protocol
// stop.
Close()
// Wait waits for any worker goroutines owned by the protocol to stop.
Wait()
} }
// TransportDispatcher contains the methods used by the network stack to deliver // TransportDispatcher contains the methods used by the network stack to deliver
@ -293,6 +301,13 @@ type NetworkProtocol interface {
// Option returns an error if the option is not supported or the // Option returns an error if the option is not supported or the
// provided option value is invalid. // provided option value is invalid.
Option(option interface{}) *tcpip.Error Option(option interface{}) *tcpip.Error
// Close requests that any worker goroutines owned by the protocol
// stop.
Close()
// Wait waits for any worker goroutines owned by the protocol to stop.
Wait()
} }
// NetworkDispatcher contains the methods used by the network stack to deliver // NetworkDispatcher contains the methods used by the network stack to deliver

View File

@ -1446,7 +1446,13 @@ func (s *Stack) RestoreCleanupEndpoints(es []TransportEndpoint) {
// Endpoints created or modified during this call may not get closed. // Endpoints created or modified during this call may not get closed.
func (s *Stack) Close() { func (s *Stack) Close() {
for _, e := range s.RegisteredEndpoints() { for _, e := range s.RegisteredEndpoints() {
e.Close() e.Abort()
}
for _, p := range s.transportProtocols {
p.proto.Close()
}
for _, p := range s.networkProtocols {
p.Close()
} }
} }
@ -1464,6 +1470,12 @@ func (s *Stack) Wait() {
for _, e := range s.CleanupEndpoints() { for _, e := range s.CleanupEndpoints() {
e.Wait() e.Wait()
} }
for _, p := range s.transportProtocols {
p.proto.Wait()
}
for _, p := range s.networkProtocols {
p.Wait()
}
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()

View File

@ -235,6 +235,12 @@ func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error {
} }
} }
// Close implements TransportProtocol.Close.
func (*fakeNetworkProtocol) Close() {}
// Wait implements TransportProtocol.Wait.
func (*fakeNetworkProtocol) Wait() {}
func fakeNetFactory() stack.NetworkProtocol { func fakeNetFactory() stack.NetworkProtocol {
return &fakeNetworkProtocol{} return &fakeNetworkProtocol{}
} }

View File

@ -306,26 +306,6 @@ func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, p
ep.mu.RUnlock() // Don't use defer for performance reasons. ep.mu.RUnlock() // Don't use defer for performance reasons.
} }
// Close implements stack.TransportEndpoint.Close.
func (ep *multiPortEndpoint) Close() {
ep.mu.RLock()
eps := append([]TransportEndpoint(nil), ep.endpointsArr...)
ep.mu.RUnlock()
for _, e := range eps {
e.Close()
}
}
// Wait implements stack.TransportEndpoint.Wait.
func (ep *multiPortEndpoint) Wait() {
ep.mu.RLock()
eps := append([]TransportEndpoint(nil), ep.endpointsArr...)
ep.mu.RUnlock()
for _, e := range eps {
e.Wait()
}
}
// singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint // singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint
// list. The list might be empty already. // list. The list might be empty already.
func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error { func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error {

View File

@ -61,6 +61,10 @@ func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netP
return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
} }
func (f *fakeTransportEndpoint) Abort() {
f.Close()
}
func (f *fakeTransportEndpoint) Close() { func (f *fakeTransportEndpoint) Close() {
f.route.Release() f.route.Release()
} }
@ -272,7 +276,7 @@ func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.N
return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil
} }
func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { func (*fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return nil, tcpip.ErrUnknownProtocol return nil, tcpip.ErrUnknownProtocol
} }
@ -310,6 +314,15 @@ func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error {
} }
} }
// Abort implements TransportProtocol.Abort.
func (*fakeTransportProtocol) Abort() {}
// Close implements tcpip.Endpoint.Close.
func (*fakeTransportProtocol) Close() {}
// Wait implements TransportProtocol.Wait.
func (*fakeTransportProtocol) Wait() {}
func fakeTransFactory() stack.TransportProtocol { func fakeTransFactory() stack.TransportProtocol {
return &fakeTransportProtocol{} return &fakeTransportProtocol{}
} }

View File

@ -341,9 +341,15 @@ type ControlMessages struct {
// networking stack. // networking stack.
type Endpoint interface { type Endpoint interface {
// Close puts the endpoint in a closed state and frees all resources // Close puts the endpoint in a closed state and frees all resources
// associated with it. // associated with it. Close initiates the teardown process, the
// Endpoint may not be fully closed when Close returns.
Close() Close()
// Abort initiates an expedited endpoint teardown. As compared to
// Close, Abort prioritizes closing the Endpoint quickly over cleanly.
// Abort is best effort; implementing Abort with Close is acceptable.
Abort()
// Read reads data from the endpoint and optionally returns the sender. // Read reads data from the endpoint and optionally returns the sender.
// //
// This method does not block if there is no data pending. It will also // This method does not block if there is no data pending. It will also

View File

@ -96,6 +96,11 @@ func (e *endpoint) UniqueID() uint64 {
return e.uniqueID return e.uniqueID
} }
// Abort implements stack.TransportEndpoint.Abort.
func (e *endpoint) Abort() {
e.Close()
}
// Close puts the endpoint in a closed state and frees all resources // Close puts the endpoint in a closed state and frees all resources
// associated with it. // associated with it.
func (e *endpoint) Close() { func (e *endpoint) Close() {

View File

@ -104,20 +104,26 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error)
// HandleUnknownDestinationPacket handles packets targeted at this protocol but // HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint. // that don't match any existing endpoint.
func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool { func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool {
return true return true
} }
// SetOption implements TransportProtocol.SetOption. // SetOption implements stack.TransportProtocol.SetOption.
func (p *protocol) SetOption(option interface{}) *tcpip.Error { func (*protocol) SetOption(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption return tcpip.ErrUnknownProtocolOption
} }
// Option implements TransportProtocol.Option. // Option implements stack.TransportProtocol.Option.
func (p *protocol) Option(option interface{}) *tcpip.Error { func (*protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption return tcpip.ErrUnknownProtocolOption
} }
// Close implements stack.TransportProtocol.Close.
func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
// NewProtocol4 returns an ICMPv4 transport protocol. // NewProtocol4 returns an ICMPv4 transport protocol.
func NewProtocol4() stack.TransportProtocol { func NewProtocol4() stack.TransportProtocol {
return &protocol{ProtocolNumber4} return &protocol{ProtocolNumber4}

View File

@ -98,6 +98,11 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb
return ep, nil return ep, nil
} }
// Abort implements stack.TransportEndpoint.Abort.
func (e *endpoint) Abort() {
e.Close()
}
// Close implements tcpip.Endpoint.Close. // Close implements tcpip.Endpoint.Close.
func (ep *endpoint) Close() { func (ep *endpoint) Close() {
ep.mu.Lock() ep.mu.Lock()

View File

@ -121,6 +121,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
return e, nil return e, nil
} }
// Abort implements stack.TransportEndpoint.Abort.
func (e *endpoint) Abort() {
e.Close()
}
// Close implements tcpip.Endpoint.Close. // Close implements tcpip.Endpoint.Close.
func (e *endpoint) Close() { func (e *endpoint) Close() {
e.mu.Lock() e.mu.Lock()

View File

@ -299,6 +299,13 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept) h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept)
if err := h.execute(); err != nil { if err := h.execute(); err != nil {
ep.Close() ep.Close()
// Wake up any waiters. This is strictly not required normally
// as a socket that was never accepted can't really have any
// registered waiters except when stack.Wait() is called which
// waits for all registered endpoints to stop and expects an
// EventHUp.
ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
if l.listenEP != nil { if l.listenEP != nil {
l.removePendingEndpoint(ep) l.removePendingEndpoint(ep)
} }
@ -607,7 +614,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
e.mu.Unlock() e.mu.Unlock()
// Notify waiters that the endpoint is shutdown. // Notify waiters that the endpoint is shutdown.
e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut) e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr)
}() }()
s := sleep.Sleeper{} s := sleep.Sleeper{}

View File

@ -1372,7 +1372,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
e.snd.updateMaxPayloadSize(mtu, count) e.snd.updateMaxPayloadSize(mtu, count)
} }
if n&notifyReset != 0 { if n&notifyReset != 0 || n&notifyAbort != 0 {
return tcpip.ErrConnectionAborted return tcpip.ErrConnectionAborted
} }
@ -1655,7 +1655,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) {
} }
case notification: case notification:
n := e.fetchNotifications() n := e.fetchNotifications()
if n&notifyClose != 0 { if n&notifyClose != 0 || n&notifyAbort != 0 {
return nil return nil
} }
if n&notifyDrain != 0 { if n&notifyDrain != 0 {

View File

@ -68,17 +68,28 @@ func (q *epQueue) empty() bool {
type processor struct { type processor struct {
epQ epQueue epQ epQueue
newEndpointWaker sleep.Waker newEndpointWaker sleep.Waker
closeWaker sleep.Waker
id int id int
wg sync.WaitGroup
} }
func newProcessor(id int) *processor { func newProcessor(id int) *processor {
p := &processor{ p := &processor{
id: id, id: id,
} }
p.wg.Add(1)
go p.handleSegments() go p.handleSegments()
return p return p
} }
func (p *processor) close() {
p.closeWaker.Assert()
}
func (p *processor) wait() {
p.wg.Wait()
}
func (p *processor) queueEndpoint(ep *endpoint) { func (p *processor) queueEndpoint(ep *endpoint) {
// Queue an endpoint for processing by the processor goroutine. // Queue an endpoint for processing by the processor goroutine.
p.epQ.enqueue(ep) p.epQ.enqueue(ep)
@ -87,11 +98,17 @@ func (p *processor) queueEndpoint(ep *endpoint) {
func (p *processor) handleSegments() { func (p *processor) handleSegments() {
const newEndpointWaker = 1 const newEndpointWaker = 1
const closeWaker = 2
s := sleep.Sleeper{} s := sleep.Sleeper{}
s.AddWaker(&p.newEndpointWaker, newEndpointWaker) s.AddWaker(&p.newEndpointWaker, newEndpointWaker)
s.AddWaker(&p.closeWaker, closeWaker)
defer s.Done() defer s.Done()
for { for {
s.Fetch(true) id, ok := s.Fetch(true)
if ok && id == closeWaker {
p.wg.Done()
return
}
for ep := p.epQ.dequeue(); ep != nil; ep = p.epQ.dequeue() { for ep := p.epQ.dequeue(); ep != nil; ep = p.epQ.dequeue() {
if ep.segmentQueue.empty() { if ep.segmentQueue.empty() {
continue continue
@ -160,6 +177,18 @@ func newDispatcher(nProcessors int) *dispatcher {
} }
} }
func (d *dispatcher) close() {
for _, p := range d.processors {
p.close()
}
}
func (d *dispatcher) wait() {
for _, p := range d.processors {
p.wait()
}
}
func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) {
ep := stackEP.(*endpoint) ep := stackEP.(*endpoint)
s := newSegment(r, id, pkt) s := newSegment(r, id, pkt)

View File

@ -121,6 +121,8 @@ const (
notifyDrain notifyDrain
notifyReset notifyReset
notifyResetByPeer notifyResetByPeer
// notifyAbort is a request for an expedited teardown.
notifyAbort
notifyKeepaliveChanged notifyKeepaliveChanged
notifyMSSChanged notifyMSSChanged
// notifyTickleWorker is used to tickle the protocol main loop during a // notifyTickleWorker is used to tickle the protocol main loop during a
@ -785,6 +787,24 @@ func (e *endpoint) notifyProtocolGoroutine(n uint32) {
} }
} }
// Abort implements stack.TransportEndpoint.Abort.
func (e *endpoint) Abort() {
// The abort notification is not processed synchronously, so no
// synchronization is needed.
//
// If the endpoint becomes connected after this check, we still close
// the endpoint. This worst case results in a slower abort.
//
// If the endpoint disconnected after the check, nothing needs to be
// done, so sending a notification which will potentially be ignored is
// fine.
if e.EndpointState().connected() {
e.notifyProtocolGoroutine(notifyAbort)
return
}
e.Close()
}
// Close puts the endpoint in a closed state and frees all resources associated // Close puts the endpoint in a closed state and frees all resources associated
// with it. It must be called only once and with no other concurrent calls to // with it. It must be called only once and with no other concurrent calls to
// the endpoint. // the endpoint.
@ -829,9 +849,18 @@ func (e *endpoint) closeNoShutdown() {
// Either perform the local cleanup or kick the worker to make sure it // Either perform the local cleanup or kick the worker to make sure it
// knows it needs to cleanup. // knows it needs to cleanup.
tcpip.AddDanglingEndpoint(e) tcpip.AddDanglingEndpoint(e)
if !e.workerRunning { switch e.EndpointState() {
// Sockets in StateSynRecv state(passive connections) are closed when
// the handshake fails or if the listening socket is closed while
// handshake was in progress. In such cases the handshake goroutine
// is already gone by the time Close is called and we need to cleanup
// here.
case StateInitial, StateBound, StateSynRecv:
e.cleanupLocked() e.cleanupLocked()
} else { e.setEndpointState(StateClose)
case StateError, StateClose:
// do nothing.
default:
e.workerCleanup = true e.workerCleanup = true
e.notifyProtocolGoroutine(notifyClose) e.notifyProtocolGoroutine(notifyClose)
} }

View File

@ -194,7 +194,7 @@ func replyWithReset(s *segment) {
sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, flags, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */) sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, flags, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */)
} }
// SetOption implements TransportProtocol.SetOption. // SetOption implements stack.TransportProtocol.SetOption.
func (p *protocol) SetOption(option interface{}) *tcpip.Error { func (p *protocol) SetOption(option interface{}) *tcpip.Error {
switch v := option.(type) { switch v := option.(type) {
case SACKEnabled: case SACKEnabled:
@ -269,7 +269,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
} }
} }
// Option implements TransportProtocol.Option. // Option implements stack.TransportProtocol.Option.
func (p *protocol) Option(option interface{}) *tcpip.Error { func (p *protocol) Option(option interface{}) *tcpip.Error {
switch v := option.(type) { switch v := option.(type) {
case *SACKEnabled: case *SACKEnabled:
@ -331,6 +331,16 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
} }
} }
// Close implements stack.TransportProtocol.Close.
func (p *protocol) Close() {
p.dispatcher.close()
}
// Wait implements stack.TransportProtocol.Wait.
func (p *protocol) Wait() {
p.dispatcher.wait()
}
// NewProtocol returns a TCP transport protocol. // NewProtocol returns a TCP transport protocol.
func NewProtocol() stack.TransportProtocol { func NewProtocol() stack.TransportProtocol {
return &protocol{ return &protocol{

View File

@ -186,6 +186,11 @@ func (e *endpoint) UniqueID() uint64 {
return e.uniqueID return e.uniqueID
} }
// Abort implements stack.TransportEndpoint.Abort.
func (e *endpoint) Abort() {
e.Close()
}
// Close puts the endpoint in a closed state and frees all resources // Close puts the endpoint in a closed state and frees all resources
// associated with it. // associated with it.
func (e *endpoint) Close() { func (e *endpoint) Close() {

View File

@ -180,16 +180,22 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
return true return true
} }
// SetOption implements TransportProtocol.SetOption. // SetOption implements stack.TransportProtocol.SetOption.
func (p *protocol) SetOption(option interface{}) *tcpip.Error { func (*protocol) SetOption(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption return tcpip.ErrUnknownProtocolOption
} }
// Option implements TransportProtocol.Option. // Option implements stack.TransportProtocol.Option.
func (p *protocol) Option(option interface{}) *tcpip.Error { func (*protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption return tcpip.ErrUnknownProtocolOption
} }
// Close implements stack.TransportProtocol.Close.
func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
// NewProtocol returns a UDP transport protocol. // NewProtocol returns a UDP transport protocol.
func NewProtocol() stack.TransportProtocol { func NewProtocol() stack.TransportProtocol {
return &protocol{} return &protocol{}