From e5a89edaf3215c6534592c2ddac4fcc55f7a95c4 Mon Sep 17 00:00:00 2001 From: Huan Zhao Date: Tue, 29 Oct 2024 10:49:32 +0800 Subject: [PATCH] Ensure the synchronization of parameters using zero offload (#3435) Signed-off-by: Huan Zhao Signed-off-by: Kyunggeun Lee Co-authored-by: Kyunggeun Lee --- .../python/aimet_torch/v2/deepspeed_utils.py | 34 ++- .../src/python/aimet_torch/v2/nn/base.py | 4 +- .../affine/backends/torch_builtins.py | 8 +- .../v2/quantization/affine/quantizer.py | 3 +- .../affine/test_affine_quantizer.py | 19 +- .../torch/test/python/v2/test_deepspeed.py | 256 +++++++++++++++++- 6 files changed, 291 insertions(+), 33 deletions(-) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v2/deepspeed_utils.py b/TrainingExtensions/torch/src/python/aimet_torch/v2/deepspeed_utils.py index 0185214c51d..91eb1d4ef3d 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v2/deepspeed_utils.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v2/deepspeed_utils.py @@ -42,21 +42,27 @@ try: from deepspeed.runtime.zero import ZeroParamStatus, GatheredParameters + from deepspeed.utils import safe_set_local_fp32_param - def gathered_parameters(params, *args, **kwargs): + + class SafeGatheredParameters(GatheredParameters): """ Shallow wrapper around ref:`GatheredParameters`. Unlike ref:`GatheredParameters`, this function can be also called with parameters that are already all-gathered by deepspeed zero3 or zero-offload runtime. + Additionally, this function ensure the synchronization of parameters. """ - params = [ - p for p in params - # Ignore if the parameter is already all-gathered. - # deepspeed.zero.runtime.GatheredParameters assumes all the parameters to be "NOT_AVAILABLE" - # and can fail if some of them were already "AVAILABLE". - if getattr(p, 'ds_status', None) == ZeroParamStatus.NOT_AVAILABLE - ] - return GatheredParameters(params, *args, **kwargs) + def __exit__(self, *exc): + super().__exit__(*exc) + + if not self.enabled: + return + + if self.src_rank is not None: + for param in self.params: + if hasattr(param, "_z3_optimizer"): + safe_set_local_fp32_param(param, param.ds_tensor) + @contextlib.contextmanager def _do_patch_dummy_parameters(module): @@ -82,9 +88,10 @@ def _do_patch_dummy_parameters(module): getattr(module, name).data = data except ImportError: - def gathered_parameters(*args, **kwargs): # pylint: disable=unused-argument + class SafeGatheredParameters(contextlib.nullcontext): """ Dummy placeholder in case deepspeed doesn't exist """ - return contextlib.nullcontext() + pass + def _do_patch_dummy_parameters(module): # pylint: disable=unused-argument """ Dummy placeholder in case deepspeed doesn't exist """ @@ -94,7 +101,7 @@ def _do_patch_dummy_parameters(module): # pylint: disable=unused-argument _ds_ctx = {} def _all_gather(module, _): - ctx = gathered_parameters(module.parameters(recurse=False)) + ctx = SafeGatheredParameters(module.parameters(recurse=False)) ctx.__enter__() _ds_ctx[module] = ctx @@ -111,9 +118,8 @@ def _restore(module, *_): @contextlib.contextmanager def _register_zero3_forward_hooks(model: torch.nn.Module, use_dummy_params: bool): - handles = [] - # Temporarily materialize parameters to make forward runnable + handles = [] materialize_parameters = _patch_dummy_parameters if use_dummy_params else _all_gather try: for module in model.modules(): diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/base.py b/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/base.py index 1c4f52ce3d3..8156aa64b31 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/base.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/base.py @@ -53,7 +53,7 @@ _ContextManager, flatten_nn_module_list, ) -from aimet_torch.v2.deepspeed_utils import gathered_parameters, _shallow_copy +from aimet_torch.v2.deepspeed_utils import SafeGatheredParameters, _shallow_copy def _no_op(in_tensor): return in_tensor @@ -147,7 +147,7 @@ def _compute_param_encodings(self, overwrite: bool): if not params: return - with gathered_parameters(params.values()): + with SafeGatheredParameters(params.values()): for param_qtzr, param in params.items(): with patch_attr(param_qtzr, "forward", _no_op), param_qtzr.compute_encodings(): _ = param_qtzr(param) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/affine/backends/torch_builtins.py b/TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/affine/backends/torch_builtins.py index 7ba9c41bfa9..503a04c4350 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/affine/backends/torch_builtins.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/affine/backends/torch_builtins.py @@ -105,7 +105,13 @@ def _validate_arguments(tensor: torch.Tensor, scale: torch.Tensor, f'and block size {block_size}') elif not _is_expandable(scale.shape, tensor.shape): - raise RuntimeError(f"Scale of shape {scale.shape} cannot be expanded like input tensor of shape {tensor.shape}") + msg = f"Scale of shape {scale.shape} cannot be expanded like input tensor of shape {tensor.shape}. " + # Additional message if the tensor is empty + if tensor.numel() == 0: + msg += (f"Detected that the tensor is empty, which may be caused by the following reasons: " + f"1. The input tensor is incorrect. " + f"2. Improper use of model inference without initializing DeepSpeed after offloading parameters.") + raise RuntimeError(msg) if qmin is not None and qmax is not None: if qmin > qmax: diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/affine/quantizer.py b/TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/affine/quantizer.py index ebd421dd06e..6f33589e922 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/affine/quantizer.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/affine/quantizer.py @@ -53,6 +53,7 @@ from aimet_torch.v2.quantization.base import QuantizerBase from aimet_torch.v2.quantization.affine.backends import quantize, quantize_dequantize, torch_builtins, _derive_qmin_qmax from aimet_torch.v2.utils import ste_round +from aimet_torch.v2.deepspeed_utils import SafeGatheredParameters from ._utils import _GridMixin, _register_signature # pylint: disable=import-error @@ -440,7 +441,7 @@ def set_range(self, min: torch.Tensor, max: torch.Tensor): """ Set quantization parameters to the given min-max range """ - with torch.no_grad(): + with torch.no_grad(), SafeGatheredParameters(self.parameters(recurse=False), modifier_rank=0): self.min.copy_(min) self.max.copy_(max) diff --git a/TrainingExtensions/torch/test/python/v2/quantization/affine/test_affine_quantizer.py b/TrainingExtensions/torch/test/python/v2/quantization/affine/test_affine_quantizer.py index b2d133359a1..5165b2b3c75 100644 --- a/TrainingExtensions/torch/test/python/v2/quantization/affine/test_affine_quantizer.py +++ b/TrainingExtensions/torch/test/python/v2/quantization/affine/test_affine_quantizer.py @@ -883,20 +883,15 @@ def test_is_initialized_with_deepspeed_zero3(init_process_group, deepspeed_zero3 qdq = QuantizeDequantize((10,), bitwidth=8, symmetric=True, encoding_analyzer=MinMaxEncodingAnalyzer((10,))) engine, *_ = ds.initialize(model=qdq, config=deepspeed_zero3_config) qdq_zero3 = engine.module - with ds.zero.GatheredParameters(qdq_zero3.parameters(), modifier_rank=0): - qdq_zero3.set_range(-1, 1) - assert qdq_zero3.is_initialized() + qdq_zero3.set_range(-1, 1) assert qdq_zero3.is_initialized() - # TODO (kyunggeu): Support the below use case - # qdq = QuantizeDequantize((10,), bitwidth=8, symmetric=True, encoding_analyzer=MinMaxEncodingAnalyzer((10,))) - # engine, *_ = ds.initialize(model=qdq, config=deepspeed_zero3_config) - # qdq_zero3 = engine.module - # with ds.zero.GatheredParameters(qdq_zero3.parameters(), modifier_rank=0): - # with qdq_zero3.compute_encodings(): - # _ = qdq_zero3(torch.arange(-5, 5, dtype=torch.float, device='cuda:0')) - # assert qdq_zero3.is_initialized() - # assert qdq_zero3.is_initialized() + qdq = QuantizeDequantize((10,), bitwidth=8, symmetric=True, encoding_analyzer=MinMaxEncodingAnalyzer((10,))) + engine, *_ = ds.initialize(model=qdq, config=deepspeed_zero3_config) + qdq_zero3 = engine.module + with qdq_zero3.compute_encodings(): + _ = qdq_zero3(torch.arange(-5, 5, dtype=torch.float, device='cuda:0')) + assert qdq_zero3.is_initialized() """ When: Gather the partitioned quantization parameters in writable mode but don't modify them diff --git a/TrainingExtensions/torch/test/python/v2/test_deepspeed.py b/TrainingExtensions/torch/test/python/v2/test_deepspeed.py index a6cd1820a4c..199babbc6be 100644 --- a/TrainingExtensions/torch/test/python/v2/test_deepspeed.py +++ b/TrainingExtensions/torch/test/python/v2/test_deepspeed.py @@ -64,6 +64,7 @@ from aimet_torch.v2.quantization.affine import QuantizeDequantize from aimet_torch.v2.quantization.base.quantizer import QuantizerBase from aimet_torch.v2.quantization import DequantizedTensor +from aimet_torch.v2.deepspeed_utils import SafeGatheredParameters class Net(nn.Module): @@ -248,6 +249,14 @@ def test_deepspeed_zero3_offload(unlabeled_data_loader, per_channel_quantsim_config, init_process_group, deepspeed_zero3_offload_config): + """ + This test case demonstrates how to train a model using DeepSpeed Zero3 Offload in regular mode. + This mode offers a performance improvement compared to the compatibility mode for PTQ operations. + Steps: + 1. Create QuantSim + 2. Initialize with ds.initialize + 3. Compute encodings + """ # Baseline model without deepsped model_baseline = Net().cuda().eval() baseline_state_dict = model_baseline.state_dict() @@ -259,7 +268,249 @@ def test_deepspeed_zero3_offload(unlabeled_data_loader, in_place=True) """ - Given: Model pre-partitioned with deepspeed zero3 offload + Given: Model pre-partitioned with DeepSpeed Zero3 offload + """ + with ds.zero.Init(config_dict_or_path=deepspeed_zero3_offload_config): + # ds.zero.Init context pre-partitoins the pytorch models at instantiation time. + # PyTorch modules instantiated under this context will only hold a partition + # of their parameters + model = Net().cuda().eval() + assert all(param.numel() == 0 for param in model.parameters()) # sanity check + assert all(hasattr(param, 'ds_shape') for param in model.parameters()) # sanity check + + # Copy the parameters/buffers of baseline model to deepspeed pre-partitoined model to assert + # outputs to be equal with or without deepspeed + with ds.runtime.zero.GatheredParameters(model.parameters(), modifier_rank=0), torch.no_grad(): + model.load_state_dict(baseline_state_dict) + + """ + When: Create quantsim with the model pre-partitioned model + Then: Quantizers should be instantiated with correct shape + """ + sim_deepspeed = QuantizationSimModel(model, + torch.randn(1, 1, 28, 28).cuda(), + default_param_bw=4, + config_file=per_channel_quantsim_config, + quant_scheme=QuantScheme.training_range_learning_with_tf_init, + in_place=True) + + assert isinstance(sim_deepspeed.model.conv1.input_quantizers[0], QuantizeDequantize) + assert isinstance(sim_deepspeed.model.conv1.param_quantizers['weight'], QuantizeDequantize) + assert isinstance(sim_deepspeed.model.conv1.output_quantizers[0], QuantizeDequantize) + assert isinstance(sim_deepspeed.model.maxpool1.output_quantizers[0], QuantizeDequantize) + assert isinstance(sim_deepspeed.model.relu1.output_quantizers[0], QuantizeDequantize) + assert isinstance(sim_deepspeed.model.conv2.param_quantizers['weight'], QuantizeDequantize) + assert isinstance(sim_deepspeed.model.conv2.output_quantizers[0], QuantizeDequantize) + assert isinstance(sim_deepspeed.model.maxpool2.output_quantizers[0], QuantizeDequantize) + assert isinstance(sim_deepspeed.model.relu2.output_quantizers[0], QuantizeDequantize) + assert isinstance(sim_deepspeed.model.fc1.param_quantizers['weight'], QuantizeDequantize) + assert sim_deepspeed.model.fc1.output_quantizers[0] is None + assert isinstance(sim_deepspeed.model.relu3.output_quantizers[0], QuantizeDequantize) + assert isinstance(sim_deepspeed.model.fc2.param_quantizers['weight'], QuantizeDequantize) + assert isinstance(sim_deepspeed.model.fc2.output_quantizers[0], QuantizeDequantize) + assert isinstance(sim_deepspeed.model.log_softmax.output_quantizers[0], QuantizeDequantize) + + assert sim_deepspeed.model.conv1.param_quantizers['weight'].shape == (32, 1, 1, 1) + assert sim_deepspeed.model.conv2.param_quantizers['weight'].shape == (32, 1, 1, 1) + + # NOTE: default per-channel quantsim config doesn't apply per-channel qtzn to nn.Linear + assert sim_deepspeed.model.fc1.param_quantizers['weight'].shape == () + assert sim_deepspeed.model.fc2.param_quantizers['weight'].shape == () + + assert sim_deepspeed.model.conv1.input_quantizers[0].shape ==\ + sim_deepspeed.model.conv1.output_quantizers[0].shape ==\ + sim_deepspeed.model.maxpool1.output_quantizers[0].shape ==\ + sim_deepspeed.model.relu1.output_quantizers[0].shape ==\ + sim_deepspeed.model.conv2.output_quantizers[0].shape ==\ + sim_deepspeed.model.maxpool2.output_quantizers[0].shape ==\ + sim_deepspeed.model.relu2.output_quantizers[0].shape ==\ + sim_deepspeed.model.relu3.output_quantizers[0].shape ==\ + sim_deepspeed.model.fc2.output_quantizers[0].shape ==\ + sim_deepspeed.model.log_softmax.output_quantizers[0].shape == () + + + """ + When: Initialize quantsim model with deepspeed zero3 offload + Then: + 1) All parameters must be initialized with deepspeed zero3 parameter partitioning mechanism + 2) Forward pass outputs must be equal with or without deepspeed + """ + engine, ds_optimizer, *_ = ds.initialize(model=sim_deepspeed.model, + model_parameters=sim_deepspeed.model.parameters(), + config=deepspeed_zero3_offload_config, + mpu=CustomMPU(init_process_group)) + # Indicates that the model has been initialized with DeepSpeed ZeRO stage 3 if hasattr(param, 'ds_id') returns True + assert all(hasattr(param, 'ds_id') for param in model.parameters()) + + + """ + When: Compute encodings after deepspeed initialization + Then: + 1) All quantizer encodings must be inititalized + 2) get_{encoding, scale, offset, min, max} returns real tensors, not empty tensors + """ + with aimet.nn.compute_encodings(sim_deepspeed.model),\ + aimet.nn.compute_encodings(sim_baseline.model): + for data in itertools.islice(unlabeled_data_loader, 3): + data = data.cuda() + _ = sim_deepspeed.model(data) + _ = sim_baseline.model(data) + + for qtzr in sim_deepspeed.model.modules(): + if isinstance(qtzr, QuantizerBase): + assert qtzr.is_initialized() + + with torch.no_grad(): + for data in unlabeled_data_loader: + data = data.cuda() + assert torch.equal(sim_deepspeed.model(data), sim_baseline.model(data)) + + """ + When: Run training loop + Then: All trainable parameters must be udpated by training in the (almost) same way + with or without deepspeed + """ + with ds.runtime.zero.GatheredParameters(sim_deepspeed.model.parameters()): + ds_params_before = { + name: param.clone().detach() for name, param in sim_deepspeed.model.named_parameters() + } + + target = torch.ones((1, 10)).float().cuda() + sim_deepspeed.model.train() + sim_baseline.model.train() + optimizer = torch.optim.AdamW([{ + 'params': sim_baseline.model.parameters(), + 'lr': ds_optimizer.get_lr(), + 'weight_decay': ds_optimizer.param_groups[0]['weight_decay'], + 'betas': ds_optimizer.param_groups[0]['betas'], + 'eps': ds_optimizer.param_groups[0]['eps'], + 'bias_correction': True, + }]) + + for i, data in enumerate(unlabeled_data_loader): + output = sim_deepspeed.model(data.cuda()) + output_baseline = sim_baseline.model(data.cuda()) + assert torch.allclose(output, output_baseline, rtol=1e-3) + assert isinstance(output, DequantizedTensor) + assert output.encoding.scale.numel() == 1 + assert output.encoding.offset.numel() == 1 + loss = functional.mse_loss(output, target) + loss_baseline = functional.mse_loss(output_baseline, target) + engine.backward(loss) + loss_baseline.backward() + + if i == 0: + # Gradient checker + for param_ds, param_baseline in zip(sim_deepspeed.model.parameters(), + sim_baseline.model.parameters()): + grad_ds = ds.utils.safe_get_full_grad(param_ds) + assert torch.allclose(grad_ds, param_baseline.grad, rtol=1e-3) + + ds_optimizer.step() + optimizer.step() + ds_optimizer.zero_grad() + optimizer.zero_grad() + + with ds.runtime.zero.GatheredParameters(sim_deepspeed.model.parameters()): + ds_params_after = { + name: param.clone().detach() for name, param in sim_deepspeed.model.named_parameters() + } + + assert ds_params_before.keys() == ds_params_after.keys() + for param_name in ds_params_before: + ds_before = ds_params_before[param_name] + ds_after = ds_params_after[param_name] + assert not torch.equal(ds_before, ds_after) + +@pytest.mark.cuda +def test_deepspeed_zero3_offload_buckets_sync(unlabeled_data_loader, + per_channel_quantsim_config, + init_process_group, + deepspeed_zero3_offload_config): + """ + Verify that the fp32_partitioned_groups_flat are synchronized with the written values + using SafeGatheredParameters. + """ + # Given: Model pre-partitioned with DeepSpeed Zero3 offload + with ds.zero.Init(config_dict_or_path=deepspeed_zero3_offload_config): + # ds.zero.Init context pre-partitoins the pytorch models at instantiation time. + # PyTorch modules instantiated under this context will only hold a partition + # of their parameters + model = Net().cuda().eval() + assert all(param.numel() == 0 for param in model.parameters()) # sanity check + assert all(hasattr(param, 'ds_shape') for param in model.parameters()) # sanity check + + """ + When: Create quantsim with the model pre-partitioned model + Then: Quantizers should be instantiated with correct shape + """ + sim_deepspeed = QuantizationSimModel(model, + torch.randn(1, 1, 28, 28).cuda(), + default_param_bw=4, + config_file=per_channel_quantsim_config, + quant_scheme=QuantScheme.training_range_learning_with_tf_init, + in_place=True) + + """ + When: Initialize quantsim model with deepspeed zero3 offload + Then: + 1) All parameters must be initialized with deepspeed zero3 parameter partitioning mechanism + """ + engine, ds_optimizer, *_ = ds.initialize(model=model, + model_parameters=model.parameters(), + config=deepspeed_zero3_offload_config, + mpu=CustomMPU(init_process_group)) + # Indicates that the model has been initialized with DeepSpeed ZeRO stage 3 if hasattr(param, 'ds_id') returns True + assert all(hasattr(param, 'ds_id') for param in model.parameters()) + + """ + When: Compute encodings after deepspeed initialization + Then: + 1) Trace mode are set to ZeRoTraceMode.COMPLETE: 2 + 2) fp32_partitioned_groups_flat are synchronized with the written values + """ + + sim_deepspeed.model.eval() + with aimet.nn.compute_encodings(sim_deepspeed.model), torch.no_grad(): + for data in unlabeled_data_loader: + data = data.cuda() + _ = sim_deepspeed.model(data) + + param_coordinator = ds_optimizer._get_param_coordinator(False) + assert param_coordinator.is_complete_trace() + + with SafeGatheredParameters(sim_deepspeed.model.parameters()), torch.no_grad(): + for param in sim_deepspeed.model.parameters(): + if param.requires_grad: + bucket_param, group_idx = param._z3_optimizer._get_fp32_opt_state_partition(param, None) + assert torch.all(bucket_param == param.flatten().cpu()) + +@pytest.mark.cuda +def test_deepspeed_zero3_offload_fallback(unlabeled_data_loader, + per_channel_quantsim_config, + init_process_group, + deepspeed_zero3_offload_config): + """ + This test case demonstrates how to train a model using DeepSpeed Zero3 Offload in compatibility mode. + This mode will significantly decrease the performance of PTQ operations like compute_encodings. + However, compatibility will improve due to fewer hooks and functions being applied. + Steps: + 1. Create QuantSim + 2. Compute encodings + 3. Initialize with ds.initialize + """ + # Baseline model without deepsped + model_baseline = Net().cuda().eval() + baseline_state_dict = model_baseline.state_dict() + sim_baseline = QuantizationSimModel(model_baseline, + torch.randn(1, 1, 28, 28).cuda(), + default_param_bw=4, + config_file=per_channel_quantsim_config, + quant_scheme=QuantScheme.training_range_learning_with_tf_init, + in_place=True) + + """ + Given: Model pre-partitioned with DeepSpeed Zero3 offload """ with ds.zero.Init(config_dict_or_path=deepspeed_zero3_offload_config): # ds.zero.Init context pre-partitoins the pytorch models at instantiation time. @@ -347,7 +598,7 @@ def test_deepspeed_zero3_offload(unlabeled_data_loader, model_parameters=sim_deepspeed.model.parameters(), config=deepspeed_zero3_offload_config, mpu=CustomMPU(init_process_group)) - assert all(hasattr(param, 'ds_shape') for param in model.parameters()) + assert all(hasattr(param, 'ds_id') for param in model.parameters()) with torch.no_grad(): for data in unlabeled_data_loader: @@ -411,7 +662,6 @@ def test_deepspeed_zero3_offload(unlabeled_data_loader, ds_after = ds_params_after[param_name] assert not torch.equal(ds_before, ds_after) - @pytest.mark.cuda def test_conv_transpose(per_channel_quantsim_config, init_process_group,