Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor QuantizationConfigOptions #1088

Merged
merged 5 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def __eq__(self, other):
self.simd_size == other.simd_size


class QuantizationConfigOptions(object):
class QuantizationConfigOptions:
"""

Wrap a set of quantization configurations to consider during the quantization
Expand All @@ -215,19 +215,24 @@ def __init__(self,
"""

assert isinstance(quantization_config_list,
list), f'\'QuantizationConfigOptions\' options list must be a list, but received: {type(quantization_config_list)}.'
assert len(quantization_config_list) > 0, f'Options list can not be empty.'
ofirgo marked this conversation as resolved.
Show resolved Hide resolved
list), f"'QuantizationConfigOptions' options list must be a list, but received: {type(quantization_config_list)}."
for cfg in quantization_config_list:
assert isinstance(cfg, OpQuantizationConfig), f'Each option must be an instance of \'OpQuantizationConfig\', but found an object of type: {type(cfg)}.'
assert isinstance(cfg, OpQuantizationConfig),\
f"Each option must be an instance of 'OpQuantizationConfig', but found an object of type: {type(cfg)}."
self.quantization_config_list = quantization_config_list
if len(quantization_config_list) > 1:
assert base_config is not None, f'For multiple configurations, a \'base_config\' is required for non-mixed-precision optimization.'
assert base_config in quantization_config_list, f"\'base_config\' must be included in the quantization config options list."
assert base_config is not None, \
f"For multiple configurations, a 'base_config' is required for non-mixed-precision optimization."
assert any([base_config is cfg for cfg in quantization_config_list]), \
f"'base_config' must be included in the quantization config options list."
# Enforce base_config to be a different instance from the one in quantization_config_list
elad-c marked this conversation as resolved.
Show resolved Hide resolved
self.base_config = base_config
elif len(quantization_config_list) == 1:
assert base_config is None or base_config == quantization_config_list[0], "'base_config' should be included in 'quantization_config_list'"
# Override base_config to be a different instance from the one in quantization_config_list
elad-c marked this conversation as resolved.
Show resolved Hide resolved
self.base_config = quantization_config_list[0]
else:
Logger.critical("\'QuantizationConfigOptions\' requires at least one \'OpQuantizationConfig\'; the provided list is empty.")
raise AssertionError("'QuantizationConfigOptions' requires at least one 'OpQuantizationConfig'. The provided list is empty.")

def __eq__(self, other):
"""
Expand Down
11 changes: 8 additions & 3 deletions tests/common_tests/helpers/generate_test_tp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ def generate_mixed_precision_test_tp_model(base_cfg, default_config, mp_bitwidth
candidate_cfg = base_cfg.clone_and_edit(attr_to_edit={KERNEL_ATTR: {WEIGHTS_N_BITS: weights_n_bits}},
activation_n_bits=activation_n_bits)

mp_op_cfg_list.append(candidate_cfg)
if candidate_cfg == base_cfg:
mp_op_cfg_list.append(base_cfg)
elad-c marked this conversation as resolved.
Show resolved Hide resolved
else:
mp_op_cfg_list.append(candidate_cfg)

return generate_tp_model(default_config=default_config,
base_config=base_cfg,
Expand All @@ -85,8 +88,10 @@ def generate_tp_model_with_activation_mp(base_cfg, default_config, mp_bitwidth_c
**{k: v for k, v in base_cfg.attr_weights_configs_mapping.items() if
k != KERNEL_ATTR}},
activation_n_bits=activation_n_bits)

mp_op_cfg_list.append(candidate_cfg)
if candidate_cfg == base_cfg:
mp_op_cfg_list.append(base_cfg)
else:
mp_op_cfg_list.append(candidate_cfg)

base_tp_model = generate_tp_model(default_config=default_config,
base_config=base_cfg,
Expand Down
3 changes: 2 additions & 1 deletion tests/common_tests/test_tp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ class QCOptionsTest(unittest.TestCase):
def test_empty_qc_options(self):
with self.assertRaises(AssertionError) as e:
tp.QuantizationConfigOptions([])
self.assertEqual('Options list can not be empty.', str(e.exception))
self.assertEqual("'QuantizationConfigOptions' requires at least one 'OpQuantizationConfig'. The provided list is empty.",
str(e.exception))

def test_list_of_no_qc(self):
with self.assertRaises(AssertionError) as e:
Expand Down
Loading