Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OOM trying to pretrain llama 7b on v4-256 #98

Open
redbrain opened this issue Oct 8, 2023 · 7 comments
Open

OOM trying to pretrain llama 7b on v4-256 #98

redbrain opened this issue Oct 8, 2023 · 7 comments

Comments

@redbrain
Copy link

redbrain commented Oct 8, 2023

Command
python -m EasyLM.models.llama.llama_train \
    --mesh_dim='-1,32,1' \
    --dtype='fp32' \
    --total_steps=250000 \
    --log_freq=50 \
    --save_model_freq=0 \
    --save_milestone_freq=2500 \
    --load_llama_config='7b' \
    --update_llama_config='' \
    --load_dataset_state='' \
    --load_checkpoint='' \
    --tokenizer.vocab_file='gs://.../tokenizer.model' \
    --optimizer.type='adamw' \
    --optimizer.adamw_optimizer.weight_decay=0.1 \
    --optimizer.adamw_optimizer.lr=3e-4 \
    --optimizer.adamw_optimizer.end_lr=3e-5 \
    --optimizer.adamw_optimizer.lr_warmup_steps=2000 \
    --optimizer.adamw_optimizer.lr_decay_steps=250000 \
    --train_dataset.type='json' \
    --train_dataset.text_processor.fields='text' \
    --train_dataset.json_dataset.path='gs://.../dataset.jsonl' \
    --train_dataset.json_dataset.seq_length=2048 \
    --train_dataset.json_dataset.batch_size=2048 \
    --train_dataset.json_dataset.tokenizer_processes=16 \
    --checkpointer.save_optimizer_state=True \
    --logger.online=True \
    --logger.prefix='devingulliver' \
    --logger.project="sl_llama_7b" \
    --logger.output_dir="gs://.../output/" \
    --logger.wandb_dir="$HOME/experiment_output/sl_llama_7b"
Log
I1008 02:58:09.536914 139894565414912 mesh_utils.py:282] _create_device_mesh_for_nd_torus assignment: [(1,), (0, 2), ()]
  0% 0/250000 [02:56<?, ?it/s]
