diff --git a/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py b/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py index 047745ca7..360d97112 100644 --- a/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +++ b/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py @@ -26,7 +26,7 @@ from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \ RUTarget, ResourceUtilization from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_calculator import \ - ResourceUtilizationCalculator, TargetInclusionCriterion, BitwidthMode + TargetInclusionCriterion, BitwidthMode from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_methods import \ MixedPrecisionRUHelper from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation @@ -67,13 +67,19 @@ def __init__(self, self.compute_metric_fn = self.get_sensitivity_metric() self._cuts = None - self.ru_metrics = target_resource_utilization.get_restricted_metrics() + # To define RU Total constraints we need to compute weights and activations even if they have no constraints + # TODO currently this logic is duplicated in linear_programming.py + targets = target_resource_utilization.get_restricted_metrics() + if RUTarget.TOTAL in targets: + targets = targets.union({RUTarget.ACTIVATION, RUTarget.WEIGHTS}) - {RUTarget.TOTAL} + self.ru_targets_to_compute = targets + self.ru_helper = MixedPrecisionRUHelper(graph, fw_info, fw_impl) self.target_resource_utilization = target_resource_utilization self.min_ru_config = self.graph.get_min_candidates_config(fw_info) self.max_ru_config = self.graph.get_max_candidates_config(fw_info) - self.min_ru = self.ru_helper.compute_utilization(self.ru_metrics, self.min_ru_config) - self.non_conf_ru_dict = self._non_configurable_nodes_ru() + self.min_ru = self.ru_helper.compute_utilization(self.ru_targets_to_compute, self.min_ru_config) + self.non_conf_ru_dict = self.ru_helper.compute_utilization(self.ru_targets_to_compute, None) self.config_reconstruction_helper = ConfigReconstructionHelper(virtual_graph=self.graph, original_graph=self.original_graph) @@ -111,18 +117,14 @@ def get_sensitivity_metric(self) -> Callable: def compute_resource_utilization_matrix(self, target: RUTarget) -> np.ndarray: """ Computes and builds a resource utilization matrix, to be used for the mixed-precision search problem formalization. - The matrix is constructed as follows (for a given target): - - Each row represents the set of resource utilization values for a specific resource utilization - measure (number of rows should be equal to the length of the output of the respective target compute_ru function). - - Each entry in a specific column represents the resource utilization value of a given configuration - (single layer is configured with specific candidate, all other layer are at the minimal resource - utilization configuration) for the resource utilization measure of the respective row. + Utilization is computed relative to the minimal configuration, i.e. utilization for it will be 0. Args: target: The resource target for which the resource utilization is calculated (a RUTarget value). - Returns: A resource utilization matrix. - + Returns: + A resource utilization matrix of shape (num memory elements, num configurations). Num memory elements + depends on the target, e.g. num nodes or num cuts, for which utilization is computed. """ assert isinstance(target, RUTarget), f"{target} is not a valid resource target" @@ -132,21 +134,14 @@ def compute_resource_utilization_matrix(self, target: RUTarget) -> np.ndarray: for c, c_n in enumerate(configurable_sorted_nodes): for candidate_idx in range(len(c_n.candidates_quantization_cfg)): if candidate_idx == self.min_ru_config[c]: - # skip ru computation for min configuration. Since we compute the difference from min_ru it'll - # always be 0 for all entries in the results vector. - candidate_rus = np.zeros(shape=self.min_ru[target].shape) + candidate_rus = self.min_ru[target] else: - candidate_rus = self.compute_node_ru_for_candidate(c, candidate_idx, target) - self.min_ru[target] + candidate_rus = self.compute_node_ru_for_candidate(c, candidate_idx, target) ru_matrix.append(np.asarray(candidate_rus)) - # We need to transpose the calculated ru matrix to allow later multiplication with - # the indicators' diagonal matrix. - # We only move the first axis (num of configurations) to be last, - # the remaining axes include the metric specific nodes (rows dimension of the new tensor) - # and the ru metric values (if they are non-scalars) - np_ru_matrix = np.array(ru_matrix) - return np.moveaxis(np_ru_matrix, source=0, destination=len(np_ru_matrix.shape) - 1) + np_ru_matrix = np.array(ru_matrix) - self.min_ru[target] # num configurations X num elements + return np_ru_matrix.T def compute_node_ru_for_candidate(self, conf_node_idx: int, candidate_idx: int, target: RUTarget) -> np.ndarray: """ @@ -162,7 +157,6 @@ def compute_node_ru_for_candidate(self, conf_node_idx: int, candidate_idx: int, """ cfg = self.replace_config_in_index(self.min_ru_config, conf_node_idx, candidate_idx) - # TODO compute for all targets at once. Currently the way up to add_set_of_ru_constraints is per target. return self.ru_helper.compute_utilization({target}, cfg)[target] @staticmethod @@ -183,18 +177,6 @@ def replace_config_in_index(mp_cfg: List[int], idx: int, value: int) -> List[int updated_cfg[idx] = value return updated_cfg - def _non_configurable_nodes_ru(self) -> Dict[RUTarget, np.ndarray]: - """ - Computes a resource utilization vector of all non-configurable nodes in the given graph for each of the - resource utilization targets. - - Returns: A mapping between a RUTarget and its non-configurable nodes' resource utilization vector. - """ - ru_metrics = self.ru_metrics - {RUTarget.BOPS} - ru = self.ru_helper.compute_utilization(ru_targets=ru_metrics, mp_cfg=None) - ru[RUTarget.BOPS] = None - return ru - def compute_resource_utilization_for_config(self, config: List[int]) -> ResourceUtilization: """ Computes the resource utilization values for a given mixed-precision configuration. @@ -206,7 +188,7 @@ def compute_resource_utilization_for_config(self, config: List[int]) -> Resource with the given config. """ - act_qcs, w_qcs = self.ru_helper.get_configurable_qcs(config) + act_qcs, w_qcs = self.ru_helper.get_quantization_candidates(config) ru = self.ru_helper.ru_calculator.compute_resource_utilization( target_criterion=TargetInclusionCriterion.AnyQuantized, bitwidth_mode=BitwidthMode.QCustom, act_qcs=act_qcs, w_qcs=w_qcs) diff --git a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py index b99b2f55d..9e22cd5f7 100644 --- a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +++ b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py @@ -152,10 +152,10 @@ def compute_resource_utilization(self, elif w_qcs is not None: # pragma: no cover raise ValueError('Weight configuration passed but no relevant metric requested.') - if act_qcs and not {RUTarget.ACTIVATION, RUTarget.TOTAL}.intersection(ru_targets): # pragma: no cover - raise ValueError('Activation configuration passed but no relevant metric requested.') - if RUTarget.ACTIVATION in ru_targets: + if {RUTarget.ACTIVATION, RUTarget.TOTAL}.intersection(ru_targets): a_total = self.compute_activations_utilization(target_criterion, bitwidth_mode, act_qcs) + elif act_qcs is not None: # pragma: no cover + raise ValueError('Activation configuration passed but no relevant metric requested.') ru = ResourceUtilization() if RUTarget.WEIGHTS in ru_targets: @@ -163,9 +163,7 @@ def compute_resource_utilization(self, if RUTarget.ACTIVATION in ru_targets: ru.activation_memory = a_total if RUTarget.TOTAL in ru_targets: - # TODO use maxcut - act_tensors_total, *_ = self.compute_activation_tensors_utilization(target_criterion, bitwidth_mode, act_qcs) - ru.total_memory = w_total + act_tensors_total + ru.total_memory = w_total + a_total if RUTarget.BOPS in ru_targets: ru.bops, _ = self.compute_bops(target_criterion=target_criterion, bitwidth_mode=bitwidth_mode, act_qcs=act_qcs, w_qcs=w_qcs) diff --git a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_methods.py b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_methods.py index b3605089f..24350ee29 100644 --- a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_methods.py +++ b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_methods.py @@ -12,14 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from typing import List, Set, Dict, Optional, Tuple +from typing import List, Set, Dict, Optional, Tuple, Any import numpy as np from model_compression_toolkit.core import FrameworkInfo from model_compression_toolkit.core.common import Graph, BaseNode from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation -from model_compression_toolkit.core.common.graph.memory_graph.cut import Cut from model_compression_toolkit.core.common.graph.virtual_activation_weights_node import VirtualActivationWeightsNode from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \ RUTarget @@ -44,9 +43,8 @@ def __init__(self, graph: Graph, fw_info: FrameworkInfo, fw_impl: FrameworkImple def compute_utilization(self, ru_targets: Set[RUTarget], mp_cfg: Optional[List[int]]) -> Dict[RUTarget, np.ndarray]: """ Compute utilization of requested targets for a specific configuration in the format expected by LP problem - formulation, namely an array of ru values corresponding to graph's configurable nodes in the topological order. - For activation target, the array contains values for activation cuts in unspecified order (as long as it is - consistent between configurations). + formulation namely a vector of ru values for relevant memory elements (nodes or cuts) in a constant order + (between calls). Args: ru_targets: resource utilization targets to compute. @@ -57,33 +55,26 @@ def compute_utilization(self, ru_targets: Set[RUTarget], mp_cfg: Optional[List[i """ ru = {} - - act_qcs, w_qcs = self.get_configurable_qcs(mp_cfg) if mp_cfg else (None, None) - w_util = None + act_qcs, w_qcs = self.get_quantization_candidates(mp_cfg) if mp_cfg else (None, None) if RUTarget.WEIGHTS in ru_targets: - w_util = self._weights_utilization(w_qcs) - ru[RUTarget.WEIGHTS] = np.array(list(w_util.values())) + wu = self._weights_utilization(w_qcs) + ru[RUTarget.WEIGHTS] = np.array(list(wu.values())) - # TODO make mp agnostic to activation method if RUTarget.ACTIVATION in ru_targets: - act_util = self._activation_maxcut_utilization(act_qcs) - ru[RUTarget.ACTIVATION] = np.array(list(act_util.values())) - - # TODO use maxcut - if RUTarget.TOTAL in ru_targets: - act_tensors_util = self._activation_tensor_utilization(act_qcs) - w_util = w_util or self._weights_utilization(w_qcs) - total = {n: (w_util.get(n, 0), act_tensors_util.get(n, 0)) - # for n in self.graph.nodes if n in act_tensors_util or n in w_util} - for n in self.graph.get_topo_sorted_nodes() if n in act_tensors_util or n in w_util} - ru[RUTarget.TOTAL] = np.array(list(total.values())) + au = self._activation_utilization(act_qcs) + ru[RUTarget.ACTIVATION] = np.array(list(au.values())) if RUTarget.BOPS in ru_targets: ru[RUTarget.BOPS] = self._bops_utilization(mp_cfg) + if RUTarget.TOTAL in ru_targets: + raise ValueError('Total target should be computed based on weights and activations targets.') + + assert len(ru) == len(ru_targets), (f'Mismatch between the number of computed and requested metrics.' + f'Requested {ru_targets}') return ru - def get_configurable_qcs(self, mp_cfg) \ + def get_quantization_candidates(self, mp_cfg) \ -> Tuple[Dict[BaseNode, NodeActivationQuantizationConfig], Dict[BaseNode, NodeWeightsQuantizationConfig]]: """ Retrieve quantization candidates objects for weights and activations from the configuration list. @@ -92,15 +83,13 @@ def get_configurable_qcs(self, mp_cfg) \ mp_cfg: a list of candidates indices for configurable layers. Returns: - Mapping between nodes to weights quantization config, and a mapping between nodes and activation + A mapping between nodes to weights quantization config, and a mapping between nodes and activation quantization config. """ mp_nodes = self.graph.get_configurable_sorted_nodes(self.fw_info) node_qcs = {n: n.candidates_quantization_cfg[mp_cfg[i]] for i, n in enumerate(mp_nodes)} - act_qcs = {n: node_qcs[n].activation_quantization_cfg - for n in self.graph.get_activation_configurable_nodes()} - w_qcs = {n: node_qcs[n].weights_quantization_cfg - for n in self.graph.get_weights_configurable_nodes(self.fw_info)} + act_qcs = {n: cfg.activation_quantization_cfg for n, cfg in node_qcs.items()} + w_qcs = {n: cfg.weights_quantization_cfg for n, cfg in node_qcs.items()} return act_qcs, w_qcs def _weights_utilization(self, w_qcs: Optional[Dict[BaseNode, NodeWeightsQuantizationConfig]]) -> Dict[BaseNode, float]: @@ -127,8 +116,8 @@ def _weights_utilization(self, w_qcs: Optional[Dict[BaseNode, NodeWeightsQuantiz nodes_util = {n: u.bytes for n, u in nodes_util.items()} return nodes_util - def _activation_maxcut_utilization(self, act_qcs: Optional[Dict[BaseNode, NodeActivationQuantizationConfig]]) \ - -> Optional[Dict[Cut, float]]: + def _activation_utilization(self, act_qcs: Optional[Dict[BaseNode, NodeActivationQuantizationConfig]]) \ + -> Optional[Dict[Any, float]]: """ Compute activation utilization using MaxCut for all quantized nodes if configuration is passed. @@ -138,41 +127,18 @@ def _activation_maxcut_utilization(self, act_qcs: Optional[Dict[BaseNode, NodeAc Returns: Activation utilization per cut, or empty dict if no configuration was passed. """ - if act_qcs: - _, cuts_util, _ = self.ru_calculator.compute_cut_activation_utilization(TargetInclusionCriterion.AnyQuantized, - bitwidth_mode=BitwidthMode.QCustom, - act_qcs=act_qcs) - cuts_util = {c: u.bytes for c, u in cuts_util.items()} - return cuts_util - - # Computing non-configurable nodes resource utilization for max-cut is included in the calculation of the - # configurable nodes. - return {} - - def _activation_tensor_utilization(self, act_qcs: Optional[Dict[BaseNode, NodeActivationQuantizationConfig]]): - """ - Compute activation tensors utilization fo configurable nodes if configuration is passed or - for non-configurable nodes otherwise. - - Args: - act_qcs: activation quantization configuration or None. - - Returns: - Activation utilization per node. - """ - if act_qcs: - target_criterion = TargetInclusionCriterion.QConfigurable - bitwidth_mode = BitwidthMode.QCustom - else: - target_criterion = TargetInclusionCriterion.QNonConfigurable - bitwidth_mode = BitwidthMode.QDefaultSP - - _, nodes_util = self.ru_calculator.compute_activation_tensors_utilization(target_criterion=target_criterion, - bitwidth_mode=bitwidth_mode, - act_qcs=act_qcs) - return {n: u.bytes for n, u in nodes_util.items()} - - def _bops_utilization(self, mp_cfg: List[int]): + # Maxcut activation utilization is computed for all quantized nodes, so non-configurable memory is already + # covered by the computation of configurable activations. + if not act_qcs: + return {} + + _, cuts_util, *_ = self.ru_calculator.compute_cut_activation_utilization(TargetInclusionCriterion.AnyQuantized, + bitwidth_mode=BitwidthMode.QCustom, + act_qcs=act_qcs) + cuts_util = {c: u.bytes for c, u in cuts_util.items()} + return cuts_util + + def _bops_utilization(self, mp_cfg: List[int]) -> np.ndarray: """ Computes a resource utilization vector with the respective bit-operations (BOPS) count for each configurable node, according to the given mixed-precision configuration of a virtual graph with composed nodes. @@ -180,15 +146,15 @@ def _bops_utilization(self, mp_cfg: List[int]): Args: mp_cfg: A mixed-precision configuration (list of candidates index for each configurable node) - Returns: A vector of node's BOPS count. - Note that the vector is not necessarily of the same length as the given config. - + Returns: + A vector of node's BOPS count. """ - # TODO keeping old implementation for now - - # BOPs utilization method considers non-configurable nodes, therefore, it doesn't need separate implementation - # for non-configurable nodes for setting a constraint (no need for separate implementation for len(mp_cfg) = 0). + # bops is computed for all nodes, so non-configurable memory is already covered by the computation of + # configurable nodes + if not mp_cfg: + return np.array([]) + # TODO keeping old implementation for now virtual_bops_nodes = [n for n in self.graph.get_topo_sorted_nodes() if isinstance(n, VirtualActivationWeightsNode)] mp_nodes = self.graph.get_configurable_sorted_nodes_names(self.fw_info) diff --git a/model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py b/model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py index 56ee0e5ca..24c695c8b 100644 --- a/model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +++ b/model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py @@ -16,7 +16,7 @@ import numpy as np from pulp import * from tqdm import tqdm -from typing import Dict, Tuple +from typing import Dict, Tuple, Set, Any from model_compression_toolkit.logger import Logger from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization, RUTarget @@ -167,93 +167,93 @@ def _formalize_problem(layer_to_indicator_vars_mapping: Dict[int, Dict[int, LpVa indicators_arr = np.array(indicators) indicators_matrix = np.diag(indicators_arr) - for target, ru_value in target_resource_utilization.get_resource_utilization_dict().items(): - if not np.isinf(ru_value): - non_conf_ru_vector = None if search_manager.non_conf_ru_dict is None \ - else search_manager.non_conf_ru_dict.get(target) - _add_set_of_ru_constraints(search_manager=search_manager, - target=target, - target_resource_utilization_value=ru_value, - indicators_matrix=indicators_matrix, - lp_problem=lp_problem, - non_conf_ru_vector=non_conf_ru_vector) + _add_ru_constraints(search_manager=search_manager, + target_resource_utilization=target_resource_utilization, + indicators_matrix=indicators_matrix, + lp_problem=lp_problem, + non_conf_ru_dict=search_manager.non_conf_ru_dict) else: # pragma: no cover Logger.critical("Unable to execute mixed-precision search: 'target_resource_utilization' is None. " "A valid 'target_resource_utilization' is required.") return lp_problem -def _add_set_of_ru_constraints(search_manager: MixedPrecisionSearchManager, - target: RUTarget, - target_resource_utilization_value: float, - indicators_matrix: np.ndarray, - lp_problem: LpProblem, - non_conf_ru_vector: np.ndarray): +def _add_ru_constraints(search_manager: MixedPrecisionSearchManager, + target_resource_utilization: ResourceUtilization, + indicators_matrix: np.ndarray, + lp_problem: LpProblem, + non_conf_ru_dict: Optional[Dict[RUTarget, np.ndarray]]): """ - Adding a constraint for the Lp problem for the given target resource utilization. + Adding targets constraints for the Lp problem for the given target resource utilization. The update to the Lp problem object is done inplace. Args: search_manager: MixedPrecisionSearchManager object to be used for resource utilization constraints formalization. - target: A RUTarget. - target_resource_utilization_value: Target resource utilization value of the given target resource utilization - for which the constraint is added. + target_resource_utilization: Target resource utilization. indicators_matrix: A diagonal matrix of the Lp problem's indicators. lp_problem: An Lp problem object to add constraint to. - non_conf_ru_vector: A non-configurable nodes' resource utilization vector. - + non_conf_ru_dict: A non-configurable nodes' resource utilization vectors for the constrained targets. """ + ru_indicated_vectors = {} + # targets to add constraints for + constraints_targets = target_resource_utilization.get_restricted_metrics() + # to add constraints for Total target we need to compute weight and activation + targets_to_compute = constraints_targets + if RUTarget.TOTAL in constraints_targets: + targets_to_compute = targets_to_compute.union({RUTarget.ACTIVATION, RUTarget.WEIGHTS}) - {RUTarget.TOTAL} + + for target in targets_to_compute: + ru_matrix = search_manager.compute_resource_utilization_matrix(target) # num elements X num configurations + indicated_ru_matrix = np.matmul(ru_matrix, indicators_matrix) # num elements X num configurations + + # Sum the indicated values over all configurations, and add the value for minimal configuration once. + # Indicated utilization values are relative to the minimal configuration, i.e. they represent the extra memory + # that would be required if that configuration is selected). + # Each element in a vector is an lp object representing the configurations sum term for a memory element. + ru_vec = indicated_ru_matrix.sum(axis=1) + search_manager.min_ru[target] + + non_conf_ru_vec = non_conf_ru_dict[target] + if non_conf_ru_vec is not None and non_conf_ru_vec.size: + # add non-conf value as additional mem elements so that they get aggregated + ru_vec = np.concatenate([ru_vec, non_conf_ru_vec]) + ru_indicated_vectors[target] = ru_vec + + # add constraints only for the restricted targets in target resource utilization. + for target in constraints_targets: + target_resource_utilization_value = target_resource_utilization.get_resource_utilization_dict()[target] + aggr_ru = _aggregate_for_lp(ru_indicated_vectors, target) + for v in aggr_ru: + if isinstance(v, float): + if v > target_resource_utilization_value: + Logger.critical( + f"The model cannot be quantized to meet the specified target resource utilization {target.value} " + f"with the value {target_resource_utilization_value}.") # pragma: no cover + else: + lp_problem += v <= target_resource_utilization_value + - ru_matrix = search_manager.compute_resource_utilization_matrix(target) - indicated_ru_matrix = np.matmul(ru_matrix, indicators_matrix) - # Need to re-organize the tensor such that the configurations' axis will be second, - # and all metric values' axis will come afterword - indicated_ru_matrix = np.moveaxis(indicated_ru_matrix, source=len(indicated_ru_matrix.shape) - 1, destination=1) - - # In order to get the result resource utilization according to a chosen set of indicators, we sum each row in - # the result matrix. Each row represents the resource utilization values for a specific resource utilization metric, - # such that only elements corresponding to a configuration which implied by the set of indicators will have some - # positive value different than 0 (and will contribute to the total resource utilization). - ru_sum_vector = np.array([ - np.sum(indicated_ru_matrix[i], axis=0) + # sum of metric values over all configurations in a row - search_manager.min_ru[target][i] for i in range(indicated_ru_matrix.shape[0])]) - - ru_vec = ru_sum_vector - if non_conf_ru_vector is not None and non_conf_ru_vector.size: - ru_vec = np.concatenate([ru_vec, non_conf_ru_vector]) - - aggr_ru = _aggregate_for_lp(ru_vec, target) - for v in aggr_ru: - if isinstance(v, float): - if v > target_resource_utilization_value: - Logger.critical( - f"The model cannot be quantized to meet the specified target resource utilization {target.value} " - f"with the value {target_resource_utilization_value}.") # pragma: no cover - else: - lp_problem += v <= target_resource_utilization_value - - -def _aggregate_for_lp(ru_vec, target: RUTarget) -> list: +def _aggregate_for_lp(targets_ru_vec: Dict[RUTarget, Any], target: RUTarget) -> list: """ Aggregate resource utilization values for the LP. Args: - ru_vec: a vector of resource utilization values. + targets_ru_vec: resource utilization vectors for all precomputed targets. target: resource utilization target. Returns: Aggregated resource utilization. """ if target == RUTarget.TOTAL: - w = lpSum(v[0] for v in ru_vec) - return [w + v[1] for v in ru_vec] + w = lpSum(targets_ru_vec[RUTarget.WEIGHTS]) + act_ru_vec = targets_ru_vec[RUTarget.ACTIVATION] + return [w + v for v in act_ru_vec] if target in [RUTarget.WEIGHTS, RUTarget.BOPS]: - return [lpSum(ru_vec)] + return [lpSum(targets_ru_vec[target])] if target == RUTarget.ACTIVATION: # for max aggregation, each value constitutes a separate constraint - return list(ru_vec) + return list(targets_ru_vec[target]) raise ValueError(f'Unexpected target {target}.') diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_tests.py b/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_tests.py index 0fd4ee3af..242c20f59 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_tests.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_tests.py @@ -174,13 +174,12 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= # test with its current setup (therefore, we don't check the input layer's bitwidth) self.unit_test.assertTrue((activation_bits == [4, 8])) - # TODO maxcut: restore this test after total_memory is fixed to be the sum of weight & activation metrics. - # # Verify final resource utilization - # self.unit_test.assertTrue( - # quantization_info.final_resource_utilization.total_memory == - # quantization_info.final_resource_utilization.weights_memory + quantization_info.final_resource_utilization.activation_memory, - # "Running weights and activation mixed-precision, " - # "final total memory should be equal to sum of weights and activation memory.") + # Verify final resource utilization + self.unit_test.assertTrue( + quantization_info.final_resource_utilization.total_memory == + quantization_info.final_resource_utilization.weights_memory + quantization_info.final_resource_utilization.activation_memory, + "Running weights and activation mixed-precision, " + "final total memory should be equal to sum of weights and activation memory.") class MixedPrecisionActivationSearch2BitsAvgTest(MixedPrecisionActivationBaseTest): @@ -207,13 +206,12 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= activation_layers_idx=self.activation_layers_idx, unique_tensor_values=4) - # TODO maxcut: restore this test after total_memory is fixed to be the sum of weight & activation metrics. - # # Verify final resource utilization - # self.unit_test.assertTrue( - # quantization_info.final_resource_utilization.total_memory == - # quantization_info.final_resource_utilization.weights_memory + quantization_info.final_resource_utilization.activation_memory, - # "Running weights and activation mixed-precision, " - # "final total memory should be equal to sum of weights and activation memory.") + # Verify final resource utilization + self.unit_test.assertTrue( + quantization_info.final_resource_utilization.total_memory == + quantization_info.final_resource_utilization.weights_memory + quantization_info.final_resource_utilization.activation_memory, + "Running weights and activation mixed-precision, " + "final total memory should be equal to sum of weights and activation memory.") class MixedPrecisionActivationDepthwiseTest(MixedPrecisionActivationBaseTest): @@ -343,13 +341,12 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= activation_layers_idx=self.activation_layers_idx, unique_tensor_values=256) - # TODO maxcut: restore this test after total_memory is fixed to be the sum of weight & activation metrics. - # # Verify final ResourceUtilization - # self.unit_test.assertTrue( - # quantization_info.final_resource_utilization.activation_memory + quantization_info.final_resource_utilization.weights_memory == - # quantization_info.final_resource_utilization.total_memory, - # "Running activation mixed-precision with unconstrained weights and total resource utilization, " - # "final total memory should be equal to the sum of activation and weights memory.") + # Verify final ResourceUtilization + self.unit_test.assertTrue( + quantization_info.final_resource_utilization.activation_memory + quantization_info.final_resource_utilization.weights_memory == + quantization_info.final_resource_utilization.total_memory, + "Running activation mixed-precision with unconstrained weights and total resource utilization, " + "final total memory should be equal to the sum of activation and weights memory.") class MixedPrecisionActivationOnlyWeightsDisabledTest(MixedPrecisionActivationBaseTest): @@ -472,7 +469,8 @@ def __init__(self, unit_test): super().__init__(unit_test, activation_layers_idx=[2, 4]) def get_resource_utilization(self): - return ResourceUtilization(np.inf, np.inf, total_memory=(17920 + 5408) * 4 / 8) + # 17920: 8-bit weights, 6176: max cut of input+conv_bn + return ResourceUtilization(np.inf, np.inf, total_memory=(17920 + 6176) * 4 / 8) def compare(self, quantized_model, float_model, input_x=None, quantization_info: UserInformation = None): # verify chosen activation bitwidth config @@ -486,13 +484,12 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info: activation_layers_idx=self.activation_layers_idx, unique_tensor_values=16) - # TODO maxcut: restore this test after total_memory is fixed to be the sum of weight & activation metrics. - # # Verify final ResourceUtilization - # self.unit_test.assertTrue( - # quantization_info.final_resource_utilization.total_memory == - # quantization_info.final_resource_utilization.weights_memory + quantization_info.final_resource_utilization.activation_memory, - # "Running weights and activation mixed-precision, " - # "final total memory should be equal to sum of weights and activation memory.") + # Verify final ResourceUtilization + self.unit_test.assertTrue( + quantization_info.final_resource_utilization.total_memory == + quantization_info.final_resource_utilization.weights_memory + quantization_info.final_resource_utilization.activation_memory, + "Running weights and activation mixed-precision, " + "final total memory should be equal to sum of weights and activation memory.") class MixedPrecisionMultipleResourcesTightUtilizationSearchTest(MixedPrecisionActivationBaseTest): @@ -501,7 +498,8 @@ def __init__(self, unit_test): def get_resource_utilization(self): weights = 17920 * 4 / 8 - activation = 4000 + # activation = 4000 + activation = 6176 * 4 / 8 return ResourceUtilization(weights, activation, total_memory=weights + activation) def compare(self, quantized_model, float_model, input_x=None, quantization_info: UserInformation = None): @@ -510,21 +508,22 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info: activation_bits = [layer.activation_holder_quantizer.get_config()['num_bits'] for layer in holder_layers] # TODO maxcut: restore activation_bits == [4, 4] and unique_tensor_values=16 when maxcut calculates tensor sizes # of fused nodes correctly. - self.unit_test.assertTrue((activation_bits == [4, 8])) + # TODO: maxcut Test updated but lowered activation ru (how can 4000 enforce 4,4??). Not sure what the fused nodes + # comment is about so I might be missing something. Elad? + self.unit_test.assertTrue((activation_bits == [4, 4])) self.verify_quantization(quantized_model, input_x, weights_layers_idx=[2, 3], weights_layers_channels_size=[32, 32], activation_layers_idx=self.activation_layers_idx, - unique_tensor_values=256) + unique_tensor_values=16) - # TODO maxcut: restore this test after total_memory is fixed to be the sum of weight & activation metrics. - # # Verify final ResourceUtilization - # self.unit_test.assertTrue( - # quantization_info.final_resource_utilization.total_memory == - # quantization_info.final_resource_utilization.weights_memory + quantization_info.final_resource_utilization.activation_memory, - # "Running weights and activation mixed-precision, " - # "final total memory should be equal to sum of weights and activation memory.") + # Verify final ResourceUtilization + self.unit_test.assertTrue( + quantization_info.final_resource_utilization.total_memory == + quantization_info.final_resource_utilization.weights_memory + quantization_info.final_resource_utilization.activation_memory, + "Running weights and activation mixed-precision, " + "final total memory should be equal to sum of weights and activation memory.") class MixedPrecisionReducedTotalMemorySearchTest(MixedPrecisionActivationBaseTest): @@ -533,7 +532,7 @@ def __init__(self, unit_test): def get_resource_utilization(self): weights = 17920 * 4 / 8 - activation = 5408 * 4 / 8 + activation = 6176 * 4 / 8 # max cut of input + conv_bn return ResourceUtilization(weights, activation, total_memory=(weights + activation) / 2) def compare(self, quantized_model, float_model, input_x=None, quantization_info: UserInformation = None): @@ -548,13 +547,12 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info: activation_layers_idx=self.activation_layers_idx, unique_tensor_values=16) - # TODO maxcut: restore this test after total_memory is fixed to be the sum of weight & activation metrics. - # # Verify final ResourceUtilization - # self.unit_test.assertTrue( - # quantization_info.final_resource_utilization.total_memory == - # quantization_info.final_resource_utilization.weights_memory + quantization_info.final_resource_utilization.activation_memory, - # "Running weights and activation mixed-precision, " - # "final total memory should be equal to sum of weights and activation memory.") + # Verify final ResourceUtilization + self.unit_test.assertTrue( + quantization_info.final_resource_utilization.total_memory == + quantization_info.final_resource_utilization.weights_memory + quantization_info.final_resource_utilization.activation_memory, + "Running weights and activation mixed-precision, " + "final total memory should be equal to sum of weights and activation memory.") class MixedPrecisionDistanceSoftmaxTest(MixedPrecisionActivationBaseTest): diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py b/tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py index 077e91db2..9ed6e079a 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py @@ -127,13 +127,12 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= self.unit_test.assertTrue( np.unique(conv_layers[1].get_quantized_weights()['kernel'][:, :, :, i]).flatten().shape[0] <= 256) - # TODO maxcut: restore this test after total_memory is fixed to be the sum of weight & activation metrics. - # # Verify final ResourceUtilization - # self.unit_test.assertTrue( - # quantization_info.final_resource_utilization.weights_memory + quantization_info.final_resource_utilization.activation_memory == - # quantization_info.final_resource_utilization.total_memory, - # "Running weights mixed-precision with unconstrained ResourceUtilization, " - # "final weights and activation memory sum should be equal to total memory.") + # Verify final ResourceUtilization + self.unit_test.assertTrue( + quantization_info.final_resource_utilization.weights_memory + quantization_info.final_resource_utilization.activation_memory == + quantization_info.final_resource_utilization.total_memory, + "Running weights mixed-precision with unconstrained ResourceUtilization, " + "final weights and activation memory sum should be equal to total memory.") class MixedPrecisionWithHessianScoresTest(MixedPrecisionBaseTest): @@ -161,13 +160,12 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= self.unit_test.assertTrue( np.unique(conv_layers[1].get_quantized_weights()['kernel'][:, :, :, i]).flatten().shape[0] <= 256) - # TODO maxcut: restore this test after total_memory is fixed to be the sum of weight & activation metrics. - # # Verify final ResourceUtilization - # self.unit_test.assertTrue( - # quantization_info.final_resource_utilization.weights_memory + quantization_info.final_resource_utilization.activation_memory == - # quantization_info.final_resource_utilization.total_memory, - # "Running weights mixed-precision with unconstrained ResourceUtilization, " - # "final weights and activation memory sum should be equal to total memory.") + # Verify final ResourceUtilization + self.unit_test.assertTrue( + quantization_info.final_resource_utilization.weights_memory + quantization_info.final_resource_utilization.activation_memory == + quantization_info.final_resource_utilization.total_memory, + "Running weights mixed-precision with unconstrained ResourceUtilization, " + "final weights and activation memory sum should be equal to total memory.") class MixedPrecisionSearchPartWeightsLayersTest(MixedPrecisionBaseTest): @@ -255,13 +253,12 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= self.unit_test.assertTrue( np.unique(conv_layers[1].get_quantized_weights()['kernel'][:, :, :, i]).flatten().shape[0] <= 16) - # TODO maxcut: restore this test after total_memory is fixed to be the sum of weight & activation metrics. - # # Verify final ResourceUtilization - # self.unit_test.assertTrue( - # quantization_info.final_resource_utilization.weights_memory + quantization_info.final_resource_utilization.activation_memory == - # quantization_info.final_resource_utilization.total_memory, - # "Running weights mixed-precision with unconstrained ResourceUtilization, " - # "final weights and activation memory sum should be equal to total memory.") + # Verify final ResourceUtilization + self.unit_test.assertTrue( + quantization_info.final_resource_utilization.weights_memory + quantization_info.final_resource_utilization.activation_memory == + quantization_info.final_resource_utilization.total_memory, + "Running weights mixed-precision with unconstrained ResourceUtilization, " + "final weights and activation memory sum should be equal to total memory.") class MixedPrecisionCombinedNMSTest(MixedPrecisionBaseTest): @@ -296,13 +293,12 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= np.unique(conv_layers[0].get_quantized_weights()['kernel'][:, :, :, i]).flatten().shape[0] <= 16 or np.unique(conv_layers[1].get_quantized_weights()['kernel'][:, :, :, i]).flatten().shape[0] <= 16) - # TODO maxcut: restore this test after total_memory is fixed to be the sum of weight & activation metrics. - # # Verify final ResourceUtilization - # self.unit_test.assertTrue( - # quantization_info.final_resource_utilization.weights_memory + quantization_info.final_resource_utilization.activation_memory == - # quantization_info.final_resource_utilization.total_memory, - # "Running weights mixed-precision with unconstrained ResourceUtilization, " - # "final weights and activation memory sum should be equal to total memory.") + # Verify final ResourceUtilization + self.unit_test.assertTrue( + quantization_info.final_resource_utilization.weights_memory + quantization_info.final_resource_utilization.activation_memory == + quantization_info.final_resource_utilization.total_memory, + "Running weights mixed-precision with unconstrained ResourceUtilization, " + "final weights and activation memory sum should be equal to total memory.") class MixedPrecisionSearch2BitsAvgTest(MixedPrecisionBaseTest): @@ -323,13 +319,12 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= self.unit_test.assertTrue( np.unique(conv_layers[1].get_quantized_weights()['kernel'][:, :, :, i]).flatten().shape[0] <= 4) - # TODO maxcut: restore this test after total_memory is fixed to be the sum of weight & activation metrics. - # # Verify final ResourceUtilization - # self.unit_test.assertTrue( - # quantization_info.final_resource_utilization.weights_memory + quantization_info.final_resource_utilization.activation_memory == - # quantization_info.final_resource_utilization.total_memory, - # "Running weights mixed-precision with unconstrained ResourceUtilization, " - # "final weights and activation memory sum should be equal to total memory.") + # Verify final ResourceUtilization + self.unit_test.assertTrue( + quantization_info.final_resource_utilization.weights_memory + quantization_info.final_resource_utilization.activation_memory == + quantization_info.final_resource_utilization.total_memory, + "Running weights mixed-precision with unconstrained ResourceUtilization, " + "final weights and activation memory sum should be equal to total memory.") class MixedPrecisionSearchActivationNonConfNodesTest(MixedPrecisionBaseTest): @@ -347,19 +342,18 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= self.unit_test.assertTrue(quantization_info.final_resource_utilization.activation_memory <= self.target_total_ru.activation_memory) - # TODO maxcut: restore this test after total_memory is fixed to be the sum of weight & activation metrics. - # self.unit_test.assertTrue( - # quantization_info.final_resource_utilization.weights_memory + quantization_info.final_resource_utilization.activation_memory == - # quantization_info.final_resource_utilization.total_memory, - # "Running weights mixed-precision with unconstrained Resource Utilization, " - # "final weights and activation memory sum should be equal to total memory.") + self.unit_test.assertTrue( + quantization_info.final_resource_utilization.weights_memory + quantization_info.final_resource_utilization.activation_memory == + quantization_info.final_resource_utilization.total_memory, + "Running weights mixed-precision with unconstrained Resource Utilization, " + "final weights and activation memory sum should be equal to total memory.") class MixedPrecisionSearchTotalMemoryNonConfNodesTest(MixedPrecisionBaseTest): def __init__(self, unit_test): super().__init__(unit_test) # Total ResourceUtilization for weights in 2 bit avg and non-configurable activation in 8 bit - self.target_total_ru = ResourceUtilization(total_memory=17920 * 2 / 8 + 5408) + self.target_total_ru = ResourceUtilization(total_memory=17920 * 2 / 8 + 6176) def get_resource_utilization(self): return self.target_total_ru @@ -369,12 +363,11 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= # we're only interested in the ResourceUtilization self.unit_test.assertTrue( quantization_info.final_resource_utilization.total_memory <= self.target_total_ru.total_memory) - # TODO maxcut: restore this test after total_memory is fixed to be the sum of weight & activation metrics. - # self.unit_test.assertTrue( - # quantization_info.final_resource_utilization.weights_memory + quantization_info.final_resource_utilization.activation_memory == - # quantization_info.final_resource_utilization.total_memory, - # "Running weights mixed-precision with unconstrained ResourceUtilization, " - # "final weights and activation memory sum should be equal to total memory.") + self.unit_test.assertTrue( + quantization_info.final_resource_utilization.weights_memory + quantization_info.final_resource_utilization.activation_memory == + quantization_info.final_resource_utilization.total_memory, + "Running weights mixed-precision with unconstrained ResourceUtilization, " + "final weights and activation memory sum should be equal to total memory.") class MixedPrecisionDepthwiseTest(MixedPrecisionBaseTest): @@ -479,13 +472,12 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= self.unit_test.assertTrue( np.unique(conv_layers[1].get_quantized_weights()['kernel'][:, :, :, i]).flatten().shape[0] <= 256) - # TODO maxcut: restore this test after total_memory is fixed to be the sum of weight & activation metrics. - # # Verify final Resource Utilization - # self.unit_test.assertTrue( - # quantization_info.final_resource_utilization.weights_memory + quantization_info.final_resource_utilization.activation_memory == - # quantization_info.final_resource_utilization.total_memory, - # "Running weights mixed-precision with unconstrained Resource Utilization, " - # "final weights and activation memory sum should be equal to total memory.") + # Verify final Resource Utilization + self.unit_test.assertTrue( + quantization_info.final_resource_utilization.weights_memory + quantization_info.final_resource_utilization.activation_memory == + quantization_info.final_resource_utilization.total_memory, + "Running weights mixed-precision with unconstrained Resource Utilization, " + "final weights and activation memory sum should be equal to total memory.") class MixedPrecisionWeightsOnlyConfigurableActivationsTest(MixedPrecisionBaseTest): diff --git a/tests/keras_tests/feature_networks_tests/test_features_runner.py b/tests/keras_tests/feature_networks_tests/test_features_runner.py index e34525ac1..3fe47c6fc 100644 --- a/tests/keras_tests/feature_networks_tests/test_features_runner.py +++ b/tests/keras_tests/feature_networks_tests/test_features_runner.py @@ -246,11 +246,10 @@ def test_mixed_precision_search(self): def test_mixed_precision_weights_only_activation_conf(self): MixedPrecisionWeightsOnlyConfigurableActivationsTest(self).run_test() - def test_requires_mixed_recision(self): + def test_requires_mixed_precision(self): RequiresMixedPrecisionWeights(self, weights_memory=True).run_test() RequiresMixedPrecision(self, activation_memory=True).run_test() - # TODO maxcut: restore this test after total_memory is fixed to be the sum of weight & activation metrics. - # RequiresMixedPrecision(self, total_memory=True).run_test() + RequiresMixedPrecision(self, total_memory=True).run_test() RequiresMixedPrecision(self, bops=True).run_test() RequiresMixedPrecision(self).run_test() diff --git a/tests/keras_tests/non_parallel_tests/test_lp_search_bitwidth.py b/tests/keras_tests/non_parallel_tests/test_lp_search_bitwidth.py index a922b68bd..0daddd449 100644 --- a/tests/keras_tests/non_parallel_tests/test_lp_search_bitwidth.py +++ b/tests/keras_tests/non_parallel_tests/test_lp_search_bitwidth.py @@ -61,14 +61,13 @@ def __init__(self, layer_to_ru_mapping): self.layer_to_bitwidth_mapping = {0: [0, 1, 2]} self.layer_to_ru_mapping = layer_to_ru_mapping self.compute_metric_fn = lambda x, y=None, z=None: {0: 2, 1: 1, 2: 0}[x[0]] - self.min_ru = {RUTarget.WEIGHTS: [[1], [1], [1]], - RUTarget.ACTIVATION: [[1], [1], [1]], - RUTarget.TOTAL: [[1, 1], [1, 1], [1, 1]], - RUTarget.BOPS: [[1], [1], [1]]} # minimal resource utilization in the tests layer_to_ru_mapping + self.min_ru = {RUTarget.WEIGHTS: [1], + RUTarget.ACTIVATION: [1], + RUTarget.BOPS: [1]} # minimal resource utilization in the tests layer_to_ru_mapping self.max_ru_config = [0] self.config_reconstruction_helper = MockReconstructionHelper() - self.non_conf_ru_dict = None + self.non_conf_ru_dict = {RUTarget.WEIGHTS: None, RUTarget.ACTIVATION: None, RUTarget.BOPS: None} def compute_resource_utilization_matrix(self, target): # minus 1 is normalization by the minimal resource utilization (which is always 1 in this test) @@ -76,15 +75,10 @@ def compute_resource_utilization_matrix(self, target): ru_matrix = [np.flip(np.array([ru.weights_memory - 1 for _, ru in self.layer_to_ru_mapping[0].items()]))] elif target == RUTarget.ACTIVATION: ru_matrix = [np.flip(np.array([ru.activation_memory - 1 for _, ru in self.layer_to_ru_mapping[0].items()]))] - elif target == RUTarget.TOTAL: - ru_matrix = [[np.flip(np.array([ru.weights_memory - 1 for _, ru in self.layer_to_ru_mapping[0].items()])), - np.flip(np.array([ru.activation_memory - 1 for _, ru in self.layer_to_ru_mapping[0].items()]))]] elif target == RUTarget.BOPS: ru_matrix = [np.flip(np.array([ru.bops - 1 for _, ru in self.layer_to_ru_mapping[0].items()]))] else: - # not supposed to get here - ru_matrix = [] - + raise ValueError('Not supposed to get here') return np.array(ru_matrix) def finalize_distance_metric(self, d): @@ -122,6 +116,26 @@ def test_search_weights_only(self): bit_cfg = mp_integer_programming_search(mock_search_manager, target_resource_utilization=target_resource_utilization) + def test_search_weights_only_with_non_conf(self): + target_resource_utilization = ResourceUtilization(weights_memory=2+11) + layer_to_ru_mapping = {0: {2: ResourceUtilization(weights_memory=1), + 1: ResourceUtilization(weights_memory=2), + 0: ResourceUtilization(weights_memory=3)} + } + mock_search_manager = MockMixedPrecisionSearchManager(layer_to_ru_mapping) + mock_search_manager.non_conf_ru_dict = {RUTarget.WEIGHTS: np.array([5, 6])} + bit_cfg = mp_integer_programming_search(mock_search_manager, + target_resource_utilization=target_resource_utilization) + + self.assertTrue(len(bit_cfg) == 1) + self.assertTrue(bit_cfg[0] == 1) + + # make sure non_conf was taken into account and lower target has a different solution + target_resource_utilization = ResourceUtilization(weights_memory=2 + 10.9) + bit_cfg = mp_integer_programming_search(mock_search_manager, + target_resource_utilization=target_resource_utilization) + self.assertFalse(bit_cfg[0] == 1) + def test_search_activation_only(self): target_resource_utilization = ResourceUtilization(activation_memory=2) layer_to_ru_mapping = {0: {2: ResourceUtilization(activation_memory=1),