Skip to content

Commit

Permalink
rename tensor_input_indices to tensor_input_allocs as it's no longer …
Browse files Browse the repository at this point in the history
…a list of integers, but also strings.
  • Loading branch information
elad-c committed Jun 10, 2024
1 parent d6519d4 commit 738e3cc
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self,
functional_op: Any = None,
inputs_as_list: bool = False,
has_activation: bool = True,
tensor_input_indices = None):
tensor_input_allocs = None):
"""
Init a FunctionalNode object.
Expand All @@ -44,7 +44,7 @@ def __init__(self,
functional_op: The op the node implements.
inputs_as_list: Whether to pass the node its input tensors as a list or not when calling the layer.
has_activation: Whether the node has activations that we might want to quantize.
tensor_input_indices: A list of indices for activation tensors in the node's input tensor list
tensor_input_allocs: A list of indices for activation tensors in the node's input tensor list
"""

Expand All @@ -63,7 +63,7 @@ def __init__(self,
self.op_call_args = op_call_args
self.functional_op = functional_op
self.inputs_as_list = inputs_as_list
self.tensor_input_indices = [] if tensor_input_indices is None else tensor_input_indices
self.tensor_input_allocs = [] if tensor_input_allocs is None else tensor_input_allocs

@property
def type(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _build_input_tensors_list(node: BaseNode,


def _merge_inputs(_node: BaseNode, input_tensors: List, op_call_args: List, op_call_kwargs: Dict,
tensor_input_indices: List = None) -> List:
tensor_input_allocs: List = None) -> List:
"""
Merge input tensors list with positional weights and op_call_args, according to correct order.
Expand All @@ -76,18 +76,18 @@ def _merge_inputs(_node: BaseNode, input_tensors: List, op_call_args: List, op_c
input_tensors: activation input tensors to node.
op_call_args: framework node call args.
op_call_kwargs: framework node call kwargs.
tensor_input_indices: List of input indices to node.
tensor_input_allocs: List of input allocations to node.
Returns:
Combined list of input_tensors and op_call_args.
"""
if isinstance(_node, FunctionalNode) and _node.tensor_input_indices:
if isinstance(_node, FunctionalNode) and _node.tensor_input_allocs:
_input_list = op_call_args.copy()
if tensor_input_indices is None:
tensor_input_indices = _node.tensor_input_indices
assert len(tensor_input_indices) == len(input_tensors), \
f'Mismatch between input tensors ({len(tensor_input_indices)}) and indices {len(input_tensors)}'
for i, t in zip(tensor_input_indices, input_tensors):
if tensor_input_allocs is None:
tensor_input_allocs = _node.tensor_input_allocs
assert len(tensor_input_allocs) == len(input_tensors), \
f'Mismatch between input tensors ({len(tensor_input_allocs)}) and indices {len(input_tensors)}'
for i, t in zip(tensor_input_allocs, input_tensors):
if isinstance(i, str):
if i in op_call_kwargs:
a=1
Expand Down Expand Up @@ -133,15 +133,15 @@ def _run_operation(n: BaseNode,
# list separately, because in FX the tensors are FX objects and fail to_torch_tensor
input_tensors = [to_torch_tensor(t, numpy_type=t.dtype) if isinstance(t, np.ndarray) else t
for t in input_tensors]
_tensor_input_indices = None
_tensor_input_allocs = None
else:
_tensor_input_indices = [i for i in n.tensor_input_indices if i not in n.weights]
_tensor_input_allocs = [i for i in n.tensor_input_allocs if i not in n.weights]

if isinstance(n, FunctionalNode) and n.inputs_as_list:
out_tensors_of_n_float = op_func(input_tensors, *op_call_args, **functional_kwargs)
else:
merged_inputs, functional_kwargs = _merge_inputs(n, input_tensors, op_call_args, functional_kwargs.copy(),
tensor_input_indices=_tensor_input_indices)
tensor_input_allocs=_tensor_input_allocs)
out_tensors_of_n_float = op_func(*merged_inputs, **functional_kwargs)

# Add a fake quant node if the node has an activation threshold.
Expand Down
2 changes: 1 addition & 1 deletion model_compression_toolkit/core/pytorch/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
OP_CALL_ARGS = 'op_call_args'
OP_CALL_KWARGS = 'op_call_kwargs'
INPUTS_AS_LIST = 'inputs_as_list'
TENSOR_INPUT_INDICES = 'tensor_input_indices'
TENSOR_INPUT_ALLOCS = 'tensor_input_allocs'
INPLACE = 'inplace'
HARDTANH_MIN_VAL = 'min_val'
HARDTANH_MAX_VAL = 'max_val'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ def substitute(self,

# When a "reshape" is called with multiple arguments (e.g. x.reshape(-1, channels, height, width)
# this substitution converts it x.reshape((-1, channels, height, width)), so need to update the
# tensor_input_indices attribute.
# scalar argument's shape is [1] so remove those indices from tensor_input_indices
# tensor_input_allocs attribute.
# scalar argument's shape is [1] so remove those indices from tensor_input_allocs
# node.input_shape example: [[1, 32, 4, 32], [1], [1], [1]]
node.tensor_input_indices = node.tensor_input_indices[:sum([i != [1] for i in node.input_shape])]
node.tensor_input_allocs = node.tensor_input_allocs[:sum([i != [1] for i in node.input_shape])]

# modify the node input info
node.input_shape = [node.input_shape[0]]
Expand Down
10 changes: 5 additions & 5 deletions model_compression_toolkit/core/pytorch/reader/graph_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from model_compression_toolkit.core.common.graph.edge import Edge
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
from model_compression_toolkit.core.pytorch.constants import OUTPUT, PLACEHOLDER, TENSOR_META, CALL_FUNCTION, TYPE, \
CALL_METHOD, BIAS, FUNCTIONAL_OP, OP_CALL_KWARGS, OP_CALL_ARGS, INPUTS_AS_LIST, TENSOR_INPUT_INDICES, GET_ATTR
CALL_METHOD, BIAS, FUNCTIONAL_OP, OP_CALL_KWARGS, OP_CALL_ARGS, INPUTS_AS_LIST, TENSOR_INPUT_ALLOCS, GET_ATTR
from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder
from model_compression_toolkit.logger import Logger

Expand Down Expand Up @@ -180,18 +180,18 @@ def nodes_builder(model: GraphModule,
[isinstance(n, torch.fx.node.Node) for n in node.args[0]])
inputs_as_list = inputs_as_list1 or (len(node.args) > 0 and isinstance(node.args[0], Node) and
node.args[0].op == PLACEHOLDER and node.args[0].meta[TYPE] in (list, tuple))
tensor_input_index = []
tensor_input_alloc = []
op_call_args = list(node.args)
if inputs_as_list:
op_call_args.pop(0)
else:
for in_node in node.all_input_nodes:
for i, arg in enumerate(node.args):
if arg == in_node:
tensor_input_index.append(i)
tensor_input_alloc.append(i)
for k, arg in framework_attr_nodes.items():
if arg == in_node:
tensor_input_index.append(k)
tensor_input_alloc.append(k)

# remove torch.fx.node.Node from inputs to graph_node_type
op_call_args = [arg for arg in op_call_args if not isinstance(arg, Node)]
Expand All @@ -203,7 +203,7 @@ def nodes_builder(model: GraphModule,
OP_CALL_ARGS: op_call_args,
OP_CALL_KWARGS: node_kwargs,
INPUTS_AS_LIST: inputs_as_list,
TENSOR_INPUT_INDICES: tensor_input_index}
TENSOR_INPUT_ALLOCS: tensor_input_alloc}
else:
graph_node_type = BaseNode
kwargs = {}
Expand Down

0 comments on commit 738e3cc

Please sign in to comment.