Skip to content

Commit

Permalink
Update the progressive training and the finetuning for different leng…
Browse files Browse the repository at this point in the history
…ths in the main function
  • Loading branch information
jshuadvd committed Jul 9, 2024
1 parent 96ad654 commit 118c02e
Showing 1 changed file with 111 additions and 89 deletions.
200 changes: 111 additions & 89 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,108 +335,130 @@ def main():
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.model_max_length = 2048000 # Set maximum sequence length to 2048k tokens

# Load the raw data
# data = load_data("../data/raw/enwik8.gz")

# Load datasets mentioned in the LongRoPE paper
# Load the PG19 dataset
pg19_dataset = load_dataset("pg19", split="train")

# Set parameters for data preprocessing
max_length = 65536
overlap = 4096
# Define sequence lengths for progressive training
sequence_lengths = [2048, 128000, 256000, 2048000]

# Preprocess the data into sequences
logger.info("Preprocessing PG19 dataset...")
sequences = []
for item in pg19_dataset:
text = item["text"]
sequences.extend(preprocess_data(text, tokenizer, max_length, overlap))
logger.info(f"Total sequences after preprocessing: {len(sequences)}")
for length in sequence_lengths:
logger.info(f"Training on sequence length: {length}")

# Create target sequences (shifted by one token)
targets = [seq[1:] + [tokenizer.eos_token_id] for seq in sequences]
# Set parameters for data preprocessing
max_length = min(length, 65536)
overlap = 4096

# Validate that all target indices are within the vocabulary size
validate_targets(targets, tokenizer.vocab_size)
# Preprocess the data into sequences
logger.info(f"Preprocessing PG19 dataset for length {length}...")
sequences = []
for item in pg19_dataset:
text = item["text"]
sequences.extend(preprocess_data(text, tokenizer, max_length, overlap))
logger.info(f"Total sequences after preprocessing: {len(sequences)}")

# Create a custom dataset from sequences and targets
dataset = CustomDataset(sequences, targets)
# Create target sequences (shifted by one token)
targets = [seq[1:] + [tokenizer.eos_token_id] for seq in sequences]

# Split the dataset into training and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
dataset, [train_size, val_size]
)
# Validate that all target indices are within the vocabulary size
validate_targets(targets, tokenizer.vocab_size)

# Create data loaders for training and validation
train_loader = DataLoader(
train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn
)
val_loader = DataLoader(val_dataset, batch_size=8, collate_fn=collate_fn)

# Initialize the LongRoPE model
model = LongRoPEModel(
d_model=4096,
n_heads=32,
num_layers=6,
vocab_size=tokenizer.vocab_size,
max_len=2048000,
)
# Create a custom dataset from sequences and targets
dataset = CustomDataset(sequences, targets)

# Set up optimizer, loss function, and learning rate scheduler
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
scheduler = CosineAnnealingLR(optimizer, T_max=10)
# Split the dataset into training and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
dataset, [train_size, val_size]
)

# Prepare model, optimizer, data loaders, and scheduler for distributed training
model, optimizer, train_loader, val_loader, scheduler = accelerator.prepare(
model, optimizer, train_loader, val_loader, scheduler
)
# Create data loaders for training and validation
train_loader = DataLoader(
train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn
)
val_loader = DataLoader(val_dataset, batch_size=8, collate_fn=collate_fn)

# Initialize or extend the LongRoPE model based on the current sequence length
if length == 2048:
# Initialize the base LongRoPE model
model = LongRoPEModel(
d_model=4096,
n_heads=32,
num_layers=6,
vocab_size=tokenizer.vocab_size,
max_len=length,
)
else:
# Extend the context window of the model
model = model.extend_context(
data=sequences,
target_length=length,
max_sequence_length=max_length,
tokenizer=tokenizer,
population_size=64,
num_mutations=16,
num_crossovers=16,
max_iterations=10,
)

# Check for the latest checkpoint
latest_checkpoint = "checkpoint_latest.pt"
if os.path.exists(latest_checkpoint):
logger.info(f"Found checkpoint: {latest_checkpoint}")
resume_from_checkpoint = latest_checkpoint
else:
logger.info("No checkpoint found, starting training from scratch")
resume_from_checkpoint = None

# Extend the context window of the model
extended_model = model.extend_context(
data=sequences,
target_length=2048000,
max_sequence_length=65536,
tokenizer=tokenizer,
population_size=64,
num_mutations=16,
num_crossovers=16,
max_iterations=10,
)
# Set up optimizer, loss function, and learning rate scheduler
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
scheduler = CosineAnnealingLR(optimizer, T_max=10)

# Recover performance on shorter contexts
recovered_model = extended_model.recover_short_context(
data=sequences,
max_sequence_length=48192,
tokenizer=tokenizer,
)
# Prepare model, optimizer, data loaders, and scheduler for distributed training
model, optimizer, train_loader, val_loader, scheduler = accelerator.prepare(
model, optimizer, train_loader, val_loader, scheduler
)

# Create new optimizer and scheduler for the recovered model
optimizer = optim.AdamW(recovered_model.parameters(), lr=1e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=10)

# Train the recovered model
train(
recovered_model,
train_loader,
val_loader,
optimizer,
criterion,
scheduler,
tokenizer,
resume_from_checkpoint=resume_from_checkpoint,
)
# Check for the latest checkpoint specific to this sequence length
latest_checkpoint = f"checkpoint_latest_{length}.pt"
if os.path.exists(latest_checkpoint):
logger.info(f"Found checkpoint for length {length}: {latest_checkpoint}")
resume_from_checkpoint = latest_checkpoint
else:
logger.info(
f"No checkpoint found for length {length}, starting training from scratch"
)
resume_from_checkpoint = None

# Perform training or fine-tuning based on the current sequence length
if length in [128000, 256000]:
# Fine-tuning for specific steps as mentioned in the LongRoPE paper
fine_tune_steps = 400 if length == 128000 else 600
train(
model,
train_loader,
val_loader,
optimizer,
criterion,
scheduler,
tokenizer,
epochs=1,
gradient_accumulation_steps=fine_tune_steps // len(train_loader),
resume_from_checkpoint=resume_from_checkpoint,
max_steps=fine_tune_steps,
)
else:
# Regular training for other sequence lengths
train(
model,
train_loader,
val_loader,
optimizer,
criterion,
scheduler,
tokenizer,
resume_from_checkpoint=resume_from_checkpoint,
)

# Recover performance on shorter contexts after 256k extension
if length == 256000:
model = model.recover_short_context(
data=sequences,
max_sequence_length=48192,
tokenizer=tokenizer,
)

# Finish logging and close the Weights & Biases run
wandb.finish()
Expand Down

0 comments on commit 118c02e

Please sign in to comment.