Skip to content

Commit

Permalink
Fix mixed precision LP numerical issue (#843)
Browse files Browse the repository at this point in the history
Adding a threshold for mixed precision metric and scaling the metric in case some values are larger than the threshold, to prevent numerical issues.
This is supposed to address a possible issue with the pulp library's default LP solver.

---------

Co-authored-by: Ofir Gordon <Ofir.Gordon@altair-semi.com>
  • Loading branch information
ofirgo and Ofir Gordon authored Oct 29, 2023
1 parent 7030635 commit 424bb91
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def __init__(self,
use_grad_based_weights: bool = True,
output_grad_factor: float = 0.1,
norm_weights: bool = True,
refine_mp_solution: bool = True):
refine_mp_solution: bool = True,
metric_normalization_threshold: float = 1e10):
"""
Class with mixed precision parameters to quantize the input model.
Unlike QuantizationConfig, number of bits for quantization is a list of possible bit widths to
Expand All @@ -49,6 +50,7 @@ def __init__(self,
output_grad_factor (float): A tuning parameter to be used for gradient-based weights.
norm_weights (bool): Whether to normalize the returned weights (to get values between 0 and 1).
refine_mp_solution (bool): Whether to try to improve the final mixed-precision configuration using a greedy algorithm that searches layers to increase their bit-width, or not.
metric_normalization_threshold (float): A threshold for checking the mixed precision distance metric values, In case of values larger than this threshold, the metric will be scaled to prevent numerical issues.
"""

Expand All @@ -72,6 +74,8 @@ def __init__(self,
Logger.info(f"Using gradient-based weights for mixed-precision distance metric with tuning factor "
f"{output_grad_factor}")

self.metric_normalization_threshold = metric_normalization_threshold


class MixedPrecisionQuantizationConfig(QuantizationConfig):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,30 @@ def compute_kpi_for_config(self, config: List[int]) -> KPI:
config_kpi.set_kpi_by_target(kpis_dict)
return config_kpi

def finalize_distance_metric(self, layer_to_metrics_mapping: Dict[int, Dict[int, float]]):
"""
Finalizing the distance metric building.
The method checks to see if the maximal distance value is larger than a given threshold, and if so,
it scales all metric values to prevent possible numerical issues.
Modification to the dictionary is done inplace.
Args:
layer_to_metrics_mapping: A mapping between a node index to a mapping between
a bitwidth index to a distance value.
"""
# normalize metric for numerical stability

max_dist = max([max([d for b, d in dists.items()]) for layer, dists in layer_to_metrics_mapping.items()])
if max_dist >= self.sensitivity_evaluator.quant_config.metric_normalization_threshold:
Logger.warning(f"The mixed precision distance metric values indicate a large error in the quantized model."
f"this can cause numerical issues."
f"The program will proceed with mixed precision search after scaling the metric values,"
f"which can lead to unstable results.")
for layer, dists in layer_to_metrics_mapping.items():
for b, d in dists.items():
layer_to_metrics_mapping[layer][b] /= max_dist


class ConfigReconstructionHelper:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,4 +297,7 @@ def _build_layer_to_metrics_mapping(search_manager: MixedPrecisionSearchManager,
[node_idx],
search_manager.max_kpi_config)

# Finalize distance metric mapping
search_manager.finalize_distance_metric(layer_to_metrics_mapping)

return layer_to_metrics_mapping
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ def compute_kpi_matrix(self, target):

return np.array(kpi_matrix)

def finalize_distance_metric(self, d):
return d


class TestLpSearchBitwidth(unittest.TestCase):

Expand Down

0 comments on commit 424bb91

Please sign in to comment.