Skip to content

Commit

Permalink
Add code placeholder for custom quantized modules (#3425)
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
  • Loading branch information
quic-kyunggeu authored Oct 24, 2024
1 parent 67e8601 commit c51b284
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -566,3 +566,103 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# class QuantizedGridSample(_DispatchMixin, QuantizationMixin, GridSample):
# """ Quantized GridSample """
# _builtin_torch_fn = torch.nn.functional.grid_sample
#
#
# @QuantizationMixin.implements(DynamicConv2d)
# class QuantizedDynamicConv2d(QuantizationMixin, DynamicConv2d):
# """ Quantized DynamicConv2d """
#
#
# @QuantizationMixin.implements(Pow)
# class QuantizedPow(QuantizationMixin, Pow):
# """ Quantized Pow """
#
#
# @QuantizationMixin.implements(CustomSiLU)
# class QuantizedCustomSiLU(QuantizationMixin, CustomSiLU):
# """ Quantized CustomSiLU """
#
#
# @QuantizationMixin.implements(StridedSlice)
# class QuantizedStridedSlice(QuantizationMixin, StridedSlice):
# """ Quantized StridedSlice """
#
#
# @QuantizationMixin.implements(ChannelShuffle)
# class QuantizedChannelShuffle(QuantizationMixin, ChannelShuffle):
# """ Quantized ChannelShuffle """
#
#
# @QuantizationMixin.implements(Cast)
# class QuantizedCast(QuantizationMixin, Cast):
# """ Quantized Cast """
#
#
# @QuantizationMixin.implements(CustomGather)
# class QuantizedCustomGather(QuantizationMixin, CustomGather):
# """ Quantized CustomGather """
#
#
# @QuantizationMixin.implements(DepthToSpaceCRDMode)
# class QuantizedDepthToSpaceCRDMode(QuantizationMixin, DepthToSpaceCRDMode):
# """ Quantized DepthToSpaceCRDMode """
#
#
# @QuantizationMixin.implements(DepthToSpaceDCRMode)
# class QuantizedDepthToSpaceDCRMode(QuantizationMixin, DepthToSpaceDCRMode):
# """ Quantized DepthToSpaceDCRMode """
#
#
# @QuantizationMixin.implements(CustomSparseConv3DLayer)
# class QuantizedCustomSparseConv3DLayer(QuantizationMixin, CustomSparseConv3DLayer):
# """ Quantized CustomSparseConv3DLayer """
#
#
# @QuantizationMixin.implements(SparseTensorWrapper)
# class QuantizedSparseTensorWrapper(QuantizationMixin, SparseTensorWrapper):
# """ Quantized SparseTensorWrapper """
#
#
# @QuantizationMixin.implements(ScatterDense)
# class QuantizedScatterDense(QuantizationMixin, ScatterDense):
# """ Quantized ScatterDense """
#
#
# @QuantizationMixin.implements(ScatterND)
# class QuantizedScatterND(QuantizationMixin, ScatterND):
# """ Quantized ScatterND """
#
#
# @QuantizationMixin.implements(RoiAlign)
# class QuantizedRoiAlign(QuantizationMixin, RoiAlign):
# """ Quantized RoiAlign """
#
#
# @QuantizationMixin.implements(NonMaxSuppression)
# class QuantizedNonMaxSuppression(QuantizationMixin, NonMaxSuppression):
# """ Quantized NonMaxSuppression """
#
#
# @QuantizationMixin.implements(GatherNd)
# class QuantizedGatherNd(QuantizationMixin, GatherNd):
# """ Quantized GatherNd """
#
#
# @QuantizationMixin.implements(ScatterElements)
# class QuantizedScatterElements(QuantizationMixin, ScatterElements):
# """ Quantized ScatterElements """
#
#
# @QuantizationMixin.implements(OneHot)
# class QuantizedOneHot(QuantizationMixin, OneHot):
# """ Quantized OneHot """
#
#
# @QuantizationMixin.implements(Expand)
# class QuantizedExpand(QuantizationMixin, Expand):
# """ Quantized Expand """
#
#
# @QuantizationMixin.implements(DynamicLinear)
# class QuantizedDynamicLinear(QuantizationMixin, DynamicLinear):
# """ Quantized DynamicLinear """
22 changes: 21 additions & 1 deletion TrainingExtensions/torch/test/python/v2/nn/test_true_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,7 +954,27 @@ def _create_quantized_module(module):
# (lambda: custom.Normalize(), lambda: ...),
# (lambda: custom.Pad(), lambda: ...),
# (lambda: custom.GridSample(), lambda: ...),
(lambda: custom.RmsNorm([5, 2, 3], [2], 1e-5), lambda: (randn(5, 2, 3)))
(lambda: custom.RmsNorm([5, 2, 3], [2], 1e-5), lambda: (randn(5, 2, 3))),
# (lambda custom.DynamicConv2d(), lambda: ...),
# (lambda custom.Pow(), lambda: ...),
# (lambda custom.CustomSiLU(), lambda: ...),
# (lambda custom.StridedSlice(), lambda: ...),
# (lambda custom.ChannelShuffle(), lambda: ...),
# (lambda custom.Cast(), lambda: ...),
# (lambda custom.CustomGather(), lambda: ...),
# (lambda custom.DepthToSpaceCRDMode(), lambda: ...),
# (lambda custom.DepthToSpaceDCRMode(), lambda: ...),
# (lambda custom.CustomSparseConv3DLayer(), lambda: ...),
# (lambda custom.SparseTensorWrapper(), lambda: ...),
# (lambda custom.ScatterDense(), lambda: ...),
# (lambda custom.ScatterND(), lambda: ...),
# (lambda custom.RoiAlign(), lambda: ...),
# (lambda custom.NonMaxSuppression(), lambda: ...),
# (lambda custom.GatherNd(), lambda: ...),
# (lambda custom.ScatterElements(), lambda: ...),
# (lambda custom.OneHot(), lambda: ...),
# (lambda custom.Expand(), lambda: ...),
# (lambda custom.DynamicLinear(), lambda: ...),
]))
def test_default_kernels(module_factory, input_factory):
module = module_factory()
Expand Down

0 comments on commit c51b284

Please sign in to comment.