Skip to content

Commit

Permalink
remove configurable aggregation function
Browse files Browse the repository at this point in the history
  • Loading branch information
irenaby committed Jan 12, 2025
1 parent a38034d commit 488a285
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,25 +99,6 @@ def __lt__(self, other: 'Utilization'):
return self.bytes < other.bytes


class AggregationMethod(Enum):
SUM = sum
MAX = lambda seq: max(seq) if (seq := list(seq)) else 0 # walrus op for empty generator

def __call__(self, *args, **kwarg):
return self.value(*args, **kwarg)


# default aggregation methods
# TODO This is used by mp to use the same aggregation. Except that for total it must do its own thing (add indicators
# to weights before summation). So maybe just get rid of it altogether? If it ever becomes configurable we can add it.
ru_target_aggregation_fn = {
RUTarget.WEIGHTS: AggregationMethod.SUM,
RUTarget.ACTIVATION: AggregationMethod.MAX,
RUTarget.TOTAL: AggregationMethod.SUM,
RUTarget.BOPS: AggregationMethod.SUM
}


class ResourceUtilizationCalculator:
""" Resource utilization calculator. """

Expand Down Expand Up @@ -226,8 +207,7 @@ def compute_weights_utilization(self,
util_per_node[n] = node_weights_util
util_per_node_per_weight[n] = per_weight_util

aggregate_fn = ru_target_aggregation_fn[RUTarget.WEIGHTS]
total_util = aggregate_fn(util_per_node.values())
total_util = sum(util_per_node.values())
return total_util.bytes, util_per_node, util_per_node_per_weight

def compute_node_weights_utilization(self,
Expand Down Expand Up @@ -334,8 +314,7 @@ def compute_cut_activation_utilization(self,
bitwidth_mode, qc)
util_per_cut[cut] = sum(util_per_cut_per_node[cut].values()) # type: ignore

aggregate_fn = ru_target_aggregation_fn[RUTarget.ACTIVATION]
total_util = aggregate_fn(util_per_cut.values())
total_util = max(util_per_cut.values())
return total_util.bytes, util_per_cut, util_per_cut_per_node

def compute_activation_tensors_utilization(self,
Expand Down Expand Up @@ -369,8 +348,7 @@ def compute_activation_tensors_utilization(self,
util = self.compute_node_activation_tensor_utilization(n, None, bitwidth_mode, qc)
util_per_node[n] = util

aggregate_fn = ru_target_aggregation_fn[RUTarget.ACTIVATION]
total_util = aggregate_fn(util_per_node.values())
total_util = max(util_per_node.values())
return total_util.bytes, util_per_node

def compute_node_activation_tensor_utilization(self,
Expand Down Expand Up @@ -438,8 +416,7 @@ def compute_bops(self,
w_qc = w_qcs.get(n) if w_qcs else None
nodes_bops[n] = self.compute_node_bops(n, bitwidth_mode, act_qcs=act_qcs, w_qc=w_qc)

aggregate_fn = ru_target_aggregation_fn[RUTarget.BOPS]
return aggregate_fn(nodes_bops.values()), nodes_bops
return sum(nodes_bops.values()), nodes_bops

def compute_node_bops(self,
n: BaseNode,
Expand Down Expand Up @@ -621,8 +598,7 @@ def _get_activation_nbits(cls,
if act_qc:
if bitwidth_mode != BitwidthMode.QCustom:
raise ValueError(f'Activation config is not expected for non-custom bit mode {bitwidth_mode}')
assert act_qc.enable_activation_quantization or act_qc.activation_n_bits == FLOAT_BITWIDTH
return act_qc.activation_n_bits
return act_qc.activation_n_bits if act_qc.enable_activation_quantization else FLOAT_BITWIDTH

if bitwidth_mode == BitwidthMode.Float or not n.is_activation_quantization_enabled():
return FLOAT_BITWIDTH
Expand Down Expand Up @@ -667,8 +643,7 @@ def _get_weight_nbits(cls,
if bitwidth_mode != BitwidthMode.QCustom:
raise ValueError('Weight config is not expected for non-custom bit mode {bitwidth_mode}')
attr_cfg = w_qc.get_attr_config(w_attr)
assert attr_cfg.enable_weights_quantization or attr_cfg.weights_n_bits == FLOAT_BITWIDTH
return attr_cfg.weights_n_bits
return attr_cfg.weights_n_bits if attr_cfg.enable_weights_quantization else FLOAT_BITWIDTH

if bitwidth_mode == BitwidthMode.Float or not n.is_weights_quantization_enabled(w_attr):
return FLOAT_BITWIDTH
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
from tqdm import tqdm
from typing import Dict, Tuple

from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_calculator import \
ru_target_aggregation_fn, AggregationMethod
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization, RUTarget
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_search_manager import MixedPrecisionSearchManager
Expand Down Expand Up @@ -250,14 +248,14 @@ def _aggregate_for_lp(ru_vec, target: RUTarget) -> list:
w = lpSum(v[0] for v in ru_vec)
return [w + v[1] for v in ru_vec]

if ru_target_aggregation_fn[target] == AggregationMethod.SUM:
if target in [RUTarget.WEIGHTS, RUTarget.BOPS]:
return [lpSum(ru_vec)]

if ru_target_aggregation_fn[target] == AggregationMethod.MAX:
if target == RUTarget.ACTIVATION:
# for max aggregation, each value constitutes a separate constraint
return list(ru_vec)

raise NotImplementedError(f'Cannot define lp constraints with unsupported aggregation function '
f'{ru_target_aggregation_fn[target]}') # pragma: no cover
raise ValueError(f'Unexpected target {target}.')


def _build_layer_to_metrics_mapping(search_manager: MixedPrecisionSearchManager,
Expand Down

0 comments on commit 488a285

Please sign in to comment.