Skip to content

Commit

Permalink
Implement float quantizer
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
  • Loading branch information
quic-kyunggeu authored Jan 31, 2024
1 parent d92b7e5 commit 1554d42
Show file tree
Hide file tree
Showing 9 changed files with 378 additions and 110 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@

import torch.nn as nn

from aimet_torch.experimental.v2.utils import patch_param
from aimet_torch.experimental.v2.utils import patch_attr


class BaseQuantizationMixin(abc.ABC):
Expand Down Expand Up @@ -91,7 +91,7 @@ def _patch_quantized_parameters(self):
if param_quantizer:
orig_param = getattr(self, param_name)
quantized_param = param_quantizer(orig_param)
ctx = patch_param(self, param_name, quantized_param)
ctx = patch_attr(self, param_name, quantized_param)
stack.enter_context(ctx)
yield

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@

from .base import *
from .affine import *
from .float import *
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
import torch
from torch import nn

from aimet_torch.experimental.v2.utils import patch_attr, patch_param, _is_expandable, StatisticsNotFoundError
from aimet_torch.experimental.v2.utils import patch_attr, _is_expandable, StatisticsNotFoundError
from aimet_torch.experimental.v2.quantization.encoding_analyzer import EncodingAnalyzer, MinMaxEncodingAnalyzer
from aimet_torch.experimental.v2.quantization.quantizers.base import QuantizerBase
from aimet_torch.experimental.v2.quantization.backends import get_backend
Expand Down Expand Up @@ -180,8 +180,8 @@ def forward_wrapper(input):
self.encoding_analyzer.compute_encodings_from_stats(batch_statistics,
self.bitwidth,
self.symmetric)
with patch_param(self, 'min', dynamic_min),\
patch_param(self, 'max', dynamic_max):
with patch_attr(self, 'min', dynamic_min),\
patch_attr(self, 'max', dynamic_max):
return original_forward(input)

try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,178 @@
#
# @@-COPYRIGHT-END-@@
# =============================================================================
# pylint: disable=redefined-builtin
""" Float quantizers """

import contextlib
import functools
from typing import Optional, List, Dict

import torch
from aimet_torch.experimental.v2.quantization.encoding_analyzer import EncodingAnalyzer
from aimet_torch.experimental.v2.quantization.quantizers.base import QuantizerBase
from aimet_torch.experimental.v2.utils import StatisticsNotFoundError, patch_attr
from aimet_torch.fp_quantization import fake_cast_to_ieee_float


__all__ = ['FloatQuantizeDequantize']


def _ieee_float_max_representable_value(exponent_bits, mantissa_bits):
exponent_max = 2 ** exponent_bits - 1
exponent_bias = exponent_max // 2
return (2 - 2**-mantissa_bits) * 2 ** (exponent_max - exponent_bias - 1)


_IEEE_FLOAT16_EXPONENT_BITS = 5
_IEEE_FLOAT16_MANTISSA_BITS = 10
assert _ieee_float_max_representable_value(_IEEE_FLOAT16_EXPONENT_BITS, _IEEE_FLOAT16_MANTISSA_BITS) == \
torch.finfo(torch.float16).max

_BFLOAT16_EXPONENT_BITS = 8
_BFLOAT16_MANTISSA_BITS = 7
assert _ieee_float_max_representable_value(_BFLOAT16_EXPONENT_BITS, _BFLOAT16_MANTISSA_BITS) == \
torch.finfo(torch.bfloat16).max


class FloatQuantizerBase(QuantizerBase):
class FloatQuantizeDequantize(QuantizerBase): # pylint: disable=abstract-method
"""
Base class for float quantization modules.
Float quantizer
:param exponent_bits: Number of exponent bits to simulate.
:param mantissa_bits: Number of mantissa bits to simulate.
:param dtype: torch.dtype to simulate. This argument is mutually exclusive with
exponent_bits and mantissa_bits.
:param encoding_analyzer: If specified, the maximum value to represent
will be determined dynamically based on the input statistics
for finer precision.
"""
def __init__(self, *args, **kwargs):
# pylint: disable=super-init-not-called
raise NotImplementedError
maxval: torch.Tensor

def __init__(self,
exponent_bits: int = None,
mantissa_bits: int = None,
dtype: torch.dtype = None,
encoding_analyzer: EncodingAnalyzer = None):
super().__init__()

