diff --git a/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py b/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py index c11d5ea403a..1f73bcb6244 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py @@ -83,8 +83,10 @@ def create_wrapper_module(class_name: str, functional: Callable) -> Callable: Sqrt = create_wrapper_module('Sqrt', torch.sqrt) Maximum = create_wrapper_module('Maximum', torch.maximum) Max = create_wrapper_module('Max', torch.max) # NOTE: Not elementwise +AMax = create_wrapper_module('AMax', torch.amax) Minimum = create_wrapper_module('Minimum', torch.minimum) Min = create_wrapper_module('Min', torch.min) # NOTE: Not elementwise +AMin = create_wrapper_module('AMin', torch.amin) Where = create_wrapper_module('Where', torch.where) Greater = create_wrapper_module('Greater', torch.gt) Less = create_wrapper_module('Less', torch.lt)