You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thank you for your excellent work here. We have been focusing on accelerating Candle performance on Metal and recently integrated some fast SDPA kernels from MLX!
I noticed that the scaled_dot_product_attention.metal file does not contain kernels for causal, non-decode step kernels. This seems to prohibit their usage in the prompt step for models such as Llama, and the fallback code for the SDPA op seems to indicate so.
Are there any plans to add masking support to the full SDPA kernel or current methods that have been taken to accelerate the prompt step?
The text was updated successfully, but these errors were encountered:
Hello MLX team!
Thank you for your excellent work here. We have been focusing on accelerating Candle performance on Metal and recently integrated some fast SDPA kernels from MLX!
I noticed that the scaled_dot_product_attention.metal file does not contain kernels for causal, non-decode step kernels. This seems to prohibit their usage in the prompt step for models such as Llama, and the fallback code for the SDPA op seems to indicate so.
Are there any plans to add masking support to the full SDPA kernel or current methods that have been taken to accelerate the prompt step?
The text was updated successfully, but these errors were encountered: