diff --git a/tests/pytorch_tests/model_tests/feature_models/multi_head_attention_test.py b/tests/pytorch_tests/model_tests/feature_models/multi_head_attention_test.py index 545f36b19..ce98b4d9f 100644 --- a/tests/pytorch_tests/model_tests/feature_models/multi_head_attention_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/multi_head_attention_test.py @@ -20,8 +20,7 @@ import model_compression_toolkit as mct from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities -from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO -from tests.common_tests.helpers.generate_test_tp_model import generate_test_tp_model +from tests.common_tests.helpers.generate_test_tpc import generate_test_tpc from tests.pytorch_tests.model_tests.base_pytorch_test import BasePytorchTest """ @@ -57,7 +56,7 @@ def create_inputs_shape(self): def get_tpc(self): tpc = { - 'no_quantization': generate_test_tp_model({ + 'no_quantization': generate_test_tpc({ 'weights_n_bits': 32, 'activation_n_bits': 32, 'enable_weights_quantization': False, @@ -65,10 +64,10 @@ def get_tpc(self): }) } if self.num_heads < 5: - tpc['all_4bit'] = generate_test_tp_model({'weights_n_bits': 4, - 'activation_n_bits': 4, - 'enable_weights_quantization': True, - 'enable_activation_quantization': True}) + tpc['all_4bit'] = generate_test_tpc({'weights_n_bits': 4, + 'activation_n_bits': 4, + 'enable_weights_quantization': True, + 'enable_activation_quantization': True}) return tpc