Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Fine Tune (Part 10) - Updating fine tuning py #371

Merged
merged 44 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
08ef1a4
fine tuning updates
ishaansehgal99 May 3, 2024
009bbe0
remove
ishaansehgal99 May 3, 2024
bd9d2cc
add to docker
ishaansehgal99 May 3, 2024
f599790
combine tuning and inference image
ishaansehgal99 May 3, 2024
42a8ee1
update test
ishaansehgal99 May 3, 2024
4ad8e14
begin adding trainer types
ishaansehgal99 May 3, 2024
52465b0
renaming and file moving
ishaansehgal99 May 3, 2024
76f1cdf
seperate requirements
ishaansehgal99 May 3, 2024
1760b55
add parser
ishaansehgal99 May 3, 2024
8df9694
typo
ishaansehgal99 May 4, 2024
2a3e5e5
typo
ishaansehgal99 May 4, 2024
1534183
typo
ishaansehgal99 May 4, 2024
4a3acdc
mime type
ishaansehgal99 May 4, 2024
f885c31
fix
ishaansehgal99 May 4, 2024
c649b89
add python magic
ishaansehgal99 May 6, 2024
e497413
simplify using filetype
ishaansehgal99 May 6, 2024
2fb7798
dataset
ishaansehgal99 May 7, 2024
ef7ab5c
Merge branch 'main' into Ishaan/fine-tuning-py
ishaansehgal99 May 7, 2024
3b66555
Remove all CLI parser logic (unused) and add support for SFT_Trainer …
ishaansehgal99 May 8, 2024
7c7b7b1
Merge branch 'Ishaan/fine-tuning-py' of https://github.com/Azure/kait…
ishaansehgal99 May 8, 2024
519f2a3
Seperate out into dataset class
ishaansehgal99 May 8, 2024
5aac066
Add dataset
ishaansehgal99 May 8, 2024
784fe11
datasets support
ishaansehgal99 May 9, 2024
659decb
Add support for datasets
ishaansehgal99 May 9, 2024
c3ed446
header
ishaansehgal99 May 9, 2024
156e62b
feat: format and preprocess
ishaansehgal99 May 10, 2024
49fcb2c
fix some edge cases
ishaansehgal99 May 10, 2024
9fb46aa
chore: Use image enum
ishaansehgal99 May 21, 2024
5ae940c
minor tweaks
ishaansehgal99 May 21, 2024
caa39eb
seperate function
ishaansehgal99 May 21, 2024
fa0d4bd
add helpers
ishaansehgal99 May 22, 2024
cde8657
Remove manifests.go from PR
ishaansehgal99 May 22, 2024
7f88cfa
restore
ishaansehgal99 May 22, 2024
68c2f55
Merge branch 'main' of https://github.com/Azure/kaito into Ishaan/fin…
ishaansehgal99 May 22, 2024
dc01977
Merge branch 'main' into Ishaan/fine-tuning-py
ishaansehgal99 May 23, 2024
5eeff54
file rename
ishaansehgal99 May 23, 2024
2f48ef0
Merge branch 'Ishaan/fine-tuning-py' of https://github.com/Azure/kait…
ishaansehgal99 May 23, 2024
c7fe877
Dockerfile
ishaansehgal99 May 23, 2024
0c222c0
Handle custom output dir
ishaansehgal99 May 27, 2024
2c76c16
variable renamed
ishaansehgal99 May 27, 2024
baf5918
update defaults
ishaansehgal99 May 27, 2024
816d4b5
log msg update
ishaansehgal99 May 28, 2024
1d5e8b4
comments and nits
ishaansehgal99 May 28, 2024
1ba3cc4
Merge branch 'main' into Ishaan/fine-tuning-py
ishaansehgal99 May 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 53 additions & 5 deletions api/v1alpha1/params_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@ package v1alpha1
import (
"context"
"fmt"
"reflect"

"github.com/azure/kaito/pkg/k8sclient"
"github.com/azure/kaito/pkg/utils"
"gopkg.in/yaml.v2"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/runtime"
"knative.dev/pkg/apis"
"path/filepath"
"reflect"
"sigs.k8s.io/controller-runtime/pkg/client"
"strings"
)

type Config struct {
Expand Down Expand Up @@ -94,15 +95,59 @@ func (t *TrainingConfig) UnmarshalYAML(unmarshal func(interface{}) error) error
return nil
}

