Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] Masked SDPA attention kernel? #1582

Closed
EricLBuehler opened this issue Nov 12, 2024 · 2 comments
Closed

[Question] Masked SDPA attention kernel? #1582

EricLBuehler opened this issue Nov 12, 2024 · 2 comments

Comments

@EricLBuehler
Copy link

Hello MLX team!

Thank you for your excellent work here. We have been focusing on accelerating Candle performance on Metal and recently integrated some fast SDPA kernels from MLX!

I noticed that the scaled_dot_product_attention.metal file does not contain kernels for causal, non-decode step kernels. This seems to prohibit their usage in the prompt step for models such as Llama, and the fallback code for the SDPA op seems to indicate so.

Are there any plans to add masking support to the full SDPA kernel or current methods that have been taken to accelerate the prompt step?

@awni
Copy link
Member

awni commented Nov 12, 2024

Are there any plans to add masking support to the full SDPA kernel or current methods that have been taken to accelerate the prompt step?

This is a work in progress. @jagrit06 has been working on that. We're hoping to ship something soon, but I can't give you an exact timeline yet.

@awni
Copy link
Member

awni commented Nov 12, 2024

Closing as a dup of #129

@awni awni closed this as completed Nov 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants