Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix checkpointable_layers Logic #6881

Merged
merged 8 commits into from
Jan 4, 2025

Conversation

Quentin-Anthony
Copy link
Contributor

Problem

There's an edge-case in DeepSpeed, where if all three of the following are true:

  1. Deepspeed activation checkpointing is applied
  2. The user passes checkpointable_layers (e.g. https://github.com/EleutherAI/gpt-neox/blob/f5325805678c2b9e35aae4528283e0132c5f5bbc/megatron/model/gpt2_model.py#L175)
  3. The user's model class contains GPT2ModelPipe or GPTModelPipe`

Then the checkpointable_layers will not be activation checkpointed.

Reason

This is because in the current logic, _is_checkpointable will short-circuit to just return layers matching ParallelTransformerLayerPipe in the case of self.__class__.__name__ in ('GPTModelPipe', 'GPT2ModelPipe'). See

return all('ParallelTransformerLayerPipe' in f.__class__.__name__ for f in funcs)

Proposed Fixes

I think that checkpointable_layers should always be checked for, and added logic to this effect. I also found the documentation for checkpointable_layers confusing and contradictory, so I updated the docstring. Lastly, I added a unit test for checkpointable_layers.

@loadams loadams enabled auto-merge January 3, 2025 00:59
@loadams loadams added this pull request to the merge queue Jan 4, 2025
Merged via the queue into microsoft:master with commit 0dbbb70 Jan 4, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants