diff --git a/train.py b/train.py index ed13870..e1150c2 100644 --- a/train.py +++ b/train.py @@ -335,108 +335,130 @@ def main(): tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer.model_max_length = 2048000 # Set maximum sequence length to 2048k tokens - # Load the raw data - # data = load_data("../data/raw/enwik8.gz") - - # Load datasets mentioned in the LongRoPE paper + # Load the PG19 dataset pg19_dataset = load_dataset("pg19", split="train") - # Set parameters for data preprocessing - max_length = 65536 - overlap = 4096 + # Define sequence lengths for progressive training + sequence_lengths = [2048, 128000, 256000, 2048000] - # Preprocess the data into sequences - logger.info("Preprocessing PG19 dataset...") - sequences = [] - for item in pg19_dataset: - text = item["text"] - sequences.extend(preprocess_data(text, tokenizer, max_length, overlap)) - logger.info(f"Total sequences after preprocessing: {len(sequences)}") + for length in sequence_lengths: + logger.info(f"Training on sequence length: {length}") - # Create target sequences (shifted by one token) - targets = [seq[1:] + [tokenizer.eos_token_id] for seq in sequences] + # Set parameters for data preprocessing + max_length = min(length, 65536) + overlap = 4096 - # Validate that all target indices are within the vocabulary size - validate_targets(targets, tokenizer.vocab_size) + # Preprocess the data into sequences + logger.info(f"Preprocessing PG19 dataset for length {length}...") + sequences = [] + for item in pg19_dataset: + text = item["text"] + sequences.extend(preprocess_data(text, tokenizer, max_length, overlap)) + logger.info(f"Total sequences after preprocessing: {len(sequences)}") - # Create a custom dataset from sequences and targets - dataset = CustomDataset(sequences, targets) + # Create target sequences (shifted by one token) + targets = [seq[1:] + [tokenizer.eos_token_id] for seq in sequences] - # Split the dataset into training and validation sets - train_size = int(0.8 * len(dataset)) - val_size = len(dataset) - train_size - train_dataset, val_dataset = torch.utils.data.random_split( - dataset, [train_size, val_size] - ) + # Validate that all target indices are within the vocabulary size + validate_targets(targets, tokenizer.vocab_size) - # Create data loaders for training and validation - train_loader = DataLoader( - train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn - ) - val_loader = DataLoader(val_dataset, batch_size=8, collate_fn=collate_fn) - - # Initialize the LongRoPE model - model = LongRoPEModel( - d_model=4096, - n_heads=32, - num_layers=6, - vocab_size=tokenizer.vocab_size, - max_len=2048000, - ) + # Create a custom dataset from sequences and targets + dataset = CustomDataset(sequences, targets) - # Set up optimizer, loss function, and learning rate scheduler - optimizer = optim.AdamW(model.parameters(), lr=1e-4) - criterion = nn.CrossEntropyLoss() - scheduler = CosineAnnealingLR(optimizer, T_max=10) + # Split the dataset into training and validation sets + train_size = int(0.8 * len(dataset)) + val_size = len(dataset) - train_size + train_dataset, val_dataset = torch.utils.data.random_split( + dataset, [train_size, val_size] + ) - # Prepare model, optimizer, data loaders, and scheduler for distributed training - model, optimizer, train_loader, val_loader, scheduler = accelerator.prepare( - model, optimizer, train_loader, val_loader, scheduler - ) + # Create data loaders for training and validation + train_loader = DataLoader( + train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn + ) + val_loader = DataLoader(val_dataset, batch_size=8, collate_fn=collate_fn) + + # Initialize or extend the LongRoPE model based on the current sequence length + if length == 2048: + # Initialize the base LongRoPE model + model = LongRoPEModel( + d_model=4096, + n_heads=32, + num_layers=6, + vocab_size=tokenizer.vocab_size, + max_len=length, + ) + else: + # Extend the context window of the model + model = model.extend_context( + data=sequences, + target_length=length, + max_sequence_length=max_length, + tokenizer=tokenizer, + population_size=64, + num_mutations=16, + num_crossovers=16, + max_iterations=10, + ) - # Check for the latest checkpoint - latest_checkpoint = "checkpoint_latest.pt" - if os.path.exists(latest_checkpoint): - logger.info(f"Found checkpoint: {latest_checkpoint}") - resume_from_checkpoint = latest_checkpoint - else: - logger.info("No checkpoint found, starting training from scratch") - resume_from_checkpoint = None - - # Extend the context window of the model - extended_model = model.extend_context( - data=sequences, - target_length=2048000, - max_sequence_length=65536, - tokenizer=tokenizer, - population_size=64, - num_mutations=16, - num_crossovers=16, - max_iterations=10, - ) + # Set up optimizer, loss function, and learning rate scheduler + optimizer = optim.AdamW(model.parameters(), lr=1e-4) + criterion = nn.CrossEntropyLoss() + scheduler = CosineAnnealingLR(optimizer, T_max=10) - # Recover performance on shorter contexts - recovered_model = extended_model.recover_short_context( - data=sequences, - max_sequence_length=48192, - tokenizer=tokenizer, - ) + # Prepare model, optimizer, data loaders, and scheduler for distributed training + model, optimizer, train_loader, val_loader, scheduler = accelerator.prepare( + model, optimizer, train_loader, val_loader, scheduler + ) - # Create new optimizer and scheduler for the recovered model - optimizer = optim.AdamW(recovered_model.parameters(), lr=1e-4) - scheduler = CosineAnnealingLR(optimizer, T_max=10) - - # Train the recovered model - train( - recovered_model, - train_loader, - val_loader, - optimizer, - criterion, - scheduler, - tokenizer, - resume_from_checkpoint=resume_from_checkpoint, - ) + # Check for the latest checkpoint specific to this sequence length + latest_checkpoint = f"checkpoint_latest_{length}.pt" + if os.path.exists(latest_checkpoint): + logger.info(f"Found checkpoint for length {length}: {latest_checkpoint}") + resume_from_checkpoint = latest_checkpoint + else: + logger.info( + f"No checkpoint found for length {length}, starting training from scratch" + ) + resume_from_checkpoint = None + + # Perform training or fine-tuning based on the current sequence length + if length in [128000, 256000]: + # Fine-tuning for specific steps as mentioned in the LongRoPE paper + fine_tune_steps = 400 if length == 128000 else 600 + train( + model, + train_loader, + val_loader, + optimizer, + criterion, + scheduler, + tokenizer, + epochs=1, + gradient_accumulation_steps=fine_tune_steps // len(train_loader), + resume_from_checkpoint=resume_from_checkpoint, + max_steps=fine_tune_steps, + ) + else: + # Regular training for other sequence lengths + train( + model, + train_loader, + val_loader, + optimizer, + criterion, + scheduler, + tokenizer, + resume_from_checkpoint=resume_from_checkpoint, + ) + + # Recover performance on shorter contexts after 256k extension + if length == 256000: + model = model.recover_short_context( + data=sequences, + max_sequence_length=48192, + tokenizer=tokenizer, + ) # Finish logging and close the Weights & Biases run wandb.finish()