Skip to content

Commit

Permalink
fix some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ofirgo committed Jan 2, 2025
1 parent 683d95a commit 8b4503e
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,12 @@ def attach(self, tpc_model: TargetPlatformModel,
elif opset.name in self._opset2layer:
# Note that if the user provided a custom operator set with a name that exists in our
# pre-defined set of operator sets, we prioritize the user's custom opset definition
attr_mapping = self._opset2attr_mapping.get(opset.name)
OperationsSetToLayers(opset.name, self._opset2layer[opset.name], attr_mapping=attr_mapping)
layers = self._opset2layer[opset.name]
if len(layers) > 0:
# If the framework does not define any matching operators to a given operator set name that
# appears in the TPC, then we just skip it
attr_mapping = self._opset2attr_mapping.get(opset.name)
OperationsSetToLayers(opset.name, layers, attr_mapping=attr_mapping)
else:
Logger.critical(f'{opset.name} is defined in TargetPlatformModel, '
f'but is not defined in the framework set of operators or in the provided '
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def generate_tp_model(default_config: OpQuantizationConfig,
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.OPSET_CAST.value, qc_options=no_quantization_config))
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.OPSET_COMBINED_NON_MAX_SUPPRESSION.value, qc_options=no_quantization_config))
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.OPSET_FAKE_QUANT.value, qc_options=no_quantization_config))
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.OPSET_SSD_POST_PROCESS.value, qc_options=no_quantization_config))

# Define operator sets that use mixed_precision_configuration_options:
conv = schema.OperatorsSet(name=schema.OperatorSetNames.OPSET_CONV.value, qc_options=mixed_precision_configuration_options)
Expand Down
1 change: 1 addition & 0 deletions tests/common_tests/helpers/tpcs_for_tests/v1/tp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def generate_tp_model(default_config: OpQuantizationConfig,
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.OPSET_CAST.value, qc_options=no_quantization_config))
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.OPSET_COMBINED_NON_MAX_SUPPRESSION.value, qc_options=no_quantization_config))
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.OPSET_FAKE_QUANT.value, qc_options=no_quantization_config))
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.OPSET_SSD_POST_PROCESS.value, qc_options=no_quantization_config))

# Define operator sets that use mixed_precision_configuration_options:
conv = schema.OperatorsSet(name=schema.OperatorSetNames.OPSET_CONV.value, qc_options=mixed_precision_configuration_options)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import model_compression_toolkit as mct
from sony_custom_layers.keras.object_detection.ssd_post_process import SSDPostProcess
from mct_quantizers.keras.metadata import MetadataLayer
from tests.common_tests.helpers.tpcs_for_tests.v4.tp_model import get_tp_model

keras = tf.keras
layers = keras.layers
Expand Down Expand Up @@ -54,7 +55,9 @@ def test_custom_layer(self):
q_model, _ = mct.ptq.keras_post_training_quantization(model,
get_rep_dataset(2, (1, 8, 8, 3)),
core_config=core_config,
target_resource_utilization=mct.core.ResourceUtilization(weights_memory=6000))
target_resource_utilization=mct.core.ResourceUtilization(weights_memory=6000),
target_platform_capabilities=get_tp_model()
)

# verify the custom layer is in the quantized model
last_model_layer_index = -2 if isinstance(q_model.layers[-1], MetadataLayer) else -1
Expand Down

0 comments on commit 8b4503e

Please sign in to comment.