Skip to content

Commit

Permalink
Update the main() function to make the training process more robust a…
Browse files Browse the repository at this point in the history
…nd easier to monitor.

Key differences and improvements:

Weights & Biases Integration:
The updated function initializes wandb for logging and visualization.
This allows for better tracking of training progress and results.

Batch Size:
The batch size is reduced from 32 to 8.
This change might be to accommodate larger models or to work better with gradient accumulation.

Optimizer:
Changed from Adam to AdamW.
AdamW provides better weight decay handling, which can help with regularization.

Learning Rate Scheduler:
Added a CosineAnnealingLR scheduler.
This can help in better convergence by adjusting the learning rate over time.

Accelerator Preparation:
The scheduler is now also prepared with the accelerator.
This ensures that the scheduler works correctly with distributed training setups.

Model Extension:
The extended_model is now used to call recover_short_context instead of the original model.
This ensures that the short context recovery is performed on the extended model.

Final Training:
The train function now includes the scheduler as an argument.
This allows the learning rate to be adjusted during training.

Wandb Finish:
The wandb.finish() call at the end ensures that all logs are properly synced and the run is closed.
  • Loading branch information
jshuadvd committed Jul 5, 2024
1 parent 05ca736 commit 306ba56
Showing 1 changed file with 19 additions and 17 deletions.
36 changes: 19 additions & 17 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,12 @@ def train(
# %%
def main():
"""Main function to setup and run training."""

wandb.init(project="longrope", entity="your-entity-name")

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.model_max_length = (
2048000 # Set the maximum sequence length for the tokenizer
)
tokenizer.model_max_length = 2048000

data = load_data("../data/raw/enwik8.gz")

max_length = 65536
Expand All @@ -258,8 +260,6 @@ def main():

validate_targets(targets, tokenizer.vocab_size)

print(f"Validated: {validate_targets(targets, tokenizer.vocab_size)}")

dataset = CustomDataset(sequences, targets)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
Expand All @@ -268,29 +268,29 @@ def main():
)

train_loader = DataLoader(
train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn
train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn
)
val_loader = DataLoader(val_dataset, batch_size=32, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=8, collate_fn=collate_fn)

model = LongRoPEModel(
d_model=4096,
n_heads=32,
num_layers=6,
vocab_size=tokenizer.vocab_size,
max_len=2048000, # Set max_len to 2048k tokens
max_len=2048000,
)

optimizer = optim.Adam(model.parameters(), lr=1e-4)
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
scheduler = CosineAnnealingLR(optimizer, T_max=10)

# Prepare everything with accelerator
model, optimizer, train_loader, val_loader = accelerator.prepare(
model, optimizer, train_loader, val_loader
model, optimizer, train_loader, val_loader, scheduler = accelerator.prepare(
model, optimizer, train_loader, val_loader, scheduler
)

extended_model = model.extend_context(
data_path="../data/raw/enwik8.gz",
target_length=2048000, # Set target_length to 2048k tokens
target_length=2048000,
max_sequence_length=65536,
tokenizer=tokenizer,
population_size=64,
Expand All @@ -299,16 +299,18 @@ def main():
max_iterations=10,
)

recovered_model = model.recover_short_context(
recovered_model = extended_model.recover_short_context(
data_path="../data/raw/enwik8.gz",
max_sequence_length=48192,
tokenizer=tokenizer,
)

optimizer = optim.Adam(recovered_model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(recovered_model.parameters(), lr=1e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=10)

train(recovered_model, train_loader, val_loader, optimizer, criterion, scheduler)

train(recovered_model, train_loader, val_loader, optimizer, criterion)
wandb.finish()


if __name__ == "__main__":
Expand Down

0 comments on commit 306ba56

Please sign in to comment.