Skip to content

Commit

Permalink
test updates
Browse files Browse the repository at this point in the history
  • Loading branch information
smritidahal653 committed Jun 7, 2024
1 parent 1be92bb commit aa202c6
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 178 deletions.
85 changes: 0 additions & 85 deletions api/v1alpha1/sku_config.go

This file was deleted.

20 changes: 12 additions & 8 deletions api/v1alpha1/workspace_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func (r *TuningSpec) validateCreate(ctx context.Context, workspaceNamespace stri
// Currently require a preset to specified, in future we can consider defining a template
if r.Preset == nil {
errs = errs.Also(apis.ErrMissingField("Preset"))
} else if presetName := string(r.Preset.Name); !isValidPreset(presetName) {
} else if presetName := string(r.Preset.Name); !utils.IsValidPreset(presetName) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported tuning preset name %s", presetName), "presetName"))
}
return errs
Expand Down Expand Up @@ -258,8 +258,15 @@ func (r *ResourceSpec) validateCreate(inference InferenceSpec) (errs *apis.Field
}
instanceType := string(r.InstanceType)

// Check if instancetype exists in our SKUs map
if skuConfig, exists := SupportedGPUConfigs[instanceType]; exists {
skuHandler, err := utils.GetSKUHandler()
if err != nil {
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Failed to get SKU handler: %v", err), "instanceType"))
return errs
}
gpuConfigs := skuHandler.GetGPUConfigs()

// Check if instancetype exists in our SKUs map for the particular skuHandler
if skuConfig, exists := gpuConfigs[instanceType]; exists {
if inference.Preset != nil {
model := plugin.KaitoModelRegister.MustGet(presetName) // InferenceSpec has been validated so the name is valid.
// Validate GPU count for given SKU
Expand All @@ -284,10 +291,7 @@ func (r *ResourceSpec) validateCreate(inference InferenceSpec) (errs *apis.Field
}
}
} else {
// Check for other instance types pattern matches
if !strings.HasPrefix(instanceType, N_SERIES_PREFIX) && !strings.HasPrefix(instanceType, D_SERIES_PREFIX) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported instance type %s. Supported SKUs: %s", instanceType, getSupportedSKUs()), "instanceType"))
}
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported instance type %s. Supported SKUs: %s", instanceType, skuHandler.GetSupportedSKUs()), "instanceType"))
}