if dtype is None:
if exponent_bits is None or mantissa_bits is None:
raise ValueError('Neither "dtype" nor "exponent/mantissa_bits" was specified.')

if dtype is not None:
if exponent_bits is not None or mantissa_bits is not None:
raise ValueError(
'Argument "dtype" is mutually exclusive with "exponent/mantissa_bits".')

if dtype not in (torch.half, torch.float16, torch.bfloat16):
raise ValueError(
f"Float quantizer only supports torch.float16 and torch.bfloat16. Got {dtype}.")

if dtype in (torch.half, torch.float16):
exponent_bits = _IEEE_FLOAT16_EXPONENT_BITS
mantissa_bits = _IEEE_FLOAT16_MANTISSA_BITS
else:
exponent_bits = _BFLOAT16_EXPONENT_BITS
mantissa_bits = _BFLOAT16_MANTISSA_BITS

self.exponent_bits = exponent_bits
self.mantissa_bits = mantissa_bits
self.encoding_analyzer = encoding_analyzer

if self.encoding_analyzer:
shape = self.encoding_analyzer.observer.shape
maxval = _ieee_float_max_representable_value(exponent_bits, mantissa_bits)
self.register_buffer('maxval', torch.full(shape, maxval))
else:
self.register_buffer('maxval', None)

@property
def bitwidth(self):
"""
Returns bitwidth of the quantizer
"""
return self.exponent_bits + self.mantissa_bits + 1

def is_float16(self):
"""
Returns true if current configuration simulates IEEE float16
"""
return self.exponent_bits == _IEEE_FLOAT16_EXPONENT_BITS and \
self.mantissa_bits == _IEEE_FLOAT16_MANTISSA_BITS

def is_bfloat16(self):
"""
Returns true if current configuration simulates bfloat16
"""
return self.exponent_bits == _BFLOAT16_EXPONENT_BITS and \
self.mantissa_bits == _BFLOAT16_MANTISSA_BITS

def get_encodings(self) -> Optional[List[Dict]]:
return [{'bitwidth': self.bitwidth, 'dtype': 'float'}]

@contextlib.contextmanager
def compute_encodings(self):
"""
Observe inputs and update quantization parameters based on the input statistics.
During ``compute_encodings`` is enabled, the quantizer forward pass performs
dynamic quantization using the batch statistics.
"""
if not self.encoding_analyzer:
yield
return

original_forward = self.forward

@functools.wraps(original_forward)
def forward_wrapper(input):
batch_statistics = self.encoding_analyzer.update_stats(input)
dynamic_min, dynamic_max =\
self.encoding_analyzer.compute_encodings_from_stats(batch_statistics,
self.bitwidth,
is_symmetric=False)
dynamic_absmax = torch.maximum(dynamic_min.abs(), dynamic_max.abs())
with patch_attr(self, 'maxval', dynamic_absmax):
return original_forward(input)

try:
with patch_attr(self, 'forward', forward_wrapper):
yield
except: # pylint: disable=try-except-raise
raise
else:
try:
min, max = self.encoding_analyzer.compute_encodings(self.bitwidth,
is_symmetric=False)
except StatisticsNotFoundError:
return

if min is None or max is None:
return

absmax = torch.maximum(min.abs(), max.abs()).expand_as(self.maxval)
with torch.no_grad():
self.maxval.copy_(absmax)

finally:
self.encoding_analyzer.reset_stats()

def forward(self, input: torch.Tensor):
"""
:param input: Input to quantize and dequantize
:return: Quantize-dequantized output
"""
maxval = self.maxval
exponent_bits = self.exponent_bits
mantissa_bits = self.mantissa_bits

if maxval is None:
if self.is_float16() or self.is_bfloat16():
# Fast forward using type casting
orig_dtype = input.dtype
dtype = torch.float16 if self.is_float16() else torch.bfloat16
return input.to(dtype).to(orig_dtype)

maxval = _ieee_float_max_representable_value(exponent_bits, mantissa_bits)

return fake_cast_to_ieee_float(input, maxval, exponent_bits, mantissa_bits)
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@

