Skip to content

Commit

Permalink
Introduce Kuberay RayJobs MultiKueue Adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
mszadkow committed Dec 30, 2024
1 parent 4db76a9 commit d79b737
Show file tree
Hide file tree
Showing 4 changed files with 292 additions and 3 deletions.
4 changes: 3 additions & 1 deletion pkg/controller/jobframework/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (

kfmpi "github.com/kubeflow/mpi-operator/pkg/apis/kubeflow/v2beta1"
kftraining "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
batchv1 "k8s.io/api/batch/v1"
apivalidation "k8s.io/apimachinery/pkg/api/validation"
"k8s.io/apimachinery/pkg/util/sets"
Expand All @@ -47,7 +48,8 @@ var (
kftraining.SchemeGroupVersion.WithKind(kftraining.PaddleJobKind).String(),
kftraining.SchemeGroupVersion.WithKind(kftraining.PyTorchJobKind).String(),
kftraining.SchemeGroupVersion.WithKind(kftraining.XGBoostJobKind).String(),
kfmpi.SchemeGroupVersion.WithKind(kfmpi.Kind).String())
kfmpi.SchemeGroupVersion.WithKind(kfmpi.Kind).String(),
rayv1.SchemeGroupVersion.WithKind("RayJob").String())
)

// ValidateJobOnCreate encapsulates all GenericJob validations that must be performed on a Create operation
Expand Down
10 changes: 8 additions & 2 deletions pkg/controller/jobs/rayjob/rayjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/types"
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
Expand Down Expand Up @@ -55,12 +56,13 @@ func init() {
JobType: &rayv1.RayJob{},
AddToScheme: rayv1.AddToScheme,
IsManagingObjectsOwner: isRayJob,
MultiKueueAdapter: &multikueueAdapter{},
}))
}

// +kubebuilder:rbac:groups="",resources=events,verbs=create;watch;update
// +kubebuilder:rbac:groups=ray.io,resources=rayjobs,verbs=get;list;watch;update;patch
// +kubebuilder:rbac:groups=ray.io,resources=rayjobs/status,verbs=get;update
// +kubebuilder:rbac:groups=ray.io,resources=rayjobs/status,verbs=get;update;patch
// +kubebuilder:rbac:groups=ray.io,resources=rayjobs/finalizers,verbs=get;update
// +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloads,verbs=get;list;watch;create;update;patch;delete
// +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloads/status,verbs=get;update;patch
Expand All @@ -82,12 +84,16 @@ func (j *RayJob) Object() client.Object {
return (*rayv1.RayJob)(j)
}

func fromObject(obj runtime.Object) *RayJob {
return (*RayJob)(obj.(*rayv1.RayJob))
}

func (j *RayJob) IsSuspended() bool {
return j.Spec.Suspend
}

func (j *RayJob) IsActive() bool {
return j.Status.JobDeploymentStatus != rayv1.JobDeploymentStatusSuspended
return (j.Status.JobDeploymentStatus != rayv1.JobDeploymentStatusSuspended) && (j.Status.JobDeploymentStatus != rayv1.JobDeploymentStatusNew)
}

