From f07081cdff414ad49a6c2e487fc4d71e990d6c84 Mon Sep 17 00:00:00 2001 From: Joshua David Date: Fri, 12 Jul 2024 22:43:56 -0700 Subject: [PATCH] Add StreamingDataset class for better handling of the long sequences --- train.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index a39aad2..1b5efc4 100644 --- a/train.py +++ b/train.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn import torch.optim as optim -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader, Dataset, IterableDataset from torch.nn.utils.rnn import pad_sequence from torch.cuda.amp import autocast, GradScaler from torch.optim.lr_scheduler import CosineAnnealingLR @@ -35,6 +35,23 @@ # %% +class StreamingDataset(IterableDataset): + def __init__(self, dataset, tokenizer, max_length, overlap): + self.dataset = dataset + self.tokenizer = tokenizer + self.max_length = max_length + self.overlap = overlap + + def __iter__(self): + for item in self.dataset: + text = item["text"] + sequences = preprocess_data( + text, self.tokenizer, self.max_length, self.overlap + ) + for seq in sequences: + yield seq, seq[1:] + [self.tokenizer.eos_token_id] + + class CustomDataset(Dataset): """Custom dataset for handling sequences and targets."""