Traceback (most recent call last):
  File "/home/redbrain/EasyLM/EasyLM/models/llama/llama_train.py", line 267, in <module>
    mlxu.run(main)
  File "/home/redbrain/.local/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/redbrain/.local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/redbrain/EasyLM/EasyLM/models/llama/llama_train.py", line 235, in main
    train_state, sharded_rng, metrics = sharded_train_step(
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 250, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **params)
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/core.py", line 2677, in bind
    return self.bind_with_trace(top_trace, args, params)
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/core.py", line 383, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/core.py", line 815, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1203, in _pjit_call_impl
    return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1187, in call_impl_cache_miss
    out_flat, compiled = _pjit_call_impl_python(
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1123, in _pjit_call_impl_python
    always_lower=False, lowering_platform=None).compile()
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2323, in compile
    executable = UnloadedMeshExecutable.from_hlo(
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2645, in from_hlo
    xla_executable, compile_options = _cached_compilation(
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2555, in _cached_compilation
    xla_executable = dispatch.compile_or_get_cached(
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached
    return backend_compile(backend, computation, compile_options,
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/dispatch.py", line 465, in backend_compile
    return backend.compile(built_c, compile_options=options)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 163.10G of 30.75G hbm. Exceeded hbm capacity by 132.35G.
Total hbm usage >= 164.35G:
    reserved          1.25G
    program         163.10G
    arguments            0B
Output size 0B; shares 0B with arguments.
Program hbm requirement 163.10G:
    global           20.85M
    scoped            1.19M
    HLO temp        163.08G (100.0% utilization: Unpadded (157.97G) Padded (157.97G), 3.1% fragmentation (5.11G))
  Largest program allocations in hbm:
  1. Size: 8.00G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/22/attention/sub" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/attention.py" source_line=108
     Shape: f32[16,32,2048,2048]{2,3,1,0:T(8,128)}
     Unpadded size: 8.00G
     XLA label: fusion.8979.remat4 = fusion(fusion.112.remat), kind=kOutput, calls=fused_computation.6065.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  2. Size: 8.00G
     Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/22/attention/...hqk,...khd->...qhd/dot_general[dimension_numbers=(((2,), (3,)), ((0, 1), (0, 2))) precision=None preferred_element_type=None]" source_file="/home/redbrain/EasyLM/EasyLM/models/llama/llama_model.py" source_line=569
     Shape: f32[16,32,2048,2048]{2,3,1,0:T(8,128)}
     Unpadded size: 8.00G
     XLA label: fusion.1599.remat = fusion(bitcast.2905, reshape.13130), kind=kOutput, calls=fused_computation.1400.clone
     Allocation type: HLO temp
     ==========================
  3. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/19/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1888.remat4 = fusion(fusion.942.remat7.1.remat, copy-done.2890, copy-done.2506, all-gather.568.remat2), kind=kOutput, calls=fused_computation.1611.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  4. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/29/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1968.remat4 = fusion(fusion.922.remat7.1.remat, copy-done.2812, copy-done.2536, all-gather.638.remat2), kind=kOutput, calls=fused_computation.1651.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  5. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/18/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1882.remat4 = fusion(fusion.944.remat5.1.remat, copy-done.2897, copy-done.2503, copy-done.141), kind=kOutput, calls=fused_computation.1608.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  6. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/18/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1880.remat4 = fusion(fusion.944.remat7.1.remat, copy-done.2897, copy-done.2503, all-gather.561.remat2), kind=kOutput, calls=fused_computation.1607.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  7. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/19/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1890.remat4 = fusion(fusion.942.remat5.1.remat, copy-done.2890, copy-done.2506, copy-done.147), kind=kOutput, calls=fused_computation.1612.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  8. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/20/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1896.remat4 = fusion(fusion.940.remat7.1.remat, copy-done.2875, copy-done.2509, all-gather.575.remat2), kind=kOutput, calls=fused_computation.1615.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  9. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/20/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1898.remat4 = fusion(fusion.940.remat5.1.remat, copy-done.2875, copy-done.2509, copy-done.153), kind=kOutput, calls=fused_computation.1616.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  10. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/17/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1874.remat4 = fusion(fusion.946.remat5.1.remat, copy-done.2904, copy-done.2500, copy-done.135), kind=kOutput, calls=fused_computation.1604.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  11. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/17/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1872.remat4 = fusion(fusion.946.remat7.1.remat, copy-done.2904, copy-done.2500, all-gather.554.remat2), kind=kOutput, calls=fused_computation.1603.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  12. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/6/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1786.remat4 = fusion(fusion.968.remat5.1.remat, copy-done.2768, copy-done.2426, copy-done.62), kind=kOutput, calls=fused_computation.1560.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  13. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/6/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1784.remat4 = fusion(fusion.968.remat7.1.remat, copy-done.2768, copy-done.2426, all-gather.477.remat2), kind=kOutput, calls=fused_computation.1559.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  14. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/21/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1904.remat4 = fusion(fusion.938.remat7.1.remat, copy-done.2868, copy-done.2512, all-gather.582.remat2), kind=kOutput, calls=fused_computation.1619.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  15. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/21/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1906.remat4 = fusion(fusion.938.remat5.1.remat, copy-done.2868, copy-done.2512, copy-done.159), kind=kOutput, calls=fused_computation.1620.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  16. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/16/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1866.remat4 = fusion(fusion.948.remat5.1.remat, copy-done.2911, copy-done.2497, copy-done.129), kind=kOutput, calls=fused_computation.1600.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  17. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/16/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1864.remat4 = fusion(fusion.948.remat7.1.remat, copy-done.2911, copy-done.2497, all-gather.547.remat2), kind=kOutput, calls=fused_computation.1599.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  18. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/22/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1912.remat4 = fusion(fusion.936.remat7.1.remat, copy-done.2861, copy-done.2515, all-gather.589.remat2), kind=kOutput, calls=fused_computation.1623.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  19. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/22/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1914.remat4 = fusion(fusion.936.remat5.1.remat, copy-done.2861, copy-done.2515, copy-done.165), kind=kOutput, calls=fused_computation.1624.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  20. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/15/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1858.remat4 = fusion(fusion.950.remat5.1.remat, copy-done.2918, copy-done.2482, copy-done.123), kind=kOutput, calls=fused_computation.1596.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/redbrain/EasyLM/EasyLM/models/llama/llama_train.py", line 267, in <module>
    mlxu.run(main)
  File "/home/redbrain/.local/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/redbrain/.local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/redbrain/EasyLM/EasyLM/models/llama/llama_train.py", line 235, in main
    train_state, sharded_rng, metrics = sharded_train_step(
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 163.10G of 30.75G hbm. Exceeded hbm capacity by 132.35G.
Total hbm usage >= 164.35G:
    reserved          1.25G
    program         163.10G
    arguments            0B
Output size 0B; shares 0B with arguments.
Program hbm requirement 163.10G:
    global           20.85M
    scoped            1.19M
    HLO temp        163.08G (100.0% utilization: Unpadded (157.97G) Padded (157.97G), 3.1% fragmentation (5.11G))
  Largest program allocations in hbm:
  1. Size: 8.00G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/22/attention/sub" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/attention.py" source_line=108
     Shape: f32[16,32,2048,2048]{2,3,1,0:T(8,128)}
     Unpadded size: 8.00G
     XLA label: fusion.8979.remat4 = fusion(fusion.112.remat), kind=kOutput, calls=fused_computation.6065.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  2. Size: 8.00G
     Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/22/attention/...hqk,...khd->...qhd/dot_general[dimension_numbers=(((2,), (3,)), ((0, 1), (0, 2))) precision=None preferred_element_type=None]" source_file="/home/redbrain/EasyLM/EasyLM/models/llama/llama_model.py" source_line=569
     Shape: f32[16,32,2048,2048]{2,3,1,0:T(8,128)}
     Unpadded size: 8.00G
     XLA label: fusion.1599.remat = fusion(bitcast.2905, reshape.13130), kind=kOutput, calls=fused_computation.1400.clone
     Allocation type: HLO temp
     ==========================
  3. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/19/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1888.remat4 = fusion(fusion.942.remat7.1.remat, copy-done.2890, copy-done.2506, all-gather.568.remat2), kind=kOutput, calls=fused_computation.1611.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  4. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/29/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1968.remat4 = fusion(fusion.922.remat7.1.remat, copy-done.2812, copy-done.2536, all-gather.638.remat2), kind=kOutput, calls=fused_computation.1651.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  5. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/18/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1882.remat4 = fusion(fusion.944.remat5.1.remat, copy-done.2897, copy-done.2503, copy-done.141), kind=kOutput, calls=fused_computation.1608.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  6. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/18/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1880.remat4 = fusion(fusion.944.remat7.1.remat, copy-done.2897, copy-done.2503, all-gather.561.remat2), kind=kOutput, calls=fused_computation.1607.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  7. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/19/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1890.remat4 = fusion(fusion.942.remat5.1.remat, copy-done.2890, copy-done.2506, copy-done.147), kind=kOutput, calls=fused_computation.1612.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  8. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/20/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1896.remat4 = fusion(fusion.940.remat7.1.remat, copy-done.2875, copy-done.2509, all-gather.575.remat2), kind=kOutput, calls=fused_computation.1615.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  9. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/20/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1898.remat4 = fusion(fusion.940.remat5.1.remat, copy-done.2875, copy-done.2509, copy-done.153), kind=kOutput, calls=fused_computation.1616.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  10. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/17/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1874.remat4 = fusion(fusion.946.remat5.1.remat, copy-done.2904, copy-done.2500, copy-done.135), kind=kOutput, calls=fused_computation.1604.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  11. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/17/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1872.remat4 = fusion(fusion.946.remat7.1.remat, copy-done.2904, copy-done.2500, all-gather.554.remat2), kind=kOutput, calls=fused_computation.1603.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  12. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/6/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1786.remat4 = fusion(fusion.968.remat5.1.remat, copy-done.2768, copy-done.2426, copy-done.62), kind=kOutput, calls=fused_computation.1560.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  13. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/6/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1784.remat4 = fusion(fusion.968.remat7.1.remat, copy-done.2768, copy-done.2426, all-gather.477.remat2), kind=kOutput, calls=fused_computation.1559.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  14. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/21/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1904.remat4 = fusion(fusion.938.remat7.1.remat, copy-done.2868, copy-done.2512, all-gather.582.remat2), kind=kOutput, calls=fused_computation.1619.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  15. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/21/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1906.remat4 = fusion(fusion.938.remat5.1.remat, copy-done.2868, copy-done.2512, copy-done.159), kind=kOutput, calls=fused_computation.1620.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  16. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/16/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1866.remat4 = fusion(fusion.948.remat5.1.remat, copy-done.2911, copy-done.2497, copy-done.129), kind=kOutput, calls=fused_computation.1600.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  17. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/16/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1864.remat4 = fusion(fusion.948.remat7.1.remat, copy-done.2911, copy-done.2497, all-gather.547.remat2), kind=kOutput, calls=fused_computation.1599.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  18. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/22/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1912.remat4 = fusion(fusion.936.remat7.1.remat, copy-done.2861, copy-done.2515, all-gather.589.remat2), kind=kOutput, calls=fused_computation.1623.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  19. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/22/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1914.remat4 = fusion(fusion.936.remat5.1.remat, copy-done.2861, copy-done.2515, copy-done.165), kind=kOutput, calls=fused_computation.1624.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  20. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/15/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1858.remat4 = fusion(fusion.950.remat5.1.remat, copy-done.2918, copy-done.2482, copy-done.123), kind=kOutput, calls=fused_computation.1596.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
Traceback (most recent call last):
  File "/home/redbrain/EasyLM/EasyLM/models/llama/llama_train.py", line 267, in <module>
    mlxu.run(main)
  File "/home/redbrain/.local/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/redbrain/.local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/redbrain/EasyLM/EasyLM/models/llama/llama_train.py", line 235, in main
    train_state, sharded_rng, metrics = sharded_train_step(
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 250, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **params)
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/core.py", line 2677, in bind
    return self.bind_with_trace(top_trace, args, params)
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/core.py", line 383, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/core.py", line 815, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1203, in _pjit_call_impl
    return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1187, in call_impl_cache_miss
    out_flat, compiled = _pjit_call_impl_python(
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1123, in _pjit_call_impl_python
    always_lower=False, lowering_platform=None).compile()
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2323, in compile
    executable = UnloadedMeshExecutable.from_hlo(
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2645, in from_hlo
    xla_executable, compile_options = _cached_compilation(
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2555, in _cached_compilation
    xla_executable = dispatch.compile_or_get_cached(
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached
    return backend_compile(backend, computation, compile_options,
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/dispatch.py", line 465, in backend_compile
    return backend.compile(built_c, compile_options=options)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 163.10G of 30.75G hbm. Exceeded hbm capacity by 132.35G.
Total hbm usage >= 164.35G:
    reserved          1.25G
    program         163.10G
    arguments            0B
Output size 0B; shares 0B with arguments.
Program hbm requirement 163.10G:
    global           20.85M
    scoped            1.19M
    HLO temp        163.08G (100.0% utilization: Unpadded (157.97G) Padded (157.97G), 3.1% fragmentation (5.11G))
  Largest program allocations in hbm:
  1. Size: 8.00G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/22/attention/sub" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/attention.py" source_line=108
     Shape: f32[16,32,2048,2048]{2,3,1,0:T(8,128)}
     Unpadded size: 8.00G
     XLA label: fusion.8979.remat4 = fusion(fusion.112.remat), kind=kOutput, calls=fused_computation.6065.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  2. Size: 8.00G
     Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/22/attention/...hqk,...khd->...qhd/dot_general[dimension_numbers=(((2,), (3,)), ((0, 1), (0, 2))) precision=None preferred_element_type=None]" source_file="/home/redbrain/EasyLM/EasyLM/models/llama/llama_model.py" source_line=569
     Shape: f32[16,32,2048,2048]{2,3,1,0:T(8,128)}
     Unpadded size: 8.00G
     XLA label: fusion.1599.remat = fusion(bitcast.2905, reshape.13130), kind=kOutput, calls=fused_computation.1400.clone
     Allocation type: HLO temp
     ==========================
  3. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/19/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1888.remat4 = fusion(fusion.942.remat7.1.remat, copy-done.2890, copy-done.2506, all-gather.568.remat2), kind=kOutput, calls=fused_computation.1611.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  4. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/29/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1968.remat4 = fusion(fusion.922.remat7.1.remat, copy-done.2812, copy-done.2536, all-gather.638.remat2), kind=kOutput, calls=fused_computation.1651.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  5. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/18/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1882.remat4 = fusion(fusion.944.remat5.1.remat, copy-done.2897, copy-done.2503, copy-done.141), kind=kOutput, calls=fused_computation.1608.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  6. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/18/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1880.remat4 = fusion(fusion.944.remat7.1.remat, copy-done.2897, copy-done.2503, all-gather.561.remat2), kind=kOutput, calls=fused_computation.1607.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  7. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/19/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1890.remat4 = fusion(fusion.942.remat5.1.remat, copy-done.2890, copy-done.2506, copy-done.147), kind=kOutput, calls=fused_computation.1612.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  8. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/20/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1896.remat4 = fusion(fusion.940.remat7.1.remat, copy-done.2875, copy-done.2509, all-gather.575.remat2), kind=kOutput, calls=fused_computation.1615.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  9. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/20/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1898.remat4 = fusion(fusion.940.remat5.1.remat, copy-done.2875, copy-done.2509, copy-done.153), kind=kOutput, calls=fused_computation.1616.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  10. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/17/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1874.remat4 = fusion(fusion.946.remat5.1.remat, copy-done.2904, copy-done.2500, copy-done.135), kind=kOutput, calls=fused_computation.1604.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  11. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/17/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1872.remat4 = fusion(fusion.946.remat7.1.remat, copy-done.2904, copy-done.2500, all-gather.554.remat2), kind=kOutput, calls=fused_computation.1603.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  12. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/6/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1786.remat4 = fusion(fusion.968.remat5.1.remat, copy-done.2768, copy-done.2426, copy-done.62), kind=kOutput, calls=fused_computation.1560.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  13. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/6/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1784.remat4 = fusion(fusion.968.remat7.1.remat, copy-done.2768, copy-done.2426, all-gather.477.remat2), kind=kOutput, calls=fused_computation.1559.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  14. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/21/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1904.remat4 = fusion(fusion.938.remat7.1.remat, copy-done.2868, copy-done.2512, all-gather.582.remat2), kind=kOutput, calls=fused_computation.1619.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  15. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/21/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1906.remat4 = fusion(fusion.938.remat5.1.remat, copy-done.2868, copy-done.2512, copy-done.159), kind=kOutput, calls=fused_computation.1620.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  16. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/16/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1866.remat4 = fusion(fusion.948.remat5.1.remat, copy-done.2911, copy-done.2497, copy-done.129), kind=kOutput, calls=fused_computation.1600.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  17. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/16/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1864.remat4 = fusion(fusion.948.remat7.1.remat, copy-done.2911, copy-done.2497, all-gather.547.remat2), kind=kOutput, calls=fused_computation.1599.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  18. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/22/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1912.remat4 = fusion(fusion.936.remat7.1.remat, copy-done.2861, copy-done.2515, all-gather.589.remat2), kind=kOutput, calls=fused_computation.1623.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  19. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/22/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1914.remat4 = fusion(fusion.936.remat5.1.remat, copy-done.2861, copy-done.2515, copy-done.165), kind=kOutput, calls=fused_computation.1624.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  20. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/15/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1858.remat4 = fusion(fusion.950.remat5.1.remat, copy-done.2918, copy-done.2482, copy-done.123), kind=kOutput, calls=fused_computation.1596.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/redbrain/EasyLM/EasyLM/models/llama/llama_train.py", line 267, in <module>
    mlxu.run(main)
  File "/home/redbrain/.local/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/redbrain/.local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/redbrain/EasyLM/EasyLM/models/llama/llama_train.py", line 235, in main
    train_state, sharded_rng, metrics = sharded_train_step(
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 163.10G of 30.75G hbm. Exceeded hbm capacity by 132.35G.
Total hbm usage >= 164.35G:
    reserved          1.25G
    program         163.10G
    arguments            0B
Output size 0B; shares 0B with arguments.
Program hbm requirement 163.10G:
    global           20.85M
    scoped            1.19M
    HLO temp        163.08G (100.0% utilization: Unpadded (157.97G) Padded (157.97G), 3.1% fragmentation (5.11G))
  Largest program allocations in hbm:
  1. Size: 8.00G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/22/attention/sub" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/attention.py" source_line=108
     Shape: f32[16,32,2048,2048]{2,3,1,0:T(8,128)}
     Unpadded size: 8.00G
     XLA label: fusion.8979.remat4 = fusion(fusion.112.remat), kind=kOutput, calls=fused_computation.6065.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  2. Size: 8.00G
     Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/22/attention/...hqk,...khd->...qhd/dot_general[dimension_numbers=(((2,), (3,)), ((0, 1), (0, 2))) precision=None preferred_element_type=None]" source_file="/home/redbrain/EasyLM/EasyLM/models/llama/llama_model.py" source_line=569
     Shape: f32[16,32,2048,2048]{2,3,1,0:T(8,128)}
     Unpadded size: 8.00G
     XLA label: fusion.1599.remat = fusion(bitcast.2905, reshape.13130), kind=kOutput, calls=fused_computation.1400.clone
     Allocation type: HLO temp
     ==========================
  3. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/19/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1888.remat4 = fusion(fusion.942.remat7.1.remat, copy-done.2890, copy-done.2506, all-gather.568.remat2), kind=kOutput, calls=fused_computation.1611.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  4. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/29/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1968.remat4 = fusion(fusion.922.remat7.1.remat, copy-done.2812, copy-done.2536, all-gather.638.remat2), kind=kOutput, calls=fused_computation.1651.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  5. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/18/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1882.remat4 = fusion(fusion.944.remat5.1.remat, copy-done.2897, copy-done.2503, copy-done.141), kind=kOutput, calls=fused_computation.1608.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  6. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/18/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1880.remat4 = fusion(fusion.944.remat7.1.remat, copy-done.2897, copy-done.2503, all-gather.561.remat2), kind=kOutput, calls=fused_computation.1607.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  7. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/19/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1890.remat4 = fusion(fusion.942.remat5.1.remat, copy-done.2890, copy-done.2506, copy-done.147), kind=kOutput, calls=fused_computation.1612.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  8. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/20/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1896.remat4 = fusion(fusion.940.remat7.1.remat, copy-done.2875, copy-done.2509, all-gather.575.remat2), kind=kOutput, calls=fused_computation.1615.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  9. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/20/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1898.remat4 = fusion(fusion.940.remat5.1.remat, copy-done.2875, copy-done.2509, copy-done.153), kind=kOutput, calls=fused_computation.1616.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  10. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/17/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1874.remat4 = fusion(fusion.946.remat5.1.remat, copy-done.2904, copy-done.2500, copy-done.135), kind=kOutput, calls=fused_computation.1604.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  11. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/17/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1872.remat4 = fusion(fusion.946.remat7.1.remat, copy-done.2904, copy-done.2500, all-gather.554.remat2), kind=kOutput, calls=fused_computation.1603.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  12. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/6/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1786.remat4 = fusion(fusion.968.remat5.1.remat, copy-done.2768, copy-done.2426, copy-done.62), kind=kOutput, calls=fused_computation.1560.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  13. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/6/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1784.remat4 = fusion(fusion.968.remat7.1.remat, copy-done.2768, copy-done.2426, all-gather.477.remat2), kind=kOutput, calls=fused_computation.1559.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  14. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/21/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1904.remat4 = fusion(fusion.938.remat7.1.remat, copy-done.2868, copy-done.2512, all-gather.582.remat2), kind=kOutput, calls=fused_computation.1619.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  15. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/21/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1906.remat4 = fusion(fusion.938.remat5.1.remat, copy-done.2868, copy-done.2512, copy-done.159), kind=kOutput, calls=fused_computation.1620.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  16. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/16/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1866.remat4 = fusion(fusion.948.remat5.1.remat, copy-done.2911, copy-done.2497, copy-done.129), kind=kOutput, calls=fused_computation.1600.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  17. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/16/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1864.remat4 = fusion(fusion.948.remat7.1.remat, copy-done.2911, copy-done.2497, all-gather.547.remat2), kind=kOutput, calls=fused_computation.1599.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  18. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/22/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1912.remat4 = fusion(fusion.936.remat7.1.remat, copy-done.2861, copy-done.2515, all-gather.589.remat2), kind=kOutput, calls=fused_computation.1623.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  19. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/22/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1914.remat4 = fusion(fusion.936.remat5.1.remat, copy-done.2861, copy-done.2515, copy-done.165), kind=kOutput, calls=fused_computation.1624.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  20. Size: 1.34G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/15/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[16,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 1.34G
     XLA label: fusion.1858.remat4 = fusion(fusion.950.remat5.1.remat, copy-done.2918, copy-done.2482, copy-done.123), kind=kOutput, calls=fused_computation.1596.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
@young-geng
Copy link
Owner

This is expected. v4-256 is actually only 128 TPU-v4 chips (weird naming convention due to the fact that the two tensorcores on the same chip are viewed as separate devices before v4), so our OpenLLaMA 7B configuration actually uses v4-512. If you want to train on v4-256, consider using batch size 1024 and --mesh_dim='-1,64,1'.

@redbrain
Copy link
Author

redbrain commented Oct 23, 2023

This results in NotImplementedError: Failed to find assignment for logical_axis_index 1 of size 64 with remaining assignable mesh [4, 4, 8]. Any clue what went wrong?

@redbrain
Copy link
Author

redbrain commented Nov 6, 2023

It appears that since a v4-256 has half the chips of a v4-512, the appropriate mesh topology would be -1,32,1. But running it with that mesh and with batch sizes of 1024 and even 512 still produces OOM errors. Any advice on how to fix this?

@young-geng
Copy link
Owner

Oh, this is a known problem that by default JAX does not want to split a physical axis into multiple logical axes. However, we can force it to do that by specifying --mesh_dim='!-1,64,1'

@redbrain
Copy link
Author

Still not working, even with the parameters you suggested for mesh_dim and batch_size.

Full command
export LIBTPU_INIT_ARGS='--xla_jf_spmd_threshold_for_windowed_einsum_mib=0 --xla_tpu_spmd_threshold_for_allgather_cse=10000 --xla_tpu_spmd_rewrite_einsum_with_reshape=true --xla_enable_async_all_gather=true --jax_enable_async_collective_offload=true --xla_tpu_enable_latency_hiding_scheduler=true TPU_MEGACORE=MEGACORE_DENSE'


python -m EasyLM.models.llama.llama_train \
    --mesh_dim='!-1,64,1' \
    --dtype='fp32' \
    --total_steps=250000 \
    --log_freq=50 \
    --save_model_freq=0 \
    --save_milestone_freq=2500 \
    --load_llama_config='7b' \
    --update_llama_config='' \
    --load_dataset_state='' \
    --load_checkpoint='' \
    --tokenizer.vocab_file='gs://.../tokenizer.model' \
    --optimizer.type='adamw' \
    --optimizer.adamw_optimizer.weight_decay=0.1 \
    --optimizer.adamw_optimizer.lr=3e-4 \
    --optimizer.adamw_optimizer.end_lr=3e-5 \
    --optimizer.adamw_optimizer.lr_warmup_steps=2000 \
    --optimizer.adamw_optimizer.lr_decay_steps=250000 \
    --train_dataset.type='json' \
    --train_dataset.text_processor.fields='text' \
    --train_dataset.json_dataset.path='gs://.../slimpajama.jsonl' \
    --train_dataset.json_dataset.seq_length=2048 \
    --train_dataset.json_dataset.batch_size=1024 \
    --train_dataset.json_dataset.tokenizer_processes=16 \
    --checkpointer.save_optimizer_state=True \
    --logger.online=True \
    --logger.prefix='devingulliver' \
    --logger.project="slender_llama_7b" \
    --logger.output_dir="gs://.../output/" \
    --logger.wandb_dir="$HOME/experiment_output/slender_llama_7b" \
|& tee $HOME/output.txt
Full output log

  0% 0/250000 [02:22<?, ?it/s]
Traceback (most recent call last):
  File "/home/redbrain/EasyLM/EasyLM/models/llama/llama_train.py", line 267, in <module>
    mlxu.run(main)
  File "/home/redbrain/.local/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/redbrain/.local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/redbrain/EasyLM/EasyLM/models/llama/llama_train.py", line 235, in main
    train_state, sharded_rng, metrics = sharded_train_step(
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 250, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **params)
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/core.py", line 2677, in bind
    return self.bind_with_trace(top_trace, args, params)
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/core.py", line 383, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/core.py", line 815, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1203, in _pjit_call_impl
    return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1187, in call_impl_cache_miss
    out_flat, compiled = _pjit_call_impl_python(
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1123, in _pjit_call_impl_python
    always_lower=False, lowering_platform=None).compile()
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2323, in compile
    executable = UnloadedMeshExecutable.from_hlo(
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2645, in from_hlo
    xla_executable, compile_options = _cached_compilation(
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2555, in _cached_compilation
    xla_executable = dispatch.compile_or_get_cached(
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached
    return backend_compile(backend, computation, compile_options,
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/dispatch.py", line 465, in backend_compile
    return backend.compile(built_c, compile_options=options)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 87.08G of 30.75G hbm. Exceeded hbm capacity by 56.33G.
Total hbm usage >= 88.33G:
    reserved          1.25G
    program          87.08G
    arguments            0B
Output size 0B; shares 0B with arguments.
Program hbm requirement 87.08G:
    global           13.02M
    scoped           529.0K
    HLO temp         87.06G (100.0% utilization: Unpadded (86.24G) Padded (86.24G), 0.9% fragmentation (841.66M))
  Largest program allocations in hbm:
  1. Size: 4.00G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/4/attention/sub" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/attention.py" source_line=108
     Shape: f32[8,32,2048,2048]{2,3,1,0:T(8,128)}
     Unpadded size: 4.00G
     XLA label: fusion.8917.remat7 = fusion(fusion.130.remat6), kind=kOutput, calls=fused_computation.6051.clone.clone.clone.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  2. Size: 4.00G
     Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/0/attention/...hqk,...khd->...qhd/dot_general[dimension_numbers=(((2,), (3,)), ((0, 1), (0, 2))) precision=None preferred_element_type=None]" source_file="/home/redbrain/EasyLM/EasyLM/models/llama/llama_model.py" source_line=569
     Shape: f32[8,32,2048,2048]{2,3,1,0:T(8,128)}
     Unpadded size: 4.00G
     XLA label: fusion.136 = fusion(get-tuple-element.5938, copy-done.1220, bitcast.6320, copy-done.206), kind=kOutput, calls=fused_computation.136
     Allocation type: HLO temp
     ==========================
  3. Size: 4.00G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/0/attention/sub" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/attention.py" source_line=108
     Shape: f32[8,32,2048,2048]{2,3,1,0:T(8,128)}
     Unpadded size: 4.00G
     XLA label: fusion.8913.remat5 = fusion(fusion.134.remat2), kind=kOutput, calls=fused_computation.6047.clone.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  4. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/30/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1977.remat4 = fusion(fusion.926.remat5.1.remat, all-gather.623.remat2, param.536, fusion.6437), kind=kOutput, calls=fused_computation.1655.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  5. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/31/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1985.remat4 = fusion(fusion.924.remat5.1.remat, all-gather.630.remat2, param.545, fusion.6439), kind=kOutput, calls=fused_computation.1659.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  6. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/0/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.2799.remat2 = fusion(fusion.1057, all-gather.748.remat2), kind=kOutput, calls=fused_computation.2158.clone.clone
     Allocation type: HLO temp
     ==========================
  7. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/30/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.2559.remat5 = fusion(fusion.1207, copy-done.454), kind=kOutput, calls=fused_computation.1918.clone.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  8. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/1/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1745.remat3 = fusion(fusion.984.remat3, copy-done.464, copy-done.2938, copy-done.2521), kind=kOutput, calls=fused_computation.1539.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  9. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/1/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.2791.remat2 = fusion(fusion.1062, all-gather.741.remat2), kind=kOutput, calls=fused_computation.2150.clone.clone
     Allocation type: HLO temp
     ==========================
  10. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/29/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.2567.remat5 = fusion(fusion.1202, copy-done.452), kind=kOutput, calls=fused_computation.1926.clone.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  11. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/2/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1753.remat3 = fusion(fusion.982.remat3, copy-done.462, copy-done.2888, copy-done.2518), kind=kOutput, calls=fused_computation.1543.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  12. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/2/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.2783.remat2 = fusion(fusion.1067, all-gather.734.remat2), kind=kOutput, calls=fused_computation.2142.clone.clone
     Allocation type: HLO temp
     ==========================
  13. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/3/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1759.remat3 = fusion(fusion.980.remat3, all-gather.729.remat2, copy-done.2827, copy-done.2527), kind=kOutput, calls=fused_computation.1546.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  14. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/6/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1783.remat4 = fusion(fusion.974.remat7.1.remat, all-gather.456.remat2, copy-done.2801, copy-done.2540), kind=kOutput, calls=fused_computation.1558.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  15. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/6/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1785.remat4 = fusion(fusion.974.remat5.1.remat, all-gather.455.remat2, param.320, fusion.6389), kind=kOutput, calls=fused_computation.1559.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  16. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/3/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1761.remat3 = fusion(fusion.980.remat3, copy-done.460, copy-done.2826, copy-done.2527), kind=kOutput, calls=fused_computation.1547.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  17. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/8/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1799.remat4 = fusion(fusion.970.remat7.1.remat, all-gather.470.remat2, copy-done.2793, copy-done.2548), kind=kOutput, calls=fused_computation.1566.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  18. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/8/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1801.remat4 = fusion(fusion.970.remat5.1.remat, all-gather.469.remat2, param.338, fusion.6393), kind=kOutput, calls=fused_computation.1567.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  19. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/9/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1809.remat4 = fusion(fusion.968.remat5.1.remat, all-gather.476.remat2, param.347, fusion.6395), kind=kOutput, calls=fused_computation.1571.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  20. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/3/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.2775.remat2 = fusion(fusion.1072, all-gather.727.remat2), kind=kOutput, calls=fused_computation.2134.clone.clone
     Allocation type: HLO temp
     ==========================
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/redbrain/EasyLM/EasyLM/models/llama/llama_train.py", line 267, in <module>
    mlxu.run(main)
  File "/home/redbrain/.local/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/redbrain/.local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/redbrain/EasyLM/EasyLM/models/llama/llama_train.py", line 235, in main
    train_state, sharded_rng, metrics = sharded_train_step(
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 87.08G of 30.75G hbm. Exceeded hbm capacity by 56.33G.
Total hbm usage >= 88.33G:
    reserved          1.25G
    program          87.08G
    arguments            0B
Output size 0B; shares 0B with arguments.
Program hbm requirement 87.08G:
    global           13.02M
    scoped           529.0K
    HLO temp         87.06G (100.0% utilization: Unpadded (86.24G) Padded (86.24G), 0.9% fragmentation (841.66M))
  Largest program allocations in hbm:
  1. Size: 4.00G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/4/attention/sub" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/attention.py" source_line=108
     Shape: f32[8,32,2048,2048]{2,3,1,0:T(8,128)}
     Unpadded size: 4.00G
     XLA label: fusion.8917.remat7 = fusion(fusion.130.remat6), kind=kOutput, calls=fused_computation.6051.clone.clone.clone.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  2. Size: 4.00G
     Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/0/attention/...hqk,...khd->...qhd/dot_general[dimension_numbers=(((2,), (3,)), ((0, 1), (0, 2))) precision=None preferred_element_type=None]" source_file="/home/redbrain/EasyLM/EasyLM/models/llama/llama_model.py" source_line=569
     Shape: f32[8,32,2048,2048]{2,3,1,0:T(8,128)}
     Unpadded size: 4.00G
     XLA label: fusion.136 = fusion(get-tuple-element.5938, copy-done.1220, bitcast.6320, copy-done.206), kind=kOutput, calls=fused_computation.136
     Allocation type: HLO temp
     ==========================
  3. Size: 4.00G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/0/attention/sub" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/attention.py" source_line=108
     Shape: f32[8,32,2048,2048]{2,3,1,0:T(8,128)}
     Unpadded size: 4.00G
     XLA label: fusion.8913.remat5 = fusion(fusion.134.remat2), kind=kOutput, calls=fused_computation.6047.clone.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  4. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/30/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1977.remat4 = fusion(fusion.926.remat5.1.remat, all-gather.623.remat2, param.536, fusion.6437), kind=kOutput, calls=fused_computation.1655.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  5. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/31/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1985.remat4 = fusion(fusion.924.remat5.1.remat, all-gather.630.remat2, param.545, fusion.6439), kind=kOutput, calls=fused_computation.1659.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  6. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/0/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.2799.remat2 = fusion(fusion.1057, all-gather.748.remat2), kind=kOutput, calls=fused_computation.2158.clone.clone
     Allocation type: HLO temp
     ==========================
  7. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/30/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.2559.remat5 = fusion(fusion.1207, copy-done.454), kind=kOutput, calls=fused_computation.1918.clone.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  8. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/1/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1745.remat3 = fusion(fusion.984.remat3, copy-done.464, copy-done.2938, copy-done.2521), kind=kOutput, calls=fused_computation.1539.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  9. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/1/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.2791.remat2 = fusion(fusion.1062, all-gather.741.remat2), kind=kOutput, calls=fused_computation.2150.clone.clone
     Allocation type: HLO temp
     ==========================
  10. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/29/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.2567.remat5 = fusion(fusion.1202, copy-done.452), kind=kOutput, calls=fused_computation.1926.clone.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  11. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/2/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1753.remat3 = fusion(fusion.982.remat3, copy-done.462, copy-done.2888, copy-done.2518), kind=kOutput, calls=fused_computation.1543.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  12. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/2/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.2783.remat2 = fusion(fusion.1067, all-gather.734.remat2), kind=kOutput, calls=fused_computation.2142.clone.clone
     Allocation type: HLO temp
     ==========================
  13. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/3/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1759.remat3 = fusion(fusion.980.remat3, all-gather.729.remat2, copy-done.2827, copy-done.2527), kind=kOutput, calls=fused_computation.1546.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  14. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/6/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1783.remat4 = fusion(fusion.974.remat7.1.remat, all-gather.456.remat2, copy-done.2801, copy-done.2540), kind=kOutput, calls=fused_computation.1558.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  15. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/6/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1785.remat4 = fusion(fusion.974.remat5.1.remat, all-gather.455.remat2, param.320, fusion.6389), kind=kOutput, calls=fused_computation.1559.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  16. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/3/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1761.remat3 = fusion(fusion.980.remat3, copy-done.460, copy-done.2826, copy-done.2527), kind=kOutput, calls=fused_computation.1547.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  17. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/8/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1799.remat4 = fusion(fusion.970.remat7.1.remat, all-gather.470.remat2, copy-done.2793, copy-done.2548), kind=kOutput, calls=fused_computation.1566.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  18. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/8/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1801.remat4 = fusion(fusion.970.remat5.1.remat, all-gather.469.remat2, param.338, fusion.6393), kind=kOutput, calls=fused_computation.1567.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  19. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/9/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1809.remat4 = fusion(fusion.968.remat5.1.remat, all-gather.476.remat2, param.347, fusion.6395), kind=kOutput, calls=fused_computation.1571.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  20. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/3/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.2775.remat2 = fusion(fusion.1072, all-gather.727.remat2), kind=kOutput, calls=fused_computation.2134.clone.clone
     Allocation type: HLO temp
     ==========================
Traceback (most recent call last):
  File "/home/redbrain/EasyLM/EasyLM/models/llama/llama_train.py", line 267, in <module>
    mlxu.run(main)
  File "/home/redbrain/.local/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/redbrain/.local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/redbrain/EasyLM/EasyLM/models/llama/llama_train.py", line 235, in main
    train_state, sharded_rng, metrics = sharded_train_step(
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 250, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **params)
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/core.py", line 2677, in bind
    return self.bind_with_trace(top_trace, args, params)
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/core.py", line 383, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/core.py", line 815, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1203, in _pjit_call_impl
    return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1187, in call_impl_cache_miss
    out_flat, compiled = _pjit_call_impl_python(
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1123, in _pjit_call_impl_python
    always_lower=False, lowering_platform=None).compile()
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2323, in compile
    executable = UnloadedMeshExecutable.from_hlo(
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2645, in from_hlo
    xla_executable, compile_options = _cached_compilation(
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2555, in _cached_compilation
    xla_executable = dispatch.compile_or_get_cached(
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached
    return backend_compile(backend, computation, compile_options,
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/dispatch.py", line 465, in backend_compile
    return backend.compile(built_c, compile_options=options)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 87.08G of 30.75G hbm. Exceeded hbm capacity by 56.33G.
Total hbm usage >= 88.33G:
    reserved          1.25G
    program          87.08G
    arguments            0B
Output size 0B; shares 0B with arguments.
Program hbm requirement 87.08G:
    global           13.02M
    scoped           529.0K
    HLO temp         87.06G (100.0% utilization: Unpadded (86.24G) Padded (86.24G), 0.9% fragmentation (841.66M))
  Largest program allocations in hbm:
  1. Size: 4.00G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/4/attention/sub" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/attention.py" source_line=108
     Shape: f32[8,32,2048,2048]{2,3,1,0:T(8,128)}
     Unpadded size: 4.00G
     XLA label: fusion.8917.remat7 = fusion(fusion.130.remat6), kind=kOutput, calls=fused_computation.6051.clone.clone.clone.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  2. Size: 4.00G
     Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/0/attention/...hqk,...khd->...qhd/dot_general[dimension_numbers=(((2,), (3,)), ((0, 1), (0, 2))) precision=None preferred_element_type=None]" source_file="/home/redbrain/EasyLM/EasyLM/models/llama/llama_model.py" source_line=569
     Shape: f32[8,32,2048,2048]{2,3,1,0:T(8,128)}
     Unpadded size: 4.00G
     XLA label: fusion.136 = fusion(get-tuple-element.5938, copy-done.1220, bitcast.6320, copy-done.206), kind=kOutput, calls=fused_computation.136
     Allocation type: HLO temp
     ==========================
  3. Size: 4.00G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/0/attention/sub" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/attention.py" source_line=108
     Shape: f32[8,32,2048,2048]{2,3,1,0:T(8,128)}
     Unpadded size: 4.00G
     XLA label: fusion.8913.remat5 = fusion(fusion.134.remat2), kind=kOutput, calls=fused_computation.6047.clone.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  4. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/30/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1977.remat4 = fusion(fusion.926.remat5.1.remat, all-gather.623.remat2, param.536, fusion.6437), kind=kOutput, calls=fused_computation.1655.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  5. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/31/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1985.remat4 = fusion(fusion.924.remat5.1.remat, all-gather.630.remat2, param.545, fusion.6439), kind=kOutput, calls=fused_computation.1659.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  6. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/0/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.2799.remat2 = fusion(fusion.1057, all-gather.748.remat2), kind=kOutput, calls=fused_computation.2158.clone.clone
     Allocation type: HLO temp
     ==========================
  7. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/30/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.2559.remat5 = fusion(fusion.1207, copy-done.454), kind=kOutput, calls=fused_computation.1918.clone.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  8. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/1/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1745.remat3 = fusion(fusion.984.remat3, copy-done.464, copy-done.2938, copy-done.2521), kind=kOutput, calls=fused_computation.1539.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  9. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/1/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.2791.remat2 = fusion(fusion.1062, all-gather.741.remat2), kind=kOutput, calls=fused_computation.2150.clone.clone
     Allocation type: HLO temp
     ==========================
  10. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/29/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.2567.remat5 = fusion(fusion.1202, copy-done.452), kind=kOutput, calls=fused_computation.1926.clone.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  11. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/2/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1753.remat3 = fusion(fusion.982.remat3, copy-done.462, copy-done.2888, copy-done.2518), kind=kOutput, calls=fused_computation.1543.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  12. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/2/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.2783.remat2 = fusion(fusion.1067, all-gather.734.remat2), kind=kOutput, calls=fused_computation.2142.clone.clone
     Allocation type: HLO temp
     ==========================
  13. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/3/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1759.remat3 = fusion(fusion.980.remat3, all-gather.729.remat2, copy-done.2827, copy-done.2527), kind=kOutput, calls=fused_computation.1546.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  14. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/6/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1783.remat4 = fusion(fusion.974.remat7.1.remat, all-gather.456.remat2, copy-done.2801, copy-done.2540), kind=kOutput, calls=fused_computation.1558.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  15. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/6/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1785.remat4 = fusion(fusion.974.remat5.1.remat, all-gather.455.remat2, param.320, fusion.6389), kind=kOutput, calls=fused_computation.1559.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  16. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/3/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1761.remat3 = fusion(fusion.980.remat3, copy-done.460, copy-done.2826, copy-done.2527), kind=kOutput, calls=fused_computation.1547.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  17. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/8/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1799.remat4 = fusion(fusion.970.remat7.1.remat, all-gather.470.remat2, copy-done.2793, copy-done.2548), kind=kOutput, calls=fused_computation.1566.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  18. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/8/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1801.remat4 = fusion(fusion.970.remat5.1.remat, all-gather.469.remat2, param.338, fusion.6393), kind=kOutput, calls=fused_computation.1567.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  19. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/9/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1809.remat4 = fusion(fusion.968.remat5.1.remat, all-gather.476.remat2, param.347, fusion.6395), kind=kOutput, calls=fused_computation.1571.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  20. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/3/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.2775.remat2 = fusion(fusion.1072, all-gather.727.remat2), kind=kOutput, calls=fused_computation.2134.clone.clone
     Allocation type: HLO temp
     ==========================
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/redbrain/EasyLM/EasyLM/models/llama/llama_train.py", line 267, in <module>
    mlxu.run(main)
  File "/home/redbrain/.local/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/redbrain/.local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/redbrain/EasyLM/EasyLM/models/llama/llama_train.py", line 235, in main
    train_state, sharded_rng, metrics = sharded_train_step(
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 87.08G of 30.75G hbm. Exceeded hbm capacity by 56.33G.
Total hbm usage >= 88.33G:
    reserved          1.25G
    program          87.08G
    arguments            0B
Output size 0B; shares 0B with arguments.
Program hbm requirement 87.08G:
    global           13.02M
    scoped           529.0K
    HLO temp         87.06G (100.0% utilization: Unpadded (86.24G) Padded (86.24G), 0.9% fragmentation (841.66M))
  Largest program allocations in hbm:
  1. Size: 4.00G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/4/attention/sub" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/attention.py" source_line=108
     Shape: f32[8,32,2048,2048]{2,3,1,0:T(8,128)}
     Unpadded size: 4.00G
     XLA label: fusion.8917.remat7 = fusion(fusion.130.remat6), kind=kOutput, calls=fused_computation.6051.clone.clone.clone.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  2. Size: 4.00G
     Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/0/attention/...hqk,...khd->...qhd/dot_general[dimension_numbers=(((2,), (3,)), ((0, 1), (0, 2))) precision=None preferred_element_type=None]" source_file="/home/redbrain/EasyLM/EasyLM/models/llama/llama_model.py" source_line=569
     Shape: f32[8,32,2048,2048]{2,3,1,0:T(8,128)}
     Unpadded size: 4.00G
     XLA label: fusion.136 = fusion(get-tuple-element.5938, copy-done.1220, bitcast.6320, copy-done.206), kind=kOutput, calls=fused_computation.136
     Allocation type: HLO temp
     ==========================
  3. Size: 4.00G
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/0/attention/sub" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/attention.py" source_line=108
     Shape: f32[8,32,2048,2048]{2,3,1,0:T(8,128)}
     Unpadded size: 4.00G
     XLA label: fusion.8913.remat5 = fusion(fusion.134.remat2), kind=kOutput, calls=fused_computation.6047.clone.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  4. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/30/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1977.remat4 = fusion(fusion.926.remat5.1.remat, all-gather.623.remat2, param.536, fusion.6437), kind=kOutput, calls=fused_computation.1655.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  5. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/31/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1985.remat4 = fusion(fusion.924.remat5.1.remat, all-gather.630.remat2, param.545, fusion.6439), kind=kOutput, calls=fused_computation.1659.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  6. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/0/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.2799.remat2 = fusion(fusion.1057, all-gather.748.remat2), kind=kOutput, calls=fused_computation.2158.clone.clone
     Allocation type: HLO temp
     ==========================
  7. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/30/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.2559.remat5 = fusion(fusion.1207, copy-done.454), kind=kOutput, calls=fused_computation.1918.clone.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  8. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/1/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1745.remat3 = fusion(fusion.984.remat3, copy-done.464, copy-done.2938, copy-done.2521), kind=kOutput, calls=fused_computation.1539.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  9. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/1/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.2791.remat2 = fusion(fusion.1062, all-gather.741.remat2), kind=kOutput, calls=fused_computation.2150.clone.clone
     Allocation type: HLO temp
     ==========================
  10. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/29/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.2567.remat5 = fusion(fusion.1202, copy-done.452), kind=kOutput, calls=fused_computation.1926.clone.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  11. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/2/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1753.remat3 = fusion(fusion.982.remat3, copy-done.462, copy-done.2888, copy-done.2518), kind=kOutput, calls=fused_computation.1543.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  12. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/2/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.2783.remat2 = fusion(fusion.1067, all-gather.734.remat2), kind=kOutput, calls=fused_computation.2142.clone.clone
     Allocation type: HLO temp
     ==========================
  13. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/3/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1759.remat3 = fusion(fusion.980.remat3, all-gather.729.remat2, copy-done.2827, copy-done.2527), kind=kOutput, calls=fused_computation.1546.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  14. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/6/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1783.remat4 = fusion(fusion.974.remat7.1.remat, all-gather.456.remat2, copy-done.2801, copy-done.2540), kind=kOutput, calls=fused_computation.1558.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  15. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/6/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1785.remat4 = fusion(fusion.974.remat5.1.remat, all-gather.455.remat2, param.320, fusion.6389), kind=kOutput, calls=fused_computation.1559.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  16. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/3/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1761.remat3 = fusion(fusion.980.remat3, copy-done.460, copy-done.2826, copy-done.2527), kind=kOutput, calls=fused_computation.1547.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  17. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/8/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1799.remat4 = fusion(fusion.970.remat7.1.remat, all-gather.470.remat2, copy-done.2793, copy-done.2548), kind=kOutput, calls=fused_computation.1566.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  18. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/8/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1801.remat4 = fusion(fusion.970.remat5.1.remat, all-gather.469.remat2, param.338, fusion.6393), kind=kOutput, calls=fused_computation.1567.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  19. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/9/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.1809.remat4 = fusion(fusion.968.remat5.1.remat, all-gather.476.remat2, param.347, fusion.6395), kind=kOutput, calls=fused_computation.1571.clone.clone.clone.clone
     Allocation type: HLO temp
     ==========================
  20. Size: 688.00M
     Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/3/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206
     Shape: f32[8,2048,11008]{2,1,0:T(8,128)}
     Unpadded size: 688.00M
     XLA label: fusion.2775.remat2 = fusion(fusion.1072, all-gather.727.remat2), kind=kOutput, calls=fused_computation.2134.clone.clone
     Allocation type: HLO temp
     ==========================

@young-geng
Copy link
Owner

young-geng commented Nov 28, 2023

This is quite strange. Maybe XLA is not smart enough for allocating memory. In this case I'd recommend tweaking with batch sizes and mesh size. For example, try a even smaller batch size of 512 or using !-1,128,1 as mesh dim.

@0x7o
Copy link

0x7o commented Jul 26, 2024

I was able to run 7B model training on TPU v4-256 with mesh_dim = !-1,16,4 and batch_size = 64 at 115000 tokens per second

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants