gvisor/pkg/sentry/socket/netlink/socket.go

518 lines
15 KiB
Go

// Copyright 2018 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package netlink provides core functionality for netlink sockets.
package netlink
import (
"sync"
"gvisor.googlesource.com/gvisor/pkg/abi/linux"
"gvisor.googlesource.com/gvisor/pkg/binary"
"gvisor.googlesource.com/gvisor/pkg/sentry/arch"
"gvisor.googlesource.com/gvisor/pkg/sentry/context"
"gvisor.googlesource.com/gvisor/pkg/sentry/device"
"gvisor.googlesource.com/gvisor/pkg/sentry/fs"
"gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
"gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
"gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
"gvisor.googlesource.com/gvisor/pkg/sentry/socket"
"gvisor.googlesource.com/gvisor/pkg/sentry/socket/netlink/port"
sunix "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix"
"gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
"gvisor.googlesource.com/gvisor/pkg/syserr"
"gvisor.googlesource.com/gvisor/pkg/syserror"
"gvisor.googlesource.com/gvisor/pkg/tcpip"
"gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix"
"gvisor.googlesource.com/gvisor/pkg/waiter"
)
// defaultSendBufferSize is the default size for the send buffer.
const defaultSendBufferSize = 16 * 1024
// netlinkSocketDevice is the netlink socket virtual device.
var netlinkSocketDevice = device.NewAnonDevice()
// Socket is the base socket type for netlink sockets.
//
// This implementation only supports userspace sending and receiving messages
// to/from the kernel.
//
// Socket implements socket.Socket.
type Socket struct {
socket.ReceiveTimeout
fsutil.PipeSeek `state:"nosave"`
fsutil.NotDirReaddir `state:"nosave"`
fsutil.NoFsync `state:"nosave"`
fsutil.NoopFlush `state:"nosave"`
fsutil.NoMMap `state:"nosave"`
// ports provides netlink port allocation.
ports *port.Manager
// protocol is the netlink protocol implementation.
protocol Protocol
// ep is a datagram unix endpoint used to buffer messages sent from the
// kernel to userspace. RecvMsg reads messages from this endpoint.
ep unix.Endpoint
// connection is the kernel's connection to ep, used to write messages
// sent to userspace.
connection unix.ConnectedEndpoint
// mu protects the fields below.
mu sync.Mutex `state:"nosave"`
// bound indicates that portid is valid.
bound bool
// portID is the port ID allocated for this socket.
portID int32
// sendBufferSize is the send buffer "size". We don't actually have a
// fixed buffer but only consume this many bytes.
sendBufferSize uint64
}
var _ socket.Socket = (*Socket)(nil)
// NewSocket creates a new Socket.
func NewSocket(t *kernel.Task, protocol Protocol) (*Socket, *syserr.Error) {
// Datagram endpoint used to buffer kernel -> user messages.
ep := unix.NewConnectionless()
// Bind the endpoint for good measure so we can connect to it. The
// bound address will never be exposed.
if terr := ep.Bind(tcpip.FullAddress{Addr: "dummy"}, nil); terr != nil {
ep.Close()
return nil, syserr.TranslateNetstackError(terr)
}
// Create a connection from which the kernel can write messages.
connection, terr := ep.(unix.BoundEndpoint).UnidirectionalConnect()
if terr != nil {
ep.Close()
return nil, syserr.TranslateNetstackError(terr)
}
return &Socket{
ports: t.Kernel().NetlinkPorts(),
protocol: protocol,
ep: ep,
connection: connection,
sendBufferSize: defaultSendBufferSize,
}, nil
}
// Release implements fs.FileOperations.Release.
func (s *Socket) Release() {
s.connection.Release()
s.ep.Close()
if s.bound {
s.ports.Release(s.protocol.Protocol(), s.portID)
}
}
// Readiness implements waiter.Waitable.Readiness.
func (s *Socket) Readiness(mask waiter.EventMask) waiter.EventMask {
// ep holds messages to be read and thus handles EventIn readiness.
ready := s.ep.Readiness(mask)
if mask&waiter.EventOut == waiter.EventOut {
// sendMsg handles messages synchronously and is thus always
// ready for writing.
ready |= waiter.EventOut
}
return ready
}
// EventRegister implements waiter.Waitable.EventRegister.
func (s *Socket) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
s.ep.EventRegister(e, mask)
// Writable readiness never changes, so no registration is needed.
}
// EventUnregister implements waiter.Waitable.EventUnregister.
func (s *Socket) EventUnregister(e *waiter.Entry) {
s.ep.EventUnregister(e)
}
// Ioctl implements fs.FileOperations.Ioctl.
func (s *Socket) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
// TODO: no ioctls supported.
return 0, syserror.ENOTTY
}
// ExtractSockAddr extracts the SockAddrNetlink from b.
func ExtractSockAddr(b []byte) (*linux.SockAddrNetlink, *syserr.Error) {
if len(b) < linux.SockAddrNetlinkSize {
return nil, syserr.ErrBadAddress
}
var sa linux.SockAddrNetlink
binary.Unmarshal(b[:linux.SockAddrNetlinkSize], usermem.ByteOrder, &sa)
if sa.Family != linux.AF_NETLINK {
return nil, syserr.ErrInvalidArgument
}
return &sa, nil
}
// bindPort binds this socket to a port, preferring 'port' if it is available.
//
// port of 0 defaults to the ThreadGroup ID.
//
// Preconditions: mu is held.
func (s *Socket) bindPort(t *kernel.Task, port int32) *syserr.Error {
if s.bound {
// Re-binding is only allowed if the port doesn't change.
if port != s.portID {
return syserr.ErrInvalidArgument
}
return nil
}
if port == 0 {
port = int32(t.ThreadGroup().ID())
}
port, ok := s.ports.Allocate(s.protocol.Protocol(), port)
if !ok {
return syserr.ErrBusy
}
s.portID = port
s.bound = true
return nil
}
// Bind implements socket.Socket.Bind.
func (s *Socket) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
a, err := ExtractSockAddr(sockaddr)
if err != nil {
return err
}
// No support for multicast groups yet.
if a.Groups != 0 {
return syserr.ErrPermissionDenied
}
s.mu.Lock()
defer s.mu.Unlock()
return s.bindPort(t, int32(a.PortID))
}
// Connect implements socket.Socket.Connect.
func (s *Socket) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
a, err := ExtractSockAddr(sockaddr)
if err != nil {
return err
}
// No support for multicast groups yet.
if a.Groups != 0 {
return syserr.ErrPermissionDenied
}
s.mu.Lock()
defer s.mu.Unlock()
if a.PortID == 0 {
// Netlink sockets default to connected to the kernel, but
// connecting anyways automatically binds if not already bound.
if !s.bound {
// Pass port 0 to get an auto-selected port ID.
return s.bindPort(t, 0)
}
return nil
}
// We don't support non-kernel destination ports. Linux returns EPERM
// if applications attempt to do this without NL_CFG_F_NONROOT_SEND, so
// we emulate that.
return syserr.ErrPermissionDenied
}
// Accept implements socket.Socket.Accept.
func (s *Socket) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (kdefs.FD, interface{}, uint32, *syserr.Error) {
// Netlink sockets never support accept.
return 0, nil, 0, syserr.ErrNotSupported
}
// Listen implements socket.Socket.Listen.
func (s *Socket) Listen(t *kernel.Task, backlog int) *syserr.Error {
// Netlink sockets never support listen.
return syserr.ErrNotSupported
}
// Shutdown implements socket.Socket.Shutdown.
func (s *Socket) Shutdown(t *kernel.Task, how int) *syserr.Error {
// Netlink sockets never support shutdown.
return syserr.ErrNotSupported
}
// GetSockOpt implements socket.Socket.GetSockOpt.
func (s *Socket) GetSockOpt(t *kernel.Task, level int, name int, outLen int) (interface{}, *syserr.Error) {
// TODO: no sockopts supported.
return nil, syserr.ErrProtocolNotAvailable
}
// SetSockOpt implements socket.Socket.SetSockOpt.
func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error {
// TODO: no sockopts supported.
return syserr.ErrProtocolNotAvailable
}
// GetSockName implements socket.Socket.GetSockName.
func (s *Socket) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
s.mu.Lock()
defer s.mu.Unlock()
sa := linux.SockAddrNetlink{
Family: linux.AF_NETLINK,
PortID: uint32(s.portID),
}
return sa, uint32(binary.Size(sa)), nil
}
// GetPeerName implements socket.Socket.GetPeerName.
func (s *Socket) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
sa := linux.SockAddrNetlink{
Family: linux.AF_NETLINK,
// TODO: Support non-kernel peers. For now the peer
// must be the kernel.
PortID: 0,
}
return sa, uint32(binary.Size(sa)), nil
}
// RecvMsg implements socket.Socket.RecvMsg.
func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, interface{}, uint32, socket.ControlMessages, *syserr.Error) {
from := linux.SockAddrNetlink{
Family: linux.AF_NETLINK,
PortID: 0,
}
fromLen := uint32(binary.Size(from))
trunc := flags&linux.MSG_TRUNC != 0
r := sunix.EndpointReader{
Endpoint: s.ep,
Peek: flags&linux.MSG_PEEK != 0,
}
if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
if trunc {
n = int64(r.MsgSize)
}
return int(n), from, fromLen, socket.ControlMessages{}, syserr.FromError(err)
}
// We'll have to block. Register for notification and keep trying to
// receive all the data.
e, ch := waiter.NewChannelEntry(nil)
s.EventRegister(&e, waiter.EventIn)
defer s.EventUnregister(&e)
for {
if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock {
if trunc {
n = int64(r.MsgSize)
}
return int(n), from, fromLen, socket.ControlMessages{}, syserr.FromError(err)
}
if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
if err == syserror.ETIMEDOUT {
return 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
}
return 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
}
}
}
// Read implements fs.FileOperations.Read.
func (s *Socket) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
if dst.NumBytes() == 0 {
return 0, nil
}
return dst.CopyOutFrom(ctx, &sunix.EndpointReader{
Endpoint: s.ep,
})
}
// sendResponse sends the response messages in ms back to userspace.
func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error {
// Linux combines multiple netlink messages into a single datagram.
bufs := make([][]byte, 0, len(ms.Messages))
for _, m := range ms.Messages {
bufs = append(bufs, m.Finalize())
}
if len(bufs) > 0 {
// RecvMsg never receives the address, so we don't need to send
// one.
_, notify, terr := s.connection.Send(bufs, unix.ControlMessages{}, tcpip.FullAddress{})
// If the buffer is full, we simply drop messages, just like
// Linux.
if terr != nil && terr != tcpip.ErrWouldBlock {
return syserr.TranslateNetstackError(terr)
}
if notify {
s.connection.SendNotify()
}
}
// N.B. multi-part messages should still send NLMSG_DONE even if
// MessageSet contains no messages.
//
// N.B. NLMSG_DONE is always sent in a different datagram. See
// net/netlink/af_netlink.c:netlink_dump.
if ms.Multi {
m := NewMessage(linux.NetlinkMessageHeader{
Type: linux.NLMSG_DONE,
Flags: linux.NLM_F_MULTI,
Seq: ms.Seq,
PortID: uint32(ms.PortID),
})
_, notify, terr := s.connection.Send([][]byte{m.Finalize()}, unix.ControlMessages{}, tcpip.FullAddress{})
if terr != nil && terr != tcpip.ErrWouldBlock {
return syserr.TranslateNetstackError(terr)
}
if notify {
s.connection.SendNotify()
}
}
return nil
}
// processMessages handles each message in buf, passing it to the protocol
// handler for final handling.
func (s *Socket) processMessages(ctx context.Context, buf []byte) *syserr.Error {
for len(buf) > 0 {
if len(buf) < linux.NetlinkMessageHeaderSize {
// Linux ignores messages that are too short. See
// net/netlink/af_netlink.c:netlink_rcv_skb.
break
}
var hdr linux.NetlinkMessageHeader
binary.Unmarshal(buf[:linux.NetlinkMessageHeaderSize], usermem.ByteOrder, &hdr)
if hdr.Length < linux.NetlinkMessageHeaderSize || uint64(hdr.Length) > uint64(len(buf)) {
// Linux ignores malformed messages. See
// net/netlink/af_netlink.c:netlink_rcv_skb.
break
}
// Data from this message.
data := buf[linux.NetlinkMessageHeaderSize:hdr.Length]
// Advance to the next message.
next := alignUp(int(hdr.Length), linux.NLMSG_ALIGNTO)
if next >= len(buf)-1 {
next = len(buf) - 1
}
buf = buf[next:]
// Ignore control messages.
if hdr.Type < linux.NLMSG_MIN_TYPE {
continue
}
// TODO: ACKs not supported yet.
if hdr.Flags&linux.NLM_F_ACK == linux.NLM_F_ACK {
return syserr.ErrNotSupported
}
ms := NewMessageSet(s.portID, hdr.Seq)
if err := s.protocol.ProcessMessage(ctx, hdr, data, ms); err != nil {
return err
}
if err := s.sendResponse(ctx, ms); err != nil {
return err
}
}
return nil
}
// sendMsg is the core of message send, used for SendMsg and Write.
func (s *Socket) sendMsg(ctx context.Context, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) {
dstPort := int32(0)
if len(to) != 0 {
a, err := ExtractSockAddr(to)
if err != nil {
return 0, err
}
// No support for multicast groups yet.
if a.Groups != 0 {
return 0, syserr.ErrPermissionDenied
}
dstPort = int32(a.PortID)
}
if dstPort != 0 {
// Non-kernel destinations not supported yet. Treat as if
// NL_CFG_F_NONROOT_SEND is not set.
return 0, syserr.ErrPermissionDenied
}
s.mu.Lock()
defer s.mu.Unlock()
// For simplicity, and consistency with Linux, we copy in the entire
// message up front.
if uint64(src.NumBytes()) > s.sendBufferSize {
return 0, syserr.ErrMessageTooLong
}
buf := make([]byte, src.NumBytes())
n, err := src.CopyIn(ctx, buf)
if err != nil {
// Don't partially consume messages.
return 0, syserr.FromError(err)
}
if err := s.processMessages(ctx, buf); err != nil {
return 0, err
}
return n, nil
}
// SendMsg implements socket.Socket.SendMsg.
func (s *Socket) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) {
return s.sendMsg(t, src, to, flags, controlMessages)
}
// Write implements fs.FileOperations.Write.
func (s *Socket) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
n, err := s.sendMsg(ctx, src, nil, 0, socket.ControlMessages{})
return int64(n), err.ToError()
}