Skip to content

Commit

Permalink
Merge branch 'master' into qanthony/fix-act-recomp
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Dec 30, 2024
2 parents cd6fa6d + 3573858 commit be5392e
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/hpu-gaudi2-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
# The type of runner that the job will run on
runs-on: [self-hosted, intel, gaudi2]
container:
image: vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest
image: vault.habana.ai/gaudi-docker/1.19.0/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest
ports:
- 80
options: --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/hpu-gaudi2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
# The type of runner that the job will run on
runs-on: [self-hosted, intel, gaudi2]
container:
image: vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest
image: vault.habana.ai/gaudi-docker/1.19.0/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest
ports:
- 80
options: --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice
Expand Down
8 changes: 8 additions & 0 deletions deepspeed/runtime/pipe/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,3 +672,11 @@ def get_additional_losses(self):
Return a dictionary of {"loss name": loss_value} or None if no additional losses.
"""
return None

def compile(self, *args, **kwargs):
for idx, layer in enumerate(self.forward_funcs):
if isinstance(layer, nn.Module):
layer.compile(*args, **kwargs)
else:
new_layer = torch.compile(layer, *args, **kwargs)
self.forward_funcs[idx] = new_layer
8 changes: 6 additions & 2 deletions tests/unit/pipe/test_pipe_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,12 @@ def batch_input():

class TestPipeModuleSequential(DistributedTest):
world_size = 2
# needs to be set for torch.compile: running torch.compile with daemonic process causes an error
non_daemonic_procs = True

@pytest.mark.parametrize("activation_checkpoints", [False, True])
def test(self, sequential_model, simple_config, batch_input, activation_checkpoints):
@pytest.mark.parametrize("use_compile", [False, True])
def test(self, sequential_model, simple_config, batch_input, activation_checkpoints, use_compile):
base_model = copy.deepcopy(sequential_model)
base_input = batch_input.clone().detach()
base_output = base_model(base_input)
Expand All @@ -71,7 +74,8 @@ def test(self, sequential_model, simple_config, batch_input, activation_checkpoi

pipe_model = copy.deepcopy(sequential_model)
pipe_model = PipelineModule(layers=pipe_model, num_stages=2)

if (use_compile):
pipe_model.compile()
# Ensure all parameters are accounted for.
my_params = sum(p.numel() for p in pipe_model.parameters())
total_pipe_params = torch.LongTensor([my_params]).to(get_accelerator().device_name())
Expand Down

0 comments on commit be5392e

Please sign in to comment.