diff --git a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/utils.py b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/utils.py index 8b642afec1d..a2a643dd89c 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/utils.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/utils.py @@ -118,7 +118,7 @@ def patch_attr(obj, attr_name, new_attr)-> _ContextManager: if attr_name in obj._parameters or attr_name in obj._buffers: # pylint: disable=protected-access return _patch_param_or_buffer(obj, attr_name, new_attr) - old_attr = getattr(obj, attr_name) + old_attr = getattr(obj, attr_name, None) action = lambda: setattr(obj, attr_name, new_attr) def cleanup(): @@ -127,7 +127,7 @@ def cleanup(): except AttributeError: pass - if not hasattr(obj, attr_name): + if not hasattr(obj, attr_name) and old_attr is not None: setattr(obj, attr_name, old_attr) return _ContextManager(action, cleanup)