From c0cd2920db22eb9ff317f729c8ff94cdc46c77dd Mon Sep 17 00:00:00 2001 From: Mykhailo Bobrovskyi Date: Wed, 13 Nov 2024 11:18:16 +0200 Subject: [PATCH] Allow mutating queue name in StatefulSet Webhook. --- pkg/controller/jobs/pod/pod_controller.go | 26 +- pkg/controller/jobs/pod/pod_webhook.go | 13 - .../statefulset/statefulset_reconciler.go | 2 +- .../jobs/statefulset/statefulset_webhook.go | 54 ++-- pkg/util/pod/pod.go | 39 +++ test/e2e/singlecluster/statefulset_test.go | 249 ++++++++++++++---- 6 files changed, 267 insertions(+), 116 deletions(-) diff --git a/pkg/controller/jobs/pod/pod_controller.go b/pkg/controller/jobs/pod/pod_controller.go index 67e6231470..d5203b845f 100644 --- a/pkg/controller/jobs/pod/pod_controller.go +++ b/pkg/controller/jobs/pod/pod_controller.go @@ -19,8 +19,6 @@ package pod import ( "cmp" "context" - "crypto/sha256" - "encoding/json" "errors" "fmt" "slices" @@ -563,29 +561,7 @@ func getRoleHash(p corev1.Pod) (string, error) { if roleHash, ok := p.Annotations[RoleHashAnnotation]; ok { return roleHash, nil } - - shape := map[string]interface{}{ - "spec": map[string]interface{}{ - "initContainers": containersShape(p.Spec.InitContainers), - "containers": containersShape(p.Spec.Containers), - "nodeSelector": p.Spec.NodeSelector, - "affinity": p.Spec.Affinity, - "tolerations": p.Spec.Tolerations, - "runtimeClassName": p.Spec.RuntimeClassName, - "priority": p.Spec.Priority, - "topologySpreadConstraints": p.Spec.TopologySpreadConstraints, - "overhead": p.Spec.Overhead, - "resourceClaims": p.Spec.ResourceClaims, - }, - } - - shapeJSON, err := json.Marshal(shape) - if err != nil { - return "", err - } - - // Trim hash to 8 characters and return - return fmt.Sprintf("%x", sha256.Sum256(shapeJSON))[:8], nil + return utilpod.GenerateShape(p.Spec) } // Load loads all pods in the group diff --git a/pkg/controller/jobs/pod/pod_webhook.go b/pkg/controller/jobs/pod/pod_webhook.go index 23b01a01ca..8eaf30ba9c 100644 --- a/pkg/controller/jobs/pod/pod_webhook.go +++ b/pkg/controller/jobs/pod/pod_webhook.go @@ -113,19 +113,6 @@ func getPodOptions(integrationOpts map[string]any) (*configapi.PodIntegrationOpt var _ admission.CustomDefaulter = &PodWebhook{} -func containersShape(containers []corev1.Container) (result []map[string]interface{}) { - for _, c := range containers { - result = append(result, map[string]interface{}{ - "resources": map[string]interface{}{ - "requests": c.Resources.Requests, - }, - "ports": c.Ports, - }) - } - - return result -} - // addRoleHash calculates the role hash and adds it to the pod's annotations func (p *Pod) addRoleHash() error { if p.pod.Annotations == nil { diff --git a/pkg/controller/jobs/statefulset/statefulset_reconciler.go b/pkg/controller/jobs/statefulset/statefulset_reconciler.go index e6c39fce72..76be2814c8 100644 --- a/pkg/controller/jobs/statefulset/statefulset_reconciler.go +++ b/pkg/controller/jobs/statefulset/statefulset_reconciler.go @@ -67,7 +67,7 @@ func (r *Reconciler) Reconcile(ctx context.Context, req reconcile.Request) (reco func (r *Reconciler) fetchAndFinalizePods(ctx context.Context, namespace, statefulSetName string) error { podList := &corev1.PodList{} if err := r.client.List(ctx, podList, client.InNamespace(namespace), client.MatchingLabels{ - pod.GroupNameLabel: GetWorkloadName(statefulSetName), + StatefulSetNameLabel: statefulSetName, }); err != nil { return err } diff --git a/pkg/controller/jobs/statefulset/statefulset_webhook.go b/pkg/controller/jobs/statefulset/statefulset_webhook.go index 3121fa3ef6..bfecf147a4 100644 --- a/pkg/controller/jobs/statefulset/statefulset_webhook.go +++ b/pkg/controller/jobs/statefulset/statefulset_webhook.go @@ -34,6 +34,11 @@ import ( "sigs.k8s.io/kueue/pkg/controller/constants" "sigs.k8s.io/kueue/pkg/controller/jobframework" "sigs.k8s.io/kueue/pkg/controller/jobs/pod" + utilpod "sigs.k8s.io/kueue/pkg/util/pod" +) + +const ( + StatefulSetNameLabel = "kueue.x-k8s.io/statefulset-name" ) type Webhook struct { @@ -69,10 +74,15 @@ func (wh *Webhook) Default(ctx context.Context, obj runtime.Object) error { } if ss.Spec.Template.Labels == nil { - ss.Spec.Template.Labels = make(map[string]string, 2) + ss.Spec.Template.Labels = make(map[string]string, 3) } + ss.Spec.Template.Labels[StatefulSetNameLabel] = ss.Name ss.Spec.Template.Labels[constants.QueueLabel] = queueName - ss.Spec.Template.Labels[pod.GroupNameLabel] = GetWorkloadName(ss.Name) + groupName, err := GetWorkloadName(obj.(*appsv1.StatefulSet)) + if err != nil { + return err + } + ss.Spec.Template.Labels[pod.GroupNameLabel] = groupName if ss.Spec.Template.Annotations == nil { ss.Spec.Template.Annotations = make(map[string]string, 4) @@ -119,33 +129,12 @@ func (wh *Webhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Ob oldQueueName := jobframework.QueueNameForObject(oldStatefulSet.Object()) newQueueName := jobframework.QueueNameForObject(newStatefulSet.Object()) - allErrs := apivalidation.ValidateImmutableField(oldQueueName, newQueueName, queueNameLabelPath) - allErrs = append(allErrs, apivalidation.ValidateImmutableField( - newStatefulSet.Spec.Template.GetLabels()[constants.QueueLabel], - oldStatefulSet.Spec.Template.GetLabels()[constants.QueueLabel], - podSpecQueueNameLabelPath, - )...) - allErrs = append(allErrs, apivalidation.ValidateImmutableField( - newStatefulSet.GetLabels()[pod.GroupNameLabel], - oldStatefulSet.GetLabels()[pod.GroupNameLabel], - groupNameLabelPath, - )...) - - oldReplicas := ptr.Deref(oldStatefulSet.Spec.Replicas, 1) - newReplicas := ptr.Deref(newStatefulSet.Spec.Replicas, 1) - - // Allow only scale down to zero and scale up from zero. - // TODO(#3279): Support custom resizes later - if newReplicas != 0 && oldReplicas != 0 { - allErrs = append(allErrs, apivalidation.ValidateImmutableField( - newStatefulSet.Spec.Replicas, - oldStatefulSet.Spec.Replicas, - replicasPath, - )...) - } + allErrs := jobframework.ValidateQueueName(newStatefulSet.Object()) - if oldReplicas == 0 && newReplicas > 0 && newStatefulSet.Status.Replicas > 0 { - allErrs = append(allErrs, field.Forbidden(replicasPath, "scaling down is still in progress")) + // Prevents updating the queue-name if at least one Pod is not suspended + // or if the queue-name has been deleted. + if oldStatefulSet.Status.ReadyReplicas > 0 || newQueueName == "" { + allErrs = append(allErrs, apivalidation.ValidateImmutableField(oldQueueName, newQueueName, queueNameLabelPath)...) } return warnings, allErrs.ToAggregate() @@ -155,7 +144,12 @@ func (wh *Webhook) ValidateDelete(context.Context, runtime.Object) (warnings adm return nil, nil } -func GetWorkloadName(statefulSetName string) string { +func GetWorkloadName(sts *appsv1.StatefulSet) (string, error) { + shape, err := utilpod.GenerateShape(sts.Spec.Template.Spec) + if err != nil { + return "", err + } + ownerName := fmt.Sprintf("%s-%s", sts.Name, shape) // Passing empty UID as it is not available before object creation - return jobframework.GetWorkloadNameForOwnerWithGVK(statefulSetName, "", gvk) + return jobframework.GetWorkloadNameForOwnerWithGVK(ownerName, "", gvk), nil } diff --git a/pkg/util/pod/pod.go b/pkg/util/pod/pod.go index c8d6b54d38..c573e0659e 100644 --- a/pkg/util/pod/pod.go +++ b/pkg/util/pod/pod.go @@ -17,6 +17,8 @@ limitations under the License. package pod import ( + "crypto/sha256" + "encoding/json" "errors" "fmt" "math" @@ -104,3 +106,40 @@ func readUIntFromStringBelowBound(value string, bound int) (*int, error) { } return ptr.To(int(uintValue)), nil } + +func GenerateShape(podSpec corev1.PodSpec) (string, error) { + shape := map[string]interface{}{ + "spec": map[string]interface{}{ + "initContainers": containersShape(podSpec.InitContainers), + "containers": containersShape(podSpec.Containers), + "nodeSelector": podSpec.NodeSelector, + "affinity": podSpec.Affinity, + "tolerations": podSpec.Tolerations, + "runtimeClassName": podSpec.RuntimeClassName, + "priority": podSpec.Priority, + "topologySpreadConstraints": podSpec.TopologySpreadConstraints, + "overhead": podSpec.Overhead, + "resourceClaims": podSpec.ResourceClaims, + }, + } + + shapeJSON, err := json.Marshal(shape) + if err != nil { + return "", err + } + + // Trim hash to 8 characters and return + return fmt.Sprintf("%x", sha256.Sum256(shapeJSON))[:8], nil +} + +func containersShape(containers []corev1.Container) (result []map[string]interface{}) { + for _, c := range containers { + result = append(result, map[string]interface{}{ + "resources": map[string]interface{}{ + "requests": c.Resources.Requests, + }, + "ports": c.Ports, + }) + } + return result +} diff --git a/test/e2e/singlecluster/statefulset_test.go b/test/e2e/singlecluster/statefulset_test.go index 44697adc85..fff62bcf6e 100644 --- a/test/e2e/singlecluster/statefulset_test.go +++ b/test/e2e/singlecluster/statefulset_test.go @@ -82,52 +82,64 @@ var _ = ginkgo.Describe("StatefulSet integration", func() { }) ginkgo.When("StatefulSet created", func() { - ginkgo.It("should admit group that fits", func() { + ginkgo.It("should admit groups that fits with potentially conflicting StatefulSet", func() { statefulSet := statefulsettesting.MakeStatefulSet("sts", ns.Name). Image(util.E2eTestSleepImage, []string{"10m"}). Request(corev1.ResourceCPU, "100m"). Replicas(3). Queue(lq.Name). Obj() - wlLookupKey := types.NamespacedName{Name: statefulset.GetWorkloadName(statefulSet.Name), Namespace: ns.Name} - gomega.Expect(k8sClient.Create(ctx, statefulSet)).To(gomega.Succeed()) - - gomega.Eventually(func(g gomega.Gomega) { - createdStatefulSet := &appsv1.StatefulSet{} - g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)). - To(gomega.Succeed()) - g.Expect(createdStatefulSet.Status.ReadyReplicas).To(gomega.Equal(int32(3))) - }, util.LongTimeout, util.Interval).Should(gomega.Succeed()) + ginkgo.By("Create StatefulSet", func() { + gomega.Expect(k8sClient.Create(ctx, statefulSet)).To(gomega.Succeed()) + }) + + createdStatefulSet := &appsv1.StatefulSet{} + ginkgo.By("Waiting for replicas is ready", func() { + gomega.Eventually(func(g gomega.Gomega) { + g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)). + To(gomega.Succeed()) + g.Expect(createdStatefulSet.Status.ReadyReplicas).To(gomega.Equal(int32(3))) + }, util.LongTimeout, util.Interval).Should(gomega.Succeed()) + }) + createdWorkload := &kueue.Workload{} - gomega.Expect(k8sClient.Get(ctx, wlLookupKey, createdWorkload)).To(gomega.Succeed()) - - ginkgo.By("Creating potentially conflicting stateful-set", func() { - conflictingStatefulSet := statefulsettesting.MakeStatefulSet("sts-conflict", ns.Name). - Image(util.E2eTestSleepImage, []string{"10m"}). - Request(corev1.ResourceCPU, "100m"). - Replicas(1). - Queue(lq.Name). - Obj() - conflictingWlLookupKey := types.NamespacedName{ - Name: statefulset.GetWorkloadName(conflictingStatefulSet.Name), - Namespace: ns.Name, - } + ginkgo.By("Check the Workload is created", func() { + gomega.Expect(k8sClient.Get(ctx, getWorkloadKeyForStatefulSet(statefulSet), createdWorkload)).To(gomega.Succeed()) + }) + + conflictingStatefulSet := statefulsettesting.MakeStatefulSet("sts-conflict", ns.Name). + Image(util.E2eTestSleepImage, []string{"10m"}). + Request(corev1.ResourceCPU, "100m"). + Replicas(1). + Queue(lq.Name). + Obj() + ginkgo.By("Creating a potentially conflicting StatefulSet", func() { gomega.Expect(k8sClient.Create(ctx, conflictingStatefulSet)).To(gomega.Succeed()) + }) + + ginkgo.By("Waiting for replicas is ready in the conflicting StatefulSet", func() { gomega.Eventually(func(g gomega.Gomega) { - createdStatefulSet := &appsv1.StatefulSet{} g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(conflictingStatefulSet), createdStatefulSet)). To(gomega.Succeed()) g.Expect(createdStatefulSet.Status.ReadyReplicas).To(gomega.Equal(int32(1))) }, util.LongTimeout, util.Interval).Should(gomega.Succeed()) - conflictingWorkload := &kueue.Workload{} - gomega.Expect(k8sClient.Get(ctx, conflictingWlLookupKey, conflictingWorkload)).To(gomega.Succeed()) + }) + + conflictingWorkload := &kueue.Workload{} + ginkgo.By("Check the Workload for the conflicting StatefulSet is created with a different name", func() { + gomega.Expect(k8sClient.Get(ctx, getWorkloadKeyForStatefulSet(conflictingStatefulSet), conflictingWorkload)).To(gomega.Succeed()) gomega.Expect(createdWorkload.Name).ToNot(gomega.Equal(conflictingWorkload.Name)) + }) + + ginkgo.By("Check the conflicting Workload is deleted after the StatefulSet is deleted", func() { util.ExpectObjectToBeDeleted(ctx, k8sClient, conflictingStatefulSet, true) util.ExpectObjectToBeDeletedWithTimeout(ctx, k8sClient, conflictingWorkload, false, util.LongTimeout) }) - util.ExpectObjectToBeDeleted(ctx, k8sClient, statefulSet, true) - util.ExpectObjectToBeDeletedWithTimeout(ctx, k8sClient, createdWorkload, false, util.LongTimeout) + ginkgo.By("Check the Workload is deleted after the StatefulSet is deleted", func() { + util.ExpectObjectToBeDeleted(ctx, k8sClient, statefulSet, true) + util.ExpectObjectToBeDeletedWithTimeout(ctx, k8sClient, createdWorkload, false, util.LongTimeout) + }) }) ginkgo.It("should allow to update the PodTemplate in StatefulSet", func() { @@ -187,15 +199,14 @@ var _ = ginkgo.Describe("StatefulSet integration", func() { Replicas(3). Queue(lq.Name). Obj() - wlLookupKey := types.NamespacedName{Name: statefulset.GetWorkloadName(statefulSet.Name), Namespace: ns.Name} ginkgo.By("Create StatefulSet", func() { gomega.Expect(k8sClient.Create(ctx, statefulSet)).To(gomega.Succeed()) }) + createdStatefulSet := &appsv1.StatefulSet{} ginkgo.By("Waiting for replicas is ready", func() { gomega.Eventually(func(g gomega.Gomega) { - createdStatefulSet := &appsv1.StatefulSet{} g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed()) g.Expect(createdStatefulSet.Status.ReadyReplicas).To(gomega.Equal(int32(3))) }, util.LongTimeout, util.Interval).Should(gomega.Succeed()) @@ -203,12 +214,11 @@ var _ = ginkgo.Describe("StatefulSet integration", func() { createdWorkload := &kueue.Workload{} ginkgo.By("Check workload is created", func() { - gomega.Expect(k8sClient.Get(ctx, wlLookupKey, createdWorkload)).To(gomega.Succeed()) + gomega.Expect(k8sClient.Get(ctx, getWorkloadKeyForStatefulSet(statefulSet), createdWorkload)).To(gomega.Succeed()) }) ginkgo.By("Scale down replicas to zero", func() { gomega.Eventually(func(g gomega.Gomega) { - createdStatefulSet := &appsv1.StatefulSet{} g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed()) createdStatefulSet.Spec.Replicas = ptr.To[int32](0) g.Expect(k8sClient.Update(ctx, createdStatefulSet)).To(gomega.Succeed()) @@ -217,7 +227,6 @@ var _ = ginkgo.Describe("StatefulSet integration", func() { ginkgo.By("Waiting for replicas is deleted", func() { gomega.Eventually(func(g gomega.Gomega) { - createdStatefulSet := &appsv1.StatefulSet{} g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed()) g.Expect(createdStatefulSet.Status.ReadyReplicas).To(gomega.Equal(int32(0))) }, util.LongTimeout, util.Interval).Should(gomega.Succeed()) @@ -235,15 +244,14 @@ var _ = ginkgo.Describe("StatefulSet integration", func() { Replicas(0). Queue(lq.Name). Obj() - wlLookupKey := types.NamespacedName{Name: statefulset.GetWorkloadName(statefulSet.Name), Namespace: ns.Name} ginkgo.By("Create StatefulSet", func() { gomega.Expect(k8sClient.Create(ctx, statefulSet)).To(gomega.Succeed()) }) + createdStatefulSet := &appsv1.StatefulSet{} ginkgo.By("Scale up replicas", func() { gomega.Eventually(func(g gomega.Gomega) { - createdStatefulSet := &appsv1.StatefulSet{} g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed()) createdStatefulSet.Spec.Replicas = ptr.To[int32](3) g.Expect(k8sClient.Update(ctx, createdStatefulSet)).To(gomega.Succeed()) @@ -252,34 +260,77 @@ var _ = ginkgo.Describe("StatefulSet integration", func() { ginkgo.By("Waiting for replicas is ready", func() { gomega.Eventually(func(g gomega.Gomega) { - createdStatefulSet := &appsv1.StatefulSet{} g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed()) g.Expect(createdStatefulSet.Status.ReadyReplicas).To(gomega.Equal(int32(3))) }, util.LongTimeout, util.Interval).Should(gomega.Succeed()) }) - ginkgo.By("Check workload is created", func() { - createdWorkload := &kueue.Workload{} - gomega.Expect(k8sClient.Get(ctx, wlLookupKey, createdWorkload)).To(gomega.Succeed()) + createdWorkload := &kueue.Workload{} + ginkgo.By("Check the Workload is created", func() { + gomega.Expect(k8sClient.Get(ctx, getWorkloadKeyForStatefulSet(statefulSet), createdWorkload)).To(gomega.Succeed()) + }) + }) + + ginkgo.It("should allow to scale up", func() { + statefulSet := statefulsettesting.MakeStatefulSet("sts", ns.Name). + Image(util.E2eTestSleepImage, []string{"10m"}). + Request(corev1.ResourceCPU, "100m"). + Replicas(1). + Queue(lq.Name). + Obj() + + ginkgo.By("Create StatefulSet", func() { + gomega.Expect(k8sClient.Create(ctx, statefulSet)).To(gomega.Succeed()) + }) + + createdStatefulSet := &appsv1.StatefulSet{} + ginkgo.By("Waiting for replicas is ready", func() { + gomega.Eventually(func(g gomega.Gomega) { + g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed()) + g.Expect(createdStatefulSet.Status.ReadyReplicas).To(gomega.Equal(int32(1))) + }, util.LongTimeout, util.Interval).Should(gomega.Succeed()) + }) + + createdWorkload := &kueue.Workload{} + ginkgo.By("Check the Workload is created", func() { + gomega.Expect(k8sClient.Get(ctx, getWorkloadKeyForStatefulSet(statefulSet), createdWorkload)).To(gomega.Succeed()) + }) + + ginkgo.By("Scale up replicas", func() { + gomega.Eventually(func(g gomega.Gomega) { + g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed()) + createdStatefulSet.Spec.Replicas = ptr.To[int32](3) + g.Expect(k8sClient.Update(ctx, createdStatefulSet)).To(gomega.Succeed()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + }) + + ginkgo.By("Waiting for replicas is ready", func() { + gomega.Eventually(func(g gomega.Gomega) { + g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed()) + g.Expect(createdStatefulSet.Status.ReadyReplicas).To(gomega.Equal(int32(3))) + }, util.LongTimeout, util.Interval).Should(gomega.Succeed()) + }) + + ginkgo.By("Check the previous Workload is deleted", func() { + util.ExpectObjectToBeDeletedWithTimeout(ctx, k8sClient, createdWorkload, false, util.LongTimeout) }) }) - ginkgo.It("should allow to scale up after scale down to zero", func() { + ginkgo.It("should allow to scale up after partially scaled down", func() { statefulSet := statefulsettesting.MakeStatefulSet("sts", ns.Name). Image(util.E2eTestSleepImage, []string{"10m"}). Request(corev1.ResourceCPU, "100m"). Replicas(3). Queue(lq.Name). Obj() - wlLookupKey := types.NamespacedName{Name: statefulset.GetWorkloadName(statefulSet.Name), Namespace: ns.Name} ginkgo.By("Create StatefulSet", func() { gomega.Expect(k8sClient.Create(ctx, statefulSet)).To(gomega.Succeed()) }) + createdStatefulSet := &appsv1.StatefulSet{} ginkgo.By("Waiting for replicas is ready", func() { gomega.Eventually(func(g gomega.Gomega) { - createdStatefulSet := &appsv1.StatefulSet{} g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed()) g.Expect(createdStatefulSet.Status.ReadyReplicas).To(gomega.Equal(int32(3))) }, util.LongTimeout, util.Interval).Should(gomega.Succeed()) @@ -287,12 +338,11 @@ var _ = ginkgo.Describe("StatefulSet integration", func() { createdWorkload := &kueue.Workload{} ginkgo.By("Check workload is created", func() { - gomega.Expect(k8sClient.Get(ctx, wlLookupKey, createdWorkload)).To(gomega.Succeed()) + gomega.Expect(k8sClient.Get(ctx, getWorkloadKeyForStatefulSet(statefulSet), createdWorkload)).To(gomega.Succeed()) }) ginkgo.By("Scale down replicas to zero", func() { gomega.Eventually(func(g gomega.Gomega) { - createdStatefulSet := &appsv1.StatefulSet{} g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed()) createdStatefulSet.Spec.Replicas = ptr.To[int32](0) g.Expect(k8sClient.Update(ctx, createdStatefulSet)).To(gomega.Succeed()) @@ -301,7 +351,6 @@ var _ = ginkgo.Describe("StatefulSet integration", func() { ginkgo.By("Wait for ReadyReplicas < 3", func() { gomega.Eventually(func(g gomega.Gomega) { - createdStatefulSet := &appsv1.StatefulSet{} g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed()) g.Expect(createdStatefulSet.Status.ReadyReplicas).To(gomega.BeNumerically("<", 3)) g.Expect(k8sClient.Update(ctx, createdStatefulSet)).To(gomega.Succeed()) @@ -310,7 +359,6 @@ var _ = ginkgo.Describe("StatefulSet integration", func() { ginkgo.By("Scale up replicas to zero - retry as it may not be possible immediately", func() { gomega.Eventually(func(g gomega.Gomega) { - createdStatefulSet := &appsv1.StatefulSet{} g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed()) createdStatefulSet.Spec.Replicas = ptr.To[int32](3) g.Expect(k8sClient.Update(ctx, createdStatefulSet)).To(gomega.Succeed()) @@ -319,11 +367,118 @@ var _ = ginkgo.Describe("StatefulSet integration", func() { ginkgo.By("Waiting for replicas is ready", func() { gomega.Eventually(func(g gomega.Gomega) { - createdStatefulSet := &appsv1.StatefulSet{} g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed()) g.Expect(createdStatefulSet.Status.ReadyReplicas).To(gomega.Equal(int32(3))) }, util.LongTimeout, util.Interval).Should(gomega.Succeed()) }) }) + + ginkgo.It("should allow to scale down", func() { + statefulSet := statefulsettesting.MakeStatefulSet("sts", ns.Name). + Image(util.E2eTestSleepImage, []string{"10m"}). + Request(corev1.ResourceCPU, "100m"). + Replicas(3). + Queue(lq.Name). + Obj() + + ginkgo.By("Create StatefulSet", func() { + gomega.Expect(k8sClient.Create(ctx, statefulSet)).To(gomega.Succeed()) + }) + + createdStatefulSet := &appsv1.StatefulSet{} + ginkgo.By("Waiting for replicas is ready", func() { + gomega.Eventually(func(g gomega.Gomega) { + g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed()) + g.Expect(createdStatefulSet.Status.ReadyReplicas).To(gomega.Equal(int32(3))) + }, util.LongTimeout, util.Interval).Should(gomega.Succeed()) + }) + + createdWorkload := &kueue.Workload{} + ginkgo.By("Check the Workload is created", func() { + gomega.Expect(k8sClient.Get(ctx, getWorkloadKeyForStatefulSet(statefulSet), createdWorkload)).To(gomega.Succeed()) + }) + + ginkgo.By("Scale down replicas", func() { + gomega.Eventually(func(g gomega.Gomega) { + g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed()) + createdStatefulSet.Spec.Replicas = ptr.To[int32](1) + g.Expect(k8sClient.Update(ctx, createdStatefulSet)).To(gomega.Succeed()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + }) + + ginkgo.By("Waiting for replicas is ready", func() { + gomega.Eventually(func(g gomega.Gomega) { + g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed()) + g.Expect(createdStatefulSet.Status.ReadyReplicas).To(gomega.Equal(int32(1))) + }, util.LongTimeout, util.Interval).Should(gomega.Succeed()) + }) + + ginkgo.By("Check the previous Workload is deleted", func() { + util.ExpectObjectToBeDeletedWithTimeout(ctx, k8sClient, createdWorkload, false, util.LongTimeout) + }) + }) + + ginkgo.It("should allow to scale down after partially scale up", func() { + statefulSet := statefulsettesting.MakeStatefulSet("sts", ns.Name). + Image(util.E2eTestSleepImage, []string{"10m"}). + Request(corev1.ResourceCPU, "100m"). + Replicas(3). + Queue(lq.Name). + Obj() + + ginkgo.By("Create StatefulSet", func() { + gomega.Expect(k8sClient.Create(ctx, statefulSet)).To(gomega.Succeed()) + }) + + createdWorkload := &kueue.Workload{} + ginkgo.By("Check the Workload is created", func() { + gomega.Eventually(func(g gomega.Gomega) { + g.Expect(k8sClient.Get(ctx, getWorkloadKeyForStatefulSet(statefulSet), createdWorkload)).To(gomega.Succeed()) + }, util.LongTimeout, util.Interval).Should(gomega.Succeed()) + }) + + createdStatefulSet := &appsv1.StatefulSet{} + ginkgo.By("Wait for ReadyReplicas > 1", func() { + gomega.Eventually(func(g gomega.Gomega) { + g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed()) + g.Expect(createdStatefulSet.Status.ReadyReplicas).To(gomega.BeNumerically(">", 1)) + g.Expect(k8sClient.Update(ctx, createdStatefulSet)).To(gomega.Succeed()) + }, util.LongTimeout, util.Interval).Should(gomega.Succeed()) + }) + + ginkgo.By("Scale down replicas", func() { + gomega.Eventually(func(g gomega.Gomega) { + g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed()) + createdStatefulSet.Spec.Replicas = ptr.To[int32](1) + g.Expect(k8sClient.Update(ctx, createdStatefulSet)).To(gomega.Succeed()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + }) + + ginkgo.By("Waiting for replicas to be scaled down", func() { + gomega.Eventually(func(g gomega.Gomega) { + g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed()) + g.Expect(createdStatefulSet.Status.ReadyReplicas).To(gomega.Equal(int32(1))) + }, util.LongTimeout, util.Interval).Should(gomega.Succeed()) + }) + + pods := &corev1.PodList{} + gomega.Expect(k8sClient.List(ctx, pods, client.InNamespace(ns.Namespace), + client.MatchingLabels(map[string]string{statefulset.StatefulSetNameLabel: statefulSet.Name}), + )).To(gomega.Succeed()) + gomega.Expect(pods.Items).To(gomega.HaveLen(1)) + + ginkgo.By("Check the previous Workload is deleted", func() { + util.ExpectObjectToBeDeletedWithTimeout(ctx, k8sClient, createdWorkload, false, util.LongTimeout) + }) + }) }) }) + +func getWorkloadKeyForStatefulSet(sts *appsv1.StatefulSet) types.NamespacedName { + workloadName, err := statefulset.GetWorkloadName(sts) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + return types.NamespacedName{ + Name: workloadName, + Namespace: sts.Namespace, + } +}