Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Initialize Fine-Tuning Interface and Core Methods - Part 3 #308

Merged
merged 38 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
da63f93
feat: spec level validation
ishaansehgal99 Mar 18, 2024
a4f45e6
feat: Added validation checks for TuningSpec, DataSource, DataDestina…
ishaansehgal99 Mar 18, 2024
a9bbe7a
fix: prevent toggling
ishaansehgal99 Mar 18, 2024
d73ef65
fix: validation fixes
ishaansehgal99 Mar 19, 2024
3fa0e46
feat: Add UTs for workspace validation
ishaansehgal99 Mar 19, 2024
13153f3
Merge branch 'main' into Ishaan/fine-tuning
ishaansehgal99 Mar 19, 2024
5a6b64f
Merge branch 'main' of https://github.com/Azure/kaito into Ishaan/fin…
ishaansehgal99 Mar 19, 2024
392ff40
fix: Update CRD to use pointers
ishaansehgal99 Mar 20, 2024
b2e2b26
Merge branch 'main' of https://github.com/Azure/kaito into Ishaan/fin…
ishaansehgal99 Mar 20, 2024
6c347c9
fix: Add name flag
ishaansehgal99 Mar 20, 2024
1a14872
feat: Setup Interface for fine tuning
ishaansehgal99 Mar 20, 2024
b21b99f
feat: spec level validation
ishaansehgal99 Mar 18, 2024
75cacae
feat: Added validation checks for TuningSpec, DataSource, DataDestina…
ishaansehgal99 Mar 18, 2024
5f9e132
fix: prevent toggling
ishaansehgal99 Mar 18, 2024
9f4e820
fix: validation fixes
ishaansehgal99 Mar 19, 2024
a16351d
feat: Add UTs for workspace validation
ishaansehgal99 Mar 19, 2024
d2cd230
fix: Update CRD to use pointers
ishaansehgal99 Mar 20, 2024
082844d
fix: Add name flag
ishaansehgal99 Mar 20, 2024
56409a7
feat: Setup Interface for fine tuning
ishaansehgal99 Mar 20, 2024
ffa809a
Merge branch 'Ishaan/ft' of https://github.com/Azure/kaito into Ishaa…
ishaansehgal99 Mar 20, 2024
1a4b5ac
feat: Added validation checks for TuningSpec, DataSource, DataDestina…
ishaansehgal99 Mar 18, 2024
29210a7
fix: validation fixes
ishaansehgal99 Mar 19, 2024
4a93976
feat: Setup Interface for fine tuning
ishaansehgal99 Mar 20, 2024
36977c4
feat: spec level validation
ishaansehgal99 Mar 18, 2024
b738b47
feat: Added validation checks for TuningSpec, DataSource, DataDestina…
ishaansehgal99 Mar 18, 2024
1a73ca4
Merge branch 'Ishaan/ft' of https://github.com/Azure/kaito into Ishaa…
ishaansehgal99 Mar 20, 2024
33c6b24
fix: Add required training func for tests
ishaansehgal99 Mar 21, 2024
9b8ee56
fix: Add training func for phi
ishaansehgal99 Mar 21, 2024
bbfd6f9
fix: Add support training method
ishaansehgal99 Mar 21, 2024
931b037
fix: lint issue
ishaansehgal99 Mar 21, 2024
107c501
fix: workspace condition
ishaansehgal99 Mar 21, 2024
c80eada
fix: workspace condition
ishaansehgal99 Mar 21, 2024
3932bf4
Merge branch 'main' of https://github.com/Azure/kaito into Ishaan/ft
ishaansehgal99 Mar 26, 2024
565e2ae
Merge branch 'main' into Ishaan/ft
ishaansehgal99 Mar 29, 2024
860108f
fix: address comments
ishaansehgal99 Mar 29, 2024
76b10e4
Merge branch 'Ishaan/ft' of https://github.com/Azure/kaito into Ishaa…
ishaansehgal99 Mar 29, 2024
56609e9
Merge branch 'main' of https://github.com/Azure/kaito into Ishaan/ft
ishaansehgal99 Mar 29, 2024
6a47da0
fix: remove code
ishaansehgal99 Mar 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions api/v1alpha1/workspace_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,15 @@ var perGPUMemoryRequirement string

type testModel struct{}

