Skip to content

Commit

Permalink
add initialization check
Browse files Browse the repository at this point in the history
  • Loading branch information
hwchen2017 committed Dec 17, 2024
1 parent 2d3d688 commit 2833db7
Showing 1 changed file with 10 additions and 13 deletions.
23 changes: 10 additions & 13 deletions deepspeed/runtime/domino/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import deepspeed.comm
from deepspeed.comm.comm import init_distributed
import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator


def is_rank_0():
if deepspeed.comm.get_rank() == 0:
if dist.get_rank() == 0:
return True


Expand Down Expand Up @@ -97,7 +96,7 @@ def backward(ctx, grad_output):
return grad_output

# Async All-reduce.
handle = deepspeed.comm.all_reduce(grad_output, group=ctx.mpu.get_tensor_model_parallel_group(), async_op=True)
handle = dist.all_reduce(grad_output, group=ctx.mpu.get_tensor_model_parallel_group(), async_op=True)
ctx.handle_dic[ctx.h_id] = handle
return None, grad_output, None, None

Expand Down Expand Up @@ -249,7 +248,9 @@ def __init__(self,
output_bias=None):
super(DominoTransformerLayer, self).__init__()

init_distributed()
if not dist.is_initialized():
dist.init_distributed()
assert dist.is_initialized(), "deepspeed.comm is not initialized!"

self.llama_model = config.llama_model
self.layer_number = layer_number
Expand Down Expand Up @@ -360,18 +361,14 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None):
layernorm_output0,
attention_mask,
rotary_pos_emb=rotary_pos_emb)
handle0 = deepspeed.comm.all_reduce(attention_output0,
group=self.mpu.get_tensor_model_parallel_group(),
async_op=True)
handle0 = dist.all_reduce(attention_output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True)

attention_output1, attention_bias1 = \
self.self_attention(
layernorm_output1,
attention_mask,
rotary_pos_emb=rotary_pos_emb)
handle1 = deepspeed.comm.all_reduce(attention_output1,
group=self.mpu.get_tensor_model_parallel_group(),
async_op=True)
handle1 = dist.all_reduce(attention_output1, group=self.mpu.get_tensor_model_parallel_group(), async_op=True)
handle0.wait()

# Residual0 connection.
Expand Down Expand Up @@ -415,7 +412,7 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None):
output0 = output0 + bias_c
output0 = self.mlp_activation_func(output0)
output0 = torch.matmul(output0, self.weight_r.t())
handle2 = deepspeed.comm.all_reduce(output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True)
handle2 = dist.all_reduce(output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True)

handle1.wait()

Expand All @@ -427,7 +424,7 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None):
if bias_c is not None:
output1 = output1 + bias_c
output1 = torch.matmul(output1, self.weight_r.t())
deepspeed.comm.all_reduce(output1, group=self.mpu.get_tensor_model_parallel_group())
dist.all_reduce(output1, group=self.mpu.get_tensor_model_parallel_group())

handle2.wait()

Expand Down

0 comments on commit 2833db7

Please sign in to comment.