from aimet_common.defs import QuantScheme, QuantizationDataType, MAP_ROUND_MODE_TO_PYMO
from aimet_common.utils import AimetLogger, log_with_error_and_assert_if_false
from aimet_torch.experimental.v2.quantization.quantizers.float import FloatQuantizeDequantize
from aimet_torch.utils import get_v1_quant_scheme_for_initialization
from aimet_torch.qc_quantize_op import QcQuantizeOpMode, QcQuantizeWrapper, StaticGridQuantWrapper, tensor_quantizer_factory
from aimet_torch.tensor_quantizer import TensorQuantizer, StaticGridPerChannelQuantizer
from aimet_torch.experimental.v2.nn.fake_quant import FakeQuantizationMixin
from aimet_torch.experimental.v2.quantization.quantizers.affine import QuantizeDequantize
from aimet_torch.experimental.v2.quantization.encoding_analyzer import MinMaxEncodingAnalyzer, PercentileEncodingAnalyzer
import aimet_torch.fp_quantization as v1_fp_quantization


logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant)
Expand Down Expand Up @@ -266,8 +268,6 @@ def _validate_quantizer_properties(self):
assert not self.use_unsigned_symmetric, "Unsigned symmetric is not supported in quantsim v1.5"
assert not self.is_unsigned_symmetric, "Unsigned symmetric is not supported in quantsim v1.5"

assert self.data_type == QuantizationDataType.int, "Only int quantization is supported in quantsim v1.5"

