From 8b78ad56fc44c51560002186bddb1b3438b6646b Mon Sep 17 00:00:00 2001 From: Rishabh Thakur Date: Tue, 31 Oct 2023 10:12:07 +0530 Subject: [PATCH] Fix missing input encoding issue for custom strided slice Signed-off-by: Rishabh Thakur --- .../src/python/aimet_torch/meta/connectedgraph.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/meta/connectedgraph.py b/TrainingExtensions/torch/src/python/aimet_torch/meta/connectedgraph.py index 6d947dc9684..5f84fcf82a6 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/meta/connectedgraph.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/meta/connectedgraph.py @@ -82,6 +82,11 @@ MULTI_INPUT_OPS_TO_PARSE = [elementwise_ops.Add, elementwise_ops.Multiply, elementwise_ops.Subtract, elementwise_ops.Divide, elementwise_ops.Pow] +# We want to consider following operations as leaf nodes while creating op for connected graph. +SKIP_LIST_FOR_SUBGRAPH_TRACE = (elementwise_ops.StridedSlice, elementwise_ops.GatherNd, elementwise_ops.ScatterND, + elementwise_ops.CustomGather, elementwise_ops.DepthToSpaceDCRMode, + elementwise_ops.RoiAlign, elementwise_ops.ChannelShuffle) + # pylint: disable=too-many-lines # pylint: disable=protected-access class OpWithMultipleOutputs(Op): @@ -522,7 +527,10 @@ def _parse_callmethod_node(self, node: torch._C.Node, # 1st input is a reference on which the call method is being invoked. input_name: str = inputs[0].debugName() outputs = [output for output in node.outputs()] - if input_name in node_name_to_subgraph_model: + + # We don't want to further trace some custom implementation from elementwise_ops + if input_name in node_name_to_subgraph_model and \ + not isinstance(node_name_to_subgraph_model[input_name][0], SKIP_LIST_FOR_SUBGRAPH_TRACE): elementwise_info = None subgraph_model, getattr_node_info = node_name_to_subgraph_model[input_name] # For elementwise ops, we need to parse the callmethod interior, but want to retain information about the