You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, when running llama_train.py distributedly on a v3-512 tpu pod, when I turn on evaluation (eval_steps > 0), I got this error:
RuntimeError: Running operations on `Array`s that are not fully addressable by this process (
i.e. `Array`s with data sharded across multiple devices and processes.) is dangerous. It’s ve
ry important that all processes run the same cross-process computations in the same order oth
erwise it can lead to hangs. If you’re not already familiar with JAX’s multi-process programm
ing model, please read https://jax.readthedocs.io/en/latest/multi_process.html. To fix this e
rror, run your `jitted` computation inside `with jax.spmd_mode('allow_all'):` context manager
.
Hi, when running
llama_train.py
distributedly on a v3-512 tpu pod, when I turn on evaluation (eval_steps > 0
), I got this error:This happens at this line in code :
Could you please help me with this? Thank you very much!
The text was updated successfully, but these errors were encountered: