Record VFS2 sockets in global socket map.

Updates #1476, #1478, #1484, #1485.

PiperOrigin-RevId: 304845354
This commit is contained in:
Dean Deng 2020-04-04 21:01:42 -07:00 committed by gVisor bot
parent fc99a7ebf0
commit 24bee1c181
6 changed files with 91 additions and 42 deletions

View File

@ -22,7 +22,6 @@ go_library(
"//pkg/log",
"//pkg/refs",
"//pkg/safemem",
"//pkg/sentry/fs",
"//pkg/sentry/fsbridge",
"//pkg/sentry/fsimpl/kernfs",
"//pkg/sentry/inet",

View File

@ -24,7 +24,6 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@ -32,6 +31,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/unix"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/usermem"
@ -206,22 +206,21 @@ var _ dynamicInode = (*netUnixData)(nil)
func (n *netUnixData) Generate(ctx context.Context, buf *bytes.Buffer) error {
buf.WriteString("Num RefCount Protocol Flags Type St Inode Path\n")
for _, se := range n.kernel.ListSockets() {
s := se.Sock.Get()
if s == nil {
log.Debugf("Couldn't resolve weakref %v in socket table, racing with destruction?", se.Sock)
s := se.SockVFS2
if !s.TryIncRef() {
log.Debugf("Couldn't get reference on %v in socket table, racing with destruction?", s)
continue
}
sfile := s.(*fs.File)
if family, _, _ := sfile.FileOperations.(socket.Socket).Type(); family != linux.AF_UNIX {
if family, _, _ := s.Impl().(socket.SocketVFS2).Type(); family != linux.AF_UNIX {
s.DecRef()
// Not a unix socket.
continue
}
sops := sfile.FileOperations.(*unix.SocketOperations)
sops := s.Impl().(*unix.SocketVFS2)
addr, err := sops.Endpoint().GetLocalAddress()
if err != nil {
log.Warningf("Failed to retrieve socket name from %+v: %v", sfile, err)
log.Warningf("Failed to retrieve socket name from %+v: %v", s, err)
addr.Addr = "<unknown>"
}
@ -234,6 +233,15 @@ func (n *netUnixData) Generate(ctx context.Context, buf *bytes.Buffer) error {
}
}
// Get inode number.
var ino uint64
stat, statErr := s.Stat(ctx, vfs.StatOptions{Mask: linux.STATX_INO})
if statErr != nil || stat.Mask&linux.STATX_INO == 0 {
log.Warningf("Failed to retrieve ino for socket file: %v", statErr)
} else {
ino = stat.Ino
}
// In the socket entry below, the value for the 'Num' field requires
// some consideration. Linux prints the address to the struct
// unix_sock representing a socket in the kernel, but may redact the
@ -252,14 +260,14 @@ func (n *netUnixData) Generate(ctx context.Context, buf *bytes.Buffer) error {
// the definition of this struct changes over time.
//
// For now, we always redact this pointer.
fmt.Fprintf(buf, "%#016p: %08X %08X %08X %04X %02X %5d",
fmt.Fprintf(buf, "%#016p: %08X %08X %08X %04X %02X %8d",
(*unix.SocketOperations)(nil), // Num, pointer to kernel socket struct.
sfile.ReadRefs()-1, // RefCount, don't count our own ref.
s.Refs()-1, // RefCount, don't count our own ref.
0, // Protocol, always 0 for UDS.
sockFlags, // Flags.
sops.Endpoint().Type(), // Type.
sops.State(), // State.
sfile.InodeID(), // Inode.
ino, // Inode.
)
// Path
@ -341,15 +349,14 @@ func commonGenerateTCP(ctx context.Context, buf *bytes.Buffer, k *kernel.Kernel,
t := kernel.TaskFromContext(ctx)
for _, se := range k.ListSockets() {
s := se.Sock.Get()
if s == nil {
log.Debugf("Couldn't resolve weakref with ID %v in socket table, racing with destruction?", se.ID)
s := se.SockVFS2
if !s.TryIncRef() {
log.Debugf("Couldn't get reference on %v in socket table, racing with destruction?", s)
continue
}
sfile := s.(*fs.File)
sops, ok := sfile.FileOperations.(socket.Socket)
sops, ok := s.Impl().(socket.SocketVFS2)
if !ok {
panic(fmt.Sprintf("Found non-socket file in socket table: %+v", sfile))
panic(fmt.Sprintf("Found non-socket file in socket table: %+v", s))
}
if fa, stype, _ := sops.Type(); !(family == fa && stype == linux.SOCK_STREAM) {
s.DecRef()
@ -398,14 +405,15 @@ func commonGenerateTCP(ctx context.Context, buf *bytes.Buffer, k *kernel.Kernel,
// Unimplemented.
fmt.Fprintf(buf, "%08X ", 0)
stat, statErr := s.Stat(ctx, vfs.StatOptions{Mask: linux.STATX_UID | linux.STATX_INO})
// Field: uid.
uattr, err := sfile.Dirent.Inode.UnstableAttr(ctx)
if err != nil {
log.Warningf("Failed to retrieve unstable attr for socket file: %v", err)
if statErr != nil || stat.Mask&linux.STATX_UID == 0 {
log.Warningf("Failed to retrieve uid for socket file: %v", statErr)
fmt.Fprintf(buf, "%5d ", 0)
} else {
creds := auth.CredentialsFromContext(ctx)
fmt.Fprintf(buf, "%5d ", uint32(uattr.Owner.UID.In(creds.UserNamespace).OrOverflow()))
fmt.Fprintf(buf, "%5d ", uint32(auth.KUID(stat.UID).In(creds.UserNamespace).OrOverflow()))
}
// Field: timeout; number of unanswered 0-window probes.
@ -413,11 +421,16 @@ func commonGenerateTCP(ctx context.Context, buf *bytes.Buffer, k *kernel.Kernel,
fmt.Fprintf(buf, "%8d ", 0)
// Field: inode.
fmt.Fprintf(buf, "%8d ", sfile.InodeID())
if statErr != nil || stat.Mask&linux.STATX_INO == 0 {
log.Warningf("Failed to retrieve inode for socket file: %v", statErr)
fmt.Fprintf(buf, "%8d ", 0)
} else {
fmt.Fprintf(buf, "%8d ", stat.Ino)
}
// Field: refcount. Don't count the ref we obtain while deferencing
// the weakref to this socket.
fmt.Fprintf(buf, "%d ", sfile.ReadRefs()-1)
fmt.Fprintf(buf, "%d ", s.Refs()-1)
// Field: Socket struct address. Redacted due to the same reason as
// the 'Num' field in /proc/net/unix, see netUnix.ReadSeqFileData.
@ -499,15 +512,14 @@ func (d *netUDPData) Generate(ctx context.Context, buf *bytes.Buffer) error {
t := kernel.TaskFromContext(ctx)
for _, se := range d.kernel.ListSockets() {
s := se.Sock.Get()
if s == nil {
log.Debugf("Couldn't resolve weakref with ID %v in socket table, racing with destruction?", se.ID)
s := se.SockVFS2
if !s.TryIncRef() {
log.Debugf("Couldn't get reference on %v in socket table, racing with destruction?", s)
continue
}
sfile := s.(*fs.File)
sops, ok := sfile.FileOperations.(socket.Socket)
sops, ok := s.Impl().(socket.SocketVFS2)
if !ok {
panic(fmt.Sprintf("Found non-socket file in socket table: %+v", sfile))
panic(fmt.Sprintf("Found non-socket file in socket table: %+v", s))
}
if family, stype, _ := sops.Type(); family != linux.AF_INET || stype != linux.SOCK_DGRAM {
s.DecRef()
@ -551,25 +563,31 @@ func (d *netUDPData) Generate(ctx context.Context, buf *bytes.Buffer) error {
// Field: retrnsmt. Always 0 for UDP.
fmt.Fprintf(buf, "%08X ", 0)
stat, statErr := s.Stat(ctx, vfs.StatOptions{Mask: linux.STATX_UID | linux.STATX_INO})
// Field: uid.
uattr, err := sfile.Dirent.Inode.UnstableAttr(ctx)
if err != nil {
log.Warningf("Failed to retrieve unstable attr for socket file: %v", err)
if statErr != nil || stat.Mask&linux.STATX_UID == 0 {
log.Warningf("Failed to retrieve uid for socket file: %v", statErr)
fmt.Fprintf(buf, "%5d ", 0)
} else {
creds := auth.CredentialsFromContext(ctx)
fmt.Fprintf(buf, "%5d ", uint32(uattr.Owner.UID.In(creds.UserNamespace).OrOverflow()))
fmt.Fprintf(buf, "%5d ", uint32(auth.KUID(stat.UID).In(creds.UserNamespace).OrOverflow()))
}
// Field: timeout. Always 0 for UDP.
fmt.Fprintf(buf, "%8d ", 0)
// Field: inode.
fmt.Fprintf(buf, "%8d ", sfile.InodeID())
if statErr != nil || stat.Mask&linux.STATX_INO == 0 {
log.Warningf("Failed to retrieve inode for socket file: %v", statErr)
fmt.Fprintf(buf, "%8d ", 0)
} else {
fmt.Fprintf(buf, "%8d ", stat.Ino)
}
// Field: ref; reference count on the socket inode. Don't count the ref
// we obtain while deferencing the weakref to this socket.
fmt.Fprintf(buf, "%d ", sfile.ReadRefs()-1)
fmt.Fprintf(buf, "%d ", s.Refs()-1)
// Field: Socket struct address. Redacted due to the same reason as
// the 'Num' field in /proc/net/unix, see netUnix.ReadSeqFileData.

View File

@ -1445,9 +1445,10 @@ func (k *Kernel) SupervisorContext() context.Context {
// +stateify savable
type SocketEntry struct {
socketEntry
k *Kernel
Sock *refs.WeakRef
ID uint64 // Socket table entry number.
k *Kernel
Sock *refs.WeakRef
SockVFS2 *vfs.FileDescription
ID uint64 // Socket table entry number.
}
// WeakRefGone implements refs.WeakRefUser.WeakRefGone.
@ -1470,7 +1471,30 @@ func (k *Kernel) RecordSocket(sock *fs.File) {
k.extMu.Unlock()
}
// RecordSocketVFS2 adds a VFS2 socket to the system-wide socket table for
// tracking.
//
// Precondition: Caller must hold a reference to sock.
//
// Note that the socket table will not hold a reference on the
// vfs.FileDescription, because we do not support weak refs on VFS2 files.
func (k *Kernel) RecordSocketVFS2(sock *vfs.FileDescription) {
k.extMu.Lock()
id := k.nextSocketEntry
k.nextSocketEntry++
s := &SocketEntry{
k: k,
ID: id,
SockVFS2: sock,
}
k.sockets.PushBack(s)
k.extMu.Unlock()
}
// ListSockets returns a snapshot of all sockets.
//
// Callers of ListSockets() in VFS2 should use SocketEntry.SockVFS2.TryIncRef()
// to get a reference on a socket in the table.
func (k *Kernel) ListSockets() []*SocketEntry {
k.extMu.Lock()
var socks []*SocketEntry

View File

@ -269,7 +269,7 @@ func NewVFS2(t *kernel.Task, family int, stype linux.SockType, protocol int) (*v
return nil, err
}
if s != nil {
// TODO: Add vfs2 sockets to global socket table.
t.Kernel().RecordSocketVFS2(s)
return s, nil
}
}
@ -291,7 +291,9 @@ func PairVFS2(t *kernel.Task, family int, stype linux.SockType, protocol int) (*
return nil, nil, err
}
if s1 != nil && s2 != nil {
// TODO: Add vfs2 sockets to global socket table.
k := t.Kernel()
k.RecordSocketVFS2(s1)
k.RecordSocketVFS2(s2)
return s1, s2, nil
}
}

View File

@ -141,7 +141,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block
return 0, nil, 0, syserr.FromError(e)
}
// TODO: add vfs2 sockets to global table.
t.Kernel().RecordSocketVFS2(ns)
return fd, addr, addrLen, nil
}

View File

@ -182,6 +182,12 @@ func (fd *FileDescription) DecRef() {
}
}
// Refs returns the current number of references. The returned count
// is inherently racy and is unsafe to use without external synchronization.
func (fd *FileDescription) Refs() int64 {
return atomic.LoadInt64(&fd.refs)
}
// Mount returns the mount on which fd was opened. It does not take a reference
// on the returned Mount.
func (fd *FileDescription) Mount() *Mount {