diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index 929814752..a5f78506a 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -191,6 +191,15 @@ const ( SO_TXTIME = 61 ) +// enum socket_state, from uapi/linux/net.h. +const ( + SS_FREE = 0 // Not allocated. + SS_UNCONNECTED = 1 // Unconnected to any socket. + SS_CONNECTING = 2 // In process of connecting. + SS_CONNECTED = 3 // Connected to socket. + SS_DISCONNECTING = 4 // In process of disconnecting. +) + // SockAddrMax is the maximum size of a struct sockaddr, from // uapi/linux/socket.h. const SockAddrMax = 128 @@ -343,3 +352,10 @@ const SizeOfControlMessageRight = 4 // SCM_MAX_FD is the maximum number of FDs accepted in a single sendmsg call. // From net/scm.h. const SCM_MAX_FD = 253 + +// SO_ACCEPTCON is defined as __SO_ACCEPTCON in +// include/uapi/linux/net.h, which represents a listening socket +// state. Note that this is distinct from SO_ACCEPTCONN, which is a +// socket option for querying whether a socket is in a listening +// state. +const SO_ACCEPTCON = 1 << 16 diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD index f6bc90634..666b0ab3a 100644 --- a/pkg/sentry/fs/proc/BUILD +++ b/pkg/sentry/fs/proc/BUILD @@ -30,6 +30,7 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", + "//pkg/log", "//pkg/sentry/context", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", @@ -43,6 +44,8 @@ go_library( "//pkg/sentry/kernel/time", "//pkg/sentry/mm", "//pkg/sentry/socket/rpcinet", + "//pkg/sentry/socket/unix", + "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usage", "//pkg/sentry/usermem", "//pkg/syserror", diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go index 219eea7f8..55a958f9e 100644 --- a/pkg/sentry/fs/proc/net.go +++ b/pkg/sentry/fs/proc/net.go @@ -15,19 +15,24 @@ package proc import ( + "bytes" "fmt" "time" "gvisor.googlesource.com/gvisor/pkg/abi/linux" + "gvisor.googlesource.com/gvisor/pkg/log" "gvisor.googlesource.com/gvisor/pkg/sentry/context" "gvisor.googlesource.com/gvisor/pkg/sentry/fs" "gvisor.googlesource.com/gvisor/pkg/sentry/fs/proc/seqfile" "gvisor.googlesource.com/gvisor/pkg/sentry/fs/ramfs" "gvisor.googlesource.com/gvisor/pkg/sentry/inet" + "gvisor.googlesource.com/gvisor/pkg/sentry/kernel" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" ) // newNet creates a new proc net entry. -func (p *proc) newNetDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode { +func (p *proc) newNetDir(ctx context.Context, k *kernel.Kernel, msrc *fs.MountSource) *fs.Inode { var contents map[string]*fs.Inode if s := p.k.NetworkStack(); s != nil { contents = map[string]*fs.Inode{ @@ -52,6 +57,8 @@ func (p *proc) newNetDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode { "tcp": newStaticProcInode(ctx, msrc, []byte(" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode")), "udp": newStaticProcInode(ctx, msrc, []byte(" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode ref pointer drops")), + + "unix": seqfile.NewSeqFileInode(ctx, &netUnix{k: k}, msrc), } if s.SupportsIPv6() { @@ -182,3 +189,120 @@ func (n *netDev) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se return data, 0 } + +// netUnix implements seqfile.SeqSource for /proc/net/unix. +// +// +stateify savable +type netUnix struct { + k *kernel.Kernel +} + +// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate. +func (*netUnix) NeedsUpdate(generation int64) bool { + return true +} + +// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData. +func (n *netUnix) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) { + if h != nil { + return []seqfile.SeqData{}, 0 + } + + var buf bytes.Buffer + // Header + fmt.Fprintf(&buf, "Num RefCount Protocol Flags Type St Inode Path\n") + + // Entries + for _, sref := range n.k.ListSockets(linux.AF_UNIX) { + s := sref.Get() + if s == nil { + log.Debugf("Couldn't resolve weakref %v in socket table, racing with destruction?", sref) + continue + } + sfile := s.(*fs.File) + sops, ok := sfile.FileOperations.(*unix.SocketOperations) + if !ok { + panic(fmt.Sprintf("Found non-unix socket file in unix socket table: %+v", sfile)) + } + + addr, err := sops.Endpoint().GetLocalAddress() + if err != nil { + log.Warningf("Failed to retrieve socket name from %+v: %v", sfile, err) + addr.Addr = "" + } + + sockFlags := 0 + if ce, ok := sops.Endpoint().(transport.ConnectingEndpoint); ok { + if ce.Listening() { + // For unix domain sockets, linux reports a single flag + // value if the socket is listening, of __SO_ACCEPTCON. + sockFlags = linux.SO_ACCEPTCON + } + } + + var sockState int + switch sops.Endpoint().Type() { + case linux.SOCK_DGRAM: + sockState = linux.SS_CONNECTING + // Unlike Linux, we don't have unbound connection-less sockets, + // so no SS_DISCONNECTING. + + case linux.SOCK_SEQPACKET: + fallthrough + case linux.SOCK_STREAM: + // Connectioned. + if sops.Endpoint().(transport.ConnectingEndpoint).Connected() { + sockState = linux.SS_CONNECTED + } else { + sockState = linux.SS_UNCONNECTED + } + } + + // In the socket entry below, the value for the 'Num' field requires + // some consideration. Linux prints the address to the struct + // unix_sock representing a socket in the kernel, but may redact the + // value for unprivileged users depending on the kptr_restrict + // sysctl. + // + // One use for this field is to allow a privileged user to + // introspect into the kernel memory to determine information about + // a socket not available through procfs, such as the socket's peer. + // + // On gvisor, returning a pointer to our internal structures would + // be pointless, as it wouldn't match the memory layout for struct + // unix_sock, making introspection difficult. We could populate a + // struct unix_sock with the appropriate data, but even that + // requires consideration for which kernel version to emulate, as + // the definition of this struct changes over time. + // + // For now, we always redact this pointer. + fmt.Fprintf(&buf, "%#016p: %08X %08X %08X %04X %02X %5d", + (*unix.SocketOperations)(nil), // Num, pointer to kernel socket struct. + sfile.ReadRefs()-1, // RefCount, don't count our own ref. + 0, // Protocol, always 0 for UDS. + sockFlags, // Flags. + sops.Endpoint().Type(), // Type. + sockState, // State. + sfile.InodeID(), // Inode. + ) + + // Path + if len(addr.Addr) != 0 { + if addr.Addr[0] == 0 { + // Abstract path. + fmt.Fprintf(&buf, " @%s", string(addr.Addr[1:])) + } else { + fmt.Fprintf(&buf, " %s", string(addr.Addr)) + } + } + fmt.Fprintf(&buf, "\n") + + sfile.DecRef() + } + + data := []seqfile.SeqData{{ + Buf: buf.Bytes(), + Handle: (*netUnix)(nil), + }} + return data, 0 +} diff --git a/pkg/sentry/fs/proc/proc.go b/pkg/sentry/fs/proc/proc.go index be04f94af..88018e707 100644 --- a/pkg/sentry/fs/proc/proc.go +++ b/pkg/sentry/fs/proc/proc.go @@ -85,7 +85,7 @@ func New(ctx context.Context, msrc *fs.MountSource) (*fs.Inode, error) { if _, ok := p.k.NetworkStack().(*rpcinet.Stack); ok { p.AddChild(ctx, "net", newRPCInetProcNet(ctx, msrc)) } else { - p.AddChild(ctx, "net", p.newNetDir(ctx, msrc)) + p.AddChild(ctx, "net", p.newNetDir(ctx, k, msrc)) } return newProcInode(p, msrc, fs.SpecialDirectory, nil), nil diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 43e9823cb..e7e5ff777 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -43,6 +43,7 @@ import ( "gvisor.googlesource.com/gvisor/pkg/cpuid" "gvisor.googlesource.com/gvisor/pkg/eventchannel" "gvisor.googlesource.com/gvisor/pkg/log" + "gvisor.googlesource.com/gvisor/pkg/refs" "gvisor.googlesource.com/gvisor/pkg/sentry/arch" "gvisor.googlesource.com/gvisor/pkg/sentry/context" "gvisor.googlesource.com/gvisor/pkg/sentry/fs" @@ -164,7 +165,7 @@ type Kernel struct { // nextInotifyCookie is a monotonically increasing counter used for // generating unique inotify event cookies. // - // nextInotifyCookie is mutable, and is accesed using atomic memory + // nextInotifyCookie is mutable, and is accessed using atomic memory // operations. nextInotifyCookie uint32 @@ -177,6 +178,10 @@ type Kernel struct { // danglingEndpoints is used to save / restore tcpip.DanglingEndpoints. danglingEndpoints struct{} `state:".([]tcpip.Endpoint)"` + + // socketTable is used to track all sockets on the system. Protected by + // extMu. + socketTable map[int]map[*refs.WeakRef]struct{} } // InitKernelArgs holds arguments to Init. @@ -266,6 +271,7 @@ func (k *Kernel) Init(args InitKernelArgs) error { k.monotonicClock = &timekeeperClock{tk: args.Timekeeper, c: sentrytime.Monotonic} k.futexes = futex.NewManager() k.netlinkPorts = port.New() + k.socketTable = make(map[int]map[*refs.WeakRef]struct{}) return nil } @@ -1051,6 +1057,56 @@ func (k *Kernel) EmitUnimplementedEvent(ctx context.Context) { }) } +// socketEntry represents a socket recorded in Kernel.socketTable. It implements +// refs.WeakRefUser for sockets stored in the socket table. +// +// +stateify savable +type socketEntry struct { + k *Kernel + sock *refs.WeakRef + family int +} + +// WeakRefGone implements refs.WeakRefUser.WeakRefGone. +func (s *socketEntry) WeakRefGone() { + s.k.extMu.Lock() + // k.socketTable is guaranteed to point to a valid socket table for s.family + // at this point, since we made sure of the fact when we created this + // socketEntry, and we never delete socket tables. + delete(s.k.socketTable[s.family], s.sock) + s.k.extMu.Unlock() +} + +// RecordSocket adds a socket to the system-wide socket table for tracking. +// +// Precondition: Caller must hold a reference to sock. +func (k *Kernel) RecordSocket(sock *fs.File, family int) { + k.extMu.Lock() + table, ok := k.socketTable[family] + if !ok { + table = make(map[*refs.WeakRef]struct{}) + k.socketTable[family] = table + } + se := socketEntry{k: k, family: family} + se.sock = refs.NewWeakRef(sock, &se) + table[se.sock] = struct{}{} + k.extMu.Unlock() +} + +// ListSockets returns a snapshot of all sockets of a given family. +func (k *Kernel) ListSockets(family int) []*refs.WeakRef { + k.extMu.Lock() + socks := []*refs.WeakRef{} + if table, ok := k.socketTable[family]; ok { + socks = make([]*refs.WeakRef, 0, len(table)) + for s, _ := range table { + socks = append(socks, s) + } + } + k.extMu.Unlock() + return socks +} + type supervisorContext struct { context.NoopSleeper log.Logger diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index e28d2c4fa..5ab423f3c 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -147,6 +147,7 @@ func New(t *kernel.Task, family int, stype transport.SockType, protocol int) (*f return nil, err } if s != nil { + t.Kernel().RecordSocket(s, family) return s, nil } } @@ -163,12 +164,15 @@ func Pair(t *kernel.Task, family int, stype transport.SockType, protocol int) (* } for _, p := range providers { - s, t, err := p.Pair(t, stype, protocol) + s1, s2, err := p.Pair(t, stype, protocol) if err != nil { return nil, nil, err } - if s != nil && t != nil { - return s, t, nil + if s1 != nil && s2 != nil { + k := t.Kernel() + k.RecordSocket(s1, family) + k.RecordSocket(s2, family) + return s1, s2, nil } } diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 19258e692..c857a0f33 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -219,6 +219,8 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, return 0, nil, 0, syserr.FromError(e) } + t.Kernel().RecordSocket(ns, linux.AF_UNIX) + return fd, addr, addrLen, nil } diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 148d9c366..53da121ec 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -534,6 +534,13 @@ syscall_test( syscall_test(test = "//test/syscalls/linux:write_test") +syscall_test( + test = "//test/syscalls/linux:proc_net_unix_test", + # Unix domain socket creation isn't supported on all file systems. The + # sentry-internal tmpfs is known to support it. + use_tmpfs = True, +) + go_binary( name = "syscall_test_runner", srcs = ["syscall_test_runner.go"], diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index a311ca12c..590ee1659 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -3102,3 +3102,20 @@ cc_binary( "@com_google_googletest//:gtest", ], ) + +cc_binary( + name = "proc_net_unix_test", + testonly = 1, + srcs = ["proc_net_unix.cc"], + linkstatic = 1, + deps = [ + ":unix_domain_socket_test_util", + "//test/util:file_descriptor", + "//test/util:fs_util", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_googletest//:gtest", + ], +) diff --git a/test/syscalls/linux/proc_net_unix.cc b/test/syscalls/linux/proc_net_unix.cc new file mode 100644 index 000000000..ea7c93012 --- /dev/null +++ b/test/syscalls/linux/proc_net_unix.cc @@ -0,0 +1,246 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "gtest/gtest.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "test/syscalls/linux/unix_domain_socket_test_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/fs_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { +namespace { + +using absl::StrCat; +using absl::StreamFormat; +using absl::StrFormat; + +constexpr char kProcNetUnixHeader[] = + "Num RefCount Protocol Flags Type St Inode Path"; + +// UnixEntry represents a single entry from /proc/net/unix. +struct UnixEntry { + uintptr_t addr; + uint64_t refs; + uint64_t protocol; + uint64_t flags; + uint64_t type; + uint64_t state; + uint64_t inode; + std::string path; +}; + +std::string ExtractPath(const struct sockaddr* addr) { + const char* path = + reinterpret_cast(addr)->sun_path; + // Note: sockaddr_un.sun_path is an embedded character array of length + // UNIX_PATH_MAX, so we can always safely dereference the first 2 bytes below. + // + // The kernel also enforces that the path is always null terminated. + if (path[0] == 0) { + // Abstract socket paths are null padded to the end of the struct + // sockaddr. However, these null bytes may or may not show up in + // /proc/net/unix depending on the kernel version. Truncate after the first + // null byte (by treating path as a c-std::string). + return StrCat("@", &path[1]); + } + return std::string(path); +} + +// Returns a parsed representation of /proc/net/unix entries. +PosixErrorOr> ProcNetUnixEntries() { + std::string content; + RETURN_IF_ERRNO(GetContents("/proc/net/unix", &content)); + + bool skipped_header = false; + std::vector entries; + std::vector lines = absl::StrSplit(content, absl::ByAnyChar("\n")); + for (std::string line : lines) { + if (!skipped_header) { + EXPECT_EQ(line, kProcNetUnixHeader); + skipped_header = true; + continue; + } + if (line.empty()) { + continue; + } + + // Abstract socket paths can have trailing null bytes in them depending on + // the linux version. Strip off everything after a null byte, including the + // null byte. + std::size_t null_pos = line.find('\0'); + if (null_pos != std::string::npos) { + line.erase(null_pos); + } + + // Parse a single entry from /proc/net/unix. + // + // Sample file: + // + // clang-format off + // + // Num RefCount Protocol Flags Type St Inode Path" + // ffffa130e7041c00: 00000002 00000000 00010000 0001 01 1299413685 /tmp/control_server/13293772586877554487 + // ffffa14f547dc400: 00000002 00000000 00010000 0001 01 3793 @remote_coredump + // + // clang-format on + // + // Note that from the second entry, the inode number can be padded using + // spaces, so we need to handle it separately during parsing. See + // net/unix/af_unix.c:unix_seq_show() for how these entries are produced. In + // particular, only the inode field is padded with spaces. + UnixEntry entry; + + // Process the first 6 fields, up to but not including "Inode". + std::vector fields = absl::StrSplit(line, absl::MaxSplits(' ', 6)); + + if (fields.size() < 7) { + return PosixError(EINVAL, StrFormat("Invalid entry: '%s'\n", line)); + } + + // AtoiBase can't handle the ':' in the "Num" field, so strip it out. + std::vector addr = absl::StrSplit(fields[0], ':'); + ASSIGN_OR_RETURN_ERRNO(entry.addr, AtoiBase(addr[0], 16)); + + ASSIGN_OR_RETURN_ERRNO(entry.refs, AtoiBase(fields[1], 16)); + ASSIGN_OR_RETURN_ERRNO(entry.protocol, AtoiBase(fields[2], 16)); + ASSIGN_OR_RETURN_ERRNO(entry.flags, AtoiBase(fields[3], 16)); + ASSIGN_OR_RETURN_ERRNO(entry.type, AtoiBase(fields[4], 16)); + ASSIGN_OR_RETURN_ERRNO(entry.state, AtoiBase(fields[5], 16)); + + absl::string_view rest = absl::StripAsciiWhitespace(fields[6]); + fields = absl::StrSplit(rest, absl::MaxSplits(' ', 1)); + if (fields.empty()) { + return PosixError( + EINVAL, StrFormat("Invalid entry, missing 'Inode': '%s'\n", line)); + } + ASSIGN_OR_RETURN_ERRNO(entry.inode, AtoiBase(fields[0], 10)); + + entry.path = ""; + if (fields.size() > 1) { + entry.path = fields[1]; + } + + entries.push_back(entry); + } + + return entries; +} + +// Finds the first entry in 'entries' for which 'predicate' returns true. +// Returns true on match, and sets 'match' to point to the matching entry. +bool FindBy(std::vector entries, UnixEntry* match, + std::function predicate) { + for (int i = 0; i < entries.size(); ++i) { + if (predicate(entries[i])) { + *match = entries[i]; + return true; + } + } + return false; +} + +bool FindByPath(std::vector entries, UnixEntry* match, + const std::string& path) { + return FindBy(entries, match, [path](UnixEntry e) { return e.path == path; }); +} + +TEST(ProcNetUnix, Exists) { + const std::string content = + ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/unix")); + const std::string header_line = StrCat(kProcNetUnixHeader, "\n"); + if (IsRunningOnGvisor()) { + // Should be just the header since we don't have any unix domain sockets + // yet. + EXPECT_EQ(content, header_line); + } else { + // However, on a general linux machine, we could have abitrary sockets on + // the system, so just check the header. + EXPECT_THAT(content, ::testing::StartsWith(header_line)); + } +} + +TEST(ProcNetUnix, FilesystemBindAcceptConnect) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE( + FilesystemBoundUnixDomainSocketPair(SOCK_STREAM).Create()); + + std::string path1 = ExtractPath(sockets->first_addr()); + std::string path2 = ExtractPath(sockets->second_addr()); + std::cout << StreamFormat("Server socket address: %s\n", path1); + std::cout << StreamFormat("Client socket address: %s\n", path2); + + std::vector entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + if (IsRunningOnGvisor()) { + EXPECT_EQ(entries.size(), 2); + } + + // The server-side socket's path is listed in the socket entry... + UnixEntry s1; + EXPECT_TRUE(FindByPath(entries, &s1, path1)); + + // ... but the client-side socket's path is not. + UnixEntry s2; + EXPECT_FALSE(FindByPath(entries, &s2, path2)); +} + +TEST(ProcNetUnix, AbstractBindAcceptConnect) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE( + AbstractBoundUnixDomainSocketPair(SOCK_STREAM).Create()); + + std::string path1 = ExtractPath(sockets->first_addr()); + std::string path2 = ExtractPath(sockets->second_addr()); + std::cout << StreamFormat("Server socket address: '%s'\n", path1); + std::cout << StreamFormat("Client socket address: '%s'\n", path2); + + std::vector entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + if (IsRunningOnGvisor()) { + EXPECT_EQ(entries.size(), 2); + } + + // The server-side socket's path is listed in the socket entry... + UnixEntry s1; + EXPECT_TRUE(FindByPath(entries, &s1, path1)); + + // ... but the client-side socket's path is not. + UnixEntry s2; + EXPECT_FALSE(FindByPath(entries, &s2, path2)); +} + +TEST(ProcNetUnix, SocketPair) { + // Under gvisor, ensure a socketpair() syscall creates exactly 2 new + // entries. We have no way to verify this under Linux, as we have no control + // over socket creation on a general Linux machine. + SKIP_IF(!IsRunningOnGvisor()); + + std::vector entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + ASSERT_EQ(entries.size(), 0); + + auto sockets = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_STREAM).Create()); + + entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + EXPECT_EQ(entries.size(), 2); +} + +} // namespace +} // namespace testing +} // namespace gvisor