diff --git a/train.py b/train.py index a39aad2..1b5efc4 100644 --- a/train.py +++ b/train.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn import torch.optim as optim -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader, Dataset, IterableDataset from torch.nn.utils.rnn import pad_sequence from torch.cuda.amp import autocast, GradScaler from torch.optim.lr_scheduler import CosineAnnealingLR @@ -35,6 +35,23 @@ # %% +class StreamingDataset(IterableDataset): + def __init__(self, dataset, tokenizer, max_length, overlap): + self.dataset = dataset + self.tokenizer = tokenizer + self.max_length = max_length + self.overlap = overlap + + def __iter__(self): + for item in self.dataset: + text = item["text"] + sequences = preprocess_data( + text, self.tokenizer, self.max_length, self.overlap + ) + for seq in sequences: + yield seq, seq[1:] + [self.tokenizer.eos_token_id] + + class CustomDataset(Dataset): """Custom dataset for handling sequences and targets."""