diff --git a/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py b/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py index d5dc3f55b7b..16ddb8502a9 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py @@ -37,7 +37,7 @@ """ Custom modules for functional operations defined under torch and torch.nn.functional packages """ -from typing import Callable, Any, Tuple, Union +from typing import Callable, Any, Tuple, Union, List import itertools import torchvision import torch @@ -278,6 +278,24 @@ def forward(data: torch.Tensor, indices: torch.Tensor, axis: int = 0) -> torch.T return torch.index_select(data, axis, indices.flatten()).reshape(target_shape) +class DepthToSpaceCRDMode(torch.nn.Module): + """ Depthtospace op implementation in CRD mode """ + + def __init__(self, block_size: List): + super().__init__() + self.block_size_h = block_size[0] + self.block_size_w = block_size[1] + + def forward(self, x: torch.Tensor) -> Any: + """ + Forward-pass routine for DepthToSpace op in CRD mode + """ + b, c, h, w = x.shape + tmp = torch.reshape(x, (b, c // (self.block_size_h * self.block_size_w), self.block_size_h, self.block_size_w, h, w)) + tmp = torch.permute(tmp, (0, 1, 4, 2, 5, 3)) + out = torch.reshape(tmp, (b, c // (self.block_size_h * self.block_size_w), h * self.block_size_h, w * self.block_size_w)) + return out + class DepthToSpaceDCRMode(torch.nn.Module): """ Depthtospace op implementation in DCR mode """