Skip to content

Commit

Permalink
Removing percentile argument from compute encodings (#2768)
Browse files Browse the repository at this point in the history
Signed-off-by: Priyanka Dangi <quic_pdangi@quicinc.com>
  • Loading branch information
quic-pdangi authored Feb 20, 2024
1 parent 5aa5fcd commit 10f1f0d
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -341,37 +341,31 @@ class PercentileEncodingAnalyzer(EncodingAnalyzer[_Histogram]):
"""
Encoding Analyzer for Percentile calibration technique
"""
def __init__(self, shape: tuple, num_bins: int = 2048):
def __init__(self, shape: tuple, num_bins: int = 2048, percentile: float = 100):
if num_bins <= 0:
raise ValueError('Number of bins cannot be less than or equal to 0.')

if percentile < 50 or percentile > 100:
raise ValueError('Percentile value must be within 50-100 range')

observer = _HistogramObserver(shape=shape, num_bins=num_bins)
super().__init__(observer)
self.percentile = percentile

@torch.no_grad()
def update_stats(self, input_tensor: torch.Tensor) -> _Statistics:
new_stats = self.observer.collect_stats(input_tensor)
self.observer.merge_stats(new_stats, input_tensor)
return new_stats

def compute_dynamic_encodings(self, input_tensor: torch.Tensor, bitwidth: int,\
is_symmetric: bool, percentile: float)-> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
return self.compute_encodings_from_stats(
self.observer.collect_stats(input_tensor), bitwidth, is_symmetric, percentile)

def compute_encodings(self, bitwidth: int, is_symmetric: bool, percentile: float) -> torch.Tensor:
return self.compute_encodings_from_stats(self.observer.get_stats(), bitwidth, is_symmetric, percentile)

# pylint: disable=too-many-locals
@torch.no_grad()
def compute_encodings_from_stats(self, stats: List[_Histogram], bitwidth: int, is_symmetric: bool, percentile: float)\
def compute_encodings_from_stats(self, stats: List[_Histogram], bitwidth: int, is_symmetric: bool)\
-> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:

if bitwidth <= 0:
raise ValueError('Bitwidth cannot be less than or equal to 0.')

if percentile < 50 or percentile > 100:
raise ValueError('Percentile value must be within 50-100 range')

if stats[0].histogram is None:
raise StatisticsNotFoundError('No statistics present to compute encodings.')

Expand All @@ -381,10 +375,10 @@ def compute_encodings_from_stats(self, stats: List[_Histogram], bitwidth: int, i
for list_elem in stats:
cum_sum = torch.cumsum(list_elem.histogram, dim=0)
# trim percentile value from min and max
max_index = torch.searchsorted(cum_sum, torch.quantile(cum_sum, percentile/100))
min_index = torch.searchsorted(cum_sum, torch.quantile(cum_sum, 1 - percentile/100))
max_index = torch.searchsorted(cum_sum, torch.quantile(cum_sum, self.percentile/100))
min_index = torch.searchsorted(cum_sum, torch.quantile(cum_sum, 1 - self.percentile/100))

if percentile == 100:
if self.percentile == 100:
min_index = 0
max_index = -1
curr_min = list_elem.bin_edges[min_index]
Expand Down
Loading

0 comments on commit 10f1f0d

Please sign in to comment.