Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding initial Histogram Observer implementation #2700

Merged
merged 6 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,19 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TypeVar, Generic, Tuple, Optional
from typing import TypeVar, Generic, Tuple, Optional, List
import itertools
import numpy as np
import torch
from aimet_torch.experimental.v2.utils import reduce, StatisticsNotFoundError
from aimet_torch.experimental.v2.utils import reduce, StatisticsNotFoundError, _is_expandable


@dataclass
class _MinMaxRange:
min: Optional[torch.Tensor] = None
max: Optional[torch.Tensor] = None

@dataclass
class _Histogram:
histogram: torch.Tensor = None
bin_edges: torch.Tensor = None
Expand Down Expand Up @@ -123,29 +126,108 @@ def reset_stats(self):
def get_stats(self) -> _MinMaxRange:
return self.stats


class _HistogramObserver(_Observer[_Histogram]):
"""
Observer for Histogram based calibration techniques (percentile, MSE)
"""
def __init__(self, shape: tuple, num_bins: int):
super().__init__(shape)
self.stats = _Histogram()
self.num_bins = num_bins
self.num_histograms = np.prod(self.shape)
self.stats = []
for _ in range(self.num_histograms):
self.stats.append(_Histogram())

@torch.no_grad()
def collect_stats(self, input_tensor: torch.Tensor) -> _Histogram:
# TODO
raise NotImplementedError

@torch.no_grad()
def merge_stats(self, stats: _Histogram):
# TODO
raise NotImplementedError
def collect_stats(self, input_tensor: torch.Tensor) -> List[_Histogram]:
if not _is_expandable(self.shape, input_tensor.shape):
raise RuntimeError(f"Shape {self.shape} is incompatible with input of shape {input_tensor.shape}")

hist_stats = []
input_shape = tuple(input_tensor.shape)
histogram_shape = self.shape

padded_histogram_shape = (
*itertools.repeat(1, len(input_shape) - len(histogram_shape)),
*histogram_shape
)

for hist_num in range(self.num_histograms):
hist_input = input_tensor

