Skip to content

Commit

Permalink
Implement SQNR encoding analyzer
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Tuttle <quic_mtuttle@quicinc.com>
  • Loading branch information
quic-mtuttle authored Feb 9, 2024
1 parent 75bc9df commit 9bf063d
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -394,11 +394,21 @@ class SqnrEncodingAnalyzer(EncodingAnalyzer[_Histogram]):
"""
Encoding Analyzer for SQNR Calibration technique
"""
def __init__(self, shape: tuple, num_bins: int = 2048):
def __init__(self,
shape: tuple,
num_bins: int = 2048, *,
asymmetric_delta_candidates=17,
symmetric_delta_candidates=101,
offset_candidates=21,
gamma=3.0):
if num_bins <= 0:
raise ValueError('Number of bins cannot be less than or equal to 0.')
observer = _HistogramObserver(shape=shape, num_bins=num_bins)
super().__init__(observer)
self.asym_delta_candidates = asymmetric_delta_candidates
self.sym_delta_candidates = symmetric_delta_candidates
self.num_offset_candidates = offset_candidates
self.gamma = gamma

@torch.no_grad()
def update_stats(self, input_tensor: torch.Tensor) -> _Statistics:
Expand All @@ -407,7 +417,150 @@ def update_stats(self, input_tensor: torch.Tensor) -> _Statistics:
return new_stats

@torch.no_grad()
def compute_encodings_from_stats(self, stats: _Histogram, bitwidth: int, is_symmetric: bool)\
def compute_encodings_from_stats(self, stats: List[_Histogram], bitwidth: int, is_symmetric: bool)\
-> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
# TODO
raise NotImplementedError
"""
Searches for encodings which produce the lowest expected SQNR based on the histograms in stats
:param stats: A list of _Histogram objects with length equal to the number of encodings to compute
:param bitwidth: The bitwidth of the computed encodings
:param is_symmetric: If True, computes symmetric encodings, else computes asymmetric encodings
:return: Tuple of computed encodings (min, max) as tensors with shape self.shape
"""
if stats[0].histogram is None:
raise StatisticsNotFoundError('No statistics present to compute encodings.')
if bitwidth <= 0:
raise ValueError('Bitwidth cannot be less than or equal to 0.')
dtype = stats[0].max.dtype
num_steps = 2 ** bitwidth - 1
test_deltas, test_offsets = self._pick_test_candidates(stats, num_steps, is_symmetric)
best_delta, best_offset = self._select_best_candidates(test_deltas, test_offsets, stats, num_steps)
min_enc = best_offset * best_delta
max_enc = min_enc + num_steps * best_delta
shape = self.observer.shape
return min_enc.view(shape).to(dtype), max_enc.view(shape).to(dtype)

def _pick_test_candidates(self, stats, num_steps, symmetric):
# min/max.shape = (num_histograms, )
min_vals = torch.stack([stat.min for stat in stats])
max_vals = torch.stack([stat.max for stat in stats])
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
max_vals = torch.max(max_vals, min_vals + torch.finfo(min_vals.dtype).tiny * num_steps)
if symmetric:
return self._pick_test_candidates_symmetric(min_vals, max_vals, num_steps)
return self._pick_test_candidates_asymmetric(min_vals, max_vals, num_steps)