func (*testModel) GetInferenceParameters() *model.PresetInferenceParam {
return &model.PresetInferenceParam{
func (*testModel) GetInferenceParameters() *model.PresetParam {
return &model.PresetParam{
GPUCountRequirement: gpuCountRequirement,
TotalGPUMemoryRequirement: totalGPUMemoryRequirement,
PerGPUMemoryRequirement: perGPUMemoryRequirement,
}
}
func (*testModel) GetTuningParameters() *model.PresetParam {
return &model.PresetParam{
GPUCountRequirement: gpuCountRequirement,
TotalGPUMemoryRequirement: totalGPUMemoryRequirement,
PerGPUMemoryRequirement: perGPUMemoryRequirement,
Expand All @@ -31,11 +38,22 @@ func (*testModel) GetInferenceParameters() *model.PresetInferenceParam {
func (*testModel) SupportDistributedInference() bool {
return false
}
func (*testModel) SupportTuning() bool {
return true
}

type testModelPrivate struct{}

func (*testModelPrivate) GetInferenceParameters() *model.PresetInferenceParam {
return &model.PresetInferenceParam{
func (*testModelPrivate) GetInferenceParameters() *model.PresetParam {
return &model.PresetParam{
ImageAccessMode: "private",
GPUCountRequirement: gpuCountRequirement,
TotalGPUMemoryRequirement: totalGPUMemoryRequirement,
PerGPUMemoryRequirement: perGPUMemoryRequirement,
}
}
func (*testModelPrivate) GetTuningParameters() *model.PresetParam {
return &model.PresetParam{
ImageAccessMode: "private",
GPUCountRequirement: gpuCountRequirement,
TotalGPUMemoryRequirement: totalGPUMemoryRequirement,
Expand All @@ -45,6 +63,9 @@ func (*testModelPrivate) GetInferenceParameters() *model.PresetInferenceParam {
func (*testModelPrivate) SupportDistributedInference() bool {
return false
}
func (*testModelPrivate) SupportTuning() bool {
return true
}

func RegisterValidationTestModels() {
var test testModel
Expand Down
20 changes: 20 additions & 0 deletions examples/fine-tuning/kaito_workspace_tuning_falcon_7b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
apiVersion: kaito.sh/v1alpha1
kind: Workspace
metadata:
name: workspace-tuning-falcon-7b
spec:
resource:
instanceType: "Standard_NC12s_v3"
labelSelector:
matchLabels:
app: tuning-falcon-7b
tuning:
preset:
name: falcon-7b
method: lora
config: tuning-config-map # ConfigMap containing tuning arguments
input:
name: tuning-data
hostPath: /path/to/your/input/data # dataset on node
output:
hostPath: /path/to/store/output # Tuning Output
65 changes: 54 additions & 11 deletions pkg/controllers/workspace_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (
"strings"
"time"

appsv1 "k8s.io/api/apps/v1"
"k8s.io/utils/clock"
"github.com/azure/kaito/pkg/tuning"
batchv1 "k8s.io/api/batch/v1"

"github.com/aws/karpenter-core/pkg/apis/v1alpha5"
kaitov1alpha1 "github.com/azure/kaito/api/v1alpha1"
Expand All @@ -21,13 +21,15 @@ import (
"github.com/azure/kaito/pkg/utils/plugin"
"github.com/go-logr/logr"
"github.com/samber/lo"
appsv1 "k8s.io/api/apps/v1"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/client-go/tools/record"
"k8s.io/klog/v2"
"k8s.io/utils/clock"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/controller"
Expand Down Expand Up @@ -109,16 +111,22 @@ func (c *WorkspaceReconciler) addOrUpdateWorkspace(ctx context.Context, wObj *ka
return reconcile.Result{}, err
}

if err = c.applyInference(ctx, wObj); err != nil {
if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeReady, metav1.ConditionFalse,
"workspaceFailed", err.Error()); updateErr != nil {
klog.ErrorS(updateErr, "failed to update workspace status", "workspace", klog.KObj(wObj))
return reconcile.Result{}, updateErr
if wObj.Tuning != nil {
if err = c.applyTuning(ctx, wObj); err != nil {
return reconcile.Result{}, err
}
ishaansehgal99 marked this conversation as resolved.
Show resolved Hide resolved
}
if wObj.Inference != nil {
if err = c.applyInference(ctx, wObj); err != nil {
if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeReady, metav1.ConditionFalse,
"workspaceFailed", err.Error()); updateErr != nil {
klog.ErrorS(updateErr, "failed to update workspace status", "workspace", klog.KObj(wObj))
return reconcile.Result{}, updateErr
}
return reconcile.Result{}, err
}
return reconcile.Result{}, err
}

// TODO apply TrainingSpec
if err = c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeReady, metav1.ConditionTrue,
"workspaceReady", "workspace is ready"); err != nil {
klog.ErrorS(err, "failed to update workspace status", "workspace", klog.KObj(wObj))
Expand Down Expand Up @@ -423,6 +431,41 @@ func (c *WorkspaceReconciler) ensureService(ctx context.Context, wObj *kaitov1al
return nil
}

func (c *WorkspaceReconciler) applyTuning(ctx context.Context, wObj *kaitov1alpha1.Workspace) error {
var err error
func() {
if wObj.Tuning.Preset != nil {
presetName := string(wObj.Tuning.Preset.Name)
model := plugin.KaitoModelRegister.MustGet(presetName)

tuningParam := model.GetTuningParameters()
existingObj := &batchv1.Job{}
if err = resources.GetResource(ctx, wObj.Name, wObj.Namespace, c.Client, existingObj); err == nil {
klog.InfoS("A tuning workload already exists for workspace", "workspace", klog.KObj(wObj))
if err = resources.CheckResourceStatus(existingObj, c.Client, tuningParam.ReadinessTimeout); err != nil {
return
}
} else if apierrors.IsNotFound(err) {
var workloadObj client.Object
// Need to create a new workload
workloadObj, err = tuning.CreatePresetTuning(ctx, wObj, tuningParam, c.Client)
if err != nil {
return
}
if err = resources.CheckResourceStatus(workloadObj, c.Client, tuningParam.ReadinessTimeout); err != nil {
return
}
}
}
}()

if err != nil {
return err
}

return nil
}

// applyInference applies inference spec.
func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1alpha1.Workspace) error {
var err error
Expand Down Expand Up @@ -455,7 +498,7 @@ func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1a

if err = resources.GetResource(ctx, wObj.Name, wObj.Namespace, c.Client, existingObj); err == nil {
klog.InfoS("An inference workload already exists for workspace", "workspace", klog.KObj(wObj))
if err = resources.CheckResourceStatus(existingObj, c.Client, inferenceParam.DeploymentTimeout); err != nil {
if err = resources.CheckResourceStatus(existingObj, c.Client, inferenceParam.ReadinessTimeout); err != nil {
return
}
} else if apierrors.IsNotFound(err) {
Expand All @@ -465,7 +508,7 @@ func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1a
if err != nil {
return
}
if err = resources.CheckResourceStatus(workloadObj, c.Client, inferenceParam.DeploymentTimeout); err != nil {
if err = resources.CheckResourceStatus(workloadObj, c.Client, inferenceParam.ReadinessTimeout); err != nil {
return
}
}
Expand Down
10 changes: 5 additions & 5 deletions pkg/inference/preset-inferences.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ var (
}
)

func updateTorchParamsForDistributedInference(ctx context.Context, kubeClient client.Client, wObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetInferenceParam) error {
func updateTorchParamsForDistributedInference(ctx context.Context, kubeClient client.Client, wObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetParam) error {
existingService := &corev1.Service{}
err := resources.GetResource(ctx, wObj.Name, wObj.Namespace, kubeClient, existingService)
if err != nil {
Expand All @@ -92,7 +92,7 @@ func updateTorchParamsForDistributedInference(ctx context.Context, kubeClient cl
return nil
}

func GetImageInfo(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetInferenceParam) (string, []corev1.LocalObjectReference) {
func GetImageInfo(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetParam) (string, []corev1.LocalObjectReference) {
imageName := string(workspaceObj.Inference.Preset.Name)
imageTag := inferenceObj.Tag
imagePullSecretRefs := []corev1.LocalObjectReference{}
Expand All @@ -110,7 +110,7 @@ func GetImageInfo(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, in
}

func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace,
inferenceObj *model.PresetInferenceParam, supportDistributedInference bool, kubeClient client.Client) (client.Object, error) {
inferenceObj *model.PresetParam, supportDistributedInference bool, kubeClient client.Client) (client.Object, error) {
if inferenceObj.TorchRunParams != nil && supportDistributedInference {
if err := updateTorchParamsForDistributedInference(ctx, kubeClient, workspaceObj, inferenceObj); err != nil {
klog.ErrorS(err, "failed to update torch params", "workspace", workspaceObj)
Expand Down Expand Up @@ -141,7 +141,7 @@ func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Work
// torchrun <TORCH_PARAMS> <OPTIONAL_RDZV_PARAMS> baseCommand <MODEL_PARAMS>
// and sets the GPU resources required for inference.
// Returns the command and resource configuration.
func prepareInferenceParameters(ctx context.Context, inferenceObj *model.PresetInferenceParam) ([]string, corev1.ResourceRequirements) {
func prepareInferenceParameters(ctx context.Context, inferenceObj *model.PresetParam) ([]string, corev1.ResourceRequirements) {
torchCommand := buildCommandStr(inferenceObj.BaseCommand, inferenceObj.TorchRunParams)
torchCommand = buildCommandStr(torchCommand, inferenceObj.TorchRunRdzvParams)
modelCommand := buildCommandStr(InferenceFile, inferenceObj.ModelRunParams)
Expand All @@ -159,7 +159,7 @@ func prepareInferenceParameters(ctx context.Context, inferenceObj *model.PresetI
return commands, resourceRequirements
}

func configVolume(wObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetInferenceParam) ([]corev1.Volume, []corev1.VolumeMount) {
func configVolume(wObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetParam) ([]corev1.Volume, []corev1.VolumeMount) {
volume := []corev1.Volume{}
volumeMount := []corev1.VolumeMount{}

Expand Down
2 changes: 1 addition & 1 deletion pkg/inference/preset-inferences_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func TestCreatePresetInference(t *testing.T) {

useHeadlessSvc := false

var inferenceObj *model.PresetInferenceParam
var inferenceObj *model.PresetParam
model := plugin.KaitoModelRegister.MustGet(tc.modelName)
inferenceObj = model.GetInferenceParameters()

Expand Down
24 changes: 13 additions & 11 deletions pkg/model/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,29 @@ import (
)

type Model interface {
GetInferenceParameters() *PresetInferenceParam
GetInferenceParameters() *PresetParam
GetTuningParameters() *PresetParam
SupportDistributedInference() bool //If true, the model workload will be a StatefulSet, using the torch elastic runtime framework.
SupportTuning() bool
}

// PresetInferenceParam defines the preset inference parameters for a model.
type PresetInferenceParam struct {
// PresetParam defines the preset inference parameters for a model.
type PresetParam struct {
ishaansehgal99 marked this conversation as resolved.
Show resolved Hide resolved
ModelFamilyName string // The name of the model family.
ImageAccessMode string // Defines where the Image is Public or Private.
DiskStorageRequirement string // Disk storage requirements for the model.
GPUCountRequirement string // Number of GPUs required for the inference.
TotalGPUMemoryRequirement string // Total GPU memory required for the inference.
GPUCountRequirement string // Number of GPUs required for the Preset.
TotalGPUMemoryRequirement string // Total GPU memory required for the Preset.
PerGPUMemoryRequirement string // GPU memory required per GPU.
TorchRunParams map[string]string // Parameters for configuring the torchrun command.
TorchRunRdzvParams map[string]string // Optional rendezvous parameters for distributed inference using torchrun (elastic).
ModelRunParams map[string]string // Parameters for running the model inference.
// DeploymentTimeout defines the maximum duration for pulling the Preset image.
TorchRunRdzvParams map[string]string // Optional rendezvous parameters for distributed training/inference using torchrun (elastic).
// BaseCommand is the initial command (e.g., 'torchrun', 'accelerate launch') used in the command line.
BaseCommand string
ModelRunParams map[string]string // Parameters for running the model training/inference.
// ReadinessTimeout defines the maximum duration for creating the workload.
// This timeout accommodates the size of the image, ensuring pull completion
// even under slower network conditions or unforeseen delays.
DeploymentTimeout time.Duration
// BaseCommand is the initial command (e.g., 'torchrun', 'accelerate launch') used in the command line.
BaseCommand string
ReadinessTimeout time.Duration
// WorldSize defines the number of processes required for distributed inference.
WorldSize int
Tag string // The model image tag
Expand Down
21 changes: 21 additions & 0 deletions pkg/tuning/preset-tuning-types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package tuning

import corev1 "k8s.io/api/core/v1"

const (
DefaultNumProcesses = "1"
DefaultNumMachines = "1"
DefaultMachineRank = "0"
DefaultGPUIds = "all"
)

var (
DefaultAccelerateParams = map[string]string{
"num_processes": DefaultNumProcesses,
"num_machines": DefaultNumMachines,
"machine_rank": DefaultMachineRank,
"gpu_ids": DefaultGPUIds,
}

DefaultImagePullSecrets = []corev1.LocalObjectReference{}
)
14 changes: 14 additions & 0 deletions pkg/tuning/preset-tuning.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package tuning

import (
"context"
kaitov1alpha1 "github.com/azure/kaito/api/v1alpha1"
"github.com/azure/kaito/pkg/model"
"sigs.k8s.io/controller-runtime/pkg/client"
)

func CreatePresetTuning(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace,
tuningObj *model.PresetParam, kubeClient client.Client) (client.Object, error) {
// TODO
return nil, nil
}
30 changes: 24 additions & 6 deletions pkg/utils/testModel.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,45 @@ import (

type testModel struct{}

func (*testModel) GetInferenceParameters() *model.PresetInferenceParam {
return &model.PresetInferenceParam{
func (*testModel) GetInferenceParameters() *model.PresetParam {
return &model.PresetParam{
GPUCountRequirement: "1",
DeploymentTimeout: time.Duration(30) * time.Minute,
ReadinessTimeout: time.Duration(30) * time.Minute,
}
}
func (*testModel) GetTuningParameters() *model.PresetParam {
return &model.PresetParam{
GPUCountRequirement: "1",
ReadinessTimeout: time.Duration(30) * time.Minute,
}
}
func (*testModel) SupportDistributedInference() bool {
return false
}
func (*testModel) SupportTuning() bool {
return true
}

type testDistributedModel struct{}

func (*testDistributedModel) GetInferenceParameters() *model.PresetInferenceParam {
return &model.PresetInferenceParam{
func (*testDistributedModel) GetInferenceParameters() *model.PresetParam {
return &model.PresetParam{
GPUCountRequirement: "1",
DeploymentTimeout: time.Duration(30) * time.Minute,
ReadinessTimeout: time.Duration(30) * time.Minute,
}
}
func (*testDistributedModel) GetTuningParameters() *model.PresetParam {
return &model.PresetParam{
GPUCountRequirement: "1",
ReadinessTimeout: time.Duration(30) * time.Minute,
}
}
func (*testDistributedModel) SupportDistributedInference() bool {
return true
}
func (*testDistributedModel) SupportTuning() bool {
return true
}

func RegisterTestModel() {
var test testModel
Expand Down
8 changes: 4 additions & 4 deletions presets/models/falcon/README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
## Supported Models
|Model name| Model source | Sample workspace|Kubernetes Workload|Distributed inference|
|----|:----:|:----:| :----: |:----: |
|falcon-7b-instruct |[tiiuae](https://huggingface.co/tiiuae/falcon-7b-instruct)|[link](../../../examples/kaito_workspace_falcon_7b-instruct.yaml)|Deployment| false|
|falcon-7b |[tiiuae](https://huggingface.co/tiiuae/falcon-7b) |[link](../../../examples/kaito_workspace_falcon_7b.yaml)|Deployment| false|
|falcon-40b-instruct|[tiiuae](https://huggingface.co/tiiuae/falcon-40b-instruct) |[link](../../../examples/kaito_workspace_falcon_40b-instruct.yaml)|Deployment| false|
|falcon-40b |[tiiuae](https://huggingface.co/tiiuae/falcon-40b)|[link](../../../examples/kaito_workspace_falcon_40b.yaml)|Deployment| false|
|falcon-7b-instruct |[tiiuae](https://huggingface.co/tiiuae/falcon-7b-instruct)|[link](../../../examples/inference/kaito_workspace_falcon_7b-instruct.yaml)|Deployment| false|
|falcon-7b |[tiiuae](https://huggingface.co/tiiuae/falcon-7b) |[link](../../../examples/inference/kaito_workspace_falcon_7b.yaml)|Deployment| false|
|falcon-40b-instruct|[tiiuae](https://huggingface.co/tiiuae/falcon-40b-instruct) |[link](../../../examples/inference/kaito_workspace_falcon_40b-instruct.yaml)|Deployment| false|
|falcon-40b |[tiiuae](https://huggingface.co/tiiuae/falcon-40b)|[link](../../../examples/inference/kaito_workspace_falcon_40b.yaml)|Deployment| false|

## Image Source
- **Public**: Kaito maintainers manage the lifecycle of the inference service images that contain model weights. The images are available in Microsoft Container Registry (MCR).
Expand Down
Loading
Loading