Skip to content

Commit

Permalink
Seperate out into dataset class
Browse files Browse the repository at this point in the history
  • Loading branch information
ishaansehgal99 committed May 8, 2024
1 parent 7c7b7b1 commit 519f2a3
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 94 deletions.
13 changes: 6 additions & 7 deletions presets/tuning/text-generation/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,19 @@
from transformers import (BitsAndBytesConfig, DataCollatorForLanguageModeling,
PreTrainedTokenizer, TrainerCallback)


class TrainerTypes(Enum):
TRAINER = "Trainer"
SFT_TRAINER = "SFTTrainer"
# class TrainerTypes(Enum):
# TRAINER = "Trainer"
# SFT_TRAINER = "SFTTrainer"
# TODO: Future Support for other trainers
# DPO_TRAINER = "DPOTrainer"
# REWARD_TRAINER = "RewardTrainer"
# PPO_TRAINER = "PPOTrainer"
# CPO_TRAINER = "CPOTrainer"
# ORPO_TRAINER = "ORPOTrainer"

@dataclass
class TrainerType:
trainer_type: TrainerTypes = field(default=TrainerTypes.SFT_TRAINER, metadata={"help": "Type of trainer to use for fine-tuning."})
# @dataclass
# class TrainerType:
# trainer_type: TrainerTypes = field(default=TrainerTypes.SFT_TRAINER, metadata={"help": "Type of trainer to use for fine-tuning."})

@dataclass
class ExtDataCollator(DataCollatorForLanguageModeling):
Expand Down
77 changes: 77 additions & 0 deletions presets/tuning/text-generation/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import os
from typing import Optional

from datasets import load_dataset

SUPPORTED_EXTENSIONS = {'csv', 'json', 'parquet', 'arrow', 'webdataset'}

class DatasetManager:
def __init__(self, config, tokenizer_params, preprocess_function: Optional[callable] = None):
self.config = config
self.tokenizer_params = tokenizer_params
self.preprocess_function = preprocess_function
self.dataset = None

def load_dataset(self):
if self.config.dataset_path:
dataset_path = os.path.join("/mnt", self.config.dataset_path.strip("/"))
else:
dataset_path = self.find_valid_dataset(os.environ.get('DATASET_FOLDER_PATH', '/mnt/data'))
if not dataset_path:
raise ValueError("Unable to find a valid dataset file.")

file_ext = self.config.dataset_extension if self.config.dataset_extension else self.get_file_extension(dataset_path)
try:
self.dataset = load_dataset(file_ext, data_files=dataset_path, split="train")
print(f"Dataset loaded successfully from {dataset_path} with file type '{file_ext}'.")
except Exception as e:
print(f"Error loading dataset: {e}")
raise ValueError(f"Unable to load dataset {dataset_path} with file type '{file_ext}'")

def find_valid_dataset(self, data_dir):
""" Searches for a file with a valid dataset type in the given directory. """
for root, dirs, files in os.walk(data_dir):
for file in files:
filename_lower = file.lower() # Convert to lowercase once per filename
for ext in SUPPORTED_EXTENSIONS:
if ext in filename_lower:
return os.path.join(root, file)
return None

def get_file_extension(self, file_path):
""" Returns the file extension based on filetype guess or filename. """
filename_lower = os.path.basename(file_path).lower()
for ext in SUPPORTED_EXTENSIONS:
if ext in filename_lower:
return ext
_, file_ext = os.path.splitext(file_path)
return file_ext[1:] # Remove leading "."

def preprocess_data(self):
if self.dataset is None:
raise ValueError("Dataset is not loaded.")
if self.preprocess_function:
self.dataset = self.dataset.map(self.preprocess_function, batched=True, fn_kwargs=self.tokenizer_params)

def shuffle_dataset(self, seed=None):
if self.dataset is None:
raise ValueError("Dataset is not loaded.")
self.dataset = self.dataset.shuffle(seed=seed)

def split_dataset(self):
if self.dataset is None:
raise ValueError("Dataset is not loaded.")
assert 0 < self.config.train_test_split <= 1, "Train/Test split needs to be between 0 and 1"
if self.config.train_test_split < 1:
split_dataset = self.dataset.train_test_split(
test_size=1-self.config.train_test_split,
seed=self.config.shuffle_seed
)
return split_dataset['train'], split_dataset['test']
else:
return self.dataset, None

def get_dataset(self):
if self.dataset is None:
raise ValueError("Dataset is not loaded.")
return self.dataset
110 changes: 24 additions & 86 deletions presets/tuning/text-generation/fine_tuning_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,20 @@
import torch
import transformers
from accelerate import Accelerator
from cli import TrainerTypes
from dataset import DatasetManager
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import (AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig, HfArgumentParser, Trainer,
TrainingArguments)
from trl import SFTTrainer

DATASET_PATH = os.environ.get('DATASET_FOLDER_PATH', '/mnt/data')
CONFIG_YAML = os.environ.get('YAML_FILE_PATH', '/mnt/config/training_config.yaml')
TRAINER_CLASS_MAP = {
TrainerTypes.TRAINER: Trainer,
TrainerTypes.SFT_TRAINER: SFTTrainer,
# Additional mappings as needed
}
# TRAINER_CLASS_MAP = {
# TrainerTypes.TRAINER: Trainer,
# TrainerTypes.SFT_TRAINER: SFTTrainer,
# # Additional mappings as needed
# }

