Skip to content

Commit

Permalink
Add quantized definition of custom numerical operations (#3481)
Browse files Browse the repository at this point in the history
* Adding definitions for custom elementwise numerical modules

Signed-off-by: Priyanka Dangi <quic_pdangi@quicinc.com>
  • Loading branch information
quic-pdangi authored Nov 14, 2024
1 parent 9ad3254 commit 1c41530
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -195,24 +195,24 @@ def cat(tensors, dim=0, *, out=None):
class QuantizedNorm(_DispatchMixin, QuantizationMixin, Norm):
""" Quantized Norm """
_builtin_torch_fn = torch.norm
#
#
# @QuantizationMixin.implements(Exponential)
# class QuantizedExponential(_DispatchMixin, QuantizationMixin, Exponential):
# """ Quantized Exponential """
# _builtin_torch_fn = torch.exp
#
#
# @QuantizationMixin.implements(Erf)
# class QuantizedErf(_DispatchMixin, QuantizationMixin, Erf):
# """ Quantized Erf """
# _builtin_torch_fn = torch.erf
#
#
# @QuantizationMixin.implements(Sqrt)
# class QuantizedSqrt(_DispatchMixin, QuantizationMixin, Sqrt):
# """ Quantized Sqrt """
# _builtin_torch_fn = torch.sqrt


@QuantizationMixin.implements(Exponential)
class QuantizedExponential(_DispatchMixin, QuantizationMixin, Exponential):
""" Quantized Exponential """
_builtin_torch_fn = torch.exp


@QuantizationMixin.implements(Erf)
class QuantizedErf(_DispatchMixin, QuantizationMixin, Erf):
""" Quantized Erf """
_builtin_torch_fn = torch.erf


@QuantizationMixin.implements(Sqrt)
class QuantizedSqrt(_DispatchMixin, QuantizationMixin, Sqrt):
""" Quantized Sqrt """
_builtin_torch_fn = torch.sqrt
#
#
# @QuantizationMixin.implements(Maximum)
Expand Down Expand Up @@ -328,22 +328,22 @@ class QuantizedCumSum(_DispatchMixin, QuantizationMixin, CumSum):
# _builtin_torch_fn = torch.prod
#
#
# @QuantizationMixin.implements(Log)
# class QuantizedLog(_DispatchMixin, QuantizationMixin, Log):
# """ Quantized Log """
# _builtin_torch_fn = torch.log
#
#
# @QuantizationMixin.implements(Abs)
# class QuantizedAbs(_DispatchMixin, QuantizationMixin, Abs):
# """ Quantized Abs """
# _builtin_torch_fn = torch.abs
#
#
# @QuantizationMixin.implements(Neg)
# class QuantizedNeg(_DispatchMixin, QuantizationMixin, Neg):
# """ Quantized Neg """
# _builtin_torch_fn = torch.neg
@QuantizationMixin.implements(Log)
class QuantizedLog(_DispatchMixin, QuantizationMixin, Log):
""" Quantized Log """
_builtin_torch_fn = torch.log


@QuantizationMixin.implements(Abs)
class QuantizedAbs(_DispatchMixin, QuantizationMixin, Abs):
""" Quantized Abs """
_builtin_torch_fn = torch.abs


@QuantizationMixin.implements(Neg)
class QuantizedNeg(_DispatchMixin, QuantizationMixin, Neg):
""" Quantized Neg """
_builtin_torch_fn = torch.neg
#
#
# @QuantizationMixin.implements(Argmin)
Expand Down Expand Up @@ -464,8 +464,7 @@ class QuantizedCumSum(_DispatchMixin, QuantizationMixin, CumSum):
# class QuantizedTile(_DispatchMixin, QuantizationMixin, Tile):
# """ Quantized Tile """
# _builtin_torch_fn = torch.tile
#
#

# @QuantizationMixin.implements(ElementwiseUnarySign)
# class QuantizedElementwiseUnarySign(_DispatchMixin, QuantizationMixin, ElementwiseUnarySign):
# """ Quantized ElementwiseUnarySign """
Expand Down
16 changes: 8 additions & 8 deletions TrainingExtensions/torch/test/python/v2/nn/test_true_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,10 +917,10 @@ def _create_quantized_module(module):
(lambda: custom.Divide(), lambda: (randn(100), randn(100))),
(lambda: custom.Concat(), lambda: (randn(1, 100), randn(3, 100))),
# (lambda: custom.FloorDivide(), lambda: ...),
(lambda: custom.Norm(), lambda: randn(100)),
# (lambda: custom.Exponential(), lambda: ...),
# (lambda: custom.Erf(), lambda: ...),
# (lambda: custom.Sqrt(), lambda: ...),
(lambda: custom.Norm(), lambda: randn(100)),
(lambda: custom.Exponential(), lambda: randn(100)),
(lambda: custom.Erf(), lambda: randn(100)),
(lambda: custom.Sqrt(), lambda: randn(100).abs()),
# (lambda: custom.Maximum(), lambda: ...),
# (lambda: custom.Max(), lambda: ...),
# (lambda: custom.AMax(), lambda: ...),
Expand All @@ -940,9 +940,9 @@ def _create_quantized_module(module):
# (lambda: custom.Mean(), lambda: ...),
# (lambda: custom.Sum(), lambda: ...),
# (lambda: custom.Prod(), lambda: ...),
# (lambda: custom.Log(), lambda: ...),
# (lambda: custom.Abs(), lambda: ...),
# (lambda: custom.Neg(), lambda: ...),
(lambda: custom.Log(), lambda: randint(1, 1000, (10,10))),
(lambda: custom.Abs(), lambda: randn(100)),
(lambda: custom.Neg(), lambda: randn(100)),
# (lambda: custom.Argmin(), lambda: ...),
# (lambda: custom.Argmax(), lambda: ...),
# (lambda: custom.ElementwiseCeil(), lambda: ...),
Expand Down Expand Up @@ -1026,7 +1026,7 @@ def test_default_kernels(module_factory, input_factory):


with qmodule.compute_encodings():
torch.manual_seed(0);
torch.manual_seed(0)
_ = qmodule(*inputs)

torch.manual_seed(0)
Expand Down

0 comments on commit 1c41530

Please sign in to comment.