def _get_v2_encoding_analyzer(self, shape):
"""
Converts v1 quant scheme into v2 quant scheme.
Expand Down Expand Up @@ -305,8 +305,23 @@ def realize(self) -> Optional[QuantizeDequantize]:

encoding_analyzer = self._get_v2_encoding_analyzer(quantizer_param_shape)

return QuantizeDequantize(quantizer_param_shape, self.bitwidth,
self.use_symmetric_encodings, encoding_analyzer)
if self.data_type == QuantizationDataType.int:
quantizer = QuantizeDequantize(quantizer_param_shape, self.bitwidth,
self.use_symmetric_encodings, encoding_analyzer)
else:
if self.bitwidth == 16:
quantizer = FloatQuantizeDequantize(dtype=torch.float16)
else:
assert self.bitwidth == 8
mantissa_bits = v1_fp_quantization.NUM_MANTISSA_BITS
exponent_bits = 7 - mantissa_bits
quantizer = FloatQuantizeDequantize(exponent_bits, mantissa_bits,
encoding_analyzer=encoding_analyzer)
# Float quantizers are not trainable in V1 quantsim
for param in quantizer.parameters():
param.requires_grad = False

return quantizer

def _set_internal_quantizer_properties(self, quantizer: TensorQuantizer):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,44 +114,58 @@ def patch_attr(obj, attr_name, new_attr)-> _ContextManager:
"""
Temporarily overwrite object attribute
"""
if isinstance(obj, torch.nn.Module):
if attr_name in obj._parameters or attr_name in obj._buffers: # pylint: disable=protected-access
return _patch_param_or_buffer(obj, attr_name, new_attr)

old_attr = getattr(obj, attr_name)
action = lambda: setattr(obj, attr_name, new_attr)
cleanup = lambda: setattr(obj, attr_name, old_attr)
return _ContextManager(action, cleanup)


def patch_param(module: torch.nn.Module, param_name: str, new_param: torch.Tensor) -> _ContextManager:
def _patch_param_or_buffer(module: torch.nn.Module,
param_or_buffer_name: str,
new_param_or_buffer: torch.Tensor):
"""
Temporarily substitute the reference to the a parameter with the quantized parameter.
Under the scope of this function, ``getattr(module, param_name)`` will return
``new_param`` instead of the original parameter.
Under the scope of this function, ``getattr(module, param_or_buffer_name)`` will return
``new_param_or_buffer`` instead of the original parameter.
:param module: Module that owns the parameter
:param param_name: Name of the parameter
:param new_param: New parameter to replace the original parameter
:param param_or_buffer_name: Name of the parameter
:param new_param_or_buffer: New parameter to replace the original parameter
"""
original_param = getattr(module, param_name)
if original_param is not None:
assert _is_expandable(new_param.shape, original_param.shape)
new_param = new_param.expand_as(original_param)
# pylint: disable=protected-access

orig_param_or_buffer = getattr(module, param_or_buffer_name)
if orig_param_or_buffer is not None:
assert _is_expandable(new_param_or_buffer.shape, orig_param_or_buffer.shape)
new_param_or_buffer = new_param_or_buffer.expand_as(orig_param_or_buffer)

# Modify module.__dict__.
# module.__dict__ is the primary lookup table which has higher priority than __getattr__ method.
# Once we overwrite module.__dict__[param_name] with quantized_params,
# getattr(module, param_name) will return module.__dict__[param_name] directly
# Once we overwrite module.__dict__[param_or_buffer_name] with quantized_params,
# getattr(module, param_or_buffer_name) will return module.__dict__[param_or_buffer_name] directly
# without falling back to torch.nn.Module's __getattr__ method which returns
# the original parameter stored in module._parameters.
action = lambda: module.__dict__.update({param_name: new_param})
# the original parameter stored in module._parameters or module._buffers.
action = lambda: module.__dict__.update({param_or_buffer_name: new_param_or_buffer})

if param_name in module.__dict__:
if param_or_buffer_name in module.__dict__:
# Some non-standard modules (e.g. replicas of torch.nn.DataParallel) store their parameters
# directly to module.__dict__. In that case, the cleanup function should restore the dict
# so that module.__dict__[param_name] points back to the original parameter again.
assert module.__dict__[param_name] is original_param
cleanup = lambda: module.__dict__.update({param_name: original_param})
# so that module.__dict__[param_or_buffer_name] points back to the original parameter again.
assert module.__dict__[param_or_buffer_name] is orig_param_or_buffer
cleanup = lambda: module.__dict__.update({param_or_buffer_name: orig_param_or_buffer})
else:
assert module._parameters[param_name] is original_param # pylint: disable=protected-access
cleanup = lambda: module.__dict__.pop(param_name)
if param_or_buffer_name in module._parameters:
assert module._parameters[param_or_buffer_name] is orig_param_or_buffer
elif param_or_buffer_name in module._buffers:
assert module._buffers[param_or_buffer_name] is orig_param_or_buffer
else:
raise RuntimeError(f"'{param_or_buffer_name}' is not a valid name of parameter of buffer of {type(module)}.")

cleanup = lambda: module.__dict__.pop(param_or_buffer_name)


return _ContextManager(action, cleanup)
Expand Down
24 changes: 21 additions & 3 deletions TrainingExtensions/torch/src/python/aimet_torch/fp_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,21 +165,39 @@ def quantize_to_fp8(x_float: torch.Tensor,
new_shape[per_channel_axis] = -1
maxval = maxval.view(new_shape)

return fake_cast_to_ieee_float(x_float, maxval, exponent_bits, mantissa_bits)


def fake_cast_to_ieee_float(x_float, maxval, exponent_bits, mantissa_bits):
"""
Fake-cast to the given exponent and mantissa bits based on IEEE float representation.
IEEE float representation follows the following equation:
maximum_representiable_value = (2 - 2**-M) * 2 ** (2**E - bias - 2)
(E: exponent bits, M: mantissa bits)
This function derives the bias from exponent bits, mantissa bits, and
maximum representable value based on the above equation.
"""
def log2(x):
import numpy as np
if isinstance(x, torch.Tensor):
return torch.log2(x)
return np.log2(x)
# Math explanation of what happens here:
# Bias is computed from maxval: $B=2^E - \log_2(M) + \log_2(2 - 2^{-M}) - 1$
# This follows from maxval $M=(2 - 2^{-M}) \cdot 2^{2^E-1-B}$.
bias = 2 ** exponent_bits - torch.log2(maxval) + torch.log2(2 - 2 ** (-mantissa_bits)) - 1
bias = 2 ** exponent_bits - log2(maxval) + log2(2 - 2 ** (-mantissa_bits)) - 1

# Ensure no values are greater than the maximum value represented by an 8 bit float system
# with M mantissa and E exponent bits. torch.min/torch.max are used to allow gradients to
# flow to maxval
x_clipped = torch.min(torch.max(x_float, -maxval), maxval)
x_clipped = x_float.clamp(-maxval, maxval)

# FP quantization scale is determined per-element, and is computed as
# \log_2 s = \left\lfloor \log_2 |x_c| + B \right\rfloor - M - B
# the addition of bias inside the floor and subtraction outside ensures that a
# tensor scaling $\alpha \neq 1$ is correctly incorporated
log_scales = torch.floor(torch.log2(torch.abs(x_clipped)) + bias).detach()
log_scales = torch.floor(log2(torch.abs(x_clipped)) + bias).detach()

# This ensures scales are never smaller than the subnormal scale
log_scales = torch.clamp(log_scales, 1.)
Expand Down
Loading

0 comments on commit 1554d42

Please sign in to comment.