Skip to content

Commit

Permalink
Address code violations
Browse files Browse the repository at this point in the history
Signed-off-by: Priyanka Dangi <quic_pdangi@quicinc.com>
  • Loading branch information
quic-pdangi authored and quic-akhobare committed Feb 2, 2024
1 parent f011a84 commit 5a4df08
Showing 1 changed file with 3 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,7 @@ def __init__(self, shape: tuple, num_bins: int):
@torch.no_grad()
def collect_stats(self, input_tensor: torch.Tensor) -> List[_Histogram]:
if not _is_expandable(self.shape, input_tensor.shape):
raise RuntimeError(f"Shape {self.shape} is incompatible with "
f"input of shape {input_tensor.shape}")
raise RuntimeError(f"Shape {self.shape} is incompatible with input of shape {input_tensor.shape}")

hist_stats = []
input_shape = tuple(input_tensor.shape)
Expand Down Expand Up @@ -178,12 +177,13 @@ def _get_bin_num(self, bin_width: int, curr_min, data):
return bin_width

# pylint: disable=arguments-differ
# pylint: disable=too-many-locals
@torch.no_grad()
def merge_stats(self, new_stats_list: List[_Histogram], input_tensor: torch.Tensor):
if self.stats[0].histogram is None:
self.stats = new_stats_list
return

hist_inputs = torch.reshape(input_tensor, (len(new_stats_list), -1))

for index, new_stats in enumerate(new_stats_list):
Expand Down Expand Up @@ -347,4 +347,3 @@ def compute_encodings_from_stats(self, stats: _Histogram, bitwidth: int, is_symm
-> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
# TODO
raise NotImplementedError

0 comments on commit 5a4df08

Please sign in to comment.