diff --git a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/encoding_analyzer.py b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/encoding_analyzer.py index afbe95da336..eb448faf8be 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/encoding_analyzer.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/encoding_analyzer.py @@ -394,11 +394,21 @@ class SqnrEncodingAnalyzer(EncodingAnalyzer[_Histogram]): """ Encoding Analyzer for SQNR Calibration technique """ - def __init__(self, shape: tuple, num_bins: int = 2048): + def __init__(self, + shape: tuple, + num_bins: int = 2048, *, + asymmetric_delta_candidates=17, + symmetric_delta_candidates=101, + offset_candidates=21, + gamma=3.0): if num_bins <= 0: raise ValueError('Number of bins cannot be less than or equal to 0.') observer = _HistogramObserver(shape=shape, num_bins=num_bins) super().__init__(observer) + self.asym_delta_candidates = asymmetric_delta_candidates + self.sym_delta_candidates = symmetric_delta_candidates + self.num_offset_candidates = offset_candidates + self.gamma = gamma @torch.no_grad() def update_stats(self, input_tensor: torch.Tensor) -> _Statistics: @@ -407,7 +417,150 @@ def update_stats(self, input_tensor: torch.Tensor) -> _Statistics: return new_stats @torch.no_grad() - def compute_encodings_from_stats(self, stats: _Histogram, bitwidth: int, is_symmetric: bool)\ + def compute_encodings_from_stats(self, stats: List[_Histogram], bitwidth: int, is_symmetric: bool)\ -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: - # TODO - raise NotImplementedError + """ + Searches for encodings which produce the lowest expected SQNR based on the histograms in stats + + :param stats: A list of _Histogram objects with length equal to the number of encodings to compute + :param bitwidth: The bitwidth of the computed encodings + :param is_symmetric: If True, computes symmetric encodings, else computes asymmetric encodings + :return: Tuple of computed encodings (min, max) as tensors with shape self.shape + """ + if stats[0].histogram is None: + raise StatisticsNotFoundError('No statistics present to compute encodings.') + if bitwidth <= 0: + raise ValueError('Bitwidth cannot be less than or equal to 0.') + dtype = stats[0].max.dtype + num_steps = 2 ** bitwidth - 1 + test_deltas, test_offsets = self._pick_test_candidates(stats, num_steps, is_symmetric) + best_delta, best_offset = self._select_best_candidates(test_deltas, test_offsets, stats, num_steps) + min_enc = best_offset * best_delta + max_enc = min_enc + num_steps * best_delta + shape = self.observer.shape + return min_enc.view(shape).to(dtype), max_enc.view(shape).to(dtype) + + def _pick_test_candidates(self, stats, num_steps, symmetric): + # min/max.shape = (num_histograms, ) + min_vals = torch.stack([stat.min for stat in stats]) + max_vals = torch.stack([stat.max for stat in stats]) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + max_vals = torch.max(max_vals, min_vals + torch.finfo(min_vals.dtype).tiny * num_steps) + if symmetric: + return self._pick_test_candidates_symmetric(min_vals, max_vals, num_steps) + return self._pick_test_candidates_asymmetric(min_vals, max_vals, num_steps) + + def _pick_test_candidates_asymmetric(self, min_vals, max_vals, num_steps): + """ + Selects the set of deltas and offsets over which to search for the optimal encodings + """ + # Note: casting to float32 for two reason: + # 1) float16 on CPU is not well-supported in pytorch + # 2) Computing int16 encodings using f16 can result in inf (2 ** 16 - 1 == inf in fp16) + tensor_kwargs = {"device": min_vals.device, "dtype": torch.float32} + max_delta = (max_vals - min_vals).to(torch.float32) / num_steps + observed_offset = torch.round(min_vals / max_delta) + observed_min = max_delta * observed_offset + observed_max = observed_min + max_delta * num_steps + num_deltas = self.asym_delta_candidates + search_space = torch.arange(start=1, end=(1 + num_deltas), step=1, **tensor_kwargs) + # test_deltas.shape = (num_histograms, num_tests) + test_deltas = max_delta[:, None] * search_space[None, :] / (num_deltas - 1) + # test_offsets.shape = (num_offsets) + num_offsets = min(num_steps + 2, self.num_offset_candidates) + test_offset_step = num_steps / (num_offsets - 2) # subtract 2 because we add the observed offset + test_offsets = torch.round(torch.arange(start=-num_steps, end=test_offset_step, step=test_offset_step, **tensor_kwargs)) + test_offsets = test_offsets[None, :].expand(min_vals.shape[0], -1) + # Add in the observed offset as a candidate, test_offsets.shape = (num_histograms, num_offsets + 1) + test_offsets = torch.concat((test_offsets, observed_offset[:, None]), dim=1) + return self._clamp_delta_offset_values(observed_min, observed_max, num_steps, test_deltas, test_offsets) + + def _pick_test_candidates_symmetric(self, min_vals, max_vals, num_steps): + """ + Selects the set of deltas over which to search for the optimal symmetric encodings + """ + tensor_kwargs = {"device": min_vals.device, "dtype": torch.float32} + max_delta = 2 * torch.max(max_vals, -min_vals).to(torch.float32) / num_steps + test_offsets = torch.full((1, ), (-num_steps) // 2, **tensor_kwargs) + num_deltas = self.sym_delta_candidates + search_space = torch.arange(start=1, end=(1 + num_deltas), step=1, **tensor_kwargs) + test_deltas = max_delta[:, None] * search_space[None, :] / (num_deltas - 1) + # test_deltas.shape = (num_histograms, num_deltas, 1) + # test_offsets.shape = (1, 1, 1) + min_delta = torch.Tensor([torch.finfo(test_deltas.dtype).tiny]).to(**tensor_kwargs) + test_deltas = torch.max(test_deltas, min_delta) + return test_deltas[:, :, None], test_offsets[:, None, None] + + @staticmethod + def _clamp_delta_offset_values(min_vals, max_vals, num_steps, test_deltas, test_offsets): + """ + Clamps delta/offset encodings such that represented range falls within the observed min/max range of inputs + """ + # test_min shape = (num_histograms, num_deltas, num_offsets) + test_min = test_deltas[:, :, None] * test_offsets[:, None, :] + test_max = test_min + test_deltas[:, :, None] * num_steps + # Clamp min/max to observed min/max + test_min = torch.max(min_vals[:, None, None], test_min) + test_max = torch.min(max_vals[:, None, None], test_max) + # Recompute delta/offset with clamped min/max + # Returned delta/offset shapes = (num_histograms, num_deltas, num_offsets) + test_deltas = (test_max - test_min) / num_steps + min_delta = torch.Tensor([torch.finfo(test_deltas.dtype).tiny]).to(device=test_deltas.device, + dtype=test_deltas.dtype) + test_deltas = torch.max(test_deltas, min_delta) + test_offsets = torch.round(test_min / test_deltas) + return test_deltas, test_offsets + + def _select_best_candidates(self, test_deltas, test_offsets, stats, num_steps): + """ + Searches all pairs of (delta, offset) in test_deltas, test_offsets to find the set with the lowest expected SQNR + """ + noise = self._estimate_clip_and_quant_noise(stats, test_deltas, test_offsets, num_steps, self.gamma) + _, min_idx = torch.min(noise.flatten(start_dim=1), dim=1) + best_delta = torch.gather(test_deltas.flatten(start_dim=1), dim=1, index=min_idx[:, None]) + if test_offsets.numel() == 1: + best_offset = test_offsets + else: + best_offset = torch.gather(test_offsets.flatten(start_dim=1), dim=1, index=min_idx[:, None]) + return best_delta, best_offset + + # pylint: disable=too-many-locals + @staticmethod + def _estimate_clip_and_quant_noise(stats: List[_Histogram], + test_deltas: torch.Tensor, + test_offsets: torch.Tensor, + num_steps: int, + gamma: float = 1.0): + """ + Calculates the error from quantization for each delta, offset pair in test_deltas, test_offsets. + We approximately reconstruct x from hists by assuming all elements within a given bin fall exactly on the + midpoint of that bin. + + :param stats: list of _Histogram objects of observed input values + :param test_deltas: Tensor holding the values of all deltas to search with shape (num_hists, num_deltas, num_offsets) + :param test_offsets: Tensor holding values of all offsets to search with shape (num_hists, num_deltas, num_offsets) + :param num_steps: Number of quantization steps, i.e., (2 ** bitwidth) - 1 + :param gamma: Fudge factor to trade off between saturation cost and quantization cost. When gamma=1.0, this + approximates the MSE of the quantization function + """ + tensor_kwargs = {"device": test_deltas.device, "dtype": test_deltas.dtype} + hists = torch.stack([stat.histogram for stat in stats]) + bin_edges = torch.stack([stat.bin_edges for stat in stats]) + hist_delta = bin_edges[:, 1] - bin_edges[:, 0] + # hist_midpoints is shape (hists, num_bins) + hist_offsets = hist_delta[:, None] * torch.arange(0, bin_edges.shape[1] - 1, **tensor_kwargs)[None, :] + hist_midpoints = (bin_edges[:, 0] + hist_delta/2)[:, None] + hist_offsets + # hists_midpoints_qdq is shape (hists, num_deltas, num_offsets, num_bins) + test_offsets_bcast = test_offsets[:, :, :, None] + test_deltas_bcast = test_deltas[:, :, :, None] + hist_midpoints_qdq = hist_midpoints[:, None, None, :].div(test_deltas_bcast).sub(test_offsets_bcast).round() + if gamma != 1.0: + clipped = torch.logical_or(hist_midpoints_qdq < 0, + hist_midpoints_qdq > num_steps) + hist_midpoints_qdq = hist_midpoints_qdq.clamp(0, num_steps).add(test_offsets_bcast).mul(test_deltas_bcast) + square_error = (hist_midpoints[:, None, None, :] - hist_midpoints_qdq).pow(2) * hists[:, None, None, :] + if gamma != 1.0: + # Apply the gamma "fudge factor" to the clipped errors + square_error = torch.where(clipped, square_error * gamma, square_error) + return torch.sum(square_error, dim=-1) diff --git a/TrainingExtensions/torch/test/python/experimental/v2/test_encoding_analyzer.py b/TrainingExtensions/torch/test/python/experimental/v2/test_encoding_analyzer.py index 88ac2f976dc..ae0ec42fd0a 100644 --- a/TrainingExtensions/torch/test/python/experimental/v2/test_encoding_analyzer.py +++ b/TrainingExtensions/torch/test/python/experimental/v2/test_encoding_analyzer.py @@ -50,10 +50,8 @@ class TestEncodingAnalyzer(): def encoding_analyzers(self): min_max_encoding_analyzer = MinMaxEncodingAnalyzer((1,)) percentile_encoding_analyzer = PercentileEncodingAnalyzer((1,), 3) - # TODO: Uncomment after implementation is complete - # sqnr_encoding_analyzer = SqnrEncodingAnalyzer() - # encoding_analyzer_list = [min_max_encoding_analyzer, percentile_encoding_analyzer, sqnr_encoding_analyzer] - encoding_analyzer_list = [min_max_encoding_analyzer, percentile_encoding_analyzer] + sqnr_encoding_analyzer = SqnrEncodingAnalyzer((1, )) + encoding_analyzer_list = [min_max_encoding_analyzer, percentile_encoding_analyzer, sqnr_encoding_analyzer] yield encoding_analyzer_list def test_compute_encodings_with_negative_bitwidth(self, encoding_analyzers): @@ -561,7 +559,7 @@ def test_compute_encodings_50_percentile(self): assert asymmetric_min == 0 assert asymmetric_max == mid_value -@pytest.mark.skip("Not implemented") + class TestSqnrEncodingAnalyzer: def test_computed_encodings_uniform_dist(self): @@ -610,7 +608,7 @@ def test_computed_encodings_with_outliers(self): encoding_analyzer.update_stats(x) outlier = torch.Tensor([outlier_val]).view(1, 1) encoding_analyzer.update_stats(outlier) - qmin, qmax = encoding_analyzer.compute_encodings(8, is_symmetric=True) + qmin, qmax = encoding_analyzer.compute_encodings(8, is_symmetric=False) expected_min = torch.Tensor([0]) expected_max = expected_delta * 255 assert torch.allclose(qmin, expected_min)