def _pick_test_candidates_asymmetric(self, min_vals, max_vals, num_steps):
"""
Selects the set of deltas and offsets over which to search for the optimal encodings
"""
# Note: casting to float32 for two reason:
# 1) float16 on CPU is not well-supported in pytorch
# 2) Computing int16 encodings using f16 can result in inf (2 ** 16 - 1 == inf in fp16)
tensor_kwargs = {"device": min_vals.device, "dtype": torch.float32}
max_delta = (max_vals - min_vals).to(torch.float32) / num_steps
observed_offset = torch.round(min_vals / max_delta)
observed_min = max_delta * observed_offset
observed_max = observed_min + max_delta * num_steps
num_deltas = self.asym_delta_candidates
search_space = torch.arange(start=1, end=(1 + num_deltas), step=1, **tensor_kwargs)
# test_deltas.shape = (num_histograms, num_tests)
test_deltas = max_delta[:, None] * search_space[None, :] / (num_deltas - 1)
# test_offsets.shape = (num_offsets)
num_offsets = min(num_steps + 2, self.num_offset_candidates)
test_offset_step = num_steps / (num_offsets - 2) # subtract 2 because we add the observed offset
test_offsets = torch.round(torch.arange(start=-num_steps, end=test_offset_step, step=test_offset_step, **tensor_kwargs))
test_offsets = test_offsets[None, :].expand(min_vals.shape[0], -1)
# Add in the observed offset as a candidate, test_offsets.shape = (num_histograms, num_offsets + 1)
test_offsets = torch.concat((test_offsets, observed_offset[:, None]), dim=1)
return self._clamp_delta_offset_values(observed_min, observed_max, num_steps, test_deltas, test_offsets)