parsed_configs = parse_configs(CONFIG_YAML)

Expand Down Expand Up @@ -96,99 +95,38 @@ def preprocess_data(example):
prompt = f"human: {example[ds_config.context_column]}\n bot: {example[ds_config.response_column]}".strip()
return tokenizer(prompt, **tk_params)

SUPPORTED_EXTENSIONS = {'csv', 'json', 'parquet', 'arrow', 'webdataset'}

def find_valid_dataset(data_dir):
""" Searches for a file with a valid dataset type in the given directory. """
for root, dirs, files in os.walk(data_dir):
for file in files:
filename_lower = file.lower() # Convert to lowercase once per filename
for ext in SUPPORTED_EXTENSIONS:
if ext in filename_lower:
return os.path.join(root, file)
return None

def get_file_extension(file_path):
""" Returns the file extension based on filetype guess or filename. """
filename_lower = os.path.basename(file_path).lower()
for ext in SUPPORTED_EXTENSIONS:
if ext in filename_lower:
return ext
_, file_ext = os.path.splitext(file_path)
return file_ext[1:] # Remove leading "."

# Loading the dataset
if ds_config.dataset_path:
DATASET_PATH = os.path.join("/mnt", ds_config.dataset_path.strip("/"))
else:
DATASET_PATH = find_valid_dataset(DATASET_PATH)
if not DATASET_PATH:
raise ValueError("Unable to find a valid dataset file.")

# Determine the file extension
file_ext = ds_config.dataset_extension if ds_config.dataset_extension else get_file_extension(DATASET_PATH)

dataset = load_dataset(file_ext, data_files=DATASET_PATH, split="train")
if dataset:
print(f"Dataset loaded successfully from {DATASET_PATH} with file type '{file_ext}'.")
else:
dm = DatasetManager(ds_config, tk_params)
# Load the dataset
dm.load_dataset()
if not dm.dataset:
print("Failed to load dataset.")
raise ValueError("Unable to load the dataset.")

# Shuffling the dataset (if needed)
if ds_config.shuffle_dataset:
dataset = dataset.shuffle(seed=ds_config.shuffle_seed)

if tt_args.trainer_type == TrainerTypes.SFT_TRAINER:
text_mapping_func = lambda x: {
'text': f"<s>[INST]{('<<SYS>>' + x[ds_config.instruction_column] + '<</SYS>>') if ds_config.instruction_column in x else ''}{x[ds_config.context_column]} [/INST]{x[ds_config.response_column]} </s>"
}
dataset = dataset.map(text_mapping_func)

# Preprocessing the data
dataset = dataset.map(preprocess_data)
dm.shuffle_dataset()

assert 0 < ds_config.train_test_split <= 1, "Train/Test split needs to be between 0 and 1"

# Initialize variables for train and eval datasets
train_dataset, eval_dataset = dataset, None
text_mapping_func = lambda x: {
'text': f"<s>[INST]{('<<SYS>>' + x[ds_config.instruction_column] + '<</SYS>>') if ds_config.instruction_column in x else ''}{x[ds_config.context_column]} [/INST]{x[ds_config.response_column]} </s>"
}
dm.preprocess_data(text_mapping_func)

if ds_config.train_test_split < 1:
# Splitting the dataset into training and test sets
split_dataset = dataset.train_test_split(
test_size=1-ds_config.train_test_split,
seed=ds_config.shuffle_seed
)
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
print("Training Dataset Dimensions: ", train_dataset.shape)
print("Test Dataset Dimensions: ", eval_dataset.shape)
else:
print(f"Using full dataset for training. Dimensions: {train_dataset.shape}")
train_dataset, eval_dataset = dm.split_dataset()

# checkpoint_callback = CheckpointCallback()

# Prepare for training
torch.cuda.set_device(accelerator.process_index)
torch.cuda.empty_cache()
# Training the Model
if tt_args.trainer_type == TrainerTypes.TRAINER:
trainer = accelerator.prepare(trainer_class(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
args=ta_args,
data_collator=dc_args,
# callbacks=[checkpoint_callback]
))
elif tt_args.trainer_type == TrainerTypes.SFT_TRAINER:
trainer = accelerator.prepare(trainer_class(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
args=ta_args,
data_collator=dc_args,
dataset_text_field="text"
))
trainer = accelerator.prepare(SFTTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
args=ta_args,
data_collator=dc_args,
# metrics = "tensorboard" or "wandb" # TODO
))
trainer.train()
os.makedirs(ta_args.output_dir, exist_ok=True)
trainer.save_model(ta_args.output_dir)
Expand Down
2 changes: 1 addition & 1 deletion presets/tuning/text-generation/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
datasets==2.16.1
peft==0.8.2
transformers==4.38.2
transformers==4.40.2
torch==2.2.0
accelerate==0.27.2
fastapi==0.109.1
Expand Down

0 comments on commit 519f2a3

Please sign in to comment.