Skip to content

Commit

Permalink
Move sampling a 1-batch to hessian info service (sony#830)
Browse files Browse the repository at this point in the history
To handle all possible types in the representative dataset (numpy array, tensorflow/pytorch tensors), the sampling method was moved to the hessian info service (instead of framework implementation) and changed to slice operation supported on these types of objects.

Co-authored-by: reuvenp <reuvenp@altair-semi.com>
  • Loading branch information
reuvenperetz and reuvenp authored Oct 17, 2023
1 parent fb5917a commit 1054b4e
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 44 deletions.
12 changes: 0 additions & 12 deletions model_compression_toolkit/core/common/framework_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,6 @@ def get_trace_hessian_calculator(self,
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
f'framework\'s get_trace_hessian_calculator method.') # pragma: no cover

@abstractmethod
def sample_single_representative_dataset(self, representative_dataset: Callable):
"""
Get a single sample (namely, batch size of 1) from a representative dataset.
Args:
representative_dataset: Callable which returns the representative dataset at any batch size.
Returns: List of inputs from representative_dataset where each sample has a batch size of 1.
"""
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
f'framework\'s sample_single_representative_dataset method.') # pragma: no cover

@abstractmethod
def to_numpy(self, tensor: Any) -> np.ndarray:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,30 @@ def __init__(self,
self.graph = graph

# Create a representative_data_gen with batch size of 1
self.representative_dataset = partial(fw_impl.sample_single_representative_dataset,
self.representative_dataset = partial(self._sample_single_representative_dataset,
representative_dataset=representative_dataset)

self.fw_impl = fw_impl
self.num_iterations_for_approximation = num_iterations_for_approximation

self.trace_hessian_request_to_score_list = {}

def _sample_single_representative_dataset(self, representative_dataset: Callable):
"""
Get a single sample (namely, batch size of 1) from a representative dataset.
Args:
representative_dataset: Callable which returns the representative dataset at any batch size.
Returns: List of inputs from representative_dataset where each sample has a batch size of 1.
"""
images = next(representative_dataset())
if not isinstance(images, list):
Logger.error(f'Images expected to be a list but is of type {type(images)}')

# Ensure each image is a single sample, if not, take the first sample
return [image[0:1, ...] if image.shape[0] != 1 else image for image in images]

def _clear_saved_hessian_info(self):
"""Clears the saved info approximations."""
self.trace_hessian_request_to_score_list={}
Expand Down
16 changes: 0 additions & 16 deletions model_compression_toolkit/core/keras/keras_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,19 +591,3 @@ def sensitivity_eval_inference(self,
"""

return model(inputs)

def sample_single_representative_dataset(self, representative_dataset: Callable):
"""
Get a single sample (namely, batch size of 1) from a representative dataset.
Args:
representative_dataset: Callable which returns the representative dataset at any batch size.
Returns: List of inputs from representative_dataset where each sample has a batch size of 1.
"""
images = next(representative_dataset())
if not isinstance(images, list):
Logger.error(f'Images expected to be a list but is of type {type(images)}')

# Ensure each image is a single sample, if not, take the first sample
return [tf.expand_dims(image[0], 0) if image.shape[0] != 1 else image for image in images]
15 changes: 0 additions & 15 deletions model_compression_toolkit/core/pytorch/pytorch_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,18 +540,3 @@ def get_trace_hessian_calculator(self,
fw_impl=self,
num_iterations_for_approximation=num_iterations_for_approximation)

def sample_single_representative_dataset(self, representative_dataset: Callable):
"""
Get a single sample (namely, batch size of 1) from a representative dataset.
Args:
representative_dataset: Callable which returns the representative dataset at any batch size.
Returns: List of inputs from representative_dataset where each sample has a batch size of 1.
"""
images = next(representative_dataset())
if not isinstance(images, list):
Logger.error(f'Images expected to be a list but is of type {type(images)}')

# Ensure each image is a single sample, if not, take the first sample
return [torch.unsqueeze(image[0], 0) if image.shape[0] != 1 else image for image in images]

0 comments on commit 1054b4e

Please sign in to comment.