Skip to content

Commit

Permalink
Refactor Target Platform Capabilities Design (#1276)
Browse files Browse the repository at this point in the history
* Refactor Target Platform Capabilities Design

- Create a new `schema` package to house all target platform modeling classes
- Introduce a new versioning system with minor and patch versions

Additional Changes:
- Created mct_current_schema.py for one place to update the schema version, and replaced all the imports in mct to work with this location.
- Update existing target platform models to adhere to the new versioning convention
- Add necessary metadata
- Correct all import statements
- Update and enhance tests to reflect the design changes
- Remove unused files

---------

Co-authored-by: liord <lior.dikstein@altair-semi.com>
  • Loading branch information
lior-dikstein and liord authored Dec 2, 2024
1 parent 0c6d0b0 commit 28b461b
Show file tree
Hide file tree
Showing 75 changed files with 1,672 additions and 1,392 deletions.
3 changes: 0 additions & 3 deletions model_compression_toolkit/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
TENSORFLOW = 'tensorflow'
PYTORCH = 'pytorch'

# Metadata fields
MCT_VERSION = 'mct_version'
TPC_VERSION = 'tpc_version'

WEIGHTS_SIGNED = True
# Minimal threshold to use for quantization ranges:
Expand Down
5 changes: 3 additions & 2 deletions model_compression_toolkit/core/common/graph/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
ACTIVATION_N_BITS_ATTRIBUTE, FP32_BYTES_PER_PARAMETER
from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationConfigOptions, \
TargetPlatformCapabilities, LayerFilterParams, OpQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import QuantizationConfigOptions, \
OpQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities, LayerFilterParams


class BaseNode:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX
from model_compression_toolkit.core.graph_prep_runner import graph_preparation_runner
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities, \
QuantizationConfigOptions
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import QuantizationConfigOptions


def compute_resource_utilization_data(in_model: Any,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from model_compression_toolkit.core import QuantizationConfig
from model_compression_toolkit.core.common.quantization.node_quantization_config import BaseNodeQuantizationConfig, \
NodeWeightsQuantizationConfig, NodeActivationQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.target_platform import OpQuantizationConfig, \
AttributeQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import AttributeQuantizationConfig, \
OpQuantizationConfig
from model_compression_toolkit.logger import Logger


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@

from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig, \
QuantizationErrorMethod
from model_compression_toolkit.target_platform_capabilities.target_platform import OpQuantizationConfig, \
AttributeQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import AttributeQuantizationConfig, \
OpQuantizationConfig


##########################################
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import numpy as np
from typing import Dict, Union

from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod, Signedness
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import Signedness
from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
from model_compression_toolkit.core.common.quantization import quantization_params_generation
from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import \
get_weights_quantization_fn
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import OpQuantizationConfig, \
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig, \
QuantizationConfigOptions


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from model_compression_toolkit.core.common import BaseNode, Graph
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.target_platform import AttributeQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import AttributeQuantizationConfig


def apply_activation_bias_correction_to_graph(graph: Graph,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from model_compression_toolkit.core.common import Graph, BaseNode
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.target_platform import AttributeQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import AttributeQuantizationConfig


def apply_bias_correction_to_graph(graph_to_apply_bias_correction: Graph,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
from model_compression_toolkit.core.common.graph.base_graph import Graph
from model_compression_toolkit.core.common.graph.base_node import BaseNode
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod, \
AttributeQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import AttributeQuantizationConfig


class BatchNormalizationReconstruction(common.BaseSubstitution):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from model_compression_toolkit.core.common import FrameworkInfo, Graph, BaseNode
from model_compression_toolkit.constants import THRESHOLD, SIGNED, SHIFT_NEGATIVE_NON_LINEAR_NUM_BITS
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod, \
AttributeQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import AttributeQuantizationConfig
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import create_node_activation_qc, \
set_quantization_configs_to_node
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
Expand Down
19 changes: 14 additions & 5 deletions model_compression_toolkit/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from dataclasses import dataclass, asdict

from typing import Dict, Any
from model_compression_toolkit.constants import MCT_VERSION, TPC_VERSION, OPERATORS_SCHEDULING, FUSED_NODES_MAPPING, \
CUTS, MAX_CUT, OP_ORDER, OP_RECORD, SHAPE, NODE_OUTPUT_INDEX, NODE_NAME, TOTAL_SIZE, MEM_ELEMENTS
from model_compression_toolkit.constants import OPERATORS_SCHEDULING, FUSED_NODES_MAPPING, CUTS, MAX_CUT, OP_ORDER, \
OP_RECORD, SHAPE, NODE_OUTPUT_INDEX, NODE_NAME, TOTAL_SIZE, MEM_ELEMENTS
from model_compression_toolkit.core.common.graph.memory_graph.compute_graph_max_cut import SchedulerInfo
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities

Expand Down Expand Up @@ -43,13 +44,21 @@ def create_model_metadata(tpc: TargetPlatformCapabilities,
def get_versions_dict(tpc) -> Dict:
"""
Returns: A dictionary with TPC and MCT versions.
Returns: A dictionary with TPC, MCT and TPC-Schema versions.
"""
# imported inside to avoid circular import error
from model_compression_toolkit import __version__ as mct_version
tpc_version = f'{tpc.name}.{tpc.version}'
return {MCT_VERSION: mct_version, TPC_VERSION: tpc_version}

@dataclass
class TPCVersions:
mct_version: str
tpc_minor_version: str = f'{tpc.tp_model.tpc_minor_version}'
tpc_patch_version: str = f'{tpc.tp_model.tpc_patch_version}'
tpc_platform_type: str = f'{tpc.tp_model.tpc_platform_type}'
tpc_schema: str = f'{tpc.tp_model.SCHEMA_VERSION}'

return asdict(TPCVersions(mct_version))


def get_scheduler_metadata(scheduler_info: SchedulerInfo) -> Dict[str, Any]:
Expand Down
2 changes: 1 addition & 1 deletion model_compression_toolkit/qat/keras/quantizer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Several training methods may be applied by the user to train the QAT ready model
created by `keras_quantization_aware_training_init` method in [`keras/quantization_facade`](../quantization_facade.py).
Each `TrainingMethod` (an enum defined in the [`qat_config`](../../common/qat_config.py))
and [`QuantizationMethod`](../../../target_platform_capabilities/target_platform/op_quantization_config.py)
and `QuantizationMethod`
selects a quantizer for weights and a quantizer for activations.

Currently, only the STE (straight through estimator) training method is implemented by the MCT.
Expand Down
2 changes: 1 addition & 1 deletion model_compression_toolkit/qat/pytorch/quantizer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Several training methods may be applied by the user to train the QAT ready model
created by `pytorch_quantization_aware_training_init` method in [`pytorch/quantization_facade`](../quantization_facade.py).
Each [`TrainingMethod`](../../../trainable_infrastructure/common/training_method.py)
and [`QuantizationMethod`](../../../target_platform_capabilities/target_platform/op_quantization_config.py)
and `QuantizationMethod`
selects a quantizer for weights and a quantizer for activations.

## Make your own training method
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import model_compression_toolkit.target_platform_capabilities.schema.v1 as schema

Signedness = schema.Signedness
AttributeQuantizationConfig = schema.AttributeQuantizationConfig
OpQuantizationConfig = schema.OpQuantizationConfig
QuantizationConfigOptions = schema.QuantizationConfigOptions
OperatorsSetBase = schema.OperatorsSetBase
OperatorsSet = schema.OperatorsSet
OperatorSetConcat= schema.OperatorSetConcat
Fusing = schema.Fusing
TargetPlatformModel = schema.TargetPlatformModel
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import copy
from typing import Any, Dict


def clone_and_edit_object_params(obj: Any, **kwargs: Dict) -> Any:
"""
Clones the given object and edit some of its parameters.
Args:
obj: An object to clone.
**kwargs: Keyword arguments to edit in the cloned object.
Returns:
Edited copy of the given object.
"""

obj_copy = copy.deepcopy(obj)
for k, v in kwargs.items():
assert hasattr(obj_copy,
k), f'Edit parameter is possible only for existing parameters in the given object, ' \
f'but {k} is not a parameter of {obj_copy}.'
setattr(obj_copy, k, v)
return obj_copy
Loading

0 comments on commit 28b461b

Please sign in to comment.