Skip to content

Commit

Permalink
feat: organize image info (#203)
Browse files Browse the repository at this point in the history
This PR removes Image and ImagePullSecrets out of inferenceParam
  • Loading branch information
ishaansehgal99 authored Jan 16, 2024
1 parent 85f80d8 commit 7c88c4e
Show file tree
Hide file tree
Showing 9 changed files with 100 additions and 80 deletions.
5 changes: 5 additions & 0 deletions api/v1alpha1/workspace_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,11 @@ func (i *InferenceSpec) validateCreate() (errs *apis.FieldError) {
if !isValidPreset(presetName) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported preset name %s", presetName), "presetName"))
}
// Validate private preset has private image specified
if plugin.KaitoModelRegister.MustGet(string(i.Preset.Name)).GetInferenceParameters().ImageAccessMode == "private" &&
i.Preset.PresetMeta.AccessMode != "private" {
errs = errs.Also(apis.ErrGeneric("This preset only supports private AccessMode, AccessMode must be private to continue"))
}
// Additional validations for Preset
if i.Preset.PresetMeta.AccessMode == "private" && i.Preset.PresetOptions.Image == "" {
errs = errs.Also(apis.ErrGeneric("When AccessMode is private, an image must be provided in PresetOptions"))
Expand Down
52 changes: 48 additions & 4 deletions api/v1alpha1/workspace_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,39 @@ func (*testModel) SupportDistributedInference() bool {
return false
}

func RegisterValidationTestModel() {
type testModelPrivate struct{}

func (*testModelPrivate) GetInferenceParameters() *model.PresetInferenceParam {
return &model.PresetInferenceParam{
ImageAccessMode: "private",
GPUCountRequirement: gpuCountRequirement,
TotalGPUMemoryRequirement: totalGPUMemoryRequirement,
PerGPUMemoryRequirement: perGPUMemoryRequirement,
}
}
func (*testModelPrivate) SupportDistributedInference() bool {
return false
}

func RegisterValidationTestModels() {
var test testModel
var testPrivate testModelPrivate
plugin.KaitoModelRegister.Register(&plugin.Registration{
Name: "test-validation",
Instance: &test,
})
plugin.KaitoModelRegister.Register(&plugin.Registration{
Name: "private-test-validation",
Instance: &testPrivate,
})
}

func pointerToInt(i int) *int {
return &i
}

func TestResourceSpecValidateCreate(t *testing.T) {
RegisterValidationTestModel()
RegisterValidationTestModels()
tests := []struct {
name string
resourceSpec *ResourceSpec
Expand Down Expand Up @@ -269,7 +288,7 @@ func TestResourceSpecValidateUpdate(t *testing.T) {
}

func TestInferenceSpecValidateCreate(t *testing.T) {
RegisterValidationTestModel()
RegisterValidationTestModels()
tests := []struct {
name string
inferenceSpec *InferenceSpec
Expand All @@ -285,7 +304,7 @@ func TestInferenceSpecValidateCreate(t *testing.T) {
},
},
},
errContent: "Unsupported preset name",
errContent: "model is not registered",
expectErrs: true,
},
{
Expand Down Expand Up @@ -329,6 +348,19 @@ func TestInferenceSpecValidateCreate(t *testing.T) {
errContent: "When AccessMode is private, an image must be provided",
expectErrs: true,
},
{
name: "Private Preset With Public AccessMode",
inferenceSpec: &InferenceSpec{
Preset: &PresetSpec{
PresetMeta: PresetMeta{
Name: ModelName("private-test-validation"),
},
PresetOptions: PresetOptions{},
},
},
errContent: "This preset only supports private AccessMode, AccessMode must be private to continue",
expectErrs: true,
},
{
name: "Valid Preset",
inferenceSpec: &InferenceSpec{
Expand All @@ -346,6 +378,18 @@ func TestInferenceSpecValidateCreate(t *testing.T) {

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// If the test expects an error, setup defer function to catch the panic.
if tc.expectErrs {
defer func() {
if r := recover(); r != nil {
// Check if the recovered panic matches the expected error content.
if errContent, ok := r.(string); ok && strings.Contains(errContent, tc.errContent) {
return
}
t.Errorf("unexpected panic: %v", r)
}
}()
}
errs := tc.inferenceSpec.validateCreate()
hasErrs := errs != nil
if hasErrs != tc.expectErrs {
Expand Down
15 changes: 0 additions & 15 deletions pkg/controllers/workspace_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
kaitov1alpha1 "github.com/azure/kaito/api/v1alpha1"
"github.com/azure/kaito/pkg/inference"
"github.com/azure/kaito/pkg/machine"
"github.com/azure/kaito/pkg/model"
"github.com/azure/kaito/pkg/resources"
"github.com/azure/kaito/pkg/utils"
"github.com/azure/kaito/pkg/utils/plugin"
Expand Down Expand Up @@ -419,19 +418,6 @@ func (c *WorkspaceReconciler) ensureService(ctx context.Context, wObj *kaitov1al
return nil
}

func (c *WorkspaceReconciler) updateInferenceParamFromWorkspace(ctx context.Context, wObj *kaitov1alpha1.Workspace, inferenceParam *model.PresetInferenceParam) {
inferenceParam.ImageAccessMode = string(wObj.Inference.Preset.PresetMeta.AccessMode)
if inferenceParam.ImageAccessMode == "private" && wObj.Inference.Preset.PresetOptions.Image != "" {
inferenceParam.Image = wObj.Inference.Preset.PresetOptions.Image

imagePullSecretRefs := []corev1.LocalObjectReference{}
for _, secretName := range wObj.Inference.Preset.PresetOptions.ImagePullSecrets {
imagePullSecretRefs = append(imagePullSecretRefs, corev1.LocalObjectReference{Name: secretName})
}
inferenceParam.ImagePullSecrets = imagePullSecretRefs
}
}

// applyInference applies inference spec.
func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1alpha1.Workspace) error {
var err error
Expand All @@ -452,7 +438,6 @@ func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1a

inferenceParam := model.GetInferenceParameters()

c.updateInferenceParamFromWorkspace(ctx, wObj, inferenceParam)
// TODO: we only do create if it does not exist for preset model. Need to document it.

var existingObj client.Object
Expand Down
1 change: 0 additions & 1 deletion pkg/inference/preset-inference-types.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,5 @@ var (
"gpu_ids": DefaultGPUIds,
}

DefaultImageAccessMode = "public"
DefaultImagePullSecrets = []corev1.LocalObjectReference{}
)
28 changes: 23 additions & 5 deletions pkg/inference/preset-inferences.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package inference
import (
"context"
"fmt"
"os"
"strconv"

kaitov1alpha1 "github.com/azure/kaito/api/v1alpha1"
Expand All @@ -18,9 +19,9 @@ import (
)

const (
ProbePath = "/healthz"
Port5000 = int32(5000)
InferenceFile = "inference-api.py"
ProbePath = "/healthz"
Port5000 = int32(5000)
InferenceFile = "inference-api.py"
DefaultVolumeMountPath = "/dev/shm"
)

Expand Down Expand Up @@ -91,6 +92,22 @@ func updateTorchParamsForDistributedInference(ctx context.Context, kubeClient cl
return nil
}

func GetImageInfo(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetInferenceParam) (string, []corev1.LocalObjectReference) {
imageName := string(workspaceObj.Inference.Preset.Name)
imagePullSecretRefs := []corev1.LocalObjectReference{}
if inferenceObj.ImageAccessMode == "private" {
imageName = string(workspaceObj.Inference.Preset.PresetOptions.Image)
for _, secretName := range workspaceObj.Inference.Preset.PresetOptions.ImagePullSecrets {
imagePullSecretRefs = append(imagePullSecretRefs, corev1.LocalObjectReference{Name: secretName})
}
return imageName, imagePullSecretRefs
}

registryName := os.Getenv("PRESET_REGISTRY_NAME")
imageName = registryName + fmt.Sprintf("/kaito-%s:0.0.1", imageName)
return imageName, imagePullSecretRefs
}

func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace,
inferenceObj *model.PresetInferenceParam, supportDistributedInference bool, kubeClient client.Client) (client.Object, error) {
if inferenceObj.TorchRunParams != nil && supportDistributedInference {
Expand All @@ -102,13 +119,14 @@ func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Work

volume, volumeMount := configVolume(workspaceObj, inferenceObj)
commands, resourceReq := prepareInferenceParameters(ctx, inferenceObj)
image, imagePullSecrets := GetImageInfo(ctx, workspaceObj, inferenceObj)

var depObj client.Object
if supportDistributedInference {
depObj = resources.GenerateStatefulSetManifest(ctx, workspaceObj, inferenceObj.Image, inferenceObj.ImagePullSecrets, *workspaceObj.Resource.Count, commands,
depObj = resources.GenerateStatefulSetManifest(ctx, workspaceObj, image, imagePullSecrets, *workspaceObj.Resource.Count, commands,
containerPorts, livenessProbe, readinessProbe, resourceReq, tolerations, volume, volumeMount)
} else {
depObj = resources.GenerateDeploymentManifest(ctx, workspaceObj, inferenceObj.Image, inferenceObj.ImagePullSecrets, *workspaceObj.Resource.Count, commands,
depObj = resources.GenerateDeploymentManifest(ctx, workspaceObj, image, imagePullSecrets, *workspaceObj.Resource.Count, commands,
containerPorts, livenessProbe, readinessProbe, resourceReq, tolerations, volume, volumeMount)
}
err := resources.CreateResource(ctx, depObj, kubeClient)
Expand Down
24 changes: 10 additions & 14 deletions pkg/model/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ package model

import (
"time"

corev1 "k8s.io/api/core/v1"
)

type Model interface {
Expand All @@ -15,23 +13,21 @@ type Model interface {

// PresetInferenceParam defines the preset inference parameters for a model.
type PresetInferenceParam struct {
ModelFamilyName string // The name of the model family.
Image string // Docker image used for running the inference.
ImagePullSecrets []corev1.LocalObjectReference // Secrets for pulling the image from a private registry.
ImageAccessMode string // Defines where the Image is Public or Private.
DiskStorageRequirement string // Disk storage requirements for the model.
GPUCountRequirement string // Number of GPUs required for the inference.
TotalGPUMemoryRequirement string // Total GPU memory required for the inference.
PerGPUMemoryRequirement string // GPU memory required per GPU.
TorchRunParams map[string]string // Parameters for configuring the torchrun command.
TorchRunRdzvParams map[string]string // Optional rendezvous parameters for distributed inference using torchrun (elastic).
ModelRunParams map[string]string // Parameters for running the model inference.
ModelFamilyName string // The name of the model family.
ImageAccessMode string // Defines where the Image is Public or Private.
DiskStorageRequirement string // Disk storage requirements for the model.
GPUCountRequirement string // Number of GPUs required for the inference.
TotalGPUMemoryRequirement string // Total GPU memory required for the inference.
PerGPUMemoryRequirement string // GPU memory required per GPU.
TorchRunParams map[string]string // Parameters for configuring the torchrun command.
TorchRunRdzvParams map[string]string // Optional rendezvous parameters for distributed inference using torchrun (elastic).
ModelRunParams map[string]string // Parameters for running the model inference.
// DeploymentTimeout defines the maximum duration for pulling the Preset image.
// This timeout accommodates the size of the image, ensuring pull completion
// even under slower network conditions or unforeseen delays.
DeploymentTimeout time.Duration
// BaseCommand is the initial command (e.g., 'torchrun', 'accelerate launch') used in the command line.
BaseCommand string
// WorldSize defines the number of processes required for distributed inference.
WorldSize int
WorldSize int
}
29 changes: 6 additions & 23 deletions presets/models/falcon/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
package falcon

import (
"fmt"
"os"
"time"

kaitov1alpha1 "github.com/azure/kaito/api/v1alpha1"
"github.com/azure/kaito/pkg/inference"
"github.com/azure/kaito/pkg/model"
"github.com/azure/kaito/pkg/utils/plugin"
Expand All @@ -32,19 +31,11 @@ func init() {
}

var (
registryName = os.Getenv("PRESET_REGISTRY_NAME")

PresetFalcon7BModel = "falcon-7b"
PresetFalcon40BModel = "falcon-40b"
PresetFalcon7BInstructModel = PresetFalcon7BModel + "-instruct"
PresetFalcon40BInstructModel = PresetFalcon40BModel + "-instruct"

presetFalcon7bImage = registryName + fmt.Sprintf("/kaito-%s:0.0.1", PresetFalcon7BModel)
presetFalcon7bInstructImage = registryName + fmt.Sprintf("/kaito-%s:0.0.1", PresetFalcon7BInstructModel)

presetFalcon40bImage = registryName + fmt.Sprintf("/kaito-%s:0.0.1", PresetFalcon40BModel)
presetFalcon40bInstructImage = registryName + fmt.Sprintf("/kaito-%s:0.0.1", PresetFalcon40BInstructModel)

baseCommandPresetFalcon = "accelerate launch --use_deepspeed"
falconRunParams = map[string]string{}
)
Expand All @@ -56,9 +47,7 @@ type falcon7b struct{}
func (*falcon7b) GetInferenceParameters() *model.PresetInferenceParam {
return &model.PresetInferenceParam{
ModelFamilyName: "Falcon",
Image: presetFalcon7bImage,
ImagePullSecrets: inference.DefaultImagePullSecrets,
ImageAccessMode: inference.DefaultImageAccessMode,
ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic),
DiskStorageRequirement: "50Gi",
GPUCountRequirement: "1",
TotalGPUMemoryRequirement: "14Gi",
Expand All @@ -81,9 +70,7 @@ type falcon7bInst struct{}
func (*falcon7bInst) GetInferenceParameters() *model.PresetInferenceParam {
return &model.PresetInferenceParam{
ModelFamilyName: "Falcon",
Image: presetFalcon7bInstructImage,
ImagePullSecrets: inference.DefaultImagePullSecrets,
ImageAccessMode: inference.DefaultImageAccessMode,
ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic),
DiskStorageRequirement: "50Gi",
GPUCountRequirement: "1",
TotalGPUMemoryRequirement: "14Gi",
Expand All @@ -106,9 +93,7 @@ type falcon40b struct{}
func (*falcon40b) GetInferenceParameters() *model.PresetInferenceParam {
return &model.PresetInferenceParam{
ModelFamilyName: "Falcon",
Image: presetFalcon40bImage,
ImagePullSecrets: inference.DefaultImagePullSecrets,
ImageAccessMode: inference.DefaultImageAccessMode,
ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic),
DiskStorageRequirement: "400",
GPUCountRequirement: "2",
TotalGPUMemoryRequirement: "90Gi",
Expand All @@ -131,9 +116,7 @@ type falcon40bInst struct{}
func (*falcon40bInst) GetInferenceParameters() *model.PresetInferenceParam {
return &model.PresetInferenceParam{
ModelFamilyName: "Falcon",
Image: presetFalcon40bInstructImage,
ImagePullSecrets: inference.DefaultImagePullSecrets,
ImageAccessMode: inference.DefaultImageAccessMode,
ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic),
DiskStorageRequirement: "400",
GPUCountRequirement: "2",
TotalGPUMemoryRequirement: "90Gi",
Expand All @@ -143,8 +126,8 @@ func (*falcon40bInst) GetInferenceParameters() *model.PresetInferenceParam {
DeploymentTimeout: time.Duration(30) * time.Minute,
BaseCommand: baseCommandPresetFalcon,
}

}

func (*falcon40bInst) SupportDistributedInference() bool {
return false
}
13 changes: 4 additions & 9 deletions presets/models/llama2/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package llama2
import (
"time"

kaitov1alpha1 "github.com/azure/kaito/api/v1alpha1"
"github.com/azure/kaito/pkg/inference"
"github.com/azure/kaito/pkg/model"
"github.com/azure/kaito/pkg/utils/plugin"
Expand Down Expand Up @@ -40,9 +41,7 @@ type llama2Text7b struct{}
func (*llama2Text7b) GetInferenceParameters() *model.PresetInferenceParam {
return &model.PresetInferenceParam{
ModelFamilyName: "LLaMa2",
Image: "",
ImagePullSecrets: inference.DefaultImagePullSecrets,
ImageAccessMode: inference.DefaultImageAccessMode,
ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePrivate),
DiskStorageRequirement: "34Gi",
GPUCountRequirement: "1",
TotalGPUMemoryRequirement: "14Gi",
Expand All @@ -67,9 +66,7 @@ type llama2Text13b struct{}
func (*llama2Text13b) GetInferenceParameters() *model.PresetInferenceParam {
return &model.PresetInferenceParam{
ModelFamilyName: "LLaMa2",
Image: "",
ImagePullSecrets: inference.DefaultImagePullSecrets,
ImageAccessMode: inference.DefaultImageAccessMode,
ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePrivate),
DiskStorageRequirement: "46Gi",
GPUCountRequirement: "2",
TotalGPUMemoryRequirement: "30Gi",
Expand All @@ -93,9 +90,7 @@ type llama2Text70b struct{}
func (*llama2Text70b) GetInferenceParameters() *model.PresetInferenceParam {
return &model.PresetInferenceParam{
ModelFamilyName: "LLaMa2",
Image: "",
ImagePullSecrets: inference.DefaultImagePullSecrets,
ImageAccessMode: inference.DefaultImageAccessMode,
ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePrivate),
DiskStorageRequirement: "158Gi",
GPUCountRequirement: "8",
TotalGPUMemoryRequirement: "152Gi",
Expand Down
Loading

0 comments on commit 7c88c4e

Please sign in to comment.