// Copyright 2021 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package lisafs import ( "io" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/unet" ) var ( sockHeaderLen = uint32((*sockHeader)(nil).SizeBytes()) ) // sockHeader is the header present in front of each message received on a UDS. // // +marshal type sockHeader struct { payloadLen uint32 message MID _ uint16 // Need to make struct packed. } // sockCommunicator implements Communicator. This is not thread safe. type sockCommunicator struct { fdTracker sock *unet.Socket buf []byte } var _ Communicator = (*sockCommunicator)(nil) func newSockComm(sock *unet.Socket) *sockCommunicator { return &sockCommunicator{ sock: sock, buf: make([]byte, sockHeaderLen), } } func (s *sockCommunicator) FD() int { return s.sock.FD() } func (s *sockCommunicator) destroy() { s.sock.Close() } func (s *sockCommunicator) shutdown() { if err := s.sock.Shutdown(); err != nil { log.Warningf("Socket.Shutdown() failed (FD: %d): %v", s.sock.FD(), err) } } func (s *sockCommunicator) resizeBuf(size uint32) { if cap(s.buf) < int(size) { s.buf = s.buf[:cap(s.buf)] s.buf = append(s.buf, make([]byte, int(size)-cap(s.buf))...) } else { s.buf = s.buf[:size] } } // PayloadBuf implements Communicator.PayloadBuf. func (s *sockCommunicator) PayloadBuf(size uint32) []byte { s.resizeBuf(sockHeaderLen + size) return s.buf[sockHeaderLen : sockHeaderLen+size] } // SndRcvMessage implements Communicator.SndRcvMessage. func (s *sockCommunicator) SndRcvMessage(m MID, payloadLen uint32, wantFDs uint8) (MID, uint32, error) { if err := s.sndPrepopulatedMsg(m, payloadLen, nil); err != nil { return 0, 0, err } return s.rcvMsg(wantFDs) } // sndPrepopulatedMsg assumes that s.buf has already been populated with // `payloadLen` bytes of data. func (s *sockCommunicator) sndPrepopulatedMsg(m MID, payloadLen uint32, fds []int) error { header := sockHeader{payloadLen: payloadLen, message: m} header.MarshalUnsafe(s.buf) dataLen := sockHeaderLen + payloadLen return writeTo(s.sock, [][]byte{s.buf[:dataLen]}, int(dataLen), fds) } // writeTo writes the passed iovec to the UDS and donates any passed FDs. func writeTo(sock *unet.Socket, iovec [][]byte, dataLen int, fds []int) error { w := sock.Writer(true) if len(fds) > 0 { w.PackFDs(fds...) } fdsUnpacked := false for n := 0; n < dataLen; { cur, err := w.WriteVec(iovec) if err != nil { return err } n += cur // Fast common path. if n >= dataLen { break } // Consume iovecs. for consumed := 0; consumed < cur; { if len(iovec[0]) <= cur-consumed { consumed += len(iovec[0]) iovec = iovec[1:] } else { iovec[0] = iovec[0][cur-consumed:] break } } if n > 0 && !fdsUnpacked { // Don't resend any control message. fdsUnpacked = true w.UnpackFDs() } } return nil } // rcvMsg reads the message header and payload from the UDS. It also populates // fds with any donated FDs. func (s *sockCommunicator) rcvMsg(wantFDs uint8) (MID, uint32, error) { fds, err := readFrom(s.sock, s.buf[:sockHeaderLen], wantFDs) if err != nil { return 0, 0, err } for _, fd := range fds { s.TrackFD(fd) } var header sockHeader header.UnmarshalUnsafe(s.buf) // No payload? We are done. if header.payloadLen == 0 { return header.message, 0, nil } if _, err := readFrom(s.sock, s.PayloadBuf(header.payloadLen), 0); err != nil { return 0, 0, err } return header.message, header.payloadLen, nil } // readFrom fills the passed buffer with data from the socket. It also returns // any donated FDs. func readFrom(sock *unet.Socket, buf []byte, wantFDs uint8) ([]int, error) { r := sock.Reader(true) r.EnableFDs(int(wantFDs)) var ( fds []int fdInit bool ) n := len(buf) for got := 0; got < n; { cur, err := r.ReadVec([][]byte{buf[got:]}) // Ignore EOF if cur > 0. if err != nil && (err != io.EOF || cur == 0) { r.CloseFDs() return nil, err } if !fdInit && cur > 0 { fds, err = r.ExtractFDs() if err != nil { return nil, err } fdInit = true r.EnableFDs(0) } got += cur } return fds, nil } func closeFDs(fds []int) { for _, fd := range fds { if fd >= 0 { unix.Close(fd) } } }