diff --git a/model_compression_toolkit/core/pytorch/constants.py b/model_compression_toolkit/core/pytorch/constants.py index f024e0d49..7fade1758 100644 --- a/model_compression_toolkit/core/pytorch/constants.py +++ b/model_compression_toolkit/core/pytorch/constants.py @@ -66,6 +66,10 @@ IN_FEATURES = 'in_features' OUT_FEATURES = 'out_features' +# # Reserved layer names +RESERVED_NAME_TO = 'to' +RESERVED_NAME_SUFFIX = 'node' + # torch devices CUDA = 'cuda' CPU = 'cpu' diff --git a/model_compression_toolkit/core/pytorch/reader/graph_builders.py b/model_compression_toolkit/core/pytorch/reader/graph_builders.py index 78b3b3400..ceacf220f 100644 --- a/model_compression_toolkit/core/pytorch/reader/graph_builders.py +++ b/model_compression_toolkit/core/pytorch/reader/graph_builders.py @@ -25,7 +25,8 @@ from model_compression_toolkit.core.common.graph.edge import Edge from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode from model_compression_toolkit.core.pytorch.constants import OUTPUT, PLACEHOLDER, TENSOR_META, CALL_FUNCTION, TYPE, \ - CALL_METHOD, BIAS, FUNCTIONAL_OP, OP_CALL_KWARGS, OP_CALL_ARGS, INPUTS_AS_LIST, TENSOR_INPUT_ALLOCS, GET_ATTR + CALL_METHOD, BIAS, FUNCTIONAL_OP, OP_CALL_KWARGS, OP_CALL_ARGS, INPUTS_AS_LIST, TENSOR_INPUT_ALLOCS, GET_ATTR, \ + RESERVED_NAME_TO, RESERVED_NAME_SUFFIX from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder from model_compression_toolkit.logger import Logger @@ -284,6 +285,9 @@ def nodes_builder(model: GraphModule, graph_node_type = BaseNode kwargs = {} + if node.name == RESERVED_NAME_TO: + node.name = node.name + RESERVED_NAME_SUFFIX + graph_node = graph_node_type(name=node.name, framework_attr=framework_attr, input_shape=input_shape,