Skip to content

Commit

Permalink
Keep AIMET custom modules as a single op in connected graph
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Hsieh <quic_klhsieh@quicinc.com>
  • Loading branch information
quic-klhsieh authored Oct 11, 2024
1 parent 484c650 commit 1294ee6
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

import copy
from collections import defaultdict
import inspect
from typing import Tuple, Union, List, Dict, Type, Optional
import torch

Expand Down Expand Up @@ -137,6 +138,7 @@ class ConnectedGraph(AimetCommonConnectedGraph):
either module or functional) as producers and consumers of tensors.
Note that the graph has two kinds of nodes: operations and products."""

# pylint: disable=too-many-instance-attributes
def __init__(self, model: torch.nn.Module, model_input: Union[torch.Tensor, Tuple]):
"""
Init function for connected graph.
Expand All @@ -157,7 +159,12 @@ def __init__(self, model: torch.nn.Module, model_input: Union[torch.Tensor, Tupl

self._generate_module_lookup_table(model)
with in_eval_mode(model), torch.no_grad():
self._aimet_defined_modules = \
tuple(classtype for _, classtype in inspect.getmembers(aimet_modules,
lambda m: inspect.isclass(m) and issubclass(m,
torch.nn.Module)))
self._construct_graph(model, model_input)
del self._aimet_defined_modules

# List of ops in the order they are traversed using the forward function
self.ordered_ops = self._get_ordered_ops()
Expand Down Expand Up @@ -1242,7 +1249,7 @@ def _is_recursive_parsing_needed(self, module: torch.nn.Module,
recursive_parsing_needed = True
if is_torch_nn_leaf_module(module) or \
is_custom_leaf_module(module, self._find_aten_nodes_in_forward_pass(trace)) or \
isinstance(module, tuple(aimet_torch.utils.modules_to_treat_as_leaf)):
isinstance(module, (self._aimet_defined_modules, tuple(aimet_torch.utils.modules_to_treat_as_leaf))):
recursive_parsing_needed = False

return recursive_parsing_needed
Expand Down
18 changes: 18 additions & 0 deletions TrainingExtensions/torch/test/python/test_connectedgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,3 +1015,21 @@ def __init__(self):
assert len(add_1.inputs) == 2
assert add_1.inputs == [p3, p4]
assert p5.name not in mcg._products.keys()

def test_custom_ops_to_treat_as_leaf_module(self):
class RMSNormModel(torch.nn.Module):
def __init__(self):
super(RMSNormModel, self).__init__()
self.rms_norm_0 = aimet_modules.RmsNorm([5, 2, 3], [2], 1e-5)

def forward(self, inp):
x = self.rms_norm_0(inp)
return x

model = RMSNormModel()
dummy_input = torch.randn(5, 2, 3)

cg = ConnectedGraph(model, dummy_input)
assert len(cg.get_all_ops()) == 1
assert cg.ordered_ops[0].inputs[0].is_model_input
assert cg.ordered_ops[0].get_module() == model.rms_norm_0

0 comments on commit 1294ee6

Please sign in to comment.