cgroupfs: Implement task migration

PiperOrigin-RevId: 430273113
This commit is contained in:
Rahat Mahmood 2022-02-22 12:41:20 -08:00 committed by gVisor bot
parent 377cfd813d
commit 8643fe526e
20 changed files with 367 additions and 42 deletions

View File

@ -19,6 +19,7 @@ import (
"fmt"
"sort"
"strconv"
"strings"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
@ -27,6 +28,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/usermem"
)
@ -91,6 +93,26 @@ type controller interface {
// AddControlFiles should extend the contents map with inodes representing
// control files defined by this controller.
AddControlFiles(ctx context.Context, creds *auth.Credentials, c *cgroupInode, contents map[string]kernfs.Inode)
// PrepareMigrate signals the controller that a migration is about to
// happen. The controller should check for any conditions that would prevent
// the migration. If PrepareMigrate succeeds, the controller must
// unconditionally either accept the migration via CommitMigrate, or roll it
// back via AbortMigrate.
//
// Postcondition: If PrepareMigrate returns nil, caller must resolve the
// migration by calling either CommitMigrate or AbortMigrate.
PrepareMigrate(t *kernel.Task, src controller) error
// CommitMigrate completes an in-flight migration.
//
// Precondition: Caller must call a corresponding PrepareMigrate.
CommitMigrate(t *kernel.Task, src controller)
// AbortMigrate cancels an in-flight migration.
//
// Precondition: Caller must call a corresponding PrepareMigrate.
AbortMigrate(t *kernel.Task, src controller)
}
// cgroupInode implements kernel.CgroupImpl and kernfs.Inode.
@ -185,6 +207,54 @@ func (c *cgroupInode) Leave(t *kernel.Task) {
c.fs.tasksMu.Unlock()
}
// PrepareMigrate implements kernel.CgroupImpl.PrepareMigrate.
func (c *cgroupInode) PrepareMigrate(t *kernel.Task, src *kernel.Cgroup) error {
prepared := make([]controller, 0, len(c.controllers))
rollback := func() {
for _, p := range prepared {
c.controllers[p.Type()].AbortMigrate(t, p)
}
}
for srcType, srcCtl := range src.CgroupImpl.(*cgroupInode).controllers {
ctl := c.controllers[srcType]
if err := ctl.PrepareMigrate(t, srcCtl); err != nil {
rollback()
return err
}
prepared = append(prepared, srcCtl)
}
return nil
}
// CommitMigrate implements kernel.CgroupImpl.CommitMigrate.
func (c *cgroupInode) CommitMigrate(t *kernel.Task, src *kernel.Cgroup) {
for srcType, srcCtl := range src.CgroupImpl.(*cgroupInode).controllers {
c.controllers[srcType].CommitMigrate(t, srcCtl)
}
srcI := src.CgroupImpl.(*cgroupInode)
c.fs.tasksMu.Lock()
defer c.fs.tasksMu.Unlock()
delete(srcI.ts, t)
c.ts[t] = struct{}{}
}
// AbortMigrate implements kernel.CgroupImpl.AbortMigrate.
func (c *cgroupInode) AbortMigrate(t *kernel.Task, src *kernel.Cgroup) {
for srcType, srcCtl := range src.CgroupImpl.(*cgroupInode).controllers {
c.controllers[srcType].AbortMigrate(t, srcCtl)
}
}
func (c *cgroupInode) Cgroup(fd *vfs.FileDescription) kernel.Cgroup {
return kernel.Cgroup{
Dentry: fd.Dentry().Impl().(*kernfs.Dentry),
CgroupImpl: c,
}
}
func sortTIDs(tids []kernel.ThreadID) {
sort.Slice(tids, func(i, j int) bool { return tids[i] < tids[j] })
}
@ -223,9 +293,16 @@ func (d *cgroupProcsData) Generate(ctx context.Context, buf *bytes.Buffer) error
}
// Write implements vfs.WritableDynamicBytesSource.Write.
func (d *cgroupProcsData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
// TODO(b/183137098): Payload is the pid for a process to add to this cgroup.
return src.NumBytes(), nil
func (d *cgroupProcsData) Write(ctx context.Context, fd *vfs.FileDescription, src usermem.IOSequence, offset int64) (int64, error) {
tgid, n, err := parseInt64FromString(ctx, src)
if err != nil {
return n, err
}
t := kernel.TaskFromContext(ctx)
currPidns := t.ThreadGroup().PIDNamespace()
targetTG := currPidns.ThreadGroupWithID(kernel.ThreadID(tgid))
return n, targetTG.MigrateCgroup(d.Cgroup(fd))
}
// +stateify savable
@ -255,9 +332,16 @@ func (d *tasksData) Generate(ctx context.Context, buf *bytes.Buffer) error {
}
// Write implements vfs.WritableDynamicBytesSource.Write.
func (d *tasksData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
// TODO(b/183137098): Payload is the pid for a process to add to this cgroup.
return src.NumBytes(), nil
func (d *tasksData) Write(ctx context.Context, fd *vfs.FileDescription, src usermem.IOSequence, offset int64) (int64, error) {
tid, n, err := parseInt64FromString(ctx, src)
if err != nil {
return n, err
}
t := kernel.TaskFromContext(ctx)
currPidns := t.ThreadGroup().PIDNamespace()
targetTask := currPidns.TaskWithID(kernel.ThreadID(tid))
return n, targetTask.MigrateCgroup(d.Cgroup(fd))
}
// parseInt64FromString interprets src as string encoding a int64 value, and
@ -272,15 +356,30 @@ func parseInt64FromString(ctx context.Context, src usermem.IOSequence) (val, len
if err != nil {
return 0, int64(n), err
}
buf = buf[:n]
str := strings.TrimSpace(string(buf[:n]))
val, err = strconv.ParseInt(string(buf), 10, 64)
val, err = strconv.ParseInt(str, 10, 64)
if err != nil {
// Note: This also handles zero-len writes if offset is beyond the end
// of src, or src is empty.
ctx.Warningf("cgroupfs.parseInt64FromString: failed to parse %q: %v", string(buf), err)
ctx.Warningf("cgroupfs.parseInt64FromString: failed to parse %q: %v", str, err)
return 0, int64(n), linuxerr.EINVAL
}
return val, int64(n), nil
}
// controllerNoopMigrate partially implements controller. It stubs the migration
// methods with noops for a stateless controller.
type controllerNoopMigrate struct{}
// PrepareMigrate implements controller.PrepareMigrate.
func (*controllerNoopMigrate) PrepareMigrate(t *kernel.Task, src controller) error {
return nil
}
// CommitMigrate implements controller.CommitMigrate.
func (*controllerNoopMigrate) CommitMigrate(t *kernel.Task, src controller) {}
// AbortMigrate implements controller.AbortMigrate.
func (*controllerNoopMigrate) AbortMigrate(t *kernel.Task, src controller) {}

View File

@ -26,6 +26,7 @@ import (
// +stateify savable
type cpuController struct {
controllerCommon
controllerNoopMigrate
// CFS bandwidth control parameters, values in microseconds.
cfsPeriod int64

View File

@ -28,6 +28,7 @@ import (
// +stateify savable
type cpuacctController struct {
controllerCommon
controllerNoopMigrate
}
var _ controller = (*cpuacctController)(nil)

View File

@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/usermem"
)
@ -33,6 +34,7 @@ import (
// +stateify savable
type cpusetController struct {
controllerCommon
controllerNoopMigrate
maxCpus uint32
maxMems uint32
@ -95,7 +97,7 @@ func (d *cpusData) Generate(ctx context.Context, buf *bytes.Buffer) error {
}
// Write implements vfs.WritableDynamicBytesSource.Write.
func (d *cpusData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
func (d *cpusData) Write(ctx context.Context, _ *vfs.FileDescription, src usermem.IOSequence, offset int64) (int64, error) {
if src.NumBytes() > hostarch.PageSize {
return 0, linuxerr.EINVAL
}
@ -139,7 +141,7 @@ func (d *memsData) Generate(ctx context.Context, buf *bytes.Buffer) error {
}
// Write implements vfs.WritableDynamicBytesSource.Write.
func (d *memsData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
func (d *memsData) Write(ctx context.Context, _ *vfs.FileDescription, src usermem.IOSequence, offset int64) (int64, error) {
if src.NumBytes() > hostarch.PageSize {
return 0, linuxerr.EINVAL
}

View File

@ -21,12 +21,15 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/usermem"
)
// +stateify savable
type jobController struct {
controllerCommon
controllerNoopMigrate
id int64
}
@ -63,7 +66,7 @@ func (d *jobIDData) Generate(ctx context.Context, buf *bytes.Buffer) error {
}
// Write implements vfs.WritableDynamicBytesSource.Write.
func (d *jobIDData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
func (d *jobIDData) Write(ctx context.Context, _ *vfs.FileDescription, src usermem.IOSequence, offset int64) (int64, error) {
val, n, err := parseInt64FromString(ctx, src)
if err != nil {
return n, err

View File

@ -30,6 +30,7 @@ import (
// +stateify savable
type memoryController struct {
controllerCommon
controllerNoopMigrate
limitBytes int64
softLimitBytes int64

View File

@ -101,7 +101,7 @@ func (fd *DynamicBytesFD) Init(m *vfs.Mount, d *Dentry, data vfs.DynamicBytesSou
return err
}
fd.inode = d.inode
fd.SetDataSource(data)
fd.DynamicBytesFileDescriptionImpl.Init(&fd.vfsfd, data)
return nil
}

View File

@ -77,7 +77,7 @@ func (fd *queueFD) Init(m *vfs.Mount, d *kernfs.Dentry, data vfs.DynamicBytesSou
return err
}
fd.inode = d.Inode()
fd.SetDataSource(data)
fd.DynamicBytesFileDescriptionImpl.Init(&fd.vfsfd, data)
return nil
}

View File

@ -303,6 +303,7 @@ type idMapData struct {
}
var _ dynamicInode = (*idMapData)(nil)
var _ vfs.WritableDynamicBytesSource = (*idMapData)(nil)
// Generate implements vfs.WritableDynamicBytesSource.Generate.
func (d *idMapData) Generate(ctx context.Context, buf *bytes.Buffer) error {
@ -319,7 +320,7 @@ func (d *idMapData) Generate(ctx context.Context, buf *bytes.Buffer) error {
}
// Write implements vfs.WritableDynamicBytesSource.Write.
func (d *idMapData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
func (d *idMapData) Write(ctx context.Context, _ *vfs.FileDescription, src usermem.IOSequence, offset int64) (int64, error) {
// "In addition, the number of bytes written to the file must be less than
// the system page size, and the write must be performed at the start of
// the file ..." - user_namespaces(7)
@ -718,7 +719,7 @@ func (s *statusInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs
if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil {
return nil, err
}
fd.SetDataSource(fd)
fd.DynamicBytesFileDescriptionImpl.Init(&fd.vfsfd, fd)
return &fd.vfsfd, nil
}
@ -863,7 +864,7 @@ func (o *oomScoreAdj) Generate(ctx context.Context, buf *bytes.Buffer) error {
}
// Write implements vfs.WritableDynamicBytesSource.Write.
func (o *oomScoreAdj) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
func (o *oomScoreAdj) Write(ctx context.Context, _ *vfs.FileDescription, src usermem.IOSequence, offset int64) (int64, error) {
if src.NumBytes() == 0 {
return 0, nil
}

View File

@ -209,7 +209,7 @@ func (d *tcpSackData) Generate(ctx context.Context, buf *bytes.Buffer) error {
}
// Write implements vfs.WritableDynamicBytesSource.Write.
func (d *tcpSackData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
func (d *tcpSackData) Write(ctx context.Context, _ *vfs.FileDescription, src usermem.IOSequence, offset int64) (int64, error) {
if offset != 0 {
// No need to handle partial writes thus far.
return 0, linuxerr.EINVAL
@ -257,7 +257,7 @@ func (d *tcpRecoveryData) Generate(ctx context.Context, buf *bytes.Buffer) error
}
// Write implements vfs.WritableDynamicBytesSource.Write.
func (d *tcpRecoveryData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
func (d *tcpRecoveryData) Write(ctx context.Context, _ *vfs.FileDescription, src usermem.IOSequence, offset int64) (int64, error) {
if offset != 0 {
// No need to handle partial writes thus far.
return 0, linuxerr.EINVAL
@ -311,7 +311,7 @@ func (d *tcpMemData) Generate(ctx context.Context, buf *bytes.Buffer) error {
}
// Write implements vfs.WritableDynamicBytesSource.Write.
func (d *tcpMemData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
func (d *tcpMemData) Write(ctx context.Context, _ *vfs.FileDescription, src usermem.IOSequence, offset int64) (int64, error) {
if offset != 0 {
// No need to handle partial writes thus far.
return 0, linuxerr.EINVAL
@ -396,7 +396,7 @@ func (ipf *ipForwarding) Generate(ctx context.Context, buf *bytes.Buffer) error
}
// Write implements vfs.WritableDynamicBytesSource.Write.
func (ipf *ipForwarding) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
func (ipf *ipForwarding) Write(ctx context.Context, _ *vfs.FileDescription, src usermem.IOSequence, offset int64) (int64, error) {
if offset != 0 {
// No need to handle partial writes thus far.
return 0, linuxerr.EINVAL
@ -449,7 +449,7 @@ func (pr *portRange) Generate(ctx context.Context, buf *bytes.Buffer) error {
}
// Write implements vfs.WritableDynamicBytesSource.Write.
func (pr *portRange) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
func (pr *portRange) Write(ctx context.Context, _ *vfs.FileDescription, src usermem.IOSequence, offset int64) (int64, error) {
if offset != 0 {
// No need to handle partial writes thus far.
return 0, linuxerr.EINVAL

View File

@ -136,7 +136,7 @@ func TestConfigureIPForwarding(t *testing.T) {
// Write the values.
src := usermem.BytesIOSequence([]byte(c.str))
if n, err := file.Write(ctx, src, 0); n != int64(len(c.str)) || err != nil {
if n, err := file.Write(ctx, nil, src, 0); n != int64(len(c.str)) || err != nil {
t.Errorf("file.Write(ctx, nil, %q, 0) = (%d, %v); want (%d, nil)", c.str, n, err, len(c.str))
}

View File

@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/usermem"
)
@ -46,6 +47,8 @@ type yamaPtraceScope struct {
level *int32
}
var _ vfs.WritableDynamicBytesSource = (*yamaPtraceScope)(nil)
// Generate implements vfs.DynamicBytesSource.Generate.
func (s *yamaPtraceScope) Generate(ctx context.Context, buf *bytes.Buffer) error {
_, err := fmt.Fprintf(buf, "%d\n", atomic.LoadInt32(s.level))
@ -53,7 +56,7 @@ func (s *yamaPtraceScope) Generate(ctx context.Context, buf *bytes.Buffer) error
}
// Write implements vfs.WritableDynamicBytesSource.Write.
func (s *yamaPtraceScope) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
func (s *yamaPtraceScope) Write(ctx context.Context, _ *vfs.FileDescription, src usermem.IOSequence, offset int64) (int64, error) {
if offset != 0 {
// Ignore partial writes.
return 0, linuxerr.EINVAL

View File

@ -86,11 +86,46 @@ func (c *Cgroup) HierarchyID() uint32 {
return c.Controllers()[0].HierarchyID()
}
// CgroupMigrationContext represents an in-flight cgroup migration for
// a single task.
type CgroupMigrationContext struct {
src Cgroup
dst Cgroup
t *Task
}
// Abort cancels a migration.
func (ctx *CgroupMigrationContext) Abort() {
ctx.dst.AbortMigrate(ctx.t, &ctx.src)
}
// Commit completes a migration.
func (ctx *CgroupMigrationContext) Commit() {
ctx.dst.CommitMigrate(ctx.t, &ctx.src)
}
// CgroupImpl is the common interface to cgroups.
type CgroupImpl interface {
// Controllers lists the controller associated with this cgroup.
Controllers() []CgroupController
// Enter moves t into this cgroup.
Enter(t *Task)
// Leave moves t out of this cgroup.
Leave(t *Task)
// PrepareMigrate initiates a migration of t from src to this cgroup. See
// cgroupfs.controller.PrepareMigrate.
PrepareMigrate(t *Task, src *Cgroup) error
// CommitMigrate completes an in-flight migration. See
// cgroupfs.controller.CommitMigrate.
CommitMigrate(t *Task, src *Cgroup)
// AbortMigrate cancels an in-flight migration. See
// cgroupfs.controller.AbortMigrate.
AbortMigrate(t *Task, src *Cgroup)
}
// hierarchy represents a cgroupfs filesystem instance, with a unique set of

View File

@ -58,15 +58,9 @@ func (t *Task) EnterCgroup(c Cgroup) error {
defer t.mu.Unlock()
for oldCG, _ := range t.cgroups {
for _, oldCtl := range oldCG.Controllers() {
if _, ok := newControllers[oldCtl.Type()]; ok {
// Already in a cgroup with the same controller as one of the
// new ones. Requires migration between cgroups.
//
// TODO(b/183137098): Implement cgroup migration.
log.Warningf("Cgroup migration is not implemented")
return linuxerr.EBUSY
}
if oldCG.HierarchyID() == c.HierarchyID() {
log.Warningf("Cannot enter new cgroup %v due to conflicting controllers. Try migrate instead?", c)
return linuxerr.EBUSY
}
}
@ -107,6 +101,82 @@ func (t *Task) leaveCgroupLocked(c Cgroup) {
c.decRef()
}
// +checklocks:t.mu
func (t *Task) findCgroupWithMatchingHierarchyLocked(other Cgroup) (Cgroup, bool) {
for c, _ := range t.cgroups {
if c.HierarchyID() != other.HierarchyID() {
continue
}
return c, true
}
return Cgroup{}, false
}
// CgroupPrepareMigrate starts a cgroup migration for this task to dst. The
// migration must be completed through the returned context.
func (t *Task) CgroupPrepareMigrate(dst Cgroup) (*CgroupMigrationContext, error) {
t.mu.Lock()
defer t.mu.Unlock()
src, found := t.findCgroupWithMatchingHierarchyLocked(dst)
if !found {
log.Warningf("Cannot migrate to cgroup %v since task %v not currently in target hierarchy %v", dst, t, dst.HierarchyID())
return nil, linuxerr.EINVAL
}
if err := dst.PrepareMigrate(t, &src); err != nil {
return nil, err
}
return &CgroupMigrationContext{
src: src,
dst: dst,
t: t,
}, nil
}
// MigrateCgroup migrates all tasks in tg to the dst cgroup. Either all tasks
// are migrated, or none are. Atomicity of migrations wrt cgroup membership
// (i.e. a task can't switch cgroups mid-migration due to another migration) is
// guaranteed because migrations are serialized by TaskSet.mu.
func (tg *ThreadGroup) MigrateCgroup(dst Cgroup) error {
tg.pidns.owner.mu.RLock()
defer tg.pidns.owner.mu.RUnlock()
var ctxs []*CgroupMigrationContext
// Prepare migrations. On partial failure, abort.
for t := tg.tasks.Front(); t != nil; t = t.Next() {
ctx, err := t.CgroupPrepareMigrate(dst)
if err != nil {
// Rollback.
for _, ctx := range ctxs {
ctx.Abort()
}
return err
}
ctxs = append(ctxs, ctx)
}
// All migrations are now guaranteed to succeed.
for _, ctx := range ctxs {
ctx.Commit()
}
return nil
}
// MigrateCgroup migrates this task to the dst cgroup.
func (t *Task) MigrateCgroup(dst Cgroup) error {
t.tg.pidns.owner.mu.RLock()
defer t.tg.pidns.owner.mu.RUnlock()
ctx, err := t.CgroupPrepareMigrate(dst)
if err != nil {
return err
}
ctx.Commit()
return nil
}
// taskCgroupEntry represents a line in /proc/<pid>/cgroup, and is used to
// format a cgroup for display.
type taskCgroupEntry struct {

View File

@ -251,7 +251,7 @@ type WritableDynamicBytesSource interface {
DynamicBytesSource
// Write sends writes to the source.
Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error)
Write(ctx context.Context, fd *FileDescription, src usermem.IOSequence, offset int64) (int64, error)
}
// DynamicBytesFileDescriptionImpl may be embedded by implementations of
@ -262,11 +262,12 @@ type WritableDynamicBytesSource interface {
// If data additionally implements WritableDynamicBytesSource, writes are
// dispatched to the implementer. The source data is not automatically modified.
//
// DynamicBytesFileDescriptionImpl.SetDataSource() must be called before first
// DynamicBytesFileDescriptionImpl.Init() must be called before first
// use.
//
// +stateify savable
type DynamicBytesFileDescriptionImpl struct {
vfsfd *FileDescription // immutable
data DynamicBytesSource // immutable
mu sync.Mutex `state:"nosave"` // protects the following fields
buf bytes.Buffer `state:".([]byte)"`
@ -282,8 +283,9 @@ func (fd *DynamicBytesFileDescriptionImpl) loadBuf(p []byte) {
fd.buf.Write(p)
}
// SetDataSource must be called exactly once on fd before first use.
func (fd *DynamicBytesFileDescriptionImpl) SetDataSource(data DynamicBytesSource) {
// Init must be called before first use.
func (fd *DynamicBytesFileDescriptionImpl) Init(vfsfd *FileDescription, data DynamicBytesSource) {
fd.vfsfd = vfsfd
fd.data = data
}
@ -376,7 +378,7 @@ func (fd *DynamicBytesFileDescriptionImpl) pwriteLocked(ctx context.Context, src
if !ok {
return 0, linuxerr.EIO
}
n, err := writable.Write(ctx, src, offset)
n, err := writable.Write(ctx, fd.vfsfd, src, offset)
if err != nil {
return 0, err
}

View File

@ -61,7 +61,7 @@ func (d *storeData) Generate(ctx context.Context, buf *bytes.Buffer) error {
}
// Generate implements WritableDynamicBytesSource.
func (d *storeData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
func (d *storeData) Write(ctx context.Context, _ *FileDescription, src usermem.IOSequence, offset int64) (int64, error) {
buf := make([]byte, src.NumBytes())
n, err := src.CopyIn(ctx, buf)
if err != nil {
@ -84,9 +84,9 @@ func newTestFD(ctx context.Context, vfsObj *VirtualFilesystem, statusFlags uint3
vd := vfsObj.NewAnonVirtualDentry("genCountFD")
defer vd.DecRef(ctx)
var fd testFD
fd.vfsfd.Init(&fd, statusFlags, vd.Mount(), vd.Dentry(), &FileDescriptionOptions{})
fd.DynamicBytesFileDescriptionImpl.SetDataSource(data)
return &fd.vfsfd
fd.fileDescription.vfsfd.Init(&fd, statusFlags, vd.Mount(), vd.Dentry(), &FileDescriptionOptions{})
fd.DynamicBytesFileDescriptionImpl.Init(&fd.fileDescription.vfsfd, data)
return &fd.fileDescription.vfsfd
}
// Release implements FileDescriptionImpl.Release.

View File

@ -4497,6 +4497,7 @@ cc_binary(
"//test/util:fs_util",
"//test/util:mount_util",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
gtest,
"//test/util:cleanup",
"//test/util:posix_error",

View File

@ -23,6 +23,7 @@
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_split.h"
#include "absl/synchronization/notification.h"
#include "test/util/capability_util.h"
#include "test/util/cgroup_util.h"
#include "test/util/cleanup.h"
@ -53,6 +54,39 @@ bool CgroupsAvailable() {
TEST_CHECK_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN));
}
// NoopThreads spawns a set of threads that do nothing until they're asked to
// exit. Useful for testing functionality that requires a process with multiple
// threads.
class NoopThreads {
public:
NoopThreads(int count) {
auto noop = [this]() { exit_.WaitForNotification(); };
for (int i = 0; i < count; ++i) {
threads_.emplace_back(noop);
}
}
~NoopThreads() { Join(); }
void Join() {
if (joined_) {
return;
}
joined_ = true;
exit_.Notify();
for (auto& thread : threads_) {
thread.Join();
}
}
private:
std::list<ScopedThread> threads_;
absl::Notification exit_;
bool joined_ = false;
};
TEST(Cgroup, MountSucceeds) {
SKIP_IF(!CgroupsAvailable());
@ -326,6 +360,49 @@ TEST(Cgroup, SubcontainersHaveIndependentState) {
EXPECT_THAT(c.ReadIntegerControlFile("job.id"), IsPosixErrorOkAndHolds(5678));
}
TEST(Cgroup, MigrateToSubcontainer) {
SKIP_IF(!CgroupsAvailable());
Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs(""));
Cgroup child = ASSERT_NO_ERRNO_AND_VALUE(c.CreateChild("child1"));
// Initially, test process should be in the root cgroup c.
EXPECT_NO_ERRNO(c.ContainsCallingProcess());
pid_t pid = getpid();
EXPECT_NO_ERRNO(child.Enter(pid));
// After migration, child should contain the test process, and the c should
// not.
EXPECT_NO_ERRNO(child.ContainsCallingProcess());
auto procs = ASSERT_NO_ERRNO_AND_VALUE(c.Procs());
EXPECT_FALSE(procs.contains(pid));
}
TEST(Cgroup, MigrateToSubcontainerThread) {
SKIP_IF(!CgroupsAvailable());
Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs(""));
Cgroup child = ASSERT_NO_ERRNO_AND_VALUE(c.CreateChild("child1"));
// Ensure there are some threads for this process.
NoopThreads threads(10);
// Initially, test process should be in the root cgroup c.
EXPECT_NO_ERRNO(c.ContainsCallingThread());
const pid_t tid = syscall(SYS_gettid);
EXPECT_NO_ERRNO(child.EnterThread(tid));
// After migration, child should contain the test process, and the c should
// not.
EXPECT_NO_ERRNO(child.ContainsCallingThread());
auto tasks = ASSERT_NO_ERRNO_AND_VALUE(c.Tasks());
EXPECT_FALSE(tasks.contains(tid));
}
TEST(MemoryCgroup, MemoryUsageInBytes) {
SKIP_IF(!CgroupsAvailable());

View File

@ -106,6 +106,24 @@ PosixError Cgroup::ContainsCallingProcess() const {
return NoError();
}
PosixError Cgroup::ContainsCallingThread() const {
ASSIGN_OR_RETURN_ERRNO(const absl::flat_hash_set<pid_t> tasks, Tasks());
const pid_t tid = syscall(SYS_gettid);
if (!tasks.contains(tid)) {
return PosixError(ENOENT,
absl::StrFormat("Cgroup doesn't contain task %d", tid));
}
return NoError();
}
PosixError Cgroup::Enter(pid_t pid) const {
return WriteIntegerControlFile("cgroup.procs", static_cast<int64_t>(pid));
}
PosixError Cgroup::EnterThread(pid_t pid) const {
return WriteIntegerControlFile("tasks", static_cast<int64_t>(pid));
}
PosixErrorOr<absl::flat_hash_set<pid_t>> Cgroup::ParsePIDList(
absl::string_view data) const {
absl::flat_hash_set<pid_t> res;

View File

@ -74,8 +74,19 @@ class Cgroup {
PosixErrorOr<absl::flat_hash_set<pid_t>> Tasks() const;
// ContainsCallingProcess checks whether the calling process is part of the
// cgroup.
PosixError ContainsCallingProcess() const;
// ContainsCallingThread checks whether the calling thread is part of the
// cgroup.
PosixError ContainsCallingThread() const;
// Moves process with the specified pid to this cgroup.
PosixError Enter(pid_t pid) const;
// Moves thread with the specified pid to this cgroup.
PosixError EnterThread(pid_t pid) const;
private:
PosixErrorOr<absl::flat_hash_set<pid_t>> ParsePIDList(
absl::string_view data) const;