diff --git a/model_compression_toolkit/core/common/similarity_analyzer.py b/model_compression_toolkit/core/common/similarity_analyzer.py index 83791186a..84fcf493e 100644 --- a/model_compression_toolkit/core/common/similarity_analyzer.py +++ b/model_compression_toolkit/core/common/similarity_analyzer.py @@ -241,4 +241,7 @@ def compute_kl_divergence(float_tensor: np.ndarray, fxp_tensor: np.ndarray, batc non_zero_fxp_tensor = fxp_flat.copy() non_zero_fxp_tensor[non_zero_fxp_tensor == 0] = EPS - return np.mean(np.sum(np.where(float_flat != 0, float_flat * np.log(float_flat / non_zero_fxp_tensor), 0), axis=-1)) + prob_distance = np.where(float_flat != 0, float_flat * np.log(float_flat / non_zero_fxp_tensor), 0) + # The sum is part of the KL-Divergance function. + # The mean is to aggregate the distance between each output probability vectors. + return np.mean(np.sum(prob_distance, axis=-1), axis=-1) diff --git a/tests/keras_tests/function_tests/test_sensitivity_metric_interest_points.py b/tests/keras_tests/function_tests/test_sensitivity_metric_interest_points.py index b742d4bcc..9d89dd473 100644 --- a/tests/keras_tests/function_tests/test_sensitivity_metric_interest_points.py +++ b/tests/keras_tests/function_tests/test_sensitivity_metric_interest_points.py @@ -85,6 +85,7 @@ def softmax_model(input_shape): return model + class TestSensitivityMetricInterestPoints(unittest.TestCase): def test_filtered_interest_points_set(self): @@ -148,7 +149,7 @@ def test_softmax_interest_point(self): distance_per_softmax_axis = distance_fn(t1, t2, batch=True, axis=axis) distance_global = distance_fn(t1, t2, batch=True, axis=None) - self.assertFalse(np.isclose(distance_per_softmax_axis, distance_global), + self.assertFalse(np.isclose(np.mean(distance_per_softmax_axis), distance_global), f"Computing distance for softmax node on softmax activation axis should be different than " f"on than computing on the entire tensor.")