Skip to content

Commit

Permalink
Adding threshold for mixed precision metric and scaling the metric in…
Browse files Browse the repository at this point in the history
… case there are values that are larger than the threshold, to preven numerical issues.
  • Loading branch information
Ofir Gordon authored and Ofir Gordon committed Oct 29, 2023
1 parent c1bae64 commit e114128
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def __init__(self,
use_grad_based_weights: bool = True,
output_grad_factor: float = 0.1,
norm_weights: bool = True,
refine_mp_solution: bool = True):
refine_mp_solution: bool = True,
metric_normalization_threshold: float = 1e10):
"""
Class with mixed precision parameters to quantize the input model.
Unlike QuantizationConfig, number of bits for quantization is a list of possible bit widths to
Expand All @@ -49,6 +50,7 @@ def __init__(self,
output_grad_factor (float): A tuning parameter to be used for gradient-based weights.
norm_weights (bool): Whether to normalize the returned weights (to get values between 0 and 1).
refine_mp_solution (bool): Whether to try to improve the final mixed-precision configuration using a greedy algorithm that searches layers to increase their bit-width, or not.
metric_normalization_threshold (float): A threshold for checking the mixed precision distance metric values, In case of values larger than this threshold, the metric will be scaled to prevent numerical issues.
"""

Expand All @@ -72,6 +74,8 @@ def __init__(self,
Logger.info(f"Using gradient-based weights for mixed-precision distance metric with tuning factor "
f"{output_grad_factor}")

self.metric_normalization_threshold = metric_normalization_threshold


class MixedPrecisionQuantizationConfig(QuantizationConfig):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,18 @@ def mp_integer_programming_search(search_manager: MixedPrecisionSearchManager,

layer_to_metrics_mapping = _build_layer_to_metrics_mapping(search_manager, target_kpi)

# Init variables to find their values when solving the lp problem.
# normalize metric for numerical stability
max_dist = max([max([d for b, d in dists.items()]) for layer, dists in layer_to_metrics_mapping.items()])
if max_dist >= search_manager.sensitivity_evaluator.quant_config.metric_normalization_threshold:
Logger.warning(f"The mixed precision distance metric values indicate a large error in the quantized model."
f"this can cause numerical issues."
f"The program will proceed with mixed precision search after scaling the metric values,"
f"which can lead to unstable results.")
for layer, dists in layer_to_metrics_mapping.items():
for b, d in dists.items():
layer_to_metrics_mapping[layer][b] /= max_dist

# Init variables to find their values when solving the lp problem.
layer_to_indicator_vars_mapping, layer_to_objective_vars_mapping = _init_problem_vars(layer_to_metrics_mapping)

# Add all equations and inequalities that define the problem.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from tqdm import tqdm
from typing import List

from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
Expand Down Expand Up @@ -49,7 +50,7 @@ def calculate_quantization_params(graph: Graph,
# Create a list of nodes to compute their thresholds
nodes_list: List[BaseNode] = nodes if specific_nodes else graph.nodes()

for n in nodes_list: # iterate only nodes that we should compute their thresholds
for n in tqdm(nodes_list): # iterate only nodes that we should compute their thresholds
for candidate_qc in n.candidates_quantization_cfg:
if n.is_weights_quantization_enabled():
# If node's weights should be quantized, we compute its weights' quantization parameters
Expand Down

0 comments on commit e114128

Please sign in to comment.