Skip to content

Commit

Permalink
Merge branch 'master' into lyj/lm_head_replace
Browse files Browse the repository at this point in the history
  • Loading branch information
delock authored Jan 3, 2025
2 parents c9ac0c2 + 3573858 commit f214afb
Show file tree
Hide file tree
Showing 25 changed files with 124 additions and 88 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/hpu-gaudi2-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
# The type of runner that the job will run on
runs-on: [self-hosted, intel, gaudi2]
container:
image: vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest
image: vault.habana.ai/gaudi-docker/1.19.0/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest
ports:
- 80
options: --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/hpu-gaudi2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
# The type of runner that the job will run on
runs-on: [self-hosted, intel, gaudi2]
container:
image: vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest
image: vault.habana.ai/gaudi-docker/1.19.0/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest
ports:
- 80
options: --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/no-torch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ on:
workflow_dispatch:
pull_request:
paths:
- 'accelerator/**'
- '.github/workflows/no-torch.yml'
- 'op_builder/**'
schedule:
Expand Down
45 changes: 24 additions & 21 deletions CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
Expand Up @@ -8,49 +8,52 @@

# top-level repo folders
/.github/ @loadams
/azure/ @awan-10
/benchmarks/ @awan-10 @tjruwase
/azure/ @loadams
/benchmarks/ @guanhuawang @tjruwase
/bin/ @loadams
/csrc/ @awan-10
/csrc/ @tjruwase
/deepspeed/ @loadams @tjruwase
/docker/ @awan-10
/docker/ @loadams @guanhuawang
/docs/ @loadams @tjruwase
/examples/ @awan-10 @tohtana
/examples/ @jomayeri @tohtana
/op_builder/ @loadams @tjruwase @jomayeri
/release/ @loadams
/release/ @loadams @jomayeri
/requirements/ @loadams
/scripts/ @awan-10
/scripts/ @loadams @tjruwase
/tests/ @tjruwase @loadams @tohtana

# deepspeed
/deepspeed/autotuning/ @loadams
/deepspeed/checkpoint/ @tjruwase
/deepspeed/comm/ @awan-10
/deepspeed/comm/ @guanhuawang
/deepspeed/compression/ @tjruwase
/deepspeed/elasticity/ @awan-10
/deepspeed/elasticity/ @tjruwase
/deepspeed/launcher/ @loadams
/deepspeed/module_inject/ @awan-10
/deepspeed/module_inject/ @hwchen2017 @loadams
/deepspeed/moe/ @tohtana
/deepspeed/monitor/ @awan-10
/deepspeed/monitor/ @tjruwase
/deepspeed/nebula/ @tjruwase
/deepspeed/nvme/ @tjruwase @jomayeri
/deepspeed/ops/ @tohtana
/deepspeed/pipe/ @tohtana @loadams
/deepspeed/profiling/ @loadams
/deepspeed/utils/ @tjruwase @awan-10
/deepspeed/sequence/ @tohtana
/deepspeed/utils/ @tjruwase @tohtana

# inference
/deepspeed/inference/ @awan-10
/deepspeed/model_implementations/ @awan-10
/deepspeed/inference/ @hwchen2017 @tohtana
/deepspeed/model_implementations/@tohtana @loadams

# training
/deepspeed/runtime/ @tjruwase @tohtana
/deepspeed/runtime/activation_checkpointing/ @tjruwase
/deepspeed/runtime/checkpoint_engine/ @tjruwase
/deepspeed/runtime/comm/ @awan-10
/deepspeed/runtime/compression/ @awan-10
/deepspeed/runtime/comm/ @guanhuawang
/deepspeed/runtime/compression/ @tjruwase
/deepspeed/runtime/data_pipeline/ @tjruwase
/deepspeed/runtime/fp16/ @tjruwase
/deepspeed/runtime/fp16/onebit/ @awan-10
/deepspeed/runtime/pipe/ @loadams
/deepspeed/runtime/swap_tensor/ @tjruwase
/deepspeed/runtime/zero/ @tjruwase
/deepspeed/runtime/domino/ @guanhuawang @hwchen2017
/deepspeed/runtime/fp16/ @tjruwase @tohtana
/deepspeed/runtime/fp16/onebit/ @tjruwase
/deepspeed/runtime/pipe/ @loadams @tohtana
/deepspeed/runtime/swap_tensor/ @tjruwase @jomayeri
/deepspeed/runtime/zero/ @tjruwase @tohtana
8 changes: 7 additions & 1 deletion accelerator/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@

