Skip to content

Commit

Permalink
change split to unbind and cat to stack due to converter's supported …
Browse files Browse the repository at this point in the history
…dimensions
  • Loading branch information
itai-berman committed Jan 5, 2025
1 parent 40f038d commit 05f0021
Showing 1 changed file with 17 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def substitute(self,
# Stack and reshape all results - reshape if needed
# [(B, m, n)] * (D_1*...*D_N) --> (B, (D_1*...*D_N), m, n)
# (B, (D_1*...*D_N), m, n) --> (B, D_1, ..., D_N, m, n)
output_node = self._cat_matmul_outputs(
output_node = self._stack_matmul_outputs(
graph,
matmul_node,
split_matmul_nodes,
Expand Down Expand Up @@ -325,26 +325,26 @@ def _split_inputs(graph: Graph,
"""
input_split_node = FunctionalNode(
name=f'{matmul_node.name}_input_split',
framework_attr={DIM: 1},
framework_attr={},
input_shape=params.input_reshape_shape,
output_shape=params.input_split_shape,
weights={},
layer_class=torch.split,
layer_class=torch.unbind,
op_call_args=[1],
op_call_kwargs={DIM: 1},
functional_op=torch.split
op_call_kwargs={},
functional_op=torch.unbind
)

other_split_node = FunctionalNode(
name=f'{matmul_node.name}_other_split',
framework_attr={DIM: 1},
framework_attr={},
input_shape=params.other_reshape_shape,
output_shape=params.other_split_shape,
weights={},
layer_class=torch.split,
op_call_args=[1], # Should this be in kwargs or args
op_call_kwargs={DIM: 1},
functional_op=torch.split
layer_class=torch.unbind,
op_call_args=[1],
op_call_kwargs={},
functional_op=torch.unbind
)

if params.prev_input_node:
Expand Down Expand Up @@ -422,10 +422,10 @@ def _calc_single_matmul(graph: Graph,
return matmul_node

@staticmethod
def _cat_matmul_outputs(graph: Graph,
matmul_node: FunctionalNode,
split_matmul_nodes: List[FunctionalNode],
params: MatMulParams) -> FunctionalNode:
def _stack_matmul_outputs(graph: Graph,
matmul_node: FunctionalNode,
split_matmul_nodes: List[FunctionalNode],
params: MatMulParams) -> FunctionalNode:
"""
This method creates the node that concats all single matmuls together and then reshapes to the original output
shape.
Expand All @@ -441,15 +441,15 @@ def _cat_matmul_outputs(graph: Graph,
"""
# [(B, m, n)] * (D_1*...*D_N) --> (B, (D_1*...*D_N), m, n)
cat_node = FunctionalNode(
name=f'{matmul_node.name}_cat',
name=f'{matmul_node.name}_stack',
framework_attr={DIM: 1},
input_shape=[params.single_matmul_shape] * params.matmul_stack_shape[1],
output_shape=params.matmul_stack_shape,
weights={},
layer_class=torch.cat,
layer_class=torch.stack,
op_call_args=[],
op_call_kwargs={DIM: 1},
functional_op=torch.cat,
functional_op=torch.stack,
inputs_as_list=True
)
graph.add_node_with_in_edges(cat_node, split_matmul_nodes)
Expand Down

0 comments on commit 05f0021

Please sign in to comment.