Skip to content

Commit

Permalink
Ensure the synchronization of parameters using zero offload (#3435)
Browse files Browse the repository at this point in the history
Signed-off-by: Huan Zhao <quic_huzh@quicinc.com>
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
Co-authored-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
  • Loading branch information
quic-huzh and quic-kyunggeu authored Oct 29, 2024
1 parent c22ffb9 commit e5a89ed
Show file tree
Hide file tree
Showing 6 changed files with 291 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,27 @@

try:
from deepspeed.runtime.zero import ZeroParamStatus, GatheredParameters
from deepspeed.utils import safe_set_local_fp32_param

def gathered_parameters(params, *args, **kwargs):

class SafeGatheredParameters(GatheredParameters):
"""
Shallow wrapper around ref:`GatheredParameters`.
Unlike ref:`GatheredParameters`, this function can be also called
with parameters that are already all-gathered by deepspeed zero3 or zero-offload runtime.
Additionally, this function ensure the synchronization of parameters.
"""
params = [
p for p in params
# Ignore if the parameter is already all-gathered.
# deepspeed.zero.runtime.GatheredParameters assumes all the parameters to be "NOT_AVAILABLE"
# and can fail if some of them were already "AVAILABLE".
if getattr(p, 'ds_status', None) == ZeroParamStatus.NOT_AVAILABLE
]
return GatheredParameters(params, *args, **kwargs)
def __exit__(self, *exc):
super().__exit__(*exc)

if not self.enabled:
return

if self.src_rank is not None:
for param in self.params:
if hasattr(param, "_z3_optimizer"):
safe_set_local_fp32_param(param, param.ds_tensor)


@contextlib.contextmanager
def _do_patch_dummy_parameters(module):
Expand All @@ -82,9 +88,10 @@ def _do_patch_dummy_parameters(module):
getattr(module, name).data = data

except ImportError:
def gathered_parameters(*args, **kwargs): # pylint: disable=unused-argument
class SafeGatheredParameters(contextlib.nullcontext):
""" Dummy placeholder in case deepspeed doesn't exist """
return contextlib.nullcontext()
pass


def _do_patch_dummy_parameters(module): # pylint: disable=unused-argument
""" Dummy placeholder in case deepspeed doesn't exist """
Expand All @@ -94,7 +101,7 @@ def _do_patch_dummy_parameters(module): # pylint: disable=unused-argument
_ds_ctx = {}

def _all_gather(module, _):
ctx = gathered_parameters(module.parameters(recurse=False))
ctx = SafeGatheredParameters(module.parameters(recurse=False))
ctx.__enter__()
_ds_ctx[module] = ctx

Expand All @@ -111,9 +118,8 @@ def _restore(module, *_):

@contextlib.contextmanager
def _register_zero3_forward_hooks(model: torch.nn.Module, use_dummy_params: bool):
handles = []

# Temporarily materialize parameters to make forward runnable
handles = []
materialize_parameters = _patch_dummy_parameters if use_dummy_params else _all_gather
try:
for module in model.modules():
Expand Down
4 changes: 2 additions & 2 deletions TrainingExtensions/torch/src/python/aimet_torch/v2/nn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
_ContextManager,
flatten_nn_module_list,
)
from aimet_torch.v2.deepspeed_utils import gathered_parameters, _shallow_copy
from aimet_torch.v2.deepspeed_utils import SafeGatheredParameters, _shallow_copy

def _no_op(in_tensor):
return in_tensor
Expand Down Expand Up @@ -147,7 +147,7 @@ def _compute_param_encodings(self, overwrite: bool):
if not params:
return

with gathered_parameters(params.values()):
with SafeGatheredParameters(params.values()):
for param_qtzr, param in params.items():
with patch_attr(param_qtzr, "forward", _no_op), param_qtzr.compute_encodings():
_ = param_qtzr(param)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,13 @@ def _validate_arguments(tensor: torch.Tensor, scale: torch.Tensor,
f'and block size {block_size}')

elif not _is_expandable(scale.shape, tensor.shape):
raise RuntimeError(f"Scale of shape {scale.shape} cannot be expanded like input tensor of shape {tensor.shape}")
msg = f"Scale of shape {scale.shape} cannot be expanded like input tensor of shape {tensor.shape}. "
# Additional message if the tensor is empty
if tensor.numel() == 0:
msg += (f"Detected that the tensor is empty, which may be caused by the following reasons: "
f"1. The input tensor is incorrect. "
f"2. Improper use of model inference without initializing DeepSpeed after offloading parameters.")
raise RuntimeError(msg)

if qmin is not None and qmax is not None:
if qmin > qmax:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from aimet_torch.v2.quantization.base import QuantizerBase
from aimet_torch.v2.quantization.affine.backends import quantize, quantize_dequantize, torch_builtins, _derive_qmin_qmax
from aimet_torch.v2.utils import ste_round
from aimet_torch.v2.deepspeed_utils import SafeGatheredParameters
from ._utils import _GridMixin, _register_signature # pylint: disable=import-error


Expand Down Expand Up @@ -440,7 +441,7 @@ def set_range(self, min: torch.Tensor, max: torch.Tensor):
"""
Set quantization parameters to the given min-max range
"""
with torch.no_grad():
with torch.no_grad(), SafeGatheredParameters(self.parameters(recurse=False), modifier_rank=0):
self.min.copy_(min)
self.max.copy_(max)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -883,20 +883,15 @@ def test_is_initialized_with_deepspeed_zero3(init_process_group, deepspeed_zero3
qdq = QuantizeDequantize((10,), bitwidth=8, symmetric=True, encoding_analyzer=MinMaxEncodingAnalyzer((10,)))
engine, *_ = ds.initialize(model=qdq, config=deepspeed_zero3_config)
qdq_zero3 = engine.module
with ds.zero.GatheredParameters(qdq_zero3.parameters(), modifier_rank=0):
qdq_zero3.set_range(-1, 1)
assert qdq_zero3.is_initialized()
qdq_zero3.set_range(-1, 1)
assert qdq_zero3.is_initialized()

# TODO (kyunggeu): Support the below use case
# qdq = QuantizeDequantize((10,), bitwidth=8, symmetric=True, encoding_analyzer=MinMaxEncodingAnalyzer((10,)))
# engine, *_ = ds.initialize(model=qdq, config=deepspeed_zero3_config)
# qdq_zero3 = engine.module
# with ds.zero.GatheredParameters(qdq_zero3.parameters(), modifier_rank=0):
# with qdq_zero3.compute_encodings():
# _ = qdq_zero3(torch.arange(-5, 5, dtype=torch.float, device='cuda:0'))
# assert qdq_zero3.is_initialized()
# assert qdq_zero3.is_initialized()
qdq = QuantizeDequantize((10,), bitwidth=8, symmetric=True, encoding_analyzer=MinMaxEncodingAnalyzer((10,)))
engine, *_ = ds.initialize(model=qdq, config=deepspeed_zero3_config)
qdq_zero3 = engine.module
with qdq_zero3.compute_encodings():
_ = qdq_zero3(torch.arange(-5, 5, dtype=torch.float, device='cuda:0'))
assert qdq_zero3.is_initialized()

"""
When: Gather the partitioned quantization parameters in writable mode but don't modify them
Expand Down
Loading

0 comments on commit e5a89ed

Please sign in to comment.