Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SDPA support for LayoutLMv3 model #35469

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

stancld
Copy link
Contributor

@stancld stancld commented Dec 31, 2024

What does this PR do?

Part of #35467.

Performance benchmark

Speed & memory req consumption on a token classification ntraining of LayoutLMv3-like model with multilingual support, various auxiliary tasks, masked language modelling.

GPU: 1x A100 80 GB
Batch size: 16, Accumulated gradient batches: 8

Impl. Speed Peak memory
Eager ~2.0 it/s 66.7 Gi
SDPA ~3.0 it/s 47.2 Gi

Overall, ~50% speed-up and memory reqs reduction is observed.

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

cc: @ArthurZucker

@stancld stancld changed the title [WIP] Add SDPA support for LayoutLMv3 model Add SDPA support for LayoutLMv3 model Dec 31, 2024
@stancld stancld force-pushed the ds/feat/layoutlmv3-flash-attn branch 4 times, most recently from 1ef5b7e to c5de661 Compare December 31, 2024 13:01
@stancld stancld force-pushed the ds/feat/layoutlmv3-flash-attn branch from c5de661 to 923cdea Compare January 2, 2025 10:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant