diff --git a/train.py b/train.py index 262f3c0..028d040 100644 --- a/train.py +++ b/train.py @@ -10,7 +10,7 @@ from torch.optim.lr_scheduler import CosineAnnealingLR import gzip from transformers import GPT2Tokenizer -from datasets import load_dataset +from datasets import load_dataset, concatenate_datasets from importlib import reload import src.main from accelerate import Accelerator @@ -308,6 +308,11 @@ def main(): # Load the raw data data = load_data("../data/raw/enwik8.gz") + # Load datasets mentioned in the LongRoPE paper + pg19_dataset = load_dataset("pg19", split="train") + arxiv_dataset = load_dataset("arxiv_dataset", split="train") + github_dataset = load_dataset("github_dataset", split="train") + # Set parameters for data preprocessing max_length = 65536 overlap = 4096