From c2016ea936209eff0cb20619489beecfd9a544f8 Mon Sep 17 00:00:00 2001 From: Ofir Gordon Date: Mon, 6 Jan 2025 09:23:14 +0200 Subject: [PATCH] align tests and use CustomOpsetLayers --- .../feature_networks_tests/feature_networks/qat/qat_test.py | 4 ++-- .../keras_tests/function_tests/test_cfg_candidates_filter.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/qat/qat_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/qat/qat_test.py index 4d1f586bb..fd27723ae 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/qat/qat_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/qat/qat_test.py @@ -23,7 +23,7 @@ from mct_quantizers.common.base_inferable_quantizer import QuantizerID from mct_quantizers.common.get_all_subclasses import get_all_subclasses from mct_quantizers.keras.quantizers import BaseKerasInferableQuantizer -from model_compression_toolkit.core import QuantizationConfig +from model_compression_toolkit.core import QuantizationConfig, CustomOpsetLayers from model_compression_toolkit.qat.keras.quantizer.base_keras_qat_weight_quantizer import \ BaseKerasQATWeightTrainableQuantizer from model_compression_toolkit.trainable_infrastructure import TrainingMethod, KerasTrainableQuantizationWrapper, \ @@ -294,7 +294,7 @@ def __init__(self, unit_test, ru_weights=17919, ru_activation=5407, expected_mp_ def run_test(self, **kwargs): model_float = self.create_networks() config = mct.core.CoreConfig( - quantization_config=QuantizationConfig(custom_tpc_opset_to_layer={"Input": ([layers.InputLayer],)}) + quantization_config=QuantizationConfig(custom_tpc_opset_to_layer={"Input": CustomOpsetLayers([layers.InputLayer])}) ) qat_ready_model, quantization_info, custom_objects = mct.qat.keras_quantization_aware_training_init_experimental( model_float, diff --git a/tests/keras_tests/function_tests/test_cfg_candidates_filter.py b/tests/keras_tests/function_tests/test_cfg_candidates_filter.py index 7a382c447..421ee46ba 100644 --- a/tests/keras_tests/function_tests/test_cfg_candidates_filter.py +++ b/tests/keras_tests/function_tests/test_cfg_candidates_filter.py @@ -19,6 +19,7 @@ import model_compression_toolkit as mct from model_compression_toolkit.constants import FLOAT_BITWIDTH +from model_compression_toolkit.core import CustomOpsetLayers from model_compression_toolkit.core.common.quantization.filter_nodes_candidates import filter_nodes_candidates from model_compression_toolkit.core.common.quantization.set_node_quantization_config import \ set_quantization_configuration_to_graph @@ -51,7 +52,7 @@ def prepare_graph(in_model, base_config, default_config, bitwidth_candidates): graph = keras_impl.model_reader(in_model, None) # model reading attach2keras = AttachTpcToKeras() - tpc = attach2keras.attach(tpc, custom_opset2layer={"Input": ([InputLayer],)}) + tpc = attach2keras.attach(tpc, custom_opset2layer={"Input": CustomOpsetLayers([InputLayer])}) graph.set_tpc(tpc) graph.set_fw_info(fw_info)