From da63f9352a4159cca70f40781fca23c84045cd28 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Mon, 18 Mar 2024 10:37:08 -0700 Subject: [PATCH 01/29] feat: spec level validation --- api/v1alpha1/workspace_types.go | 4 ++-- api/v1alpha1/workspace_validation.go | 28 +++++++++++++++++++++++++++- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/api/v1alpha1/workspace_types.go b/api/v1alpha1/workspace_types.go index 2f966f647..e451e1066 100644 --- a/api/v1alpha1/workspace_types.go +++ b/api/v1alpha1/workspace_types.go @@ -13,7 +13,7 @@ const ( ModelImageAccessModePrivate ModelImageAccessMode = "private" ) -// ResourceSpec desicribes the resource requirement of running the workload. +// ResourceSpec describes the resource requirement of running the workload. // If the number of nodes in the cluster that meet the InstanceType and // LabelSelector requirements is small than the Count, controller // will provision new nodes before deploying the workload. @@ -51,7 +51,7 @@ type PresetMeta struct { // AccessMode specifies whether the containerized model image is accessible via public registry // or private registry. This field defaults to "public" if not specified. // If this field is "private", user needs to provide the private image information in PresetOptions. - // +bebuilder:default:="public" + // +kubebuilder:default:="public" // +optional AccessMode ModelImageAccessMode `json:"accessMode,omitempty"` } diff --git a/api/v1alpha1/workspace_validation.go b/api/v1alpha1/workspace_validation.go index 16576f684..732526447 100644 --- a/api/v1alpha1/workspace_validation.go +++ b/api/v1alpha1/workspace_validation.go @@ -35,6 +35,7 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) { if base == nil { klog.InfoS("Validate creation", "workspace", fmt.Sprintf("%s/%s", w.Namespace, w.Name)) errs = errs.Also( + w.validateCreate().ViaField("spec"), w.Inference.validateCreate().ViaField("inference"), w.Resource.validateCreate(w.Inference).ViaField("resource"), ) @@ -42,6 +43,7 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) { klog.InfoS("Validate update", "workspace", fmt.Sprintf("%s/%s", w.Namespace, w.Name)) old := base.(*Workspace) errs = errs.Also( + w.validateUpdate(old).ViaField("spec"), w.Resource.validateUpdate(&old.Resource).ViaField("resource"), w.Inference.validateUpdate(&old.Inference).ViaField("inference"), ) @@ -49,6 +51,15 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) { return errs } +func (w *Workspace) validateCreate() (errs *apis.FieldError) { + inferenceSpecified := w.Inference.Preset != nil || w.Inference.Template != nil + tuningSpecified := w.Tuning.Input != nil + if inferenceSpecified != tuningSpecified { + return errs.Also(apis.ErrGeneric("Either Inference or Tuning must be specified, but not both", "")) + } + return errs +} + func (r *ResourceSpec) validateCreate(inference InferenceSpec) (errs *apis.FieldError) { var presetName string if inference.Preset != nil { @@ -96,6 +107,21 @@ func (r *ResourceSpec) validateCreate(inference InferenceSpec) (errs *apis.Field return errs } +func (w *Workspace) validateUpdate(old *Workspace) (errs *apis.FieldError) { + // Check inference specified + oldInferenceSpecified := old.Inference.Preset != nil || old.Inference.Template != nil + inferenceSpecified := w.Inference.Preset != nil || w.Inference.Template != nil + // Check tuning specified + oldTuningSpecified := old.Tuning.Input != nil + tuningSpecified := w.Tuning.Input != nil + + // inference/tuning can be changed, but cannot be set/unset. + if (!oldInferenceSpecified && inferenceSpecified) || (!oldTuningSpecified && tuningSpecified) { + errs = errs.Also(apis.ErrGeneric("field cannot be unset/set if it was set/unset", "spec")) + } + return errs +} + func (r *ResourceSpec) validateUpdate(old *ResourceSpec) (errs *apis.FieldError) { // We disable changing node count for now. if r.Count != nil && old.Count != nil && *r.Count != *old.Count { @@ -151,7 +177,7 @@ func (i *InferenceSpec) validateUpdate(old *InferenceSpec) (errs *apis.FieldErro if !reflect.DeepEqual(i.Preset, old.Preset) { errs = errs.Also(apis.ErrGeneric("field is immutable", "preset")) } - //inference.template can be changed, but cannot be unset. + // inference.template can be changed, but cannot be set/unset. if (i.Template != nil && old.Template == nil) || (i.Template == nil && old.Template != nil) { errs = errs.Also(apis.ErrGeneric("field cannot be unset/set if it was set/unset", "template")) } From a4f45e6ee20222d2d43edc6f904f2a43c16b3883 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Mon, 18 Mar 2024 15:01:25 -0700 Subject: [PATCH 02/29] feat: Added validation checks for TuningSpec, DataSource, DataDestination --- api/v1alpha1/workspace_types.go | 2 +- api/v1alpha1/workspace_validation.go | 138 +++++++++++++++++++++++---- 2 files changed, 122 insertions(+), 18 deletions(-) diff --git a/api/v1alpha1/workspace_types.go b/api/v1alpha1/workspace_types.go index e451e1066..71e9f829c 100644 --- a/api/v1alpha1/workspace_types.go +++ b/api/v1alpha1/workspace_types.go @@ -106,7 +106,7 @@ type DataSource struct { // URLs specifies the links to the public data sources. E.g., files in a public github repository. // +optional URLs []string `json:"urls,omitempty"` - // The directory in the hsot that contains the data. + // The directory in the host that contains the data. // +optional HostPath string `json:"hostPath,omitempty"` // The name of the image that contains the source data. The assumption is that the source data locates in the diff --git a/api/v1alpha1/workspace_validation.go b/api/v1alpha1/workspace_validation.go index 732526447..81d9353b4 100644 --- a/api/v1alpha1/workspace_validation.go +++ b/api/v1alpha1/workspace_validation.go @@ -37,6 +37,9 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) { errs = errs.Also( w.validateCreate().ViaField("spec"), w.Inference.validateCreate().ViaField("inference"), + w.Tuning.validateCreate().ViaField("tuning"), + w.Tuning.Input.validateCreate().ViaField("input"), + w.Tuning.Output.validateCreate().ViaField("output"), w.Resource.validateCreate(w.Inference).ViaField("resource"), ) } else { @@ -44,8 +47,11 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) { old := base.(*Workspace) errs = errs.Also( w.validateUpdate(old).ViaField("spec"), - w.Resource.validateUpdate(&old.Resource).ViaField("resource"), w.Inference.validateUpdate(&old.Inference).ViaField("inference"), + w.Tuning.validateUpdate(&old.Tuning).ViaField("tuning"), + w.Tuning.Input.validateUpdate(old.Tuning.Input).ViaField("input"), + w.Tuning.Output.validateUpdate(old.Tuning.Output).ViaField("output"), + w.Resource.validateUpdate(&old.Resource).ViaField("resource"), ) } return errs @@ -55,11 +61,124 @@ func (w *Workspace) validateCreate() (errs *apis.FieldError) { inferenceSpecified := w.Inference.Preset != nil || w.Inference.Template != nil tuningSpecified := w.Tuning.Input != nil if inferenceSpecified != tuningSpecified { - return errs.Also(apis.ErrGeneric("Either Inference or Tuning must be specified, but not both", "")) + errs = errs.Also(apis.ErrGeneric("Either Inference or Tuning must be specified, but not both", "")) } return errs } +func (w *Workspace) validateUpdate(old *Workspace) (errs *apis.FieldError) { + // Check inference specified + oldInferenceSpecified := old.Inference.Preset != nil || old.Inference.Template != nil + inferenceSpecified := w.Inference.Preset != nil || w.Inference.Template != nil + // Check tuning specified + oldTuningSpecified := old.Tuning.Input != nil + tuningSpecified := w.Tuning.Input != nil + + // inference/tuning can be changed, but cannot be set/unset. + if (!oldInferenceSpecified && inferenceSpecified) || (!oldTuningSpecified && tuningSpecified) { + errs = errs.Also(apis.ErrGeneric("field cannot be unset/set if it was set/unset", "spec")) + } + return errs +} + +func (r *TuningSpec) validateCreate() (errs *apis.FieldError) { + if r.Input == nil { + errs = errs.Also(apis.ErrMissingField("Input")) + } + if r.Output == nil { + errs = errs.Also(apis.ErrMissingField("Output")) + } + // Currently require a preset to specified, in future we can consider defining a template + if r.Preset == nil { + errs = errs.Also(apis.ErrMissingField("Preset")) + } + // TODO: We have to register training plugins and check if it preset exists in plugins here + methodLowerCase := strings.ToLower(string(r.Method)) + if methodLowerCase != string(TuningMethodLora) && methodLowerCase != string(TuningMethodQLora) { + errs = errs.Also(apis.ErrInvalidValue(r.Method, "Method")) + } + return errs +} + +func (r *TuningSpec) validateUpdate(old *TuningSpec) (errs *apis.FieldError) { + if !reflect.DeepEqual(old.Input, r.Input) { + errs = errs.Also(apis.ErrGeneric("Input field cannot be changed", "Input")) + } + if !reflect.DeepEqual(old.Output, r.Output) { + errs = errs.Also(apis.ErrGeneric("Output field cannot be changed", "Output")) + } + if !reflect.DeepEqual(old.Preset, r.Preset) { + errs = errs.Also(apis.ErrGeneric("Preset cannot be changed", "Preset")) + } + // We will have to consider supporting tuning method and config fields changing + methodLowerCase := strings.ToLower(string(r.Method)) + if methodLowerCase != string(TuningMethodLora) && methodLowerCase != string(TuningMethodQLora) { + errs = errs.Also(apis.ErrInvalidValue(r.Method, "Method")) + } + return errs +} + +func (r *DataSource) validateCreate() (errs *apis.FieldError) { + sourcesSpecified := 0 + if len(r.URLs) > 0 { + sourcesSpecified++ + } + if r.HostPath != "" { + sourcesSpecified++ + } + if r.Image != "" { + sourcesSpecified++ + } + + // Ensure exactly one of URLs, HostPath, or Image is specified + if sourcesSpecified != 1 { + errs = errs.Also(apis.ErrGeneric("Exactly one of URLs, HostPath, or Image must be specified", "URLs", "HostPath", "Image")) + } + + return errs +} + +func (r *DataSource) validateUpdate(old *DataSource) (errs *apis.FieldError) { + if !reflect.DeepEqual(old.URLs, r.URLs) { + errs = errs.Also(apis.ErrInvalidValue("URLs field cannot be changed once set", "URLs")) + } + if old.HostPath != r.HostPath { + errs = errs.Also(apis.ErrInvalidValue("HostPath field cannot be changed once set", "HostPath")) + } + if old.Image != r.Image { + errs = errs.Also(apis.ErrInvalidValue("Image field cannot be changed once set", "Image")) + } + // TODO: Ensure ImageSecrets can be changed + return errs +} + +func (r *DataDestination) validateCreate() (errs *apis.FieldError) { + destinationsSpecified := 0 + if r.HostPath != "" { + destinationsSpecified++ + } + if r.Image != "" { + destinationsSpecified++ + } + + // If no destination is specified, return an error + if destinationsSpecified == 0 { + errs = errs.Also(apis.ErrMissingField("At least one of HostPath or Image must be specified")) + } + return errs +} + +func (r *DataDestination) validateUpdate(old *DataDestination) (errs *apis.FieldError) { + if old.HostPath != r.HostPath { + errs = errs.Also(apis.ErrInvalidValue("HostPath field cannot be changed once set", "HostPath")) + } + if old.Image != r.Image { + errs = errs.Also(apis.ErrInvalidValue("Image field cannot be changed once set", "Image")) + } + // TODO: Ensure ImageSecrets can be changed + return errs +} + func (r *ResourceSpec) validateCreate(inference InferenceSpec) (errs *apis.FieldError) { var presetName string if inference.Preset != nil { @@ -107,21 +226,6 @@ func (r *ResourceSpec) validateCreate(inference InferenceSpec) (errs *apis.Field return errs } -func (w *Workspace) validateUpdate(old *Workspace) (errs *apis.FieldError) { - // Check inference specified - oldInferenceSpecified := old.Inference.Preset != nil || old.Inference.Template != nil - inferenceSpecified := w.Inference.Preset != nil || w.Inference.Template != nil - // Check tuning specified - oldTuningSpecified := old.Tuning.Input != nil - tuningSpecified := w.Tuning.Input != nil - - // inference/tuning can be changed, but cannot be set/unset. - if (!oldInferenceSpecified && inferenceSpecified) || (!oldTuningSpecified && tuningSpecified) { - errs = errs.Also(apis.ErrGeneric("field cannot be unset/set if it was set/unset", "spec")) - } - return errs -} - func (r *ResourceSpec) validateUpdate(old *ResourceSpec) (errs *apis.FieldError) { // We disable changing node count for now. if r.Count != nil && old.Count != nil && *r.Count != *old.Count { From a9bbe7ad0e3867c29c1648f60733aacad29765fa Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Mon, 18 Mar 2024 15:08:46 -0700 Subject: [PATCH 03/29] fix: prevent toggling --- api/v1alpha1/workspace_validation.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/api/v1alpha1/workspace_validation.go b/api/v1alpha1/workspace_validation.go index 81d9353b4..39fbf87b7 100644 --- a/api/v1alpha1/workspace_validation.go +++ b/api/v1alpha1/workspace_validation.go @@ -73,10 +73,12 @@ func (w *Workspace) validateUpdate(old *Workspace) (errs *apis.FieldError) { // Check tuning specified oldTuningSpecified := old.Tuning.Input != nil tuningSpecified := w.Tuning.Input != nil + if (!oldInferenceSpecified && inferenceSpecified) || (oldInferenceSpecified && !inferenceSpecified) { + errs = errs.Also(apis.ErrGeneric("Inference field cannot be toggled once set", "inference")) + } - // inference/tuning can be changed, but cannot be set/unset. - if (!oldInferenceSpecified && inferenceSpecified) || (!oldTuningSpecified && tuningSpecified) { - errs = errs.Also(apis.ErrGeneric("field cannot be unset/set if it was set/unset", "spec")) + if (!oldTuningSpecified && tuningSpecified) || (oldTuningSpecified && !tuningSpecified) { + errs = errs.Also(apis.ErrGeneric("Tuning field cannot be toggled once set", "tuning")) } return errs } From d73ef65ec6846dd4a055062935d9a788a05397ce Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Mon, 18 Mar 2024 18:06:49 -0700 Subject: [PATCH 04/29] fix: validation fixes --- api/v1alpha1/workspace_validation.go | 47 ++++++++++++++++++++++------ 1 file changed, 37 insertions(+), 10 deletions(-) diff --git a/api/v1alpha1/workspace_validation.go b/api/v1alpha1/workspace_validation.go index 39fbf87b7..9034deae3 100644 --- a/api/v1alpha1/workspace_validation.go +++ b/api/v1alpha1/workspace_validation.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "reflect" + "sort" "strings" "github.com/azure/kaito/pkg/utils/plugin" @@ -60,7 +61,10 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) { func (w *Workspace) validateCreate() (errs *apis.FieldError) { inferenceSpecified := w.Inference.Preset != nil || w.Inference.Template != nil tuningSpecified := w.Tuning.Input != nil - if inferenceSpecified != tuningSpecified { + if !inferenceSpecified && !tuningSpecified { + errs = errs.Also(apis.ErrGeneric("Either Inference or Tuning must be specified, not neither", "")) + } + if inferenceSpecified && tuningSpecified { errs = errs.Also(apis.ErrGeneric("Either Inference or Tuning must be specified, but not both", "")) } return errs @@ -93,8 +97,9 @@ func (r *TuningSpec) validateCreate() (errs *apis.FieldError) { // Currently require a preset to specified, in future we can consider defining a template if r.Preset == nil { errs = errs.Also(apis.ErrMissingField("Preset")) + } else if presetName := string(r.Preset.Name); !isValidPreset(presetName) { + errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported tuning preset name %s", presetName), "presetName")) } - // TODO: We have to register training plugins and check if it preset exists in plugins here methodLowerCase := strings.ToLower(string(r.Method)) if methodLowerCase != string(TuningMethodLora) && methodLowerCase != string(TuningMethodQLora) { errs = errs.Also(apis.ErrInvalidValue(r.Method, "Method")) @@ -112,11 +117,11 @@ func (r *TuningSpec) validateUpdate(old *TuningSpec) (errs *apis.FieldError) { if !reflect.DeepEqual(old.Preset, r.Preset) { errs = errs.Also(apis.ErrGeneric("Preset cannot be changed", "Preset")) } - // We will have to consider supporting tuning method and config fields changing - methodLowerCase := strings.ToLower(string(r.Method)) - if methodLowerCase != string(TuningMethodLora) && methodLowerCase != string(TuningMethodQLora) { - errs = errs.Also(apis.ErrInvalidValue(r.Method, "Method")) + oldMethod, newMethod := strings.ToLower(string(old.Method)), strings.ToLower(string(r.Method)) + if !reflect.DeepEqual(oldMethod, newMethod) { + errs = errs.Also(apis.ErrGeneric("Method cannot be changed", "Method")) } + // Consider supporting config fields changing return errs } @@ -141,7 +146,15 @@ func (r *DataSource) validateCreate() (errs *apis.FieldError) { } func (r *DataSource) validateUpdate(old *DataSource) (errs *apis.FieldError) { - if !reflect.DeepEqual(old.URLs, r.URLs) { + oldURLs := make([]string, len(old.URLs)) + copy(oldURLs, old.URLs) + sort.Strings(old.URLs) + + newURLs := make([]string, len(r.URLs)) + copy(newURLs, r.URLs) + sort.Strings(r.URLs) + + if !reflect.DeepEqual(oldURLs, newURLs) { errs = errs.Also(apis.ErrInvalidValue("URLs field cannot be changed once set", "URLs")) } if old.HostPath != r.HostPath { @@ -150,7 +163,18 @@ func (r *DataSource) validateUpdate(old *DataSource) (errs *apis.FieldError) { if old.Image != r.Image { errs = errs.Also(apis.ErrInvalidValue("Image field cannot be changed once set", "Image")) } - // TODO: Ensure ImageSecrets can be changed + + oldSecrets := make([]string, len(old.ImagePullSecrets)) + copy(oldSecrets, old.ImagePullSecrets) + sort.Strings(oldSecrets) + + newSecrets := make([]string, len(r.ImagePullSecrets)) + copy(newSecrets, r.ImagePullSecrets) + sort.Strings(newSecrets) + + if !reflect.DeepEqual(oldSecrets, newSecrets) { + errs = errs.Also(apis.ErrInvalidValue("ImagePullSecrets field cannot be changed once set", "ImagePullSecrets")) + } return errs } @@ -177,7 +201,10 @@ func (r *DataDestination) validateUpdate(old *DataDestination) (errs *apis.Field if old.Image != r.Image { errs = errs.Also(apis.ErrInvalidValue("Image field cannot be changed once set", "Image")) } - // TODO: Ensure ImageSecrets can be changed + + if old.ImagePushSecret != r.ImagePushSecret { + errs = errs.Also(apis.ErrInvalidValue("ImagePushSecret field cannot be changed once set", "ImagePushSecret")) + } return errs } @@ -263,7 +290,7 @@ func (i *InferenceSpec) validateCreate() (errs *apis.FieldError) { presetName := string(i.Preset.Name) // Validate preset name if !isValidPreset(presetName) { - errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported preset name %s", presetName), "presetName")) + errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported inference preset name %s", presetName), "presetName")) } // Validate private preset has private image specified if plugin.KaitoModelRegister.MustGet(string(i.Preset.Name)).GetInferenceParameters().ImageAccessMode == "private" && From 3fa0e46f1096bd6ed4cb309c58a365b8166fe18c Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Tue, 19 Mar 2024 11:50:00 -0700 Subject: [PATCH 05/29] feat: Add UTs for workspace validation --- api/v1alpha1/workspace_validation.go | 4 +- api/v1alpha1/workspace_validation_test.go | 615 ++++++++++++++++++++++ 2 files changed, 617 insertions(+), 2 deletions(-) diff --git a/api/v1alpha1/workspace_validation.go b/api/v1alpha1/workspace_validation.go index 9034deae3..5a8269353 100644 --- a/api/v1alpha1/workspace_validation.go +++ b/api/v1alpha1/workspace_validation.go @@ -148,11 +148,11 @@ func (r *DataSource) validateCreate() (errs *apis.FieldError) { func (r *DataSource) validateUpdate(old *DataSource) (errs *apis.FieldError) { oldURLs := make([]string, len(old.URLs)) copy(oldURLs, old.URLs) - sort.Strings(old.URLs) + sort.Strings(oldURLs) newURLs := make([]string, len(r.URLs)) copy(newURLs, r.URLs) - sort.Strings(r.URLs) + sort.Strings(newURLs) if !reflect.DeepEqual(oldURLs, newURLs) { errs = errs.Also(apis.ErrInvalidValue("URLs field cannot be changed once set", "URLs")) diff --git a/api/v1alpha1/workspace_validation_test.go b/api/v1alpha1/workspace_validation_test.go index 0a3fa2de1..6c2f1a650 100644 --- a/api/v1alpha1/workspace_validation_test.go +++ b/api/v1alpha1/workspace_validation_test.go @@ -488,6 +488,621 @@ func TestInferenceSpecValidateUpdate(t *testing.T) { } } +func TestWorkspaceValidateCreate(t *testing.T) { + tests := []struct { + name string + workspace *Workspace + wantErr bool + errField string + }{ + { + name: "Neither Inference nor Tuning specified", + workspace: &Workspace{ + Inference: InferenceSpec{}, + Tuning: TuningSpec{}, + }, + wantErr: true, + errField: "neither", + }, + { + name: "Both Inference and Tuning specified", + workspace: &Workspace{ + Inference: InferenceSpec{Preset: &PresetSpec{}}, + Tuning: TuningSpec{Input: &DataSource{}}, + }, + wantErr: true, + errField: "both", + }, + { + name: "Only Inference specified", + workspace: &Workspace{ + Inference: InferenceSpec{Preset: &PresetSpec{}}, + }, + wantErr: false, + errField: "", + }, + { + name: "Only Tuning specified", + workspace: &Workspace{ + Tuning: TuningSpec{Input: &DataSource{}}, + }, + wantErr: false, + errField: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.workspace.validateCreate() + if (errs != nil) != tt.wantErr { + t.Errorf("validateCreate() error = %v, wantErr %v", errs, tt.wantErr) + } + if errs != nil && !strings.Contains(errs.Error(), tt.errField) { + t.Errorf("validateCreate() expected error to contain field %s, but got %s", tt.errField, errs.Error()) + } + }) + } +} + +func TestWorkspaceValidateUpdate(t *testing.T) { + tests := []struct { + name string + oldWorkspace *Workspace + newWorkspace *Workspace + expectErrs bool + errFields []string // Fields we expect to have errors + }{ + { + name: "Inference toggled on", + oldWorkspace: &Workspace{ + Inference: InferenceSpec{}, + }, + newWorkspace: &Workspace{ + Inference: InferenceSpec{Preset: &PresetSpec{}}, + }, + expectErrs: true, + errFields: []string{"inference"}, + }, + { + name: "Inference toggled off", + oldWorkspace: &Workspace{ + Inference: InferenceSpec{Preset: &PresetSpec{}}, + }, + newWorkspace: &Workspace{ + Inference: InferenceSpec{}, + }, + expectErrs: true, + errFields: []string{"inference"}, + }, + { + name: "Tuning toggled on", + oldWorkspace: &Workspace{ + Tuning: TuningSpec{}, + }, + newWorkspace: &Workspace{ + Tuning: TuningSpec{Input: &DataSource{}}, + }, + expectErrs: true, + errFields: []string{"tuning"}, + }, + { + name: "Tuning toggled off", + oldWorkspace: &Workspace{ + Tuning: TuningSpec{Input: &DataSource{}}, + }, + newWorkspace: &Workspace{ + Tuning: TuningSpec{}, + }, + expectErrs: true, + errFields: []string{"tuning"}, + }, + { + name: "No toggling", + oldWorkspace: &Workspace{ + Tuning: TuningSpec{Input: &DataSource{}}, + }, + newWorkspace: &Workspace{ + Tuning: TuningSpec{Input: &DataSource{}}, + }, + expectErrs: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.newWorkspace.validateUpdate(tt.oldWorkspace) + hasErrs := errs != nil + + if hasErrs != tt.expectErrs { + t.Errorf("validateUpdate() errors = %v, expectErrs %v", errs, tt.expectErrs) + } + + if hasErrs { + for _, field := range tt.errFields { + if !strings.Contains(errs.Error(), field) { + t.Errorf("validateUpdate() expected errors to contain field %s, but got %s", field, errs.Error()) + } + } + } + }) + } +} + +func TestTuningSpecValidateCreate(t *testing.T) { + RegisterValidationTestModels() + tests := []struct { + name string + tuningSpec *TuningSpec + wantErr bool + errFields []string // Fields we expect to have errors + }{ + { + name: "All fields valid", + tuningSpec: &TuningSpec{ + Input: &DataSource{Name: "valid-input"}, + Output: &DataDestination{HostPath: "valid-output"}, + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}}, + Method: TuningMethodLora, + }, + wantErr: false, + errFields: nil, + }, + { + name: "Missing Input", + tuningSpec: &TuningSpec{ + Output: &DataDestination{HostPath: "valid-output"}, + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}}, + Method: TuningMethodLora, + }, + wantErr: true, + errFields: []string{"Input"}, + }, + { + name: "Missing Output", + tuningSpec: &TuningSpec{ + Input: &DataSource{Name: "valid-input"}, + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}}, + Method: TuningMethodLora, + }, + wantErr: true, + errFields: []string{"Output"}, + }, + { + name: "Missing Preset", + tuningSpec: &TuningSpec{ + Input: &DataSource{Name: "valid-input"}, + Output: &DataDestination{HostPath: "valid-output"}, + Method: TuningMethodLora, + }, + wantErr: true, + errFields: []string{"Preset"}, + }, + { + name: "Invalid Preset", + tuningSpec: &TuningSpec{ + Input: &DataSource{Name: "valid-input"}, + Output: &DataDestination{HostPath: "valid-output"}, + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("invalid-preset")}}, + Method: TuningMethodLora, + }, + wantErr: true, + errFields: []string{"presetName"}, + }, + { + name: "Invalid Method", + tuningSpec: &TuningSpec{ + Input: &DataSource{Name: "valid-input"}, + Output: &DataDestination{HostPath: "valid-output"}, + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}}, + Method: "invalid-method", + }, + wantErr: true, + errFields: []string{"Method"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.tuningSpec.validateCreate() + hasErrs := errs != nil + + if hasErrs != tt.wantErr { + t.Errorf("validateCreate() errors = %v, wantErr %v", errs, tt.wantErr) + } + + if hasErrs { + for _, field := range tt.errFields { + if !strings.Contains(errs.Error(), field) { + t.Errorf("validateCreate() expected errors to contain field %s, but got %s", field, errs.Error()) + } + } + } + }) + } +} + +func TestTuningSpecValidateUpdate(t *testing.T) { + RegisterValidationTestModels() + tests := []struct { + name string + oldTuning *TuningSpec + newTuning *TuningSpec + expectErrs bool + errFields []string // Fields we expect to have errors + }{ + { + name: "No changes", + oldTuning: &TuningSpec{ + Input: &DataSource{Name: "input1"}, + Output: &DataDestination{HostPath: "path1"}, + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}}, + Method: TuningMethodLora, + }, + newTuning: &TuningSpec{ + Input: &DataSource{Name: "input1"}, + Output: &DataDestination{HostPath: "path1"}, + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}}, + Method: TuningMethodLora, + }, + expectErrs: false, + }, + { + name: "Input changed", + oldTuning: &TuningSpec{ + Input: &DataSource{Name: "input1"}, + }, + newTuning: &TuningSpec{ + Input: &DataSource{Name: "input2"}, + }, + expectErrs: true, + errFields: []string{"Input"}, + }, + { + name: "Output changed", + oldTuning: &TuningSpec{ + Output: &DataDestination{HostPath: "path1"}, + }, + newTuning: &TuningSpec{ + Output: &DataDestination{HostPath: "path2"}, + }, + expectErrs: true, + errFields: []string{"Output"}, + }, + { + name: "Preset changed", + oldTuning: &TuningSpec{ + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}}, + }, + newTuning: &TuningSpec{ + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("invalid-preset")}}, + }, + expectErrs: true, + errFields: []string{"Preset"}, + }, + { + name: "Method changed", + oldTuning: &TuningSpec{ + Method: TuningMethodLora, + }, + newTuning: &TuningSpec{ + Method: TuningMethodQLora, + }, + expectErrs: true, + errFields: []string{"Method"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.newTuning.validateUpdate(tt.oldTuning) + hasErrs := errs != nil + + if hasErrs != tt.expectErrs { + t.Errorf("validateUpdate() errors = %v, expectErrs %v", errs, tt.expectErrs) + } + + if hasErrs { + for _, field := range tt.errFields { + if !strings.Contains(errs.Error(), field) { + t.Errorf("validateUpdate() expected errors to contain field %s, but got %s", field, errs.Error()) + } + } + } + }) + } +} + +func TestDataSourceValidateCreate(t *testing.T) { + tests := []struct { + name string + dataSource *DataSource + wantErr bool + errField string // The field we expect to have an error on + }{ + { + name: "URLs specified only", + dataSource: &DataSource{ + URLs: []string{"http://example.com/data"}, + }, + wantErr: false, + }, + { + name: "HostPath specified only", + dataSource: &DataSource{ + HostPath: "/data/path", + }, + wantErr: false, + }, + { + name: "Image specified only", + dataSource: &DataSource{ + Image: "data-image:latest", + }, + wantErr: false, + }, + { + name: "None specified", + dataSource: &DataSource{}, + wantErr: true, + errField: "Exactly one of URLs, HostPath, or Image must be specified", + }, + { + name: "URLs and HostPath specified", + dataSource: &DataSource{ + URLs: []string{"http://example.com/data"}, + HostPath: "/data/path", + }, + wantErr: true, + errField: "Exactly one of URLs, HostPath, or Image must be specified", + }, + { + name: "All fields specified", + dataSource: &DataSource{ + URLs: []string{"http://example.com/data"}, + HostPath: "/data/path", + Image: "data-image:latest", + }, + wantErr: true, + errField: "Exactly one of URLs, HostPath, or Image must be specified", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.dataSource.validateCreate() + hasErrs := errs != nil + + if hasErrs != tt.wantErr { + t.Errorf("validateCreate() error = %v, wantErr %v", errs, tt.wantErr) + } + + if hasErrs && tt.errField != "" && !strings.Contains(errs.Error(), tt.errField) { + t.Errorf("validateCreate() expected error to contain %s, but got %s", tt.errField, errs.Error()) + } + }) + } +} + +func TestDataSourceValidateUpdate(t *testing.T) { + tests := []struct { + name string + oldSource *DataSource + newSource *DataSource + wantErr bool + errFields []string // Fields we expect to have errors + }{ + { + name: "No changes", + oldSource: &DataSource{ + URLs: []string{"http://example.com/data1", "http://example.com/data2"}, + HostPath: "/data/path", + Image: "data-image:latest", + ImagePullSecrets: []string{"secret1", "secret2"}, + }, + newSource: &DataSource{ + URLs: []string{"http://example.com/data2", "http://example.com/data1"}, // Note the different order, should not matter + HostPath: "/data/path", + Image: "data-image:latest", + ImagePullSecrets: []string{"secret2", "secret1"}, // Note the different order, should not matter + }, + wantErr: false, + }, + { + name: "URLs changed", + oldSource: &DataSource{ + URLs: []string{"http://example.com/old"}, + }, + newSource: &DataSource{ + URLs: []string{"http://example.com/new"}, + }, + wantErr: true, + errFields: []string{"URLs"}, + }, + { + name: "HostPath changed", + oldSource: &DataSource{ + HostPath: "/old/path", + }, + newSource: &DataSource{ + HostPath: "/new/path", + }, + wantErr: true, + errFields: []string{"HostPath"}, + }, + { + name: "Image changed", + oldSource: &DataSource{ + Image: "old-image:latest", + }, + newSource: &DataSource{ + Image: "new-image:latest", + }, + wantErr: true, + errFields: []string{"Image"}, + }, + { + name: "ImagePullSecrets changed", + oldSource: &DataSource{ + ImagePullSecrets: []string{"old-secret"}, + }, + newSource: &DataSource{ + ImagePullSecrets: []string{"new-secret"}, + }, + wantErr: true, + errFields: []string{"ImagePullSecrets"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.newSource.validateUpdate(tt.oldSource) + hasErrs := errs != nil + + if hasErrs != tt.wantErr { + t.Errorf("validateUpdate() error = %v, wantErr %v", errs, tt.wantErr) + } + + if hasErrs { + for _, field := range tt.errFields { + if !strings.Contains(errs.Error(), field) { + t.Errorf("validateUpdate() expected errors to contain field %s, but got %s", field, errs.Error()) + } + } + } + }) + } +} + +func TestDataDestinationValidateCreate(t *testing.T) { + tests := []struct { + name string + dataDestination *DataDestination + wantErr bool + errField string // The field we expect to have an error on + }{ + { + name: "No fields specified", + dataDestination: &DataDestination{}, + wantErr: true, + errField: "At least one of HostPath or Image must be specified", + }, + { + name: "HostPath specified only", + dataDestination: &DataDestination{ + HostPath: "/data/path", + }, + wantErr: false, + }, + { + name: "Image specified only", + dataDestination: &DataDestination{ + Image: "data-image:latest", + }, + wantErr: false, + }, + { + name: "Both fields specified", + dataDestination: &DataDestination{ + HostPath: "/data/path", + Image: "data-image:latest", + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.dataDestination.validateCreate() + hasErrs := errs != nil + + if hasErrs != tt.wantErr { + t.Errorf("validateCreate() error = %v, wantErr %v", errs, tt.wantErr) + } + + if hasErrs && tt.errField != "" && !strings.Contains(errs.Error(), tt.errField) { + t.Errorf("validateCreate() expected error to contain %s, but got %s", tt.errField, errs.Error()) + } + }) + } +} + +func TestDataDestinationValidateUpdate(t *testing.T) { + tests := []struct { + name string + oldDest *DataDestination + newDest *DataDestination + wantErr bool + errFields []string // Fields we expect to have errors + }{ + { + name: "No changes", + oldDest: &DataDestination{ + HostPath: "/data/old", + Image: "old-image:latest", + ImagePushSecret: "old-secret", + }, + newDest: &DataDestination{ + HostPath: "/data/old", + Image: "old-image:latest", + ImagePushSecret: "old-secret", + }, + wantErr: false, + }, + { + name: "HostPath changed", + oldDest: &DataDestination{ + HostPath: "/data/old", + }, + newDest: &DataDestination{ + HostPath: "/data/new", + }, + wantErr: true, + errFields: []string{"HostPath"}, + }, + { + name: "Image changed", + oldDest: &DataDestination{ + Image: "old-image:latest", + }, + newDest: &DataDestination{ + Image: "new-image:latest", + }, + wantErr: true, + errFields: []string{"Image"}, + }, + { + name: "ImagePushSecret changed", + oldDest: &DataDestination{ + ImagePushSecret: "old-secret", + }, + newDest: &DataDestination{ + ImagePushSecret: "new-secret", + }, + wantErr: true, + errFields: []string{"ImagePushSecret"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.newDest.validateUpdate(tt.oldDest) + hasErrs := errs != nil + + if hasErrs != tt.wantErr { + t.Errorf("validateUpdate() error = %v, wantErr %v", errs, tt.wantErr) + } + + if hasErrs { + for _, field := range tt.errFields { + if !strings.Contains(errs.Error(), field) { + t.Errorf("validateUpdate() expected errors to contain field %s, but got %s", field, errs.Error()) + } + } + } + }) + } +} + func TestGetSupportedSKUs(t *testing.T) { tests := []struct { name string From 392ff401d4537ede687bac8393f8195abb80fa50 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Tue, 19 Mar 2024 19:07:09 -0700 Subject: [PATCH 06/29] fix: Update CRD to use pointers --- api/v1alpha1/workspace_types.go | 8 +- api/v1alpha1/workspace_validation.go | 55 ++++--- api/v1alpha1/workspace_validation_test.go | 69 ++++---- api/v1alpha1/zz_generated.deepcopy.go | 12 +- .../workspace/crds/kaito.sh_workspaces.yaml | 147 +++++++++++++++++- config/crd/bases/kaito.sh_workspaces.yaml | 11 +- pkg/utils/testUtils.go | 6 +- test/e2e/preset_test.go | 19 ++- test/e2e/utils/utils.go | 30 ++-- 9 files changed, 252 insertions(+), 105 deletions(-) diff --git a/api/v1alpha1/workspace_types.go b/api/v1alpha1/workspace_types.go index 71e9f829c..4484b8250 100644 --- a/api/v1alpha1/workspace_types.go +++ b/api/v1alpha1/workspace_types.go @@ -150,9 +150,9 @@ type TuningSpec struct { // +optional Config string `json:"config,omitempty"` // Input describes the input used by the tuning method. - Input *DataSource `json:"input,omitempty"` + Input *DataSource `json:"input"` // Output specified where to store the tuning output. - Output *DataDestination `json:"output,omitempty"` + Output *DataDestination `json:"output"` } // WorkspaceStatus defines the observed state of Workspace @@ -181,8 +181,8 @@ type Workspace struct { metav1.ObjectMeta `json:"metadata,omitempty"` Resource ResourceSpec `json:"resource,omitempty"` - Inference InferenceSpec `json:"inference,omitempty"` - Tuning TuningSpec `json:"tuning,omitempty"` + Inference *InferenceSpec `json:"inference,omitempty"` + Tuning *TuningSpec `json:"tuning,omitempty"` Status WorkspaceStatus `json:"status,omitempty"` } diff --git a/api/v1alpha1/workspace_validation.go b/api/v1alpha1/workspace_validation.go index 5a8269353..b135f5886 100644 --- a/api/v1alpha1/workspace_validation.go +++ b/api/v1alpha1/workspace_validation.go @@ -37,51 +37,48 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) { klog.InfoS("Validate creation", "workspace", fmt.Sprintf("%s/%s", w.Namespace, w.Name)) errs = errs.Also( w.validateCreate().ViaField("spec"), - w.Inference.validateCreate().ViaField("inference"), - w.Tuning.validateCreate().ViaField("tuning"), - w.Tuning.Input.validateCreate().ViaField("input"), - w.Tuning.Output.validateCreate().ViaField("output"), - w.Resource.validateCreate(w.Inference).ViaField("resource"), + // TODO: Consider validate resource based on Tuning Spec + w.Resource.validateCreate(*w.Inference).ViaField("resource"), ) + if w.Inference != nil { + errs = errs.Also(w.Inference.validateCreate().ViaField("inference")) + } + if w.Tuning != nil { + errs = errs.Also(w.Tuning.validateCreate().ViaField("tuning")) + } } else { klog.InfoS("Validate update", "workspace", fmt.Sprintf("%s/%s", w.Namespace, w.Name)) old := base.(*Workspace) errs = errs.Also( w.validateUpdate(old).ViaField("spec"), - w.Inference.validateUpdate(&old.Inference).ViaField("inference"), - w.Tuning.validateUpdate(&old.Tuning).ViaField("tuning"), - w.Tuning.Input.validateUpdate(old.Tuning.Input).ViaField("input"), - w.Tuning.Output.validateUpdate(old.Tuning.Output).ViaField("output"), w.Resource.validateUpdate(&old.Resource).ViaField("resource"), ) + if w.Inference != nil { + errs = errs.Also(w.Inference.validateUpdate(old.Inference).ViaField("inference")) + } + if w.Tuning != nil { + errs = errs.Also(w.Tuning.validateUpdate(old.Tuning).ViaField("tuning")) + } } return errs } func (w *Workspace) validateCreate() (errs *apis.FieldError) { - inferenceSpecified := w.Inference.Preset != nil || w.Inference.Template != nil - tuningSpecified := w.Tuning.Input != nil - if !inferenceSpecified && !tuningSpecified { + if w.Inference == nil && w.Tuning == nil { errs = errs.Also(apis.ErrGeneric("Either Inference or Tuning must be specified, not neither", "")) } - if inferenceSpecified && tuningSpecified { + if w.Inference != nil && w.Tuning != nil { errs = errs.Also(apis.ErrGeneric("Either Inference or Tuning must be specified, but not both", "")) } return errs } func (w *Workspace) validateUpdate(old *Workspace) (errs *apis.FieldError) { - // Check inference specified - oldInferenceSpecified := old.Inference.Preset != nil || old.Inference.Template != nil - inferenceSpecified := w.Inference.Preset != nil || w.Inference.Template != nil - // Check tuning specified - oldTuningSpecified := old.Tuning.Input != nil - tuningSpecified := w.Tuning.Input != nil - if (!oldInferenceSpecified && inferenceSpecified) || (oldInferenceSpecified && !inferenceSpecified) { + if (old.Inference == nil && w.Inference != nil) || (old.Inference != nil && w.Inference == nil) { errs = errs.Also(apis.ErrGeneric("Inference field cannot be toggled once set", "inference")) } - if (!oldTuningSpecified && tuningSpecified) || (oldTuningSpecified && !tuningSpecified) { + if (old.Tuning == nil && w.Tuning != nil) || (old.Tuning != nil && w.Tuning == nil) { errs = errs.Also(apis.ErrGeneric("Tuning field cannot be toggled once set", "tuning")) } return errs @@ -90,9 +87,13 @@ func (w *Workspace) validateUpdate(old *Workspace) (errs *apis.FieldError) { func (r *TuningSpec) validateCreate() (errs *apis.FieldError) { if r.Input == nil { errs = errs.Also(apis.ErrMissingField("Input")) + } else { + errs = errs.Also(r.Input.validateCreate().ViaField("Input")) } if r.Output == nil { errs = errs.Also(apis.ErrMissingField("Output")) + } else { + errs = errs.Also(r.Output.validateCreate().ViaField("Output")) } // Currently require a preset to specified, in future we can consider defining a template if r.Preset == nil { @@ -108,11 +109,15 @@ func (r *TuningSpec) validateCreate() (errs *apis.FieldError) { } func (r *TuningSpec) validateUpdate(old *TuningSpec) (errs *apis.FieldError) { - if !reflect.DeepEqual(old.Input, r.Input) { - errs = errs.Also(apis.ErrGeneric("Input field cannot be changed", "Input")) + if r.Input == nil { + errs = errs.Also(apis.ErrMissingField("Input")) + } else { + errs = errs.Also(r.Input.validateUpdate(old.Input).ViaField("Input")) } - if !reflect.DeepEqual(old.Output, r.Output) { - errs = errs.Also(apis.ErrGeneric("Output field cannot be changed", "Output")) + if r.Output == nil { + errs = errs.Also(apis.ErrMissingField("Output")) + } else { + errs = errs.Also(r.Output.validateUpdate(old.Output).ViaField("Output")) } if !reflect.DeepEqual(old.Preset, r.Preset) { errs = errs.Also(apis.ErrGeneric("Preset cannot be changed", "Preset")) diff --git a/api/v1alpha1/workspace_validation_test.go b/api/v1alpha1/workspace_validation_test.go index 6c2f1a650..d1cea034d 100644 --- a/api/v1alpha1/workspace_validation_test.go +++ b/api/v1alpha1/workspace_validation_test.go @@ -496,19 +496,16 @@ func TestWorkspaceValidateCreate(t *testing.T) { errField string }{ { - name: "Neither Inference nor Tuning specified", - workspace: &Workspace{ - Inference: InferenceSpec{}, - Tuning: TuningSpec{}, - }, - wantErr: true, - errField: "neither", + name: "Neither Inference nor Tuning specified", + workspace: &Workspace{}, + wantErr: true, + errField: "neither", }, { name: "Both Inference and Tuning specified", workspace: &Workspace{ - Inference: InferenceSpec{Preset: &PresetSpec{}}, - Tuning: TuningSpec{Input: &DataSource{}}, + Inference: &InferenceSpec{}, + Tuning: &TuningSpec{}, }, wantErr: true, errField: "both", @@ -516,7 +513,7 @@ func TestWorkspaceValidateCreate(t *testing.T) { { name: "Only Inference specified", workspace: &Workspace{ - Inference: InferenceSpec{Preset: &PresetSpec{}}, + Inference: &InferenceSpec{}, }, wantErr: false, errField: "", @@ -524,7 +521,7 @@ func TestWorkspaceValidateCreate(t *testing.T) { { name: "Only Tuning specified", workspace: &Workspace{ - Tuning: TuningSpec{Input: &DataSource{}}, + Tuning: &TuningSpec{Input: &DataSource{}}, }, wantErr: false, errField: "", @@ -553,12 +550,10 @@ func TestWorkspaceValidateUpdate(t *testing.T) { errFields []string // Fields we expect to have errors }{ { - name: "Inference toggled on", - oldWorkspace: &Workspace{ - Inference: InferenceSpec{}, - }, + name: "Inference toggled on", + oldWorkspace: &Workspace{}, newWorkspace: &Workspace{ - Inference: InferenceSpec{Preset: &PresetSpec{}}, + Inference: &InferenceSpec{}, }, expectErrs: true, errFields: []string{"inference"}, @@ -566,21 +561,17 @@ func TestWorkspaceValidateUpdate(t *testing.T) { { name: "Inference toggled off", oldWorkspace: &Workspace{ - Inference: InferenceSpec{Preset: &PresetSpec{}}, - }, - newWorkspace: &Workspace{ - Inference: InferenceSpec{}, + Inference: &InferenceSpec{Preset: &PresetSpec{}}, }, - expectErrs: true, - errFields: []string{"inference"}, + newWorkspace: &Workspace{}, + expectErrs: true, + errFields: []string{"inference"}, }, { - name: "Tuning toggled on", - oldWorkspace: &Workspace{ - Tuning: TuningSpec{}, - }, + name: "Tuning toggled on", + oldWorkspace: &Workspace{}, newWorkspace: &Workspace{ - Tuning: TuningSpec{Input: &DataSource{}}, + Tuning: &TuningSpec{Input: &DataSource{}}, }, expectErrs: true, errFields: []string{"tuning"}, @@ -588,21 +579,19 @@ func TestWorkspaceValidateUpdate(t *testing.T) { { name: "Tuning toggled off", oldWorkspace: &Workspace{ - Tuning: TuningSpec{Input: &DataSource{}}, - }, - newWorkspace: &Workspace{ - Tuning: TuningSpec{}, + Tuning: &TuningSpec{Input: &DataSource{}}, }, - expectErrs: true, - errFields: []string{"tuning"}, + newWorkspace: &Workspace{}, + expectErrs: true, + errFields: []string{"tuning"}, }, { name: "No toggling", oldWorkspace: &Workspace{ - Tuning: TuningSpec{Input: &DataSource{}}, + Tuning: &TuningSpec{Input: &DataSource{}}, }, newWorkspace: &Workspace{ - Tuning: TuningSpec{Input: &DataSource{}}, + Tuning: &TuningSpec{Input: &DataSource{}}, }, expectErrs: false, }, @@ -639,7 +628,7 @@ func TestTuningSpecValidateCreate(t *testing.T) { { name: "All fields valid", tuningSpec: &TuningSpec{ - Input: &DataSource{Name: "valid-input"}, + Input: &DataSource{Name: "valid-input", HostPath: "valid-input"}, Output: &DataDestination{HostPath: "valid-output"}, Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}}, Method: TuningMethodLora, @@ -749,13 +738,15 @@ func TestTuningSpecValidateUpdate(t *testing.T) { { name: "Input changed", oldTuning: &TuningSpec{ - Input: &DataSource{Name: "input1"}, + Input: &DataSource{Name: "input", HostPath: "inputpath"}, + Output: &DataDestination{HostPath: "outputpath"}, }, newTuning: &TuningSpec{ - Input: &DataSource{Name: "input2"}, + Input: &DataSource{Name: "input", HostPath: "randompath"}, + Output: &DataDestination{HostPath: "outputpath"}, }, expectErrs: true, - errFields: []string{"Input"}, + errFields: []string{"HostPath"}, }, { name: "Output changed", diff --git a/api/v1alpha1/zz_generated.deepcopy.go b/api/v1alpha1/zz_generated.deepcopy.go index a9d662c0f..6c3ee1eb8 100644 --- a/api/v1alpha1/zz_generated.deepcopy.go +++ b/api/v1alpha1/zz_generated.deepcopy.go @@ -249,8 +249,16 @@ func (in *Workspace) DeepCopyInto(out *Workspace) { out.TypeMeta = in.TypeMeta in.ObjectMeta.DeepCopyInto(&out.ObjectMeta) in.Resource.DeepCopyInto(&out.Resource) - in.Inference.DeepCopyInto(&out.Inference) - in.Tuning.DeepCopyInto(&out.Tuning) + if in.Inference != nil { + in, out := &in.Inference, &out.Inference + *out = new(InferenceSpec) + (*in).DeepCopyInto(*out) + } + if in.Tuning != nil { + in, out := &in.Tuning, &out.Tuning + *out = new(TuningSpec) + (*in).DeepCopyInto(*out) + } in.Status.DeepCopyInto(&out.Status) } diff --git a/charts/kaito/workspace/crds/kaito.sh_workspaces.yaml b/charts/kaito/workspace/crds/kaito.sh_workspaces.yaml index 40908f609..a4103a897 100644 --- a/charts/kaito/workspace/crds/kaito.sh_workspaces.yaml +++ b/charts/kaito/workspace/crds/kaito.sh_workspaces.yaml @@ -47,11 +47,56 @@ spec: type: string inference: properties: + adapters: + description: Adapters are integrated into the base model for inference. + Users can specify multiple adapters for the model and the respective + weight of using each of them. + items: + properties: + source: + description: Source describes where to obtain the adapter data. + properties: + hostPath: + description: The directory in the host that contains the + data. + type: string + image: + description: The name of the image that contains the source + data. The assumption is that the source data locates in + the `data` directory in the image. + type: string + imagePullSecrets: + description: ImagePullSecrets is a list of secret names + in the same namespace used for pulling the data image. + items: + type: string + type: array + name: + description: The name of the dataset. The same name will + be used as a container name. It must be a valid DNS subdomain + value, + type: string + urls: + description: URLs specifies the links to the public data + sources. E.g., files in a public github repository. + items: + type: string + type: array + type: object + strength: + description: Strength specifies the default multiplier for applying + the adapter weights to the raw model weights. It is usually + a float number between 0 and 1. It is defined as a string + type to be language agnostic. + type: string + type: object + type: array preset: - description: Preset describles the model that will be deployed with - preset configurations. + description: Preset describes the base model that will be deployed + with preset configurations. properties: accessMode: + default: public description: AccessMode specifies whether the containerized model image is accessible via public registry or private registry. This field defaults to "public" if not specified. If this field @@ -72,7 +117,7 @@ spec: type: string imagePullSecrets: description: ImagePullSecrets is a list of secret names in - the same namespace used for pulling the image. + the same namespace used for pulling the model image. items: type: string type: array @@ -95,7 +140,7 @@ spec: metadata: type: object resource: - description: ResourceSpec desicribes the resource requirement of running + description: ResourceSpec describes the resource requirement of running the workload. If the number of nodes in the cluster that meet the InstanceType and LabelSelector requirements is small than the Count, controller will provision new nodes before deploying the workload. The final list of @@ -245,6 +290,100 @@ spec: type: string type: array type: object + tuning: + properties: + config: + description: Config specifies the name of the configmap in the same + namespace that contains the arguments used by the tuning method. + If not specified, a default configmap is used based on the specified + method. + type: string + input: + description: Input describes the input used by the tuning method. + properties: + hostPath: + description: The directory in the host that contains the data. + type: string + image: + description: The name of the image that contains the source data. + The assumption is that the source data locates in the `data` + directory in the image. + type: string + imagePullSecrets: + description: ImagePullSecrets is a list of secret names in the + same namespace used for pulling the data image. + items: + type: string + type: array + name: + description: The name of the dataset. The same name will be used + as a container name. It must be a valid DNS subdomain value, + type: string + urls: + description: URLs specifies the links to the public data sources. + E.g., files in a public github repository. + items: + type: string + type: array + type: object + method: + description: Method specifies the Parameter-Efficient Fine-Tuning(PEFT) + method, such as lora, qlora, used for the tuning. + type: string + output: + description: Output specified where to store the tuning output. + properties: + hostPath: + description: The directory in the host that contains the output + data. + type: string + image: + description: Name of the image where the output data is pushed + to. + type: string + imagePushSecret: + description: ImagePushSecret is the name of the secret in the + same namespace that contains the authentication information + that is needed for running `docker push`. + type: string + type: object + preset: + description: Preset describes which model to load for tuning. + properties: + accessMode: + default: public + description: AccessMode specifies whether the containerized model + image is accessible via public registry or private registry. + This field defaults to "public" if not specified. If this field + is "private", user needs to provide the private image information + in PresetOptions. + enum: + - public + - private + type: string + name: + description: Name of the supported models with preset configurations. + type: string + presetOptions: + properties: + image: + description: Image is the name of the containerized model + image. + type: string + imagePullSecrets: + description: ImagePullSecrets is a list of secret names in + the same namespace used for pulling the model image. + items: + type: string + type: array + type: object + required: + - name + type: object + required: + - input + - output + type: object type: object served: true storage: true diff --git a/config/crd/bases/kaito.sh_workspaces.yaml b/config/crd/bases/kaito.sh_workspaces.yaml index b3af23a76..a4103a897 100644 --- a/config/crd/bases/kaito.sh_workspaces.yaml +++ b/config/crd/bases/kaito.sh_workspaces.yaml @@ -57,7 +57,7 @@ spec: description: Source describes where to obtain the adapter data. properties: hostPath: - description: The directory in the hsot that contains the + description: The directory in the host that contains the data. type: string image: @@ -96,6 +96,7 @@ spec: with preset configurations. properties: accessMode: + default: public description: AccessMode specifies whether the containerized model image is accessible via public registry or private registry. This field defaults to "public" if not specified. If this field @@ -139,7 +140,7 @@ spec: metadata: type: object resource: - description: ResourceSpec desicribes the resource requirement of running + description: ResourceSpec describes the resource requirement of running the workload. If the number of nodes in the cluster that meet the InstanceType and LabelSelector requirements is small than the Count, controller will provision new nodes before deploying the workload. The final list of @@ -301,7 +302,7 @@ spec: description: Input describes the input used by the tuning method. properties: hostPath: - description: The directory in the hsot that contains the data. + description: The directory in the host that contains the data. type: string image: description: The name of the image that contains the source data. @@ -350,6 +351,7 @@ spec: description: Preset describes which model to load for tuning. properties: accessMode: + default: public description: AccessMode specifies whether the containerized model image is accessible via public registry or private registry. This field defaults to "public" if not specified. If this field @@ -378,6 +380,9 @@ spec: required: - name type: object + required: + - input + - output type: object type: object served: true diff --git a/pkg/utils/testUtils.go b/pkg/utils/testUtils.go index f88b35a4f..5ef34af1d 100644 --- a/pkg/utils/testUtils.go +++ b/pkg/utils/testUtils.go @@ -35,7 +35,7 @@ var ( }, }, }, - Inference: v1alpha1.InferenceSpec{ + Inference: &v1alpha1.InferenceSpec{ Preset: &v1alpha1.PresetSpec{ PresetMeta: v1alpha1.PresetMeta{ Name: "test-distributed-model", @@ -60,7 +60,7 @@ var ( }, }, }, - Inference: v1alpha1.InferenceSpec{ + Inference: &v1alpha1.InferenceSpec{ Preset: &v1alpha1.PresetSpec{ PresetMeta: v1alpha1.PresetMeta{ Name: "test-model", @@ -85,7 +85,7 @@ var ( }, }, }, - Inference: v1alpha1.InferenceSpec{ + Inference: &v1alpha1.InferenceSpec{ Template: &corev1.PodTemplateSpec{}, }, } diff --git a/test/e2e/preset_test.go b/test/e2e/preset_test.go index eb0333df4..e8f262ef0 100644 --- a/test/e2e/preset_test.go +++ b/test/e2e/preset_test.go @@ -26,13 +26,13 @@ import ( ) const ( - PresetLlama2AChat = "llama-2-7b-chat" - PresetLlama2BChat = "llama-2-13b-chat" - PresetFalcon7BModel = "falcon-7b" - PresetFalcon40BModel = "falcon-40b" - PresetMistral7BModel = "mistral-7b" + PresetLlama2AChat = "llama-2-7b-chat" + PresetLlama2BChat = "llama-2-13b-chat" + PresetFalcon7BModel = "falcon-7b" + PresetFalcon40BModel = "falcon-40b" + PresetMistral7BModel = "mistral-7b" PresetMistral7BInstructModel = "mistral-7b-instruct" - PresetPhi2Model = "phi-2" + PresetPhi2Model = "phi-2" ) func createFalconWorkspaceWithPresetPublicMode(numOfNode int) *kaitov1alpha1.Workspace { @@ -348,17 +348,17 @@ var _ = Describe("Workspace Preset", func() { fmt.Print("Error: RUN_LLAMA_13B ENV Variable not set") runLlama13B = false } - + aiModelsRegistry = utils.GetEnv("AI_MODELS_REGISTRY") aiModelsRegistrySecret = utils.GetEnv("AI_MODELS_REGISTRY_SECRET") - + // Load stable model versions configs, err := utils.GetModelConfigInfo("/home/runner/work/kaito/kaito/presets/models/supported_models.yaml") if err != nil { fmt.Printf("Failed to load model configs: %v\n", err) os.Exit(1) } - + modelInfo, err = utils.ExtractModelVersion(configs) if err != nil { fmt.Printf("Failed to extract stable model versions: %v\n", err) @@ -404,7 +404,6 @@ var _ = Describe("Workspace Preset", func() { validateWorkspaceReadiness(workspaceObj) }) - It("should create a Phi-2 workspace with preset public mode successfully", func() { numOfNode := 1 workspaceObj := createPhi2WorkspaceWithPresetPublicMode(numOfNode) diff --git a/test/e2e/utils/utils.go b/test/e2e/utils/utils.go index 3914f00eb..38388374f 100644 --- a/test/e2e/utils/utils.go +++ b/test/e2e/utils/utils.go @@ -60,23 +60,23 @@ func ExtractModelVersion(configs map[string]interface{}) (map[string]string, err } for _, modelItem := range models { - model, ok := modelItem.(map[interface{}]interface{}) - if !ok { - return nil, fmt.Errorf("model item is not a map") - } + model, ok := modelItem.(map[interface{}]interface{}) + if !ok { + return nil, fmt.Errorf("model item is not a map") + } - modelName, ok := model["name"].(string) - if !ok { - return nil, fmt.Errorf("model name is not a string or not found") - } + modelName, ok := model["name"].(string) + if !ok { + return nil, fmt.Errorf("model name is not a string or not found") + } - modelTag, ok := model["tag"].(string) // Using 'tag' as the version - if !ok { - return nil, fmt.Errorf("model version for %s is not a string or not found", modelName) - } + modelTag, ok := model["tag"].(string) // Using 'tag' as the version + if !ok { + return nil, fmt.Errorf("model version for %s is not a string or not found", modelName) + } - modelsInfo[modelName] = modelTag - } + modelsInfo[modelName] = modelTag + } return modelsInfo, nil } @@ -117,7 +117,7 @@ func GenerateWorkspaceManifest(name, namespace, imageName string, resourceCount workspaceInference.Template = podTemplate } - workspace.Inference = workspaceInference + workspace.Inference = &workspaceInference return workspace } From 6c347c99fd4447c7bcff414207d8fad5bee67ce4 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Tue, 19 Mar 2024 19:33:01 -0700 Subject: [PATCH 07/29] fix: Add name flag --- api/v1alpha1/workspace_validation.go | 9 +++++++-- api/v1alpha1/workspace_validation_test.go | 13 ++++++++++++- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/api/v1alpha1/workspace_validation.go b/api/v1alpha1/workspace_validation.go index b135f5886..79b27e1b9 100644 --- a/api/v1alpha1/workspace_validation.go +++ b/api/v1alpha1/workspace_validation.go @@ -41,6 +41,7 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) { w.Resource.validateCreate(*w.Inference).ViaField("resource"), ) if w.Inference != nil { + // TODO: Add Adapter Spec Validation - Including DataSource Validation for Adapter errs = errs.Also(w.Inference.validateCreate().ViaField("inference")) } if w.Tuning != nil { @@ -54,6 +55,7 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) { w.Resource.validateUpdate(&old.Resource).ViaField("resource"), ) if w.Inference != nil { + // TODO: Add Adapter Spec Validation - Including DataSource Validation for Adapter errs = errs.Also(w.Inference.validateUpdate(old.Inference).ViaField("inference")) } if w.Tuning != nil { @@ -112,7 +114,7 @@ func (r *TuningSpec) validateUpdate(old *TuningSpec) (errs *apis.FieldError) { if r.Input == nil { errs = errs.Also(apis.ErrMissingField("Input")) } else { - errs = errs.Also(r.Input.validateUpdate(old.Input).ViaField("Input")) + errs = errs.Also(r.Input.validateUpdate(old.Input, true).ViaField("Input")) } if r.Output == nil { errs = errs.Also(apis.ErrMissingField("Output")) @@ -150,7 +152,10 @@ func (r *DataSource) validateCreate() (errs *apis.FieldError) { return errs } -func (r *DataSource) validateUpdate(old *DataSource) (errs *apis.FieldError) { +func (r *DataSource) validateUpdate(old *DataSource, isTuning bool) (errs *apis.FieldError) { + if isTuning && !reflect.DeepEqual(old.Name, r.Name) { + errs = errs.Also(apis.ErrInvalidValue("During tuning Name field cannot be changed once set", "Name")) + } oldURLs := make([]string, len(old.URLs)) copy(oldURLs, old.URLs) sort.Strings(oldURLs) diff --git a/api/v1alpha1/workspace_validation_test.go b/api/v1alpha1/workspace_validation_test.go index d1cea034d..11631e67b 100644 --- a/api/v1alpha1/workspace_validation_test.go +++ b/api/v1alpha1/workspace_validation_test.go @@ -898,6 +898,17 @@ func TestDataSourceValidateUpdate(t *testing.T) { }, wantErr: false, }, + { + name: "Name changed", + oldSource: &DataSource{ + Name: "original-dataset", + }, + newSource: &DataSource{ + Name: "new-dataset", + }, + wantErr: true, + errFields: []string{"Name"}, + }, { name: "URLs changed", oldSource: &DataSource{ @@ -946,7 +957,7 @@ func TestDataSourceValidateUpdate(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - errs := tt.newSource.validateUpdate(tt.oldSource) + errs := tt.newSource.validateUpdate(tt.oldSource, true) hasErrs := errs != nil if hasErrs != tt.wantErr { From 1a14872eeea12cad24acd45918d9da55ecebe4ae Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Wed, 20 Mar 2024 13:37:35 -0700 Subject: [PATCH 08/29] feat: Setup Interface for fine tuning --- api/v1alpha1/workspace_condition_types.go | 3 + api/v1alpha1/workspace_validation_test.go | 8 +- .../kaito_workspace_tuning_falcon_7b.yaml | 20 +++++ .../kaito_workspace_falcon_40b-instruct.yaml | 0 .../kaito_workspace_falcon_40b.yaml | 0 .../kaito_workspace_falcon_7b-instruct.yaml | 0 .../kaito_workspace_falcon_7b.yaml | 0 .../kaito_workspace_llama2_13b-chat.yaml | 0 .../kaito_workspace_llama2_13b.yaml | 0 .../kaito_workspace_llama2_70b-chat.yaml | 0 .../kaito_workspace_llama2_70b.yaml | 0 .../kaito_workspace_llama2_7b-chat.yaml | 0 .../kaito_workspace_llama2_7b.yaml | 0 .../kaito_workspace_mistral_7b-instruct.yaml | 0 .../kaito_workspace_mistral_7b.yaml | 0 .../kaito_workspace_phi-2.yaml | 0 pkg/controllers/workspace_controller.go | 75 +++++++++++++++++-- pkg/inference/preset-inferences.go | 10 +-- pkg/inference/preset-inferences_test.go | 2 +- pkg/model/interface.go | 7 +- pkg/tuning/preset-tuning-types.go | 21 ++++++ pkg/tuning/preset-tuning.go | 29 +++++++ pkg/utils/testModel.go | 8 +- presets/models/falcon/README.md | 8 +- presets/models/falcon/model.go | 55 +++++++++++--- presets/models/llama2/README.md | 6 +- presets/models/llama2/model.go | 21 ++++-- presets/models/llama2chat/README.md | 6 +- presets/models/llama2chat/model.go | 22 ++++-- presets/models/mistral/README.md | 4 +- presets/models/mistral/model.go | 27 ++++++- presets/models/phi/README.md | 2 +- presets/models/phi/model.go | 4 +- 33 files changed, 272 insertions(+), 66 deletions(-) create mode 100644 examples/fine-tuning/kaito_workspace_tuning_falcon_7b.yaml rename examples/{ => inference}/kaito_workspace_falcon_40b-instruct.yaml (100%) rename examples/{ => inference}/kaito_workspace_falcon_40b.yaml (100%) rename examples/{ => inference}/kaito_workspace_falcon_7b-instruct.yaml (100%) rename examples/{ => inference}/kaito_workspace_falcon_7b.yaml (100%) rename examples/{ => inference}/kaito_workspace_llama2_13b-chat.yaml (100%) rename examples/{ => inference}/kaito_workspace_llama2_13b.yaml (100%) rename examples/{ => inference}/kaito_workspace_llama2_70b-chat.yaml (100%) rename examples/{ => inference}/kaito_workspace_llama2_70b.yaml (100%) rename examples/{ => inference}/kaito_workspace_llama2_7b-chat.yaml (100%) rename examples/{ => inference}/kaito_workspace_llama2_7b.yaml (100%) rename examples/{ => inference}/kaito_workspace_mistral_7b-instruct.yaml (100%) rename examples/{ => inference}/kaito_workspace_mistral_7b.yaml (100%) rename examples/{ => inference}/kaito_workspace_phi-2.yaml (100%) create mode 100644 pkg/tuning/preset-tuning-types.go create mode 100644 pkg/tuning/preset-tuning.go diff --git a/api/v1alpha1/workspace_condition_types.go b/api/v1alpha1/workspace_condition_types.go index 762d8dafc..9845b8a0c 100644 --- a/api/v1alpha1/workspace_condition_types.go +++ b/api/v1alpha1/workspace_condition_types.go @@ -16,6 +16,9 @@ const ( // WorkspaceConditionTypeInferenceStatus is the state when Inference has been created. WorkspaceConditionTypeInferenceStatus = ConditionType("InferenceReady") + // WorkspaceConditionTypeTuningStatus is the state when Tuning has been created. + WorkspaceConditionTypeTuningStatus = ConditionType("TuningReady") + //WorkspaceConditionTypeDeleting is the Workspace state when starts to get deleted. WorkspaceConditionTypeDeleting = ConditionType("WorkspaceDeleting") diff --git a/api/v1alpha1/workspace_validation_test.go b/api/v1alpha1/workspace_validation_test.go index 11631e67b..695d42298 100644 --- a/api/v1alpha1/workspace_validation_test.go +++ b/api/v1alpha1/workspace_validation_test.go @@ -21,8 +21,8 @@ var perGPUMemoryRequirement string type testModel struct{} -func (*testModel) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*testModel) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ GPUCountRequirement: gpuCountRequirement, TotalGPUMemoryRequirement: totalGPUMemoryRequirement, PerGPUMemoryRequirement: perGPUMemoryRequirement, @@ -34,8 +34,8 @@ func (*testModel) SupportDistributedInference() bool { type testModelPrivate struct{} -func (*testModelPrivate) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*testModelPrivate) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ImageAccessMode: "private", GPUCountRequirement: gpuCountRequirement, TotalGPUMemoryRequirement: totalGPUMemoryRequirement, diff --git a/examples/fine-tuning/kaito_workspace_tuning_falcon_7b.yaml b/examples/fine-tuning/kaito_workspace_tuning_falcon_7b.yaml new file mode 100644 index 000000000..6d6ed7831 --- /dev/null +++ b/examples/fine-tuning/kaito_workspace_tuning_falcon_7b.yaml @@ -0,0 +1,20 @@ +apiVersion: kaito.sh/v1alpha1 +kind: Workspace +metadata: + name: workspace-tuning-falcon-7b +spec: + resource: + instanceType: "Standard_NC12s_v3" + labelSelector: + matchLabels: + app: tuning-falcon-7b + tuning: + preset: + name: falcon-7b + method: lora + config: tuning-config-map # ConfigMap containing tuning arguments + input: + name: tuning-data + hostPath: /path/to/your/input/data # dataset on node + output: + hostPath: /path/to/store/output # Tuning Output diff --git a/examples/kaito_workspace_falcon_40b-instruct.yaml b/examples/inference/kaito_workspace_falcon_40b-instruct.yaml similarity index 100% rename from examples/kaito_workspace_falcon_40b-instruct.yaml rename to examples/inference/kaito_workspace_falcon_40b-instruct.yaml diff --git a/examples/kaito_workspace_falcon_40b.yaml b/examples/inference/kaito_workspace_falcon_40b.yaml similarity index 100% rename from examples/kaito_workspace_falcon_40b.yaml rename to examples/inference/kaito_workspace_falcon_40b.yaml diff --git a/examples/kaito_workspace_falcon_7b-instruct.yaml b/examples/inference/kaito_workspace_falcon_7b-instruct.yaml similarity index 100% rename from examples/kaito_workspace_falcon_7b-instruct.yaml rename to examples/inference/kaito_workspace_falcon_7b-instruct.yaml diff --git a/examples/kaito_workspace_falcon_7b.yaml b/examples/inference/kaito_workspace_falcon_7b.yaml similarity index 100% rename from examples/kaito_workspace_falcon_7b.yaml rename to examples/inference/kaito_workspace_falcon_7b.yaml diff --git a/examples/kaito_workspace_llama2_13b-chat.yaml b/examples/inference/kaito_workspace_llama2_13b-chat.yaml similarity index 100% rename from examples/kaito_workspace_llama2_13b-chat.yaml rename to examples/inference/kaito_workspace_llama2_13b-chat.yaml diff --git a/examples/kaito_workspace_llama2_13b.yaml b/examples/inference/kaito_workspace_llama2_13b.yaml similarity index 100% rename from examples/kaito_workspace_llama2_13b.yaml rename to examples/inference/kaito_workspace_llama2_13b.yaml diff --git a/examples/kaito_workspace_llama2_70b-chat.yaml b/examples/inference/kaito_workspace_llama2_70b-chat.yaml similarity index 100% rename from examples/kaito_workspace_llama2_70b-chat.yaml rename to examples/inference/kaito_workspace_llama2_70b-chat.yaml diff --git a/examples/kaito_workspace_llama2_70b.yaml b/examples/inference/kaito_workspace_llama2_70b.yaml similarity index 100% rename from examples/kaito_workspace_llama2_70b.yaml rename to examples/inference/kaito_workspace_llama2_70b.yaml diff --git a/examples/kaito_workspace_llama2_7b-chat.yaml b/examples/inference/kaito_workspace_llama2_7b-chat.yaml similarity index 100% rename from examples/kaito_workspace_llama2_7b-chat.yaml rename to examples/inference/kaito_workspace_llama2_7b-chat.yaml diff --git a/examples/kaito_workspace_llama2_7b.yaml b/examples/inference/kaito_workspace_llama2_7b.yaml similarity index 100% rename from examples/kaito_workspace_llama2_7b.yaml rename to examples/inference/kaito_workspace_llama2_7b.yaml diff --git a/examples/kaito_workspace_mistral_7b-instruct.yaml b/examples/inference/kaito_workspace_mistral_7b-instruct.yaml similarity index 100% rename from examples/kaito_workspace_mistral_7b-instruct.yaml rename to examples/inference/kaito_workspace_mistral_7b-instruct.yaml diff --git a/examples/kaito_workspace_mistral_7b.yaml b/examples/inference/kaito_workspace_mistral_7b.yaml similarity index 100% rename from examples/kaito_workspace_mistral_7b.yaml rename to examples/inference/kaito_workspace_mistral_7b.yaml diff --git a/examples/kaito_workspace_phi-2.yaml b/examples/inference/kaito_workspace_phi-2.yaml similarity index 100% rename from examples/kaito_workspace_phi-2.yaml rename to examples/inference/kaito_workspace_phi-2.yaml diff --git a/pkg/controllers/workspace_controller.go b/pkg/controllers/workspace_controller.go index a2e4fc18d..db1150ad5 100644 --- a/pkg/controllers/workspace_controller.go +++ b/pkg/controllers/workspace_controller.go @@ -5,6 +5,7 @@ package controllers import ( "context" "fmt" + "github.com/azure/kaito/pkg/tuning" "sort" "strings" "time" @@ -109,16 +110,27 @@ func (c *WorkspaceReconciler) addOrUpdateWorkspace(ctx context.Context, wObj *ka return reconcile.Result{}, err } - if err = c.applyInference(ctx, wObj); err != nil { - if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeReady, metav1.ConditionFalse, - "workspaceFailed", err.Error()); updateErr != nil { - klog.ErrorS(updateErr, "failed to update workspace status", "workspace", klog.KObj(wObj)) - return reconcile.Result{}, updateErr + if wObj.Tuning != nil { + if err = c.applyTuning(ctx, wObj); err != nil { + if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeReady, metav1.ConditionFalse, + "workspaceFailed", err.Error()); updateErr != nil { + klog.ErrorS(updateErr, "failed to update workspace status", "workspace", klog.KObj(wObj)) + return reconcile.Result{}, updateErr + } + return reconcile.Result{}, err + } + } + if wObj.Inference != nil { + if err = c.applyInference(ctx, wObj); err != nil { + if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeReady, metav1.ConditionFalse, + "workspaceFailed", err.Error()); updateErr != nil { + klog.ErrorS(updateErr, "failed to update workspace status", "workspace", klog.KObj(wObj)) + return reconcile.Result{}, updateErr + } + return reconcile.Result{}, err } - return reconcile.Result{}, err } - // TODO apply TrainingSpec if err = c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeReady, metav1.ConditionTrue, "workspaceReady", "workspace is ready"); err != nil { klog.ErrorS(err, "failed to update workspace status", "workspace", klog.KObj(wObj)) @@ -423,6 +435,55 @@ func (c *WorkspaceReconciler) ensureService(ctx context.Context, wObj *kaitov1al return nil } +func (c *WorkspaceReconciler) applyTuning(ctx context.Context, wObj *kaitov1alpha1.Workspace) error { + var err error + func() { + if wObj.Tuning.Preset != nil { + presetName := string(wObj.Tuning.Preset.Name) + model := plugin.KaitoModelRegister.MustGet(presetName) + + trainingParam := model.GetTrainingParameters() + + var existingObj client.Object + existingObj = &appsv1.Deployment{} + if err = resources.GetResource(ctx, wObj.Name, wObj.Namespace, c.Client, existingObj); err == nil { + klog.InfoS("A training workload already exists for workspace", "workspace", klog.KObj(wObj)) + if err = resources.CheckResourceStatus(existingObj, c.Client, trainingParam.DeploymentTimeout); err != nil { + return + } + } else if apierrors.IsNotFound(err) { + var workloadObj client.Object + // Need to create a new workload + workloadObj, err = tuning.CreatePresetTuning(ctx, wObj, trainingParam, c.Client) + if err != nil { + return + } + if err = resources.CheckResourceStatus(workloadObj, c.Client, trainingParam.DeploymentTimeout); err != nil { + return + } + } + } + }() + + if err != nil { + if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeTuningStatus, metav1.ConditionFalse, + "WorkspaceTuningStatusFailed", err.Error()); updateErr != nil { + klog.ErrorS(updateErr, "failed to update workspace status", "workspace", klog.KObj(wObj)) + return updateErr + } else { + return err + + } + } + + if err := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeTuningStatus, metav1.ConditionTrue, + "WorkspaceTuningStatusSuccess", "Tuning has been deployed successfully"); err != nil { + klog.ErrorS(err, "failed to update workspace status", "workspace", klog.KObj(wObj)) + return err + } + return nil +} + // applyInference applies inference spec. func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1alpha1.Workspace) error { var err error diff --git a/pkg/inference/preset-inferences.go b/pkg/inference/preset-inferences.go index 9b02012b7..4c4792b54 100644 --- a/pkg/inference/preset-inferences.go +++ b/pkg/inference/preset-inferences.go @@ -67,7 +67,7 @@ var ( } ) -func updateTorchParamsForDistributedInference(ctx context.Context, kubeClient client.Client, wObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetInferenceParam) error { +func updateTorchParamsForDistributedInference(ctx context.Context, kubeClient client.Client, wObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetParam) error { existingService := &corev1.Service{} err := resources.GetResource(ctx, wObj.Name, wObj.Namespace, kubeClient, existingService) if err != nil { @@ -92,7 +92,7 @@ func updateTorchParamsForDistributedInference(ctx context.Context, kubeClient cl return nil } -func GetImageInfo(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetInferenceParam) (string, []corev1.LocalObjectReference) { +func GetImageInfo(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetParam) (string, []corev1.LocalObjectReference) { imageName := string(workspaceObj.Inference.Preset.Name) imageTag := inferenceObj.Tag imagePullSecretRefs := []corev1.LocalObjectReference{} @@ -110,7 +110,7 @@ func GetImageInfo(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, in } func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, - inferenceObj *model.PresetInferenceParam, supportDistributedInference bool, kubeClient client.Client) (client.Object, error) { + inferenceObj *model.PresetParam, supportDistributedInference bool, kubeClient client.Client) (client.Object, error) { if inferenceObj.TorchRunParams != nil && supportDistributedInference { if err := updateTorchParamsForDistributedInference(ctx, kubeClient, workspaceObj, inferenceObj); err != nil { klog.ErrorS(err, "failed to update torch params", "workspace", workspaceObj) @@ -141,7 +141,7 @@ func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Work // torchrun baseCommand // and sets the GPU resources required for inference. // Returns the command and resource configuration. -func prepareInferenceParameters(ctx context.Context, inferenceObj *model.PresetInferenceParam) ([]string, corev1.ResourceRequirements) { +func prepareInferenceParameters(ctx context.Context, inferenceObj *model.PresetParam) ([]string, corev1.ResourceRequirements) { torchCommand := buildCommandStr(inferenceObj.BaseCommand, inferenceObj.TorchRunParams) torchCommand = buildCommandStr(torchCommand, inferenceObj.TorchRunRdzvParams) modelCommand := buildCommandStr(InferenceFile, inferenceObj.ModelRunParams) @@ -159,7 +159,7 @@ func prepareInferenceParameters(ctx context.Context, inferenceObj *model.PresetI return commands, resourceRequirements } -func configVolume(wObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetInferenceParam) ([]corev1.Volume, []corev1.VolumeMount) { +func configVolume(wObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetParam) ([]corev1.Volume, []corev1.VolumeMount) { volume := []corev1.Volume{} volumeMount := []corev1.VolumeMount{} diff --git a/pkg/inference/preset-inferences_test.go b/pkg/inference/preset-inferences_test.go index cd8df067c..31bf0551e 100644 --- a/pkg/inference/preset-inferences_test.go +++ b/pkg/inference/preset-inferences_test.go @@ -62,7 +62,7 @@ func TestCreatePresetInference(t *testing.T) { useHeadlessSvc := false - var inferenceObj *model.PresetInferenceParam + var inferenceObj *model.PresetParam model := plugin.KaitoModelRegister.MustGet(tc.modelName) inferenceObj = model.GetInferenceParameters() diff --git a/pkg/model/interface.go b/pkg/model/interface.go index 217c1f889..2763b5eec 100644 --- a/pkg/model/interface.go +++ b/pkg/model/interface.go @@ -7,12 +7,13 @@ import ( ) type Model interface { - GetInferenceParameters() *PresetInferenceParam + GetInferenceParameters() *PresetParam + GetTrainingParameters() *PresetParam SupportDistributedInference() bool //If true, the model workload will be a StatefulSet, using the torch elastic runtime framework. } -// PresetInferenceParam defines the preset inference parameters for a model. -type PresetInferenceParam struct { +// PresetParam defines the preset inference parameters for a model. +type PresetParam struct { ModelFamilyName string // The name of the model family. ImageAccessMode string // Defines where the Image is Public or Private. DiskStorageRequirement string // Disk storage requirements for the model. diff --git a/pkg/tuning/preset-tuning-types.go b/pkg/tuning/preset-tuning-types.go new file mode 100644 index 000000000..51f36511d --- /dev/null +++ b/pkg/tuning/preset-tuning-types.go @@ -0,0 +1,21 @@ +package tuning + +import corev1 "k8s.io/api/core/v1" + +const ( + DefaultNumProcesses = "1" + DefaultNumMachines = "1" + DefaultMachineRank = "0" + DefaultGPUIds = "all" +) + +var ( + DefaultAccelerateParams = map[string]string{ + "num_processes": DefaultNumProcesses, + "num_machines": DefaultNumMachines, + "machine_rank": DefaultMachineRank, + "gpu_ids": DefaultGPUIds, + } + + DefaultImagePullSecrets = []corev1.LocalObjectReference{} +) diff --git a/pkg/tuning/preset-tuning.go b/pkg/tuning/preset-tuning.go new file mode 100644 index 000000000..cbbb55a06 --- /dev/null +++ b/pkg/tuning/preset-tuning.go @@ -0,0 +1,29 @@ +package tuning + +import ( + "context" + kaitov1alpha1 "github.com/azure/kaito/api/v1alpha1" + "github.com/azure/kaito/pkg/model" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +func CreatePresetTuning(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, + tuningObj *model.PresetParam, kubeClient client.Client) (client.Object, error) { + // TODO + + // e.g. example from Inference + //volume, volumeMount := configVolume(workspaceObj, inferenceObj) + //commands, resourceReq := prepareInferenceParameters(ctx, inferenceObj) + //image, imagePullSecrets := GetImageInfo(ctx, workspaceObj, inferenceObj) + // + //depObj = resources.GenerateDeploymentManifest(ctx, workspaceObj, image, imagePullSecrets, *workspaceObj.Resource.Count, commands, + // containerPorts, livenessProbe, readinessProbe, resourceReq, tolerations, volume, volumeMount) + // + //err := resources.CreateResource(ctx, depObj, kubeClient) + //if client.IgnoreAlreadyExists(err) != nil { + // return nil, err + //} + //return depObj, nil + + return nil, nil +} diff --git a/pkg/utils/testModel.go b/pkg/utils/testModel.go index 99e3d8aca..50f6c9175 100644 --- a/pkg/utils/testModel.go +++ b/pkg/utils/testModel.go @@ -12,8 +12,8 @@ import ( type testModel struct{} -func (*testModel) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*testModel) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ GPUCountRequirement: "1", DeploymentTimeout: time.Duration(30) * time.Minute, } @@ -24,8 +24,8 @@ func (*testModel) SupportDistributedInference() bool { type testDistributedModel struct{} -func (*testDistributedModel) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*testDistributedModel) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ GPUCountRequirement: "1", DeploymentTimeout: time.Duration(30) * time.Minute, } diff --git a/presets/models/falcon/README.md b/presets/models/falcon/README.md index 81a1ced6f..e8cd895e6 100644 --- a/presets/models/falcon/README.md +++ b/presets/models/falcon/README.md @@ -1,10 +1,10 @@ ## Supported Models |Model name| Model source | Sample workspace|Kubernetes Workload|Distributed inference| |----|:----:|:----:| :----: |:----: | -|falcon-7b-instruct |[tiiuae](https://huggingface.co/tiiuae/falcon-7b-instruct)|[link](../../../examples/kaito_workspace_falcon_7b-instruct.yaml)|Deployment| false| -|falcon-7b |[tiiuae](https://huggingface.co/tiiuae/falcon-7b) |[link](../../../examples/kaito_workspace_falcon_7b.yaml)|Deployment| false| -|falcon-40b-instruct|[tiiuae](https://huggingface.co/tiiuae/falcon-40b-instruct) |[link](../../../examples/kaito_workspace_falcon_40b-instruct.yaml)|Deployment| false| -|falcon-40b |[tiiuae](https://huggingface.co/tiiuae/falcon-40b)|[link](../../../examples/kaito_workspace_falcon_40b.yaml)|Deployment| false| +|falcon-7b-instruct |[tiiuae](https://huggingface.co/tiiuae/falcon-7b-instruct)|[link](../../../examples/inference/kaito_workspace_falcon_7b-instruct.yaml)|Deployment| false| +|falcon-7b |[tiiuae](https://huggingface.co/tiiuae/falcon-7b) |[link](../../../examples/inference/kaito_workspace_falcon_7b.yaml)|Deployment| false| +|falcon-40b-instruct|[tiiuae](https://huggingface.co/tiiuae/falcon-40b-instruct) |[link](../../../examples/inference/kaito_workspace_falcon_40b-instruct.yaml)|Deployment| false| +|falcon-40b |[tiiuae](https://huggingface.co/tiiuae/falcon-40b)|[link](../../../examples/inference/kaito_workspace_falcon_40b.yaml)|Deployment| false| ## Image Source - **Public**: Kaito maintainers manage the lifecycle of the inference service images that contain model weights. The images are available in Microsoft Container Registry (MCR). diff --git a/presets/models/falcon/model.go b/presets/models/falcon/model.go index 7501dce23..c0f07495f 100644 --- a/presets/models/falcon/model.go +++ b/presets/models/falcon/model.go @@ -54,8 +54,8 @@ var falconA falcon7b type falcon7b struct{} -func (*falcon7b) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*falcon7b) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "Falcon", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), DiskStorageRequirement: "50Gi", @@ -68,8 +68,23 @@ func (*falcon7b) GetInferenceParameters() *model.PresetInferenceParam { BaseCommand: baseCommandPresetFalcon, Tag: PresetFalconTagMap["Falcon7B"], } - } +func (*falcon7b) GetTrainingParameters() *model.PresetParam { + return &model.PresetParam{ + ModelFamilyName: "Falcon", + ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), + DiskStorageRequirement: "50Gi", + GPUCountRequirement: "2", + TotalGPUMemoryRequirement: "16Gi", + PerGPUMemoryRequirement: "16Gi", + //TorchRunParams: tuning.DefaultAccelerateParams, // TODO + //ModelRunPrams: falconRunTuningParams, // TODO + DeploymentTimeout: time.Duration(30) * time.Minute, + BaseCommand: baseCommandPresetFalcon, + Tag: PresetFalconTagMap["Falcon7B"], + } +} + func (*falcon7b) SupportDistributedInference() bool { return false } @@ -78,8 +93,8 @@ var falconB falcon7bInst type falcon7bInst struct{} -func (*falcon7bInst) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*falcon7bInst) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "Falcon", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), DiskStorageRequirement: "50Gi", @@ -94,6 +109,9 @@ func (*falcon7bInst) GetInferenceParameters() *model.PresetInferenceParam { } } +func (*falcon7bInst) GetTrainingParameters() *model.PresetParam { + return nil // It is not recommended/ideal to further fine-tune instruct models - Already been fine-tuned +} func (*falcon7bInst) SupportDistributedInference() bool { return false } @@ -102,8 +120,8 @@ var falconC falcon40b type falcon40b struct{} -func (*falcon40b) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*falcon40b) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "Falcon", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), DiskStorageRequirement: "400", @@ -118,6 +136,21 @@ func (*falcon40b) GetInferenceParameters() *model.PresetInferenceParam { } } +func (*falcon40b) GetTrainingParameters() *model.PresetParam { + return &model.PresetParam{ + ModelFamilyName: "Falcon", + ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), + DiskStorageRequirement: "50Gi", + GPUCountRequirement: "2", + TotalGPUMemoryRequirement: "90Gi", + PerGPUMemoryRequirement: "16Gi", + //TorchRunParams: tuning.DefaultAccelerateParams, // TODO + //ModelRunPrams: falconRunTuningParams, // TODO + DeploymentTimeout: time.Duration(30) * time.Minute, + BaseCommand: baseCommandPresetFalcon, + Tag: PresetFalconTagMap["Falcon40B"], + } +} func (*falcon40b) SupportDistributedInference() bool { return false } @@ -126,8 +159,8 @@ var falconD falcon40bInst type falcon40bInst struct{} -func (*falcon40bInst) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*falcon40bInst) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "Falcon", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), DiskStorageRequirement: "400", @@ -141,7 +174,9 @@ func (*falcon40bInst) GetInferenceParameters() *model.PresetInferenceParam { Tag: PresetFalconTagMap["Falcon40BInstruct"], } } - +func (*falcon40bInst) GetTrainingParameters() *model.PresetParam { + return nil // It is not recommended/ideal to further fine-tune instruct models - Already been fine-tuned +} func (*falcon40bInst) SupportDistributedInference() bool { return false } diff --git a/presets/models/llama2/README.md b/presets/models/llama2/README.md index e6a40563a..ba2646a2b 100644 --- a/presets/models/llama2/README.md +++ b/presets/models/llama2/README.md @@ -1,9 +1,9 @@ ## Supported Models |Model name| Model source | Sample workspace|Kubernetes Workload|Distributed inference| |----|:----:|:----:| :----: |:----: | -|llama2-7b |[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/kaito_workspace_llama2_7b.yaml)|Deployment| false| -|llama2-13b|[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/kaito_workspace_llama2_13b.yaml)|StatefulSet| true| -|llama2-70b|[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/kaito_workspace_llama2_70b.yaml)|StatefulSet| true| +|llama2-7b |[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/inference/kaito_workspace_llama2_7b.yaml)|Deployment| false| +|llama2-13b|[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/inference/kaito_workspace_llama2_13b.yaml)|StatefulSet| true| +|llama2-70b|[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/inference/kaito_workspace_llama2_70b.yaml)|StatefulSet| true| ## Image Source - **Private**: User needs to manage the lifecycle of the inference service images that contain model weights (e.g., managing image tags). The images are available in user's private container registry. diff --git a/presets/models/llama2/model.go b/presets/models/llama2/model.go index 30c97b7fd..673a67da3 100644 --- a/presets/models/llama2/model.go +++ b/presets/models/llama2/model.go @@ -38,8 +38,8 @@ var llama2A llama2Text7b type llama2Text7b struct{} -func (*llama2Text7b) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*llama2Text7b) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "LLaMa2", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePrivate), DiskStorageRequirement: "34Gi", @@ -56,6 +56,9 @@ func (*llama2Text7b) GetInferenceParameters() *model.PresetInferenceParam { } } +func (*llama2Text7b) GetTrainingParameters() *model.PresetParam { + return nil // Currently doesn't support fine-tuning +} func (*llama2Text7b) SupportDistributedInference() bool { return false } @@ -64,8 +67,8 @@ var llama2B llama2Text13b type llama2Text13b struct{} -func (*llama2Text13b) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*llama2Text13b) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "LLaMa2", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePrivate), DiskStorageRequirement: "46Gi", @@ -81,6 +84,9 @@ func (*llama2Text13b) GetInferenceParameters() *model.PresetInferenceParam { // Tag: llama has private image access mode. The image tag is determined by the user. } } +func (*llama2Text13b) GetTrainingParameters() *model.PresetParam { + return nil // Currently doesn't support fine-tuning +} func (*llama2Text13b) SupportDistributedInference() bool { return true } @@ -89,8 +95,8 @@ var llama2C llama2Text70b type llama2Text70b struct{} -func (*llama2Text70b) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*llama2Text70b) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "LLaMa2", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePrivate), DiskStorageRequirement: "158Gi", @@ -106,6 +112,9 @@ func (*llama2Text70b) GetInferenceParameters() *model.PresetInferenceParam { // Tag: llama has private image access mode. The image tag is determined by the user. } } +func (*llama2Text70b) GetTrainingParameters() *model.PresetParam { + return nil // Currently doesn't support fine-tuning +} func (*llama2Text70b) SupportDistributedInference() bool { return true } diff --git a/presets/models/llama2chat/README.md b/presets/models/llama2chat/README.md index 53e241fab..0cf9ec3be 100644 --- a/presets/models/llama2chat/README.md +++ b/presets/models/llama2chat/README.md @@ -1,9 +1,9 @@ ## Supported Models |Model name| Model source | Sample workspace|Kubernetes Workload|Distributed inference| |----|:----:|:----:| :----: |:----: | -|llama2-7b-chat |[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/kaito_workspace_llama2_7b-chat.yaml)|Deployment| false| -|llama2-13b-chat|[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/kaito_workspace_llama2_13b-chat.yaml)|StatefulSet| true| -|llama2-70b-chat|[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/kaito_workspace_llama2_70b-chat.yaml)|StatefulSet| true| +|llama2-7b-chat |[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/inference/kaito_workspace_llama2_7b-chat.yaml)|Deployment| false| +|llama2-13b-chat|[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/inference/kaito_workspace_llama2_13b-chat.yaml)|StatefulSet| true| +|llama2-70b-chat|[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/inference/kaito_workspace_llama2_70b-chat.yaml)|StatefulSet| true| ## Image Source - **Private**: User needs to manage the lifecycle of the inference service images that contain model weights (e.g., managing image tags). The images are available in user's private container registry. diff --git a/presets/models/llama2chat/model.go b/presets/models/llama2chat/model.go index cc0d8d4c6..a555ebc07 100644 --- a/presets/models/llama2chat/model.go +++ b/presets/models/llama2chat/model.go @@ -38,8 +38,8 @@ var llama2chatA llama2Chat7b type llama2Chat7b struct{} -func (*llama2Chat7b) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*llama2Chat7b) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "LLaMa2", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePrivate), DiskStorageRequirement: "34Gi", @@ -54,7 +54,9 @@ func (*llama2Chat7b) GetInferenceParameters() *model.PresetInferenceParam { WorldSize: 1, // Tag: llama has private image access mode. The image tag is determined by the user. } - +} +func (*llama2Chat7b) GetTrainingParameters() *model.PresetParam { + return nil // Currently doesn't support fine-tuning } func (*llama2Chat7b) SupportDistributedInference() bool { return false @@ -64,8 +66,8 @@ var llama2chatB llama2Chat13b type llama2Chat13b struct{} -func (*llama2Chat13b) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*llama2Chat13b) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "LLaMa2", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePrivate), DiskStorageRequirement: "46Gi", @@ -81,6 +83,9 @@ func (*llama2Chat13b) GetInferenceParameters() *model.PresetInferenceParam { // Tag: llama has private image access mode. The image tag is determined by the user. } } +func (*llama2Chat13b) GetTrainingParameters() *model.PresetParam { + return nil // Currently doesn't support fine-tuning +} func (*llama2Chat13b) SupportDistributedInference() bool { return true } @@ -89,8 +94,8 @@ var llama2chatC llama2Chat70b type llama2Chat70b struct{} -func (*llama2Chat70b) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*llama2Chat70b) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "LLaMa2", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePrivate), DiskStorageRequirement: "158Gi", @@ -106,6 +111,9 @@ func (*llama2Chat70b) GetInferenceParameters() *model.PresetInferenceParam { // Tag: llama has private image access mode. The image tag is determined by the user. } } +func (*llama2Chat70b) GetTrainingParameters() *model.PresetParam { + return nil // Currently doesn't support fine-tuning +} func (*llama2Chat70b) SupportDistributedInference() bool { return true } diff --git a/presets/models/mistral/README.md b/presets/models/mistral/README.md index 4d0c56ba6..2d037f7a4 100644 --- a/presets/models/mistral/README.md +++ b/presets/models/mistral/README.md @@ -1,8 +1,8 @@ ## Supported Models |Model name| Model source | Sample workspace|Kubernetes Workload|Distributed inference| |----|:----:|:----:| :----: |:----: | -|mistral-7b-instruct |[mistralai](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2)|[link](../../../examples/kaito_workspace_mistral_7b-instruct.yaml)|Deployment| false| -|mistral-7b |[mistralai](https://huggingface.co/mistralai/Mistral-7B-v0.1) |[link](../../../examples/kaito_workspace_mistral_7b.yaml)|Deployment| false| +|mistral-7b-instruct |[mistralai](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2)|[link](../../../examples/inference/kaito_workspace_mistral_7b-instruct.yaml)|Deployment| false| +|mistral-7b |[mistralai](https://huggingface.co/mistralai/Mistral-7B-v0.1) |[link](../../../examples/inference/kaito_workspace_mistral_7b.yaml)|Deployment| false| ## Image Source diff --git a/presets/models/mistral/model.go b/presets/models/mistral/model.go index 7089eafb6..bcf06203b 100644 --- a/presets/models/mistral/model.go +++ b/presets/models/mistral/model.go @@ -42,8 +42,8 @@ var mistralA mistral7b type mistral7b struct{} -func (*mistral7b) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*mistral7b) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "Mistral", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), DiskStorageRequirement: "100Gi", @@ -58,6 +58,22 @@ func (*mistral7b) GetInferenceParameters() *model.PresetInferenceParam { } } +func (*mistral7b) GetTrainingParameters() *model.PresetParam { + return &model.PresetParam{ + ModelFamilyName: "Mistral", + ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), + DiskStorageRequirement: "100Gi", + GPUCountRequirement: "1", + TotalGPUMemoryRequirement: "16Gi", + PerGPUMemoryRequirement: "16Gi", // We run Mistral using native vertical model parallel, no per GPU memory requirement. + //TorchRunParams: tuning.DefaultAccelerateParams, + //ModelRunParams: mistralRunParams, + DeploymentTimeout: time.Duration(30) * time.Minute, + BaseCommand: baseCommandPresetMistral, + Tag: PresetMistralTagMap["Mistral7B"], + } +} + func (*mistral7b) SupportDistributedInference() bool { return false } @@ -66,8 +82,8 @@ var mistralB mistral7bInst type mistral7bInst struct{} -func (*mistral7bInst) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*mistral7bInst) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "Mistral", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), DiskStorageRequirement: "100Gi", @@ -82,6 +98,9 @@ func (*mistral7bInst) GetInferenceParameters() *model.PresetInferenceParam { } } +func (*mistral7bInst) GetTrainingParameters() *model.PresetParam { + return nil // It is not recommended/ideal to further fine-tune instruct models - Already been fine-tuned +} func (*mistral7bInst) SupportDistributedInference() bool { return false } diff --git a/presets/models/phi/README.md b/presets/models/phi/README.md index 7caeadb84..1e77252a5 100644 --- a/presets/models/phi/README.md +++ b/presets/models/phi/README.md @@ -1,7 +1,7 @@ ## Supported Models |Model name| Model source | Sample workspace|Kubernetes Workload|Distributed inference| |----|:----:|:----:| :----: |:----: | -|phi-2 |[microsoft](https://huggingface.co/microsoft/phi-2)|[link](../../../examples/kaito_workspace_phi-2.yaml)|Deployment| false| +|phi-2 |[microsoft](https://huggingface.co/microsoft/phi-2)|[link](../../../examples/inference/kaito_workspace_phi-2.yaml)|Deployment| false| ## Image Source diff --git a/presets/models/phi/model.go b/presets/models/phi/model.go index 2e54dce38..37d92e673 100644 --- a/presets/models/phi/model.go +++ b/presets/models/phi/model.go @@ -36,8 +36,8 @@ var phiA phi2 type phi2 struct{} -func (*phi2) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*phi2) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "Phi", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), DiskStorageRequirement: "50Gi", From b21b99f59a583f9265d898e65d5e08aee080b1ba Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Mon, 18 Mar 2024 10:37:08 -0700 Subject: [PATCH 09/29] feat: spec level validation --- api/v1alpha1/workspace_types.go | 4 ++-- api/v1alpha1/workspace_validation.go | 28 +++++++++++++++++++++++++++- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/api/v1alpha1/workspace_types.go b/api/v1alpha1/workspace_types.go index 2f966f647..e451e1066 100644 --- a/api/v1alpha1/workspace_types.go +++ b/api/v1alpha1/workspace_types.go @@ -13,7 +13,7 @@ const ( ModelImageAccessModePrivate ModelImageAccessMode = "private" ) -// ResourceSpec desicribes the resource requirement of running the workload. +// ResourceSpec describes the resource requirement of running the workload. // If the number of nodes in the cluster that meet the InstanceType and // LabelSelector requirements is small than the Count, controller // will provision new nodes before deploying the workload. @@ -51,7 +51,7 @@ type PresetMeta struct { // AccessMode specifies whether the containerized model image is accessible via public registry // or private registry. This field defaults to "public" if not specified. // If this field is "private", user needs to provide the private image information in PresetOptions. - // +bebuilder:default:="public" + // +kubebuilder:default:="public" // +optional AccessMode ModelImageAccessMode `json:"accessMode,omitempty"` } diff --git a/api/v1alpha1/workspace_validation.go b/api/v1alpha1/workspace_validation.go index 16576f684..732526447 100644 --- a/api/v1alpha1/workspace_validation.go +++ b/api/v1alpha1/workspace_validation.go @@ -35,6 +35,7 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) { if base == nil { klog.InfoS("Validate creation", "workspace", fmt.Sprintf("%s/%s", w.Namespace, w.Name)) errs = errs.Also( + w.validateCreate().ViaField("spec"), w.Inference.validateCreate().ViaField("inference"), w.Resource.validateCreate(w.Inference).ViaField("resource"), ) @@ -42,6 +43,7 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) { klog.InfoS("Validate update", "workspace", fmt.Sprintf("%s/%s", w.Namespace, w.Name)) old := base.(*Workspace) errs = errs.Also( + w.validateUpdate(old).ViaField("spec"), w.Resource.validateUpdate(&old.Resource).ViaField("resource"), w.Inference.validateUpdate(&old.Inference).ViaField("inference"), ) @@ -49,6 +51,15 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) { return errs } +func (w *Workspace) validateCreate() (errs *apis.FieldError) { + inferenceSpecified := w.Inference.Preset != nil || w.Inference.Template != nil + tuningSpecified := w.Tuning.Input != nil + if inferenceSpecified != tuningSpecified { + return errs.Also(apis.ErrGeneric("Either Inference or Tuning must be specified, but not both", "")) + } + return errs +} + func (r *ResourceSpec) validateCreate(inference InferenceSpec) (errs *apis.FieldError) { var presetName string if inference.Preset != nil { @@ -96,6 +107,21 @@ func (r *ResourceSpec) validateCreate(inference InferenceSpec) (errs *apis.Field return errs } +func (w *Workspace) validateUpdate(old *Workspace) (errs *apis.FieldError) { + // Check inference specified + oldInferenceSpecified := old.Inference.Preset != nil || old.Inference.Template != nil + inferenceSpecified := w.Inference.Preset != nil || w.Inference.Template != nil + // Check tuning specified + oldTuningSpecified := old.Tuning.Input != nil + tuningSpecified := w.Tuning.Input != nil + + // inference/tuning can be changed, but cannot be set/unset. + if (!oldInferenceSpecified && inferenceSpecified) || (!oldTuningSpecified && tuningSpecified) { + errs = errs.Also(apis.ErrGeneric("field cannot be unset/set if it was set/unset", "spec")) + } + return errs +} + func (r *ResourceSpec) validateUpdate(old *ResourceSpec) (errs *apis.FieldError) { // We disable changing node count for now. if r.Count != nil && old.Count != nil && *r.Count != *old.Count { @@ -151,7 +177,7 @@ func (i *InferenceSpec) validateUpdate(old *InferenceSpec) (errs *apis.FieldErro if !reflect.DeepEqual(i.Preset, old.Preset) { errs = errs.Also(apis.ErrGeneric("field is immutable", "preset")) } - //inference.template can be changed, but cannot be unset. + // inference.template can be changed, but cannot be set/unset. if (i.Template != nil && old.Template == nil) || (i.Template == nil && old.Template != nil) { errs = errs.Also(apis.ErrGeneric("field cannot be unset/set if it was set/unset", "template")) } From 75cacae1e1e8ac2d3ab822c4109d294305360c3f Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Mon, 18 Mar 2024 15:01:25 -0700 Subject: [PATCH 10/29] feat: Added validation checks for TuningSpec, DataSource, DataDestination --- api/v1alpha1/workspace_types.go | 2 +- api/v1alpha1/workspace_validation.go | 138 +++++++++++++++++++++++---- 2 files changed, 122 insertions(+), 18 deletions(-) diff --git a/api/v1alpha1/workspace_types.go b/api/v1alpha1/workspace_types.go index e451e1066..71e9f829c 100644 --- a/api/v1alpha1/workspace_types.go +++ b/api/v1alpha1/workspace_types.go @@ -106,7 +106,7 @@ type DataSource struct { // URLs specifies the links to the public data sources. E.g., files in a public github repository. // +optional URLs []string `json:"urls,omitempty"` - // The directory in the hsot that contains the data. + // The directory in the host that contains the data. // +optional HostPath string `json:"hostPath,omitempty"` // The name of the image that contains the source data. The assumption is that the source data locates in the diff --git a/api/v1alpha1/workspace_validation.go b/api/v1alpha1/workspace_validation.go index 732526447..81d9353b4 100644 --- a/api/v1alpha1/workspace_validation.go +++ b/api/v1alpha1/workspace_validation.go @@ -37,6 +37,9 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) { errs = errs.Also( w.validateCreate().ViaField("spec"), w.Inference.validateCreate().ViaField("inference"), + w.Tuning.validateCreate().ViaField("tuning"), + w.Tuning.Input.validateCreate().ViaField("input"), + w.Tuning.Output.validateCreate().ViaField("output"), w.Resource.validateCreate(w.Inference).ViaField("resource"), ) } else { @@ -44,8 +47,11 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) { old := base.(*Workspace) errs = errs.Also( w.validateUpdate(old).ViaField("spec"), - w.Resource.validateUpdate(&old.Resource).ViaField("resource"), w.Inference.validateUpdate(&old.Inference).ViaField("inference"), + w.Tuning.validateUpdate(&old.Tuning).ViaField("tuning"), + w.Tuning.Input.validateUpdate(old.Tuning.Input).ViaField("input"), + w.Tuning.Output.validateUpdate(old.Tuning.Output).ViaField("output"), + w.Resource.validateUpdate(&old.Resource).ViaField("resource"), ) } return errs @@ -55,11 +61,124 @@ func (w *Workspace) validateCreate() (errs *apis.FieldError) { inferenceSpecified := w.Inference.Preset != nil || w.Inference.Template != nil tuningSpecified := w.Tuning.Input != nil if inferenceSpecified != tuningSpecified { - return errs.Also(apis.ErrGeneric("Either Inference or Tuning must be specified, but not both", "")) + errs = errs.Also(apis.ErrGeneric("Either Inference or Tuning must be specified, but not both", "")) } return errs } +func (w *Workspace) validateUpdate(old *Workspace) (errs *apis.FieldError) { + // Check inference specified + oldInferenceSpecified := old.Inference.Preset != nil || old.Inference.Template != nil + inferenceSpecified := w.Inference.Preset != nil || w.Inference.Template != nil + // Check tuning specified + oldTuningSpecified := old.Tuning.Input != nil + tuningSpecified := w.Tuning.Input != nil + + // inference/tuning can be changed, but cannot be set/unset. + if (!oldInferenceSpecified && inferenceSpecified) || (!oldTuningSpecified && tuningSpecified) { + errs = errs.Also(apis.ErrGeneric("field cannot be unset/set if it was set/unset", "spec")) + } + return errs +} + +func (r *TuningSpec) validateCreate() (errs *apis.FieldError) { + if r.Input == nil { + errs = errs.Also(apis.ErrMissingField("Input")) + } + if r.Output == nil { + errs = errs.Also(apis.ErrMissingField("Output")) + } + // Currently require a preset to specified, in future we can consider defining a template + if r.Preset == nil { + errs = errs.Also(apis.ErrMissingField("Preset")) + } + // TODO: We have to register training plugins and check if it preset exists in plugins here + methodLowerCase := strings.ToLower(string(r.Method)) + if methodLowerCase != string(TuningMethodLora) && methodLowerCase != string(TuningMethodQLora) { + errs = errs.Also(apis.ErrInvalidValue(r.Method, "Method")) + } + return errs +} + +func (r *TuningSpec) validateUpdate(old *TuningSpec) (errs *apis.FieldError) { + if !reflect.DeepEqual(old.Input, r.Input) { + errs = errs.Also(apis.ErrGeneric("Input field cannot be changed", "Input")) + } + if !reflect.DeepEqual(old.Output, r.Output) { + errs = errs.Also(apis.ErrGeneric("Output field cannot be changed", "Output")) + } + if !reflect.DeepEqual(old.Preset, r.Preset) { + errs = errs.Also(apis.ErrGeneric("Preset cannot be changed", "Preset")) + } + // We will have to consider supporting tuning method and config fields changing + methodLowerCase := strings.ToLower(string(r.Method)) + if methodLowerCase != string(TuningMethodLora) && methodLowerCase != string(TuningMethodQLora) { + errs = errs.Also(apis.ErrInvalidValue(r.Method, "Method")) + } + return errs +} + +func (r *DataSource) validateCreate() (errs *apis.FieldError) { + sourcesSpecified := 0 + if len(r.URLs) > 0 { + sourcesSpecified++ + } + if r.HostPath != "" { + sourcesSpecified++ + } + if r.Image != "" { + sourcesSpecified++ + } + + // Ensure exactly one of URLs, HostPath, or Image is specified + if sourcesSpecified != 1 { + errs = errs.Also(apis.ErrGeneric("Exactly one of URLs, HostPath, or Image must be specified", "URLs", "HostPath", "Image")) + } + + return errs +} + +func (r *DataSource) validateUpdate(old *DataSource) (errs *apis.FieldError) { + if !reflect.DeepEqual(old.URLs, r.URLs) { + errs = errs.Also(apis.ErrInvalidValue("URLs field cannot be changed once set", "URLs")) + } + if old.HostPath != r.HostPath { + errs = errs.Also(apis.ErrInvalidValue("HostPath field cannot be changed once set", "HostPath")) + } + if old.Image != r.Image { + errs = errs.Also(apis.ErrInvalidValue("Image field cannot be changed once set", "Image")) + } + // TODO: Ensure ImageSecrets can be changed + return errs +} + +func (r *DataDestination) validateCreate() (errs *apis.FieldError) { + destinationsSpecified := 0 + if r.HostPath != "" { + destinationsSpecified++ + } + if r.Image != "" { + destinationsSpecified++ + } + + // If no destination is specified, return an error + if destinationsSpecified == 0 { + errs = errs.Also(apis.ErrMissingField("At least one of HostPath or Image must be specified")) + } + return errs +} + +func (r *DataDestination) validateUpdate(old *DataDestination) (errs *apis.FieldError) { + if old.HostPath != r.HostPath { + errs = errs.Also(apis.ErrInvalidValue("HostPath field cannot be changed once set", "HostPath")) + } + if old.Image != r.Image { + errs = errs.Also(apis.ErrInvalidValue("Image field cannot be changed once set", "Image")) + } + // TODO: Ensure ImageSecrets can be changed + return errs +} + func (r *ResourceSpec) validateCreate(inference InferenceSpec) (errs *apis.FieldError) { var presetName string if inference.Preset != nil { @@ -107,21 +226,6 @@ func (r *ResourceSpec) validateCreate(inference InferenceSpec) (errs *apis.Field return errs } -func (w *Workspace) validateUpdate(old *Workspace) (errs *apis.FieldError) { - // Check inference specified - oldInferenceSpecified := old.Inference.Preset != nil || old.Inference.Template != nil - inferenceSpecified := w.Inference.Preset != nil || w.Inference.Template != nil - // Check tuning specified - oldTuningSpecified := old.Tuning.Input != nil - tuningSpecified := w.Tuning.Input != nil - - // inference/tuning can be changed, but cannot be set/unset. - if (!oldInferenceSpecified && inferenceSpecified) || (!oldTuningSpecified && tuningSpecified) { - errs = errs.Also(apis.ErrGeneric("field cannot be unset/set if it was set/unset", "spec")) - } - return errs -} - func (r *ResourceSpec) validateUpdate(old *ResourceSpec) (errs *apis.FieldError) { // We disable changing node count for now. if r.Count != nil && old.Count != nil && *r.Count != *old.Count { From 5f9e132ae0ff61902333336895793ba6796c4b42 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Mon, 18 Mar 2024 15:08:46 -0700 Subject: [PATCH 11/29] fix: prevent toggling --- api/v1alpha1/workspace_validation.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/api/v1alpha1/workspace_validation.go b/api/v1alpha1/workspace_validation.go index 81d9353b4..39fbf87b7 100644 --- a/api/v1alpha1/workspace_validation.go +++ b/api/v1alpha1/workspace_validation.go @@ -73,10 +73,12 @@ func (w *Workspace) validateUpdate(old *Workspace) (errs *apis.FieldError) { // Check tuning specified oldTuningSpecified := old.Tuning.Input != nil tuningSpecified := w.Tuning.Input != nil + if (!oldInferenceSpecified && inferenceSpecified) || (oldInferenceSpecified && !inferenceSpecified) { + errs = errs.Also(apis.ErrGeneric("Inference field cannot be toggled once set", "inference")) + } - // inference/tuning can be changed, but cannot be set/unset. - if (!oldInferenceSpecified && inferenceSpecified) || (!oldTuningSpecified && tuningSpecified) { - errs = errs.Also(apis.ErrGeneric("field cannot be unset/set if it was set/unset", "spec")) + if (!oldTuningSpecified && tuningSpecified) || (oldTuningSpecified && !tuningSpecified) { + errs = errs.Also(apis.ErrGeneric("Tuning field cannot be toggled once set", "tuning")) } return errs } From 9f4e820fe10ecdbe5b2c4fc5546de6e713fd4b81 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Mon, 18 Mar 2024 18:06:49 -0700 Subject: [PATCH 12/29] fix: validation fixes --- api/v1alpha1/workspace_validation.go | 47 ++++++++++++++++++++++------ 1 file changed, 37 insertions(+), 10 deletions(-) diff --git a/api/v1alpha1/workspace_validation.go b/api/v1alpha1/workspace_validation.go index 39fbf87b7..9034deae3 100644 --- a/api/v1alpha1/workspace_validation.go +++ b/api/v1alpha1/workspace_validation.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "reflect" + "sort" "strings" "github.com/azure/kaito/pkg/utils/plugin" @@ -60,7 +61,10 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) { func (w *Workspace) validateCreate() (errs *apis.FieldError) { inferenceSpecified := w.Inference.Preset != nil || w.Inference.Template != nil tuningSpecified := w.Tuning.Input != nil - if inferenceSpecified != tuningSpecified { + if !inferenceSpecified && !tuningSpecified { + errs = errs.Also(apis.ErrGeneric("Either Inference or Tuning must be specified, not neither", "")) + } + if inferenceSpecified && tuningSpecified { errs = errs.Also(apis.ErrGeneric("Either Inference or Tuning must be specified, but not both", "")) } return errs @@ -93,8 +97,9 @@ func (r *TuningSpec) validateCreate() (errs *apis.FieldError) { // Currently require a preset to specified, in future we can consider defining a template if r.Preset == nil { errs = errs.Also(apis.ErrMissingField("Preset")) + } else if presetName := string(r.Preset.Name); !isValidPreset(presetName) { + errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported tuning preset name %s", presetName), "presetName")) } - // TODO: We have to register training plugins and check if it preset exists in plugins here methodLowerCase := strings.ToLower(string(r.Method)) if methodLowerCase != string(TuningMethodLora) && methodLowerCase != string(TuningMethodQLora) { errs = errs.Also(apis.ErrInvalidValue(r.Method, "Method")) @@ -112,11 +117,11 @@ func (r *TuningSpec) validateUpdate(old *TuningSpec) (errs *apis.FieldError) { if !reflect.DeepEqual(old.Preset, r.Preset) { errs = errs.Also(apis.ErrGeneric("Preset cannot be changed", "Preset")) } - // We will have to consider supporting tuning method and config fields changing - methodLowerCase := strings.ToLower(string(r.Method)) - if methodLowerCase != string(TuningMethodLora) && methodLowerCase != string(TuningMethodQLora) { - errs = errs.Also(apis.ErrInvalidValue(r.Method, "Method")) + oldMethod, newMethod := strings.ToLower(string(old.Method)), strings.ToLower(string(r.Method)) + if !reflect.DeepEqual(oldMethod, newMethod) { + errs = errs.Also(apis.ErrGeneric("Method cannot be changed", "Method")) } + // Consider supporting config fields changing return errs } @@ -141,7 +146,15 @@ func (r *DataSource) validateCreate() (errs *apis.FieldError) { } func (r *DataSource) validateUpdate(old *DataSource) (errs *apis.FieldError) { - if !reflect.DeepEqual(old.URLs, r.URLs) { + oldURLs := make([]string, len(old.URLs)) + copy(oldURLs, old.URLs) + sort.Strings(old.URLs) + + newURLs := make([]string, len(r.URLs)) + copy(newURLs, r.URLs) + sort.Strings(r.URLs) + + if !reflect.DeepEqual(oldURLs, newURLs) { errs = errs.Also(apis.ErrInvalidValue("URLs field cannot be changed once set", "URLs")) } if old.HostPath != r.HostPath { @@ -150,7 +163,18 @@ func (r *DataSource) validateUpdate(old *DataSource) (errs *apis.FieldError) { if old.Image != r.Image { errs = errs.Also(apis.ErrInvalidValue("Image field cannot be changed once set", "Image")) } - // TODO: Ensure ImageSecrets can be changed + + oldSecrets := make([]string, len(old.ImagePullSecrets)) + copy(oldSecrets, old.ImagePullSecrets) + sort.Strings(oldSecrets) + + newSecrets := make([]string, len(r.ImagePullSecrets)) + copy(newSecrets, r.ImagePullSecrets) + sort.Strings(newSecrets) + + if !reflect.DeepEqual(oldSecrets, newSecrets) { + errs = errs.Also(apis.ErrInvalidValue("ImagePullSecrets field cannot be changed once set", "ImagePullSecrets")) + } return errs } @@ -177,7 +201,10 @@ func (r *DataDestination) validateUpdate(old *DataDestination) (errs *apis.Field if old.Image != r.Image { errs = errs.Also(apis.ErrInvalidValue("Image field cannot be changed once set", "Image")) } - // TODO: Ensure ImageSecrets can be changed + + if old.ImagePushSecret != r.ImagePushSecret { + errs = errs.Also(apis.ErrInvalidValue("ImagePushSecret field cannot be changed once set", "ImagePushSecret")) + } return errs } @@ -263,7 +290,7 @@ func (i *InferenceSpec) validateCreate() (errs *apis.FieldError) { presetName := string(i.Preset.Name) // Validate preset name if !isValidPreset(presetName) { - errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported preset name %s", presetName), "presetName")) + errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported inference preset name %s", presetName), "presetName")) } // Validate private preset has private image specified if plugin.KaitoModelRegister.MustGet(string(i.Preset.Name)).GetInferenceParameters().ImageAccessMode == "private" && From a16351dab65067143908e0ae0f63518f497cc21f Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Tue, 19 Mar 2024 11:50:00 -0700 Subject: [PATCH 13/29] feat: Add UTs for workspace validation --- api/v1alpha1/workspace_validation.go | 4 +- api/v1alpha1/workspace_validation_test.go | 615 ++++++++++++++++++++++ 2 files changed, 617 insertions(+), 2 deletions(-) diff --git a/api/v1alpha1/workspace_validation.go b/api/v1alpha1/workspace_validation.go index 9034deae3..5a8269353 100644 --- a/api/v1alpha1/workspace_validation.go +++ b/api/v1alpha1/workspace_validation.go @@ -148,11 +148,11 @@ func (r *DataSource) validateCreate() (errs *apis.FieldError) { func (r *DataSource) validateUpdate(old *DataSource) (errs *apis.FieldError) { oldURLs := make([]string, len(old.URLs)) copy(oldURLs, old.URLs) - sort.Strings(old.URLs) + sort.Strings(oldURLs) newURLs := make([]string, len(r.URLs)) copy(newURLs, r.URLs) - sort.Strings(r.URLs) + sort.Strings(newURLs) if !reflect.DeepEqual(oldURLs, newURLs) { errs = errs.Also(apis.ErrInvalidValue("URLs field cannot be changed once set", "URLs")) diff --git a/api/v1alpha1/workspace_validation_test.go b/api/v1alpha1/workspace_validation_test.go index 0a3fa2de1..6c2f1a650 100644 --- a/api/v1alpha1/workspace_validation_test.go +++ b/api/v1alpha1/workspace_validation_test.go @@ -488,6 +488,621 @@ func TestInferenceSpecValidateUpdate(t *testing.T) { } } +func TestWorkspaceValidateCreate(t *testing.T) { + tests := []struct { + name string + workspace *Workspace + wantErr bool + errField string + }{ + { + name: "Neither Inference nor Tuning specified", + workspace: &Workspace{ + Inference: InferenceSpec{}, + Tuning: TuningSpec{}, + }, + wantErr: true, + errField: "neither", + }, + { + name: "Both Inference and Tuning specified", + workspace: &Workspace{ + Inference: InferenceSpec{Preset: &PresetSpec{}}, + Tuning: TuningSpec{Input: &DataSource{}}, + }, + wantErr: true, + errField: "both", + }, + { + name: "Only Inference specified", + workspace: &Workspace{ + Inference: InferenceSpec{Preset: &PresetSpec{}}, + }, + wantErr: false, + errField: "", + }, + { + name: "Only Tuning specified", + workspace: &Workspace{ + Tuning: TuningSpec{Input: &DataSource{}}, + }, + wantErr: false, + errField: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.workspace.validateCreate() + if (errs != nil) != tt.wantErr { + t.Errorf("validateCreate() error = %v, wantErr %v", errs, tt.wantErr) + } + if errs != nil && !strings.Contains(errs.Error(), tt.errField) { + t.Errorf("validateCreate() expected error to contain field %s, but got %s", tt.errField, errs.Error()) + } + }) + } +} + +func TestWorkspaceValidateUpdate(t *testing.T) { + tests := []struct { + name string + oldWorkspace *Workspace + newWorkspace *Workspace + expectErrs bool + errFields []string // Fields we expect to have errors + }{ + { + name: "Inference toggled on", + oldWorkspace: &Workspace{ + Inference: InferenceSpec{}, + }, + newWorkspace: &Workspace{ + Inference: InferenceSpec{Preset: &PresetSpec{}}, + }, + expectErrs: true, + errFields: []string{"inference"}, + }, + { + name: "Inference toggled off", + oldWorkspace: &Workspace{ + Inference: InferenceSpec{Preset: &PresetSpec{}}, + }, + newWorkspace: &Workspace{ + Inference: InferenceSpec{}, + }, + expectErrs: true, + errFields: []string{"inference"}, + }, + { + name: "Tuning toggled on", + oldWorkspace: &Workspace{ + Tuning: TuningSpec{}, + }, + newWorkspace: &Workspace{ + Tuning: TuningSpec{Input: &DataSource{}}, + }, + expectErrs: true, + errFields: []string{"tuning"}, + }, + { + name: "Tuning toggled off", + oldWorkspace: &Workspace{ + Tuning: TuningSpec{Input: &DataSource{}}, + }, + newWorkspace: &Workspace{ + Tuning: TuningSpec{}, + }, + expectErrs: true, + errFields: []string{"tuning"}, + }, + { + name: "No toggling", + oldWorkspace: &Workspace{ + Tuning: TuningSpec{Input: &DataSource{}}, + }, + newWorkspace: &Workspace{ + Tuning: TuningSpec{Input: &DataSource{}}, + }, + expectErrs: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.newWorkspace.validateUpdate(tt.oldWorkspace) + hasErrs := errs != nil + + if hasErrs != tt.expectErrs { + t.Errorf("validateUpdate() errors = %v, expectErrs %v", errs, tt.expectErrs) + } + + if hasErrs { + for _, field := range tt.errFields { + if !strings.Contains(errs.Error(), field) { + t.Errorf("validateUpdate() expected errors to contain field %s, but got %s", field, errs.Error()) + } + } + } + }) + } +} + +func TestTuningSpecValidateCreate(t *testing.T) { + RegisterValidationTestModels() + tests := []struct { + name string + tuningSpec *TuningSpec + wantErr bool + errFields []string // Fields we expect to have errors + }{ + { + name: "All fields valid", + tuningSpec: &TuningSpec{ + Input: &DataSource{Name: "valid-input"}, + Output: &DataDestination{HostPath: "valid-output"}, + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}}, + Method: TuningMethodLora, + }, + wantErr: false, + errFields: nil, + }, + { + name: "Missing Input", + tuningSpec: &TuningSpec{ + Output: &DataDestination{HostPath: "valid-output"}, + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}}, + Method: TuningMethodLora, + }, + wantErr: true, + errFields: []string{"Input"}, + }, + { + name: "Missing Output", + tuningSpec: &TuningSpec{ + Input: &DataSource{Name: "valid-input"}, + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}}, + Method: TuningMethodLora, + }, + wantErr: true, + errFields: []string{"Output"}, + }, + { + name: "Missing Preset", + tuningSpec: &TuningSpec{ + Input: &DataSource{Name: "valid-input"}, + Output: &DataDestination{HostPath: "valid-output"}, + Method: TuningMethodLora, + }, + wantErr: true, + errFields: []string{"Preset"}, + }, + { + name: "Invalid Preset", + tuningSpec: &TuningSpec{ + Input: &DataSource{Name: "valid-input"}, + Output: &DataDestination{HostPath: "valid-output"}, + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("invalid-preset")}}, + Method: TuningMethodLora, + }, + wantErr: true, + errFields: []string{"presetName"}, + }, + { + name: "Invalid Method", + tuningSpec: &TuningSpec{ + Input: &DataSource{Name: "valid-input"}, + Output: &DataDestination{HostPath: "valid-output"}, + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}}, + Method: "invalid-method", + }, + wantErr: true, + errFields: []string{"Method"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.tuningSpec.validateCreate() + hasErrs := errs != nil + + if hasErrs != tt.wantErr { + t.Errorf("validateCreate() errors = %v, wantErr %v", errs, tt.wantErr) + } + + if hasErrs { + for _, field := range tt.errFields { + if !strings.Contains(errs.Error(), field) { + t.Errorf("validateCreate() expected errors to contain field %s, but got %s", field, errs.Error()) + } + } + } + }) + } +} + +func TestTuningSpecValidateUpdate(t *testing.T) { + RegisterValidationTestModels() + tests := []struct { + name string + oldTuning *TuningSpec + newTuning *TuningSpec + expectErrs bool + errFields []string // Fields we expect to have errors + }{ + { + name: "No changes", + oldTuning: &TuningSpec{ + Input: &DataSource{Name: "input1"}, + Output: &DataDestination{HostPath: "path1"}, + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}}, + Method: TuningMethodLora, + }, + newTuning: &TuningSpec{ + Input: &DataSource{Name: "input1"}, + Output: &DataDestination{HostPath: "path1"}, + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}}, + Method: TuningMethodLora, + }, + expectErrs: false, + }, + { + name: "Input changed", + oldTuning: &TuningSpec{ + Input: &DataSource{Name: "input1"}, + }, + newTuning: &TuningSpec{ + Input: &DataSource{Name: "input2"}, + }, + expectErrs: true, + errFields: []string{"Input"}, + }, + { + name: "Output changed", + oldTuning: &TuningSpec{ + Output: &DataDestination{HostPath: "path1"}, + }, + newTuning: &TuningSpec{ + Output: &DataDestination{HostPath: "path2"}, + }, + expectErrs: true, + errFields: []string{"Output"}, + }, + { + name: "Preset changed", + oldTuning: &TuningSpec{ + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}}, + }, + newTuning: &TuningSpec{ + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("invalid-preset")}}, + }, + expectErrs: true, + errFields: []string{"Preset"}, + }, + { + name: "Method changed", + oldTuning: &TuningSpec{ + Method: TuningMethodLora, + }, + newTuning: &TuningSpec{ + Method: TuningMethodQLora, + }, + expectErrs: true, + errFields: []string{"Method"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.newTuning.validateUpdate(tt.oldTuning) + hasErrs := errs != nil + + if hasErrs != tt.expectErrs { + t.Errorf("validateUpdate() errors = %v, expectErrs %v", errs, tt.expectErrs) + } + + if hasErrs { + for _, field := range tt.errFields { + if !strings.Contains(errs.Error(), field) { + t.Errorf("validateUpdate() expected errors to contain field %s, but got %s", field, errs.Error()) + } + } + } + }) + } +} + +func TestDataSourceValidateCreate(t *testing.T) { + tests := []struct { + name string + dataSource *DataSource + wantErr bool + errField string // The field we expect to have an error on + }{ + { + name: "URLs specified only", + dataSource: &DataSource{ + URLs: []string{"http://example.com/data"}, + }, + wantErr: false, + }, + { + name: "HostPath specified only", + dataSource: &DataSource{ + HostPath: "/data/path", + }, + wantErr: false, + }, + { + name: "Image specified only", + dataSource: &DataSource{ + Image: "data-image:latest", + }, + wantErr: false, + }, + { + name: "None specified", + dataSource: &DataSource{}, + wantErr: true, + errField: "Exactly one of URLs, HostPath, or Image must be specified", + }, + { + name: "URLs and HostPath specified", + dataSource: &DataSource{ + URLs: []string{"http://example.com/data"}, + HostPath: "/data/path", + }, + wantErr: true, + errField: "Exactly one of URLs, HostPath, or Image must be specified", + }, + { + name: "All fields specified", + dataSource: &DataSource{ + URLs: []string{"http://example.com/data"}, + HostPath: "/data/path", + Image: "data-image:latest", + }, + wantErr: true, + errField: "Exactly one of URLs, HostPath, or Image must be specified", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.dataSource.validateCreate() + hasErrs := errs != nil + + if hasErrs != tt.wantErr { + t.Errorf("validateCreate() error = %v, wantErr %v", errs, tt.wantErr) + } + + if hasErrs && tt.errField != "" && !strings.Contains(errs.Error(), tt.errField) { + t.Errorf("validateCreate() expected error to contain %s, but got %s", tt.errField, errs.Error()) + } + }) + } +} + +func TestDataSourceValidateUpdate(t *testing.T) { + tests := []struct { + name string + oldSource *DataSource + newSource *DataSource + wantErr bool + errFields []string // Fields we expect to have errors + }{ + { + name: "No changes", + oldSource: &DataSource{ + URLs: []string{"http://example.com/data1", "http://example.com/data2"}, + HostPath: "/data/path", + Image: "data-image:latest", + ImagePullSecrets: []string{"secret1", "secret2"}, + }, + newSource: &DataSource{ + URLs: []string{"http://example.com/data2", "http://example.com/data1"}, // Note the different order, should not matter + HostPath: "/data/path", + Image: "data-image:latest", + ImagePullSecrets: []string{"secret2", "secret1"}, // Note the different order, should not matter + }, + wantErr: false, + }, + { + name: "URLs changed", + oldSource: &DataSource{ + URLs: []string{"http://example.com/old"}, + }, + newSource: &DataSource{ + URLs: []string{"http://example.com/new"}, + }, + wantErr: true, + errFields: []string{"URLs"}, + }, + { + name: "HostPath changed", + oldSource: &DataSource{ + HostPath: "/old/path", + }, + newSource: &DataSource{ + HostPath: "/new/path", + }, + wantErr: true, + errFields: []string{"HostPath"}, + }, + { + name: "Image changed", + oldSource: &DataSource{ + Image: "old-image:latest", + }, + newSource: &DataSource{ + Image: "new-image:latest", + }, + wantErr: true, + errFields: []string{"Image"}, + }, + { + name: "ImagePullSecrets changed", + oldSource: &DataSource{ + ImagePullSecrets: []string{"old-secret"}, + }, + newSource: &DataSource{ + ImagePullSecrets: []string{"new-secret"}, + }, + wantErr: true, + errFields: []string{"ImagePullSecrets"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.newSource.validateUpdate(tt.oldSource) + hasErrs := errs != nil + + if hasErrs != tt.wantErr { + t.Errorf("validateUpdate() error = %v, wantErr %v", errs, tt.wantErr) + } + + if hasErrs { + for _, field := range tt.errFields { + if !strings.Contains(errs.Error(), field) { + t.Errorf("validateUpdate() expected errors to contain field %s, but got %s", field, errs.Error()) + } + } + } + }) + } +} + +func TestDataDestinationValidateCreate(t *testing.T) { + tests := []struct { + name string + dataDestination *DataDestination + wantErr bool + errField string // The field we expect to have an error on + }{ + { + name: "No fields specified", + dataDestination: &DataDestination{}, + wantErr: true, + errField: "At least one of HostPath or Image must be specified", + }, + { + name: "HostPath specified only", + dataDestination: &DataDestination{ + HostPath: "/data/path", + }, + wantErr: false, + }, + { + name: "Image specified only", + dataDestination: &DataDestination{ + Image: "data-image:latest", + }, + wantErr: false, + }, + { + name: "Both fields specified", + dataDestination: &DataDestination{ + HostPath: "/data/path", + Image: "data-image:latest", + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.dataDestination.validateCreate() + hasErrs := errs != nil + + if hasErrs != tt.wantErr { + t.Errorf("validateCreate() error = %v, wantErr %v", errs, tt.wantErr) + } + + if hasErrs && tt.errField != "" && !strings.Contains(errs.Error(), tt.errField) { + t.Errorf("validateCreate() expected error to contain %s, but got %s", tt.errField, errs.Error()) + } + }) + } +} + +func TestDataDestinationValidateUpdate(t *testing.T) { + tests := []struct { + name string + oldDest *DataDestination + newDest *DataDestination + wantErr bool + errFields []string // Fields we expect to have errors + }{ + { + name: "No changes", + oldDest: &DataDestination{ + HostPath: "/data/old", + Image: "old-image:latest", + ImagePushSecret: "old-secret", + }, + newDest: &DataDestination{ + HostPath: "/data/old", + Image: "old-image:latest", + ImagePushSecret: "old-secret", + }, + wantErr: false, + }, + { + name: "HostPath changed", + oldDest: &DataDestination{ + HostPath: "/data/old", + }, + newDest: &DataDestination{ + HostPath: "/data/new", + }, + wantErr: true, + errFields: []string{"HostPath"}, + }, + { + name: "Image changed", + oldDest: &DataDestination{ + Image: "old-image:latest", + }, + newDest: &DataDestination{ + Image: "new-image:latest", + }, + wantErr: true, + errFields: []string{"Image"}, + }, + { + name: "ImagePushSecret changed", + oldDest: &DataDestination{ + ImagePushSecret: "old-secret", + }, + newDest: &DataDestination{ + ImagePushSecret: "new-secret", + }, + wantErr: true, + errFields: []string{"ImagePushSecret"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.newDest.validateUpdate(tt.oldDest) + hasErrs := errs != nil + + if hasErrs != tt.wantErr { + t.Errorf("validateUpdate() error = %v, wantErr %v", errs, tt.wantErr) + } + + if hasErrs { + for _, field := range tt.errFields { + if !strings.Contains(errs.Error(), field) { + t.Errorf("validateUpdate() expected errors to contain field %s, but got %s", field, errs.Error()) + } + } + } + }) + } +} + func TestGetSupportedSKUs(t *testing.T) { tests := []struct { name string From d2cd23045a451bfd9ac17523e4c1a9bf6ff2986f Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Tue, 19 Mar 2024 19:07:09 -0700 Subject: [PATCH 14/29] fix: Update CRD to use pointers --- api/v1alpha1/workspace_types.go | 8 +- api/v1alpha1/workspace_validation.go | 55 ++++--- api/v1alpha1/workspace_validation_test.go | 69 ++++---- api/v1alpha1/zz_generated.deepcopy.go | 12 +- .../workspace/crds/kaito.sh_workspaces.yaml | 147 +++++++++++++++++- config/crd/bases/kaito.sh_workspaces.yaml | 11 +- pkg/utils/testUtils.go | 6 +- test/e2e/preset_test.go | 19 ++- test/e2e/utils/utils.go | 30 ++-- 9 files changed, 252 insertions(+), 105 deletions(-) diff --git a/api/v1alpha1/workspace_types.go b/api/v1alpha1/workspace_types.go index 71e9f829c..4484b8250 100644 --- a/api/v1alpha1/workspace_types.go +++ b/api/v1alpha1/workspace_types.go @@ -150,9 +150,9 @@ type TuningSpec struct { // +optional Config string `json:"config,omitempty"` // Input describes the input used by the tuning method. - Input *DataSource `json:"input,omitempty"` + Input *DataSource `json:"input"` // Output specified where to store the tuning output. - Output *DataDestination `json:"output,omitempty"` + Output *DataDestination `json:"output"` } // WorkspaceStatus defines the observed state of Workspace @@ -181,8 +181,8 @@ type Workspace struct { metav1.ObjectMeta `json:"metadata,omitempty"` Resource ResourceSpec `json:"resource,omitempty"` - Inference InferenceSpec `json:"inference,omitempty"` - Tuning TuningSpec `json:"tuning,omitempty"` + Inference *InferenceSpec `json:"inference,omitempty"` + Tuning *TuningSpec `json:"tuning,omitempty"` Status WorkspaceStatus `json:"status,omitempty"` } diff --git a/api/v1alpha1/workspace_validation.go b/api/v1alpha1/workspace_validation.go index 5a8269353..b135f5886 100644 --- a/api/v1alpha1/workspace_validation.go +++ b/api/v1alpha1/workspace_validation.go @@ -37,51 +37,48 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) { klog.InfoS("Validate creation", "workspace", fmt.Sprintf("%s/%s", w.Namespace, w.Name)) errs = errs.Also( w.validateCreate().ViaField("spec"), - w.Inference.validateCreate().ViaField("inference"), - w.Tuning.validateCreate().ViaField("tuning"), - w.Tuning.Input.validateCreate().ViaField("input"), - w.Tuning.Output.validateCreate().ViaField("output"), - w.Resource.validateCreate(w.Inference).ViaField("resource"), + // TODO: Consider validate resource based on Tuning Spec + w.Resource.validateCreate(*w.Inference).ViaField("resource"), ) + if w.Inference != nil { + errs = errs.Also(w.Inference.validateCreate().ViaField("inference")) + } + if w.Tuning != nil { + errs = errs.Also(w.Tuning.validateCreate().ViaField("tuning")) + } } else { klog.InfoS("Validate update", "workspace", fmt.Sprintf("%s/%s", w.Namespace, w.Name)) old := base.(*Workspace) errs = errs.Also( w.validateUpdate(old).ViaField("spec"), - w.Inference.validateUpdate(&old.Inference).ViaField("inference"), - w.Tuning.validateUpdate(&old.Tuning).ViaField("tuning"), - w.Tuning.Input.validateUpdate(old.Tuning.Input).ViaField("input"), - w.Tuning.Output.validateUpdate(old.Tuning.Output).ViaField("output"), w.Resource.validateUpdate(&old.Resource).ViaField("resource"), ) + if w.Inference != nil { + errs = errs.Also(w.Inference.validateUpdate(old.Inference).ViaField("inference")) + } + if w.Tuning != nil { + errs = errs.Also(w.Tuning.validateUpdate(old.Tuning).ViaField("tuning")) + } } return errs } func (w *Workspace) validateCreate() (errs *apis.FieldError) { - inferenceSpecified := w.Inference.Preset != nil || w.Inference.Template != nil - tuningSpecified := w.Tuning.Input != nil - if !inferenceSpecified && !tuningSpecified { + if w.Inference == nil && w.Tuning == nil { errs = errs.Also(apis.ErrGeneric("Either Inference or Tuning must be specified, not neither", "")) } - if inferenceSpecified && tuningSpecified { + if w.Inference != nil && w.Tuning != nil { errs = errs.Also(apis.ErrGeneric("Either Inference or Tuning must be specified, but not both", "")) } return errs } func (w *Workspace) validateUpdate(old *Workspace) (errs *apis.FieldError) { - // Check inference specified - oldInferenceSpecified := old.Inference.Preset != nil || old.Inference.Template != nil - inferenceSpecified := w.Inference.Preset != nil || w.Inference.Template != nil - // Check tuning specified - oldTuningSpecified := old.Tuning.Input != nil - tuningSpecified := w.Tuning.Input != nil - if (!oldInferenceSpecified && inferenceSpecified) || (oldInferenceSpecified && !inferenceSpecified) { + if (old.Inference == nil && w.Inference != nil) || (old.Inference != nil && w.Inference == nil) { errs = errs.Also(apis.ErrGeneric("Inference field cannot be toggled once set", "inference")) } - if (!oldTuningSpecified && tuningSpecified) || (oldTuningSpecified && !tuningSpecified) { + if (old.Tuning == nil && w.Tuning != nil) || (old.Tuning != nil && w.Tuning == nil) { errs = errs.Also(apis.ErrGeneric("Tuning field cannot be toggled once set", "tuning")) } return errs @@ -90,9 +87,13 @@ func (w *Workspace) validateUpdate(old *Workspace) (errs *apis.FieldError) { func (r *TuningSpec) validateCreate() (errs *apis.FieldError) { if r.Input == nil { errs = errs.Also(apis.ErrMissingField("Input")) + } else { + errs = errs.Also(r.Input.validateCreate().ViaField("Input")) } if r.Output == nil { errs = errs.Also(apis.ErrMissingField("Output")) + } else { + errs = errs.Also(r.Output.validateCreate().ViaField("Output")) } // Currently require a preset to specified, in future we can consider defining a template if r.Preset == nil { @@ -108,11 +109,15 @@ func (r *TuningSpec) validateCreate() (errs *apis.FieldError) { } func (r *TuningSpec) validateUpdate(old *TuningSpec) (errs *apis.FieldError) { - if !reflect.DeepEqual(old.Input, r.Input) { - errs = errs.Also(apis.ErrGeneric("Input field cannot be changed", "Input")) + if r.Input == nil { + errs = errs.Also(apis.ErrMissingField("Input")) + } else { + errs = errs.Also(r.Input.validateUpdate(old.Input).ViaField("Input")) } - if !reflect.DeepEqual(old.Output, r.Output) { - errs = errs.Also(apis.ErrGeneric("Output field cannot be changed", "Output")) + if r.Output == nil { + errs = errs.Also(apis.ErrMissingField("Output")) + } else { + errs = errs.Also(r.Output.validateUpdate(old.Output).ViaField("Output")) } if !reflect.DeepEqual(old.Preset, r.Preset) { errs = errs.Also(apis.ErrGeneric("Preset cannot be changed", "Preset")) diff --git a/api/v1alpha1/workspace_validation_test.go b/api/v1alpha1/workspace_validation_test.go index 6c2f1a650..d1cea034d 100644 --- a/api/v1alpha1/workspace_validation_test.go +++ b/api/v1alpha1/workspace_validation_test.go @@ -496,19 +496,16 @@ func TestWorkspaceValidateCreate(t *testing.T) { errField string }{ { - name: "Neither Inference nor Tuning specified", - workspace: &Workspace{ - Inference: InferenceSpec{}, - Tuning: TuningSpec{}, - }, - wantErr: true, - errField: "neither", + name: "Neither Inference nor Tuning specified", + workspace: &Workspace{}, + wantErr: true, + errField: "neither", }, { name: "Both Inference and Tuning specified", workspace: &Workspace{ - Inference: InferenceSpec{Preset: &PresetSpec{}}, - Tuning: TuningSpec{Input: &DataSource{}}, + Inference: &InferenceSpec{}, + Tuning: &TuningSpec{}, }, wantErr: true, errField: "both", @@ -516,7 +513,7 @@ func TestWorkspaceValidateCreate(t *testing.T) { { name: "Only Inference specified", workspace: &Workspace{ - Inference: InferenceSpec{Preset: &PresetSpec{}}, + Inference: &InferenceSpec{}, }, wantErr: false, errField: "", @@ -524,7 +521,7 @@ func TestWorkspaceValidateCreate(t *testing.T) { { name: "Only Tuning specified", workspace: &Workspace{ - Tuning: TuningSpec{Input: &DataSource{}}, + Tuning: &TuningSpec{Input: &DataSource{}}, }, wantErr: false, errField: "", @@ -553,12 +550,10 @@ func TestWorkspaceValidateUpdate(t *testing.T) { errFields []string // Fields we expect to have errors }{ { - name: "Inference toggled on", - oldWorkspace: &Workspace{ - Inference: InferenceSpec{}, - }, + name: "Inference toggled on", + oldWorkspace: &Workspace{}, newWorkspace: &Workspace{ - Inference: InferenceSpec{Preset: &PresetSpec{}}, + Inference: &InferenceSpec{}, }, expectErrs: true, errFields: []string{"inference"}, @@ -566,21 +561,17 @@ func TestWorkspaceValidateUpdate(t *testing.T) { { name: "Inference toggled off", oldWorkspace: &Workspace{ - Inference: InferenceSpec{Preset: &PresetSpec{}}, - }, - newWorkspace: &Workspace{ - Inference: InferenceSpec{}, + Inference: &InferenceSpec{Preset: &PresetSpec{}}, }, - expectErrs: true, - errFields: []string{"inference"}, + newWorkspace: &Workspace{}, + expectErrs: true, + errFields: []string{"inference"}, }, { - name: "Tuning toggled on", - oldWorkspace: &Workspace{ - Tuning: TuningSpec{}, - }, + name: "Tuning toggled on", + oldWorkspace: &Workspace{}, newWorkspace: &Workspace{ - Tuning: TuningSpec{Input: &DataSource{}}, + Tuning: &TuningSpec{Input: &DataSource{}}, }, expectErrs: true, errFields: []string{"tuning"}, @@ -588,21 +579,19 @@ func TestWorkspaceValidateUpdate(t *testing.T) { { name: "Tuning toggled off", oldWorkspace: &Workspace{ - Tuning: TuningSpec{Input: &DataSource{}}, - }, - newWorkspace: &Workspace{ - Tuning: TuningSpec{}, + Tuning: &TuningSpec{Input: &DataSource{}}, }, - expectErrs: true, - errFields: []string{"tuning"}, + newWorkspace: &Workspace{}, + expectErrs: true, + errFields: []string{"tuning"}, }, { name: "No toggling", oldWorkspace: &Workspace{ - Tuning: TuningSpec{Input: &DataSource{}}, + Tuning: &TuningSpec{Input: &DataSource{}}, }, newWorkspace: &Workspace{ - Tuning: TuningSpec{Input: &DataSource{}}, + Tuning: &TuningSpec{Input: &DataSource{}}, }, expectErrs: false, }, @@ -639,7 +628,7 @@ func TestTuningSpecValidateCreate(t *testing.T) { { name: "All fields valid", tuningSpec: &TuningSpec{ - Input: &DataSource{Name: "valid-input"}, + Input: &DataSource{Name: "valid-input", HostPath: "valid-input"}, Output: &DataDestination{HostPath: "valid-output"}, Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}}, Method: TuningMethodLora, @@ -749,13 +738,15 @@ func TestTuningSpecValidateUpdate(t *testing.T) { { name: "Input changed", oldTuning: &TuningSpec{ - Input: &DataSource{Name: "input1"}, + Input: &DataSource{Name: "input", HostPath: "inputpath"}, + Output: &DataDestination{HostPath: "outputpath"}, }, newTuning: &TuningSpec{ - Input: &DataSource{Name: "input2"}, + Input: &DataSource{Name: "input", HostPath: "randompath"}, + Output: &DataDestination{HostPath: "outputpath"}, }, expectErrs: true, - errFields: []string{"Input"}, + errFields: []string{"HostPath"}, }, { name: "Output changed", diff --git a/api/v1alpha1/zz_generated.deepcopy.go b/api/v1alpha1/zz_generated.deepcopy.go index a9d662c0f..6c3ee1eb8 100644 --- a/api/v1alpha1/zz_generated.deepcopy.go +++ b/api/v1alpha1/zz_generated.deepcopy.go @@ -249,8 +249,16 @@ func (in *Workspace) DeepCopyInto(out *Workspace) { out.TypeMeta = in.TypeMeta in.ObjectMeta.DeepCopyInto(&out.ObjectMeta) in.Resource.DeepCopyInto(&out.Resource) - in.Inference.DeepCopyInto(&out.Inference) - in.Tuning.DeepCopyInto(&out.Tuning) + if in.Inference != nil { + in, out := &in.Inference, &out.Inference + *out = new(InferenceSpec) + (*in).DeepCopyInto(*out) + } + if in.Tuning != nil { + in, out := &in.Tuning, &out.Tuning + *out = new(TuningSpec) + (*in).DeepCopyInto(*out) + } in.Status.DeepCopyInto(&out.Status) } diff --git a/charts/kaito/workspace/crds/kaito.sh_workspaces.yaml b/charts/kaito/workspace/crds/kaito.sh_workspaces.yaml index 40908f609..a4103a897 100644 --- a/charts/kaito/workspace/crds/kaito.sh_workspaces.yaml +++ b/charts/kaito/workspace/crds/kaito.sh_workspaces.yaml @@ -47,11 +47,56 @@ spec: type: string inference: properties: + adapters: + description: Adapters are integrated into the base model for inference. + Users can specify multiple adapters for the model and the respective + weight of using each of them. + items: + properties: + source: + description: Source describes where to obtain the adapter data. + properties: + hostPath: + description: The directory in the host that contains the + data. + type: string + image: + description: The name of the image that contains the source + data. The assumption is that the source data locates in + the `data` directory in the image. + type: string + imagePullSecrets: + description: ImagePullSecrets is a list of secret names + in the same namespace used for pulling the data image. + items: + type: string + type: array + name: + description: The name of the dataset. The same name will + be used as a container name. It must be a valid DNS subdomain + value, + type: string + urls: + description: URLs specifies the links to the public data + sources. E.g., files in a public github repository. + items: + type: string + type: array + type: object + strength: + description: Strength specifies the default multiplier for applying + the adapter weights to the raw model weights. It is usually + a float number between 0 and 1. It is defined as a string + type to be language agnostic. + type: string + type: object + type: array preset: - description: Preset describles the model that will be deployed with - preset configurations. + description: Preset describes the base model that will be deployed + with preset configurations. properties: accessMode: + default: public description: AccessMode specifies whether the containerized model image is accessible via public registry or private registry. This field defaults to "public" if not specified. If this field @@ -72,7 +117,7 @@ spec: type: string imagePullSecrets: description: ImagePullSecrets is a list of secret names in - the same namespace used for pulling the image. + the same namespace used for pulling the model image. items: type: string type: array @@ -95,7 +140,7 @@ spec: metadata: type: object resource: - description: ResourceSpec desicribes the resource requirement of running + description: ResourceSpec describes the resource requirement of running the workload. If the number of nodes in the cluster that meet the InstanceType and LabelSelector requirements is small than the Count, controller will provision new nodes before deploying the workload. The final list of @@ -245,6 +290,100 @@ spec: type: string type: array type: object + tuning: + properties: + config: + description: Config specifies the name of the configmap in the same + namespace that contains the arguments used by the tuning method. + If not specified, a default configmap is used based on the specified + method. + type: string + input: + description: Input describes the input used by the tuning method. + properties: + hostPath: + description: The directory in the host that contains the data. + type: string + image: + description: The name of the image that contains the source data. + The assumption is that the source data locates in the `data` + directory in the image. + type: string + imagePullSecrets: + description: ImagePullSecrets is a list of secret names in the + same namespace used for pulling the data image. + items: + type: string + type: array + name: + description: The name of the dataset. The same name will be used + as a container name. It must be a valid DNS subdomain value, + type: string + urls: + description: URLs specifies the links to the public data sources. + E.g., files in a public github repository. + items: + type: string + type: array + type: object + method: + description: Method specifies the Parameter-Efficient Fine-Tuning(PEFT) + method, such as lora, qlora, used for the tuning. + type: string + output: + description: Output specified where to store the tuning output. + properties: + hostPath: + description: The directory in the host that contains the output + data. + type: string + image: + description: Name of the image where the output data is pushed + to. + type: string + imagePushSecret: + description: ImagePushSecret is the name of the secret in the + same namespace that contains the authentication information + that is needed for running `docker push`. + type: string + type: object + preset: + description: Preset describes which model to load for tuning. + properties: + accessMode: + default: public + description: AccessMode specifies whether the containerized model + image is accessible via public registry or private registry. + This field defaults to "public" if not specified. If this field + is "private", user needs to provide the private image information + in PresetOptions. + enum: + - public + - private + type: string + name: + description: Name of the supported models with preset configurations. + type: string + presetOptions: + properties: + image: + description: Image is the name of the containerized model + image. + type: string + imagePullSecrets: + description: ImagePullSecrets is a list of secret names in + the same namespace used for pulling the model image. + items: + type: string + type: array + type: object + required: + - name + type: object + required: + - input + - output + type: object type: object served: true storage: true diff --git a/config/crd/bases/kaito.sh_workspaces.yaml b/config/crd/bases/kaito.sh_workspaces.yaml index b3af23a76..a4103a897 100644 --- a/config/crd/bases/kaito.sh_workspaces.yaml +++ b/config/crd/bases/kaito.sh_workspaces.yaml @@ -57,7 +57,7 @@ spec: description: Source describes where to obtain the adapter data. properties: hostPath: - description: The directory in the hsot that contains the + description: The directory in the host that contains the data. type: string image: @@ -96,6 +96,7 @@ spec: with preset configurations. properties: accessMode: + default: public description: AccessMode specifies whether the containerized model image is accessible via public registry or private registry. This field defaults to "public" if not specified. If this field @@ -139,7 +140,7 @@ spec: metadata: type: object resource: - description: ResourceSpec desicribes the resource requirement of running + description: ResourceSpec describes the resource requirement of running the workload. If the number of nodes in the cluster that meet the InstanceType and LabelSelector requirements is small than the Count, controller will provision new nodes before deploying the workload. The final list of @@ -301,7 +302,7 @@ spec: description: Input describes the input used by the tuning method. properties: hostPath: - description: The directory in the hsot that contains the data. + description: The directory in the host that contains the data. type: string image: description: The name of the image that contains the source data. @@ -350,6 +351,7 @@ spec: description: Preset describes which model to load for tuning. properties: accessMode: + default: public description: AccessMode specifies whether the containerized model image is accessible via public registry or private registry. This field defaults to "public" if not specified. If this field @@ -378,6 +380,9 @@ spec: required: - name type: object + required: + - input + - output type: object type: object served: true diff --git a/pkg/utils/testUtils.go b/pkg/utils/testUtils.go index f88b35a4f..5ef34af1d 100644 --- a/pkg/utils/testUtils.go +++ b/pkg/utils/testUtils.go @@ -35,7 +35,7 @@ var ( }, }, }, - Inference: v1alpha1.InferenceSpec{ + Inference: &v1alpha1.InferenceSpec{ Preset: &v1alpha1.PresetSpec{ PresetMeta: v1alpha1.PresetMeta{ Name: "test-distributed-model", @@ -60,7 +60,7 @@ var ( }, }, }, - Inference: v1alpha1.InferenceSpec{ + Inference: &v1alpha1.InferenceSpec{ Preset: &v1alpha1.PresetSpec{ PresetMeta: v1alpha1.PresetMeta{ Name: "test-model", @@ -85,7 +85,7 @@ var ( }, }, }, - Inference: v1alpha1.InferenceSpec{ + Inference: &v1alpha1.InferenceSpec{ Template: &corev1.PodTemplateSpec{}, }, } diff --git a/test/e2e/preset_test.go b/test/e2e/preset_test.go index eb0333df4..e8f262ef0 100644 --- a/test/e2e/preset_test.go +++ b/test/e2e/preset_test.go @@ -26,13 +26,13 @@ import ( ) const ( - PresetLlama2AChat = "llama-2-7b-chat" - PresetLlama2BChat = "llama-2-13b-chat" - PresetFalcon7BModel = "falcon-7b" - PresetFalcon40BModel = "falcon-40b" - PresetMistral7BModel = "mistral-7b" + PresetLlama2AChat = "llama-2-7b-chat" + PresetLlama2BChat = "llama-2-13b-chat" + PresetFalcon7BModel = "falcon-7b" + PresetFalcon40BModel = "falcon-40b" + PresetMistral7BModel = "mistral-7b" PresetMistral7BInstructModel = "mistral-7b-instruct" - PresetPhi2Model = "phi-2" + PresetPhi2Model = "phi-2" ) func createFalconWorkspaceWithPresetPublicMode(numOfNode int) *kaitov1alpha1.Workspace { @@ -348,17 +348,17 @@ var _ = Describe("Workspace Preset", func() { fmt.Print("Error: RUN_LLAMA_13B ENV Variable not set") runLlama13B = false } - + aiModelsRegistry = utils.GetEnv("AI_MODELS_REGISTRY") aiModelsRegistrySecret = utils.GetEnv("AI_MODELS_REGISTRY_SECRET") - + // Load stable model versions configs, err := utils.GetModelConfigInfo("/home/runner/work/kaito/kaito/presets/models/supported_models.yaml") if err != nil { fmt.Printf("Failed to load model configs: %v\n", err) os.Exit(1) } - + modelInfo, err = utils.ExtractModelVersion(configs) if err != nil { fmt.Printf("Failed to extract stable model versions: %v\n", err) @@ -404,7 +404,6 @@ var _ = Describe("Workspace Preset", func() { validateWorkspaceReadiness(workspaceObj) }) - It("should create a Phi-2 workspace with preset public mode successfully", func() { numOfNode := 1 workspaceObj := createPhi2WorkspaceWithPresetPublicMode(numOfNode) diff --git a/test/e2e/utils/utils.go b/test/e2e/utils/utils.go index 3914f00eb..38388374f 100644 --- a/test/e2e/utils/utils.go +++ b/test/e2e/utils/utils.go @@ -60,23 +60,23 @@ func ExtractModelVersion(configs map[string]interface{}) (map[string]string, err } for _, modelItem := range models { - model, ok := modelItem.(map[interface{}]interface{}) - if !ok { - return nil, fmt.Errorf("model item is not a map") - } + model, ok := modelItem.(map[interface{}]interface{}) + if !ok { + return nil, fmt.Errorf("model item is not a map") + } - modelName, ok := model["name"].(string) - if !ok { - return nil, fmt.Errorf("model name is not a string or not found") - } + modelName, ok := model["name"].(string) + if !ok { + return nil, fmt.Errorf("model name is not a string or not found") + } - modelTag, ok := model["tag"].(string) // Using 'tag' as the version - if !ok { - return nil, fmt.Errorf("model version for %s is not a string or not found", modelName) - } + modelTag, ok := model["tag"].(string) // Using 'tag' as the version + if !ok { + return nil, fmt.Errorf("model version for %s is not a string or not found", modelName) + } - modelsInfo[modelName] = modelTag - } + modelsInfo[modelName] = modelTag + } return modelsInfo, nil } @@ -117,7 +117,7 @@ func GenerateWorkspaceManifest(name, namespace, imageName string, resourceCount workspaceInference.Template = podTemplate } - workspace.Inference = workspaceInference + workspace.Inference = &workspaceInference return workspace } From 082844d22d77eabdc96cefc06574495934b2394e Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Tue, 19 Mar 2024 19:33:01 -0700 Subject: [PATCH 15/29] fix: Add name flag --- api/v1alpha1/workspace_validation.go | 9 +++++++-- api/v1alpha1/workspace_validation_test.go | 13 ++++++++++++- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/api/v1alpha1/workspace_validation.go b/api/v1alpha1/workspace_validation.go index b135f5886..79b27e1b9 100644 --- a/api/v1alpha1/workspace_validation.go +++ b/api/v1alpha1/workspace_validation.go @@ -41,6 +41,7 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) { w.Resource.validateCreate(*w.Inference).ViaField("resource"), ) if w.Inference != nil { + // TODO: Add Adapter Spec Validation - Including DataSource Validation for Adapter errs = errs.Also(w.Inference.validateCreate().ViaField("inference")) } if w.Tuning != nil { @@ -54,6 +55,7 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) { w.Resource.validateUpdate(&old.Resource).ViaField("resource"), ) if w.Inference != nil { + // TODO: Add Adapter Spec Validation - Including DataSource Validation for Adapter errs = errs.Also(w.Inference.validateUpdate(old.Inference).ViaField("inference")) } if w.Tuning != nil { @@ -112,7 +114,7 @@ func (r *TuningSpec) validateUpdate(old *TuningSpec) (errs *apis.FieldError) { if r.Input == nil { errs = errs.Also(apis.ErrMissingField("Input")) } else { - errs = errs.Also(r.Input.validateUpdate(old.Input).ViaField("Input")) + errs = errs.Also(r.Input.validateUpdate(old.Input, true).ViaField("Input")) } if r.Output == nil { errs = errs.Also(apis.ErrMissingField("Output")) @@ -150,7 +152,10 @@ func (r *DataSource) validateCreate() (errs *apis.FieldError) { return errs } -func (r *DataSource) validateUpdate(old *DataSource) (errs *apis.FieldError) { +func (r *DataSource) validateUpdate(old *DataSource, isTuning bool) (errs *apis.FieldError) { + if isTuning && !reflect.DeepEqual(old.Name, r.Name) { + errs = errs.Also(apis.ErrInvalidValue("During tuning Name field cannot be changed once set", "Name")) + } oldURLs := make([]string, len(old.URLs)) copy(oldURLs, old.URLs) sort.Strings(oldURLs) diff --git a/api/v1alpha1/workspace_validation_test.go b/api/v1alpha1/workspace_validation_test.go index d1cea034d..11631e67b 100644 --- a/api/v1alpha1/workspace_validation_test.go +++ b/api/v1alpha1/workspace_validation_test.go @@ -898,6 +898,17 @@ func TestDataSourceValidateUpdate(t *testing.T) { }, wantErr: false, }, + { + name: "Name changed", + oldSource: &DataSource{ + Name: "original-dataset", + }, + newSource: &DataSource{ + Name: "new-dataset", + }, + wantErr: true, + errFields: []string{"Name"}, + }, { name: "URLs changed", oldSource: &DataSource{ @@ -946,7 +957,7 @@ func TestDataSourceValidateUpdate(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - errs := tt.newSource.validateUpdate(tt.oldSource) + errs := tt.newSource.validateUpdate(tt.oldSource, true) hasErrs := errs != nil if hasErrs != tt.wantErr { From 56409a71ad22ea171786c737d23ab8c8be967d3a Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Wed, 20 Mar 2024 13:37:35 -0700 Subject: [PATCH 16/29] feat: Setup Interface for fine tuning --- api/v1alpha1/workspace_condition_types.go | 3 + api/v1alpha1/workspace_validation_test.go | 8 +- .../kaito_workspace_tuning_falcon_7b.yaml | 20 +++++ .../kaito_workspace_falcon_40b-instruct.yaml | 0 .../kaito_workspace_falcon_40b.yaml | 0 .../kaito_workspace_falcon_7b-instruct.yaml | 0 .../kaito_workspace_falcon_7b.yaml | 0 .../kaito_workspace_llama2_13b-chat.yaml | 0 .../kaito_workspace_llama2_13b.yaml | 0 .../kaito_workspace_llama2_70b-chat.yaml | 0 .../kaito_workspace_llama2_70b.yaml | 0 .../kaito_workspace_llama2_7b-chat.yaml | 0 .../kaito_workspace_llama2_7b.yaml | 0 .../kaito_workspace_mistral_7b-instruct.yaml | 0 .../kaito_workspace_mistral_7b.yaml | 0 .../kaito_workspace_phi-2.yaml | 0 pkg/controllers/workspace_controller.go | 75 +++++++++++++++++-- pkg/inference/preset-inferences.go | 10 +-- pkg/inference/preset-inferences_test.go | 2 +- pkg/model/interface.go | 7 +- pkg/tuning/preset-tuning-types.go | 21 ++++++ pkg/tuning/preset-tuning.go | 29 +++++++ pkg/utils/testModel.go | 8 +- presets/models/falcon/README.md | 8 +- presets/models/falcon/model.go | 55 +++++++++++--- presets/models/llama2/README.md | 6 +- presets/models/llama2/model.go | 21 ++++-- presets/models/llama2chat/README.md | 6 +- presets/models/llama2chat/model.go | 22 ++++-- presets/models/mistral/README.md | 4 +- presets/models/mistral/model.go | 27 ++++++- presets/models/phi/README.md | 2 +- presets/models/phi/model.go | 4 +- 33 files changed, 272 insertions(+), 66 deletions(-) create mode 100644 examples/fine-tuning/kaito_workspace_tuning_falcon_7b.yaml rename examples/{ => inference}/kaito_workspace_falcon_40b-instruct.yaml (100%) rename examples/{ => inference}/kaito_workspace_falcon_40b.yaml (100%) rename examples/{ => inference}/kaito_workspace_falcon_7b-instruct.yaml (100%) rename examples/{ => inference}/kaito_workspace_falcon_7b.yaml (100%) rename examples/{ => inference}/kaito_workspace_llama2_13b-chat.yaml (100%) rename examples/{ => inference}/kaito_workspace_llama2_13b.yaml (100%) rename examples/{ => inference}/kaito_workspace_llama2_70b-chat.yaml (100%) rename examples/{ => inference}/kaito_workspace_llama2_70b.yaml (100%) rename examples/{ => inference}/kaito_workspace_llama2_7b-chat.yaml (100%) rename examples/{ => inference}/kaito_workspace_llama2_7b.yaml (100%) rename examples/{ => inference}/kaito_workspace_mistral_7b-instruct.yaml (100%) rename examples/{ => inference}/kaito_workspace_mistral_7b.yaml (100%) rename examples/{ => inference}/kaito_workspace_phi-2.yaml (100%) create mode 100644 pkg/tuning/preset-tuning-types.go create mode 100644 pkg/tuning/preset-tuning.go diff --git a/api/v1alpha1/workspace_condition_types.go b/api/v1alpha1/workspace_condition_types.go index 762d8dafc..9845b8a0c 100644 --- a/api/v1alpha1/workspace_condition_types.go +++ b/api/v1alpha1/workspace_condition_types.go @@ -16,6 +16,9 @@ const ( // WorkspaceConditionTypeInferenceStatus is the state when Inference has been created. WorkspaceConditionTypeInferenceStatus = ConditionType("InferenceReady") + // WorkspaceConditionTypeTuningStatus is the state when Tuning has been created. + WorkspaceConditionTypeTuningStatus = ConditionType("TuningReady") + //WorkspaceConditionTypeDeleting is the Workspace state when starts to get deleted. WorkspaceConditionTypeDeleting = ConditionType("WorkspaceDeleting") diff --git a/api/v1alpha1/workspace_validation_test.go b/api/v1alpha1/workspace_validation_test.go index 11631e67b..695d42298 100644 --- a/api/v1alpha1/workspace_validation_test.go +++ b/api/v1alpha1/workspace_validation_test.go @@ -21,8 +21,8 @@ var perGPUMemoryRequirement string type testModel struct{} -func (*testModel) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*testModel) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ GPUCountRequirement: gpuCountRequirement, TotalGPUMemoryRequirement: totalGPUMemoryRequirement, PerGPUMemoryRequirement: perGPUMemoryRequirement, @@ -34,8 +34,8 @@ func (*testModel) SupportDistributedInference() bool { type testModelPrivate struct{} -func (*testModelPrivate) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*testModelPrivate) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ImageAccessMode: "private", GPUCountRequirement: gpuCountRequirement, TotalGPUMemoryRequirement: totalGPUMemoryRequirement, diff --git a/examples/fine-tuning/kaito_workspace_tuning_falcon_7b.yaml b/examples/fine-tuning/kaito_workspace_tuning_falcon_7b.yaml new file mode 100644 index 000000000..6d6ed7831 --- /dev/null +++ b/examples/fine-tuning/kaito_workspace_tuning_falcon_7b.yaml @@ -0,0 +1,20 @@ +apiVersion: kaito.sh/v1alpha1 +kind: Workspace +metadata: + name: workspace-tuning-falcon-7b +spec: + resource: + instanceType: "Standard_NC12s_v3" + labelSelector: + matchLabels: + app: tuning-falcon-7b + tuning: + preset: + name: falcon-7b + method: lora + config: tuning-config-map # ConfigMap containing tuning arguments + input: + name: tuning-data + hostPath: /path/to/your/input/data # dataset on node + output: + hostPath: /path/to/store/output # Tuning Output diff --git a/examples/kaito_workspace_falcon_40b-instruct.yaml b/examples/inference/kaito_workspace_falcon_40b-instruct.yaml similarity index 100% rename from examples/kaito_workspace_falcon_40b-instruct.yaml rename to examples/inference/kaito_workspace_falcon_40b-instruct.yaml diff --git a/examples/kaito_workspace_falcon_40b.yaml b/examples/inference/kaito_workspace_falcon_40b.yaml similarity index 100% rename from examples/kaito_workspace_falcon_40b.yaml rename to examples/inference/kaito_workspace_falcon_40b.yaml diff --git a/examples/kaito_workspace_falcon_7b-instruct.yaml b/examples/inference/kaito_workspace_falcon_7b-instruct.yaml similarity index 100% rename from examples/kaito_workspace_falcon_7b-instruct.yaml rename to examples/inference/kaito_workspace_falcon_7b-instruct.yaml diff --git a/examples/kaito_workspace_falcon_7b.yaml b/examples/inference/kaito_workspace_falcon_7b.yaml similarity index 100% rename from examples/kaito_workspace_falcon_7b.yaml rename to examples/inference/kaito_workspace_falcon_7b.yaml diff --git a/examples/kaito_workspace_llama2_13b-chat.yaml b/examples/inference/kaito_workspace_llama2_13b-chat.yaml similarity index 100% rename from examples/kaito_workspace_llama2_13b-chat.yaml rename to examples/inference/kaito_workspace_llama2_13b-chat.yaml diff --git a/examples/kaito_workspace_llama2_13b.yaml b/examples/inference/kaito_workspace_llama2_13b.yaml similarity index 100% rename from examples/kaito_workspace_llama2_13b.yaml rename to examples/inference/kaito_workspace_llama2_13b.yaml diff --git a/examples/kaito_workspace_llama2_70b-chat.yaml b/examples/inference/kaito_workspace_llama2_70b-chat.yaml similarity index 100% rename from examples/kaito_workspace_llama2_70b-chat.yaml rename to examples/inference/kaito_workspace_llama2_70b-chat.yaml diff --git a/examples/kaito_workspace_llama2_70b.yaml b/examples/inference/kaito_workspace_llama2_70b.yaml similarity index 100% rename from examples/kaito_workspace_llama2_70b.yaml rename to examples/inference/kaito_workspace_llama2_70b.yaml diff --git a/examples/kaito_workspace_llama2_7b-chat.yaml b/examples/inference/kaito_workspace_llama2_7b-chat.yaml similarity index 100% rename from examples/kaito_workspace_llama2_7b-chat.yaml rename to examples/inference/kaito_workspace_llama2_7b-chat.yaml diff --git a/examples/kaito_workspace_llama2_7b.yaml b/examples/inference/kaito_workspace_llama2_7b.yaml similarity index 100% rename from examples/kaito_workspace_llama2_7b.yaml rename to examples/inference/kaito_workspace_llama2_7b.yaml diff --git a/examples/kaito_workspace_mistral_7b-instruct.yaml b/examples/inference/kaito_workspace_mistral_7b-instruct.yaml similarity index 100% rename from examples/kaito_workspace_mistral_7b-instruct.yaml rename to examples/inference/kaito_workspace_mistral_7b-instruct.yaml diff --git a/examples/kaito_workspace_mistral_7b.yaml b/examples/inference/kaito_workspace_mistral_7b.yaml similarity index 100% rename from examples/kaito_workspace_mistral_7b.yaml rename to examples/inference/kaito_workspace_mistral_7b.yaml diff --git a/examples/kaito_workspace_phi-2.yaml b/examples/inference/kaito_workspace_phi-2.yaml similarity index 100% rename from examples/kaito_workspace_phi-2.yaml rename to examples/inference/kaito_workspace_phi-2.yaml diff --git a/pkg/controllers/workspace_controller.go b/pkg/controllers/workspace_controller.go index a2e4fc18d..db1150ad5 100644 --- a/pkg/controllers/workspace_controller.go +++ b/pkg/controllers/workspace_controller.go @@ -5,6 +5,7 @@ package controllers import ( "context" "fmt" + "github.com/azure/kaito/pkg/tuning" "sort" "strings" "time" @@ -109,16 +110,27 @@ func (c *WorkspaceReconciler) addOrUpdateWorkspace(ctx context.Context, wObj *ka return reconcile.Result{}, err } - if err = c.applyInference(ctx, wObj); err != nil { - if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeReady, metav1.ConditionFalse, - "workspaceFailed", err.Error()); updateErr != nil { - klog.ErrorS(updateErr, "failed to update workspace status", "workspace", klog.KObj(wObj)) - return reconcile.Result{}, updateErr + if wObj.Tuning != nil { + if err = c.applyTuning(ctx, wObj); err != nil { + if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeReady, metav1.ConditionFalse, + "workspaceFailed", err.Error()); updateErr != nil { + klog.ErrorS(updateErr, "failed to update workspace status", "workspace", klog.KObj(wObj)) + return reconcile.Result{}, updateErr + } + return reconcile.Result{}, err + } + } + if wObj.Inference != nil { + if err = c.applyInference(ctx, wObj); err != nil { + if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeReady, metav1.ConditionFalse, + "workspaceFailed", err.Error()); updateErr != nil { + klog.ErrorS(updateErr, "failed to update workspace status", "workspace", klog.KObj(wObj)) + return reconcile.Result{}, updateErr + } + return reconcile.Result{}, err } - return reconcile.Result{}, err } - // TODO apply TrainingSpec if err = c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeReady, metav1.ConditionTrue, "workspaceReady", "workspace is ready"); err != nil { klog.ErrorS(err, "failed to update workspace status", "workspace", klog.KObj(wObj)) @@ -423,6 +435,55 @@ func (c *WorkspaceReconciler) ensureService(ctx context.Context, wObj *kaitov1al return nil } +func (c *WorkspaceReconciler) applyTuning(ctx context.Context, wObj *kaitov1alpha1.Workspace) error { + var err error + func() { + if wObj.Tuning.Preset != nil { + presetName := string(wObj.Tuning.Preset.Name) + model := plugin.KaitoModelRegister.MustGet(presetName) + + trainingParam := model.GetTrainingParameters() + + var existingObj client.Object + existingObj = &appsv1.Deployment{} + if err = resources.GetResource(ctx, wObj.Name, wObj.Namespace, c.Client, existingObj); err == nil { + klog.InfoS("A training workload already exists for workspace", "workspace", klog.KObj(wObj)) + if err = resources.CheckResourceStatus(existingObj, c.Client, trainingParam.DeploymentTimeout); err != nil { + return + } + } else if apierrors.IsNotFound(err) { + var workloadObj client.Object + // Need to create a new workload + workloadObj, err = tuning.CreatePresetTuning(ctx, wObj, trainingParam, c.Client) + if err != nil { + return + } + if err = resources.CheckResourceStatus(workloadObj, c.Client, trainingParam.DeploymentTimeout); err != nil { + return + } + } + } + }() + + if err != nil { + if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeTuningStatus, metav1.ConditionFalse, + "WorkspaceTuningStatusFailed", err.Error()); updateErr != nil { + klog.ErrorS(updateErr, "failed to update workspace status", "workspace", klog.KObj(wObj)) + return updateErr + } else { + return err + + } + } + + if err := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeTuningStatus, metav1.ConditionTrue, + "WorkspaceTuningStatusSuccess", "Tuning has been deployed successfully"); err != nil { + klog.ErrorS(err, "failed to update workspace status", "workspace", klog.KObj(wObj)) + return err + } + return nil +} + // applyInference applies inference spec. func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1alpha1.Workspace) error { var err error diff --git a/pkg/inference/preset-inferences.go b/pkg/inference/preset-inferences.go index 9b02012b7..4c4792b54 100644 --- a/pkg/inference/preset-inferences.go +++ b/pkg/inference/preset-inferences.go @@ -67,7 +67,7 @@ var ( } ) -func updateTorchParamsForDistributedInference(ctx context.Context, kubeClient client.Client, wObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetInferenceParam) error { +func updateTorchParamsForDistributedInference(ctx context.Context, kubeClient client.Client, wObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetParam) error { existingService := &corev1.Service{} err := resources.GetResource(ctx, wObj.Name, wObj.Namespace, kubeClient, existingService) if err != nil { @@ -92,7 +92,7 @@ func updateTorchParamsForDistributedInference(ctx context.Context, kubeClient cl return nil } -func GetImageInfo(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetInferenceParam) (string, []corev1.LocalObjectReference) { +func GetImageInfo(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetParam) (string, []corev1.LocalObjectReference) { imageName := string(workspaceObj.Inference.Preset.Name) imageTag := inferenceObj.Tag imagePullSecretRefs := []corev1.LocalObjectReference{} @@ -110,7 +110,7 @@ func GetImageInfo(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, in } func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, - inferenceObj *model.PresetInferenceParam, supportDistributedInference bool, kubeClient client.Client) (client.Object, error) { + inferenceObj *model.PresetParam, supportDistributedInference bool, kubeClient client.Client) (client.Object, error) { if inferenceObj.TorchRunParams != nil && supportDistributedInference { if err := updateTorchParamsForDistributedInference(ctx, kubeClient, workspaceObj, inferenceObj); err != nil { klog.ErrorS(err, "failed to update torch params", "workspace", workspaceObj) @@ -141,7 +141,7 @@ func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Work // torchrun baseCommand // and sets the GPU resources required for inference. // Returns the command and resource configuration. -func prepareInferenceParameters(ctx context.Context, inferenceObj *model.PresetInferenceParam) ([]string, corev1.ResourceRequirements) { +func prepareInferenceParameters(ctx context.Context, inferenceObj *model.PresetParam) ([]string, corev1.ResourceRequirements) { torchCommand := buildCommandStr(inferenceObj.BaseCommand, inferenceObj.TorchRunParams) torchCommand = buildCommandStr(torchCommand, inferenceObj.TorchRunRdzvParams) modelCommand := buildCommandStr(InferenceFile, inferenceObj.ModelRunParams) @@ -159,7 +159,7 @@ func prepareInferenceParameters(ctx context.Context, inferenceObj *model.PresetI return commands, resourceRequirements } -func configVolume(wObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetInferenceParam) ([]corev1.Volume, []corev1.VolumeMount) { +func configVolume(wObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetParam) ([]corev1.Volume, []corev1.VolumeMount) { volume := []corev1.Volume{} volumeMount := []corev1.VolumeMount{} diff --git a/pkg/inference/preset-inferences_test.go b/pkg/inference/preset-inferences_test.go index cd8df067c..31bf0551e 100644 --- a/pkg/inference/preset-inferences_test.go +++ b/pkg/inference/preset-inferences_test.go @@ -62,7 +62,7 @@ func TestCreatePresetInference(t *testing.T) { useHeadlessSvc := false - var inferenceObj *model.PresetInferenceParam + var inferenceObj *model.PresetParam model := plugin.KaitoModelRegister.MustGet(tc.modelName) inferenceObj = model.GetInferenceParameters() diff --git a/pkg/model/interface.go b/pkg/model/interface.go index 217c1f889..2763b5eec 100644 --- a/pkg/model/interface.go +++ b/pkg/model/interface.go @@ -7,12 +7,13 @@ import ( ) type Model interface { - GetInferenceParameters() *PresetInferenceParam + GetInferenceParameters() *PresetParam + GetTrainingParameters() *PresetParam SupportDistributedInference() bool //If true, the model workload will be a StatefulSet, using the torch elastic runtime framework. } -// PresetInferenceParam defines the preset inference parameters for a model. -type PresetInferenceParam struct { +// PresetParam defines the preset inference parameters for a model. +type PresetParam struct { ModelFamilyName string // The name of the model family. ImageAccessMode string // Defines where the Image is Public or Private. DiskStorageRequirement string // Disk storage requirements for the model. diff --git a/pkg/tuning/preset-tuning-types.go b/pkg/tuning/preset-tuning-types.go new file mode 100644 index 000000000..51f36511d --- /dev/null +++ b/pkg/tuning/preset-tuning-types.go @@ -0,0 +1,21 @@ +package tuning + +import corev1 "k8s.io/api/core/v1" + +const ( + DefaultNumProcesses = "1" + DefaultNumMachines = "1" + DefaultMachineRank = "0" + DefaultGPUIds = "all" +) + +var ( + DefaultAccelerateParams = map[string]string{ + "num_processes": DefaultNumProcesses, + "num_machines": DefaultNumMachines, + "machine_rank": DefaultMachineRank, + "gpu_ids": DefaultGPUIds, + } + + DefaultImagePullSecrets = []corev1.LocalObjectReference{} +) diff --git a/pkg/tuning/preset-tuning.go b/pkg/tuning/preset-tuning.go new file mode 100644 index 000000000..cbbb55a06 --- /dev/null +++ b/pkg/tuning/preset-tuning.go @@ -0,0 +1,29 @@ +package tuning + +import ( + "context" + kaitov1alpha1 "github.com/azure/kaito/api/v1alpha1" + "github.com/azure/kaito/pkg/model" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +func CreatePresetTuning(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, + tuningObj *model.PresetParam, kubeClient client.Client) (client.Object, error) { + // TODO + + // e.g. example from Inference + //volume, volumeMount := configVolume(workspaceObj, inferenceObj) + //commands, resourceReq := prepareInferenceParameters(ctx, inferenceObj) + //image, imagePullSecrets := GetImageInfo(ctx, workspaceObj, inferenceObj) + // + //depObj = resources.GenerateDeploymentManifest(ctx, workspaceObj, image, imagePullSecrets, *workspaceObj.Resource.Count, commands, + // containerPorts, livenessProbe, readinessProbe, resourceReq, tolerations, volume, volumeMount) + // + //err := resources.CreateResource(ctx, depObj, kubeClient) + //if client.IgnoreAlreadyExists(err) != nil { + // return nil, err + //} + //return depObj, nil + + return nil, nil +} diff --git a/pkg/utils/testModel.go b/pkg/utils/testModel.go index 99e3d8aca..50f6c9175 100644 --- a/pkg/utils/testModel.go +++ b/pkg/utils/testModel.go @@ -12,8 +12,8 @@ import ( type testModel struct{} -func (*testModel) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*testModel) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ GPUCountRequirement: "1", DeploymentTimeout: time.Duration(30) * time.Minute, } @@ -24,8 +24,8 @@ func (*testModel) SupportDistributedInference() bool { type testDistributedModel struct{} -func (*testDistributedModel) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*testDistributedModel) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ GPUCountRequirement: "1", DeploymentTimeout: time.Duration(30) * time.Minute, } diff --git a/presets/models/falcon/README.md b/presets/models/falcon/README.md index 81a1ced6f..e8cd895e6 100644 --- a/presets/models/falcon/README.md +++ b/presets/models/falcon/README.md @@ -1,10 +1,10 @@ ## Supported Models |Model name| Model source | Sample workspace|Kubernetes Workload|Distributed inference| |----|:----:|:----:| :----: |:----: | -|falcon-7b-instruct |[tiiuae](https://huggingface.co/tiiuae/falcon-7b-instruct)|[link](../../../examples/kaito_workspace_falcon_7b-instruct.yaml)|Deployment| false| -|falcon-7b |[tiiuae](https://huggingface.co/tiiuae/falcon-7b) |[link](../../../examples/kaito_workspace_falcon_7b.yaml)|Deployment| false| -|falcon-40b-instruct|[tiiuae](https://huggingface.co/tiiuae/falcon-40b-instruct) |[link](../../../examples/kaito_workspace_falcon_40b-instruct.yaml)|Deployment| false| -|falcon-40b |[tiiuae](https://huggingface.co/tiiuae/falcon-40b)|[link](../../../examples/kaito_workspace_falcon_40b.yaml)|Deployment| false| +|falcon-7b-instruct |[tiiuae](https://huggingface.co/tiiuae/falcon-7b-instruct)|[link](../../../examples/inference/kaito_workspace_falcon_7b-instruct.yaml)|Deployment| false| +|falcon-7b |[tiiuae](https://huggingface.co/tiiuae/falcon-7b) |[link](../../../examples/inference/kaito_workspace_falcon_7b.yaml)|Deployment| false| +|falcon-40b-instruct|[tiiuae](https://huggingface.co/tiiuae/falcon-40b-instruct) |[link](../../../examples/inference/kaito_workspace_falcon_40b-instruct.yaml)|Deployment| false| +|falcon-40b |[tiiuae](https://huggingface.co/tiiuae/falcon-40b)|[link](../../../examples/inference/kaito_workspace_falcon_40b.yaml)|Deployment| false| ## Image Source - **Public**: Kaito maintainers manage the lifecycle of the inference service images that contain model weights. The images are available in Microsoft Container Registry (MCR). diff --git a/presets/models/falcon/model.go b/presets/models/falcon/model.go index 7501dce23..c0f07495f 100644 --- a/presets/models/falcon/model.go +++ b/presets/models/falcon/model.go @@ -54,8 +54,8 @@ var falconA falcon7b type falcon7b struct{} -func (*falcon7b) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*falcon7b) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "Falcon", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), DiskStorageRequirement: "50Gi", @@ -68,8 +68,23 @@ func (*falcon7b) GetInferenceParameters() *model.PresetInferenceParam { BaseCommand: baseCommandPresetFalcon, Tag: PresetFalconTagMap["Falcon7B"], } - } +func (*falcon7b) GetTrainingParameters() *model.PresetParam { + return &model.PresetParam{ + ModelFamilyName: "Falcon", + ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), + DiskStorageRequirement: "50Gi", + GPUCountRequirement: "2", + TotalGPUMemoryRequirement: "16Gi", + PerGPUMemoryRequirement: "16Gi", + //TorchRunParams: tuning.DefaultAccelerateParams, // TODO + //ModelRunPrams: falconRunTuningParams, // TODO + DeploymentTimeout: time.Duration(30) * time.Minute, + BaseCommand: baseCommandPresetFalcon, + Tag: PresetFalconTagMap["Falcon7B"], + } +} + func (*falcon7b) SupportDistributedInference() bool { return false } @@ -78,8 +93,8 @@ var falconB falcon7bInst type falcon7bInst struct{} -func (*falcon7bInst) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*falcon7bInst) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "Falcon", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), DiskStorageRequirement: "50Gi", @@ -94,6 +109,9 @@ func (*falcon7bInst) GetInferenceParameters() *model.PresetInferenceParam { } } +func (*falcon7bInst) GetTrainingParameters() *model.PresetParam { + return nil // It is not recommended/ideal to further fine-tune instruct models - Already been fine-tuned +} func (*falcon7bInst) SupportDistributedInference() bool { return false } @@ -102,8 +120,8 @@ var falconC falcon40b type falcon40b struct{} -func (*falcon40b) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*falcon40b) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "Falcon", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), DiskStorageRequirement: "400", @@ -118,6 +136,21 @@ func (*falcon40b) GetInferenceParameters() *model.PresetInferenceParam { } } +func (*falcon40b) GetTrainingParameters() *model.PresetParam { + return &model.PresetParam{ + ModelFamilyName: "Falcon", + ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), + DiskStorageRequirement: "50Gi", + GPUCountRequirement: "2", + TotalGPUMemoryRequirement: "90Gi", + PerGPUMemoryRequirement: "16Gi", + //TorchRunParams: tuning.DefaultAccelerateParams, // TODO + //ModelRunPrams: falconRunTuningParams, // TODO + DeploymentTimeout: time.Duration(30) * time.Minute, + BaseCommand: baseCommandPresetFalcon, + Tag: PresetFalconTagMap["Falcon40B"], + } +} func (*falcon40b) SupportDistributedInference() bool { return false } @@ -126,8 +159,8 @@ var falconD falcon40bInst type falcon40bInst struct{} -func (*falcon40bInst) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*falcon40bInst) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "Falcon", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), DiskStorageRequirement: "400", @@ -141,7 +174,9 @@ func (*falcon40bInst) GetInferenceParameters() *model.PresetInferenceParam { Tag: PresetFalconTagMap["Falcon40BInstruct"], } } - +func (*falcon40bInst) GetTrainingParameters() *model.PresetParam { + return nil // It is not recommended/ideal to further fine-tune instruct models - Already been fine-tuned +} func (*falcon40bInst) SupportDistributedInference() bool { return false } diff --git a/presets/models/llama2/README.md b/presets/models/llama2/README.md index e6a40563a..ba2646a2b 100644 --- a/presets/models/llama2/README.md +++ b/presets/models/llama2/README.md @@ -1,9 +1,9 @@ ## Supported Models |Model name| Model source | Sample workspace|Kubernetes Workload|Distributed inference| |----|:----:|:----:| :----: |:----: | -|llama2-7b |[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/kaito_workspace_llama2_7b.yaml)|Deployment| false| -|llama2-13b|[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/kaito_workspace_llama2_13b.yaml)|StatefulSet| true| -|llama2-70b|[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/kaito_workspace_llama2_70b.yaml)|StatefulSet| true| +|llama2-7b |[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/inference/kaito_workspace_llama2_7b.yaml)|Deployment| false| +|llama2-13b|[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/inference/kaito_workspace_llama2_13b.yaml)|StatefulSet| true| +|llama2-70b|[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/inference/kaito_workspace_llama2_70b.yaml)|StatefulSet| true| ## Image Source - **Private**: User needs to manage the lifecycle of the inference service images that contain model weights (e.g., managing image tags). The images are available in user's private container registry. diff --git a/presets/models/llama2/model.go b/presets/models/llama2/model.go index 30c97b7fd..673a67da3 100644 --- a/presets/models/llama2/model.go +++ b/presets/models/llama2/model.go @@ -38,8 +38,8 @@ var llama2A llama2Text7b type llama2Text7b struct{} -func (*llama2Text7b) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*llama2Text7b) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "LLaMa2", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePrivate), DiskStorageRequirement: "34Gi", @@ -56,6 +56,9 @@ func (*llama2Text7b) GetInferenceParameters() *model.PresetInferenceParam { } } +func (*llama2Text7b) GetTrainingParameters() *model.PresetParam { + return nil // Currently doesn't support fine-tuning +} func (*llama2Text7b) SupportDistributedInference() bool { return false } @@ -64,8 +67,8 @@ var llama2B llama2Text13b type llama2Text13b struct{} -func (*llama2Text13b) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*llama2Text13b) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "LLaMa2", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePrivate), DiskStorageRequirement: "46Gi", @@ -81,6 +84,9 @@ func (*llama2Text13b) GetInferenceParameters() *model.PresetInferenceParam { // Tag: llama has private image access mode. The image tag is determined by the user. } } +func (*llama2Text13b) GetTrainingParameters() *model.PresetParam { + return nil // Currently doesn't support fine-tuning +} func (*llama2Text13b) SupportDistributedInference() bool { return true } @@ -89,8 +95,8 @@ var llama2C llama2Text70b type llama2Text70b struct{} -func (*llama2Text70b) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*llama2Text70b) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "LLaMa2", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePrivate), DiskStorageRequirement: "158Gi", @@ -106,6 +112,9 @@ func (*llama2Text70b) GetInferenceParameters() *model.PresetInferenceParam { // Tag: llama has private image access mode. The image tag is determined by the user. } } +func (*llama2Text70b) GetTrainingParameters() *model.PresetParam { + return nil // Currently doesn't support fine-tuning +} func (*llama2Text70b) SupportDistributedInference() bool { return true } diff --git a/presets/models/llama2chat/README.md b/presets/models/llama2chat/README.md index 53e241fab..0cf9ec3be 100644 --- a/presets/models/llama2chat/README.md +++ b/presets/models/llama2chat/README.md @@ -1,9 +1,9 @@ ## Supported Models |Model name| Model source | Sample workspace|Kubernetes Workload|Distributed inference| |----|:----:|:----:| :----: |:----: | -|llama2-7b-chat |[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/kaito_workspace_llama2_7b-chat.yaml)|Deployment| false| -|llama2-13b-chat|[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/kaito_workspace_llama2_13b-chat.yaml)|StatefulSet| true| -|llama2-70b-chat|[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/kaito_workspace_llama2_70b-chat.yaml)|StatefulSet| true| +|llama2-7b-chat |[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/inference/kaito_workspace_llama2_7b-chat.yaml)|Deployment| false| +|llama2-13b-chat|[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/inference/kaito_workspace_llama2_13b-chat.yaml)|StatefulSet| true| +|llama2-70b-chat|[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/inference/kaito_workspace_llama2_70b-chat.yaml)|StatefulSet| true| ## Image Source - **Private**: User needs to manage the lifecycle of the inference service images that contain model weights (e.g., managing image tags). The images are available in user's private container registry. diff --git a/presets/models/llama2chat/model.go b/presets/models/llama2chat/model.go index cc0d8d4c6..a555ebc07 100644 --- a/presets/models/llama2chat/model.go +++ b/presets/models/llama2chat/model.go @@ -38,8 +38,8 @@ var llama2chatA llama2Chat7b type llama2Chat7b struct{} -func (*llama2Chat7b) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*llama2Chat7b) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "LLaMa2", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePrivate), DiskStorageRequirement: "34Gi", @@ -54,7 +54,9 @@ func (*llama2Chat7b) GetInferenceParameters() *model.PresetInferenceParam { WorldSize: 1, // Tag: llama has private image access mode. The image tag is determined by the user. } - +} +func (*llama2Chat7b) GetTrainingParameters() *model.PresetParam { + return nil // Currently doesn't support fine-tuning } func (*llama2Chat7b) SupportDistributedInference() bool { return false @@ -64,8 +66,8 @@ var llama2chatB llama2Chat13b type llama2Chat13b struct{} -func (*llama2Chat13b) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*llama2Chat13b) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "LLaMa2", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePrivate), DiskStorageRequirement: "46Gi", @@ -81,6 +83,9 @@ func (*llama2Chat13b) GetInferenceParameters() *model.PresetInferenceParam { // Tag: llama has private image access mode. The image tag is determined by the user. } } +func (*llama2Chat13b) GetTrainingParameters() *model.PresetParam { + return nil // Currently doesn't support fine-tuning +} func (*llama2Chat13b) SupportDistributedInference() bool { return true } @@ -89,8 +94,8 @@ var llama2chatC llama2Chat70b type llama2Chat70b struct{} -func (*llama2Chat70b) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*llama2Chat70b) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "LLaMa2", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePrivate), DiskStorageRequirement: "158Gi", @@ -106,6 +111,9 @@ func (*llama2Chat70b) GetInferenceParameters() *model.PresetInferenceParam { // Tag: llama has private image access mode. The image tag is determined by the user. } } +func (*llama2Chat70b) GetTrainingParameters() *model.PresetParam { + return nil // Currently doesn't support fine-tuning +} func (*llama2Chat70b) SupportDistributedInference() bool { return true } diff --git a/presets/models/mistral/README.md b/presets/models/mistral/README.md index 4d0c56ba6..2d037f7a4 100644 --- a/presets/models/mistral/README.md +++ b/presets/models/mistral/README.md @@ -1,8 +1,8 @@ ## Supported Models |Model name| Model source | Sample workspace|Kubernetes Workload|Distributed inference| |----|:----:|:----:| :----: |:----: | -|mistral-7b-instruct |[mistralai](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2)|[link](../../../examples/kaito_workspace_mistral_7b-instruct.yaml)|Deployment| false| -|mistral-7b |[mistralai](https://huggingface.co/mistralai/Mistral-7B-v0.1) |[link](../../../examples/kaito_workspace_mistral_7b.yaml)|Deployment| false| +|mistral-7b-instruct |[mistralai](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2)|[link](../../../examples/inference/kaito_workspace_mistral_7b-instruct.yaml)|Deployment| false| +|mistral-7b |[mistralai](https://huggingface.co/mistralai/Mistral-7B-v0.1) |[link](../../../examples/inference/kaito_workspace_mistral_7b.yaml)|Deployment| false| ## Image Source diff --git a/presets/models/mistral/model.go b/presets/models/mistral/model.go index 7089eafb6..bcf06203b 100644 --- a/presets/models/mistral/model.go +++ b/presets/models/mistral/model.go @@ -42,8 +42,8 @@ var mistralA mistral7b type mistral7b struct{} -func (*mistral7b) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*mistral7b) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "Mistral", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), DiskStorageRequirement: "100Gi", @@ -58,6 +58,22 @@ func (*mistral7b) GetInferenceParameters() *model.PresetInferenceParam { } } +func (*mistral7b) GetTrainingParameters() *model.PresetParam { + return &model.PresetParam{ + ModelFamilyName: "Mistral", + ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), + DiskStorageRequirement: "100Gi", + GPUCountRequirement: "1", + TotalGPUMemoryRequirement: "16Gi", + PerGPUMemoryRequirement: "16Gi", // We run Mistral using native vertical model parallel, no per GPU memory requirement. + //TorchRunParams: tuning.DefaultAccelerateParams, + //ModelRunParams: mistralRunParams, + DeploymentTimeout: time.Duration(30) * time.Minute, + BaseCommand: baseCommandPresetMistral, + Tag: PresetMistralTagMap["Mistral7B"], + } +} + func (*mistral7b) SupportDistributedInference() bool { return false } @@ -66,8 +82,8 @@ var mistralB mistral7bInst type mistral7bInst struct{} -func (*mistral7bInst) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*mistral7bInst) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "Mistral", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), DiskStorageRequirement: "100Gi", @@ -82,6 +98,9 @@ func (*mistral7bInst) GetInferenceParameters() *model.PresetInferenceParam { } } +func (*mistral7bInst) GetTrainingParameters() *model.PresetParam { + return nil // It is not recommended/ideal to further fine-tune instruct models - Already been fine-tuned +} func (*mistral7bInst) SupportDistributedInference() bool { return false } diff --git a/presets/models/phi/README.md b/presets/models/phi/README.md index 7caeadb84..1e77252a5 100644 --- a/presets/models/phi/README.md +++ b/presets/models/phi/README.md @@ -1,7 +1,7 @@ ## Supported Models |Model name| Model source | Sample workspace|Kubernetes Workload|Distributed inference| |----|:----:|:----:| :----: |:----: | -|phi-2 |[microsoft](https://huggingface.co/microsoft/phi-2)|[link](../../../examples/kaito_workspace_phi-2.yaml)|Deployment| false| +|phi-2 |[microsoft](https://huggingface.co/microsoft/phi-2)|[link](../../../examples/inference/kaito_workspace_phi-2.yaml)|Deployment| false| ## Image Source diff --git a/presets/models/phi/model.go b/presets/models/phi/model.go index 2e54dce38..37d92e673 100644 --- a/presets/models/phi/model.go +++ b/presets/models/phi/model.go @@ -36,8 +36,8 @@ var phiA phi2 type phi2 struct{} -func (*phi2) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*phi2) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "Phi", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), DiskStorageRequirement: "50Gi", From 1a4b5ac5d496c02712667fa8b2a389148c78ea36 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Mon, 18 Mar 2024 15:01:25 -0700 Subject: [PATCH 17/29] feat: Added validation checks for TuningSpec, DataSource, DataDestination --- api/v1alpha1/workspace_validation.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/api/v1alpha1/workspace_validation.go b/api/v1alpha1/workspace_validation.go index 79b27e1b9..aa9b77316 100644 --- a/api/v1alpha1/workspace_validation.go +++ b/api/v1alpha1/workspace_validation.go @@ -218,6 +218,17 @@ func (r *DataDestination) validateUpdate(old *DataDestination) (errs *apis.Field return errs } +func (r *DataDestination) validateUpdate(old *DataDestination) (errs *apis.FieldError) { + if old.HostPath != r.HostPath { + errs = errs.Also(apis.ErrInvalidValue("HostPath field cannot be changed once set", "HostPath")) + } + if old.Image != r.Image { + errs = errs.Also(apis.ErrInvalidValue("Image field cannot be changed once set", "Image")) + } + // TODO: Ensure ImageSecrets can be changed + return errs +} + func (r *ResourceSpec) validateCreate(inference InferenceSpec) (errs *apis.FieldError) { var presetName string if inference.Preset != nil { From 29210a7c9737b107558fc08e3971f7a7e2fc01d1 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Mon, 18 Mar 2024 18:06:49 -0700 Subject: [PATCH 18/29] fix: validation fixes --- api/v1alpha1/workspace_validation.go | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/api/v1alpha1/workspace_validation.go b/api/v1alpha1/workspace_validation.go index aa9b77316..79b27e1b9 100644 --- a/api/v1alpha1/workspace_validation.go +++ b/api/v1alpha1/workspace_validation.go @@ -218,17 +218,6 @@ func (r *DataDestination) validateUpdate(old *DataDestination) (errs *apis.Field return errs } -func (r *DataDestination) validateUpdate(old *DataDestination) (errs *apis.FieldError) { - if old.HostPath != r.HostPath { - errs = errs.Also(apis.ErrInvalidValue("HostPath field cannot be changed once set", "HostPath")) - } - if old.Image != r.Image { - errs = errs.Also(apis.ErrInvalidValue("Image field cannot be changed once set", "Image")) - } - // TODO: Ensure ImageSecrets can be changed - return errs -} - func (r *ResourceSpec) validateCreate(inference InferenceSpec) (errs *apis.FieldError) { var presetName string if inference.Preset != nil { From 4a93976b609ed0ed61f0a59dddbe93ab1a1993cc Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Wed, 20 Mar 2024 13:37:35 -0700 Subject: [PATCH 19/29] feat: Setup Interface for fine tuning --- api/v1alpha1/workspace_condition_types.go | 3 + api/v1alpha1/workspace_validation_test.go | 8 +- .../kaito_workspace_tuning_falcon_7b.yaml | 20 +++++ .../kaito_workspace_falcon_40b-instruct.yaml | 0 .../kaito_workspace_falcon_40b.yaml | 0 .../kaito_workspace_falcon_7b-instruct.yaml | 0 .../kaito_workspace_falcon_7b.yaml | 0 .../kaito_workspace_llama2_13b-chat.yaml | 0 .../kaito_workspace_llama2_13b.yaml | 0 .../kaito_workspace_llama2_70b-chat.yaml | 0 .../kaito_workspace_llama2_70b.yaml | 0 .../kaito_workspace_llama2_7b-chat.yaml | 0 .../kaito_workspace_llama2_7b.yaml | 0 .../kaito_workspace_mistral_7b-instruct.yaml | 0 .../kaito_workspace_mistral_7b.yaml | 0 .../kaito_workspace_phi-2.yaml | 0 pkg/controllers/workspace_controller.go | 75 +++++++++++++++++-- pkg/inference/preset-inferences.go | 10 +-- pkg/inference/preset-inferences_test.go | 2 +- pkg/model/interface.go | 7 +- pkg/tuning/preset-tuning-types.go | 21 ++++++ pkg/tuning/preset-tuning.go | 29 +++++++ pkg/utils/testModel.go | 8 +- presets/models/falcon/README.md | 8 +- presets/models/falcon/model.go | 55 +++++++++++--- presets/models/llama2/README.md | 6 +- presets/models/llama2/model.go | 21 ++++-- presets/models/llama2chat/README.md | 6 +- presets/models/llama2chat/model.go | 22 ++++-- presets/models/mistral/README.md | 4 +- presets/models/mistral/model.go | 27 ++++++- presets/models/phi/README.md | 2 +- presets/models/phi/model.go | 4 +- 33 files changed, 272 insertions(+), 66 deletions(-) create mode 100644 examples/fine-tuning/kaito_workspace_tuning_falcon_7b.yaml rename examples/{ => inference}/kaito_workspace_falcon_40b-instruct.yaml (100%) rename examples/{ => inference}/kaito_workspace_falcon_40b.yaml (100%) rename examples/{ => inference}/kaito_workspace_falcon_7b-instruct.yaml (100%) rename examples/{ => inference}/kaito_workspace_falcon_7b.yaml (100%) rename examples/{ => inference}/kaito_workspace_llama2_13b-chat.yaml (100%) rename examples/{ => inference}/kaito_workspace_llama2_13b.yaml (100%) rename examples/{ => inference}/kaito_workspace_llama2_70b-chat.yaml (100%) rename examples/{ => inference}/kaito_workspace_llama2_70b.yaml (100%) rename examples/{ => inference}/kaito_workspace_llama2_7b-chat.yaml (100%) rename examples/{ => inference}/kaito_workspace_llama2_7b.yaml (100%) rename examples/{ => inference}/kaito_workspace_mistral_7b-instruct.yaml (100%) rename examples/{ => inference}/kaito_workspace_mistral_7b.yaml (100%) rename examples/{ => inference}/kaito_workspace_phi-2.yaml (100%) create mode 100644 pkg/tuning/preset-tuning-types.go create mode 100644 pkg/tuning/preset-tuning.go diff --git a/api/v1alpha1/workspace_condition_types.go b/api/v1alpha1/workspace_condition_types.go index 762d8dafc..9845b8a0c 100644 --- a/api/v1alpha1/workspace_condition_types.go +++ b/api/v1alpha1/workspace_condition_types.go @@ -16,6 +16,9 @@ const ( // WorkspaceConditionTypeInferenceStatus is the state when Inference has been created. WorkspaceConditionTypeInferenceStatus = ConditionType("InferenceReady") + // WorkspaceConditionTypeTuningStatus is the state when Tuning has been created. + WorkspaceConditionTypeTuningStatus = ConditionType("TuningReady") + //WorkspaceConditionTypeDeleting is the Workspace state when starts to get deleted. WorkspaceConditionTypeDeleting = ConditionType("WorkspaceDeleting") diff --git a/api/v1alpha1/workspace_validation_test.go b/api/v1alpha1/workspace_validation_test.go index 11631e67b..695d42298 100644 --- a/api/v1alpha1/workspace_validation_test.go +++ b/api/v1alpha1/workspace_validation_test.go @@ -21,8 +21,8 @@ var perGPUMemoryRequirement string type testModel struct{} -func (*testModel) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*testModel) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ GPUCountRequirement: gpuCountRequirement, TotalGPUMemoryRequirement: totalGPUMemoryRequirement, PerGPUMemoryRequirement: perGPUMemoryRequirement, @@ -34,8 +34,8 @@ func (*testModel) SupportDistributedInference() bool { type testModelPrivate struct{} -func (*testModelPrivate) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*testModelPrivate) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ImageAccessMode: "private", GPUCountRequirement: gpuCountRequirement, TotalGPUMemoryRequirement: totalGPUMemoryRequirement, diff --git a/examples/fine-tuning/kaito_workspace_tuning_falcon_7b.yaml b/examples/fine-tuning/kaito_workspace_tuning_falcon_7b.yaml new file mode 100644 index 000000000..6d6ed7831 --- /dev/null +++ b/examples/fine-tuning/kaito_workspace_tuning_falcon_7b.yaml @@ -0,0 +1,20 @@ +apiVersion: kaito.sh/v1alpha1 +kind: Workspace +metadata: + name: workspace-tuning-falcon-7b +spec: + resource: + instanceType: "Standard_NC12s_v3" + labelSelector: + matchLabels: + app: tuning-falcon-7b + tuning: + preset: + name: falcon-7b + method: lora + config: tuning-config-map # ConfigMap containing tuning arguments + input: + name: tuning-data + hostPath: /path/to/your/input/data # dataset on node + output: + hostPath: /path/to/store/output # Tuning Output diff --git a/examples/kaito_workspace_falcon_40b-instruct.yaml b/examples/inference/kaito_workspace_falcon_40b-instruct.yaml similarity index 100% rename from examples/kaito_workspace_falcon_40b-instruct.yaml rename to examples/inference/kaito_workspace_falcon_40b-instruct.yaml diff --git a/examples/kaito_workspace_falcon_40b.yaml b/examples/inference/kaito_workspace_falcon_40b.yaml similarity index 100% rename from examples/kaito_workspace_falcon_40b.yaml rename to examples/inference/kaito_workspace_falcon_40b.yaml diff --git a/examples/kaito_workspace_falcon_7b-instruct.yaml b/examples/inference/kaito_workspace_falcon_7b-instruct.yaml similarity index 100% rename from examples/kaito_workspace_falcon_7b-instruct.yaml rename to examples/inference/kaito_workspace_falcon_7b-instruct.yaml diff --git a/examples/kaito_workspace_falcon_7b.yaml b/examples/inference/kaito_workspace_falcon_7b.yaml similarity index 100% rename from examples/kaito_workspace_falcon_7b.yaml rename to examples/inference/kaito_workspace_falcon_7b.yaml diff --git a/examples/kaito_workspace_llama2_13b-chat.yaml b/examples/inference/kaito_workspace_llama2_13b-chat.yaml similarity index 100% rename from examples/kaito_workspace_llama2_13b-chat.yaml rename to examples/inference/kaito_workspace_llama2_13b-chat.yaml diff --git a/examples/kaito_workspace_llama2_13b.yaml b/examples/inference/kaito_workspace_llama2_13b.yaml similarity index 100% rename from examples/kaito_workspace_llama2_13b.yaml rename to examples/inference/kaito_workspace_llama2_13b.yaml diff --git a/examples/kaito_workspace_llama2_70b-chat.yaml b/examples/inference/kaito_workspace_llama2_70b-chat.yaml similarity index 100% rename from examples/kaito_workspace_llama2_70b-chat.yaml rename to examples/inference/kaito_workspace_llama2_70b-chat.yaml diff --git a/examples/kaito_workspace_llama2_70b.yaml b/examples/inference/kaito_workspace_llama2_70b.yaml similarity index 100% rename from examples/kaito_workspace_llama2_70b.yaml rename to examples/inference/kaito_workspace_llama2_70b.yaml diff --git a/examples/kaito_workspace_llama2_7b-chat.yaml b/examples/inference/kaito_workspace_llama2_7b-chat.yaml similarity index 100% rename from examples/kaito_workspace_llama2_7b-chat.yaml rename to examples/inference/kaito_workspace_llama2_7b-chat.yaml diff --git a/examples/kaito_workspace_llama2_7b.yaml b/examples/inference/kaito_workspace_llama2_7b.yaml similarity index 100% rename from examples/kaito_workspace_llama2_7b.yaml rename to examples/inference/kaito_workspace_llama2_7b.yaml diff --git a/examples/kaito_workspace_mistral_7b-instruct.yaml b/examples/inference/kaito_workspace_mistral_7b-instruct.yaml similarity index 100% rename from examples/kaito_workspace_mistral_7b-instruct.yaml rename to examples/inference/kaito_workspace_mistral_7b-instruct.yaml diff --git a/examples/kaito_workspace_mistral_7b.yaml b/examples/inference/kaito_workspace_mistral_7b.yaml similarity index 100% rename from examples/kaito_workspace_mistral_7b.yaml rename to examples/inference/kaito_workspace_mistral_7b.yaml diff --git a/examples/kaito_workspace_phi-2.yaml b/examples/inference/kaito_workspace_phi-2.yaml similarity index 100% rename from examples/kaito_workspace_phi-2.yaml rename to examples/inference/kaito_workspace_phi-2.yaml diff --git a/pkg/controllers/workspace_controller.go b/pkg/controllers/workspace_controller.go index a2e4fc18d..db1150ad5 100644 --- a/pkg/controllers/workspace_controller.go +++ b/pkg/controllers/workspace_controller.go @@ -5,6 +5,7 @@ package controllers import ( "context" "fmt" + "github.com/azure/kaito/pkg/tuning" "sort" "strings" "time" @@ -109,16 +110,27 @@ func (c *WorkspaceReconciler) addOrUpdateWorkspace(ctx context.Context, wObj *ka return reconcile.Result{}, err } - if err = c.applyInference(ctx, wObj); err != nil { - if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeReady, metav1.ConditionFalse, - "workspaceFailed", err.Error()); updateErr != nil { - klog.ErrorS(updateErr, "failed to update workspace status", "workspace", klog.KObj(wObj)) - return reconcile.Result{}, updateErr + if wObj.Tuning != nil { + if err = c.applyTuning(ctx, wObj); err != nil { + if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeReady, metav1.ConditionFalse, + "workspaceFailed", err.Error()); updateErr != nil { + klog.ErrorS(updateErr, "failed to update workspace status", "workspace", klog.KObj(wObj)) + return reconcile.Result{}, updateErr + } + return reconcile.Result{}, err + } + } + if wObj.Inference != nil { + if err = c.applyInference(ctx, wObj); err != nil { + if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeReady, metav1.ConditionFalse, + "workspaceFailed", err.Error()); updateErr != nil { + klog.ErrorS(updateErr, "failed to update workspace status", "workspace", klog.KObj(wObj)) + return reconcile.Result{}, updateErr + } + return reconcile.Result{}, err } - return reconcile.Result{}, err } - // TODO apply TrainingSpec if err = c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeReady, metav1.ConditionTrue, "workspaceReady", "workspace is ready"); err != nil { klog.ErrorS(err, "failed to update workspace status", "workspace", klog.KObj(wObj)) @@ -423,6 +435,55 @@ func (c *WorkspaceReconciler) ensureService(ctx context.Context, wObj *kaitov1al return nil } +func (c *WorkspaceReconciler) applyTuning(ctx context.Context, wObj *kaitov1alpha1.Workspace) error { + var err error + func() { + if wObj.Tuning.Preset != nil { + presetName := string(wObj.Tuning.Preset.Name) + model := plugin.KaitoModelRegister.MustGet(presetName) + + trainingParam := model.GetTrainingParameters() + + var existingObj client.Object + existingObj = &appsv1.Deployment{} + if err = resources.GetResource(ctx, wObj.Name, wObj.Namespace, c.Client, existingObj); err == nil { + klog.InfoS("A training workload already exists for workspace", "workspace", klog.KObj(wObj)) + if err = resources.CheckResourceStatus(existingObj, c.Client, trainingParam.DeploymentTimeout); err != nil { + return + } + } else if apierrors.IsNotFound(err) { + var workloadObj client.Object + // Need to create a new workload + workloadObj, err = tuning.CreatePresetTuning(ctx, wObj, trainingParam, c.Client) + if err != nil { + return + } + if err = resources.CheckResourceStatus(workloadObj, c.Client, trainingParam.DeploymentTimeout); err != nil { + return + } + } + } + }() + + if err != nil { + if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeTuningStatus, metav1.ConditionFalse, + "WorkspaceTuningStatusFailed", err.Error()); updateErr != nil { + klog.ErrorS(updateErr, "failed to update workspace status", "workspace", klog.KObj(wObj)) + return updateErr + } else { + return err + + } + } + + if err := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeTuningStatus, metav1.ConditionTrue, + "WorkspaceTuningStatusSuccess", "Tuning has been deployed successfully"); err != nil { + klog.ErrorS(err, "failed to update workspace status", "workspace", klog.KObj(wObj)) + return err + } + return nil +} + // applyInference applies inference spec. func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1alpha1.Workspace) error { var err error diff --git a/pkg/inference/preset-inferences.go b/pkg/inference/preset-inferences.go index 9b02012b7..4c4792b54 100644 --- a/pkg/inference/preset-inferences.go +++ b/pkg/inference/preset-inferences.go @@ -67,7 +67,7 @@ var ( } ) -func updateTorchParamsForDistributedInference(ctx context.Context, kubeClient client.Client, wObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetInferenceParam) error { +func updateTorchParamsForDistributedInference(ctx context.Context, kubeClient client.Client, wObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetParam) error { existingService := &corev1.Service{} err := resources.GetResource(ctx, wObj.Name, wObj.Namespace, kubeClient, existingService) if err != nil { @@ -92,7 +92,7 @@ func updateTorchParamsForDistributedInference(ctx context.Context, kubeClient cl return nil } -func GetImageInfo(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetInferenceParam) (string, []corev1.LocalObjectReference) { +func GetImageInfo(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetParam) (string, []corev1.LocalObjectReference) { imageName := string(workspaceObj.Inference.Preset.Name) imageTag := inferenceObj.Tag imagePullSecretRefs := []corev1.LocalObjectReference{} @@ -110,7 +110,7 @@ func GetImageInfo(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, in } func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, - inferenceObj *model.PresetInferenceParam, supportDistributedInference bool, kubeClient client.Client) (client.Object, error) { + inferenceObj *model.PresetParam, supportDistributedInference bool, kubeClient client.Client) (client.Object, error) { if inferenceObj.TorchRunParams != nil && supportDistributedInference { if err := updateTorchParamsForDistributedInference(ctx, kubeClient, workspaceObj, inferenceObj); err != nil { klog.ErrorS(err, "failed to update torch params", "workspace", workspaceObj) @@ -141,7 +141,7 @@ func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Work // torchrun baseCommand // and sets the GPU resources required for inference. // Returns the command and resource configuration. -func prepareInferenceParameters(ctx context.Context, inferenceObj *model.PresetInferenceParam) ([]string, corev1.ResourceRequirements) { +func prepareInferenceParameters(ctx context.Context, inferenceObj *model.PresetParam) ([]string, corev1.ResourceRequirements) { torchCommand := buildCommandStr(inferenceObj.BaseCommand, inferenceObj.TorchRunParams) torchCommand = buildCommandStr(torchCommand, inferenceObj.TorchRunRdzvParams) modelCommand := buildCommandStr(InferenceFile, inferenceObj.ModelRunParams) @@ -159,7 +159,7 @@ func prepareInferenceParameters(ctx context.Context, inferenceObj *model.PresetI return commands, resourceRequirements } -func configVolume(wObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetInferenceParam) ([]corev1.Volume, []corev1.VolumeMount) { +func configVolume(wObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetParam) ([]corev1.Volume, []corev1.VolumeMount) { volume := []corev1.Volume{} volumeMount := []corev1.VolumeMount{} diff --git a/pkg/inference/preset-inferences_test.go b/pkg/inference/preset-inferences_test.go index cd8df067c..31bf0551e 100644 --- a/pkg/inference/preset-inferences_test.go +++ b/pkg/inference/preset-inferences_test.go @@ -62,7 +62,7 @@ func TestCreatePresetInference(t *testing.T) { useHeadlessSvc := false - var inferenceObj *model.PresetInferenceParam + var inferenceObj *model.PresetParam model := plugin.KaitoModelRegister.MustGet(tc.modelName) inferenceObj = model.GetInferenceParameters() diff --git a/pkg/model/interface.go b/pkg/model/interface.go index 217c1f889..2763b5eec 100644 --- a/pkg/model/interface.go +++ b/pkg/model/interface.go @@ -7,12 +7,13 @@ import ( ) type Model interface { - GetInferenceParameters() *PresetInferenceParam + GetInferenceParameters() *PresetParam + GetTrainingParameters() *PresetParam SupportDistributedInference() bool //If true, the model workload will be a StatefulSet, using the torch elastic runtime framework. } -// PresetInferenceParam defines the preset inference parameters for a model. -type PresetInferenceParam struct { +// PresetParam defines the preset inference parameters for a model. +type PresetParam struct { ModelFamilyName string // The name of the model family. ImageAccessMode string // Defines where the Image is Public or Private. DiskStorageRequirement string // Disk storage requirements for the model. diff --git a/pkg/tuning/preset-tuning-types.go b/pkg/tuning/preset-tuning-types.go new file mode 100644 index 000000000..51f36511d --- /dev/null +++ b/pkg/tuning/preset-tuning-types.go @@ -0,0 +1,21 @@ +package tuning + +import corev1 "k8s.io/api/core/v1" + +const ( + DefaultNumProcesses = "1" + DefaultNumMachines = "1" + DefaultMachineRank = "0" + DefaultGPUIds = "all" +) + +var ( + DefaultAccelerateParams = map[string]string{ + "num_processes": DefaultNumProcesses, + "num_machines": DefaultNumMachines, + "machine_rank": DefaultMachineRank, + "gpu_ids": DefaultGPUIds, + } + + DefaultImagePullSecrets = []corev1.LocalObjectReference{} +) diff --git a/pkg/tuning/preset-tuning.go b/pkg/tuning/preset-tuning.go new file mode 100644 index 000000000..cbbb55a06 --- /dev/null +++ b/pkg/tuning/preset-tuning.go @@ -0,0 +1,29 @@ +package tuning + +import ( + "context" + kaitov1alpha1 "github.com/azure/kaito/api/v1alpha1" + "github.com/azure/kaito/pkg/model" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +func CreatePresetTuning(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, + tuningObj *model.PresetParam, kubeClient client.Client) (client.Object, error) { + // TODO + + // e.g. example from Inference + //volume, volumeMount := configVolume(workspaceObj, inferenceObj) + //commands, resourceReq := prepareInferenceParameters(ctx, inferenceObj) + //image, imagePullSecrets := GetImageInfo(ctx, workspaceObj, inferenceObj) + // + //depObj = resources.GenerateDeploymentManifest(ctx, workspaceObj, image, imagePullSecrets, *workspaceObj.Resource.Count, commands, + // containerPorts, livenessProbe, readinessProbe, resourceReq, tolerations, volume, volumeMount) + // + //err := resources.CreateResource(ctx, depObj, kubeClient) + //if client.IgnoreAlreadyExists(err) != nil { + // return nil, err + //} + //return depObj, nil + + return nil, nil +} diff --git a/pkg/utils/testModel.go b/pkg/utils/testModel.go index 99e3d8aca..50f6c9175 100644 --- a/pkg/utils/testModel.go +++ b/pkg/utils/testModel.go @@ -12,8 +12,8 @@ import ( type testModel struct{} -func (*testModel) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*testModel) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ GPUCountRequirement: "1", DeploymentTimeout: time.Duration(30) * time.Minute, } @@ -24,8 +24,8 @@ func (*testModel) SupportDistributedInference() bool { type testDistributedModel struct{} -func (*testDistributedModel) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*testDistributedModel) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ GPUCountRequirement: "1", DeploymentTimeout: time.Duration(30) * time.Minute, } diff --git a/presets/models/falcon/README.md b/presets/models/falcon/README.md index 81a1ced6f..e8cd895e6 100644 --- a/presets/models/falcon/README.md +++ b/presets/models/falcon/README.md @@ -1,10 +1,10 @@ ## Supported Models |Model name| Model source | Sample workspace|Kubernetes Workload|Distributed inference| |----|:----:|:----:| :----: |:----: | -|falcon-7b-instruct |[tiiuae](https://huggingface.co/tiiuae/falcon-7b-instruct)|[link](../../../examples/kaito_workspace_falcon_7b-instruct.yaml)|Deployment| false| -|falcon-7b |[tiiuae](https://huggingface.co/tiiuae/falcon-7b) |[link](../../../examples/kaito_workspace_falcon_7b.yaml)|Deployment| false| -|falcon-40b-instruct|[tiiuae](https://huggingface.co/tiiuae/falcon-40b-instruct) |[link](../../../examples/kaito_workspace_falcon_40b-instruct.yaml)|Deployment| false| -|falcon-40b |[tiiuae](https://huggingface.co/tiiuae/falcon-40b)|[link](../../../examples/kaito_workspace_falcon_40b.yaml)|Deployment| false| +|falcon-7b-instruct |[tiiuae](https://huggingface.co/tiiuae/falcon-7b-instruct)|[link](../../../examples/inference/kaito_workspace_falcon_7b-instruct.yaml)|Deployment| false| +|falcon-7b |[tiiuae](https://huggingface.co/tiiuae/falcon-7b) |[link](../../../examples/inference/kaito_workspace_falcon_7b.yaml)|Deployment| false| +|falcon-40b-instruct|[tiiuae](https://huggingface.co/tiiuae/falcon-40b-instruct) |[link](../../../examples/inference/kaito_workspace_falcon_40b-instruct.yaml)|Deployment| false| +|falcon-40b |[tiiuae](https://huggingface.co/tiiuae/falcon-40b)|[link](../../../examples/inference/kaito_workspace_falcon_40b.yaml)|Deployment| false| ## Image Source - **Public**: Kaito maintainers manage the lifecycle of the inference service images that contain model weights. The images are available in Microsoft Container Registry (MCR). diff --git a/presets/models/falcon/model.go b/presets/models/falcon/model.go index 7501dce23..c0f07495f 100644 --- a/presets/models/falcon/model.go +++ b/presets/models/falcon/model.go @@ -54,8 +54,8 @@ var falconA falcon7b type falcon7b struct{} -func (*falcon7b) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*falcon7b) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "Falcon", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), DiskStorageRequirement: "50Gi", @@ -68,8 +68,23 @@ func (*falcon7b) GetInferenceParameters() *model.PresetInferenceParam { BaseCommand: baseCommandPresetFalcon, Tag: PresetFalconTagMap["Falcon7B"], } - } +func (*falcon7b) GetTrainingParameters() *model.PresetParam { + return &model.PresetParam{ + ModelFamilyName: "Falcon", + ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), + DiskStorageRequirement: "50Gi", + GPUCountRequirement: "2", + TotalGPUMemoryRequirement: "16Gi", + PerGPUMemoryRequirement: "16Gi", + //TorchRunParams: tuning.DefaultAccelerateParams, // TODO + //ModelRunPrams: falconRunTuningParams, // TODO + DeploymentTimeout: time.Duration(30) * time.Minute, + BaseCommand: baseCommandPresetFalcon, + Tag: PresetFalconTagMap["Falcon7B"], + } +} + func (*falcon7b) SupportDistributedInference() bool { return false } @@ -78,8 +93,8 @@ var falconB falcon7bInst type falcon7bInst struct{} -func (*falcon7bInst) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*falcon7bInst) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "Falcon", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), DiskStorageRequirement: "50Gi", @@ -94,6 +109,9 @@ func (*falcon7bInst) GetInferenceParameters() *model.PresetInferenceParam { } } +func (*falcon7bInst) GetTrainingParameters() *model.PresetParam { + return nil // It is not recommended/ideal to further fine-tune instruct models - Already been fine-tuned +} func (*falcon7bInst) SupportDistributedInference() bool { return false } @@ -102,8 +120,8 @@ var falconC falcon40b type falcon40b struct{} -func (*falcon40b) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*falcon40b) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "Falcon", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), DiskStorageRequirement: "400", @@ -118,6 +136,21 @@ func (*falcon40b) GetInferenceParameters() *model.PresetInferenceParam { } } +func (*falcon40b) GetTrainingParameters() *model.PresetParam { + return &model.PresetParam{ + ModelFamilyName: "Falcon", + ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), + DiskStorageRequirement: "50Gi", + GPUCountRequirement: "2", + TotalGPUMemoryRequirement: "90Gi", + PerGPUMemoryRequirement: "16Gi", + //TorchRunParams: tuning.DefaultAccelerateParams, // TODO + //ModelRunPrams: falconRunTuningParams, // TODO + DeploymentTimeout: time.Duration(30) * time.Minute, + BaseCommand: baseCommandPresetFalcon, + Tag: PresetFalconTagMap["Falcon40B"], + } +} func (*falcon40b) SupportDistributedInference() bool { return false } @@ -126,8 +159,8 @@ var falconD falcon40bInst type falcon40bInst struct{} -func (*falcon40bInst) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*falcon40bInst) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "Falcon", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), DiskStorageRequirement: "400", @@ -141,7 +174,9 @@ func (*falcon40bInst) GetInferenceParameters() *model.PresetInferenceParam { Tag: PresetFalconTagMap["Falcon40BInstruct"], } } - +func (*falcon40bInst) GetTrainingParameters() *model.PresetParam { + return nil // It is not recommended/ideal to further fine-tune instruct models - Already been fine-tuned +} func (*falcon40bInst) SupportDistributedInference() bool { return false } diff --git a/presets/models/llama2/README.md b/presets/models/llama2/README.md index e6a40563a..ba2646a2b 100644 --- a/presets/models/llama2/README.md +++ b/presets/models/llama2/README.md @@ -1,9 +1,9 @@ ## Supported Models |Model name| Model source | Sample workspace|Kubernetes Workload|Distributed inference| |----|:----:|:----:| :----: |:----: | -|llama2-7b |[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/kaito_workspace_llama2_7b.yaml)|Deployment| false| -|llama2-13b|[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/kaito_workspace_llama2_13b.yaml)|StatefulSet| true| -|llama2-70b|[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/kaito_workspace_llama2_70b.yaml)|StatefulSet| true| +|llama2-7b |[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/inference/kaito_workspace_llama2_7b.yaml)|Deployment| false| +|llama2-13b|[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/inference/kaito_workspace_llama2_13b.yaml)|StatefulSet| true| +|llama2-70b|[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/inference/kaito_workspace_llama2_70b.yaml)|StatefulSet| true| ## Image Source - **Private**: User needs to manage the lifecycle of the inference service images that contain model weights (e.g., managing image tags). The images are available in user's private container registry. diff --git a/presets/models/llama2/model.go b/presets/models/llama2/model.go index 30c97b7fd..673a67da3 100644 --- a/presets/models/llama2/model.go +++ b/presets/models/llama2/model.go @@ -38,8 +38,8 @@ var llama2A llama2Text7b type llama2Text7b struct{} -func (*llama2Text7b) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*llama2Text7b) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "LLaMa2", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePrivate), DiskStorageRequirement: "34Gi", @@ -56,6 +56,9 @@ func (*llama2Text7b) GetInferenceParameters() *model.PresetInferenceParam { } } +func (*llama2Text7b) GetTrainingParameters() *model.PresetParam { + return nil // Currently doesn't support fine-tuning +} func (*llama2Text7b) SupportDistributedInference() bool { return false } @@ -64,8 +67,8 @@ var llama2B llama2Text13b type llama2Text13b struct{} -func (*llama2Text13b) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*llama2Text13b) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "LLaMa2", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePrivate), DiskStorageRequirement: "46Gi", @@ -81,6 +84,9 @@ func (*llama2Text13b) GetInferenceParameters() *model.PresetInferenceParam { // Tag: llama has private image access mode. The image tag is determined by the user. } } +func (*llama2Text13b) GetTrainingParameters() *model.PresetParam { + return nil // Currently doesn't support fine-tuning +} func (*llama2Text13b) SupportDistributedInference() bool { return true } @@ -89,8 +95,8 @@ var llama2C llama2Text70b type llama2Text70b struct{} -func (*llama2Text70b) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*llama2Text70b) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "LLaMa2", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePrivate), DiskStorageRequirement: "158Gi", @@ -106,6 +112,9 @@ func (*llama2Text70b) GetInferenceParameters() *model.PresetInferenceParam { // Tag: llama has private image access mode. The image tag is determined by the user. } } +func (*llama2Text70b) GetTrainingParameters() *model.PresetParam { + return nil // Currently doesn't support fine-tuning +} func (*llama2Text70b) SupportDistributedInference() bool { return true } diff --git a/presets/models/llama2chat/README.md b/presets/models/llama2chat/README.md index 53e241fab..0cf9ec3be 100644 --- a/presets/models/llama2chat/README.md +++ b/presets/models/llama2chat/README.md @@ -1,9 +1,9 @@ ## Supported Models |Model name| Model source | Sample workspace|Kubernetes Workload|Distributed inference| |----|:----:|:----:| :----: |:----: | -|llama2-7b-chat |[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/kaito_workspace_llama2_7b-chat.yaml)|Deployment| false| -|llama2-13b-chat|[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/kaito_workspace_llama2_13b-chat.yaml)|StatefulSet| true| -|llama2-70b-chat|[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/kaito_workspace_llama2_70b-chat.yaml)|StatefulSet| true| +|llama2-7b-chat |[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/inference/kaito_workspace_llama2_7b-chat.yaml)|Deployment| false| +|llama2-13b-chat|[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/inference/kaito_workspace_llama2_13b-chat.yaml)|StatefulSet| true| +|llama2-70b-chat|[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/inference/kaito_workspace_llama2_70b-chat.yaml)|StatefulSet| true| ## Image Source - **Private**: User needs to manage the lifecycle of the inference service images that contain model weights (e.g., managing image tags). The images are available in user's private container registry. diff --git a/presets/models/llama2chat/model.go b/presets/models/llama2chat/model.go index cc0d8d4c6..a555ebc07 100644 --- a/presets/models/llama2chat/model.go +++ b/presets/models/llama2chat/model.go @@ -38,8 +38,8 @@ var llama2chatA llama2Chat7b type llama2Chat7b struct{} -func (*llama2Chat7b) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*llama2Chat7b) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "LLaMa2", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePrivate), DiskStorageRequirement: "34Gi", @@ -54,7 +54,9 @@ func (*llama2Chat7b) GetInferenceParameters() *model.PresetInferenceParam { WorldSize: 1, // Tag: llama has private image access mode. The image tag is determined by the user. } - +} +func (*llama2Chat7b) GetTrainingParameters() *model.PresetParam { + return nil // Currently doesn't support fine-tuning } func (*llama2Chat7b) SupportDistributedInference() bool { return false @@ -64,8 +66,8 @@ var llama2chatB llama2Chat13b type llama2Chat13b struct{} -func (*llama2Chat13b) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*llama2Chat13b) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "LLaMa2", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePrivate), DiskStorageRequirement: "46Gi", @@ -81,6 +83,9 @@ func (*llama2Chat13b) GetInferenceParameters() *model.PresetInferenceParam { // Tag: llama has private image access mode. The image tag is determined by the user. } } +func (*llama2Chat13b) GetTrainingParameters() *model.PresetParam { + return nil // Currently doesn't support fine-tuning +} func (*llama2Chat13b) SupportDistributedInference() bool { return true } @@ -89,8 +94,8 @@ var llama2chatC llama2Chat70b type llama2Chat70b struct{} -func (*llama2Chat70b) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*llama2Chat70b) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "LLaMa2", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePrivate), DiskStorageRequirement: "158Gi", @@ -106,6 +111,9 @@ func (*llama2Chat70b) GetInferenceParameters() *model.PresetInferenceParam { // Tag: llama has private image access mode. The image tag is determined by the user. } } +func (*llama2Chat70b) GetTrainingParameters() *model.PresetParam { + return nil // Currently doesn't support fine-tuning +} func (*llama2Chat70b) SupportDistributedInference() bool { return true } diff --git a/presets/models/mistral/README.md b/presets/models/mistral/README.md index 4d0c56ba6..2d037f7a4 100644 --- a/presets/models/mistral/README.md +++ b/presets/models/mistral/README.md @@ -1,8 +1,8 @@ ## Supported Models |Model name| Model source | Sample workspace|Kubernetes Workload|Distributed inference| |----|:----:|:----:| :----: |:----: | -|mistral-7b-instruct |[mistralai](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2)|[link](../../../examples/kaito_workspace_mistral_7b-instruct.yaml)|Deployment| false| -|mistral-7b |[mistralai](https://huggingface.co/mistralai/Mistral-7B-v0.1) |[link](../../../examples/kaito_workspace_mistral_7b.yaml)|Deployment| false| +|mistral-7b-instruct |[mistralai](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2)|[link](../../../examples/inference/kaito_workspace_mistral_7b-instruct.yaml)|Deployment| false| +|mistral-7b |[mistralai](https://huggingface.co/mistralai/Mistral-7B-v0.1) |[link](../../../examples/inference/kaito_workspace_mistral_7b.yaml)|Deployment| false| ## Image Source diff --git a/presets/models/mistral/model.go b/presets/models/mistral/model.go index 7089eafb6..bcf06203b 100644 --- a/presets/models/mistral/model.go +++ b/presets/models/mistral/model.go @@ -42,8 +42,8 @@ var mistralA mistral7b type mistral7b struct{} -func (*mistral7b) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*mistral7b) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "Mistral", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), DiskStorageRequirement: "100Gi", @@ -58,6 +58,22 @@ func (*mistral7b) GetInferenceParameters() *model.PresetInferenceParam { } } +func (*mistral7b) GetTrainingParameters() *model.PresetParam { + return &model.PresetParam{ + ModelFamilyName: "Mistral", + ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), + DiskStorageRequirement: "100Gi", + GPUCountRequirement: "1", + TotalGPUMemoryRequirement: "16Gi", + PerGPUMemoryRequirement: "16Gi", // We run Mistral using native vertical model parallel, no per GPU memory requirement. + //TorchRunParams: tuning.DefaultAccelerateParams, + //ModelRunParams: mistralRunParams, + DeploymentTimeout: time.Duration(30) * time.Minute, + BaseCommand: baseCommandPresetMistral, + Tag: PresetMistralTagMap["Mistral7B"], + } +} + func (*mistral7b) SupportDistributedInference() bool { return false } @@ -66,8 +82,8 @@ var mistralB mistral7bInst type mistral7bInst struct{} -func (*mistral7bInst) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*mistral7bInst) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "Mistral", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), DiskStorageRequirement: "100Gi", @@ -82,6 +98,9 @@ func (*mistral7bInst) GetInferenceParameters() *model.PresetInferenceParam { } } +func (*mistral7bInst) GetTrainingParameters() *model.PresetParam { + return nil // It is not recommended/ideal to further fine-tune instruct models - Already been fine-tuned +} func (*mistral7bInst) SupportDistributedInference() bool { return false } diff --git a/presets/models/phi/README.md b/presets/models/phi/README.md index 7caeadb84..1e77252a5 100644 --- a/presets/models/phi/README.md +++ b/presets/models/phi/README.md @@ -1,7 +1,7 @@ ## Supported Models |Model name| Model source | Sample workspace|Kubernetes Workload|Distributed inference| |----|:----:|:----:| :----: |:----: | -|phi-2 |[microsoft](https://huggingface.co/microsoft/phi-2)|[link](../../../examples/kaito_workspace_phi-2.yaml)|Deployment| false| +|phi-2 |[microsoft](https://huggingface.co/microsoft/phi-2)|[link](../../../examples/inference/kaito_workspace_phi-2.yaml)|Deployment| false| ## Image Source diff --git a/presets/models/phi/model.go b/presets/models/phi/model.go index 2e54dce38..37d92e673 100644 --- a/presets/models/phi/model.go +++ b/presets/models/phi/model.go @@ -36,8 +36,8 @@ var phiA phi2 type phi2 struct{} -func (*phi2) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*phi2) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "Phi", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), DiskStorageRequirement: "50Gi", From 36977c436a5c28e20033f4f67e6565b0235c2ddd Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Mon, 18 Mar 2024 10:37:08 -0700 Subject: [PATCH 20/29] feat: spec level validation --- api/v1alpha1/workspace_validation.go | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/api/v1alpha1/workspace_validation.go b/api/v1alpha1/workspace_validation.go index 79b27e1b9..50f4cd979 100644 --- a/api/v1alpha1/workspace_validation.go +++ b/api/v1alpha1/workspace_validation.go @@ -218,6 +218,15 @@ func (r *DataDestination) validateUpdate(old *DataDestination) (errs *apis.Field return errs } +func (w *Workspace) validateCreate() (errs *apis.FieldError) { + inferenceSpecified := w.Inference.Preset != nil || w.Inference.Template != nil + tuningSpecified := w.Tuning.Input != nil + if inferenceSpecified != tuningSpecified { + return errs.Also(apis.ErrGeneric("Either Inference or Tuning must be specified, but not both", "")) + } + return errs +} + func (r *ResourceSpec) validateCreate(inference InferenceSpec) (errs *apis.FieldError) { var presetName string if inference.Preset != nil { @@ -265,6 +274,21 @@ func (r *ResourceSpec) validateCreate(inference InferenceSpec) (errs *apis.Field return errs } +func (w *Workspace) validateUpdate(old *Workspace) (errs *apis.FieldError) { + // Check inference specified + oldInferenceSpecified := old.Inference.Preset != nil || old.Inference.Template != nil + inferenceSpecified := w.Inference.Preset != nil || w.Inference.Template != nil + // Check tuning specified + oldTuningSpecified := old.Tuning.Input != nil + tuningSpecified := w.Tuning.Input != nil + + // inference/tuning can be changed, but cannot be set/unset. + if (!oldInferenceSpecified && inferenceSpecified) || (!oldTuningSpecified && tuningSpecified) { + errs = errs.Also(apis.ErrGeneric("field cannot be unset/set if it was set/unset", "spec")) + } + return errs +} + func (r *ResourceSpec) validateUpdate(old *ResourceSpec) (errs *apis.FieldError) { // We disable changing node count for now. if r.Count != nil && old.Count != nil && *r.Count != *old.Count { From b738b47ae5a92bcd99d20e336f37b0e5176a9798 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Mon, 18 Mar 2024 15:01:25 -0700 Subject: [PATCH 21/29] feat: Added validation checks for TuningSpec, DataSource, DataDestination --- api/v1alpha1/workspace_validation.go | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/api/v1alpha1/workspace_validation.go b/api/v1alpha1/workspace_validation.go index 50f4cd979..79b27e1b9 100644 --- a/api/v1alpha1/workspace_validation.go +++ b/api/v1alpha1/workspace_validation.go @@ -218,15 +218,6 @@ func (r *DataDestination) validateUpdate(old *DataDestination) (errs *apis.Field return errs } -func (w *Workspace) validateCreate() (errs *apis.FieldError) { - inferenceSpecified := w.Inference.Preset != nil || w.Inference.Template != nil - tuningSpecified := w.Tuning.Input != nil - if inferenceSpecified != tuningSpecified { - return errs.Also(apis.ErrGeneric("Either Inference or Tuning must be specified, but not both", "")) - } - return errs -} - func (r *ResourceSpec) validateCreate(inference InferenceSpec) (errs *apis.FieldError) { var presetName string if inference.Preset != nil { @@ -274,21 +265,6 @@ func (r *ResourceSpec) validateCreate(inference InferenceSpec) (errs *apis.Field return errs } -func (w *Workspace) validateUpdate(old *Workspace) (errs *apis.FieldError) { - // Check inference specified - oldInferenceSpecified := old.Inference.Preset != nil || old.Inference.Template != nil - inferenceSpecified := w.Inference.Preset != nil || w.Inference.Template != nil - // Check tuning specified - oldTuningSpecified := old.Tuning.Input != nil - tuningSpecified := w.Tuning.Input != nil - - // inference/tuning can be changed, but cannot be set/unset. - if (!oldInferenceSpecified && inferenceSpecified) || (!oldTuningSpecified && tuningSpecified) { - errs = errs.Also(apis.ErrGeneric("field cannot be unset/set if it was set/unset", "spec")) - } - return errs -} - func (r *ResourceSpec) validateUpdate(old *ResourceSpec) (errs *apis.FieldError) { // We disable changing node count for now. if r.Count != nil && old.Count != nil && *r.Count != *old.Count { From 33c6b24bd70b79eaeaaf467f5d8afa29c8bc96eb Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Thu, 21 Mar 2024 09:25:31 -0400 Subject: [PATCH 22/29] fix: Add required training func for tests --- api/v1alpha1/workspace_validation_test.go | 15 +++++++++++++++ pkg/utils/testModel.go | 12 ++++++++++++ 2 files changed, 27 insertions(+) diff --git a/api/v1alpha1/workspace_validation_test.go b/api/v1alpha1/workspace_validation_test.go index 695d42298..93e456566 100644 --- a/api/v1alpha1/workspace_validation_test.go +++ b/api/v1alpha1/workspace_validation_test.go @@ -28,6 +28,13 @@ func (*testModel) GetInferenceParameters() *model.PresetParam { PerGPUMemoryRequirement: perGPUMemoryRequirement, } } +func (*testModel) GetTrainingParameters() *model.PresetParam { + return &model.PresetParam{ + GPUCountRequirement: gpuCountRequirement, + TotalGPUMemoryRequirement: totalGPUMemoryRequirement, + PerGPUMemoryRequirement: perGPUMemoryRequirement, + } +} func (*testModel) SupportDistributedInference() bool { return false } @@ -42,6 +49,14 @@ func (*testModelPrivate) GetInferenceParameters() *model.PresetParam { PerGPUMemoryRequirement: perGPUMemoryRequirement, } } +func (*testModelPrivate) GetTrainingParameters() *model.PresetParam { + return &model.PresetParam{ + ImageAccessMode: "private", + GPUCountRequirement: gpuCountRequirement, + TotalGPUMemoryRequirement: totalGPUMemoryRequirement, + PerGPUMemoryRequirement: perGPUMemoryRequirement, + } +} func (*testModelPrivate) SupportDistributedInference() bool { return false } diff --git a/pkg/utils/testModel.go b/pkg/utils/testModel.go index 50f6c9175..7dd04fa72 100644 --- a/pkg/utils/testModel.go +++ b/pkg/utils/testModel.go @@ -18,6 +18,12 @@ func (*testModel) GetInferenceParameters() *model.PresetParam { DeploymentTimeout: time.Duration(30) * time.Minute, } } +func (*testModel) GetTrainingParameters() *model.PresetParam { + return &model.PresetParam{ + GPUCountRequirement: "1", + DeploymentTimeout: time.Duration(30) * time.Minute, + } +} func (*testModel) SupportDistributedInference() bool { return false } @@ -30,6 +36,12 @@ func (*testDistributedModel) GetInferenceParameters() *model.PresetParam { DeploymentTimeout: time.Duration(30) * time.Minute, } } +func (*testDistributedModel) GetTrainingParameters() *model.PresetParam { + return &model.PresetParam{ + GPUCountRequirement: "1", + DeploymentTimeout: time.Duration(30) * time.Minute, + } +} func (*testDistributedModel) SupportDistributedInference() bool { return true } From 9b8ee5655012f3ba16e7a67d472fc5b25c7d00ce Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Thu, 21 Mar 2024 09:30:37 -0400 Subject: [PATCH 23/29] fix: Add training func for phi --- presets/models/phi/model.go | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/presets/models/phi/model.go b/presets/models/phi/model.go index 37d92e673..c9a67033c 100644 --- a/presets/models/phi/model.go +++ b/presets/models/phi/model.go @@ -50,7 +50,21 @@ func (*phi2) GetInferenceParameters() *model.PresetParam { BaseCommand: baseCommandPresetPhi, Tag: PresetPhiTagMap["Phi2"], } - +} +func (*phi2) GetTrainingParameters() *model.PresetParam { + return &model.PresetParam{ + ModelFamilyName: "Phi", + ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), + DiskStorageRequirement: "50Gi", + GPUCountRequirement: "1", + TotalGPUMemoryRequirement: "16Gi", + PerGPUMemoryRequirement: "16Gi", // We run Phi using native vertical model parallel, no per GPU memory requirement. + // TorchRunParams: inference.DefaultAccelerateParams, + // ModelRunParams: phiRunParams, + DeploymentTimeout: time.Duration(30) * time.Minute, + BaseCommand: baseCommandPresetPhi, + Tag: PresetPhiTagMap["Phi2"], + } } func (*phi2) SupportDistributedInference() bool { return false From bbfd6f92b05304eba2ec1bbe600ba0872447ea67 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Thu, 21 Mar 2024 09:46:18 -0400 Subject: [PATCH 24/29] fix: Add support training method --- api/v1alpha1/workspace_validation_test.go | 6 +++++ pkg/controllers/workspace_controller.go | 8 +++--- pkg/model/interface.go | 17 ++++++------ pkg/utils/testModel.go | 14 +++++++--- presets/models/falcon/model.go | 32 ++++++++++++++++------- presets/models/llama2/model.go | 15 ++++++++--- presets/models/llama2chat/model.go | 15 ++++++++--- presets/models/mistral/model.go | 16 ++++++++---- presets/models/phi/model.go | 11 +++++--- 9 files changed, 93 insertions(+), 41 deletions(-) diff --git a/api/v1alpha1/workspace_validation_test.go b/api/v1alpha1/workspace_validation_test.go index 93e456566..4fc193f52 100644 --- a/api/v1alpha1/workspace_validation_test.go +++ b/api/v1alpha1/workspace_validation_test.go @@ -38,6 +38,9 @@ func (*testModel) GetTrainingParameters() *model.PresetParam { func (*testModel) SupportDistributedInference() bool { return false } +func (*testModel) SupportTraining() bool { + return true +} type testModelPrivate struct{} @@ -60,6 +63,9 @@ func (*testModelPrivate) GetTrainingParameters() *model.PresetParam { func (*testModelPrivate) SupportDistributedInference() bool { return false } +func (*testModelPrivate) SupportTraining() bool { + return true +} func RegisterValidationTestModels() { var test testModel diff --git a/pkg/controllers/workspace_controller.go b/pkg/controllers/workspace_controller.go index db1150ad5..7d1cac238 100644 --- a/pkg/controllers/workspace_controller.go +++ b/pkg/controllers/workspace_controller.go @@ -448,7 +448,7 @@ func (c *WorkspaceReconciler) applyTuning(ctx context.Context, wObj *kaitov1alph existingObj = &appsv1.Deployment{} if err = resources.GetResource(ctx, wObj.Name, wObj.Namespace, c.Client, existingObj); err == nil { klog.InfoS("A training workload already exists for workspace", "workspace", klog.KObj(wObj)) - if err = resources.CheckResourceStatus(existingObj, c.Client, trainingParam.DeploymentTimeout); err != nil { + if err = resources.CheckResourceStatus(existingObj, c.Client, trainingParam.WorkloadTimeout); err != nil { return } } else if apierrors.IsNotFound(err) { @@ -458,7 +458,7 @@ func (c *WorkspaceReconciler) applyTuning(ctx context.Context, wObj *kaitov1alph if err != nil { return } - if err = resources.CheckResourceStatus(workloadObj, c.Client, trainingParam.DeploymentTimeout); err != nil { + if err = resources.CheckResourceStatus(workloadObj, c.Client, trainingParam.WorkloadTimeout); err != nil { return } } @@ -516,7 +516,7 @@ func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1a if err = resources.GetResource(ctx, wObj.Name, wObj.Namespace, c.Client, existingObj); err == nil { klog.InfoS("An inference workload already exists for workspace", "workspace", klog.KObj(wObj)) - if err = resources.CheckResourceStatus(existingObj, c.Client, inferenceParam.DeploymentTimeout); err != nil { + if err = resources.CheckResourceStatus(existingObj, c.Client, inferenceParam.WorkloadTimeout); err != nil { return } } else if apierrors.IsNotFound(err) { @@ -526,7 +526,7 @@ func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1a if err != nil { return } - if err = resources.CheckResourceStatus(workloadObj, c.Client, inferenceParam.DeploymentTimeout); err != nil { + if err = resources.CheckResourceStatus(workloadObj, c.Client, inferenceParam.WorkloadTimeout); err != nil { return } } diff --git a/pkg/model/interface.go b/pkg/model/interface.go index 2763b5eec..148d80fcf 100644 --- a/pkg/model/interface.go +++ b/pkg/model/interface.go @@ -10,6 +10,7 @@ type Model interface { GetInferenceParameters() *PresetParam GetTrainingParameters() *PresetParam SupportDistributedInference() bool //If true, the model workload will be a StatefulSet, using the torch elastic runtime framework. + SupportTraining() bool } // PresetParam defines the preset inference parameters for a model. @@ -17,18 +18,18 @@ type PresetParam struct { ModelFamilyName string // The name of the model family. ImageAccessMode string // Defines where the Image is Public or Private. DiskStorageRequirement string // Disk storage requirements for the model. - GPUCountRequirement string // Number of GPUs required for the inference. - TotalGPUMemoryRequirement string // Total GPU memory required for the inference. + GPUCountRequirement string // Number of GPUs required for the Preset. + TotalGPUMemoryRequirement string // Total GPU memory required for the Preset. PerGPUMemoryRequirement string // GPU memory required per GPU. TorchRunParams map[string]string // Parameters for configuring the torchrun command. - TorchRunRdzvParams map[string]string // Optional rendezvous parameters for distributed inference using torchrun (elastic). - ModelRunParams map[string]string // Parameters for running the model inference. - // DeploymentTimeout defines the maximum duration for pulling the Preset image. + TorchRunRdzvParams map[string]string // Optional rendezvous parameters for distributed training/inference using torchrun (elastic). + // BaseCommand is the initial command (e.g., 'torchrun', 'accelerate launch') used in the command line. + BaseCommand string + ModelRunParams map[string]string // Parameters for running the model training/inference. + // WorkloadTimeout defines the maximum duration for creating the workload. // This timeout accommodates the size of the image, ensuring pull completion // even under slower network conditions or unforeseen delays. - DeploymentTimeout time.Duration - // BaseCommand is the initial command (e.g., 'torchrun', 'accelerate launch') used in the command line. - BaseCommand string + WorkloadTimeout time.Duration // WorldSize defines the number of processes required for distributed inference. WorldSize int Tag string // The model image tag diff --git a/pkg/utils/testModel.go b/pkg/utils/testModel.go index 7dd04fa72..fdf9423c3 100644 --- a/pkg/utils/testModel.go +++ b/pkg/utils/testModel.go @@ -15,36 +15,42 @@ type testModel struct{} func (*testModel) GetInferenceParameters() *model.PresetParam { return &model.PresetParam{ GPUCountRequirement: "1", - DeploymentTimeout: time.Duration(30) * time.Minute, + WorkloadTimeout: time.Duration(30) * time.Minute, } } func (*testModel) GetTrainingParameters() *model.PresetParam { return &model.PresetParam{ GPUCountRequirement: "1", - DeploymentTimeout: time.Duration(30) * time.Minute, + WorkloadTimeout: time.Duration(30) * time.Minute, } } func (*testModel) SupportDistributedInference() bool { return false } +func (*testModel) SupportTraining() bool { + return true +} type testDistributedModel struct{} func (*testDistributedModel) GetInferenceParameters() *model.PresetParam { return &model.PresetParam{ GPUCountRequirement: "1", - DeploymentTimeout: time.Duration(30) * time.Minute, + WorkloadTimeout: time.Duration(30) * time.Minute, } } func (*testDistributedModel) GetTrainingParameters() *model.PresetParam { return &model.PresetParam{ GPUCountRequirement: "1", - DeploymentTimeout: time.Duration(30) * time.Minute, + WorkloadTimeout: time.Duration(30) * time.Minute, } } func (*testDistributedModel) SupportDistributedInference() bool { return true } +func (*testDistributedModel) SupportTraining() bool { + return true +} func RegisterTestModel() { var test testModel diff --git a/presets/models/falcon/model.go b/presets/models/falcon/model.go index c0f07495f..f99cac208 100644 --- a/presets/models/falcon/model.go +++ b/presets/models/falcon/model.go @@ -64,7 +64,7 @@ func (*falcon7b) GetInferenceParameters() *model.PresetParam { PerGPUMemoryRequirement: "0Gi", // We run Falcon using native vertical model parallel, no per GPU memory requirement. TorchRunParams: inference.DefaultAccelerateParams, ModelRunParams: falconRunParams, - DeploymentTimeout: time.Duration(30) * time.Minute, + WorkloadTimeout: time.Duration(30) * time.Minute, BaseCommand: baseCommandPresetFalcon, Tag: PresetFalconTagMap["Falcon7B"], } @@ -79,15 +79,18 @@ func (*falcon7b) GetTrainingParameters() *model.PresetParam { PerGPUMemoryRequirement: "16Gi", //TorchRunParams: tuning.DefaultAccelerateParams, // TODO //ModelRunPrams: falconRunTuningParams, // TODO - DeploymentTimeout: time.Duration(30) * time.Minute, - BaseCommand: baseCommandPresetFalcon, - Tag: PresetFalconTagMap["Falcon7B"], + WorkloadTimeout: time.Duration(30) * time.Minute, + BaseCommand: baseCommandPresetFalcon, + Tag: PresetFalconTagMap["Falcon7B"], } } func (*falcon7b) SupportDistributedInference() bool { return false } +func (*falcon7b) SupportTraining() bool { + return true +} var falconB falcon7bInst @@ -103,7 +106,7 @@ func (*falcon7bInst) GetInferenceParameters() *model.PresetParam { PerGPUMemoryRequirement: "0Gi", // We run Falcon using native vertical model parallel, no per GPU memory requirement. TorchRunParams: inference.DefaultAccelerateParams, ModelRunParams: falconRunParams, - DeploymentTimeout: time.Duration(30) * time.Minute, + WorkloadTimeout: time.Duration(30) * time.Minute, BaseCommand: baseCommandPresetFalcon, Tag: PresetFalconTagMap["Falcon7BInstruct"], } @@ -115,6 +118,9 @@ func (*falcon7bInst) GetTrainingParameters() *model.PresetParam { func (*falcon7bInst) SupportDistributedInference() bool { return false } +func (*falcon7bInst) SupportTraining() bool { + return false +} var falconC falcon40b @@ -130,7 +136,7 @@ func (*falcon40b) GetInferenceParameters() *model.PresetParam { PerGPUMemoryRequirement: "0Gi", // We run Falcon using native vertical model parallel, no per GPU memory requirement. TorchRunParams: inference.DefaultAccelerateParams, ModelRunParams: falconRunParams, - DeploymentTimeout: time.Duration(30) * time.Minute, + WorkloadTimeout: time.Duration(30) * time.Minute, BaseCommand: baseCommandPresetFalcon, Tag: PresetFalconTagMap["Falcon40B"], } @@ -146,14 +152,17 @@ func (*falcon40b) GetTrainingParameters() *model.PresetParam { PerGPUMemoryRequirement: "16Gi", //TorchRunParams: tuning.DefaultAccelerateParams, // TODO //ModelRunPrams: falconRunTuningParams, // TODO - DeploymentTimeout: time.Duration(30) * time.Minute, - BaseCommand: baseCommandPresetFalcon, - Tag: PresetFalconTagMap["Falcon40B"], + WorkloadTimeout: time.Duration(30) * time.Minute, + BaseCommand: baseCommandPresetFalcon, + Tag: PresetFalconTagMap["Falcon40B"], } } func (*falcon40b) SupportDistributedInference() bool { return false } +func (*falcon40b) SupportTraining() bool { + return true +} var falconD falcon40bInst @@ -169,7 +178,7 @@ func (*falcon40bInst) GetInferenceParameters() *model.PresetParam { PerGPUMemoryRequirement: "0Gi", // We run Falcon using native vertical model parallel, no per GPU memory requirement. TorchRunParams: inference.DefaultAccelerateParams, ModelRunParams: falconRunParams, - DeploymentTimeout: time.Duration(30) * time.Minute, + WorkloadTimeout: time.Duration(30) * time.Minute, BaseCommand: baseCommandPresetFalcon, Tag: PresetFalconTagMap["Falcon40BInstruct"], } @@ -180,3 +189,6 @@ func (*falcon40bInst) GetTrainingParameters() *model.PresetParam { func (*falcon40bInst) SupportDistributedInference() bool { return false } +func (*falcon40bInst) SupportTraining() bool { + return false +} diff --git a/presets/models/llama2/model.go b/presets/models/llama2/model.go index 673a67da3..b1e1dc180 100644 --- a/presets/models/llama2/model.go +++ b/presets/models/llama2/model.go @@ -49,7 +49,7 @@ func (*llama2Text7b) GetInferenceParameters() *model.PresetParam { TorchRunParams: inference.DefaultTorchRunParams, TorchRunRdzvParams: inference.DefaultTorchRunRdzvParams, ModelRunParams: llamaRunParams, - DeploymentTimeout: time.Duration(10) * time.Minute, + WorkloadTimeout: time.Duration(10) * time.Minute, BaseCommand: baseCommandPresetLlama, WorldSize: 1, // Tag: llama has private image access mode. The image tag is determined by the user. @@ -62,6 +62,9 @@ func (*llama2Text7b) GetTrainingParameters() *model.PresetParam { func (*llama2Text7b) SupportDistributedInference() bool { return false } +func (*llama2Text7b) SupportTraining() bool { + return false +} var llama2B llama2Text13b @@ -78,7 +81,7 @@ func (*llama2Text13b) GetInferenceParameters() *model.PresetParam { TorchRunParams: inference.DefaultTorchRunParams, TorchRunRdzvParams: inference.DefaultTorchRunRdzvParams, ModelRunParams: llamaRunParams, - DeploymentTimeout: time.Duration(20) * time.Minute, + WorkloadTimeout: time.Duration(20) * time.Minute, BaseCommand: baseCommandPresetLlama, WorldSize: 2, // Tag: llama has private image access mode. The image tag is determined by the user. @@ -90,6 +93,9 @@ func (*llama2Text13b) GetTrainingParameters() *model.PresetParam { func (*llama2Text13b) SupportDistributedInference() bool { return true } +func (*llama2Text13b) SupportTraining() bool { + return false +} var llama2C llama2Text70b @@ -106,7 +112,7 @@ func (*llama2Text70b) GetInferenceParameters() *model.PresetParam { TorchRunParams: inference.DefaultTorchRunParams, TorchRunRdzvParams: inference.DefaultTorchRunRdzvParams, ModelRunParams: llamaRunParams, - DeploymentTimeout: time.Duration(30) * time.Minute, + WorkloadTimeout: time.Duration(30) * time.Minute, BaseCommand: baseCommandPresetLlama, WorldSize: 8, // Tag: llama has private image access mode. The image tag is determined by the user. @@ -118,3 +124,6 @@ func (*llama2Text70b) GetTrainingParameters() *model.PresetParam { func (*llama2Text70b) SupportDistributedInference() bool { return true } +func (*llama2Text70b) SupportTraining() bool { + return false +} diff --git a/presets/models/llama2chat/model.go b/presets/models/llama2chat/model.go index a555ebc07..9108a41d5 100644 --- a/presets/models/llama2chat/model.go +++ b/presets/models/llama2chat/model.go @@ -49,7 +49,7 @@ func (*llama2Chat7b) GetInferenceParameters() *model.PresetParam { TorchRunParams: inference.DefaultTorchRunParams, TorchRunRdzvParams: inference.DefaultTorchRunRdzvParams, ModelRunParams: llamaRunParams, - DeploymentTimeout: time.Duration(10) * time.Minute, + WorkloadTimeout: time.Duration(10) * time.Minute, BaseCommand: baseCommandPresetLlama, WorldSize: 1, // Tag: llama has private image access mode. The image tag is determined by the user. @@ -61,6 +61,9 @@ func (*llama2Chat7b) GetTrainingParameters() *model.PresetParam { func (*llama2Chat7b) SupportDistributedInference() bool { return false } +func (*llama2Chat7b) SupportTraining() bool { + return false +} var llama2chatB llama2Chat13b @@ -77,7 +80,7 @@ func (*llama2Chat13b) GetInferenceParameters() *model.PresetParam { TorchRunParams: inference.DefaultTorchRunParams, TorchRunRdzvParams: inference.DefaultTorchRunRdzvParams, ModelRunParams: llamaRunParams, - DeploymentTimeout: time.Duration(20) * time.Minute, + WorkloadTimeout: time.Duration(20) * time.Minute, BaseCommand: baseCommandPresetLlama, WorldSize: 2, // Tag: llama has private image access mode. The image tag is determined by the user. @@ -89,6 +92,9 @@ func (*llama2Chat13b) GetTrainingParameters() *model.PresetParam { func (*llama2Chat13b) SupportDistributedInference() bool { return true } +func (*llama2Chat13b) SupportTraining() bool { + return false +} var llama2chatC llama2Chat70b @@ -105,7 +111,7 @@ func (*llama2Chat70b) GetInferenceParameters() *model.PresetParam { TorchRunParams: inference.DefaultTorchRunParams, TorchRunRdzvParams: inference.DefaultTorchRunRdzvParams, ModelRunParams: llamaRunParams, - DeploymentTimeout: time.Duration(30) * time.Minute, + WorkloadTimeout: time.Duration(30) * time.Minute, BaseCommand: baseCommandPresetLlama, WorldSize: 8, // Tag: llama has private image access mode. The image tag is determined by the user. @@ -117,3 +123,6 @@ func (*llama2Chat70b) GetTrainingParameters() *model.PresetParam { func (*llama2Chat70b) SupportDistributedInference() bool { return true } +func (*llama2Chat70b) SupportTraining() bool { + return false +} diff --git a/presets/models/mistral/model.go b/presets/models/mistral/model.go index bcf06203b..c4c518feb 100644 --- a/presets/models/mistral/model.go +++ b/presets/models/mistral/model.go @@ -52,7 +52,7 @@ func (*mistral7b) GetInferenceParameters() *model.PresetParam { PerGPUMemoryRequirement: "0Gi", // We run Mistral using native vertical model parallel, no per GPU memory requirement. TorchRunParams: inference.DefaultAccelerateParams, ModelRunParams: mistralRunParams, - DeploymentTimeout: time.Duration(30) * time.Minute, + WorkloadTimeout: time.Duration(30) * time.Minute, BaseCommand: baseCommandPresetMistral, Tag: PresetMistralTagMap["Mistral7B"], } @@ -68,15 +68,18 @@ func (*mistral7b) GetTrainingParameters() *model.PresetParam { PerGPUMemoryRequirement: "16Gi", // We run Mistral using native vertical model parallel, no per GPU memory requirement. //TorchRunParams: tuning.DefaultAccelerateParams, //ModelRunParams: mistralRunParams, - DeploymentTimeout: time.Duration(30) * time.Minute, - BaseCommand: baseCommandPresetMistral, - Tag: PresetMistralTagMap["Mistral7B"], + WorkloadTimeout: time.Duration(30) * time.Minute, + BaseCommand: baseCommandPresetMistral, + Tag: PresetMistralTagMap["Mistral7B"], } } func (*mistral7b) SupportDistributedInference() bool { return false } +func (*mistral7b) SupportTraining() bool { + return true +} var mistralB mistral7bInst @@ -92,7 +95,7 @@ func (*mistral7bInst) GetInferenceParameters() *model.PresetParam { PerGPUMemoryRequirement: "0Gi", // We run mistral using native vertical model parallel, no per GPU memory requirement. TorchRunParams: inference.DefaultAccelerateParams, ModelRunParams: mistralRunParams, - DeploymentTimeout: time.Duration(30) * time.Minute, + WorkloadTimeout: time.Duration(30) * time.Minute, BaseCommand: baseCommandPresetMistral, Tag: PresetMistralTagMap["Mistral7BInstruct"], } @@ -104,3 +107,6 @@ func (*mistral7bInst) GetTrainingParameters() *model.PresetParam { func (*mistral7bInst) SupportDistributedInference() bool { return false } +func (*mistral7bInst) SupportTraining() bool { + return false +} diff --git a/presets/models/phi/model.go b/presets/models/phi/model.go index c9a67033c..6a1bfd109 100644 --- a/presets/models/phi/model.go +++ b/presets/models/phi/model.go @@ -46,7 +46,7 @@ func (*phi2) GetInferenceParameters() *model.PresetParam { PerGPUMemoryRequirement: "0Gi", // We run Phi using native vertical model parallel, no per GPU memory requirement. TorchRunParams: inference.DefaultAccelerateParams, ModelRunParams: phiRunParams, - DeploymentTimeout: time.Duration(30) * time.Minute, + WorkloadTimeout: time.Duration(30) * time.Minute, BaseCommand: baseCommandPresetPhi, Tag: PresetPhiTagMap["Phi2"], } @@ -61,11 +61,14 @@ func (*phi2) GetTrainingParameters() *model.PresetParam { PerGPUMemoryRequirement: "16Gi", // We run Phi using native vertical model parallel, no per GPU memory requirement. // TorchRunParams: inference.DefaultAccelerateParams, // ModelRunParams: phiRunParams, - DeploymentTimeout: time.Duration(30) * time.Minute, - BaseCommand: baseCommandPresetPhi, - Tag: PresetPhiTagMap["Phi2"], + WorkloadTimeout: time.Duration(30) * time.Minute, + BaseCommand: baseCommandPresetPhi, + Tag: PresetPhiTagMap["Phi2"], } } func (*phi2) SupportDistributedInference() bool { return false } +func (*phi2) SupportTraining() bool { + return true +} From 931b037817198e73ccf596df528eccead952e71b Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Thu, 21 Mar 2024 09:48:16 -0400 Subject: [PATCH 25/29] fix: lint issue --- pkg/controllers/workspace_controller.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pkg/controllers/workspace_controller.go b/pkg/controllers/workspace_controller.go index 7d1cac238..2ab2b4604 100644 --- a/pkg/controllers/workspace_controller.go +++ b/pkg/controllers/workspace_controller.go @@ -444,8 +444,7 @@ func (c *WorkspaceReconciler) applyTuning(ctx context.Context, wObj *kaitov1alph trainingParam := model.GetTrainingParameters() - var existingObj client.Object - existingObj = &appsv1.Deployment{} + existingObj := &appsv1.Deployment{} if err = resources.GetResource(ctx, wObj.Name, wObj.Namespace, c.Client, existingObj); err == nil { klog.InfoS("A training workload already exists for workspace", "workspace", klog.KObj(wObj)) if err = resources.CheckResourceStatus(existingObj, c.Client, trainingParam.WorkloadTimeout); err != nil { From 107c5011a969bf4759b8f1fa193042c51942f12b Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Thu, 21 Mar 2024 15:17:26 -0400 Subject: [PATCH 26/29] fix: workspace condition --- api/v1alpha1/workspace_condition_types.go | 10 +++++++-- api/v1alpha1/workspace_validation_test.go | 8 +++---- pkg/controllers/workspace_controller.go | 26 +++++++++++------------ pkg/model/interface.go | 4 ++-- pkg/utils/testModel.go | 8 +++---- presets/models/falcon/model.go | 16 +++++++------- presets/models/llama2/model.go | 12 +++++------ presets/models/llama2chat/model.go | 12 +++++------ presets/models/mistral/model.go | 8 +++---- presets/models/phi/model.go | 4 ++-- 10 files changed, 57 insertions(+), 51 deletions(-) diff --git a/api/v1alpha1/workspace_condition_types.go b/api/v1alpha1/workspace_condition_types.go index 9845b8a0c..e14995f12 100644 --- a/api/v1alpha1/workspace_condition_types.go +++ b/api/v1alpha1/workspace_condition_types.go @@ -16,8 +16,14 @@ const ( // WorkspaceConditionTypeInferenceStatus is the state when Inference has been created. WorkspaceConditionTypeInferenceStatus = ConditionType("InferenceReady") - // WorkspaceConditionTypeTuningStatus is the state when Tuning has been created. - WorkspaceConditionTypeTuningStatus = ConditionType("TuningReady") + // WorkspaceConditionTypeTuningStarted indicates that the tuning Job has been started. + WorkspaceConditionTypeTuningStarted = ConditionType("TuningStarted") + + // WorkspaceConditionTypeTuningComplete indicates that the tuning Job has completed successfully. + WorkspaceConditionTypeTuningComplete = ConditionType("TuningComplete") + + // WorkspaceConditionTypeTuningFailed indicates that the tuning Job has failed to complete. + WorkspaceConditionTypeTuningFailed = ConditionType("TuningFailed") //WorkspaceConditionTypeDeleting is the Workspace state when starts to get deleted. WorkspaceConditionTypeDeleting = ConditionType("WorkspaceDeleting") diff --git a/api/v1alpha1/workspace_validation_test.go b/api/v1alpha1/workspace_validation_test.go index 4fc193f52..d196beabc 100644 --- a/api/v1alpha1/workspace_validation_test.go +++ b/api/v1alpha1/workspace_validation_test.go @@ -28,7 +28,7 @@ func (*testModel) GetInferenceParameters() *model.PresetParam { PerGPUMemoryRequirement: perGPUMemoryRequirement, } } -func (*testModel) GetTrainingParameters() *model.PresetParam { +func (*testModel) GetTuningParameters() *model.PresetParam { return &model.PresetParam{ GPUCountRequirement: gpuCountRequirement, TotalGPUMemoryRequirement: totalGPUMemoryRequirement, @@ -38,7 +38,7 @@ func (*testModel) GetTrainingParameters() *model.PresetParam { func (*testModel) SupportDistributedInference() bool { return false } -func (*testModel) SupportTraining() bool { +func (*testModel) SupportTuning() bool { return true } @@ -52,7 +52,7 @@ func (*testModelPrivate) GetInferenceParameters() *model.PresetParam { PerGPUMemoryRequirement: perGPUMemoryRequirement, } } -func (*testModelPrivate) GetTrainingParameters() *model.PresetParam { +func (*testModelPrivate) GetTuningParameters() *model.PresetParam { return &model.PresetParam{ ImageAccessMode: "private", GPUCountRequirement: gpuCountRequirement, @@ -63,7 +63,7 @@ func (*testModelPrivate) GetTrainingParameters() *model.PresetParam { func (*testModelPrivate) SupportDistributedInference() bool { return false } -func (*testModelPrivate) SupportTraining() bool { +func (*testModelPrivate) SupportTuning() bool { return true } diff --git a/pkg/controllers/workspace_controller.go b/pkg/controllers/workspace_controller.go index 2ab2b4604..ef2e4c01e 100644 --- a/pkg/controllers/workspace_controller.go +++ b/pkg/controllers/workspace_controller.go @@ -5,13 +5,12 @@ package controllers import ( "context" "fmt" - "github.com/azure/kaito/pkg/tuning" "sort" "strings" "time" - appsv1 "k8s.io/api/apps/v1" - "k8s.io/utils/clock" + "github.com/azure/kaito/pkg/tuning" + batchv1 "k8s.io/api/batch/v1" "github.com/aws/karpenter-core/pkg/apis/v1alpha5" kaitov1alpha1 "github.com/azure/kaito/api/v1alpha1" @@ -22,6 +21,7 @@ import ( "github.com/azure/kaito/pkg/utils/plugin" "github.com/go-logr/logr" "github.com/samber/lo" + appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/errors" apierrors "k8s.io/apimachinery/pkg/api/errors" @@ -29,6 +29,7 @@ import ( "k8s.io/apimachinery/pkg/runtime" "k8s.io/client-go/tools/record" "k8s.io/klog/v2" + "k8s.io/utils/clock" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/controller" @@ -442,22 +443,21 @@ func (c *WorkspaceReconciler) applyTuning(ctx context.Context, wObj *kaitov1alph presetName := string(wObj.Tuning.Preset.Name) model := plugin.KaitoModelRegister.MustGet(presetName) - trainingParam := model.GetTrainingParameters() - - existingObj := &appsv1.Deployment{} + tuningParam := model.GetTuningParameters() + existingObj := &batchv1.Job{} if err = resources.GetResource(ctx, wObj.Name, wObj.Namespace, c.Client, existingObj); err == nil { - klog.InfoS("A training workload already exists for workspace", "workspace", klog.KObj(wObj)) - if err = resources.CheckResourceStatus(existingObj, c.Client, trainingParam.WorkloadTimeout); err != nil { + klog.InfoS("A tuning workload already exists for workspace", "workspace", klog.KObj(wObj)) + if err = resources.CheckResourceStatus(existingObj, c.Client, tuningParam.WorkloadTimeout); err != nil { return } } else if apierrors.IsNotFound(err) { var workloadObj client.Object // Need to create a new workload - workloadObj, err = tuning.CreatePresetTuning(ctx, wObj, trainingParam, c.Client) + workloadObj, err = tuning.CreatePresetTuning(ctx, wObj, tuningParam, c.Client) if err != nil { return } - if err = resources.CheckResourceStatus(workloadObj, c.Client, trainingParam.WorkloadTimeout); err != nil { + if err = resources.CheckResourceStatus(workloadObj, c.Client, tuningParam.WorkloadTimeout); err != nil { return } } @@ -465,7 +465,7 @@ func (c *WorkspaceReconciler) applyTuning(ctx context.Context, wObj *kaitov1alph }() if err != nil { - if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeTuningStatus, metav1.ConditionFalse, + if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeTuningFailed, metav1.ConditionFalse, "WorkspaceTuningStatusFailed", err.Error()); updateErr != nil { klog.ErrorS(updateErr, "failed to update workspace status", "workspace", klog.KObj(wObj)) return updateErr @@ -475,8 +475,8 @@ func (c *WorkspaceReconciler) applyTuning(ctx context.Context, wObj *kaitov1alph } } - if err := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeTuningStatus, metav1.ConditionTrue, - "WorkspaceTuningStatusSuccess", "Tuning has been deployed successfully"); err != nil { + if err := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeTuningStarted, metav1.ConditionTrue, + "WorkspaceTuningStatusStarted", "Tuning has been deployed successfully"); err != nil { klog.ErrorS(err, "failed to update workspace status", "workspace", klog.KObj(wObj)) return err } diff --git a/pkg/model/interface.go b/pkg/model/interface.go index 148d80fcf..eba2e0d0e 100644 --- a/pkg/model/interface.go +++ b/pkg/model/interface.go @@ -8,9 +8,9 @@ import ( type Model interface { GetInferenceParameters() *PresetParam - GetTrainingParameters() *PresetParam + GetTuningParameters() *PresetParam SupportDistributedInference() bool //If true, the model workload will be a StatefulSet, using the torch elastic runtime framework. - SupportTraining() bool + SupportTuning() bool } // PresetParam defines the preset inference parameters for a model. diff --git a/pkg/utils/testModel.go b/pkg/utils/testModel.go index fdf9423c3..f03633d7c 100644 --- a/pkg/utils/testModel.go +++ b/pkg/utils/testModel.go @@ -18,7 +18,7 @@ func (*testModel) GetInferenceParameters() *model.PresetParam { WorkloadTimeout: time.Duration(30) * time.Minute, } } -func (*testModel) GetTrainingParameters() *model.PresetParam { +func (*testModel) GetTuningParameters() *model.PresetParam { return &model.PresetParam{ GPUCountRequirement: "1", WorkloadTimeout: time.Duration(30) * time.Minute, @@ -27,7 +27,7 @@ func (*testModel) GetTrainingParameters() *model.PresetParam { func (*testModel) SupportDistributedInference() bool { return false } -func (*testModel) SupportTraining() bool { +func (*testModel) SupportTuning() bool { return true } @@ -39,7 +39,7 @@ func (*testDistributedModel) GetInferenceParameters() *model.PresetParam { WorkloadTimeout: time.Duration(30) * time.Minute, } } -func (*testDistributedModel) GetTrainingParameters() *model.PresetParam { +func (*testDistributedModel) GetTuningParameters() *model.PresetParam { return &model.PresetParam{ GPUCountRequirement: "1", WorkloadTimeout: time.Duration(30) * time.Minute, @@ -48,7 +48,7 @@ func (*testDistributedModel) GetTrainingParameters() *model.PresetParam { func (*testDistributedModel) SupportDistributedInference() bool { return true } -func (*testDistributedModel) SupportTraining() bool { +func (*testDistributedModel) SupportTuning() bool { return true } diff --git a/presets/models/falcon/model.go b/presets/models/falcon/model.go index f99cac208..e6774eed1 100644 --- a/presets/models/falcon/model.go +++ b/presets/models/falcon/model.go @@ -69,7 +69,7 @@ func (*falcon7b) GetInferenceParameters() *model.PresetParam { Tag: PresetFalconTagMap["Falcon7B"], } } -func (*falcon7b) GetTrainingParameters() *model.PresetParam { +func (*falcon7b) GetTuningParameters() *model.PresetParam { return &model.PresetParam{ ModelFamilyName: "Falcon", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), @@ -88,7 +88,7 @@ func (*falcon7b) GetTrainingParameters() *model.PresetParam { func (*falcon7b) SupportDistributedInference() bool { return false } -func (*falcon7b) SupportTraining() bool { +func (*falcon7b) SupportTuning() bool { return true } @@ -112,13 +112,13 @@ func (*falcon7bInst) GetInferenceParameters() *model.PresetParam { } } -func (*falcon7bInst) GetTrainingParameters() *model.PresetParam { +func (*falcon7bInst) GetTuningParameters() *model.PresetParam { return nil // It is not recommended/ideal to further fine-tune instruct models - Already been fine-tuned } func (*falcon7bInst) SupportDistributedInference() bool { return false } -func (*falcon7bInst) SupportTraining() bool { +func (*falcon7bInst) SupportTuning() bool { return false } @@ -142,7 +142,7 @@ func (*falcon40b) GetInferenceParameters() *model.PresetParam { } } -func (*falcon40b) GetTrainingParameters() *model.PresetParam { +func (*falcon40b) GetTuningParameters() *model.PresetParam { return &model.PresetParam{ ModelFamilyName: "Falcon", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), @@ -160,7 +160,7 @@ func (*falcon40b) GetTrainingParameters() *model.PresetParam { func (*falcon40b) SupportDistributedInference() bool { return false } -func (*falcon40b) SupportTraining() bool { +func (*falcon40b) SupportTuning() bool { return true } @@ -183,12 +183,12 @@ func (*falcon40bInst) GetInferenceParameters() *model.PresetParam { Tag: PresetFalconTagMap["Falcon40BInstruct"], } } -func (*falcon40bInst) GetTrainingParameters() *model.PresetParam { +func (*falcon40bInst) GetTuningParameters() *model.PresetParam { return nil // It is not recommended/ideal to further fine-tune instruct models - Already been fine-tuned } func (*falcon40bInst) SupportDistributedInference() bool { return false } -func (*falcon40bInst) SupportTraining() bool { +func (*falcon40bInst) SupportTuning() bool { return false } diff --git a/presets/models/llama2/model.go b/presets/models/llama2/model.go index b1e1dc180..c38d9ef4d 100644 --- a/presets/models/llama2/model.go +++ b/presets/models/llama2/model.go @@ -56,13 +56,13 @@ func (*llama2Text7b) GetInferenceParameters() *model.PresetParam { } } -func (*llama2Text7b) GetTrainingParameters() *model.PresetParam { +func (*llama2Text7b) GetTuningParameters() *model.PresetParam { return nil // Currently doesn't support fine-tuning } func (*llama2Text7b) SupportDistributedInference() bool { return false } -func (*llama2Text7b) SupportTraining() bool { +func (*llama2Text7b) SupportTuning() bool { return false } @@ -87,13 +87,13 @@ func (*llama2Text13b) GetInferenceParameters() *model.PresetParam { // Tag: llama has private image access mode. The image tag is determined by the user. } } -func (*llama2Text13b) GetTrainingParameters() *model.PresetParam { +func (*llama2Text13b) GetTuningParameters() *model.PresetParam { return nil // Currently doesn't support fine-tuning } func (*llama2Text13b) SupportDistributedInference() bool { return true } -func (*llama2Text13b) SupportTraining() bool { +func (*llama2Text13b) SupportTuning() bool { return false } @@ -118,12 +118,12 @@ func (*llama2Text70b) GetInferenceParameters() *model.PresetParam { // Tag: llama has private image access mode. The image tag is determined by the user. } } -func (*llama2Text70b) GetTrainingParameters() *model.PresetParam { +func (*llama2Text70b) GetTuningParameters() *model.PresetParam { return nil // Currently doesn't support fine-tuning } func (*llama2Text70b) SupportDistributedInference() bool { return true } -func (*llama2Text70b) SupportTraining() bool { +func (*llama2Text70b) SupportTuning() bool { return false } diff --git a/presets/models/llama2chat/model.go b/presets/models/llama2chat/model.go index 9108a41d5..1afc17655 100644 --- a/presets/models/llama2chat/model.go +++ b/presets/models/llama2chat/model.go @@ -55,13 +55,13 @@ func (*llama2Chat7b) GetInferenceParameters() *model.PresetParam { // Tag: llama has private image access mode. The image tag is determined by the user. } } -func (*llama2Chat7b) GetTrainingParameters() *model.PresetParam { +func (*llama2Chat7b) GetTuningParameters() *model.PresetParam { return nil // Currently doesn't support fine-tuning } func (*llama2Chat7b) SupportDistributedInference() bool { return false } -func (*llama2Chat7b) SupportTraining() bool { +func (*llama2Chat7b) SupportTuning() bool { return false } @@ -86,13 +86,13 @@ func (*llama2Chat13b) GetInferenceParameters() *model.PresetParam { // Tag: llama has private image access mode. The image tag is determined by the user. } } -func (*llama2Chat13b) GetTrainingParameters() *model.PresetParam { +func (*llama2Chat13b) GetTuningParameters() *model.PresetParam { return nil // Currently doesn't support fine-tuning } func (*llama2Chat13b) SupportDistributedInference() bool { return true } -func (*llama2Chat13b) SupportTraining() bool { +func (*llama2Chat13b) SupportTuning() bool { return false } @@ -117,12 +117,12 @@ func (*llama2Chat70b) GetInferenceParameters() *model.PresetParam { // Tag: llama has private image access mode. The image tag is determined by the user. } } -func (*llama2Chat70b) GetTrainingParameters() *model.PresetParam { +func (*llama2Chat70b) GetTuningParameters() *model.PresetParam { return nil // Currently doesn't support fine-tuning } func (*llama2Chat70b) SupportDistributedInference() bool { return true } -func (*llama2Chat70b) SupportTraining() bool { +func (*llama2Chat70b) SupportTuning() bool { return false } diff --git a/presets/models/mistral/model.go b/presets/models/mistral/model.go index c4c518feb..9a3dc8217 100644 --- a/presets/models/mistral/model.go +++ b/presets/models/mistral/model.go @@ -58,7 +58,7 @@ func (*mistral7b) GetInferenceParameters() *model.PresetParam { } } -func (*mistral7b) GetTrainingParameters() *model.PresetParam { +func (*mistral7b) GetTuningParameters() *model.PresetParam { return &model.PresetParam{ ModelFamilyName: "Mistral", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), @@ -77,7 +77,7 @@ func (*mistral7b) GetTrainingParameters() *model.PresetParam { func (*mistral7b) SupportDistributedInference() bool { return false } -func (*mistral7b) SupportTraining() bool { +func (*mistral7b) SupportTuning() bool { return true } @@ -101,12 +101,12 @@ func (*mistral7bInst) GetInferenceParameters() *model.PresetParam { } } -func (*mistral7bInst) GetTrainingParameters() *model.PresetParam { +func (*mistral7bInst) GetTuningParameters() *model.PresetParam { return nil // It is not recommended/ideal to further fine-tune instruct models - Already been fine-tuned } func (*mistral7bInst) SupportDistributedInference() bool { return false } -func (*mistral7bInst) SupportTraining() bool { +func (*mistral7bInst) SupportTuning() bool { return false } diff --git a/presets/models/phi/model.go b/presets/models/phi/model.go index 6a1bfd109..189b9d9ec 100644 --- a/presets/models/phi/model.go +++ b/presets/models/phi/model.go @@ -51,7 +51,7 @@ func (*phi2) GetInferenceParameters() *model.PresetParam { Tag: PresetPhiTagMap["Phi2"], } } -func (*phi2) GetTrainingParameters() *model.PresetParam { +func (*phi2) GetTuningParameters() *model.PresetParam { return &model.PresetParam{ ModelFamilyName: "Phi", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), @@ -69,6 +69,6 @@ func (*phi2) GetTrainingParameters() *model.PresetParam { func (*phi2) SupportDistributedInference() bool { return false } -func (*phi2) SupportTraining() bool { +func (*phi2) SupportTuning() bool { return true } From c80eadaa3afb6cd9f88d5dfcb246250acd1ed765 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Thu, 21 Mar 2024 15:18:55 -0400 Subject: [PATCH 27/29] fix: workspace condition --- pkg/controllers/workspace_controller.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/controllers/workspace_controller.go b/pkg/controllers/workspace_controller.go index ef2e4c01e..c3766736f 100644 --- a/pkg/controllers/workspace_controller.go +++ b/pkg/controllers/workspace_controller.go @@ -465,7 +465,7 @@ func (c *WorkspaceReconciler) applyTuning(ctx context.Context, wObj *kaitov1alph }() if err != nil { - if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeTuningFailed, metav1.ConditionFalse, + if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeTuningStarted, metav1.ConditionFalse, "WorkspaceTuningStatusFailed", err.Error()); updateErr != nil { klog.ErrorS(updateErr, "failed to update workspace status", "workspace", klog.KObj(wObj)) return updateErr From 860108fe57cb97f01b714ccfa370401ff2272a1b Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Fri, 29 Mar 2024 15:41:16 -0700 Subject: [PATCH 28/29] fix: address comments --- api/v1alpha1/workspace_condition_types.go | 9 --------- pkg/controllers/workspace_controller.go | 13 ++++--------- pkg/model/interface.go | 4 ++-- pkg/tuning/preset-tuning.go | 15 --------------- pkg/utils/testModel.go | 8 ++++---- presets/models/falcon/model.go | 20 ++++++++++---------- presets/models/llama2/model.go | 6 +++--- presets/models/llama2chat/model.go | 6 +++--- presets/models/mistral/model.go | 10 +++++----- presets/models/phi/model.go | 8 ++++---- 10 files changed, 35 insertions(+), 64 deletions(-) diff --git a/api/v1alpha1/workspace_condition_types.go b/api/v1alpha1/workspace_condition_types.go index e14995f12..762d8dafc 100644 --- a/api/v1alpha1/workspace_condition_types.go +++ b/api/v1alpha1/workspace_condition_types.go @@ -16,15 +16,6 @@ const ( // WorkspaceConditionTypeInferenceStatus is the state when Inference has been created. WorkspaceConditionTypeInferenceStatus = ConditionType("InferenceReady") - // WorkspaceConditionTypeTuningStarted indicates that the tuning Job has been started. - WorkspaceConditionTypeTuningStarted = ConditionType("TuningStarted") - - // WorkspaceConditionTypeTuningComplete indicates that the tuning Job has completed successfully. - WorkspaceConditionTypeTuningComplete = ConditionType("TuningComplete") - - // WorkspaceConditionTypeTuningFailed indicates that the tuning Job has failed to complete. - WorkspaceConditionTypeTuningFailed = ConditionType("TuningFailed") - //WorkspaceConditionTypeDeleting is the Workspace state when starts to get deleted. WorkspaceConditionTypeDeleting = ConditionType("WorkspaceDeleting") diff --git a/pkg/controllers/workspace_controller.go b/pkg/controllers/workspace_controller.go index c3766736f..d5ffe5959 100644 --- a/pkg/controllers/workspace_controller.go +++ b/pkg/controllers/workspace_controller.go @@ -113,11 +113,6 @@ func (c *WorkspaceReconciler) addOrUpdateWorkspace(ctx context.Context, wObj *ka if wObj.Tuning != nil { if err = c.applyTuning(ctx, wObj); err != nil { - if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeReady, metav1.ConditionFalse, - "workspaceFailed", err.Error()); updateErr != nil { - klog.ErrorS(updateErr, "failed to update workspace status", "workspace", klog.KObj(wObj)) - return reconcile.Result{}, updateErr - } return reconcile.Result{}, err } } @@ -447,7 +442,7 @@ func (c *WorkspaceReconciler) applyTuning(ctx context.Context, wObj *kaitov1alph existingObj := &batchv1.Job{} if err = resources.GetResource(ctx, wObj.Name, wObj.Namespace, c.Client, existingObj); err == nil { klog.InfoS("A tuning workload already exists for workspace", "workspace", klog.KObj(wObj)) - if err = resources.CheckResourceStatus(existingObj, c.Client, tuningParam.WorkloadTimeout); err != nil { + if err = resources.CheckResourceStatus(existingObj, c.Client, tuningParam.ReadinessTimeout); err != nil { return } } else if apierrors.IsNotFound(err) { @@ -457,7 +452,7 @@ func (c *WorkspaceReconciler) applyTuning(ctx context.Context, wObj *kaitov1alph if err != nil { return } - if err = resources.CheckResourceStatus(workloadObj, c.Client, tuningParam.WorkloadTimeout); err != nil { + if err = resources.CheckResourceStatus(workloadObj, c.Client, tuningParam.ReadinessTimeout); err != nil { return } } @@ -515,7 +510,7 @@ func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1a if err = resources.GetResource(ctx, wObj.Name, wObj.Namespace, c.Client, existingObj); err == nil { klog.InfoS("An inference workload already exists for workspace", "workspace", klog.KObj(wObj)) - if err = resources.CheckResourceStatus(existingObj, c.Client, inferenceParam.WorkloadTimeout); err != nil { + if err = resources.CheckResourceStatus(existingObj, c.Client, inferenceParam.ReadinessTimeout); err != nil { return } } else if apierrors.IsNotFound(err) { @@ -525,7 +520,7 @@ func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1a if err != nil { return } - if err = resources.CheckResourceStatus(workloadObj, c.Client, inferenceParam.WorkloadTimeout); err != nil { + if err = resources.CheckResourceStatus(workloadObj, c.Client, inferenceParam.ReadinessTimeout); err != nil { return } } diff --git a/pkg/model/interface.go b/pkg/model/interface.go index eba2e0d0e..3a054cf25 100644 --- a/pkg/model/interface.go +++ b/pkg/model/interface.go @@ -26,10 +26,10 @@ type PresetParam struct { // BaseCommand is the initial command (e.g., 'torchrun', 'accelerate launch') used in the command line. BaseCommand string ModelRunParams map[string]string // Parameters for running the model training/inference. - // WorkloadTimeout defines the maximum duration for creating the workload. + // ReadinessTimeout defines the maximum duration for creating the workload. // This timeout accommodates the size of the image, ensuring pull completion // even under slower network conditions or unforeseen delays. - WorkloadTimeout time.Duration + ReadinessTimeout time.Duration // WorldSize defines the number of processes required for distributed inference. WorldSize int Tag string // The model image tag diff --git a/pkg/tuning/preset-tuning.go b/pkg/tuning/preset-tuning.go index cbbb55a06..d9dfbd477 100644 --- a/pkg/tuning/preset-tuning.go +++ b/pkg/tuning/preset-tuning.go @@ -10,20 +10,5 @@ import ( func CreatePresetTuning(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, tuningObj *model.PresetParam, kubeClient client.Client) (client.Object, error) { // TODO - - // e.g. example from Inference - //volume, volumeMount := configVolume(workspaceObj, inferenceObj) - //commands, resourceReq := prepareInferenceParameters(ctx, inferenceObj) - //image, imagePullSecrets := GetImageInfo(ctx, workspaceObj, inferenceObj) - // - //depObj = resources.GenerateDeploymentManifest(ctx, workspaceObj, image, imagePullSecrets, *workspaceObj.Resource.Count, commands, - // containerPorts, livenessProbe, readinessProbe, resourceReq, tolerations, volume, volumeMount) - // - //err := resources.CreateResource(ctx, depObj, kubeClient) - //if client.IgnoreAlreadyExists(err) != nil { - // return nil, err - //} - //return depObj, nil - return nil, nil } diff --git a/pkg/utils/testModel.go b/pkg/utils/testModel.go index f03633d7c..5acd05ac5 100644 --- a/pkg/utils/testModel.go +++ b/pkg/utils/testModel.go @@ -15,13 +15,13 @@ type testModel struct{} func (*testModel) GetInferenceParameters() *model.PresetParam { return &model.PresetParam{ GPUCountRequirement: "1", - WorkloadTimeout: time.Duration(30) * time.Minute, + ReadinessTimeout: time.Duration(30) * time.Minute, } } func (*testModel) GetTuningParameters() *model.PresetParam { return &model.PresetParam{ GPUCountRequirement: "1", - WorkloadTimeout: time.Duration(30) * time.Minute, + ReadinessTimeout: time.Duration(30) * time.Minute, } } func (*testModel) SupportDistributedInference() bool { @@ -36,13 +36,13 @@ type testDistributedModel struct{} func (*testDistributedModel) GetInferenceParameters() *model.PresetParam { return &model.PresetParam{ GPUCountRequirement: "1", - WorkloadTimeout: time.Duration(30) * time.Minute, + ReadinessTimeout: time.Duration(30) * time.Minute, } } func (*testDistributedModel) GetTuningParameters() *model.PresetParam { return &model.PresetParam{ GPUCountRequirement: "1", - WorkloadTimeout: time.Duration(30) * time.Minute, + ReadinessTimeout: time.Duration(30) * time.Minute, } } func (*testDistributedModel) SupportDistributedInference() bool { diff --git a/presets/models/falcon/model.go b/presets/models/falcon/model.go index e6774eed1..04f7bd980 100644 --- a/presets/models/falcon/model.go +++ b/presets/models/falcon/model.go @@ -64,7 +64,7 @@ func (*falcon7b) GetInferenceParameters() *model.PresetParam { PerGPUMemoryRequirement: "0Gi", // We run Falcon using native vertical model parallel, no per GPU memory requirement. TorchRunParams: inference.DefaultAccelerateParams, ModelRunParams: falconRunParams, - WorkloadTimeout: time.Duration(30) * time.Minute, + ReadinessTimeout: time.Duration(30) * time.Minute, BaseCommand: baseCommandPresetFalcon, Tag: PresetFalconTagMap["Falcon7B"], } @@ -79,9 +79,9 @@ func (*falcon7b) GetTuningParameters() *model.PresetParam { PerGPUMemoryRequirement: "16Gi", //TorchRunParams: tuning.DefaultAccelerateParams, // TODO //ModelRunPrams: falconRunTuningParams, // TODO - WorkloadTimeout: time.Duration(30) * time.Minute, - BaseCommand: baseCommandPresetFalcon, - Tag: PresetFalconTagMap["Falcon7B"], + ReadinessTimeout: time.Duration(30) * time.Minute, + BaseCommand: baseCommandPresetFalcon, + Tag: PresetFalconTagMap["Falcon7B"], } } @@ -106,7 +106,7 @@ func (*falcon7bInst) GetInferenceParameters() *model.PresetParam { PerGPUMemoryRequirement: "0Gi", // We run Falcon using native vertical model parallel, no per GPU memory requirement. TorchRunParams: inference.DefaultAccelerateParams, ModelRunParams: falconRunParams, - WorkloadTimeout: time.Duration(30) * time.Minute, + ReadinessTimeout: time.Duration(30) * time.Minute, BaseCommand: baseCommandPresetFalcon, Tag: PresetFalconTagMap["Falcon7BInstruct"], } @@ -136,7 +136,7 @@ func (*falcon40b) GetInferenceParameters() *model.PresetParam { PerGPUMemoryRequirement: "0Gi", // We run Falcon using native vertical model parallel, no per GPU memory requirement. TorchRunParams: inference.DefaultAccelerateParams, ModelRunParams: falconRunParams, - WorkloadTimeout: time.Duration(30) * time.Minute, + ReadinessTimeout: time.Duration(30) * time.Minute, BaseCommand: baseCommandPresetFalcon, Tag: PresetFalconTagMap["Falcon40B"], } @@ -152,9 +152,9 @@ func (*falcon40b) GetTuningParameters() *model.PresetParam { PerGPUMemoryRequirement: "16Gi", //TorchRunParams: tuning.DefaultAccelerateParams, // TODO //ModelRunPrams: falconRunTuningParams, // TODO - WorkloadTimeout: time.Duration(30) * time.Minute, - BaseCommand: baseCommandPresetFalcon, - Tag: PresetFalconTagMap["Falcon40B"], + ReadinessTimeout: time.Duration(30) * time.Minute, + BaseCommand: baseCommandPresetFalcon, + Tag: PresetFalconTagMap["Falcon40B"], } } func (*falcon40b) SupportDistributedInference() bool { @@ -178,7 +178,7 @@ func (*falcon40bInst) GetInferenceParameters() *model.PresetParam { PerGPUMemoryRequirement: "0Gi", // We run Falcon using native vertical model parallel, no per GPU memory requirement. TorchRunParams: inference.DefaultAccelerateParams, ModelRunParams: falconRunParams, - WorkloadTimeout: time.Duration(30) * time.Minute, + ReadinessTimeout: time.Duration(30) * time.Minute, BaseCommand: baseCommandPresetFalcon, Tag: PresetFalconTagMap["Falcon40BInstruct"], } diff --git a/presets/models/llama2/model.go b/presets/models/llama2/model.go index c38d9ef4d..6a62a8987 100644 --- a/presets/models/llama2/model.go +++ b/presets/models/llama2/model.go @@ -49,7 +49,7 @@ func (*llama2Text7b) GetInferenceParameters() *model.PresetParam { TorchRunParams: inference.DefaultTorchRunParams, TorchRunRdzvParams: inference.DefaultTorchRunRdzvParams, ModelRunParams: llamaRunParams, - WorkloadTimeout: time.Duration(10) * time.Minute, + ReadinessTimeout: time.Duration(10) * time.Minute, BaseCommand: baseCommandPresetLlama, WorldSize: 1, // Tag: llama has private image access mode. The image tag is determined by the user. @@ -81,7 +81,7 @@ func (*llama2Text13b) GetInferenceParameters() *model.PresetParam { TorchRunParams: inference.DefaultTorchRunParams, TorchRunRdzvParams: inference.DefaultTorchRunRdzvParams, ModelRunParams: llamaRunParams, - WorkloadTimeout: time.Duration(20) * time.Minute, + ReadinessTimeout: time.Duration(20) * time.Minute, BaseCommand: baseCommandPresetLlama, WorldSize: 2, // Tag: llama has private image access mode. The image tag is determined by the user. @@ -112,7 +112,7 @@ func (*llama2Text70b) GetInferenceParameters() *model.PresetParam { TorchRunParams: inference.DefaultTorchRunParams, TorchRunRdzvParams: inference.DefaultTorchRunRdzvParams, ModelRunParams: llamaRunParams, - WorkloadTimeout: time.Duration(30) * time.Minute, + ReadinessTimeout: time.Duration(30) * time.Minute, BaseCommand: baseCommandPresetLlama, WorldSize: 8, // Tag: llama has private image access mode. The image tag is determined by the user. diff --git a/presets/models/llama2chat/model.go b/presets/models/llama2chat/model.go index 1afc17655..89225bef5 100644 --- a/presets/models/llama2chat/model.go +++ b/presets/models/llama2chat/model.go @@ -49,7 +49,7 @@ func (*llama2Chat7b) GetInferenceParameters() *model.PresetParam { TorchRunParams: inference.DefaultTorchRunParams, TorchRunRdzvParams: inference.DefaultTorchRunRdzvParams, ModelRunParams: llamaRunParams, - WorkloadTimeout: time.Duration(10) * time.Minute, + ReadinessTimeout: time.Duration(10) * time.Minute, BaseCommand: baseCommandPresetLlama, WorldSize: 1, // Tag: llama has private image access mode. The image tag is determined by the user. @@ -80,7 +80,7 @@ func (*llama2Chat13b) GetInferenceParameters() *model.PresetParam { TorchRunParams: inference.DefaultTorchRunParams, TorchRunRdzvParams: inference.DefaultTorchRunRdzvParams, ModelRunParams: llamaRunParams, - WorkloadTimeout: time.Duration(20) * time.Minute, + ReadinessTimeout: time.Duration(20) * time.Minute, BaseCommand: baseCommandPresetLlama, WorldSize: 2, // Tag: llama has private image access mode. The image tag is determined by the user. @@ -111,7 +111,7 @@ func (*llama2Chat70b) GetInferenceParameters() *model.PresetParam { TorchRunParams: inference.DefaultTorchRunParams, TorchRunRdzvParams: inference.DefaultTorchRunRdzvParams, ModelRunParams: llamaRunParams, - WorkloadTimeout: time.Duration(30) * time.Minute, + ReadinessTimeout: time.Duration(30) * time.Minute, BaseCommand: baseCommandPresetLlama, WorldSize: 8, // Tag: llama has private image access mode. The image tag is determined by the user. diff --git a/presets/models/mistral/model.go b/presets/models/mistral/model.go index 9a3dc8217..343ba9883 100644 --- a/presets/models/mistral/model.go +++ b/presets/models/mistral/model.go @@ -52,7 +52,7 @@ func (*mistral7b) GetInferenceParameters() *model.PresetParam { PerGPUMemoryRequirement: "0Gi", // We run Mistral using native vertical model parallel, no per GPU memory requirement. TorchRunParams: inference.DefaultAccelerateParams, ModelRunParams: mistralRunParams, - WorkloadTimeout: time.Duration(30) * time.Minute, + ReadinessTimeout: time.Duration(30) * time.Minute, BaseCommand: baseCommandPresetMistral, Tag: PresetMistralTagMap["Mistral7B"], } @@ -68,9 +68,9 @@ func (*mistral7b) GetTuningParameters() *model.PresetParam { PerGPUMemoryRequirement: "16Gi", // We run Mistral using native vertical model parallel, no per GPU memory requirement. //TorchRunParams: tuning.DefaultAccelerateParams, //ModelRunParams: mistralRunParams, - WorkloadTimeout: time.Duration(30) * time.Minute, - BaseCommand: baseCommandPresetMistral, - Tag: PresetMistralTagMap["Mistral7B"], + ReadinessTimeout: time.Duration(30) * time.Minute, + BaseCommand: baseCommandPresetMistral, + Tag: PresetMistralTagMap["Mistral7B"], } } @@ -95,7 +95,7 @@ func (*mistral7bInst) GetInferenceParameters() *model.PresetParam { PerGPUMemoryRequirement: "0Gi", // We run mistral using native vertical model parallel, no per GPU memory requirement. TorchRunParams: inference.DefaultAccelerateParams, ModelRunParams: mistralRunParams, - WorkloadTimeout: time.Duration(30) * time.Minute, + ReadinessTimeout: time.Duration(30) * time.Minute, BaseCommand: baseCommandPresetMistral, Tag: PresetMistralTagMap["Mistral7BInstruct"], } diff --git a/presets/models/phi/model.go b/presets/models/phi/model.go index 189b9d9ec..3eb5388e5 100644 --- a/presets/models/phi/model.go +++ b/presets/models/phi/model.go @@ -46,7 +46,7 @@ func (*phi2) GetInferenceParameters() *model.PresetParam { PerGPUMemoryRequirement: "0Gi", // We run Phi using native vertical model parallel, no per GPU memory requirement. TorchRunParams: inference.DefaultAccelerateParams, ModelRunParams: phiRunParams, - WorkloadTimeout: time.Duration(30) * time.Minute, + ReadinessTimeout: time.Duration(30) * time.Minute, BaseCommand: baseCommandPresetPhi, Tag: PresetPhiTagMap["Phi2"], } @@ -61,9 +61,9 @@ func (*phi2) GetTuningParameters() *model.PresetParam { PerGPUMemoryRequirement: "16Gi", // We run Phi using native vertical model parallel, no per GPU memory requirement. // TorchRunParams: inference.DefaultAccelerateParams, // ModelRunParams: phiRunParams, - WorkloadTimeout: time.Duration(30) * time.Minute, - BaseCommand: baseCommandPresetPhi, - Tag: PresetPhiTagMap["Phi2"], + ReadinessTimeout: time.Duration(30) * time.Minute, + BaseCommand: baseCommandPresetPhi, + Tag: PresetPhiTagMap["Phi2"], } } func (*phi2) SupportDistributedInference() bool { From 6a47da02876aaa974c6b4b4411b743fe312eef9c Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Fri, 29 Mar 2024 15:45:02 -0700 Subject: [PATCH 29/29] fix: remove code --- pkg/controllers/workspace_controller.go | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/pkg/controllers/workspace_controller.go b/pkg/controllers/workspace_controller.go index d5ffe5959..042249f15 100644 --- a/pkg/controllers/workspace_controller.go +++ b/pkg/controllers/workspace_controller.go @@ -460,21 +460,9 @@ func (c *WorkspaceReconciler) applyTuning(ctx context.Context, wObj *kaitov1alph }() if err != nil { - if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeTuningStarted, metav1.ConditionFalse, - "WorkspaceTuningStatusFailed", err.Error()); updateErr != nil { - klog.ErrorS(updateErr, "failed to update workspace status", "workspace", klog.KObj(wObj)) - return updateErr - } else { - return err - - } - } - - if err := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeTuningStarted, metav1.ConditionTrue, - "WorkspaceTuningStatusStarted", "Tuning has been deployed successfully"); err != nil { - klog.ErrorS(err, "failed to update workspace status", "workspace", klog.KObj(wObj)) return err } + return nil }