// Copyright 2018 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package transport import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/waiter" ) // UniqueIDProvider generates a sequence of unique identifiers useful for, // among other things, lock ordering. type UniqueIDProvider interface { // UniqueID returns a new unique identifier. UniqueID() uint64 } // A ConnectingEndpoint is a connectioned unix endpoint that is attempting to // establish a bidirectional connection with a BoundEndpoint. type ConnectingEndpoint interface { // ID returns the endpoint's globally unique identifier. This identifier // must be used to determine locking order if more than one endpoint is // to be locked in the same codepath. The endpoint with the smaller // identifier must be locked before endpoints with larger identifiers. ID() uint64 // Passcred implements socket.Credentialer.Passcred. Passcred() bool // Type returns the socket type, typically either SockStream or // SockSeqpacket. The connection attempt must be aborted if this // value doesn't match the ConnectableEndpoint's type. Type() linux.SockType // GetLocalAddress returns the bound path. GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) // Locker protects the following methods. While locked, only the holder of // the lock can change the return value of the protected methods. sync.Locker // Connected returns true iff the ConnectingEndpoint is in the connected // state. ConnectingEndpoints can only be connected to a single endpoint, // so the connection attempt must be aborted if this returns true. Connected() bool // Listening returns true iff the ConnectingEndpoint is in the listening // state. ConnectingEndpoints cannot make connections while listening, so // the connection attempt must be aborted if this returns true. Listening() bool // WaiterQueue returns a pointer to the endpoint's waiter queue. WaiterQueue() *waiter.Queue } // connectionedEndpoint is a Unix-domain connected or connectable endpoint and implements // ConnectingEndpoint, ConnectableEndpoint and tcpip.Endpoint. // // connectionedEndpoints must be in connected state in order to transfer data. // // This implementation includes STREAM and SEQPACKET Unix sockets created with // socket(2), accept(2) or socketpair(2) and dgram unix sockets created with // socketpair(2). See unix_connectionless.go for the implementation of DGRAM // Unix sockets created with socket(2). // // The state is much simpler than a TCP endpoint, so it is not encoded // explicitly. Instead we enforce the following invariants: // // receiver != nil, connected != nil => connected. // path != "" && acceptedChan == nil => bound, not listening. // path != "" && acceptedChan != nil => bound and listening. // // Only one of these will be true at any moment. // // +stateify savable type connectionedEndpoint struct { baseEndpoint // id is the unique endpoint identifier. This is used exclusively for // lock ordering within connect. id uint64 // idGenerator is used to generate new unique endpoint identifiers. idGenerator UniqueIDProvider // stype is used by connecting sockets to ensure that they are the // same type. The value is typically either tcpip.SockSeqpacket or // tcpip.SockStream. stype linux.SockType // acceptedChan is per the TCP endpoint implementation. Note that the // sockets in this channel are _already in the connected state_, and // have another associated connectionedEndpoint. // // If nil, then no listen call has been made. acceptedChan chan *connectionedEndpoint `state:".([]*connectionedEndpoint)"` } var ( _ = BoundEndpoint((*connectionedEndpoint)(nil)) _ = Endpoint((*connectionedEndpoint)(nil)) ) // NewConnectioned creates a new unbound connectionedEndpoint. func NewConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) Endpoint { return &connectionedEndpoint{ baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, id: uid.UniqueID(), idGenerator: uid, stype: stype, } } // NewPair allocates a new pair of connected unix-domain connectionedEndpoints. func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (Endpoint, Endpoint) { a := &connectionedEndpoint{ baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, id: uid.UniqueID(), idGenerator: uid, stype: stype, } b := &connectionedEndpoint{ baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, id: uid.UniqueID(), idGenerator: uid, stype: stype, } q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: initialLimit} q1.EnableLeakCheck("transport.queue") q2 := &queue{ReaderQueue: b.Queue, WriterQueue: a.Queue, limit: initialLimit} q2.EnableLeakCheck("transport.queue") if stype == linux.SOCK_STREAM { a.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q1}} b.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q2}} } else { a.receiver = &queueReceiver{q1} b.receiver = &queueReceiver{q2} } q2.IncRef() a.connected = &connectedEndpoint{ endpoint: b, writeQueue: q2, } q1.IncRef() b.connected = &connectedEndpoint{ endpoint: a, writeQueue: q1, } return a, b } // NewExternal creates a new externally backed Endpoint. It behaves like a // socketpair. func NewExternal(ctx context.Context, stype linux.SockType, uid UniqueIDProvider, queue *waiter.Queue, receiver Receiver, connected ConnectedEndpoint) Endpoint { return &connectionedEndpoint{ baseEndpoint: baseEndpoint{Queue: queue, receiver: receiver, connected: connected}, id: uid.UniqueID(), idGenerator: uid, stype: stype, } } // ID implements ConnectingEndpoint.ID. func (e *connectionedEndpoint) ID() uint64 { return e.id } // Type implements ConnectingEndpoint.Type and Endpoint.Type. func (e *connectionedEndpoint) Type() linux.SockType { return e.stype } // WaiterQueue implements ConnectingEndpoint.WaiterQueue. func (e *connectionedEndpoint) WaiterQueue() *waiter.Queue { return e.Queue } // isBound returns true iff the connectionedEndpoint is bound (but not // listening). func (e *connectionedEndpoint) isBound() bool { return e.path != "" && e.acceptedChan == nil } // Listening implements ConnectingEndpoint.Listening. func (e *connectionedEndpoint) Listening() bool { return e.acceptedChan != nil } // Close puts the connectionedEndpoint in a closed state and frees all // resources associated with it. // // The socket will be a fresh state after a call to close and may be reused. // That is, close may be used to "unbind" or "disconnect" the socket in error // paths. func (e *connectionedEndpoint) Close() { e.Lock() var c ConnectedEndpoint var r Receiver switch { case e.Connected(): e.connected.CloseSend() e.receiver.CloseRecv() // Still have unread data? If yes, we set this into the write // end so that the peer can get ECONNRESET) when it does read. if e.receiver.RecvQueuedSize() > 0 { e.connected.CloseUnread() } c = e.connected r = e.receiver e.connected = nil e.receiver = nil case e.isBound(): e.path = "" case e.Listening(): close(e.acceptedChan) for n := range e.acceptedChan { n.Close() } e.acceptedChan = nil e.path = "" } e.Unlock() if c != nil { c.CloseNotify() c.Release() } if r != nil { r.CloseNotify() r.Release() } } // BidirectionalConnect implements BoundEndpoint.BidirectionalConnect. func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error { if ce.Type() != e.stype { return syserr.ErrConnectionRefused } // Check if ce is e to avoid a deadlock. if ce, ok := ce.(*connectionedEndpoint); ok && ce == e { return syserr.ErrInvalidEndpointState } // Do a dance to safely acquire locks on both endpoints. if e.id < ce.ID() { e.Lock() ce.Lock() } else { ce.Lock() e.Lock() } // Check connecting state. if ce.Connected() { e.Unlock() ce.Unlock() return syserr.ErrAlreadyConnected } if ce.Listening() { e.Unlock() ce.Unlock() return syserr.ErrInvalidEndpointState } // Check bound state. if !e.Listening() { e.Unlock() ce.Unlock() return syserr.ErrConnectionRefused } // Create a newly bound connectionedEndpoint. ne := &connectionedEndpoint{ baseEndpoint: baseEndpoint{ path: e.path, Queue: &waiter.Queue{}, }, id: e.idGenerator.UniqueID(), idGenerator: e.idGenerator, stype: e.stype, } readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: initialLimit} readQueue.EnableLeakCheck("transport.queue") ne.connected = &connectedEndpoint{ endpoint: ce, writeQueue: readQueue, } writeQueue := &queue{ReaderQueue: ne.Queue, WriterQueue: ce.WaiterQueue(), limit: initialLimit} writeQueue.EnableLeakCheck("transport.queue") if e.stype == linux.SOCK_STREAM { ne.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{readQueue: writeQueue}} } else { ne.receiver = &queueReceiver{readQueue: writeQueue} } select { case e.acceptedChan <- ne: // Commit state. writeQueue.IncRef() connected := &connectedEndpoint{ endpoint: ne, writeQueue: writeQueue, } readQueue.IncRef() if e.stype == linux.SOCK_STREAM { returnConnect(&streamQueueReceiver{queueReceiver: queueReceiver{readQueue: readQueue}}, connected) } else { returnConnect(&queueReceiver{readQueue: readQueue}, connected) } // Notify can deadlock if we are holding these locks. e.Unlock() ce.Unlock() // Notify on both ends. e.Notify(waiter.EventIn) ce.WaiterQueue().Notify(waiter.EventOut) return nil default: // Busy; return ECONNREFUSED per spec. ne.Close() e.Unlock() ce.Unlock() return syserr.ErrConnectionRefused } } // UnidirectionalConnect implements BoundEndpoint.UnidirectionalConnect. func (e *connectionedEndpoint) UnidirectionalConnect(ctx context.Context) (ConnectedEndpoint, *syserr.Error) { return nil, syserr.ErrConnectionRefused } // Connect attempts to directly connect to another Endpoint. // Implements Endpoint.Connect. func (e *connectionedEndpoint) Connect(ctx context.Context, server BoundEndpoint) *syserr.Error { returnConnect := func(r Receiver, ce ConnectedEndpoint) { e.receiver = r e.connected = ce } return server.BidirectionalConnect(ctx, e, returnConnect) } // Listen starts listening on the connection. func (e *connectionedEndpoint) Listen(backlog int) *syserr.Error { e.Lock() defer e.Unlock() if e.Listening() { // Adjust the size of the channel iff we can fix existing // pending connections into the new one. if len(e.acceptedChan) > backlog { return syserr.ErrInvalidEndpointState } origChan := e.acceptedChan e.acceptedChan = make(chan *connectionedEndpoint, backlog) close(origChan) for ep := range origChan { e.acceptedChan <- ep } return nil } if !e.isBound() { return syserr.ErrInvalidEndpointState } // Normal case. e.acceptedChan = make(chan *connectionedEndpoint, backlog) return nil } // Accept accepts a new connection. func (e *connectionedEndpoint) Accept() (Endpoint, *syserr.Error) { e.Lock() defer e.Unlock() if !e.Listening() { return nil, syserr.ErrInvalidEndpointState } select { case ne := <-e.acceptedChan: return ne, nil default: // Nothing left. return nil, syserr.ErrWouldBlock } } // Bind binds the connection. // // For Unix connectionedEndpoints, this _only sets the address associated with // the socket_. Work associated with sockets in the filesystem or finding those // sockets must be done by a higher level. // // Bind will fail only if the socket is connected, bound or the passed address // is invalid (the empty string). func (e *connectionedEndpoint) Bind(addr tcpip.FullAddress, commit func() *syserr.Error) *syserr.Error { e.Lock() defer e.Unlock() if e.isBound() || e.Listening() { return syserr.ErrAlreadyBound } if addr.Addr == "" { // The empty string is not permitted. return syserr.ErrBadLocalAddress } if commit != nil { if err := commit(); err != nil { return err } } // Save the bound address. e.path = string(addr.Addr) return nil } // SendMsg writes data and a control message to the endpoint's peer. // This method does not block if the data cannot be written. func (e *connectionedEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMessages, to BoundEndpoint) (int64, *syserr.Error) { // Stream sockets do not support specifying the endpoint. Seqpacket // sockets ignore the passed endpoint. if e.stype == linux.SOCK_STREAM && to != nil { return 0, syserr.ErrNotSupported } return e.baseEndpoint.SendMsg(ctx, data, c, to) } // Readiness returns the current readiness of the connectionedEndpoint. For // example, if waiter.EventIn is set, the connectionedEndpoint is immediately // readable. func (e *connectionedEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask { e.Lock() defer e.Unlock() ready := waiter.EventMask(0) switch { case e.Connected(): if mask&waiter.EventIn != 0 && e.receiver.Readable() { ready |= waiter.EventIn } if mask&waiter.EventOut != 0 && e.connected.Writable() { ready |= waiter.EventOut } case e.Listening(): if mask&waiter.EventIn != 0 && len(e.acceptedChan) > 0 { ready |= waiter.EventIn } } return ready } // State implements socket.Socket.State. func (e *connectionedEndpoint) State() uint32 { if e.Connected() { return linux.SS_CONNECTED } return linux.SS_UNCONNECTED }