From 738e3cce7bc63f0919edf3354d2d4299fefe0f91 Mon Sep 17 00:00:00 2001 From: elad-c Date: Mon, 10 Jun 2024 17:53:09 +0300 Subject: [PATCH] rename tensor_input_indices to tensor_input_allocs as it's no longer a list of integers, but also strings. --- .../core/common/graph/functional_node.py | 6 ++--- .../back2framework/pytorch_model_builder.py | 22 +++++++++---------- .../core/pytorch/constants.py | 2 +- .../reshape_with_static_shapes.py | 6 ++--- .../core/pytorch/reader/graph_builders.py | 10 ++++----- 5 files changed, 23 insertions(+), 23 deletions(-) diff --git a/model_compression_toolkit/core/common/graph/functional_node.py b/model_compression_toolkit/core/common/graph/functional_node.py index ccc6e4bc2..9673b6488 100644 --- a/model_compression_toolkit/core/common/graph/functional_node.py +++ b/model_compression_toolkit/core/common/graph/functional_node.py @@ -25,7 +25,7 @@ def __init__(self, functional_op: Any = None, inputs_as_list: bool = False, has_activation: bool = True, - tensor_input_indices = None): + tensor_input_allocs = None): """ Init a FunctionalNode object. @@ -44,7 +44,7 @@ def __init__(self, functional_op: The op the node implements. inputs_as_list: Whether to pass the node its input tensors as a list or not when calling the layer. has_activation: Whether the node has activations that we might want to quantize. - tensor_input_indices: A list of indices for activation tensors in the node's input tensor list + tensor_input_allocs: A list of indices for activation tensors in the node's input tensor list """ @@ -63,7 +63,7 @@ def __init__(self, self.op_call_args = op_call_args self.functional_op = functional_op self.inputs_as_list = inputs_as_list - self.tensor_input_indices = [] if tensor_input_indices is None else tensor_input_indices + self.tensor_input_allocs = [] if tensor_input_allocs is None else tensor_input_allocs @property def type(self): diff --git a/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py b/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py index b3bccf743..76274ed6e 100644 --- a/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +++ b/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py @@ -67,7 +67,7 @@ def _build_input_tensors_list(node: BaseNode, def _merge_inputs(_node: BaseNode, input_tensors: List, op_call_args: List, op_call_kwargs: Dict, - tensor_input_indices: List = None) -> List: + tensor_input_allocs: List = None) -> List: """ Merge input tensors list with positional weights and op_call_args, according to correct order. @@ -76,18 +76,18 @@ def _merge_inputs(_node: BaseNode, input_tensors: List, op_call_args: List, op_c input_tensors: activation input tensors to node. op_call_args: framework node call args. op_call_kwargs: framework node call kwargs. - tensor_input_indices: List of input indices to node. + tensor_input_allocs: List of input allocations to node. Returns: Combined list of input_tensors and op_call_args. """ - if isinstance(_node, FunctionalNode) and _node.tensor_input_indices: + if isinstance(_node, FunctionalNode) and _node.tensor_input_allocs: _input_list = op_call_args.copy() - if tensor_input_indices is None: - tensor_input_indices = _node.tensor_input_indices - assert len(tensor_input_indices) == len(input_tensors), \ - f'Mismatch between input tensors ({len(tensor_input_indices)}) and indices {len(input_tensors)}' - for i, t in zip(tensor_input_indices, input_tensors): + if tensor_input_allocs is None: + tensor_input_allocs = _node.tensor_input_allocs + assert len(tensor_input_allocs) == len(input_tensors), \ + f'Mismatch between input tensors ({len(tensor_input_allocs)}) and indices {len(input_tensors)}' + for i, t in zip(tensor_input_allocs, input_tensors): if isinstance(i, str): if i in op_call_kwargs: a=1 @@ -133,15 +133,15 @@ def _run_operation(n: BaseNode, # list separately, because in FX the tensors are FX objects and fail to_torch_tensor input_tensors = [to_torch_tensor(t, numpy_type=t.dtype) if isinstance(t, np.ndarray) else t for t in input_tensors] - _tensor_input_indices = None + _tensor_input_allocs = None else: - _tensor_input_indices = [i for i in n.tensor_input_indices if i not in n.weights] + _tensor_input_allocs = [i for i in n.tensor_input_allocs if i not in n.weights] if isinstance(n, FunctionalNode) and n.inputs_as_list: out_tensors_of_n_float = op_func(input_tensors, *op_call_args, **functional_kwargs) else: merged_inputs, functional_kwargs = _merge_inputs(n, input_tensors, op_call_args, functional_kwargs.copy(), - tensor_input_indices=_tensor_input_indices) + tensor_input_allocs=_tensor_input_allocs) out_tensors_of_n_float = op_func(*merged_inputs, **functional_kwargs) # Add a fake quant node if the node has an activation threshold. diff --git a/model_compression_toolkit/core/pytorch/constants.py b/model_compression_toolkit/core/pytorch/constants.py index 99c33e852..1f7f74f43 100644 --- a/model_compression_toolkit/core/pytorch/constants.py +++ b/model_compression_toolkit/core/pytorch/constants.py @@ -40,7 +40,7 @@ OP_CALL_ARGS = 'op_call_args' OP_CALL_KWARGS = 'op_call_kwargs' INPUTS_AS_LIST = 'inputs_as_list' -TENSOR_INPUT_INDICES = 'tensor_input_indices' +TENSOR_INPUT_ALLOCS = 'tensor_input_allocs' INPLACE = 'inplace' HARDTANH_MIN_VAL = 'min_val' HARDTANH_MAX_VAL = 'max_val' diff --git a/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py b/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py index 709526e52..f03f6b52f 100644 --- a/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +++ b/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py @@ -65,10 +65,10 @@ def substitute(self, # When a "reshape" is called with multiple arguments (e.g. x.reshape(-1, channels, height, width) # this substitution converts it x.reshape((-1, channels, height, width)), so need to update the - # tensor_input_indices attribute. - # scalar argument's shape is [1] so remove those indices from tensor_input_indices + # tensor_input_allocs attribute. + # scalar argument's shape is [1] so remove those indices from tensor_input_allocs # node.input_shape example: [[1, 32, 4, 32], [1], [1], [1]] - node.tensor_input_indices = node.tensor_input_indices[:sum([i != [1] for i in node.input_shape])] + node.tensor_input_allocs = node.tensor_input_allocs[:sum([i != [1] for i in node.input_shape])] # modify the node input info node.input_shape = [node.input_shape[0]] diff --git a/model_compression_toolkit/core/pytorch/reader/graph_builders.py b/model_compression_toolkit/core/pytorch/reader/graph_builders.py index 9b6fdfbcd..bf38a47a2 100644 --- a/model_compression_toolkit/core/pytorch/reader/graph_builders.py +++ b/model_compression_toolkit/core/pytorch/reader/graph_builders.py @@ -23,7 +23,7 @@ from model_compression_toolkit.core.common.graph.edge import Edge from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode from model_compression_toolkit.core.pytorch.constants import OUTPUT, PLACEHOLDER, TENSOR_META, CALL_FUNCTION, TYPE, \ - CALL_METHOD, BIAS, FUNCTIONAL_OP, OP_CALL_KWARGS, OP_CALL_ARGS, INPUTS_AS_LIST, TENSOR_INPUT_INDICES, GET_ATTR + CALL_METHOD, BIAS, FUNCTIONAL_OP, OP_CALL_KWARGS, OP_CALL_ARGS, INPUTS_AS_LIST, TENSOR_INPUT_ALLOCS, GET_ATTR from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder from model_compression_toolkit.logger import Logger @@ -180,7 +180,7 @@ def nodes_builder(model: GraphModule, [isinstance(n, torch.fx.node.Node) for n in node.args[0]]) inputs_as_list = inputs_as_list1 or (len(node.args) > 0 and isinstance(node.args[0], Node) and node.args[0].op == PLACEHOLDER and node.args[0].meta[TYPE] in (list, tuple)) - tensor_input_index = [] + tensor_input_alloc = [] op_call_args = list(node.args) if inputs_as_list: op_call_args.pop(0) @@ -188,10 +188,10 @@ def nodes_builder(model: GraphModule, for in_node in node.all_input_nodes: for i, arg in enumerate(node.args): if arg == in_node: - tensor_input_index.append(i) + tensor_input_alloc.append(i) for k, arg in framework_attr_nodes.items(): if arg == in_node: - tensor_input_index.append(k) + tensor_input_alloc.append(k) # remove torch.fx.node.Node from inputs to graph_node_type op_call_args = [arg for arg in op_call_args if not isinstance(arg, Node)] @@ -203,7 +203,7 @@ def nodes_builder(model: GraphModule, OP_CALL_ARGS: op_call_args, OP_CALL_KWARGS: node_kwargs, INPUTS_AS_LIST: inputs_as_list, - TENSOR_INPUT_INDICES: tensor_input_index} + TENSOR_INPUT_ALLOCS: tensor_input_alloc} else: graph_node_type = BaseNode kwargs = {}