Skip to content

Commit

Permalink
Merge branch 'master' into loadams/transformers-inference
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Jan 2, 2025
2 parents cfc4448 + 3573858 commit 0877e1c
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/hpu-gaudi2-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
# The type of runner that the job will run on
runs-on: [self-hosted, intel, gaudi2]
container:
image: vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest
image: vault.habana.ai/gaudi-docker/1.19.0/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest
ports:
- 80
options: --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/hpu-gaudi2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
# The type of runner that the job will run on
runs-on: [self-hosted, intel, gaudi2]
container:
image: vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest
image: vault.habana.ai/gaudi-docker/1.19.0/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest
ports:
- 80
options: --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice
Expand Down
2 changes: 1 addition & 1 deletion accelerator/hpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def __init__(self):
self.apply_hpu_workarounds()
try:
import habana_frameworks.torch.hpu as hpu
hpu.setDeterministic(True)
self.hpu = hpu
torch.use_deterministic_algorithms(True)
except ImportError as e:
raise ValueError(
f"HPU_Accelerator requires habana_frameworks.torch.hpu, which is not installed on this system.")
Expand Down
21 changes: 10 additions & 11 deletions deepspeed/runtime/domino/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import deepspeed
from deepspeed import comm as dist
import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator


Expand Down Expand Up @@ -97,7 +96,7 @@ def backward(ctx, grad_output):
return grad_output

# Async All-reduce.
handle = deepspeed.comm.all_reduce(grad_output, group=ctx.mpu.get_tensor_model_parallel_group(), async_op=True)
handle = dist.all_reduce(grad_output, group=ctx.mpu.get_tensor_model_parallel_group(), async_op=True)
ctx.handle_dic[ctx.h_id] = handle
return None, grad_output, None, None

Expand Down Expand Up @@ -249,6 +248,10 @@ def __init__(self,
output_bias=None):
super(DominoTransformerLayer, self).__init__()

if not dist.is_initialized():
dist.init_distributed()
assert dist.is_initialized(), "deepspeed.comm is not initialized!"

self.llama_model = config.llama_model
self.layer_number = layer_number
self.layer_type = layer_type
Expand Down Expand Up @@ -358,18 +361,14 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None):
layernorm_output0,
attention_mask,
rotary_pos_emb=rotary_pos_emb)
handle0 = deepspeed.comm.all_reduce(attention_output0,
group=self.mpu.get_tensor_model_parallel_group(),
async_op=True)
handle0 = dist.all_reduce(attention_output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True)

attention_output1, attention_bias1 = \
self.self_attention(
layernorm_output1,
attention_mask,
rotary_pos_emb=rotary_pos_emb)
handle1 = deepspeed.comm.all_reduce(attention_output1,
group=self.mpu.get_tensor_model_parallel_group(),
async_op=True)
handle1 = dist.all_reduce(attention_output1, group=self.mpu.get_tensor_model_parallel_group(), async_op=True)
handle0.wait()

# Residual0 connection.
Expand Down Expand Up @@ -413,7 +412,7 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None):
output0 = output0 + bias_c
output0 = self.mlp_activation_func(output0)
output0 = torch.matmul(output0, self.weight_r.t())
handle2 = deepspeed.comm.all_reduce(output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True)
handle2 = dist.all_reduce(output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True)

handle1.wait()

Expand All @@ -425,7 +424,7 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None):
if bias_c is not None:
output1 = output1 + bias_c
output1 = torch.matmul(output1, self.weight_r.t())
deepspeed.comm.all_reduce(output1, group=self.mpu.get_tensor_model_parallel_group())
dist.all_reduce(output1, group=self.mpu.get_tensor_model_parallel_group())

handle2.wait()

Expand Down
8 changes: 8 additions & 0 deletions deepspeed/runtime/pipe/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,3 +662,11 @@ def get_additional_losses(self):
Return a dictionary of {"loss name": loss_value} or None if no additional losses.
"""
return None

def compile(self, *args, **kwargs):
for idx, layer in enumerate(self.forward_funcs):
if isinstance(layer, nn.Module):
layer.compile(*args, **kwargs)
else:
new_layer = torch.compile(layer, *args, **kwargs)
self.forward_funcs[idx] = new_layer
9 changes: 6 additions & 3 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ def __init__(self,
for param in param_group['params']:
if param.requires_grad:
param.grad_accum = None
param.param_idx_in_group = len(trainable_parameters)
trainable_parameters.append(param)
self.bit16_groups.append(trainable_parameters)

Expand Down Expand Up @@ -961,7 +962,7 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i):
assert grad_reduc is not None, f"rank {dist.get_rank()} - Invalid to reduce Param {param_id} with None gradient"

self.grads_in_ipg_bucket.append(grad_reduc)
self.params_in_ipg_bucket.append((i, param, param_id))
self.params_in_ipg_bucket.append((i, param.param_idx_in_group, param_id))

#make sure the average tensor function knows how to average the gradients
if is_moe_param(param):
Expand Down Expand Up @@ -1067,7 +1068,8 @@ def average_tensor(self, tensor):

process_group = self.dp_process_group
# count = 0
for i, param, param_id in self.params_in_ipg_bucket:
for i, param_idx_in_group, param_id in self.params_in_ipg_bucket:
param = self.bit16_groups[i][param_idx_in_group]

process_group = self.dp_process_group

Expand Down Expand Up @@ -1383,7 +1385,8 @@ def reduce_ipg_grads(self):
stream = get_accelerator().current_stream()

with get_accelerator().stream(stream):
for _, param, param_id in self.params_in_ipg_bucket:
for group_idx, param_idx_in_group, param_id in self.params_in_ipg_bucket:
param = self.bit16_groups[group_idx][param_idx_in_group]

assert self.params_already_reduced[param_id] == False, \
f"The parameter {param_id} has already been reduced. \
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/moe/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def strict_average_tensor(tensor):
process_group = optimizer.dp_process_group
curr_size = 0
pg_offsets = []
for i, param, param_id in optimizer.params_in_ipg_bucket:
for i, param_idx, param_id in optimizer.params_in_ipg_bucket:
param = optimizer.bit16_groups[i][param_idx]
process_group = optimizer.dp_process_group
if optimizer.ipg_bucket_has_moe_params:
process_group = optimizer.expert_dp_process_group[param.group_name] if is_moe_param(
Expand Down
8 changes: 6 additions & 2 deletions tests/unit/pipe/test_pipe_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,12 @@ def batch_input():

class TestPipeModuleSequential(DistributedTest):
world_size = 2
# needs to be set for torch.compile: running torch.compile with daemonic process causes an error
non_daemonic_procs = True

@pytest.mark.parametrize("activation_checkpoints", [False, True])
def test(self, sequential_model, simple_config, batch_input, activation_checkpoints):
@pytest.mark.parametrize("use_compile", [False, True])
def test(self, sequential_model, simple_config, batch_input, activation_checkpoints, use_compile):
base_model = copy.deepcopy(sequential_model)
base_input = batch_input.clone().detach()
base_output = base_model(base_input)
Expand All @@ -71,7 +74,8 @@ def test(self, sequential_model, simple_config, batch_input, activation_checkpoi

pipe_model = copy.deepcopy(sequential_model)
pipe_model = PipelineModule(layers=pipe_model, num_stages=2)

if (use_compile):
pipe_model.compile()
# Ensure all parameters are accounted for.
my_params = sum(p.numel() for p in pipe_model.parameters())
total_pipe_params = torch.LongTensor([my_params]).to(get_accelerator().device_name())
Expand Down

0 comments on commit 0877e1c

Please sign in to comment.