gvisor/pkg/lisafs/sock.go

209 lines
4.7 KiB
Go
Raw Normal View History

// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package lisafs
import (
"io"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/unet"
)
var (
sockHeaderLen = uint32((*sockHeader)(nil).SizeBytes())
)
// sockHeader is the header present in front of each message received on a UDS.
//
// +marshal
type sockHeader struct {
payloadLen uint32
message MID
_ uint16 // Need to make struct packed.
}
// sockCommunicator implements Communicator. This is not thread safe.
type sockCommunicator struct {
fdTracker
sock *unet.Socket
buf []byte
}
var _ Communicator = (*sockCommunicator)(nil)
func newSockComm(sock *unet.Socket) *sockCommunicator {
return &sockCommunicator{
sock: sock,
buf: make([]byte, sockHeaderLen),
}
}
func (s *sockCommunicator) FD() int {
return s.sock.FD()
}
func (s *sockCommunicator) destroy() {
s.sock.Close()
}
func (s *sockCommunicator) shutdown() {
if err := s.sock.Shutdown(); err != nil {
log.Warningf("Socket.Shutdown() failed (FD: %d): %v", s.sock.FD(), err)
}
}
func (s *sockCommunicator) resizeBuf(size uint32) {
if cap(s.buf) < int(size) {
s.buf = s.buf[:cap(s.buf)]
s.buf = append(s.buf, make([]byte, int(size)-cap(s.buf))...)
} else {
s.buf = s.buf[:size]
}
}
// PayloadBuf implements Communicator.PayloadBuf.
func (s *sockCommunicator) PayloadBuf(size uint32) []byte {
s.resizeBuf(sockHeaderLen + size)
return s.buf[sockHeaderLen : sockHeaderLen+size]
}
// SndRcvMessage implements Communicator.SndRcvMessage.
func (s *sockCommunicator) SndRcvMessage(m MID, payloadLen uint32, wantFDs uint8) (MID, uint32, error) {
if err := s.sndPrepopulatedMsg(m, payloadLen, nil); err != nil {
return 0, 0, err
}
return s.rcvMsg(wantFDs)
}
// sndPrepopulatedMsg assumes that s.buf has already been populated with
// `payloadLen` bytes of data.
func (s *sockCommunicator) sndPrepopulatedMsg(m MID, payloadLen uint32, fds []int) error {
header := sockHeader{payloadLen: payloadLen, message: m}
header.MarshalUnsafe(s.buf)
dataLen := sockHeaderLen + payloadLen
return writeTo(s.sock, [][]byte{s.buf[:dataLen]}, int(dataLen), fds)
}
// writeTo writes the passed iovec to the UDS and donates any passed FDs.
func writeTo(sock *unet.Socket, iovec [][]byte, dataLen int, fds []int) error {
w := sock.Writer(true)
if len(fds) > 0 {
w.PackFDs(fds...)
}
fdsUnpacked := false
for n := 0; n < dataLen; {
cur, err := w.WriteVec(iovec)
if err != nil {
return err
}
n += cur
// Fast common path.
if n >= dataLen {
break
}
// Consume iovecs.
for consumed := 0; consumed < cur; {
if len(iovec[0]) <= cur-consumed {
consumed += len(iovec[0])
iovec = iovec[1:]
} else {
iovec[0] = iovec[0][cur-consumed:]
break
}
}
if n > 0 && !fdsUnpacked {
// Don't resend any control message.
fdsUnpacked = true
w.UnpackFDs()
}
}
return nil
}
// rcvMsg reads the message header and payload from the UDS. It also populates
// fds with any donated FDs.
func (s *sockCommunicator) rcvMsg(wantFDs uint8) (MID, uint32, error) {
fds, err := readFrom(s.sock, s.buf[:sockHeaderLen], wantFDs)
if err != nil {
return 0, 0, err
}
for _, fd := range fds {
s.TrackFD(fd)
}
var header sockHeader
header.UnmarshalUnsafe(s.buf)
// No payload? We are done.
if header.payloadLen == 0 {
return header.message, 0, nil
}
if _, err := readFrom(s.sock, s.PayloadBuf(header.payloadLen), 0); err != nil {
return 0, 0, err
}
return header.message, header.payloadLen, nil
}
// readFrom fills the passed buffer with data from the socket. It also returns
// any donated FDs.
func readFrom(sock *unet.Socket, buf []byte, wantFDs uint8) ([]int, error) {
r := sock.Reader(true)
r.EnableFDs(int(wantFDs))
var (
fds []int
fdInit bool
)
n := len(buf)
for got := 0; got < n; {
cur, err := r.ReadVec([][]byte{buf[got:]})
// Ignore EOF if cur > 0.
if err != nil && (err != io.EOF || cur == 0) {
r.CloseFDs()
return nil, err
}
if !fdInit && cur > 0 {
fds, err = r.ExtractFDs()
if err != nil {
return nil, err
}
fdInit = true
r.EnableFDs(0)
}
got += cur
}
return fds, nil
}
func closeFDs(fds []int) {
for _, fd := range fds {
if fd >= 0 {
unix.Close(fd)
}
}
}