Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Fine Tune (Part 10) - Updating fine tuning py #371

Merged
merged 44 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
08ef1a4
fine tuning updates
ishaansehgal99 May 3, 2024
009bbe0
remove
ishaansehgal99 May 3, 2024
bd9d2cc
add to docker
ishaansehgal99 May 3, 2024
f599790
combine tuning and inference image
ishaansehgal99 May 3, 2024
42a8ee1
update test
ishaansehgal99 May 3, 2024
4ad8e14
begin adding trainer types
ishaansehgal99 May 3, 2024
52465b0
renaming and file moving
ishaansehgal99 May 3, 2024
76f1cdf
seperate requirements
ishaansehgal99 May 3, 2024
1760b55
add parser
ishaansehgal99 May 3, 2024
8df9694
typo
ishaansehgal99 May 4, 2024
2a3e5e5
typo
ishaansehgal99 May 4, 2024
1534183
typo
ishaansehgal99 May 4, 2024
4a3acdc
mime type
ishaansehgal99 May 4, 2024
f885c31
fix
ishaansehgal99 May 4, 2024
c649b89
add python magic
ishaansehgal99 May 6, 2024
e497413
simplify using filetype
ishaansehgal99 May 6, 2024
2fb7798
dataset
ishaansehgal99 May 7, 2024
ef7ab5c
Merge branch 'main' into Ishaan/fine-tuning-py
ishaansehgal99 May 7, 2024
3b66555
Remove all CLI parser logic (unused) and add support for SFT_Trainer …
ishaansehgal99 May 8, 2024
7c7b7b1
Merge branch 'Ishaan/fine-tuning-py' of https://github.com/Azure/kait…
ishaansehgal99 May 8, 2024
519f2a3
Seperate out into dataset class
ishaansehgal99 May 8, 2024
5aac066
Add dataset
ishaansehgal99 May 8, 2024
784fe11
datasets support
ishaansehgal99 May 9, 2024
659decb
Add support for datasets
ishaansehgal99 May 9, 2024
c3ed446
header
ishaansehgal99 May 9, 2024
156e62b
feat: format and preprocess
ishaansehgal99 May 10, 2024
49fcb2c
fix some edge cases
ishaansehgal99 May 10, 2024
9fb46aa
chore: Use image enum
ishaansehgal99 May 21, 2024
5ae940c
minor tweaks
ishaansehgal99 May 21, 2024
caa39eb
seperate function
ishaansehgal99 May 21, 2024
fa0d4bd
add helpers
ishaansehgal99 May 22, 2024
cde8657
Remove manifests.go from PR
ishaansehgal99 May 22, 2024
7f88cfa
restore
ishaansehgal99 May 22, 2024
68c2f55
Merge branch 'main' of https://github.com/Azure/kaito into Ishaan/fin…
ishaansehgal99 May 22, 2024
dc01977
Merge branch 'main' into Ishaan/fine-tuning-py
ishaansehgal99 May 23, 2024
5eeff54
file rename
ishaansehgal99 May 23, 2024
2f48ef0
Merge branch 'Ishaan/fine-tuning-py' of https://github.com/Azure/kait…
ishaansehgal99 May 23, 2024
c7fe877
Dockerfile
ishaansehgal99 May 23, 2024
0c222c0
Handle custom output dir
ishaansehgal99 May 27, 2024
2c76c16
variable renamed
ishaansehgal99 May 27, 2024
baf5918
update defaults
ishaansehgal99 May 27, 2024
816d4b5
log msg update
ishaansehgal99 May 28, 2024
1d5e8b4
comments and nits
ishaansehgal99 May 28, 2024
1ba3cc4
Merge branch 'main' into Ishaan/fine-tuning-py
ishaansehgal99 May 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions api/v1alpha1/workspace_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,12 +336,12 @@ func (i *InferenceSpec) validateCreate() (errs *apis.FieldError) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported inference preset name %s", presetName), "presetName"))
}
// Validate private preset has private image specified
if plugin.KaitoModelRegister.MustGet(string(i.Preset.Name)).GetInferenceParameters().ImageAccessMode == "private" &&
i.Preset.PresetMeta.AccessMode != "private" {
if plugin.KaitoModelRegister.MustGet(string(i.Preset.Name)).GetInferenceParameters().ImageAccessMode == string(ModelImageAccessModePrivate) &&
i.Preset.PresetMeta.AccessMode != ModelImageAccessModePrivate {
errs = errs.Also(apis.ErrGeneric("This preset only supports private AccessMode, AccessMode must be private to continue"))
}
// Additional validations for Preset
if i.Preset.PresetMeta.AccessMode == "private" && i.Preset.PresetOptions.Image == "" {
if i.Preset.PresetMeta.AccessMode == ModelImageAccessModePrivate && i.Preset.PresetOptions.Image == "" {
errs = errs.Also(apis.ErrGeneric("When AccessMode is private, an image must be provided in PresetOptions"))
}
// Note: we don't enforce private access mode to have image secrets, in case anonymous pulling is enabled
Expand Down
8 changes: 4 additions & 4 deletions api/v1alpha1/workspace_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,15 @@ type testModelPrivate struct{}

func (*testModelPrivate) GetInferenceParameters() *model.PresetParam {
return &model.PresetParam{
ImageAccessMode: "private",
ImageAccessMode: string(ModelImageAccessModePrivate),
GPUCountRequirement: gpuCountRequirement,
TotalGPUMemoryRequirement: totalGPUMemoryRequirement,
PerGPUMemoryRequirement: perGPUMemoryRequirement,
}
}
func (*testModelPrivate) GetTuningParameters() *model.PresetParam {
return &model.PresetParam{
ImageAccessMode: "private",
ImageAccessMode: string(ModelImageAccessModePrivate),
GPUCountRequirement: gpuCountRequirement,
TotalGPUMemoryRequirement: totalGPUMemoryRequirement,
PerGPUMemoryRequirement: perGPUMemoryRequirement,
Expand Down Expand Up @@ -461,7 +461,7 @@ func TestInferenceSpecValidateCreate(t *testing.T) {
Preset: &PresetSpec{
PresetMeta: PresetMeta{
Name: ModelName("test-validation"),
AccessMode: "private",
AccessMode: ModelImageAccessModePrivate,
},
PresetOptions: PresetOptions{},
},
Expand All @@ -488,7 +488,7 @@ func TestInferenceSpecValidateCreate(t *testing.T) {
Preset: &PresetSpec{
PresetMeta: PresetMeta{
Name: ModelName("test-validation"),
AccessMode: "public",
AccessMode: ModelImageAccessModePublic,
},
},
},
Expand Down
2 changes: 1 addition & 1 deletion charts/kaito/workspace/templates/lora-params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ data:
bias: "none"

TrainingArguments:
output_dir: "."
output_dir: "/mnt/results"
num_train_epochs: 4
auto_find_batch_size: true
ddp_find_unused_parameters: false
Expand Down
2 changes: 1 addition & 1 deletion charts/kaito/workspace/templates/qlora-params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ data:
bias: "none"

TrainingArguments:
output_dir: "."
output_dir: "/mnt/results"
ishaansehgal99 marked this conversation as resolved.
Show resolved Hide resolved
num_train_epochs: 4
auto_find_batch_size: true
ddp_find_unused_parameters: false
Expand Down
22 changes: 0 additions & 22 deletions docker/presets/inference/tfs/Dockerfile

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,20 @@ RUN echo $VERSION > /workspace/tfs/version.txt
# First, copy just the preset files and install dependencies
# This is done before copying the code to utilize Docker's layer caching and
# avoid reinstalling dependencies unless the requirements file changes.
COPY kaito/presets/tuning/${MODEL_TYPE}/requirements.txt /workspace/tfs/requirements.txt
RUN pip install --no-cache-dir -r requirements.txt
# Inference
COPY kaito/presets/inference/${MODEL_TYPE}/requirements.txt /workspace/tfs/inference-requirements.txt
RUN pip install --no-cache-dir -r inference-requirements.txt

COPY kaito/presets/inference/${MODEL_TYPE}/inference_api.py /workspace/tfs/inference_api.py

# Fine Tuning
COPY kaito/presets/tuning/${MODEL_TYPE}/requirements.txt /workspace/tfs/tuning-requirements.txt
RUN pip install --no-cache-dir -r tuning-requirements.txt

COPY kaito/presets/tuning/${MODEL_TYPE}/cli.py /workspace/tfs/cli.py
COPY kaito/presets/tuning/${MODEL_TYPE}/fine_tuning_api.py /workspace/tfs/tuning_api.py
COPY kaito/presets/tuning/${MODEL_TYPE}/fine_tuning_api.py /workspace/tfs/fine_tuning_api.py
COPY kaito/presets/tuning/${MODEL_TYPE}/parser.py /workspace/tfs/parser.py
COPY kaito/presets/tuning/${MODEL_TYPE}/dataset.py /workspace/tfs/dataset.py

# Copy the entire model weights to the weights directory
COPY ${WEIGHTS_PATH} /workspace/tfs/weights
2 changes: 1 addition & 1 deletion pkg/inference/preset-inferences.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func updateTorchParamsForDistributedInference(ctx context.Context, kubeClient cl

func GetInferenceImageInfo(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, presetObj *model.PresetParam) (string, []corev1.LocalObjectReference) {
imagePullSecretRefs := []corev1.LocalObjectReference{}
if presetObj.ImageAccessMode == "private" {
if presetObj.ImageAccessMode == string(kaitov1alpha1.ModelImageAccessModePrivate) {
imageName := workspaceObj.Inference.Preset.PresetOptions.Image
for _, secretName := range workspaceObj.Inference.Preset.PresetOptions.ImagePullSecrets {
imagePullSecretRefs = append(imagePullSecretRefs, corev1.LocalObjectReference{Name: secretName})
Expand Down
23 changes: 19 additions & 4 deletions pkg/tuning/preset-tuning.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,21 @@ func getInstanceGPUCount(sku string) int {
return gpuConfig.GPUCount
}

func GetTuningImageInfo(ctx context.Context, wObj *kaitov1alpha1.Workspace, presetObj *model.PresetParam) string {
registryName := os.Getenv("PRESET_REGISTRY_NAME")
return fmt.Sprintf("%s/%s:%s", registryName, "kaito-tuning-"+string(wObj.Tuning.Preset.Name), presetObj.Tag)
func GetTuningImageInfo(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, presetObj *model.PresetParam) (string, []corev1.LocalObjectReference) {
imagePullSecretRefs := []corev1.LocalObjectReference{}
if presetObj.ImageAccessMode == string(kaitov1alpha1.ModelImageAccessModePrivate) {
imageName := workspaceObj.Tuning.Preset.PresetOptions.Image
for _, secretName := range workspaceObj.Tuning.Preset.PresetOptions.ImagePullSecrets {
imagePullSecretRefs = append(imagePullSecretRefs, corev1.LocalObjectReference{Name: secretName})
}
return imageName, imagePullSecretRefs
} else {
imageName := string(workspaceObj.Tuning.Preset.Name)
imageTag := presetObj.Tag
registryName := os.Getenv("PRESET_REGISTRY_NAME")
imageName = fmt.Sprintf("%s/kaito-%s:%s", registryName, imageName, imageTag)
return imageName, imagePullSecretRefs
}
}

func GetDataSrcImageInfo(ctx context.Context, wObj *kaitov1alpha1.Workspace) (string, []corev1.LocalObjectReference) {
Expand Down Expand Up @@ -216,7 +228,10 @@ func CreatePresetTuning(ctx context.Context, workspaceObj *kaitov1alpha1.Workspa
return nil, err
}
commands, resourceReq := prepareTuningParameters(ctx, workspaceObj, modelCommand, tuningObj)
tuningImage := GetTuningImageInfo(ctx, workspaceObj, tuningObj)
tuningImage, tuningImagePullSecrets := GetTuningImageInfo(ctx, workspaceObj, tuningObj)
if tuningImagePullSecrets != nil {
imagePullSecrets = append(imagePullSecrets, tuningImagePullSecrets...)
}

jobObj := resources.GenerateTuningJobManifest(ctx, workspaceObj, tuningImage, imagePullSecrets, *workspaceObj.Resource.Count, commands,
containerPorts, nil, nil, resourceReq, tolerations, initContainers, sidecarContainers, volumes, volumeMounts)
Expand Down
6 changes: 3 additions & 3 deletions pkg/tuning/preset-tuning_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func TestGetTuningImageInfo(t *testing.T) {
presetObj: &model.PresetParam{
Tag: "latest",
},
expected: "testregistry/kaito-tuning-testpreset:latest",
expected: "testregistry/kaito-testpreset:latest",
},
"Empty Registry Name": {
registryName: "",
Expand All @@ -124,14 +124,14 @@ func TestGetTuningImageInfo(t *testing.T) {
presetObj: &model.PresetParam{
Tag: "latest",
},
expected: "/kaito-tuning-testpreset:latest",
expected: "/kaito-testpreset:latest",
},
}

for name, tc := range testcases {
t.Run(name, func(t *testing.T) {
os.Setenv("PRESET_REGISTRY_NAME", tc.registryName)
result := GetTuningImageInfo(context.Background(), tc.wObj, tc.presetObj)
result, _ := GetTuningImageInfo(context.Background(), tc.wObj, tc.presetObj)
assert.Equal(t, tc.expected, result)
})
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,27 @@
import os
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum, auto
from typing import Any, Dict, List, Optional

import torch
from peft import LoraConfig
from transformers import (BitsAndBytesConfig, DataCollatorForLanguageModeling,
PreTrainedTokenizer, TrainerCallback)
PreTrainedTokenizer)

# Consider Future Support for other trainers
# class TrainerTypes(Enum):
# TRAINER = "Trainer"
# SFT_TRAINER = "SFTTrainer"
# DPO_TRAINER = "DPOTrainer"
# REWARD_TRAINER = "RewardTrainer"
# PPO_TRAINER = "PPOTrainer"
# CPO_TRAINER = "CPOTrainer"
# ORPO_TRAINER = "ORPOTrainer"

# @dataclass
# class TrainerType:
# trainer_type: TrainerTypes = field(default=TrainerTypes.SFT_TRAINER, metadata={"help": "Type of trainer to use for fine-tuning."})

@dataclass
class ExtDataCollator(DataCollatorForLanguageModeling):
Expand All @@ -24,33 +38,36 @@ class ExtLoraConfig(LoraConfig):
target_modules: Optional[List[str]] = field(default=None, metadata={"help": ("List of module names to replace with LoRA.")})
layers_to_transform: Optional[List[int]] = field(default=None, metadata={"help": "Layer indices to apply LoRA"})
layers_pattern: Optional[List[str]] = field(default=None, metadata={"help": "Pattern to match layers for LoRA"})
loftq_config: Dict[str, any] = field(default_factory=dict, metadata={"help": "LoftQ configuration for quantization"})
loftq_config: Dict[str, any] = field(default_factory=dict, metadata={"help": "LoftQ configuration for quantization"})

@dataclass
class DatasetConfig:
@dataclass
class DatasetConfig:
"""
Config for Dataset
Config for Dataset
"""
dataset_name: str = field(metadata={"help": "Name of Dataset"})
dataset_path: Optional[str] = field(default=None, metadata={"help": "Where dataset file is located in the /data folder. This path will be appended to /data. This path should be the absolute path in the image or host."})
dataset_extension: Optional[str] = field(default=None, metadata={"help": "Optional explicit file extension of the dataset. If not provided, the extension will be derived from the dataset_path."})
shuffle_dataset: bool = field(default=True, metadata={"help": "Whether to shuffle dataset"})
shuffle_seed: int = field(default=42, metadata={"help": "Seed for shuffling data"})
context_column: str = field(default="Context", metadata={"help": "Example human input column in the dataset"})
response_column: str = field(default="Response", metadata={"help": "Example bot response output column in the dataset"})
# instruction_column: Optional[str] = field(default=None, metadata={"help": "Optional column for detailed instructions, used in more structured tasks like Alpaca-style setups."}) # Consider including in V2
context_column: Optional[str] = field(default=None, metadata={"help": "Column for additional context or prompts, used for generating responses based on scenarios."})
response_column: str = field(default="text", metadata={"help": "Main text column for standalone entries or the response part in prompt-response setups."})
messages_column: Optional[str] = field(default=None, metadata={"help": "Column containing structured conversational data in JSON format with roles and content, used for chatbot training."})
train_test_split: float = field(default=0.8, metadata={"help": "Split between test and training data (e.g. 0.8 means 80/20% train/test split)"})

@dataclass
class TokenizerParams:
"""
Tokenizer params
Tokenizer params
"""
add_special_tokens: bool = field(default=True, metadata={"help": ""})
padding: bool = field(default=False, metadata={"help": ""})
truncation: bool = field(default=None, metadata={"help": ""})
max_length: Optional[int] = field(default=None, metadata={"help": ""})
stride: int = field(default=0, metadata={"help": ""})
is_split_into_words: bool = field(default=False, metadata={"help": ""})
tok_pad_to_multiple_of: Optional[int] = field(default=None, metadata={"help": ""})
tok_return_tensors: Optional[str] = field(default=None, metadata={"help": ""})
pad_to_multiple_of: Optional[int] = field(default=None, metadata={"help": ""})
return_tensors: Optional[str] = field(default=None, metadata={"help": ""})
return_token_type_ids: Optional[bool] = field(default=None, metadata={"help": ""})
return_attention_mask: Optional[bool] = field(default=None, metadata={"help": ""})
return_overflowing_tokens: bool = field(default=False, metadata={"help": ""})
Expand All @@ -72,21 +89,23 @@ class ModelConfig:
resume_download: bool = field(default=False, metadata={"help": "Resume an interrupted download"})
proxies: Optional[str] = field(default=None, metadata={"help": "Proxy configuration for downloading the model"})
output_loading_info: bool = field(default=False, metadata={"help": "Output additional loading information"})
allow_remote_files: bool = field(default=False, metadata={"help": "Allow using remote files, default is local only"})
m_revision: str = field(default="main", metadata={"help": "Specific model version to use"})
local_files_only: bool = field(default=False, metadata={"help": "Allow using remote files, default is local only"})
revision: str = field(default="main", metadata={"help": "Specific model version to use"})
trust_remote_code: bool = field(default=False, metadata={"help": "Enable trusting remote code when loading the model"})
m_load_in_4bit: bool = field(default=False, metadata={"help": "Load model in 4-bit mode"})
m_load_in_8bit: bool = field(default=False, metadata={"help": "Load model in 8-bit mode"})
load_in_4bit: bool = field(default=False, metadata={"help": "Load model in 4-bit mode"})
load_in_8bit: bool = field(default=False, metadata={"help": "Load model in 8-bit mode"})
torch_dtype: Optional[str] = field(default=None, metadata={"help": "The torch dtype for the pre-trained model"})
device_map: str = field(default="auto", metadata={"help": "The device map for the pre-trained model"})

def __post_init__(self):
"""
Post-initialization to validate some ModelConfig values
"""
if self.torch_dtype and not hasattr(torch, self.torch_dtype):
raise ValueError(f"Invalid torch dtype: {self.torch_dtype}")
self.torch_dtype = getattr(torch, self.torch_dtype) if self.torch_dtype else None
if self.torch_dtype:
if isinstance(self.torch_dtype, str) and hasattr(torch, self.torch_dtype):
self.torch_dtype = getattr(torch, self.torch_dtype)
elif not isinstance(self.torch_dtype, torch.dtype):
raise ValueError(f"Invalid torch dtype: {self.torch_dtype}")

@dataclass
class QuantizationConfig(BitsAndBytesConfig):
Expand All @@ -104,14 +123,6 @@ class QuantizationConfig(BitsAndBytesConfig):
bnb_4bit_quant_type: str = field(default="fp4", metadata={"help": "Quantization type for 4-bit"})
bnb_4bit_use_double_quant: bool = field(default=False, metadata={"help": "Use double quantization for 4-bit"})

@dataclass
class TrainingConfig:
"""
Configuration for fine_tuning process
"""
save_output_path: str = field(default=".", metadata={"help": "Path where fine_tuning output is saved"})
# Other fine_tuning-related configurations can go here

# class CheckpointCallback(TrainerCallback):
# def on_train_end(self, args, state, control, **kwargs):
# model_path = args.output_dir
Expand Down
Loading
Loading