Skip to content

Commit

Permalink
Align const quant with converter (#1099)
Browse files Browse the repository at this point in the history
* Fix const quantization to support "per-axis", and add TPV.v3
* Select quantization axis if per_channel and axis is None, according to lowest MSE.
  • Loading branch information
elad-c authored Jun 9, 2024
1 parent 12ba094 commit b0c2fdd
Show file tree
Hide file tree
Showing 27 changed files with 1,316 additions and 248 deletions.
5 changes: 1 addition & 4 deletions model_compression_toolkit/core/common/graph/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,7 @@ def insert_positional_weights_to_input_list(self, input_tensors: List) -> List:
if isinstance(pos, int)):
if pos > len(input_tensors):
Logger.critical("The positional weight index cannot exceed the number of input tensors to the node.") # pragma: no cover
# Insert only positional weights that are not subject to quantization. If the positional weight is
# subject to quantization, the quantization wrapper inserts the positional weight into the node.
if not self.is_weights_quantization_enabled(pos):
input_tensors.insert(pos, weight)
input_tensors.insert(pos, weight)

return input_tensors

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,13 +326,17 @@ def calculate_and_set_weights_params(self, tensor_data: np.ndarray, min_threshol
"""
assert self.enable_weights_quantization
assert not (self.weights_per_channel_threshold and self.weights_channels_axis is None), \
"Trying to calculate threshold per channel, channel axis in None."
if self.weights_quantization_params_fn is not None:
self.set_weights_quantization_param(self.weights_quantization_params_fn(tensor_data,
p=self.l_p_value,
n_bits=self.weights_n_bits,
per_channel=self.weights_per_channel_threshold and self.weights_channels_axis is not None,
channel_axis=self.weights_channels_axis[0], # output channel axis
min_threshold=min_threshold))
self.set_weights_quantization_param(
self.weights_quantization_params_fn(tensor_data,
p=self.l_p_value,
n_bits=self.weights_n_bits,
per_channel=self.weights_per_channel_threshold and self.weights_channels_axis is not None,
channel_axis=self.weights_channels_axis[0], # output channel axis
min_threshold=min_threshold)[0] # Take only first output, the q-params, as axis is already chosen.
)
else:
self.set_weights_quantization_param({})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================

from typing import Dict
from typing import Dict, Tuple
import numpy as np
from sklearn.cluster import KMeans

Expand Down Expand Up @@ -42,7 +42,8 @@ def lut_kmeans_tensor(tensor_data: np.ndarray,
is_symmetric: bool = False,
node=None,
hessian_info_service: HessianInfoService = None,
num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES) -> Dict:
num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES,
) -> Tuple[Dict[str, np.ndarray], int]:
"""
The quantizer first finds the closest max value per channel of tensor_data.
Now, we divide tensor_data with the threshold vector per channel. In addition, we scale the result to the range
Expand Down Expand Up @@ -70,27 +71,34 @@ def lut_kmeans_tensor(tensor_data: np.ndarray,
if n_bits >= LUT_VALUES_BITWIDTH:
Logger.critical(f'Look-Up-Table (LUT) bit configuration exceeds maximum: {n_bits} bits provided, must be less than {LUT_VALUES_BITWIDTH} bits.') # pragma: no cover
# TODO: need to set this externally
n_data_points = len(np.unique(tensor_data.flatten()))
if len(np.unique(tensor_data.flatten())) < 2 ** n_bits:
n_clusters = len(np.unique(tensor_data.flatten()))
n_clusters = n_data_points
else:
n_clusters = 2 ** n_bits
kmeans = KMeans(n_clusters=n_clusters, n_init=10)

threshold_selection_tensor = symmetric_selection_tensor if is_symmetric else power_of_two_selection_tensor
thresholds_per_channel = threshold_selection_tensor(tensor_data, p, n_bits, per_channel,
channel_axis, n_iter, min_threshold,
qc.QuantizationErrorMethod.NOCLIPPING)[THRESHOLD]

_params, channel_axis = threshold_selection_tensor(tensor_data, p, n_bits, per_channel,
channel_axis, n_iter, min_threshold,
qc.QuantizationErrorMethod.NOCLIPPING)
thresholds_per_channel = _params[THRESHOLD]

tensor_for_kmeans = int_quantization_with_threshold(tensor_data, thresholds_per_channel, LUT_VALUES_BITWIDTH)
kmeans.fit(tensor_for_kmeans.reshape(-1, 1))

# Add 0 to the LUT
cc = np.round(kmeans.cluster_centers_)
if n_data_points < 2 ** n_bits and np.all(cc != 0):
# In case there are fewer data points than potential clusters, we can add the cluster 0.0
# to the original clusters array to improve quantization (i.e. no need to zero one of the clusters).
cc = np.concatenate([np.zeros([1, 1], dtype=cc.dtype), cc])
closest2zero_idx = (np.abs(cc - 0)).argmin()
cc[closest2zero_idx] = 0.0

return {LUT_VALUES: cc,
SCALE_PER_CHANNEL: thresholds_per_channel}
SCALE_PER_CHANNEL: thresholds_per_channel}, channel_axis


def lut_kmeans_histogram(bins: np.ndarray,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
import numpy as np
from typing import Union, Tuple, Dict

import model_compression_toolkit.core.common.quantization.quantization_config as qc
from model_compression_toolkit.constants import MIN_THRESHOLD, THRESHOLD, NUM_QPARAM_HESSIAN_SAMPLES
Expand All @@ -23,20 +24,22 @@
from model_compression_toolkit.core.common.quantization.quantization_params_generation.error_functions import \
get_threshold_selection_tensor_error_function, get_threshold_selection_histogram_error_function
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
from model_compression_toolkit.core.common.similarity_analyzer import compute_mse
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import quantize_tensor


def power_of_two_selection_tensor(tensor_data: np.ndarray,
p: int,
n_bits: int,
per_channel: bool = False,
channel_axis: int = 1,
channel_axis: Union[int, None] = 1,
n_iter: int = 10,
min_threshold: float = MIN_THRESHOLD,
quant_error_method: qc.QuantizationErrorMethod = qc.QuantizationErrorMethod.MSE,
node=None,
hessian_info_service: HessianInfoService = None,
num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES,
) -> dict:
) -> Tuple[Dict[str, np.ndarray], int]:
"""
Compute the power of two threshold based on the provided QuantizationErrorMethod to quantize the tensor.
Different search is applied, depends on the value of the selected QuantizationErrorMethod.
Expand All @@ -46,7 +49,7 @@ def power_of_two_selection_tensor(tensor_data: np.ndarray,
p: p-norm to use for the Lp-norm distance.
n_bits: Number of bits to quantize the tensor.
per_channel: Whether the quantization should be per-channel or not.
channel_axis: Output channel index.
channel_axis: Output channel index. if None, search for best axis.
n_iter: Number of iterations to search for the optimal threshold (not used for this method).
min_threshold: Minimal threshold to use if threshold is too small (not used for this method).
quant_error_method: an error function to optimize the parameters' selection accordingly.
Expand All @@ -56,11 +59,24 @@ def power_of_two_selection_tensor(tensor_data: np.ndarray,
Returns:
Power of two threshold to quantize the tensor in a power of 2 manner.
Selected quantization channel axis.
"""

if quant_error_method == qc.QuantizationErrorMethod.NOCLIPPING:
tensor_max = get_tensor_max(tensor_data, per_channel, channel_axis, n_bits)
threshold = max_power_of_two(tensor_max, min_threshold)
if channel_axis is None and per_channel:
total_error_list = []
th_list = []
for _axis in range(len(tensor_data.shape)):
tensor_max = get_tensor_max(tensor_data, per_channel, _axis, n_bits)
threshold = max_power_of_two(tensor_max, min_threshold)
q_tensor_data = quantize_tensor(tensor_data, threshold, n_bits, True)
total_error_list.append(compute_mse(tensor_data, q_tensor_data, norm=True))
th_list.append(threshold)
channel_axis = np.argmin(total_error_list)
threshold = th_list[channel_axis]
else:
tensor_max = get_tensor_max(tensor_data, per_channel, channel_axis, n_bits)
threshold = max_power_of_two(tensor_max, min_threshold)
else:
signed = True # weights are always signed
axis = -1 if per_channel else None
Expand All @@ -69,15 +85,15 @@ def power_of_two_selection_tensor(tensor_data: np.ndarray,
n_bits=n_bits, signed=signed, node=node,
hessian_info_service=hessian_info_service,
num_hessian_samples=num_hessian_samples)
threshold = qparams_selection_tensor_search(error_function,
tensor_data,
n_bits,
per_channel=per_channel,
channel_axis=channel_axis,
n_iter=n_iter,
min_threshold=min_threshold,
signed=signed)
return {THRESHOLD: threshold}
threshold, channel_axis = qparams_selection_tensor_search(error_function,
tensor_data,
n_bits,
per_channel=per_channel,
channel_axis=channel_axis,
n_iter=n_iter,
min_threshold=min_threshold,
signed=signed)
return {THRESHOLD: threshold}, channel_axis


def power_of_two_selection_histogram(bins: np.ndarray,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,14 @@ def calculate_quantization_params(graph: Graph,
mod_attr_cfg = copy.deepcopy(attr_cfg)
mod_attr_cfg.weights_error_method = QuantizationErrorMethod.MSE

weights_params = get_weights_qparams(n.get_weights_by_keys(attr),
candidate_qc.weights_quantization_cfg,
mod_attr_cfg,
output_channels_axis,
node=n,
hessian_info_service=hessian_info_service,
num_hessian_samples=num_hessian_samples)
weights_params, output_channels_axis = get_weights_qparams(n.get_weights_by_keys(attr),
candidate_qc.weights_quantization_cfg,
mod_attr_cfg,
output_channels_axis,
node=n,
hessian_info_service=hessian_info_service,
num_hessian_samples=num_hessian_samples)
attr_cfg.weights_channels_axis = (output_channels_axis, attr_cfg.weights_channels_axis[1])
attr_cfg.set_weights_quantization_param(weights_params)

if n.is_activation_quantization_enabled():
Expand Down
Loading

0 comments on commit b0c2fdd

Please sign in to comment.