func validateMethodViaConfigMap(cm *corev1.ConfigMap, methodLowerCase string) *apis.FieldError {
func UnmarshalTrainingConfig(cm *corev1.ConfigMap) (*Config, *apis.FieldError) {
trainingConfigYAML, ok := cm.Data["training_config.yaml"]
if !ok {
return apis.ErrGeneric(fmt.Sprintf("ConfigMap '%s' does not contain 'training_config.yaml' in namespace '%s'", cm.Name, cm.Namespace), "config")
return nil, apis.ErrGeneric(fmt.Sprintf("ConfigMap '%s' does not contain 'training_config.yaml' in namespace '%s'", cm.Name, cm.Namespace), "config")
}

var config Config
if err := yaml.Unmarshal([]byte(trainingConfigYAML), &config); err != nil {
return apis.ErrGeneric(fmt.Sprintf("Failed to parse 'training_config.yaml' in ConfigMap '%s' in namespace '%s': %v", cm.Name, cm.Namespace, err), "config")
return nil, apis.ErrGeneric(fmt.Sprintf("Failed to parse 'training_config.yaml' in ConfigMap '%s' in namespace '%s': %v", cm.Name, cm.Namespace, err), "config")
}
return &config, nil
}

func validateTrainingArgsViaConfigMap(cm *corev1.ConfigMap) *apis.FieldError {
config, err := UnmarshalTrainingConfig(cm)
if err != nil {
return err
}

trainingArgs := config.TrainingConfig.TrainingArguments
if trainingArgs != nil {
trainingArgsRaw, trainingArgsExists := trainingArgs["TrainingArguments"]
if trainingArgsExists {
// If specified, ensure output dir is of type string
outputDirValue, found, err := utils.SearchRawExtension(trainingArgsRaw, "output_dir")
if err != nil {
return apis.ErrGeneric(fmt.Sprintf("Failed to parse 'output_dir' in ConfigMap '%s' in namespace '%s': %v", cm.Name, cm.Namespace, err), "output_dir")
}
if found {
userSpecifiedDir, ok := outputDirValue.(string)
if !ok {
return apis.ErrInvalidValue(fmt.Sprintf("output_dir is not a string in ConfigMap '%s' in namespace '%s'", cm.Name, cm.Namespace), "output_dir")
}

// Ensure the user-specified directory is under baseDir
baseDir := "/mnt"
cleanPath := filepath.Clean(filepath.Join(baseDir, userSpecifiedDir))
if cleanPath == baseDir || !strings.HasPrefix(cleanPath, baseDir) {
return apis.ErrInvalidValue(fmt.Sprintf("Invalid output_dir specified: '%s', must be a directory", userSpecifiedDir), "output_dir")
}
}

// TODO: Here we perform the tuning GPU Memory Checks!
fmt.Println(trainingArgsRaw)
}
}
return nil
}

func validateMethodViaConfigMap(cm *corev1.ConfigMap, methodLowerCase string) *apis.FieldError {
config, err := UnmarshalTrainingConfig(cm)
if err != nil {
return err
}

// Validate QuantizationConfig if it exists
Expand Down Expand Up @@ -225,6 +270,9 @@ func (r *TuningSpec) validateConfigMap(ctx context.Context, namespace string, me
if err := validateMethodViaConfigMap(&cm, methodLowerCase); err != nil {
errs = errs.Also(err)
}
if err := validateTrainingArgsViaConfigMap(&cm); err != nil {
errs = errs.Also(err)
}
}
return errs
}
6 changes: 3 additions & 3 deletions api/v1alpha1/workspace_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,12 +336,12 @@ func (i *InferenceSpec) validateCreate() (errs *apis.FieldError) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported inference preset name %s", presetName), "presetName"))
}
// Validate private preset has private image specified
if plugin.KaitoModelRegister.MustGet(string(i.Preset.Name)).GetInferenceParameters().ImageAccessMode == "private" &&
i.Preset.PresetMeta.AccessMode != "private" {
if plugin.KaitoModelRegister.MustGet(string(i.Preset.Name)).GetInferenceParameters().ImageAccessMode == string(ModelImageAccessModePrivate) &&
i.Preset.PresetMeta.AccessMode != ModelImageAccessModePrivate {
errs = errs.Also(apis.ErrGeneric("This preset only supports private AccessMode, AccessMode must be private to continue"))
}
// Additional validations for Preset
if i.Preset.PresetMeta.AccessMode == "private" && i.Preset.PresetOptions.Image == "" {
if i.Preset.PresetMeta.AccessMode == ModelImageAccessModePrivate && i.Preset.PresetOptions.Image == "" {
errs = errs.Also(apis.ErrGeneric("When AccessMode is private, an image must be provided in PresetOptions"))
}
// Note: we don't enforce private access mode to have image secrets, in case anonymous pulling is enabled
Expand Down
12 changes: 6 additions & 6 deletions api/v1alpha1/workspace_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,15 @@ type testModelPrivate struct{}

func (*testModelPrivate) GetInferenceParameters() *model.PresetParam {
return &model.PresetParam{
ImageAccessMode: "private",
ImageAccessMode: string(ModelImageAccessModePrivate),
GPUCountRequirement: gpuCountRequirement,
TotalGPUMemoryRequirement: totalGPUMemoryRequirement,
PerGPUMemoryRequirement: perGPUMemoryRequirement,
}
}
func (*testModelPrivate) GetTuningParameters() *model.PresetParam {
return &model.PresetParam{
ImageAccessMode: "private",
ImageAccessMode: string(ModelImageAccessModePrivate),
GPUCountRequirement: gpuCountRequirement,
TotalGPUMemoryRequirement: totalGPUMemoryRequirement,
PerGPUMemoryRequirement: perGPUMemoryRequirement,
Expand Down Expand Up @@ -121,7 +121,7 @@ func defaultConfigMapManifest() *v1.ConfigMap {
bias: "none"

TrainingArguments:
output_dir: "."
output_dir: "output"
num_train_epochs: 4
auto_find_batch_size: true
ddp_find_unused_parameters: false
Expand Down Expand Up @@ -168,7 +168,7 @@ func qloraConfigMapManifest() *v1.ConfigMap {
bias: "none"

TrainingArguments:
output_dir: "."
output_dir: "output"
num_train_epochs: 4
auto_find_batch_size: true
ddp_find_unused_parameters: false
Expand Down Expand Up @@ -461,7 +461,7 @@ func TestInferenceSpecValidateCreate(t *testing.T) {
Preset: &PresetSpec{
PresetMeta: PresetMeta{
Name: ModelName("test-validation"),
AccessMode: "private",
AccessMode: ModelImageAccessModePrivate,
},
PresetOptions: PresetOptions{},
},
Expand All @@ -488,7 +488,7 @@ func TestInferenceSpecValidateCreate(t *testing.T) {
Preset: &PresetSpec{
PresetMeta: PresetMeta{
Name: ModelName("test-validation"),
AccessMode: "public",
AccessMode: ModelImageAccessModePublic,
},
},
},
Expand Down
9 changes: 4 additions & 5 deletions charts/kaito/workspace/templates/lora-params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@ data:
load_in_4bit: false

LoraConfig:
r: 16
lora_alpha: 32
r: 8
lora_alpha: 8
target_modules: "query_key_value"
lora_dropout: 0.05
bias: "none"
lora_dropout: 0.0

TrainingArguments:
output_dir: "."
output_dir: "/mnt/results"
num_train_epochs: 4
auto_find_batch_size: true
ddp_find_unused_parameters: false
Expand Down
9 changes: 4 additions & 5 deletions charts/kaito/workspace/templates/qlora-params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,13 @@ data:
bnb_4bit_use_double_quant: true

LoraConfig:
r: 16
lora_alpha: 32
r: 8
lora_alpha: 8
target_modules: "query_key_value"
lora_dropout: 0.05
bias: "none"
lora_dropout: 0.0

TrainingArguments:
output_dir: "."
output_dir: "/mnt/results"
ishaansehgal99 marked this conversation as resolved.
Show resolved Hide resolved
num_train_epochs: 4
auto_find_batch_size: true
ddp_find_unused_parameters: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,20 @@ RUN echo $VERSION > /workspace/tfs/version.txt
# First, copy just the preset files and install dependencies
# This is done before copying the code to utilize Docker's layer caching and
# avoid reinstalling dependencies unless the requirements file changes.
COPY kaito/presets/inference/${MODEL_TYPE}/requirements.txt /workspace/tfs/requirements.txt
RUN pip install --no-cache-dir -r requirements.txt
# Inference
COPY kaito/presets/inference/${MODEL_TYPE}/requirements.txt /workspace/tfs/inference-requirements.txt
RUN pip install --no-cache-dir -r inference-requirements.txt

COPY kaito/presets/inference/${MODEL_TYPE}/inference_api.py /workspace/tfs/inference_api.py

# Fine Tuning
COPY kaito/presets/tuning/${MODEL_TYPE}/requirements.txt /workspace/tfs/tuning-requirements.txt
RUN pip install --no-cache-dir -r tuning-requirements.txt

COPY kaito/presets/tuning/${MODEL_TYPE}/cli.py /workspace/tfs/cli.py
COPY kaito/presets/tuning/${MODEL_TYPE}/fine_tuning.py /workspace/tfs/fine_tuning.py
COPY kaito/presets/tuning/${MODEL_TYPE}/parser.py /workspace/tfs/parser.py
COPY kaito/presets/tuning/${MODEL_TYPE}/dataset.py /workspace/tfs/dataset.py

# Copy the entire model weights to the weights directory
COPY ${WEIGHTS_PATH} /workspace/tfs/weights
23 changes: 0 additions & 23 deletions docker/presets/tuning/Dockerfile

This file was deleted.

2 changes: 1 addition & 1 deletion pkg/inference/preset-inferences.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func updateTorchParamsForDistributedInference(ctx context.Context, kubeClient cl

func GetInferenceImageInfo(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, presetObj *model.PresetParam) (string, []corev1.LocalObjectReference) {
imagePullSecretRefs := []corev1.LocalObjectReference{}
if presetObj.ImageAccessMode == "private" {
if presetObj.ImageAccessMode == string(kaitov1alpha1.ModelImageAccessModePrivate) {
imageName := workspaceObj.Inference.Preset.PresetOptions.Image
for _, secretName := range workspaceObj.Inference.Preset.PresetOptions.ImagePullSecrets {
imagePullSecretRefs = append(imagePullSecretRefs, corev1.LocalObjectReference{Name: secretName})
Expand Down
Loading
Loading