// Validate labelSelector
Expand Down Expand Up @@ -332,7 +336,7 @@ func (i *InferenceSpec) validateCreate() (errs *apis.FieldError) {
if i.Preset != nil {
presetName := string(i.Preset.Name)
// Validate preset name
if !isValidPreset(presetName) {
if !utils.IsValidPreset(presetName) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported inference preset name %s", presetName), "presetName"))
}
// Validate private preset has private image specified
Expand Down
54 changes: 4 additions & 50 deletions api/v1alpha1/workspace_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ package v1alpha1
import (
"context"
"os"
"reflect"
"sort"
"strings"
"testing"

Expand Down Expand Up @@ -204,7 +202,7 @@ func TestResourceSpecValidateCreate(t *testing.T) {
{
name: "Insufficient total GPU memory",
resourceSpec: &ResourceSpec{
InstanceType: "Standard_NC6",
InstanceType: "Standard_NC6s_v3",
Count: pointerToInt(1),
},
modelGPUCount: "1",
Expand All @@ -231,7 +229,7 @@ func TestResourceSpecValidateCreate(t *testing.T) {
{
name: "Insufficient per GPU memory",
resourceSpec: &ResourceSpec{
InstanceType: "Standard_NC6",
InstanceType: "Standard_NC6s_v3",
Count: pointerToInt(2),
},
modelGPUCount: "1",
Expand Down Expand Up @@ -282,6 +280,8 @@ func TestResourceSpecValidateCreate(t *testing.T) {
},
}

os.Setenv("CLOUD_PROVIDER", "azure")

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
var spec InferenceSpec
Expand Down Expand Up @@ -1228,49 +1228,3 @@ func TestDataDestinationValidateUpdate(t *testing.T) {
})
}
}

func TestGetSupportedSKUs(t *testing.T) {
tests := []struct {
name string
gpuConfigs map[string]GPUConfig
expectedResult []string // changed to a slice for deterministic ordering
}{
{
name: "no SKUs supported",
gpuConfigs: map[string]GPUConfig{},
expectedResult: []string{""},
},
{
name: "one SKU supported",
gpuConfigs: map[string]GPUConfig{
"Standard_NC6": {SKU: "Standard_NC6"},
},
expectedResult: []string{"Standard_NC6"},
},
{
name: "multiple SKUs supported",
gpuConfigs: map[string]GPUConfig{
"Standard_NC6": {SKU: "Standard_NC6"},
"Standard_NC12": {SKU: "Standard_NC12"},
},
expectedResult: []string{"Standard_NC6", "Standard_NC12"},
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
SupportedGPUConfigs = tc.gpuConfigs

resultSlice := strings.Split(getSupportedSKUs(), ", ")
sort.Strings(resultSlice)

// Sort the expectedResult for comparison
expectedResultSlice := tc.expectedResult
sort.Strings(expectedResultSlice)

if !reflect.DeepEqual(resultSlice, expectedResultSlice) {
t.Errorf("getSupportedSKUs() = %v, want %v", resultSlice, expectedResultSlice)
}
})
}
}
20 changes: 0 additions & 20 deletions api/v1alpha1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 7 additions & 2 deletions pkg/controllers/workspace_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"context"
"fmt"
"sort"
"strings"
"time"

"github.com/azure/kaito/pkg/featuregates"
Expand Down Expand Up @@ -257,8 +256,14 @@ func (c *WorkspaceReconciler) applyWorkspaceResource(ctx context.Context, wObj *
}
}

skuHandler, err := utils.GetSKUHandler()
if err != nil {
return err
}
gpuConfigs := skuHandler.GetGPUConfigs()

// Ensure all gpu plugins are running successfully.
if strings.Contains(wObj.Resource.InstanceType, gpuSkuPrefix) { // GPU skus
if _, exists := gpuConfigs[wObj.Resource.InstanceType]; exists {
for i := range selectedNodes {
err = c.ensureNodePlugins(ctx, wObj, selectedNodes[i])
if err != nil {
Expand Down
20 changes: 7 additions & 13 deletions pkg/tuning/preset-tuning_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,6 @@ import (
"k8s.io/utils/pointer"
)

// Mocking the SupportedGPUConfigs to be used in test scenarios.
var mockSupportedGPUConfigs = map[string]kaitov1alpha1.GPUConfig{
"sku1": {GPUCount: 2},
"sku2": {GPUCount: 4},
"sku3": {GPUCount: 0},
}

func normalize(s string) string {
return strings.Join(strings.Fields(s), " ")
}
Expand All @@ -54,18 +47,19 @@ func saveEnv(key string) func() {
}

func TestGetInstanceGPUCount(t *testing.T) {
kaitov1alpha1.SupportedGPUConfigs = mockSupportedGPUConfigs
os.Setenv("CLOUD_PROVIDER", "azure")

testcases := map[string]struct {
sku string
expectedGPUCount int
}{
"SKU Exists With Multiple GPUs": {
sku: "sku1",
expectedGPUCount: 2,
sku: "Standard_NC24s_v3",
expectedGPUCount: 4,
},
"SKU Exists With Zero GPUs": {
sku: "sku3",
expectedGPUCount: 0,
"SKU Exists With Two GPUs": {
sku: "Standard_NC12s_v3",
expectedGPUCount: 2,
},
"SKU Does Not Exist": {
sku: "sku_unknown",
Expand Down

0 comments on commit aa202c6

Please sign in to comment.