From eb6337d86b6b0c94e2c35e76ae08de06d91074ee Mon Sep 17 00:00:00 2001 From: Kevin Hsieh Date: Thu, 26 Sep 2024 14:00:13 -0700 Subject: [PATCH] Enable seq mse with bq/lpbq Signed-off-by: Kevin Hsieh --- .../torch/src/python/aimet_torch/seq_mse.py | 8 +- .../src/python/aimet_torch/v2/seq_mse.py | 255 +++++++++++++++--- .../torch/test/python/v2/test_seq_mse_.py | 46 ++-- 3 files changed, 256 insertions(+), 53 deletions(-) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/seq_mse.py b/TrainingExtensions/torch/src/python/aimet_torch/seq_mse.py index 128871b6fd1..c0cd35eb3dc 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/seq_mse.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/seq_mse.py @@ -539,12 +539,12 @@ def compute_outputs(cls, module = cls._get_original_module(quant_module) if isinstance(module, torch.nn.Linear): - xqwq = functional.linear(xq, wq, module.bias) - xw = functional.linear(x, w, module.bias) + xqwq = functional.linear(xq, wq) + xw = functional.linear(x, w) elif isinstance(module, torch.nn.Conv2d): - xqwq = functional.conv2d(xq, wq, bias=module.bias, stride=module.stride, dilation=module.dilation, + xqwq = functional.conv2d(xq, wq, stride=module.stride, dilation=module.dilation, padding=module.padding, groups=module.groups) - xw = functional.conv2d(x, w, bias=module.bias, stride=module.stride, dilation=module.dilation, + xw = functional.conv2d(x, w, stride=module.stride, dilation=module.dilation, padding=module.padding, groups=module.groups) # [N, C, H, W] --> [N, H, W, C], so that loss can be computed across channel dimension. diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v2/seq_mse.py b/TrainingExtensions/torch/src/python/aimet_torch/v2/seq_mse.py index 0f431587b88..58f7f6c2a55 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v2/seq_mse.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v2/seq_mse.py @@ -42,36 +42,71 @@ import contextlib import torch from torch import nn +from torch.utils.data import DataLoader +from aimet_common.utils import AimetLogger from aimet_torch.seq_mse import SequentialMse as V1SequentialMse from aimet_torch.seq_mse import SeqMseParams as V1SeqMseParams from aimet_torch.seq_mse import SUPPORTED_MODULES from aimet_torch.v2.quantization.base import QuantizerBase -from aimet_torch.v2.quantization.affine import AffineQuantizerBase -from aimet_torch.v2.quantization.encoding_analyzer import MinMaxEncodingAnalyzer +from aimet_torch.v2.quantization.affine import AffineQuantizerBase, QuantizeDequantize, GroupedBlockQuantizeDequantize +from aimet_torch.v2.quantization.affine.backends import torch_builtins from aimet_torch.v2.nn.base import BaseQuantizationMixin from aimet_torch.v2.quantsim import QuantizationSimModel from aimet_torch.v2.utils import reduce, _is_reducible SeqMseParams = V1SeqMseParams - - -def _observe(x_min: torch.Tensor, - x_max: torch.Tensor, - num_steps: int, - symmetric: bool) -> Tuple[torch.Tensor, torch.Tensor]: - encoding_analyzer = MinMaxEncodingAnalyzer(x_min.shape) - min, max = encoding_analyzer.compute_dynamic_encodings(torch.stack([x_min, x_max]), - num_steps=num_steps, - is_symmetric=symmetric) - return min, max - +_logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.SeqMse) class SequentialMse(V1SequentialMse): """ Sequentially minimizing activation MSE loss in layer-wise way to decide optimal param quantization encodings. """ + + @classmethod + def apply_seq_mse(cls, + model: torch.nn.Module, + sim: QuantizationSimModel, + data_loader: DataLoader, + params: SeqMseParams, + modules_to_exclude: Optional[List[torch.nn.Module]] = None, + checkpoints_config: Optional[str] = None): + if not modules_to_exclude: + modules_to_exclude = [] + modules_to_exclude.extend(cls._get_grouped_convs_with_blockwise_quantization(sim)) + with cls._handle_grouped_block_quantizers(sim): + super().apply_seq_mse(model, sim, data_loader, params, modules_to_exclude, checkpoints_config) + + @staticmethod + def _get_grouped_convs_with_blockwise_quantization(sim): + """ Return a list of all grouped conv modules using blockwise quantization for weights """ + grouped_convs_with_blockwise_quantization = [] + for module in sim.model.modules(): + if isinstance(module, torch.nn.Conv2d) and \ + isinstance(module, BaseQuantizationMixin) and \ + module.groups != 1 and \ + module.param_quantizers['weight'].block_size is not None and \ + module.param_quantizers['weight'].block_size[1] != module.weight.shape[1]: + grouped_convs_with_blockwise_quantization.append(module) + return grouped_convs_with_blockwise_quantization + + @staticmethod + @contextlib.contextmanager + def _handle_grouped_block_quantizers(sim: QuantizationSimModel): + """ Set all grouped block quantizers to regular blockwise quantization for the duration of the context manager + """ + grouped_block_quantize_dequantizers = [] + for module in sim.model.modules(): + if isinstance(module, GroupedBlockQuantizeDequantize): + grouped_block_quantize_dequantizers.append((module, module.block_grouping)) + module.block_grouping = tuple(1 for _ in enumerate(module.shape)) + + yield + + for (module, block_grouping) in grouped_block_quantize_dequantizers: + module.block_grouping = block_grouping + @staticmethod def compute_all_param_encodings(sim: QuantizationSimModel): """ @@ -143,27 +178,15 @@ def compute_param_encodings(quantizer: QuantizerBase, :param x_min: min values :param x_max: max values """ - # Unsqueeze x_min/x_max until they become reducible to quantizer.min/max - while x_min.dim() < quantizer.min.dim(): - x_min = x_min[..., None] - while x_max.dim() < quantizer.max.dim(): - x_max = x_max[..., None] - assert _is_reducible(x_min.shape, quantizer.min.shape) - assert _is_reducible(x_max.shape, quantizer.max.shape) - - x_min = reduce(x_min, quantizer.shape, torch.min).values - x_max = reduce(x_max, quantizer.shape, torch.max).values + quantize_dequantize = QuantizeDequantize(quantizer.shape, quantizer.bitwidth, quantizer.symmetric, + block_size=quantizer.block_size).to(x_min.device) - num_steps = 2 ** quantizer.bitwidth - 1 - symmetric = quantizer.symmetric - - # The values of x_min and x_max don't necessarily satisfy the symmetry constraints. - # Therefore, we need to adjust their values to ensure min and max are in symmetric grids. - min, max = _observe(x_min, x_max, num_steps=num_steps, symmetric=symmetric) + with quantize_dequantize.compute_encodings(): + _ = quantize_dequantize(torch.stack([x_min, x_max])) with torch.no_grad(): - quantizer.min.copy_(min) - quantizer.max.copy_(max) + quantizer.min.copy_(quantize_dequantize.min) + quantizer.max.copy_(quantize_dequantize.max) @staticmethod def _is_symmetric_quantizer(quantizer: AffineQuantizerBase): @@ -185,6 +208,174 @@ def _get_quantized_weight(quant_module: BaseQuantizationMixin): def _get_original_module(quant_module: BaseQuantizationMixin): return quant_module + @staticmethod + def _get_input_channel_block_size(quant_module): + if not isinstance(quant_module, (torch.nn.Linear, torch.nn.Conv2d)): + raise NotImplementedError('Unsupported module type: ', type(quant_module)) + if quant_module.param_quantizers['weight'].block_size is None: + # Per tensor or per channel case. For either one, treat loss computation as per channel + return quant_module.weight.shape[1] + return quant_module.weight.shape[1] // quant_module.param_quantizers['weight'].shape[1] + + @staticmethod + def _get_indices_to_reduce(block_size, reshaped_weight): + """ + Return indices in reshaped_weight corresponding to block_sizes. Reshaped_weight is expected to contain + alternating dimensions of num_blocks and block_sizes. + """ + indices_to_reduce = [] + for idx, _ in enumerate(block_size): + indices_to_reduce.insert(0, (len(reshaped_weight.shape) - 2 * idx) - 1) + return indices_to_reduce + + @classmethod + def get_min_and_max_for_candidate_selection(cls, quant_module: BaseQuantizationMixin) -> \ + Tuple[torch.Tensor, torch.Tensor]: + """ + Get min/max values for candidate selection. + + :param quant_module: Quant module to be optimized + :return: Tuple of min and max values for candidate selection. + """ + # pylint: disable=protected-access + assert hasattr(quant_module.param_quantizers['weight'], 'block_size') + if not isinstance(quant_module, (torch.nn.Conv2d, torch.nn.Linear)): + raise ValueError('Unsupported module: ', quant_module) + + block_size = quant_module.param_quantizers['weight'].block_size + if block_size is None: + # Per tensor or per channel case + assert _is_reducible(quant_module.weight.shape, quant_module.param_quantizers['weight'].min.shape) + if cls._is_symmetric_quantizer(quant_module.param_quantizers['weight']): + max_tensor = reduce(quant_module.weight.abs(), + quant_module.param_quantizers['weight'].shape, torch.max).values + min_tensor = -max_tensor + else: + max_tensor = reduce(quant_module.weight, + quant_module.param_quantizers['weight'].shape, torch.max).values + min_tensor = reduce(quant_module.weight, + quant_module.param_quantizers['weight'].shape, torch.min).values + else: + # Reshape tensor so each dimension is split into (num_blocks, block_size) + weight = torch_builtins.reshape_tensor_for_blocks(quant_module.weight, + quant_module.param_quantizers['weight'].shape, + block_size) + indices_to_reduce = cls._get_indices_to_reduce(block_size, weight) + + # Obtain max_tensor and min_tensor which are equivalent in shape to the original weight, but with block + # values modified to be the block minimum and maximum. + # For example assume the original weight is 1 output channel and 6 input channels, with block size 2: + # Original weight: [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]] + # Then, max tensor would be: [[2.0, 2.0, 4.0, 4.0, 6.0, 6.0]] + if cls._is_symmetric_quantizer(quant_module.param_quantizers['weight']): + max_tensor = torch.maximum(weight, + torch.amax(weight.abs(), + indices_to_reduce, + keepdim=True)).reshape(quant_module.weight.shape) + min_tensor = -max_tensor + else: + max_tensor = torch.maximum(weight, + torch.amax(weight, + indices_to_reduce, + keepdim=True)).reshape(quant_module.weight.shape) + min_tensor = torch.minimum(weight, + torch.amin(weight, + indices_to_reduce, + keepdim=True)).reshape(quant_module.weight.shape) + + return min_tensor, max_tensor + + @classmethod + def _get_candidate(cls, candidate_idx: int, num_candidates: int, min_tensor: torch.Tensor, + max_tensor: torch.Tensor): + """ + Get candidate min and max tensors + """ + cand_max = max_tensor / num_candidates * (candidate_idx + 1) + cand_min = min_tensor / num_candidates * (candidate_idx + 1) + return cand_min, cand_max + + @classmethod + def _compute_loss(cls, + quant_module: BaseQuantizationMixin, + x: torch.Tensor, + xq: torch.Tensor, + w: torch.Tensor, + wq: torch.Tensor, + params) -> torch.Tensor: + """ + Compute loss for the given (x, w) and (xq, wq) input/weight pairs. Assumes that block size will be on + input_channel dimension. + """ + # pylint: disable=too-many-locals + # General strategy: split weights and input per block, and run a separate forward pass for each split. + # In the case of per tensor and per channel, the entire input channel is treated as one block. + block_size = cls._get_input_channel_block_size(quant_module) + w_blocks = torch.split(w, block_size, dim=1) + wq_blocks = torch.split(wq, block_size, dim=1) + if isinstance(quant_module, torch.nn.Linear): + x_blocks = torch.split(x, block_size, dim=-1) + xq_blocks = torch.split(xq, block_size, dim=-1) + else: + x_blocks = torch.split(x, block_size, dim=-3) + xq_blocks = torch.split(xq, block_size, dim=-3) + + block_losses = [] + for idx, x_block in enumerate(x_blocks): + xqwq, xw = cls.compute_outputs(quant_module, x_block, xq_blocks[idx], w_blocks[idx], wq_blocks[idx]) + block_losses.append(cls.compute_recon_loss(xqwq, xw, params)) + # Stack losses in the input channel dimension + block_losses = torch.stack(block_losses, dim=-1) + return block_losses + + @classmethod + def optimize_module(cls, + quant_module: BaseQuantizationMixin, + x: torch.Tensor, + xq: torch.Tensor, + params: SeqMseParams): + """ + Find and freeze optimal parameter encodings candidate for given module. + + :param quant_module: Quant module to be optimized + :param x: Inputs to module from FP32 model + :param xq: Inputs to module from QuantSim model + :param params: Sequenial MSE parameters + """ + # pylint: disable=too-many-locals + min_tensor, max_tensor = cls.get_min_and_max_for_candidate_selection(quant_module) + + total_loss = [] + for i in range(params.num_candidates): + cand_min, cand_max = cls._get_candidate(i, params.num_candidates, min_tensor, max_tensor) + cls.compute_param_encodings(quant_module.param_quantizers['weight'], cand_min, cand_max) + w = quant_module.weight + wq = cls._get_quantized_weight(quant_module) + with torch.no_grad(): + for batch_idx in range(params.num_batches): + if batch_idx == 0: + loss = cls._compute_loss(quant_module, x[batch_idx], xq[batch_idx], w, wq, params) + else: + loss += cls._compute_loss(quant_module, x[batch_idx], xq[batch_idx], w, wq, params) + total_loss.append(loss) + + best_indices = torch.stack(total_loss).min(0)[1] + block_size = cls._get_input_channel_block_size(quant_module) + # In the input_channels dimension, best_indices is of size num_blocks. We use repeat_interleave to expand + # each blockwise index into block_size number of indices. This makes best_indices input_channels dimension + # become size num_blocks * block_size, and allows for elementwise operation with min_tensor and max_tensor. + if block_size != quant_module.weight.shape[1]: + best_indices = best_indices.repeat_interleave(block_size, dim=-1) + + # Unsqueeze best_indices until it matches dim length of max_tensor + while best_indices.dim() < max_tensor.dim(): + best_indices = best_indices[..., None] + + min_tensor, max_tensor = cls._get_candidate(best_indices, params.num_candidates, min_tensor, max_tensor) + + # Compute and freeze parameter encodings using best candidate + cls.compute_param_encodings(quant_module.param_quantizers['weight'], min_tensor, max_tensor) + cls._freeze_quantizer_encoding(quant_module.param_quantizers['weight']) # Global variables for compatibility apply_seq_mse = SequentialMse.apply_seq_mse diff --git a/TrainingExtensions/torch/test/python/v2/test_seq_mse_.py b/TrainingExtensions/torch/test/python/v2/test_seq_mse_.py index 2fc801d8129..5f12097424f 100644 --- a/TrainingExtensions/torch/test/python/v2/test_seq_mse_.py +++ b/TrainingExtensions/torch/test/python/v2/test_seq_mse_.py @@ -47,6 +47,7 @@ from aimet_torch.utils import create_fake_data_loader from aimet_torch.v2.quantsim import QuantizationSimModel +from aimet_torch.v2.quantsim.config_utils import set_grouped_blockwise_quantization_for_weights from aimet_torch.v2.nn import QuantizationMixin from aimet_torch.v2.quantization.affine import QuantizeDequantize from aimet_torch.v2.seq_mse import apply_seq_mse, get_candidates, optimize_module, SeqMseParams, SequentialMse @@ -152,21 +153,20 @@ def test_seq_mse(self): assert list(cand_max.size())[0] == linear.out_features assert list(cand_min.size())[0] == linear.out_features - @pytest.mark.parametrize("enable_pcq", [True, False]) + @pytest.mark.parametrize("quantizer_shape, block_size", [[[], None], + [[128, 1], None], + [[128, 8], [-1, -1]]]) @pytest.mark.parametrize("param_bw", [4, 16]) @pytest.mark.parametrize("loss_fn", ['mse', 'l1', 'sqnr']) @pytest.mark.parametrize("qparam_requires_grad", [True, False]) - def test_optimize_module_linear(self, enable_pcq, param_bw, loss_fn, qparam_requires_grad): + def test_optimize_module_linear(self, quantizer_shape, block_size, param_bw, loss_fn, qparam_requires_grad): """ test optimize module for linear """ torch.manual_seed(0) linear = torch.nn.Linear(64, 128) wrapper = QuantizationMixin.from_module(linear) - if enable_pcq: - quantizer_shape = [linear.weight.shape[0], 1] - else: - quantizer_shape = [] - wrapper.param_quantizers['weight'] = QuantizeDequantize(shape=quantizer_shape, bitwidth=param_bw, symmetric=True) + wrapper.param_quantizers['weight'] = QuantizeDequantize(shape=quantizer_shape, bitwidth=param_bw, + symmetric=True, block_size=block_size) wrapper.param_quantizers['weight'].min.requires_grad = qparam_requires_grad wrapper.param_quantizers['weight'].max.requires_grad = qparam_requires_grad @@ -187,23 +187,20 @@ def test_optimize_module_linear(self, enable_pcq, param_bw, loss_fn, qparam_requ assert not torch.allclose(before.min, after.min) assert not torch.allclose(before.max, after.max) - @pytest.mark.parametrize("enable_pcq", [True, False]) + @pytest.mark.parametrize("quantizer_shape, block_size", [[[1, ], None], + [[32, 1, 1, 1], None], + [[32, 3, 1, 1], [1, 2, -1, -1]]]) @pytest.mark.parametrize("param_bw", [4, 16]) @pytest.mark.parametrize("loss_fn", ['mse', 'l1', 'sqnr']) - def test_optimize_module_conv(self, enable_pcq, param_bw, loss_fn): + def test_optimize_module_conv(self, quantizer_shape, block_size, param_bw, loss_fn): """ test optimize module for linear """ torch.manual_seed(0) - conv = torch.nn.Conv2d(3, 32, 3) + conv = torch.nn.Conv2d(6, 32, 3) wrapper = QuantizationMixin.from_module(conv) - if enable_pcq: - quantizer_shape = [conv.weight.shape[0], 1, 1, 1] - else: - quantizer_shape = [1, ] - wrapper.param_quantizers['weight'] = QuantizeDequantize(shape=quantizer_shape, bitwidth=param_bw, - symmetric=True) + symmetric=True, block_size=block_size) - xq = torch.randn(32, 1, 3, 10, 10) + xq = torch.randn(32, 1, 6, 10, 10) with wrapper.param_quantizers['weight'].compute_encodings(): _ = wrapper.param_quantizers['weight'](wrapper.weight.data) before = wrapper.param_quantizers['weight'].get_encoding() @@ -322,3 +319,18 @@ def test_compute_param_encodings(self, qtzr, range): SequentialMse.compute_param_encodings(qtzr, x_min, x_max) assert torch.all(torch.isclose(qtzr.get_max(), x_max) | torch.isclose(qtzr.get_min(), x_min)) + + def test_handle_grouped_block_quantizers(self): + torch.manual_seed(0) + model = Net().eval() + dummy_input = torch.randn(1, 1, 28, 28) + sim = QuantizationSimModel(model, dummy_input, default_param_bw=4) + set_grouped_blockwise_quantization_for_weights(sim, lambda m: m != sim.model.conv1, 4, True, 8, 4) + sim.compute_encodings(lambda m, _: m(dummy_input), None) + out = sim.model(dummy_input) + with SequentialMse._handle_grouped_block_quantizers(sim): + out_2 = sim.model(dummy_input) + out_3 = sim.model(dummy_input) + + assert torch.equal(out, out_3) + assert not torch.equal(out, out_2)