-
Notifications
You must be signed in to change notification settings - Fork 260
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
OOM trying to pretrain llama 7b on v4-256 #98
Comments
This is expected. v4-256 is actually only 128 TPU-v4 chips (weird naming convention due to the fact that the two tensorcores on the same chip are viewed as separate devices before v4), so our OpenLLaMA 7B configuration actually uses v4-512. If you want to train on v4-256, consider using batch size 1024 and |
This results in |
It appears that since a v4-256 has half the chips of a v4-512, the appropriate mesh topology would be |
Oh, this is a known problem that by default JAX does not want to split a physical axis into multiple logical axes. However, we can force it to do that by specifying |
Still not working, even with the parameters you suggested for mesh_dim and batch_size. Full commandexport LIBTPU_INIT_ARGS='--xla_jf_spmd_threshold_for_windowed_einsum_mib=0 --xla_tpu_spmd_threshold_for_allgather_cse=10000 --xla_tpu_spmd_rewrite_einsum_with_reshape=true --xla_enable_async_all_gather=true --jax_enable_async_collective_offload=true --xla_tpu_enable_latency_hiding_scheduler=true TPU_MEGACORE=MEGACORE_DENSE'
python -m EasyLM.models.llama.llama_train \
--mesh_dim='!-1,64,1' \
--dtype='fp32' \
--total_steps=250000 \
--log_freq=50 \
--save_model_freq=0 \
--save_milestone_freq=2500 \
--load_llama_config='7b' \
--update_llama_config='' \
--load_dataset_state='' \
--load_checkpoint='' \
--tokenizer.vocab_file='gs://.../tokenizer.model' \
--optimizer.type='adamw' \
--optimizer.adamw_optimizer.weight_decay=0.1 \
--optimizer.adamw_optimizer.lr=3e-4 \
--optimizer.adamw_optimizer.end_lr=3e-5 \
--optimizer.adamw_optimizer.lr_warmup_steps=2000 \
--optimizer.adamw_optimizer.lr_decay_steps=250000 \
--train_dataset.type='json' \
--train_dataset.text_processor.fields='text' \
--train_dataset.json_dataset.path='gs://.../slimpajama.jsonl' \
--train_dataset.json_dataset.seq_length=2048 \
--train_dataset.json_dataset.batch_size=1024 \
--train_dataset.json_dataset.tokenizer_processes=16 \
--checkpointer.save_optimizer_state=True \
--logger.online=True \
--logger.prefix='devingulliver' \
--logger.project="slender_llama_7b" \
--logger.output_dir="gs://.../output/" \
--logger.wandb_dir="$HOME/experiment_output/slender_llama_7b" \
|& tee $HOME/output.txt Full output log
|
This is quite strange. Maybe XLA is not smart enough for allocating memory. In this case I'd recommend tweaking with batch sizes and mesh size. For example, try a even smaller batch size of 512 or using |
I was able to run 7B model training on TPU v4-256 with |
Command
Log
The text was updated successfully, but these errors were encountered: