Skip to content

Commit

Permalink
fix pytorch test
Browse files Browse the repository at this point in the history
  • Loading branch information
ofirgo committed Jan 13, 2025
1 parent eb27f2c commit 9b0fbb9
Showing 1 changed file with 6 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@

import model_compression_toolkit as mct
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
from tests.common_tests.helpers.generate_test_tp_model import generate_test_tp_model
from tests.common_tests.helpers.generate_test_tpc import generate_test_tpc
from tests.pytorch_tests.model_tests.base_pytorch_test import BasePytorchTest

"""
Expand Down Expand Up @@ -57,18 +56,18 @@ def create_inputs_shape(self):

def get_tpc(self):
tpc = {
'no_quantization': generate_test_tp_model({
'no_quantization': generate_test_tpc({
'weights_n_bits': 32,
'activation_n_bits': 32,
'enable_weights_quantization': False,
'enable_activation_quantization': False
})
}
if self.num_heads < 5:
tpc['all_4bit'] = generate_test_tp_model({'weights_n_bits': 4,
'activation_n_bits': 4,
'enable_weights_quantization': True,
'enable_activation_quantization': True})
tpc['all_4bit'] = generate_test_tpc({'weights_n_bits': 4,
'activation_n_bits': 4,
'enable_weights_quantization': True,
'enable_activation_quantization': True})
return tpc


Expand Down

0 comments on commit 9b0fbb9

Please sign in to comment.