Skip to content

Commit

Permalink
Implement default quantization backend (#2576)
Browse files Browse the repository at this point in the history
Signed-off-by: Seokjun An <quic_seokan@quicinc.com>
  • Loading branch information
quic-seokan authored Nov 27, 2023
1 parent 34951b1 commit 3b9589d
Show file tree
Hide file tree
Showing 2 changed files with 402 additions and 142 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2023-2023, Qualcomm Innovation Center, Inc. All rights reserved.
# Copyright (c) 2023, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
Expand Down Expand Up @@ -34,3 +34,159 @@
#
# @@-COPYRIGHT-END-@@
# =============================================================================
""" Default quantization backend for quantizing weights and activations """
from typing import Union
import torch

def _is_expandable(a: torch.Tensor, b: torch.Tensor) -> bool:
"""
Returns true if tensor a is expandable to shape of tensor b
"""
if len(a.shape) > len(b.shape):
return False
for dim_a, dim_b in zip(a.shape[::-1], b.shape[::-1]):
if dim_a not in (1, dim_b):
return False
return True

def _validate_arguments(tensor: torch.Tensor, scale: torch.Tensor, offset: torch.Tensor, bitwidth: Union[torch.Tensor, int] = None):
if not tensor.dtype == scale.dtype == offset.dtype:
raise RuntimeError("Data type of tensor, scale, and offset are should be the same")
if bitwidth and torch.finfo(tensor.dtype).bits <= bitwidth:
raise RuntimeError(f"Dtype {tensor.dtype} has insufficient bitwidth to perform {bitwidth} quantization")
if not _is_expandable(scale, tensor):
raise RuntimeError(f"Scale of shape {scale.shape} cannot be expanded like input tensor of shape {tensor.shape}")

def quantize(tensor: torch.Tensor, scale: torch.Tensor, offset: torch.Tensor, bitwidth: Union[torch.Tensor, int]) -> torch.Tensor:
"""
Performs differentiable quantization on tensor using scale, offset, and bitwidth parameters.
:param tensor: Tensor to quantize
:param scale: Scale factor for quantization
:param offset: Offset value for quantization
:param bitwidth: Output bitwidth of quantized tensor
:return: Resulting tensor
"""
_validate_arguments(tensor, scale, offset, bitwidth)
return QuantizeFunc.apply(tensor, scale, offset, bitwidth)

def quantize_dequantize(tensor: torch.Tensor, scale: torch.Tensor, offset: torch.Tensor, bitwidth: Union[torch.Tensor, int]) -> torch.Tensor:
"""
Performs differentiable quantize-dequantize operation on tensor using scale, offset, and bitwidth parameters.
:param tensor: Tensor to quantize-dequantize
:param scale: Scale factor for quantization
:param offset: Offset value for quantization
:param bitwidth: simulated quantization bitwidth
:return: Resulting tensor
"""
_validate_arguments(tensor, scale, offset, bitwidth)
return QuantDequantFunc.apply(tensor, scale, offset, bitwidth)

def dequantize(tensor: torch.Tensor, scale: torch.Tensor, offset: torch.Tensor) -> torch.Tensor:
"""
Performs differentiable dequantize operation on tensor using scale and offset parameters.
:param tensor: Tensor to quantize
:param scale: Scale factor for quantization
:param offset: Offset value for quantization
:return: Resulting tensor
"""
_validate_arguments(tensor, scale, offset)
return DequantizeFunc.apply(tensor, scale, offset)


# pylint: disable=abstract-method
class QuantizeFunc(torch.autograd.Function):
"""
Custom gradient function for quantization
"""
# pylint: disable=arguments-differ
@staticmethod
def forward(ctx, tensor: torch.Tensor, scale: torch.Tensor, offset: torch.Tensor, bitwidth: Union[torch.Tensor, int]):
x_round = torch.round(tensor / scale) - offset
if tensor.requires_grad or scale.requires_grad or offset.requires_grad:
mask = (x_round >= 0) * (x_round <= (2 ** bitwidth - 1))
else:
mask = None
ctx.tensor_requires_grad = tensor.requires_grad
ctx.scale_requires_grad = scale.requires_grad
ctx.offset_requires_grad = offset.requires_grad
ctx.save_for_backward(tensor, scale, mask)
return torch.clamp(x_round, 0, 2 ** bitwidth - 1)

# pylint: disable=arguments-differ
@staticmethod
def backward(ctx, grad):
tensor, scale, mask = ctx.saved_tensors
if ctx.tensor_requires_grad or ctx.scale_requires_grad or ctx.offset_requires_grad:
masked_grad = grad * mask
tensor_grad = masked_grad / scale if ctx.tensor_requires_grad else None
scale_grad = -masked_grad * tensor / scale / scale if ctx.scale_requires_grad else None
offset_grad = -masked_grad if ctx.offset_requires_grad else None
return tensor_grad, scale_grad, offset_grad, None


# pylint: disable=abstract-method
class DequantizeFunc(torch.autograd.Function):
"""
Custom gradient function for dequantization
"""
# pylint: disable=arguments-differ
@staticmethod
def forward(ctx, tensor: torch.Tensor, scale: torch.Tensor, offset: torch.Tensor):
x_dequant = (tensor + offset) * scale
ctx.tensor_requires_grad = tensor.requires_grad
ctx.scale_requires_grad = scale.requires_grad
ctx.offset_requires_grad = offset.requires_grad
ctx.save_for_backward(tensor, scale, offset)
return x_dequant

# pylint: disable=arguments-differ
@staticmethod
def backward(ctx, grad):
tensor, scale, offset = ctx.saved_tensors
if ctx.tensor_requires_grad or ctx.offset_requires_grad:
tensor_and_offset_grad = grad * scale
tensor_grad = tensor_and_offset_grad if ctx.tensor_requires_grad else None
scale_grad = grad * (tensor + offset) if ctx.scale_requires_grad else None
offset_grad = tensor_and_offset_grad if ctx.offset_requires_grad else None
return tensor_grad, scale_grad, offset_grad


# pylint: disable=abstract-method
class QuantDequantFunc(torch.autograd.Function):
"""
Custom gradient function for quant-dequant
"""
# pylint: disable=arguments-differ
@staticmethod
def forward(ctx, tensor: torch.Tensor, scale: torch.Tensor, offset: torch.Tensor, bitwidth: Union[torch.Tensor, int]):
x_round = torch.round(tensor / scale) - offset
x_quant = torch.clamp(x_round, 0, 2 ** bitwidth - 1)
if tensor.requires_grad or scale.requires_grad or offset.requires_grad:
mask = (x_round >= 0) * (x_round <= (2 ** bitwidth - 1))
else:
mask = None
x_dequant = (x_quant + offset) * scale

# Downcast x_quant if bitwidth is less than or equal to 8 to reduce memory consumption
if bitwidth <= 8:
x_quant = x_quant.to(dtype=torch.uint8)

ctx.tensor_requires_grad = tensor.requires_grad
ctx.scale_requires_grad = scale.requires_grad
ctx.offset_requires_grad = offset.requires_grad
ctx.save_for_backward(tensor, scale, offset, mask, x_quant)
return x_dequant


# pylint: disable=arguments-differ
@staticmethod
def backward(ctx, grad):
tensor, scale, offset, mask, x_quant = ctx.saved_tensors
tensor_grad = grad * mask if ctx.tensor_requires_grad else None
scale_grad = grad * (x_quant + offset - mask * tensor / scale) \
if ctx.scale_requires_grad else None
offset_grad = -grad * (mask * scale - scale) if ctx.offset_requires_grad else None
return tensor_grad, scale_grad, offset_grad, None
Loading

0 comments on commit 3b9589d

Please sign in to comment.