func (j *RayJob) Suspend() {
Expand Down
123 changes: 123 additions & 0 deletions pkg/controller/jobs/rayjob/rayjob_multikueue_adapter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/*
Copyright 2024 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package rayjob

import (
"context"
"errors"
"fmt"

rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/types"
"k8s.io/klog/v2"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"

kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1"
"sigs.k8s.io/kueue/pkg/controller/constants"
"sigs.k8s.io/kueue/pkg/controller/jobframework"
"sigs.k8s.io/kueue/pkg/util/api"
clientutil "sigs.k8s.io/kueue/pkg/util/client"
)

type multikueueAdapter struct{}

var _ jobframework.MultiKueueAdapter = (*multikueueAdapter)(nil)

func (b *multikueueAdapter) SyncJob(ctx context.Context, localClient client.Client, remoteClient client.Client, key types.NamespacedName, workloadName, origin string) error {
log := ctrl.LoggerFrom(ctx)

localJob := rayv1.RayJob{}
err := localClient.Get(ctx, key, &localJob)
if err != nil {
return err
}

remoteJob := rayv1.RayJob{}
err = remoteClient.Get(ctx, key, &remoteJob)
if client.IgnoreNotFound(err) != nil {
return err
}

// if the remote exists, just copy the status
if err == nil {
if fromObject(&localJob).IsSuspended() {
// Ensure the job is unsuspended before updating its status; otherwise, it will fail when patching the spec.
log.V(2).Info("Skipping the sync since the local job is still suspended")
return nil
}
return clientutil.PatchStatus(ctx, localClient, &localJob, func() (bool, error) {
localJob.Status = remoteJob.Status
return true, nil
})
}

remoteJob = rayv1.RayJob{
ObjectMeta: api.CloneObjectMetaForCreation(&localJob.ObjectMeta),
Spec: *localJob.Spec.DeepCopy(),
}

// add the prebuilt workload
if remoteJob.Labels == nil {
remoteJob.Labels = make(map[string]string, 2)
}
remoteJob.Labels[constants.PrebuiltWorkloadLabel] = workloadName
remoteJob.Labels[kueue.MultiKueueOriginLabel] = origin

return remoteClient.Create(ctx, &remoteJob)
}

func (b *multikueueAdapter) DeleteRemoteObject(ctx context.Context, remoteClient client.Client, key types.NamespacedName) error {
job := rayv1.RayJob{}
job.SetName(key.Name)
job.SetNamespace(key.Namespace)
return client.IgnoreNotFound(remoteClient.Delete(ctx, &job))
}

func (b *multikueueAdapter) KeepAdmissionCheckPending() bool {
return false
}

func (b *multikueueAdapter) IsJobManagedByKueue(ctx context.Context, c client.Client, key types.NamespacedName) (bool, string, error) {
return true, "", nil
}

func (b *multikueueAdapter) GVK() schema.GroupVersionKind {
return gvk
}

var _ jobframework.MultiKueueWatcher = (*multikueueAdapter)(nil)

func (*multikueueAdapter) GetEmptyList() client.ObjectList {
return &rayv1.RayJobList{}
}

func (*multikueueAdapter) WorkloadKeyFor(o runtime.Object) (types.NamespacedName, error) {
job, isJob := o.(*rayv1.RayJob)
if !isJob {
return types.NamespacedName{}, errors.New("not a rayjob")
}

prebuiltWl, hasPrebuiltWorkload := job.Labels[constants.PrebuiltWorkloadLabel]
if !hasPrebuiltWorkload {
return types.NamespacedName{}, fmt.Errorf("no prebuilt workload found for rayjob: %s", klog.KObj(job))
}

return types.NamespacedName{Name: prebuiltWl, Namespace: job.Namespace}, nil
}
158 changes: 158 additions & 0 deletions pkg/controller/jobs/rayjob/rayjob_multikueue_adapter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
/*
Copyright 2024 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package rayjob

import (
"context"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/client/interceptor"

kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1"
"sigs.k8s.io/kueue/pkg/controller/constants"
"sigs.k8s.io/kueue/pkg/util/slices"
utiltesting "sigs.k8s.io/kueue/pkg/util/testing"
utiltestingrayjob "sigs.k8s.io/kueue/pkg/util/testingjobs/rayjob"
)

const (
TestNamespace = "ns"
)

func TestMultikueueAdapter(t *testing.T) {
objCheckOpts := []cmp.Option{
cmpopts.IgnoreFields(metav1.ObjectMeta{}, "ResourceVersion"),
cmpopts.EquateEmpty(),
}

rayJobBuilder := utiltestingrayjob.MakeJob("rayjob1", TestNamespace).Suspend(false)

cases := map[string]struct {
managersRayJobs []rayv1.RayJob
workerRayJobs []rayv1.RayJob

operation func(ctx context.Context, adapter *multikueueAdapter, managerClient, workerClient client.Client) error

wantError error
wantManagersRayJobs []rayv1.RayJob
wantWorkerRayJobs []rayv1.RayJob
}{
"sync creates missing remote rayjob": {
managersRayJobs: []rayv1.RayJob{
*rayJobBuilder.DeepCopy(),
},
operation: func(ctx context.Context, adapter *multikueueAdapter, managerClient, workerClient client.Client) error {
return adapter.SyncJob(ctx, managerClient, workerClient, types.NamespacedName{Name: "rayjob1", Namespace: TestNamespace}, "wl1", "origin1")
},

wantManagersRayJobs: []rayv1.RayJob{
*rayJobBuilder.DeepCopy(),
},
wantWorkerRayJobs: []rayv1.RayJob{
*rayJobBuilder.Clone().
Label(constants.PrebuiltWorkloadLabel, "wl1").
Label(kueue.MultiKueueOriginLabel, "origin1").
Obj(),
},
},
"sync status from remote rayjob": {
managersRayJobs: []rayv1.RayJob{
*rayJobBuilder.DeepCopy(),
},
workerRayJobs: []rayv1.RayJob{
*rayJobBuilder.Clone().
Label(constants.PrebuiltWorkloadLabel, "wl1").
Label(kueue.MultiKueueOriginLabel, "origin1").
JobDeploymentStatus(rayv1.JobDeploymentStatusComplete).
Obj(),
},
operation: func(ctx context.Context, adapter *multikueueAdapter, managerClient, workerClient client.Client) error {
return adapter.SyncJob(ctx, managerClient, workerClient, types.NamespacedName{Name: "rayjob1", Namespace: TestNamespace}, "wl1", "origin1")
},

wantManagersRayJobs: []rayv1.RayJob{
*rayJobBuilder.Clone().
JobDeploymentStatus(rayv1.JobDeploymentStatusComplete).
Obj(),
},
wantWorkerRayJobs: []rayv1.RayJob{
*rayJobBuilder.Clone().
Label(constants.PrebuiltWorkloadLabel, "wl1").
Label(kueue.MultiKueueOriginLabel, "origin1").
JobDeploymentStatus(rayv1.JobDeploymentStatusComplete).
Obj(),
},
},
"remote rayjob is deleted": {
workerRayJobs: []rayv1.RayJob{
*rayJobBuilder.Clone().
Label(constants.PrebuiltWorkloadLabel, "wl1").
Label(kueue.MultiKueueOriginLabel, "origin1").
Obj(),
},
operation: func(ctx context.Context, adapter *multikueueAdapter, managerClient, workerClient client.Client) error {
return adapter.DeleteRemoteObject(ctx, workerClient, types.NamespacedName{Name: "rayjob1", Namespace: TestNamespace})
},
},
}
for name, tc := range cases {
t.Run(name, func(t *testing.T) {
managerBuilder := utiltesting.NewClientBuilder(rayv1.AddToScheme).WithInterceptorFuncs(interceptor.Funcs{SubResourcePatch: utiltesting.TreatSSAAsStrategicMerge})
managerBuilder = managerBuilder.WithLists(&rayv1.RayJobList{Items: tc.managersRayJobs})
managerBuilder = managerBuilder.WithStatusSubresource(slices.Map(tc.managersRayJobs, func(w *rayv1.RayJob) client.Object { return w })...)
managerClient := managerBuilder.Build()

workerBuilder := utiltesting.NewClientBuilder(rayv1.AddToScheme).WithInterceptorFuncs(interceptor.Funcs{SubResourcePatch: utiltesting.TreatSSAAsStrategicMerge})
workerBuilder = workerBuilder.WithLists(&rayv1.RayJobList{Items: tc.workerRayJobs})
workerClient := workerBuilder.Build()

ctx, _ := utiltesting.ContextWithLog(t)

adapter := &multikueueAdapter{}

gotErr := tc.operation(ctx, adapter, managerClient, workerClient)

if diff := cmp.Diff(tc.wantError, gotErr, cmpopts.EquateErrors()); diff != "" {
t.Errorf("unexpected error (-want/+got):\n%s", diff)
}

gotManagersRayJobs := &rayv1.RayJobList{}
if err := managerClient.List(ctx, gotManagersRayJobs); err != nil {
t.Errorf("unexpected list manager's rayjobs error %s", err)
} else {
if diff := cmp.Diff(tc.wantManagersRayJobs, gotManagersRayJobs.Items, objCheckOpts...); diff != "" {
t.Errorf("unexpected manager's rayjobs (-want/+got):\n%s", diff)
}
}

gotWorkerRayJobs := &rayv1.RayJobList{}
if err := workerClient.List(ctx, gotWorkerRayJobs); err != nil {
t.Errorf("unexpected list worker's rayjobs error %s", err)
} else {
if diff := cmp.Diff(tc.wantWorkerRayJobs, gotWorkerRayJobs.Items, objCheckOpts...); diff != "" {
t.Errorf("unexpected worker's rayjobs (-want/+got):\n%s", diff)
}
}
})
}
}

0 comments on commit d79b737

Please sign in to comment.