# DeepSpeed Team

import torch
from .abstract_accelerator import DeepSpeedAccelerator

# During setup stage torch may not be installed, pass on no torch will
# allow op builder related API to be executed.
try:
import torch
except ImportError as e:
pass

try:
import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore
oneccl_imported_p = True
Expand Down
2 changes: 1 addition & 1 deletion accelerator/hpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def __init__(self):
self.apply_hpu_workarounds()
try:
import habana_frameworks.torch.hpu as hpu
hpu.setDeterministic(True)
self.hpu = hpu
torch.use_deterministic_algorithms(True)
except ImportError as e:
raise ValueError(
f"HPU_Accelerator requires habana_frameworks.torch.hpu, which is not installed on this system.")
Expand Down
50 changes: 28 additions & 22 deletions deepspeed/comm/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@
DS_COMM_REDUCE_OFF = False


def disable_compiler_collective(func):
if required_torch_version(min_version=2.3):
return func
return compiler.disable(func)


def build_shm_op():
builder = get_accelerator().create_op_builder("ShareMemCommBuilder")
if builder is None or not deepspeed.ops.__compatible_ops__[builder.NAME]:
Expand Down Expand Up @@ -114,7 +120,7 @@ def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name='
self.shm_comm_op.initialize(self.get_world_size(), self.get_rank())

@classmethod
@compiler.disable
@disable_compiler_collective
def get_all_gather_function(self):
if hasattr(torch.distributed, "all_gather_into_tensor"):
return torch.distributed.all_gather_into_tensor
Expand All @@ -123,7 +129,7 @@ def get_all_gather_function(self):
return None

@classmethod
@compiler.disable
@disable_compiler_collective
def get_reduce_scatter_function(self):
if hasattr(torch.distributed, "reduce_scatter_tensor"):
return torch.distributed.reduce_scatter_tensor
Expand All @@ -146,7 +152,7 @@ def init_process_group(self, backend, timeout, init_method, rank, world_size):
world_size=world_size)
self.using_mpi = torch.distributed.get_backend() == 'mpi'

@compiler.disable
@disable_compiler_collective
def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
op = self._reduce_op(op)
return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op)
Expand All @@ -158,7 +164,7 @@ def inference_all_reduce(self, tensor, op, group=None):
else:
return torch.ops.deepspeed.inference_all_reduce_(tensor)

