diff --git a/cgroup1/memory.go b/cgroup1/memory.go index caf5e9a7..52fe6907 100644 --- a/cgroup1/memory.go +++ b/cgroup1/memory.go @@ -454,6 +454,9 @@ func getOomControlValue(mem *specs.LinuxMemory) *int64 { if mem.DisableOOMKiller != nil && *mem.DisableOOMKiller { i := int64(1) return &i + } else if mem.DisableOOMKiller != nil && !*mem.DisableOOMKiller { + i := int64(0) + return &i } return nil } diff --git a/cgroup1/memory_test.go b/cgroup1/memory_test.go index fd16a084..9f7391c3 100644 --- a/cgroup1/memory_test.go +++ b/cgroup1/memory_test.go @@ -24,6 +24,7 @@ import ( "testing" v1 "github.com/containerd/cgroups/v3/cgroup1/stats" + specs "github.com/opencontainers/runtime-spec/specs-go" ) const memoryData = `cache 1 @@ -286,3 +287,59 @@ func buildMemoryMetrics(t *testing.T, modules []string, metrics []string) string } return tmpRoot } + +func Test_getOomControlValue(t *testing.T) { + var ( + oneInt64 int64 = 1 + zeroInt64 int64 = 0 + trueBool bool = true + falseBool bool = false + ) + + type args struct { + mem *specs.LinuxMemory + } + tests := []struct { + name string + args args + want *int64 + }{ + { + name: "enable", + args: args{ + mem: &specs.LinuxMemory{ + DisableOOMKiller: &falseBool, + }, + }, + want: &zeroInt64, + }, + { + name: "disable", + args: args{ + mem: &specs.LinuxMemory{ + DisableOOMKiller: &trueBool, + }, + }, + want: &oneInt64, + }, + { + name: "nil", + args: args{ + mem: &specs.LinuxMemory{}, + }, + want: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := getOomControlValue(tt.args.mem) + if (got == nil || tt.want == nil) && got != tt.want { + t.Errorf("getOomControlValue() = %v, want %v", got, tt.want) + return + } + if !(got == nil || tt.want == nil) && *got != *tt.want { + t.Errorf("getOomControlValue() = %v, want %v", got, tt.want) + } + }) + } +}