diff --git a/runsc/cgroup/cgroup_v2.go b/runsc/cgroup/cgroup_v2.go index 94a81a3d15..13aef972dc 100644 --- a/runsc/cgroup/cgroup_v2.go +++ b/runsc/cgroup/cgroup_v2.go @@ -177,6 +177,17 @@ func (c *cgroupV2) Update(res *specs.LinuxResources) error { return fmt.Errorf("mandatory cgroup controller %q is missing for %q", controllerName, path) } } + // Override any values set above with the unified resource, if set. + if res != nil { + for k, v := range res.Unified { + if strings.Contains(k, "/") { + return fmt.Errorf("unified resource %q must be a file name (no slashes)", k) + } + if err := setValue(path, k, v); err != nil { + return fmt.Errorf("unable to set unified resource %q: %w", k, err) + } + } + } return nil } diff --git a/runsc/cgroup/cgroup_v2_test.go b/runsc/cgroup/cgroup_v2_test.go index dc07daeba3..f53a5bcb69 100644 --- a/runsc/cgroup/cgroup_v2_test.go +++ b/runsc/cgroup/cgroup_v2_test.go @@ -517,3 +517,96 @@ func TestUpdate(t *testing.T) { }) } } + +func TestUpdateUnified(t *testing.T) { + for _, tc := range []struct { + name string + resources *specs.LinuxResources + // pre-created files in the fake cgroup. + seedFiles []string + // expected file contents after Update() + wantFiles map[string]string + wantErrSub string + }{ + { + name: "sets arbitrary unified key", + resources: &specs.LinuxResources{ + Unified: map[string]string{ + "memory.high": "1000000", + }, + }, + seedFiles: []string{"memory.high"}, + wantFiles: map[string]string{ + "memory.high": "1000000", + }, + }, + { + name: "unified overrides controller-set value", + resources: &specs.LinuxResources{ + Memory: &specs.LinuxMemory{ + Limit: int64Ptr(2048), + }, + Unified: map[string]string{ + "memory.max": "4096", + }, + }, + seedFiles: []string{"memory.max"}, + wantFiles: map[string]string{ + "memory.max": "4096", + }, + }, + { + name: "rejects keys containing slashes", + resources: &specs.LinuxResources{ + Unified: map[string]string{ + "foo/bar": "1", + }, + }, + wantErrSub: "must be a file name", + }, + } { + t.Run(tc.name, func(t *testing.T) { + dir, err := os.MkdirTemp(testutil.TmpDir(), "cgroup") + if err != nil { + t.Fatalf("error creating temporary directory: %v", err) + } + defer os.RemoveAll(dir) + + cg := &cgroupV2{ + Mountpoint: dir, + Path: "user.slice", + Controllers: mandatoryControllers, + } + cgPath := filepath.Join(cg.Mountpoint, cg.Path) + if err := os.MkdirAll(cgPath, 0o777); err != nil { + t.Fatalf("os.MkdirAll(): %v", err) + } + for _, f := range tc.seedFiles { + if err := os.WriteFile(filepath.Join(cgPath, f), nil, 0o666); err != nil { + t.Fatalf("os.WriteFile(%q): %v", f, err) + } + } + + err = cg.Update(tc.resources) + if tc.wantErrSub != "" { + if err == nil || !strings.Contains(err.Error(), tc.wantErrSub) { + t.Fatalf("Update() error = %v, want substring %q", err, tc.wantErrSub) + } + return + } + if err != nil { + t.Fatalf("Update(): %v", err) + } + + for name, want := range tc.wantFiles { + got, err := os.ReadFile(filepath.Join(cgPath, name)) + if err != nil { + t.Fatalf("ReadFile(%q): %v", name, err) + } + if string(got) != want { + t.Errorf("file %q = %q, want %q", name, string(got), want) + } + } + }) + } +}