Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use normalize MSE in mixed precision sensitivity evaluation #1082

Merged
merged 2 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -348,26 +348,28 @@ def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
f'framework\'s count_node_for_mixed_precision_interest_points method.') # pragma: no cover

def get_node_distance_fn(self, layer_class: type,
def get_mp_node_distance_fn(self, layer_class: type,
framework_attrs: Dict[str, Any],
compute_distance_fn: Callable = None,
axis: int = None) -> Callable:
axis: int = None,
norm_mse: bool = False) -> Callable:
"""
A mapping between layers' types and a distance function for computing the distance between
two tensors (for loss computation purposes). Returns a specific function if node of specific types is
two tensors in mixed precision (for loss computation purposes). Returns a specific function if node of specific types is
given, or a default (normalized MSE) function otherwise.

Args:
layer_class: Class path of a model's layer.
framework_attrs: Framework attributes the layer had which the graph node holds.
compute_distance_fn: An optional distance function to use globally for all nodes.
axis: The axis on which the operation is preformed (if specified).
norm_mse: whether to normalize mse distance function.

Returns: A distance function between two tensors.
"""

raise NotImplemented(f'{self.__class__.__name__} have to implement the '
f'framework\'s get_node_distance_fn method.') # pragma: no cover
f'framework\'s get_mp_node_distance_fn method.') # pragma: no cover


@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,13 @@ def __init__(self,
fw_impl.count_node_for_mixed_precision_interest_points,
quant_config.num_interest_points_factor)

self.ips_distance_fns, self.ips_axis = self._init_metric_points_lists(self.interest_points)
# We use normalized MSE when not running hessian-based. For Hessian-based normalized MSE is not needed
# beacause hessian weights already do normalization.
use_normalized_mse = self.quant_config.use_hessian_based_scores is False
self.ips_distance_fns, self.ips_axis = self._init_metric_points_lists(self.interest_points, use_normalized_mse)

self.output_points = get_output_nodes_for_metric(graph)
self.out_ps_distance_fns, self.out_ps_axis = self._init_metric_points_lists(self.output_points)
self.out_ps_distance_fns, self.out_ps_axis = self._init_metric_points_lists(self.output_points, use_normalized_mse)

# Setting lists with relative position of the interest points
# and output points in the list of all mp model activation tensors
Expand Down Expand Up @@ -128,14 +131,15 @@ def __init__(self,
self.interest_points_hessians = self._compute_hessian_based_scores()
self.quant_config.distance_weighting_method = lambda d: self.interest_points_hessians

def _init_metric_points_lists(self, points: List[BaseNode]) -> Tuple[List[Callable], List[int]]:
def _init_metric_points_lists(self, points: List[BaseNode], norm_mse: bool = False) -> Tuple[List[Callable], List[int]]:
"""
Initiates required lists for future use when computing the sensitivity metric.
Each point on which the metric is computed uses a dedicated distance function based on its type.
In addition, all distance functions preform batch computation. Axis is needed only for KL Divergence computation.

Args:
points: The set of nodes in the graph for which we need to initiate the lists.
norm_mse: whether to normalize mse distance function.

Returns: A lists with distance functions and an axis list for each node.

Expand All @@ -144,11 +148,12 @@ def _init_metric_points_lists(self, points: List[BaseNode]) -> Tuple[List[Callab
axis_list = []
for n in points:
axis = n.framework_attr.get(AXIS) if not isinstance(n, FunctionalNode) else n.op_call_kwargs.get(AXIS)
distance_fn = self.fw_impl.get_node_distance_fn(
distance_fn = self.fw_impl.get_mp_node_distance_fn(
layer_class=n.layer_class,
framework_attrs=n.framework_attr,
compute_distance_fn=self.quant_config.compute_distance_fn,
axis=axis)
axis=axis,
norm_mse=norm_mse)
distance_fns_list.append(distance_fn)
# Axis is needed only for KL Divergence calculation, otherwise we use per-tensor computation
axis_list.append(axis if distance_fn==compute_kl_divergence else None)
Expand Down
10 changes: 6 additions & 4 deletions model_compression_toolkit/core/keras/keras_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,20 +421,22 @@ def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool

return False

def get_node_distance_fn(self, layer_class: type,
def get_mp_node_distance_fn(self, layer_class: type,
framework_attrs: Dict[str, Any],
compute_distance_fn: Callable = None,
axis: int = None) -> Callable:
axis: int = None,
norm_mse: bool = False) -> Callable:
"""
A mapping between layers' types and a distance function for computing the distance between
two tensors (for loss computation purposes). Returns a specific function if node of specific types is
two tensors in mixed precision (for loss computation purposes). Returns a specific function if node of specific types is
given, or a default (normalized MSE) function otherwise.

Args:
layer_class: Class path of a model's layer.
framework_attrs: Framework attributes the layer had which the graph node holds.
compute_distance_fn: An optional distance function to use globally for all nodes.
axis: The axis on which the operation is preformed (if specified).
norm_mse: whether to normalize mse distance function.

Returns: A distance function between two tensors.
"""
Expand All @@ -456,7 +458,7 @@ def get_node_distance_fn(self, layer_class: type,
return compute_cs
elif layer_class == Dense:
return compute_cs
return compute_mse
return partial(compute_mse, norm=norm_mse)

def get_trace_hessian_calculator(self,
graph: Graph,
Expand Down
10 changes: 6 additions & 4 deletions model_compression_toolkit/core/pytorch/pytorch_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,20 +403,22 @@ def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool
return True
return False

def get_node_distance_fn(self, layer_class: type,
def get_mp_node_distance_fn(self, layer_class: type,
framework_attrs: Dict[str, Any],
compute_distance_fn: Callable = None,
axis: int = None) -> Callable:
axis: int = None,
norm_mse: bool = False) -> Callable:
"""
A mapping between layers' types and a distance function for computing the distance between
two tensors (for loss computation purposes). Returns a specific function if node of specific types is
two tensors in mixed precision (for loss computation purposes). Returns a specific function if node of specific types is
given, or a default (normalized MSE) function otherwise.

Args:
layer_class: Class path of a model's layer.
framework_attrs: Framework attributes the layer had which the graph node holds.
compute_distance_fn: An optional distance function to use globally for all nodes.
axis: The axis on which the operation is preformed (if specified).
norm_mse: whether to normalize mse distance function.

Returns: A distance function between two tensors.
"""
Expand All @@ -430,7 +432,7 @@ def get_node_distance_fn(self, layer_class: type,
return compute_cs
elif layer_class == Linear:
return compute_cs
return compute_mse
return partial(compute_mse, norm=norm_mse)

def is_output_node_compatible_for_hessian_score_computation(self,
node: BaseNode) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,9 @@ def test_softmax_interest_point(self):
if axis is None:
axis = sn.op_call_kwargs.get(AXIS)

distance_fn = KerasImplementation().get_node_distance_fn(layer_class=sn.layer_class,
framework_attrs=sn.framework_attr,
axis=axis)
distance_fn = KerasImplementation().get_mp_node_distance_fn(layer_class=sn.layer_class,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this change makes the lines not aligned

framework_attrs=sn.framework_attr,
axis=axis)
self.assertEqual(distance_fn, compute_kl_divergence,
f"Softmax node should use KL Divergence for distance computation.")

Expand Down
Loading