Skip to content

Commit

Permalink
Add StreamingDataset class for better handling of the long sequences
Browse files Browse the repository at this point in the history
  • Loading branch information
jshuadvd committed Jul 13, 2024
1 parent 506daf2 commit f07081c
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import DataLoader, Dataset, IterableDataset
from torch.nn.utils.rnn import pad_sequence
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import CosineAnnealingLR
Expand Down Expand Up @@ -35,6 +35,23 @@


# %%
class StreamingDataset(IterableDataset):
def __init__(self, dataset, tokenizer, max_length, overlap):
self.dataset = dataset
self.tokenizer = tokenizer
self.max_length = max_length
self.overlap = overlap

def __iter__(self):
for item in self.dataset:
text = item["text"]
sequences = preprocess_data(
text, self.tokenizer, self.max_length, self.overlap
)
for seq in sequences:
yield seq, seq[1:] + [self.tokenizer.eos_token_id]


class CustomDataset(Dataset):
"""Custom dataset for handling sequences and targets."""

Expand Down

0 comments on commit f07081c

Please sign in to comment.