Skip to content

Commit

Permalink
Fix minor bug in has_hooks (#3460)
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
  • Loading branch information
quic-kyunggeu authored Nov 4, 2024
1 parent 38bdee5 commit e457cde
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 9 deletions.
23 changes: 15 additions & 8 deletions TrainingExtensions/torch/src/python/aimet_torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@
import torch.nn
import torch
from torch.utils.data import DataLoader, Dataset
from torch.nn.modules.module import (
_global_backward_pre_hooks,
_global_backward_hooks,
_global_forward_pre_hooks,
_global_forward_hooks,
)
from torchvision import datasets, transforms

from aimet_common.defs import QuantScheme, QuantizationDataType, MAP_QUANT_SCHEME_TO_PYMO
Expand Down Expand Up @@ -543,14 +549,15 @@ def get_input_shape_batch_size(data_loader):

def has_hooks(module: torch.nn.Module):
""" Returns True if the module uses hooks. """

for hooks in (module._forward_pre_hooks, # pylint: disable=protected-access
module._forward_hooks, module._backward_hooks): # pylint: disable=protected-access
if hooks:
logger.warning("The specified model has registered hooks which might break winnowing")
return True
return False

# pylint: disable=protected-access
return module._backward_hooks or\
module._backward_pre_hooks or\
module._forward_hooks or\
module._forward_pre_hooks or\
_global_backward_pre_hooks or\
_global_backward_hooks or\
_global_forward_hooks or\
_global_forward_pre_hooks

def get_one_positions_in_binary_mask(mask):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def __init__(self, model: torch.nn.Module, input_shape: Tuple,
"""

super().__init__(list_of_modules_to_winnow, reshape, in_place, verbose)
model.apply(has_hooks)

if any(has_hooks(module) for module in model.modules()):
logger.warning("The specified model has registered hooks which might break winnowing")

debug_level = logger.getEffectiveLevel()
logger.debug("Current log level: %s", debug_level)
Expand Down

0 comments on commit e457cde

Please sign in to comment.