@compiler.disable
@disable_compiler_collective
def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
""" proxy func to torch.distributed.all_reduce_coalesced,
which is included in PyTorch 1.13 and above
Expand All @@ -169,15 +175,15 @@ def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group
op = self._reduce_op(op)
return torch.distributed.all_reduce_coalesced(tensors=tensors, op=op, group=group, async_op=async_op)

@compiler.disable
@disable_compiler_collective
def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
if DS_COMM_REDUCE_OFF:
if int(os.getenv('RANK', '0')) == 0:
utils.logger.warning("REDUCE is OFF")
return Noop()
return torch.distributed.reduce(tensor=tensor, dst=dst, op=self._reduce_op(op), group=group, async_op=async_op)

@compiler.disable
@disable_compiler_collective
def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_op=False):
if DS_COMM_REDUCE_SCATTER_OFF:
if int(os.getenv('RANK', '0')) == 0:
Expand All @@ -190,7 +196,7 @@ def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_
group=group,
async_op=async_op)

@compiler.disable
@disable_compiler_collective
def broadcast(self, tensor, src, group=None, async_op=False):
if DS_COMM_BROADCAST_OFF:
if int(os.getenv('RANK', '0')) == 0:
Expand All @@ -199,7 +205,7 @@ def broadcast(self, tensor, src, group=None, async_op=False):
else:
return torch.distributed.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)

@compiler.disable
@disable_compiler_collective
def all_gather(self, tensor_list, tensor, group=None, async_op=False):
if DS_COMM_ALL_GATHER_OFF:
if int(os.getenv('RANK', '0')) == 0:
Expand All @@ -208,15 +214,15 @@ def all_gather(self, tensor_list, tensor, group=None, async_op=False):
else:
return torch.distributed.all_gather(tensor_list=tensor_list, tensor=tensor, group=group, async_op=async_op)

@compiler.disable
@disable_compiler_collective
def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_op=False):
if self.has_all_gather_into_tensor():
return self.all_gather_function(output_tensor=output_tensor,
input_tensor=input_tensor,
group=group,
async_op=async_op)

@compiler.disable
@disable_compiler_collective
def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=False):
if DS_COMM_ALL_GATHER_OFF:
if int(os.getenv('RANK', '0')) == 0:
Expand All @@ -234,7 +240,7 @@ def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=Fals
"please consider upgrading your pytorch installation.")
pass

@compiler.disable
@disable_compiler_collective
def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_op=False):
""""""
assert len(output_tensors) == len(input_tensors), ""
Expand All @@ -258,7 +264,7 @@ def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_
else:
reqs[-1].wait()

@compiler.disable
@disable_compiler_collective
def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, group=None, async_op=False):
if self.has_reduce_scatter_tensor():
return self.reduce_scatter_function(output_tensor,
Expand All @@ -272,7 +278,7 @@ def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, gr
"please consider upgrading your pytorch installation.")
pass

@compiler.disable
@disable_compiler_collective
def all_to_all_single(self,
output,
input,
Expand All @@ -287,49 +293,49 @@ def all_to_all_single(self,
group=group,
async_op=async_op)

@compiler.disable
@disable_compiler_collective
def all_to_all(self, output_tensor_list, input_tensor_list, group=None, async_op=False):
return torch.distributed.all_to_all(output_tensor_list, input_tensor_list, group=group, async_op=async_op)

@compiler.disable
@disable_compiler_collective
def send(self, tensor, dst, group=None, tag=0):
return torch.distributed.send(tensor=tensor, dst=dst, group=group, tag=tag)

@compiler.disable
@disable_compiler_collective
def recv(self, tensor, src=None, group=None, tag=0):
return torch.distributed.recv(tensor=tensor, src=src, group=group, tag=tag)

@compiler.disable
@disable_compiler_collective
def isend(self, tensor, dst, group=None, tag=0):
return torch.distributed.isend(tensor=tensor, dst=dst, group=group, tag=tag)

@compiler.disable
@disable_compiler_collective
def irecv(self, tensor, src=None, group=None, tag=0):
return torch.distributed.irecv(tensor=tensor, src=src, group=group, tag=tag)

@compiler.disable
@disable_compiler_collective
def gather(self, tensor, gather_list=None, dst=0, group=None, async_op=False):
return torch.distributed.gather(tensor=tensor,
gather_list=gather_list,
dst=dst,
group=group,
async_op=async_op)

@compiler.disable
@disable_compiler_collective
def scatter(self, tensor, scatter_list=None, src=0, group=None, async_op=False):
return torch.distributed.scatter(tensor=tensor,
scatter_list=scatter_list,
src=src,
group=group,
async_op=async_op)

@compiler.disable
@disable_compiler_collective
def barrier(self, group=torch.distributed.GroupMember.WORLD, async_op=False, device_ids=None):
if group is None:
group = torch.distributed.GroupMember.WORLD
return torch.distributed.barrier(group=group, async_op=async_op, device_ids=device_ids)

@compiler.disable
@disable_compiler_collective
def monitored_barrier(self, group=torch.distributed.GroupMember.WORLD, timeout=None, wait_all_ranks=False):
if group is None:
group = torch.distributed.GroupMember.WORLD
Expand Down
3 changes: 3 additions & 0 deletions deepspeed/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ class DeepSpeedTPConfig(DeepSpeedConfigModel):
tp_size: int = 1
""" Number of devices to split the model across using tensor parallelism. """

tp_grain_size: int = 64
"Desired MLP/lm_head tp size granularity. DNN library favors tensor size in granularity of power of 2, we pick 64 as a default size."

mpu: object = None
"""
A model parallelism unit object that implements
Expand Down
2 changes: 2 additions & 0 deletions deepspeed/launcher/multinode_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def get_cmd(self, environment, active_resources):
deepspeed_launch.append("--no_local_rank")
if self.args.save_pid:
deepspeed_launch += ["--save_pid", f"{os.getpid()}"]
if self.args.enable_each_rank_log:
deepspeed_launch.append(f"--enable_each_rank_log={self.args.enable_each_rank_log}")
if self.args.elastic_training:
deepspeed_launch.append("--enable_elastic_training")
deepspeed_launch.append(f"--max_elastic_nodes={self.args.max_elastic_nodes}")
Expand Down
6 changes: 5 additions & 1 deletion deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,11 +346,15 @@ def _replace(self, child, name, conv_linear_layer):
weight, bias = shard_value_with_share_qk(child.weight.data, child.bias, dist.get_rank(),
dist.get_world_size(), False)
return LinearAllreduce(weight, bias, self.mp_group)
# For Arctic model, bypass to all_reduce replacement for w2 weights
arctic_w2_all_reduce_linear = False
if 'Arctic' in str(self.module) and 'w2' in name:
arctic_w2_all_reduce_linear = True
# For MLP including chunk layer.
if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)):
weight, bias = shard_chunk_mlp(child.weight.data, child.bias, dist.get_rank(), dist.get_world_size())
return LinearLayer(weight=weight, bias=bias)
if name in self.all_reduce_linears:
if name in self.all_reduce_linears or arctic_w2_all_reduce_linear:
# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
# else [weight_shape[0], weight_shape[1] // mp_size]