def _pick_test_candidates_symmetric(self, min_vals, max_vals, num_steps):
"""
Selects the set of deltas over which to search for the optimal symmetric encodings
"""
tensor_kwargs = {"device": min_vals.device, "dtype": torch.float32}
max_delta = 2 * torch.max(max_vals, -min_vals).to(torch.float32) / num_steps
test_offsets = torch.full((1, ), (-num_steps) // 2, **tensor_kwargs)
num_deltas = self.sym_delta_candidates
search_space = torch.arange(start=1, end=(1 + num_deltas), step=1, **tensor_kwargs)
test_deltas = max_delta[:, None] * search_space[None, :] / (num_deltas - 1)
# test_deltas.shape = (num_histograms, num_deltas, 1)
# test_offsets.shape = (1, 1, 1)
min_delta = torch.Tensor([torch.finfo(test_deltas.dtype).tiny]).to(**tensor_kwargs)
test_deltas = torch.max(test_deltas, min_delta)
return test_deltas[:, :, None], test_offsets[:, None, None]

@staticmethod
def _clamp_delta_offset_values(min_vals, max_vals, num_steps, test_deltas, test_offsets):
"""
Clamps delta/offset encodings such that represented range falls within the observed min/max range of inputs
"""
# test_min shape = (num_histograms, num_deltas, num_offsets)
test_min = test_deltas[:, :, None] * test_offsets[:, None, :]
test_max = test_min + test_deltas[:, :, None] * num_steps
# Clamp min/max to observed min/max
test_min = torch.max(min_vals[:, None, None], test_min)
test_max = torch.min(max_vals[:, None, None], test_max)
# Recompute delta/offset with clamped min/max
# Returned delta/offset shapes = (num_histograms, num_deltas, num_offsets)
test_deltas = (test_max - test_min) / num_steps
min_delta = torch.Tensor([torch.finfo(test_deltas.dtype).tiny]).to(device=test_deltas.device,
dtype=test_deltas.dtype)
test_deltas = torch.max(test_deltas, min_delta)
test_offsets = torch.round(test_min / test_deltas)
return test_deltas, test_offsets

def _select_best_candidates(self, test_deltas, test_offsets, stats, num_steps):
"""
Searches all pairs of (delta, offset) in test_deltas, test_offsets to find the set with the lowest expected SQNR
"""
noise = self._estimate_clip_and_quant_noise(stats, test_deltas, test_offsets, num_steps, self.gamma)
_, min_idx = torch.min(noise.flatten(start_dim=1), dim=1)
best_delta = torch.gather(test_deltas.flatten(start_dim=1), dim=1, index=min_idx[:, None])
if test_offsets.numel() == 1:
best_offset = test_offsets
else:
best_offset = torch.gather(test_offsets.flatten(start_dim=1), dim=1, index=min_idx[:, None])
return best_delta, best_offset

# pylint: disable=too-many-locals
@staticmethod
def _estimate_clip_and_quant_noise(stats: List[_Histogram],
test_deltas: torch.Tensor,
test_offsets: torch.Tensor,
num_steps: int,
gamma: float = 1.0):
"""
Calculates the error from quantization for each delta, offset pair in test_deltas, test_offsets.
We approximately reconstruct x from hists by assuming all elements within a given bin fall exactly on the
midpoint of that bin.
:param stats: list of _Histogram objects of observed input values
:param test_deltas: Tensor holding the values of all deltas to search with shape (num_hists, num_deltas, num_offsets)
:param test_offsets: Tensor holding values of all offsets to search with shape (num_hists, num_deltas, num_offsets)
:param num_steps: Number of quantization steps, i.e., (2 ** bitwidth) - 1
:param gamma: Fudge factor to trade off between saturation cost and quantization cost. When gamma=1.0, this
approximates the MSE of the quantization function
"""
tensor_kwargs = {"device": test_deltas.device, "dtype": test_deltas.dtype}
hists = torch.stack([stat.histogram for stat in stats])
bin_edges = torch.stack([stat.bin_edges for stat in stats])
hist_delta = bin_edges[:, 1] - bin_edges[:, 0]
# hist_midpoints is shape (hists, num_bins)
hist_offsets = hist_delta[:, None] * torch.arange(0, bin_edges.shape[1] - 1, **tensor_kwargs)[None, :]
hist_midpoints = (bin_edges[:, 0] + hist_delta/2)[:, None] + hist_offsets
# hists_midpoints_qdq is shape (hists, num_deltas, num_offsets, num_bins)
test_offsets_bcast = test_offsets[:, :, :, None]
test_deltas_bcast = test_deltas[:, :, :, None]
hist_midpoints_qdq = hist_midpoints[:, None, None, :].div(test_deltas_bcast).sub(test_offsets_bcast).round()
if gamma != 1.0:
clipped = torch.logical_or(hist_midpoints_qdq < 0,
hist_midpoints_qdq > num_steps)
hist_midpoints_qdq = hist_midpoints_qdq.clamp(0, num_steps).add(test_offsets_bcast).mul(test_deltas_bcast)
square_error = (hist_midpoints[:, None, None, :] - hist_midpoints_qdq).pow(2) * hists[:, None, None, :]
if gamma != 1.0:
# Apply the gamma "fudge factor" to the clipped errors
square_error = torch.where(clipped, square_error * gamma, square_error)
return torch.sum(square_error, dim=-1)
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,8 @@ class TestEncodingAnalyzer():
def encoding_analyzers(self):
min_max_encoding_analyzer = MinMaxEncodingAnalyzer((1,))
percentile_encoding_analyzer = PercentileEncodingAnalyzer((1,), 3)
# TODO: Uncomment after implementation is complete
# sqnr_encoding_analyzer = SqnrEncodingAnalyzer()
# encoding_analyzer_list = [min_max_encoding_analyzer, percentile_encoding_analyzer, sqnr_encoding_analyzer]
encoding_analyzer_list = [min_max_encoding_analyzer, percentile_encoding_analyzer]
sqnr_encoding_analyzer = SqnrEncodingAnalyzer((1, ))
encoding_analyzer_list = [min_max_encoding_analyzer, percentile_encoding_analyzer, sqnr_encoding_analyzer]
yield encoding_analyzer_list

def test_compute_encodings_with_negative_bitwidth(self, encoding_analyzers):
Expand Down Expand Up @@ -561,7 +559,7 @@ def test_compute_encodings_50_percentile(self):
assert asymmetric_min == 0
assert asymmetric_max == mid_value

@pytest.mark.skip("Not implemented")

class TestSqnrEncodingAnalyzer:

def test_computed_encodings_uniform_dist(self):
Expand Down Expand Up @@ -610,7 +608,7 @@ def test_computed_encodings_with_outliers(self):
encoding_analyzer.update_stats(x)
outlier = torch.Tensor([outlier_val]).view(1, 1)
encoding_analyzer.update_stats(outlier)
qmin, qmax = encoding_analyzer.compute_encodings(8, is_symmetric=True)
qmin, qmax = encoding_analyzer.compute_encodings(8, is_symmetric=False)
expected_min = torch.Tensor([0])
expected_max = expected_delta * 255
assert torch.allclose(qmin, expected_min)
Expand Down

0 comments on commit 9bf063d

Please sign in to comment.