for axis, dim in enumerate(padded_histogram_shape):
if dim == 1:
continue
# elements in current axis, ex: could be W*C, C, or 1 for input_shape [H, W, C]
numel = np.prod(padded_histogram_shape[axis+1:], dtype=int)
# index where hist_input at current dimension will be sliced at
index = (hist_num // numel) % dim
hist_input = hist_input.select(axis, index).unsqueeze(axis)

histogram, bin_edges = torch.histogram(hist_input, self.num_bins)
hist_stats.append(_Histogram(histogram, bin_edges, hist_input.min(), hist_input.max()))

return hist_stats

def _get_bin_num(self, bin_width: int, curr_min, data):
if bin_width:
return min(int((data - curr_min) / bin_width), self.num_bins - 1)
return bin_width

# pylint: disable=arguments-differ
# pylint: disable=too-many-locals
@torch.no_grad()
def merge_stats(self, new_stats_list: List[_Histogram], input_tensor: torch.Tensor):
if self.stats[0].histogram is None:
self.stats = new_stats_list
return

hist_inputs = torch.reshape(input_tensor, (len(new_stats_list), -1))

for index, new_stats in enumerate(new_stats_list):
curr_stats = self.stats[index]
curr_input = hist_inputs[index]

updated_min = min(new_stats.min, curr_stats.min)
updated_max = max(new_stats.max, curr_stats.max)

# if the current histogram can capture new_stats within in its range
if updated_min == curr_stats.min and updated_max == curr_stats.max:
histogram_updates = curr_stats.histogram
else:
dest_bin_width = (updated_max - updated_min) / self.num_bins
src_bin_width = (curr_stats.max - curr_stats.min) / self.num_bins
histogram_updates = np.zeros(self.num_bins)

for curr_bin in range(self.num_bins):
curr_hist = curr_stats.histogram[curr_bin]
if curr_hist:
src_bin_start = curr_stats.min + src_bin_width * curr_bin
bin_index = self._get_bin_num(dest_bin_width, updated_min, src_bin_start)
dest_bin_end = updated_min + dest_bin_width * (bin_index + 1)

# split curr_hist if values in source bin cannot neatly fold into dest bin
split_hist_value = torch.round(((dest_bin_end - src_bin_start) / src_bin_width) * curr_hist)
dest_bin_updated = min(split_hist_value, curr_hist)
# update appropriate bin with either the full or split curr_hist value
histogram_updates[bin_index] += dest_bin_updated
# if curr_hist is split, update other bin that the remaining values fall into
if dest_bin_updated < curr_hist:
bin_index = self._get_bin_num(dest_bin_width, updated_min, src_bin_start + dest_bin_width)
histogram_updates[bin_index] += curr_hist - dest_bin_updated
# create histogram given input tensor and full range
expanded_histogram, expanded_bin_edges = torch.histogram(curr_input, self.num_bins, range=(updated_min.item(), updated_max.item()))
expanded_histogram += histogram_updates
self.stats[index] = _Histogram(expanded_histogram, expanded_bin_edges, updated_min, updated_max)

def reset_stats(self):
self.stats = _Histogram()
self.stats = []
for _ in range(self.num_histograms):
self.stats.append(_Histogram())

def get_stats(self) -> _Histogram:
def get_stats(self) -> List[_Histogram]:
return self.stats

class EncodingAnalyzer(Generic[_Statistics], ABC):
Expand Down Expand Up @@ -178,7 +260,7 @@ class MinMaxEncodingAnalyzer(EncodingAnalyzer[_MinMaxRange]):
"""
Encoding Analyzer for Min-Max calibration technique
"""
def __init__(self, shape):
def __init__(self, shape: tuple):
observer = _MinMaxObserver(shape)
super().__init__(observer)

Expand Down Expand Up @@ -226,11 +308,20 @@ class PercentileEncodingAnalyzer(EncodingAnalyzer[_Histogram]):
Encoding Analyzer for Percentile calibration technique
"""
def __init__(self, shape: tuple, num_bins: int = 2048):
if num_bins <= 0:
raise ValueError('Number of bins cannot be less than or equal to 0.')
observer = _HistogramObserver(shape=shape, num_bins=num_bins)
super().__init__(observer)

@torch.no_grad()
def compute_encodings_from_stats(self, stats: _Histogram, bitwidth: int, is_symmetric: bool)\
def update_stats(self, input_tensor: torch.Tensor) -> _Statistics:
new_stats = self.observer.collect_stats(input_tensor)
self.observer.merge_stats(new_stats, input_tensor)
return new_stats

# pylint: disable=arguments-differ
@torch.no_grad()
def compute_encodings_from_stats(self, stats: _Histogram, bitwidth: int, is_symmetric: bool, percentile: float)\
-> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
# TODO
raise NotImplementedError
Expand All @@ -240,22 +331,16 @@ class SqnrEncodingAnalyzer(EncodingAnalyzer[_Histogram]):
Encoding Analyzer for SQNR Calibration technique
"""
def __init__(self, shape: tuple, num_bins: int = 2048):
if num_bins <= 0:
raise ValueError('Number of bins cannot be less than or equal to 0.')
observer = _HistogramObserver(shape=shape, num_bins=num_bins)
super().__init__(observer)

@torch.no_grad()
def compute_encodings_from_stats(self, stats: _Histogram, bitwidth: int, is_symmetric: bool)\
-> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
# TODO
raise NotImplementedError

class MseEncodingAnalyzer(EncodingAnalyzer[_Histogram]):
"""
Encoding Analyzer for Mean Square Error (MSE) Calibration technique
"""
def __init__(self, shape: tuple, num_bins: int = 2048):
observer = _HistogramObserver(shape=shape, num_bins=num_bins)
super().__init__(observer)
def update_stats(self, input_tensor: torch.Tensor) -> _Statistics:
new_stats = self.observer.collect_stats(input_tensor)
self.observer.merge_stats(new_stats, input_tensor)
return new_stats

@torch.no_grad()
def compute_encodings_from_stats(self, stats: _Histogram, bitwidth: int, is_symmetric: bool)\
Expand Down
Loading
Loading