diff --git a/pkg/sentry/fsimpl/cgroupfs/base.go b/pkg/sentry/fsimpl/cgroupfs/base.go index f757dcda0..28a494ba2 100644 --- a/pkg/sentry/fsimpl/cgroupfs/base.go +++ b/pkg/sentry/fsimpl/cgroupfs/base.go @@ -19,6 +19,7 @@ import ( "fmt" "sort" "strconv" + "strings" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -27,6 +28,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/usermem" ) @@ -91,6 +93,26 @@ type controller interface { // AddControlFiles should extend the contents map with inodes representing // control files defined by this controller. AddControlFiles(ctx context.Context, creds *auth.Credentials, c *cgroupInode, contents map[string]kernfs.Inode) + + // PrepareMigrate signals the controller that a migration is about to + // happen. The controller should check for any conditions that would prevent + // the migration. If PrepareMigrate succeeds, the controller must + // unconditionally either accept the migration via CommitMigrate, or roll it + // back via AbortMigrate. + // + // Postcondition: If PrepareMigrate returns nil, caller must resolve the + // migration by calling either CommitMigrate or AbortMigrate. + PrepareMigrate(t *kernel.Task, src controller) error + + // CommitMigrate completes an in-flight migration. + // + // Precondition: Caller must call a corresponding PrepareMigrate. + CommitMigrate(t *kernel.Task, src controller) + + // AbortMigrate cancels an in-flight migration. + // + // Precondition: Caller must call a corresponding PrepareMigrate. + AbortMigrate(t *kernel.Task, src controller) } // cgroupInode implements kernel.CgroupImpl and kernfs.Inode. @@ -185,6 +207,54 @@ func (c *cgroupInode) Leave(t *kernel.Task) { c.fs.tasksMu.Unlock() } +// PrepareMigrate implements kernel.CgroupImpl.PrepareMigrate. +func (c *cgroupInode) PrepareMigrate(t *kernel.Task, src *kernel.Cgroup) error { + prepared := make([]controller, 0, len(c.controllers)) + rollback := func() { + for _, p := range prepared { + c.controllers[p.Type()].AbortMigrate(t, p) + } + } + + for srcType, srcCtl := range src.CgroupImpl.(*cgroupInode).controllers { + ctl := c.controllers[srcType] + if err := ctl.PrepareMigrate(t, srcCtl); err != nil { + rollback() + return err + } + prepared = append(prepared, srcCtl) + } + return nil +} + +// CommitMigrate implements kernel.CgroupImpl.CommitMigrate. +func (c *cgroupInode) CommitMigrate(t *kernel.Task, src *kernel.Cgroup) { + for srcType, srcCtl := range src.CgroupImpl.(*cgroupInode).controllers { + c.controllers[srcType].CommitMigrate(t, srcCtl) + } + + srcI := src.CgroupImpl.(*cgroupInode) + c.fs.tasksMu.Lock() + defer c.fs.tasksMu.Unlock() + + delete(srcI.ts, t) + c.ts[t] = struct{}{} +} + +// AbortMigrate implements kernel.CgroupImpl.AbortMigrate. +func (c *cgroupInode) AbortMigrate(t *kernel.Task, src *kernel.Cgroup) { + for srcType, srcCtl := range src.CgroupImpl.(*cgroupInode).controllers { + c.controllers[srcType].AbortMigrate(t, srcCtl) + } +} + +func (c *cgroupInode) Cgroup(fd *vfs.FileDescription) kernel.Cgroup { + return kernel.Cgroup{ + Dentry: fd.Dentry().Impl().(*kernfs.Dentry), + CgroupImpl: c, + } +} + func sortTIDs(tids []kernel.ThreadID) { sort.Slice(tids, func(i, j int) bool { return tids[i] < tids[j] }) } @@ -223,9 +293,16 @@ func (d *cgroupProcsData) Generate(ctx context.Context, buf *bytes.Buffer) error } // Write implements vfs.WritableDynamicBytesSource.Write. -func (d *cgroupProcsData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { - // TODO(b/183137098): Payload is the pid for a process to add to this cgroup. - return src.NumBytes(), nil +func (d *cgroupProcsData) Write(ctx context.Context, fd *vfs.FileDescription, src usermem.IOSequence, offset int64) (int64, error) { + tgid, n, err := parseInt64FromString(ctx, src) + if err != nil { + return n, err + } + + t := kernel.TaskFromContext(ctx) + currPidns := t.ThreadGroup().PIDNamespace() + targetTG := currPidns.ThreadGroupWithID(kernel.ThreadID(tgid)) + return n, targetTG.MigrateCgroup(d.Cgroup(fd)) } // +stateify savable @@ -255,9 +332,16 @@ func (d *tasksData) Generate(ctx context.Context, buf *bytes.Buffer) error { } // Write implements vfs.WritableDynamicBytesSource.Write. -func (d *tasksData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { - // TODO(b/183137098): Payload is the pid for a process to add to this cgroup. - return src.NumBytes(), nil +func (d *tasksData) Write(ctx context.Context, fd *vfs.FileDescription, src usermem.IOSequence, offset int64) (int64, error) { + tid, n, err := parseInt64FromString(ctx, src) + if err != nil { + return n, err + } + + t := kernel.TaskFromContext(ctx) + currPidns := t.ThreadGroup().PIDNamespace() + targetTask := currPidns.TaskWithID(kernel.ThreadID(tid)) + return n, targetTask.MigrateCgroup(d.Cgroup(fd)) } // parseInt64FromString interprets src as string encoding a int64 value, and @@ -272,15 +356,30 @@ func parseInt64FromString(ctx context.Context, src usermem.IOSequence) (val, len if err != nil { return 0, int64(n), err } - buf = buf[:n] + str := strings.TrimSpace(string(buf[:n])) - val, err = strconv.ParseInt(string(buf), 10, 64) + val, err = strconv.ParseInt(str, 10, 64) if err != nil { // Note: This also handles zero-len writes if offset is beyond the end // of src, or src is empty. - ctx.Warningf("cgroupfs.parseInt64FromString: failed to parse %q: %v", string(buf), err) + ctx.Warningf("cgroupfs.parseInt64FromString: failed to parse %q: %v", str, err) return 0, int64(n), linuxerr.EINVAL } return val, int64(n), nil } + +// controllerNoopMigrate partially implements controller. It stubs the migration +// methods with noops for a stateless controller. +type controllerNoopMigrate struct{} + +// PrepareMigrate implements controller.PrepareMigrate. +func (*controllerNoopMigrate) PrepareMigrate(t *kernel.Task, src controller) error { + return nil +} + +// CommitMigrate implements controller.CommitMigrate. +func (*controllerNoopMigrate) CommitMigrate(t *kernel.Task, src controller) {} + +// AbortMigrate implements controller.AbortMigrate. +func (*controllerNoopMigrate) AbortMigrate(t *kernel.Task, src controller) {} diff --git a/pkg/sentry/fsimpl/cgroupfs/cpu.go b/pkg/sentry/fsimpl/cgroupfs/cpu.go index d81bc3e6d..8fd9b5dec 100644 --- a/pkg/sentry/fsimpl/cgroupfs/cpu.go +++ b/pkg/sentry/fsimpl/cgroupfs/cpu.go @@ -26,6 +26,7 @@ import ( // +stateify savable type cpuController struct { controllerCommon + controllerNoopMigrate // CFS bandwidth control parameters, values in microseconds. cfsPeriod int64 diff --git a/pkg/sentry/fsimpl/cgroupfs/cpuacct.go b/pkg/sentry/fsimpl/cgroupfs/cpuacct.go index 8f9818423..ae353fb33 100644 --- a/pkg/sentry/fsimpl/cgroupfs/cpuacct.go +++ b/pkg/sentry/fsimpl/cgroupfs/cpuacct.go @@ -28,6 +28,7 @@ import ( // +stateify savable type cpuacctController struct { controllerCommon + controllerNoopMigrate } var _ controller = (*cpuacctController)(nil) diff --git a/pkg/sentry/fsimpl/cgroupfs/cpuset.go b/pkg/sentry/fsimpl/cgroupfs/cpuset.go index f6d7cfc39..e6aa4a2a4 100644 --- a/pkg/sentry/fsimpl/cgroupfs/cpuset.go +++ b/pkg/sentry/fsimpl/cgroupfs/cpuset.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" ) @@ -33,6 +34,7 @@ import ( // +stateify savable type cpusetController struct { controllerCommon + controllerNoopMigrate maxCpus uint32 maxMems uint32 @@ -95,7 +97,7 @@ func (d *cpusData) Generate(ctx context.Context, buf *bytes.Buffer) error { } // Write implements vfs.WritableDynamicBytesSource.Write. -func (d *cpusData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { +func (d *cpusData) Write(ctx context.Context, _ *vfs.FileDescription, src usermem.IOSequence, offset int64) (int64, error) { if src.NumBytes() > hostarch.PageSize { return 0, linuxerr.EINVAL } @@ -139,7 +141,7 @@ func (d *memsData) Generate(ctx context.Context, buf *bytes.Buffer) error { } // Write implements vfs.WritableDynamicBytesSource.Write. -func (d *memsData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { +func (d *memsData) Write(ctx context.Context, _ *vfs.FileDescription, src usermem.IOSequence, offset int64) (int64, error) { if src.NumBytes() > hostarch.PageSize { return 0, linuxerr.EINVAL } diff --git a/pkg/sentry/fsimpl/cgroupfs/job.go b/pkg/sentry/fsimpl/cgroupfs/job.go index 9b3cae2d3..fa0cb7242 100644 --- a/pkg/sentry/fsimpl/cgroupfs/job.go +++ b/pkg/sentry/fsimpl/cgroupfs/job.go @@ -21,12 +21,15 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/usermem" ) // +stateify savable type jobController struct { controllerCommon + controllerNoopMigrate + id int64 } @@ -63,7 +66,7 @@ func (d *jobIDData) Generate(ctx context.Context, buf *bytes.Buffer) error { } // Write implements vfs.WritableDynamicBytesSource.Write. -func (d *jobIDData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { +func (d *jobIDData) Write(ctx context.Context, _ *vfs.FileDescription, src usermem.IOSequence, offset int64) (int64, error) { val, n, err := parseInt64FromString(ctx, src) if err != nil { return n, err diff --git a/pkg/sentry/fsimpl/cgroupfs/memory.go b/pkg/sentry/fsimpl/cgroupfs/memory.go index aeefa01c6..0395cad48 100644 --- a/pkg/sentry/fsimpl/cgroupfs/memory.go +++ b/pkg/sentry/fsimpl/cgroupfs/memory.go @@ -30,6 +30,7 @@ import ( // +stateify savable type memoryController struct { controllerCommon + controllerNoopMigrate limitBytes int64 softLimitBytes int64 diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go index 652ade564..17602c63c 100644 --- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go +++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go @@ -101,7 +101,7 @@ func (fd *DynamicBytesFD) Init(m *vfs.Mount, d *Dentry, data vfs.DynamicBytesSou return err } fd.inode = d.inode - fd.SetDataSource(data) + fd.DynamicBytesFileDescriptionImpl.Init(&fd.vfsfd, data) return nil } diff --git a/pkg/sentry/fsimpl/mqfs/queue.go b/pkg/sentry/fsimpl/mqfs/queue.go index 2f4b96e44..7f12edf50 100644 --- a/pkg/sentry/fsimpl/mqfs/queue.go +++ b/pkg/sentry/fsimpl/mqfs/queue.go @@ -77,7 +77,7 @@ func (fd *queueFD) Init(m *vfs.Mount, d *kernfs.Dentry, data vfs.DynamicBytesSou return err } fd.inode = d.Inode() - fd.SetDataSource(data) + fd.DynamicBytesFileDescriptionImpl.Init(&fd.vfsfd, data) return nil } diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index d3f9cf489..e7cfef94b 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -303,6 +303,7 @@ type idMapData struct { } var _ dynamicInode = (*idMapData)(nil) +var _ vfs.WritableDynamicBytesSource = (*idMapData)(nil) // Generate implements vfs.WritableDynamicBytesSource.Generate. func (d *idMapData) Generate(ctx context.Context, buf *bytes.Buffer) error { @@ -319,7 +320,7 @@ func (d *idMapData) Generate(ctx context.Context, buf *bytes.Buffer) error { } // Write implements vfs.WritableDynamicBytesSource.Write. -func (d *idMapData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { +func (d *idMapData) Write(ctx context.Context, _ *vfs.FileDescription, src usermem.IOSequence, offset int64) (int64, error) { // "In addition, the number of bytes written to the file must be less than // the system page size, and the write must be performed at the start of // the file ..." - user_namespaces(7) @@ -718,7 +719,7 @@ func (s *statusInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil { return nil, err } - fd.SetDataSource(fd) + fd.DynamicBytesFileDescriptionImpl.Init(&fd.vfsfd, fd) return &fd.vfsfd, nil } @@ -863,7 +864,7 @@ func (o *oomScoreAdj) Generate(ctx context.Context, buf *bytes.Buffer) error { } // Write implements vfs.WritableDynamicBytesSource.Write. -func (o *oomScoreAdj) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { +func (o *oomScoreAdj) Write(ctx context.Context, _ *vfs.FileDescription, src usermem.IOSequence, offset int64) (int64, error) { if src.NumBytes() == 0 { return 0, nil } diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go index 82e2857b3..69555ebb2 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys.go @@ -209,7 +209,7 @@ func (d *tcpSackData) Generate(ctx context.Context, buf *bytes.Buffer) error { } // Write implements vfs.WritableDynamicBytesSource.Write. -func (d *tcpSackData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { +func (d *tcpSackData) Write(ctx context.Context, _ *vfs.FileDescription, src usermem.IOSequence, offset int64) (int64, error) { if offset != 0 { // No need to handle partial writes thus far. return 0, linuxerr.EINVAL @@ -257,7 +257,7 @@ func (d *tcpRecoveryData) Generate(ctx context.Context, buf *bytes.Buffer) error } // Write implements vfs.WritableDynamicBytesSource.Write. -func (d *tcpRecoveryData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { +func (d *tcpRecoveryData) Write(ctx context.Context, _ *vfs.FileDescription, src usermem.IOSequence, offset int64) (int64, error) { if offset != 0 { // No need to handle partial writes thus far. return 0, linuxerr.EINVAL @@ -311,7 +311,7 @@ func (d *tcpMemData) Generate(ctx context.Context, buf *bytes.Buffer) error { } // Write implements vfs.WritableDynamicBytesSource.Write. -func (d *tcpMemData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { +func (d *tcpMemData) Write(ctx context.Context, _ *vfs.FileDescription, src usermem.IOSequence, offset int64) (int64, error) { if offset != 0 { // No need to handle partial writes thus far. return 0, linuxerr.EINVAL @@ -396,7 +396,7 @@ func (ipf *ipForwarding) Generate(ctx context.Context, buf *bytes.Buffer) error } // Write implements vfs.WritableDynamicBytesSource.Write. -func (ipf *ipForwarding) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { +func (ipf *ipForwarding) Write(ctx context.Context, _ *vfs.FileDescription, src usermem.IOSequence, offset int64) (int64, error) { if offset != 0 { // No need to handle partial writes thus far. return 0, linuxerr.EINVAL @@ -449,7 +449,7 @@ func (pr *portRange) Generate(ctx context.Context, buf *bytes.Buffer) error { } // Write implements vfs.WritableDynamicBytesSource.Write. -func (pr *portRange) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { +func (pr *portRange) Write(ctx context.Context, _ *vfs.FileDescription, src usermem.IOSequence, offset int64) (int64, error) { if offset != 0 { // No need to handle partial writes thus far. return 0, linuxerr.EINVAL diff --git a/pkg/sentry/fsimpl/proc/tasks_sys_test.go b/pkg/sentry/fsimpl/proc/tasks_sys_test.go index 19b012f7d..96b3d8b7a 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys_test.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys_test.go @@ -136,7 +136,7 @@ func TestConfigureIPForwarding(t *testing.T) { // Write the values. src := usermem.BytesIOSequence([]byte(c.str)) - if n, err := file.Write(ctx, src, 0); n != int64(len(c.str)) || err != nil { + if n, err := file.Write(ctx, nil, src, 0); n != int64(len(c.str)) || err != nil { t.Errorf("file.Write(ctx, nil, %q, 0) = (%d, %v); want (%d, nil)", c.str, n, err, len(c.str)) } diff --git a/pkg/sentry/fsimpl/proc/yama.go b/pkg/sentry/fsimpl/proc/yama.go index 7240563d7..072ed1cbf 100644 --- a/pkg/sentry/fsimpl/proc/yama.go +++ b/pkg/sentry/fsimpl/proc/yama.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/usermem" ) @@ -46,6 +47,8 @@ type yamaPtraceScope struct { level *int32 } +var _ vfs.WritableDynamicBytesSource = (*yamaPtraceScope)(nil) + // Generate implements vfs.DynamicBytesSource.Generate. func (s *yamaPtraceScope) Generate(ctx context.Context, buf *bytes.Buffer) error { _, err := fmt.Fprintf(buf, "%d\n", atomic.LoadInt32(s.level)) @@ -53,7 +56,7 @@ func (s *yamaPtraceScope) Generate(ctx context.Context, buf *bytes.Buffer) error } // Write implements vfs.WritableDynamicBytesSource.Write. -func (s *yamaPtraceScope) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { +func (s *yamaPtraceScope) Write(ctx context.Context, _ *vfs.FileDescription, src usermem.IOSequence, offset int64) (int64, error) { if offset != 0 { // Ignore partial writes. return 0, linuxerr.EINVAL diff --git a/pkg/sentry/kernel/cgroup.go b/pkg/sentry/kernel/cgroup.go index a0e291f58..071b204c2 100644 --- a/pkg/sentry/kernel/cgroup.go +++ b/pkg/sentry/kernel/cgroup.go @@ -86,11 +86,46 @@ func (c *Cgroup) HierarchyID() uint32 { return c.Controllers()[0].HierarchyID() } +// CgroupMigrationContext represents an in-flight cgroup migration for +// a single task. +type CgroupMigrationContext struct { + src Cgroup + dst Cgroup + t *Task +} + +// Abort cancels a migration. +func (ctx *CgroupMigrationContext) Abort() { + ctx.dst.AbortMigrate(ctx.t, &ctx.src) +} + +// Commit completes a migration. +func (ctx *CgroupMigrationContext) Commit() { + ctx.dst.CommitMigrate(ctx.t, &ctx.src) +} + // CgroupImpl is the common interface to cgroups. type CgroupImpl interface { + // Controllers lists the controller associated with this cgroup. Controllers() []CgroupController + + // Enter moves t into this cgroup. Enter(t *Task) + + // Leave moves t out of this cgroup. Leave(t *Task) + + // PrepareMigrate initiates a migration of t from src to this cgroup. See + // cgroupfs.controller.PrepareMigrate. + PrepareMigrate(t *Task, src *Cgroup) error + + // CommitMigrate completes an in-flight migration. See + // cgroupfs.controller.CommitMigrate. + CommitMigrate(t *Task, src *Cgroup) + + // AbortMigrate cancels an in-flight migration. See + // cgroupfs.controller.AbortMigrate. + AbortMigrate(t *Task, src *Cgroup) } // hierarchy represents a cgroupfs filesystem instance, with a unique set of diff --git a/pkg/sentry/kernel/task_cgroup.go b/pkg/sentry/kernel/task_cgroup.go index 828b90014..68e8a3cfa 100644 --- a/pkg/sentry/kernel/task_cgroup.go +++ b/pkg/sentry/kernel/task_cgroup.go @@ -58,15 +58,9 @@ func (t *Task) EnterCgroup(c Cgroup) error { defer t.mu.Unlock() for oldCG, _ := range t.cgroups { - for _, oldCtl := range oldCG.Controllers() { - if _, ok := newControllers[oldCtl.Type()]; ok { - // Already in a cgroup with the same controller as one of the - // new ones. Requires migration between cgroups. - // - // TODO(b/183137098): Implement cgroup migration. - log.Warningf("Cgroup migration is not implemented") - return linuxerr.EBUSY - } + if oldCG.HierarchyID() == c.HierarchyID() { + log.Warningf("Cannot enter new cgroup %v due to conflicting controllers. Try migrate instead?", c) + return linuxerr.EBUSY } } @@ -107,6 +101,82 @@ func (t *Task) leaveCgroupLocked(c Cgroup) { c.decRef() } +// +checklocks:t.mu +func (t *Task) findCgroupWithMatchingHierarchyLocked(other Cgroup) (Cgroup, bool) { + for c, _ := range t.cgroups { + if c.HierarchyID() != other.HierarchyID() { + continue + } + return c, true + } + return Cgroup{}, false +} + +// CgroupPrepareMigrate starts a cgroup migration for this task to dst. The +// migration must be completed through the returned context. +func (t *Task) CgroupPrepareMigrate(dst Cgroup) (*CgroupMigrationContext, error) { + t.mu.Lock() + defer t.mu.Unlock() + src, found := t.findCgroupWithMatchingHierarchyLocked(dst) + if !found { + log.Warningf("Cannot migrate to cgroup %v since task %v not currently in target hierarchy %v", dst, t, dst.HierarchyID()) + return nil, linuxerr.EINVAL + } + if err := dst.PrepareMigrate(t, &src); err != nil { + return nil, err + } + return &CgroupMigrationContext{ + src: src, + dst: dst, + t: t, + }, nil +} + +// MigrateCgroup migrates all tasks in tg to the dst cgroup. Either all tasks +// are migrated, or none are. Atomicity of migrations wrt cgroup membership +// (i.e. a task can't switch cgroups mid-migration due to another migration) is +// guaranteed because migrations are serialized by TaskSet.mu. +func (tg *ThreadGroup) MigrateCgroup(dst Cgroup) error { + tg.pidns.owner.mu.RLock() + defer tg.pidns.owner.mu.RUnlock() + + var ctxs []*CgroupMigrationContext + + // Prepare migrations. On partial failure, abort. + for t := tg.tasks.Front(); t != nil; t = t.Next() { + ctx, err := t.CgroupPrepareMigrate(dst) + if err != nil { + // Rollback. + for _, ctx := range ctxs { + ctx.Abort() + } + return err + } + ctxs = append(ctxs, ctx) + } + + // All migrations are now guaranteed to succeed. + + for _, ctx := range ctxs { + ctx.Commit() + } + + return nil +} + +// MigrateCgroup migrates this task to the dst cgroup. +func (t *Task) MigrateCgroup(dst Cgroup) error { + t.tg.pidns.owner.mu.RLock() + defer t.tg.pidns.owner.mu.RUnlock() + + ctx, err := t.CgroupPrepareMigrate(dst) + if err != nil { + return err + } + ctx.Commit() + return nil +} + // taskCgroupEntry represents a line in /proc//cgroup, and is used to // format a cgroup for display. type taskCgroupEntry struct { diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index 56e93274b..19b5d9515 100644 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go @@ -251,7 +251,7 @@ type WritableDynamicBytesSource interface { DynamicBytesSource // Write sends writes to the source. - Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) + Write(ctx context.Context, fd *FileDescription, src usermem.IOSequence, offset int64) (int64, error) } // DynamicBytesFileDescriptionImpl may be embedded by implementations of @@ -262,11 +262,12 @@ type WritableDynamicBytesSource interface { // If data additionally implements WritableDynamicBytesSource, writes are // dispatched to the implementer. The source data is not automatically modified. // -// DynamicBytesFileDescriptionImpl.SetDataSource() must be called before first +// DynamicBytesFileDescriptionImpl.Init() must be called before first // use. // // +stateify savable type DynamicBytesFileDescriptionImpl struct { + vfsfd *FileDescription // immutable data DynamicBytesSource // immutable mu sync.Mutex `state:"nosave"` // protects the following fields buf bytes.Buffer `state:".([]byte)"` @@ -282,8 +283,9 @@ func (fd *DynamicBytesFileDescriptionImpl) loadBuf(p []byte) { fd.buf.Write(p) } -// SetDataSource must be called exactly once on fd before first use. -func (fd *DynamicBytesFileDescriptionImpl) SetDataSource(data DynamicBytesSource) { +// Init must be called before first use. +func (fd *DynamicBytesFileDescriptionImpl) Init(vfsfd *FileDescription, data DynamicBytesSource) { + fd.vfsfd = vfsfd fd.data = data } @@ -376,7 +378,7 @@ func (fd *DynamicBytesFileDescriptionImpl) pwriteLocked(ctx context.Context, src if !ok { return 0, linuxerr.EIO } - n, err := writable.Write(ctx, src, offset) + n, err := writable.Write(ctx, fd.vfsfd, src, offset) if err != nil { return 0, err } diff --git a/pkg/sentry/vfs/file_description_impl_util_test.go b/pkg/sentry/vfs/file_description_impl_util_test.go index e34a8c11b..5b6acefb8 100644 --- a/pkg/sentry/vfs/file_description_impl_util_test.go +++ b/pkg/sentry/vfs/file_description_impl_util_test.go @@ -61,7 +61,7 @@ func (d *storeData) Generate(ctx context.Context, buf *bytes.Buffer) error { } // Generate implements WritableDynamicBytesSource. -func (d *storeData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { +func (d *storeData) Write(ctx context.Context, _ *FileDescription, src usermem.IOSequence, offset int64) (int64, error) { buf := make([]byte, src.NumBytes()) n, err := src.CopyIn(ctx, buf) if err != nil { @@ -84,9 +84,9 @@ func newTestFD(ctx context.Context, vfsObj *VirtualFilesystem, statusFlags uint3 vd := vfsObj.NewAnonVirtualDentry("genCountFD") defer vd.DecRef(ctx) var fd testFD - fd.vfsfd.Init(&fd, statusFlags, vd.Mount(), vd.Dentry(), &FileDescriptionOptions{}) - fd.DynamicBytesFileDescriptionImpl.SetDataSource(data) - return &fd.vfsfd + fd.fileDescription.vfsfd.Init(&fd, statusFlags, vd.Mount(), vd.Dentry(), &FileDescriptionOptions{}) + fd.DynamicBytesFileDescriptionImpl.Init(&fd.fileDescription.vfsfd, data) + return &fd.fileDescription.vfsfd } // Release implements FileDescriptionImpl.Release. diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index bd2a307ed..500980de4 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -4497,6 +4497,7 @@ cc_binary( "//test/util:fs_util", "//test/util:mount_util", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", gtest, "//test/util:cleanup", "//test/util:posix_error", diff --git a/test/syscalls/linux/cgroup.cc b/test/syscalls/linux/cgroup.cc index 278c3c734..08ce85fd4 100644 --- a/test/syscalls/linux/cgroup.cc +++ b/test/syscalls/linux/cgroup.cc @@ -23,6 +23,7 @@ #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/str_split.h" +#include "absl/synchronization/notification.h" #include "test/util/capability_util.h" #include "test/util/cgroup_util.h" #include "test/util/cleanup.h" @@ -53,6 +54,39 @@ bool CgroupsAvailable() { TEST_CHECK_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)); } +// NoopThreads spawns a set of threads that do nothing until they're asked to +// exit. Useful for testing functionality that requires a process with multiple +// threads. +class NoopThreads { + public: + NoopThreads(int count) { + auto noop = [this]() { exit_.WaitForNotification(); }; + + for (int i = 0; i < count; ++i) { + threads_.emplace_back(noop); + } + } + + ~NoopThreads() { Join(); } + + void Join() { + if (joined_) { + return; + } + + joined_ = true; + exit_.Notify(); + for (auto& thread : threads_) { + thread.Join(); + } + } + + private: + std::list threads_; + absl::Notification exit_; + bool joined_ = false; +}; + TEST(Cgroup, MountSucceeds) { SKIP_IF(!CgroupsAvailable()); @@ -326,6 +360,49 @@ TEST(Cgroup, SubcontainersHaveIndependentState) { EXPECT_THAT(c.ReadIntegerControlFile("job.id"), IsPosixErrorOkAndHolds(5678)); } +TEST(Cgroup, MigrateToSubcontainer) { + SKIP_IF(!CgroupsAvailable()); + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("")); + Cgroup child = ASSERT_NO_ERRNO_AND_VALUE(c.CreateChild("child1")); + + // Initially, test process should be in the root cgroup c. + EXPECT_NO_ERRNO(c.ContainsCallingProcess()); + + pid_t pid = getpid(); + + EXPECT_NO_ERRNO(child.Enter(pid)); + + // After migration, child should contain the test process, and the c should + // not. + EXPECT_NO_ERRNO(child.ContainsCallingProcess()); + auto procs = ASSERT_NO_ERRNO_AND_VALUE(c.Procs()); + EXPECT_FALSE(procs.contains(pid)); +} + +TEST(Cgroup, MigrateToSubcontainerThread) { + SKIP_IF(!CgroupsAvailable()); + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("")); + Cgroup child = ASSERT_NO_ERRNO_AND_VALUE(c.CreateChild("child1")); + + // Ensure there are some threads for this process. + NoopThreads threads(10); + + // Initially, test process should be in the root cgroup c. + EXPECT_NO_ERRNO(c.ContainsCallingThread()); + + const pid_t tid = syscall(SYS_gettid); + + EXPECT_NO_ERRNO(child.EnterThread(tid)); + + // After migration, child should contain the test process, and the c should + // not. + EXPECT_NO_ERRNO(child.ContainsCallingThread()); + auto tasks = ASSERT_NO_ERRNO_AND_VALUE(c.Tasks()); + EXPECT_FALSE(tasks.contains(tid)); +} + TEST(MemoryCgroup, MemoryUsageInBytes) { SKIP_IF(!CgroupsAvailable()); diff --git a/test/util/cgroup_util.cc b/test/util/cgroup_util.cc index 0308c2153..4a88e0b1b 100644 --- a/test/util/cgroup_util.cc +++ b/test/util/cgroup_util.cc @@ -106,6 +106,24 @@ PosixError Cgroup::ContainsCallingProcess() const { return NoError(); } +PosixError Cgroup::ContainsCallingThread() const { + ASSIGN_OR_RETURN_ERRNO(const absl::flat_hash_set tasks, Tasks()); + const pid_t tid = syscall(SYS_gettid); + if (!tasks.contains(tid)) { + return PosixError(ENOENT, + absl::StrFormat("Cgroup doesn't contain task %d", tid)); + } + return NoError(); +} + +PosixError Cgroup::Enter(pid_t pid) const { + return WriteIntegerControlFile("cgroup.procs", static_cast(pid)); +} + +PosixError Cgroup::EnterThread(pid_t pid) const { + return WriteIntegerControlFile("tasks", static_cast(pid)); +} + PosixErrorOr> Cgroup::ParsePIDList( absl::string_view data) const { absl::flat_hash_set res; diff --git a/test/util/cgroup_util.h b/test/util/cgroup_util.h index ccc7219e3..2781c0470 100644 --- a/test/util/cgroup_util.h +++ b/test/util/cgroup_util.h @@ -74,8 +74,19 @@ class Cgroup { PosixErrorOr> Tasks() const; // ContainsCallingProcess checks whether the calling process is part of the + // cgroup. PosixError ContainsCallingProcess() const; + // ContainsCallingThread checks whether the calling thread is part of the + // cgroup. + PosixError ContainsCallingThread() const; + + // Moves process with the specified pid to this cgroup. + PosixError Enter(pid_t pid) const; + + // Moves thread with the specified pid to this cgroup. + PosixError EnterThread(pid_t pid) const; + private: PosixErrorOr> ParsePIDList( absl::string_view data) const;