Expand Down
5 changes: 4 additions & 1 deletion deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .layers import TensorParallelOcShardConv2d, TensorParallelIcShardConv2d

from deepspeed import comm as dist
from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd, set_num_attention_heads
from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd, set_num_attention_heads, set_tp_grain_size

from .load_checkpoint import load_model_with_checkpoint
import time
Expand Down Expand Up @@ -303,6 +303,9 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None):
if hasattr(model_config, 'num_attention_heads'):
set_num_attention_heads(getattr(model_config, 'num_attention_heads'))

# 4.4 set tp_grain_size
set_tp_grain_size(config.tensor_parallel.tp_grain_size)

# 5. Set linear policies
_autotp.update_linear_policies()

Expand Down
11 changes: 8 additions & 3 deletions deepspeed/module_inject/tp_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ def set_n_embd(num):
n_embd = num


def set_tp_grain_size(num):
global tp_grain_size
tp_grain_size = num


def get_num_kv_heads():
global num_kv_heads
if 'num_kv_heads' in globals():
Expand All @@ -45,9 +50,9 @@ def get_shard_size(total_size, mp_size, name=None, rank=None):
my_slices = (num_kv_heads // mp_size) + (1 if rank < (num_kv_heads % mp_size) else 0)
return total_size * my_slices // num_kv_heads
else:
if total_size >= 64:
grain_size = total_size // 64
return (grain_size // mp_size + (1 if rank < (grain_size % mp_size) else 0)) * 64
if total_size >= tp_grain_size:
grain_size = total_size // tp_grain_size
return (grain_size // mp_size + (1 if rank < (grain_size % mp_size) else 0)) * tp_grain_size
else:
return total_size // mp_size + (1 if rank < (total_size % mp_size) else 0)

Expand Down
Loading

0 comments on commit f214afb

Please sign in to comment.