// Copyright 2018 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package p9 import ( "errors" "fmt" "syscall" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/flipcall" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/pool" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" ) // ErrOutOfTags indicates no tags are available. var ErrOutOfTags = errors.New("out of tags -- messages lost?") // ErrOutOfFIDs indicates no more FIDs are available. var ErrOutOfFIDs = errors.New("out of FIDs -- messages lost?") // ErrUnexpectedTag indicates a response with an unexpected tag was received. var ErrUnexpectedTag = errors.New("unexpected tag in response") // ErrVersionsExhausted indicates that all versions to negotiate have been exhausted. var ErrVersionsExhausted = errors.New("exhausted all versions to negotiate") // ErrBadVersionString indicates that the version string is malformed or unsupported. var ErrBadVersionString = errors.New("bad version string") // ErrBadResponse indicates the response didn't match the request. type ErrBadResponse struct { Got MsgType Want MsgType } // Error returns a highly descriptive error. func (e *ErrBadResponse) Error() string { return fmt.Sprintf("unexpected message type: got %v, want %v", e.Got, e.Want) } // response is the asynchronous return from recv. // // This is used in the pending map below. type response struct { r message done chan error } var responsePool = sync.Pool{ New: func() interface{} { return &response{ done: make(chan error, 1), } }, } // Client is at least a 9P2000.L client. type Client struct { // socket is the connected socket. socket *unet.Socket // tagPool is the collection of available tags. tagPool pool.Pool // fidPool is the collection of available fids. fidPool pool.Pool // messageSize is the maximum total size of a message. messageSize uint32 // payloadSize is the maximum payload size of a read or write. // // For large reads and writes this means that the read or write is // broken up into buffer-size/payloadSize requests. payloadSize uint32 // version is the agreed upon version X of 9P2000.L.Google.X. // version 0 implies 9P2000.L. version uint32 // closedWg is marked as done when the Client.watch() goroutine, which is // responsible for closing channels and the socket fd, returns. closedWg sync.WaitGroup // sendRecv is the transport function. // // This is determined dynamically based on whether or not the server // supports flipcall channels (preferred as it is faster and more // efficient, and does not require tags). sendRecv func(message, message) error // -- below corresponds to sendRecvChannel -- // channelsMu protects channels. channelsMu sync.Mutex // channelsWg counts the number of channels for which channel.active == // true. channelsWg sync.WaitGroup // channels is the set of all initialized channels. channels []*channel // availableChannels is a FIFO of inactive channels. availableChannels []*channel // -- below corresponds to sendRecvLegacy -- // pending is the set of pending messages. pending map[Tag]*response pendingMu sync.Mutex // sendMu is the lock for sending a request. sendMu sync.Mutex // recvr is essentially a mutex for calling recv. // // Whoever writes to this channel is permitted to call recv. When // finished calling recv, this channel should be emptied. recvr chan bool } // NewClient creates a new client. It performs a Tversion exchange with // the server to assert that messageSize is ok to use. // // If NewClient succeeds, ownership of socket is transferred to the new Client. func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client, error) { // Need at least one byte of payload. if messageSize <= msgRegistry.largestFixedSize { return nil, &ErrMessageTooLarge{ size: messageSize, msize: msgRegistry.largestFixedSize, } } // Compute a payload size and round to 512 (normal block size) // if it's larger than a single block. payloadSize := messageSize - msgRegistry.largestFixedSize if payloadSize > 512 && payloadSize%512 != 0 { payloadSize -= (payloadSize % 512) } c := &Client{ socket: socket, tagPool: pool.Pool{Start: 1, Limit: uint64(NoTag)}, fidPool: pool.Pool{Start: 1, Limit: uint64(NoFID)}, pending: make(map[Tag]*response), recvr: make(chan bool, 1), messageSize: messageSize, payloadSize: payloadSize, } // Agree upon a version. requested, ok := parseVersion(version) if !ok { return nil, ErrBadVersionString } for { // Always exchange the version using the legacy version of the // protocol. If the protocol supports flipcall, then we switch // our sendRecv function to use that functionality. Otherwise, // we stick to sendRecvLegacy. rversion := Rversion{} err := c.sendRecvLegacy(&Tversion{ Version: versionString(requested), MSize: messageSize, }, &rversion) // The server told us to try again with a lower version. if err == syscall.EAGAIN { if requested == lowestSupportedVersion { return nil, ErrVersionsExhausted } requested-- continue } // We requested an impossible version or our other parameters were bogus. if err != nil { return nil, err } // Parse the version. version, ok := parseVersion(rversion.Version) if !ok { // The server gave us a bad version. We return a generically worrisome error. log.Warningf("server returned bad version string %q", rversion.Version) return nil, ErrBadVersionString } c.version = version break } // Can we switch to use the more advanced channels and create // independent channels for communication? Prefer it if possible. if versionSupportsFlipcall(c.version) { // Attempt to initialize IPC-based communication. for i := 0; i < channelsPerClient; i++ { if err := c.openChannel(i); err != nil { log.Warningf("error opening flipcall channel: %v", err) break // Stop. } } if len(c.channels) >= 1 { // At least one channel created. c.sendRecv = c.sendRecvChannel } else { // Channel setup failed; fallback. c.sendRecv = c.sendRecvLegacy } } else { // No channels available: use the legacy mechanism. c.sendRecv = c.sendRecvLegacy } // Ensure that the socket and channels are closed when the socket is shut // down. c.closedWg.Add(1) go c.watch(socket) // S/R-SAFE: not relevant. return c, nil } // watch watches the given socket and releases resources on hangup events. // // This is intended to be called as a goroutine. func (c *Client) watch(socket *unet.Socket) { defer c.closedWg.Done() events := []unix.PollFd{ unix.PollFd{ Fd: int32(socket.FD()), Events: unix.POLLHUP | unix.POLLRDHUP, }, } // Wait for a shutdown event. for { n, err := unix.Ppoll(events, nil, nil) if err == syscall.EINTR || err == syscall.EAGAIN { continue } if err != nil { log.Warningf("p9.Client.watch(): %v", err) break } if n != 1 { log.Warningf("p9.Client.watch(): got %d events, wanted 1", n) } break } // Set availableChannels to nil so that future calls to c.sendRecvChannel() // don't attempt to activate a channel, and concurrent calls to // c.sendRecvChannel() don't mark released channels as available. c.channelsMu.Lock() c.availableChannels = nil // Shut down all active channels. for _, ch := range c.channels { if ch.active { log.Debugf("shutting down active channel@%p...", ch) ch.Shutdown() } } c.channelsMu.Unlock() // Wait for active channels to become inactive. c.channelsWg.Wait() // Close all channels. c.channelsMu.Lock() for _, ch := range c.channels { ch.Close() } c.channelsMu.Unlock() // Close the main socket. c.socket.Close() } // openChannel attempts to open a client channel. // // Note that this function returns naked errors which should not be propagated // directly to a caller. It is expected that the errors will be logged and a // fallback path will be used instead. func (c *Client) openChannel(id int) error { var ( rchannel0 Rchannel rchannel1 Rchannel res = new(channel) ) // Open the data channel. if err := c.sendRecvLegacy(&Tchannel{ ID: uint32(id), Control: 0, }, &rchannel0); err != nil { return fmt.Errorf("error handling Tchannel message: %v", err) } if rchannel0.FilePayload() == nil { return fmt.Errorf("missing file descriptor on primary channel") } // We don't need to hold this. defer rchannel0.FilePayload().Close() // Open the channel for file descriptors. if err := c.sendRecvLegacy(&Tchannel{ ID: uint32(id), Control: 1, }, &rchannel1); err != nil { return err } if rchannel1.FilePayload() == nil { return fmt.Errorf("missing file descriptor on file descriptor channel") } // Construct the endpoints. res.desc = flipcall.PacketWindowDescriptor{ FD: rchannel0.FilePayload().FD(), Offset: int64(rchannel0.Offset), Length: int(rchannel0.Length), } if err := res.data.Init(flipcall.ClientSide, res.desc); err != nil { rchannel1.FilePayload().Close() return err } // The fds channel owns the control payload, and it will be closed when // the channel object is closed. res.fds.Init(rchannel1.FilePayload().Release()) // Save the channel. c.channelsMu.Lock() defer c.channelsMu.Unlock() c.channels = append(c.channels, res) c.availableChannels = append(c.availableChannels, res) return nil } // handleOne handles a single incoming message. // // This should only be called with the token from recvr. Note that the received // tag will automatically be cleared from pending. func (c *Client) handleOne() { tag, r, err := recv(c.socket, c.messageSize, func(tag Tag, t MsgType) (message, error) { c.pendingMu.Lock() resp := c.pending[tag] c.pendingMu.Unlock() // Not expecting this message? if resp == nil { log.Warningf("client received unexpected tag %v, ignoring", tag) return nil, ErrUnexpectedTag } // Is it an error? We specifically allow this to // go through, and then we deserialize below. if t == MsgRlerror { return &Rlerror{}, nil } // Does it match expectations? if t != resp.r.Type() { return nil, &ErrBadResponse{Got: t, Want: resp.r.Type()} } // Return the response. return resp.r, nil }) if err != nil { // No tag was extracted (probably a socket error). // // Likely catastrophic. Notify all waiters and clear pending. c.pendingMu.Lock() for _, resp := range c.pending { resp.done <- err } c.pending = make(map[Tag]*response) c.pendingMu.Unlock() } else { // Process the tag. // // We know that is is contained in the map because our lookup function // above must have succeeded (found the tag) to return nil err. c.pendingMu.Lock() resp := c.pending[tag] delete(c.pending, tag) c.pendingMu.Unlock() resp.r = r resp.done <- err } } // waitAndRecv co-ordinates with other receivers to handle responses. func (c *Client) waitAndRecv(done chan error) error { for { select { case err := <-done: return err case c.recvr <- true: select { case err := <-done: // It's possible that we got the token, despite // done also being available. Check for that. <-c.recvr return err default: // Handle receiving one tag. c.handleOne() // Return the token. <-c.recvr } } } } // sendRecvLegacy performs a roundtrip message exchange. // // This is called by internal functions. func (c *Client) sendRecvLegacy(t message, r message) error { tag, ok := c.tagPool.Get() if !ok { return ErrOutOfTags } defer c.tagPool.Put(tag) // Indicate we're expecting a response. // // Note that the tag will be cleared from pending // automatically (see handleOne for details). resp := responsePool.Get().(*response) defer responsePool.Put(resp) resp.r = r c.pendingMu.Lock() c.pending[Tag(tag)] = resp c.pendingMu.Unlock() // Send the request over the wire. c.sendMu.Lock() err := send(c.socket, Tag(tag), t) c.sendMu.Unlock() if err != nil { return err } // Co-ordinate with other receivers. if err := c.waitAndRecv(resp.done); err != nil { return err } // Is it an error message? // // For convenience, we transform these directly // into errors. Handlers need not handle this case. if rlerr, ok := resp.r.(*Rlerror); ok { return syscall.Errno(rlerr.Error) } // At this point, we know it matches. // // Per recv call above, we will only allow a type // match (and give our r) or an instance of Rlerror. return nil } // sendRecvChannel uses channels to send a message. func (c *Client) sendRecvChannel(t message, r message) error { // Acquire an available channel. c.channelsMu.Lock() if len(c.availableChannels) == 0 { c.channelsMu.Unlock() return c.sendRecvLegacy(t, r) } idx := len(c.availableChannels) - 1 ch := c.availableChannels[idx] c.availableChannels = c.availableChannels[:idx] ch.active = true c.channelsWg.Add(1) c.channelsMu.Unlock() // Ensure that it's connected. if !ch.connected { ch.connected = true if err := ch.data.Connect(); err != nil { // The channel is unusable, so don't return it to // c.availableChannels. However, we still have to mark it as // inactive so c.watch() doesn't wait for it. c.channelsMu.Lock() ch.active = false c.channelsMu.Unlock() c.channelsWg.Done() // Map all transport errors to EIO, but ensure that the real error // is logged. log.Warningf("p9.Client.sendRecvChannel: flipcall.Endpoint.Connect: %v", err) return syscall.EIO } } // Send the request and receive the server's response. rsz, err := ch.send(t) if err != nil { // See above. c.channelsMu.Lock() ch.active = false c.channelsMu.Unlock() c.channelsWg.Done() log.Warningf("p9.Client.sendRecvChannel: p9.channel.send: %v", err) return syscall.EIO } // Parse the server's response. _, retErr := ch.recv(r, rsz) // Release the channel. c.channelsMu.Lock() ch.active = false // If c.availableChannels is nil, c.watch() has fired and we should not // mark this channel as available. if c.availableChannels != nil { c.availableChannels = append(c.availableChannels, ch) } c.channelsMu.Unlock() c.channelsWg.Done() return retErr } // Version returns the negotiated 9P2000.L.Google version number. func (c *Client) Version() uint32 { return c.version } // Close closes the underlying socket and channels. func (c *Client) Close() { // unet.Socket.Shutdown() has no effect if unet.Socket.Close() has already // been called (by c.watch()). c.socket.Shutdown() c.closedWg.Wait() }