Skip to content

Commit

Permalink
Fixed missing input encoding issue for custom strided slice
Browse files Browse the repository at this point in the history
Signed-off-by: Rishabh Thakur <quic_ristha@quicinc.com>
  • Loading branch information
quic-ristha committed Oct 31, 2023
1 parent a346c04 commit c85e0c0
Showing 1 changed file with 9 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@
MULTI_INPUT_OPS_TO_PARSE = [elementwise_ops.Add, elementwise_ops.Multiply, elementwise_ops.Subtract,
elementwise_ops.Divide, elementwise_ops.Pow]

# We want to consider following operations as leaf nodes while creating op for connected graph.
SKIP_LIST_FOR_SUBGRAPH_TRACE = (elementwise_ops.StridedSlice, elementwise_ops.GatherNd, elementwise_ops.ScatterND,
elementwise_ops.CustomGather, elementwise_ops.DepthToSpaceDCRMode,
elementwise_ops.RoiAlign, elementwise_ops.ChannelShuffle)

# pylint: disable=too-many-lines
# pylint: disable=protected-access
class OpWithMultipleOutputs(Op):
Expand Down Expand Up @@ -522,7 +527,10 @@ def _parse_callmethod_node(self, node: torch._C.Node,
# 1st input is a reference on which the call method is being invoked.
input_name: str = inputs[0].debugName()
outputs = [output for output in node.outputs()]
if input_name in node_name_to_subgraph_model:

# We don't want to further trace some custom implementation from elementwise_ops
if input_name in node_name_to_subgraph_model and \
not isinstance(node_name_to_subgraph_model[input_name][0], SKIP_LIST_FOR_SUBGRAPH_TRACE):
elementwise_info = None
subgraph_model, getattr_node_info = node_name_to_subgraph_model[input_name]
# For elementwise ops, we need to parse the callmethod interior, but want to retain information about the
Expand Down

0 comments on commit c85e0c0

Please sign in to comment.