Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Logit-based A/B test #2701

Merged
merged 4 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
# =============================================================================
"""Fake-quantized modules"""

import contextlib
import itertools
from collections import OrderedDict
from typing import Type, Optional, Tuple, List, Dict

Expand All @@ -46,6 +48,7 @@

from aimet_torch.experimental.v2.nn.quant_base import BaseQuantizationMixin
from aimet_torch.experimental.v2.quantization.quantizers import QuantizerBase
from aimet_torch.experimental.v2.utils import patch_attr
import aimet_torch.elementwise_ops as aimet_ops


Expand All @@ -72,6 +75,22 @@ class FakeQuantizationMixin(BaseQuantizationMixin): # pylint: disable=abstract-m
cls_to_qcls = OrderedDict() # ouantized class -> original class
qcls_to_cls = OrderedDict() # original class -> quantized class

@contextlib.contextmanager
def compute_encodings(self):
def no_op(input: Tensor): # pylint: disable=redefined-builtin
return input

with contextlib.ExitStack() as stack:
for quantizer in itertools.chain(self.input_quantizers, self.output_quantizers):
if not quantizer:
continue
# Set input/output quantizers into pass-through mode during compute_encodings
# NOTE: This behavior is for backawrd-compatibility with V1 quantsim.
stack.enter_context(patch_attr(quantizer, 'forward', no_op))

with super().compute_encodings():
yield

def export_input_encodings(self) -> List[List[Dict]]:
"""
Returns a list of input encodings, each represented as a List of Dicts
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,6 @@ def _compute_param_encodings(self, overwrite: bool):
def compute_encodings(self):
"""
Observe inputs and update quantization parameters based on the input statistics.
During ``compute_encodings`` is enabled, the input/output quantizers will forward perform
dynamic quantization using the batch statistics.
"""
self._compute_param_encodings(overwrite=True)

Expand Down
7 changes: 6 additions & 1 deletion TrainingExtensions/torch/src/python/aimet_torch/quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,12 @@ def replace_wrappers_for_quantize_dequantize(self):
"""
if self._quant_scheme == QuantScheme.training_range_learning_with_tf_init or self._quant_scheme == \
QuantScheme.training_range_learning_with_tf_enhanced_init:
device = utils.get_device(self.model)
try:
device = utils.get_device(self.model)
except StopIteration:
# Model doesn't have any parameter.
# Set device to cpu by default.
device = torch.device('cpu')

self._replace_quantization_wrapper(self.model, device)

Expand Down
36 changes: 36 additions & 0 deletions TrainingExtensions/torch/test/python/experimental/v2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# -*- mode: python -*-
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
# SPDX-License-Identifier: BSD-3-Clause
#
# @@-COPYRIGHT-END-@@
# =============================================================================
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# -*- mode: python -*-
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
# SPDX-License-Identifier: BSD-3-Clause
#
# @@-COPYRIGHT-END-@@
# =============================================================================
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from aimet_torch.experimental.v2.quantization.quantizers.affine import QuantizeDequantize
from aimet_torch.experimental.v2.quantization.quantizers.float import FloatQuantizeDequantize

from models_.models_to_test import SingleResidual, QuantSimTinyModel, MultiInput, SingleResidualWithModuleAdd, \
from ..models_.models_to_test import SingleResidual, QuantSimTinyModel, MultiInput, SingleResidualWithModuleAdd, \
SingleResidualWithAvgPool, ModelWithBertCustomLayerNormGelu


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
import os
import json

# from aimet_torch.experimental.v2.quantization.wrappers.quantization_mixin import _QuantizationMixin
import aimet_torch.experimental.v2.nn as aimet_nn
from aimet_torch.experimental.v2.nn.fake_quant import FakeQuantizationMixin
from aimet_torch.experimental.v2.quantization.quantizers.affine import QuantizeDequantize
Expand All @@ -50,11 +49,10 @@
from aimet_torch import onnx_utils
from aimet_torch.quantsim import QuantizationSimModel, OnnxExportApiArgs

from models_.models_to_test import (
from ..models_.models_to_test import (
SimpleConditional,
ModelWithTwoInputs,
ModelWith5Output,
ModuleWith5Output,
SoftMaxAvgPoolModel,
)

Expand Down Expand Up @@ -247,22 +245,6 @@ def test_multi_output_onnx_op(self):
model = ModelWith5Output()
dummy_input = torch.randn(1, 3, 224, 224)
sim_model = copy.deepcopy(model)

@FakeQuantizationMixin.implements(ModuleWith5Output)
class FakeQuantizationMixinWithDisabledOutput(FakeQuantizationMixin, ModuleWith5Output):
def __quant_init__(self):
super().__quant_init__()
self.output_quantizers = torch.nn.ModuleList([None, None, None, None, None])

def quantized_forward(self, input):
if self.input_quantizers[0]:
input = self.input_quantizers[0](input)
outputs = super().forward(input)
return tuple(
quantizer(out) if quantizer else out
for out, quantizer in zip(outputs, self.output_quantizers)
)

sim_model.cust = FakeQuantizationMixin.from_module(sim_model.cust)
sim_model.cust.input_quantizers[0] = QuantizeDequantize((1,),
bitwidth=8,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@
import tempfile
import pytest
import torch
import random
import numpy as np

from models_ import models_to_test
from ..models_ import models_to_test

from aimet_common.defs import QuantScheme

Expand Down Expand Up @@ -189,12 +191,21 @@ def config_path(request):
yield temp_config_path


@pytest.mark.skip("Skip tests until v2 implementation is done")
def set_seed(seed):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)


@pytest.mark.parametrize('quant_scheme', [QuantScheme.post_training_tf,
QuantScheme.post_training_percentile,
QuantScheme.training_range_learning_with_tf_init])
class TestCompareV1QuantsimAndV2Quantsim:
QuantScheme.training_range_learning_with_tf_init,
# QuantScheme.post_training_percentile, # TODO: not implemented
# QuantScheme.training_range_learning_with_tf_init, # TODO: not implemented
])
@pytest.mark.parametrize('seed', range(3))
class TestQuantsimLogits:
@staticmethod
@torch.no_grad()
def check_qsim_logit_consistency(config, quant_scheme, model, dummy_input):
with tempfile.TemporaryDirectory() as temp_dir:
config_path = os.path.join(temp_dir, "quantsim_config.json")
Expand All @@ -211,43 +222,48 @@ def check_qsim_logit_consistency(config, quant_scheme, model, dummy_input):
default_output_bw=16,
config_file=config_path)

v1_sim.compute_encodings(lambda sim_model, _: sim_model(dummy_input),
if isinstance(dummy_input, torch.Tensor):
dummy_input = (dummy_input,)

v1_sim.compute_encodings(lambda sim_model, _: sim_model(*dummy_input),
forward_pass_callback_args=None)

v2_sim.compute_encodings(lambda sim_model, _: sim_model(dummy_input),
v2_sim.compute_encodings(lambda sim_model, _: sim_model(*dummy_input),
forward_pass_callback_args=None)

v1_logits = v1_sim.model(dummy_input)
v2_logits = v2_sim.model(dummy_input)
v1_logits = v1_sim.model(*dummy_input)
v2_logits = v2_sim.model(*dummy_input)

if isinstance(v1_logits, list):
assert len(v1_logits) == len(v2_logits)
for idx in range(len(v1_logits)):
assert torch.allclose(v1_logits[idx], v2_logits[idx])
for v1_logit, v2_logit in zip(v1_logits, v2_logits):
tick = (v1_logit.max() - v1_logit.min()) / (2**16 - 1) # Tolerate off-by-one precision error
assert torch.allclose(v1_logit, v2_logit, rtol=1e-3, atol=tick)
else:
assert torch.allclose(v1_logits, v2_logits)
tick = (v1_logits.max() - v1_logits.min()) / (2**16 - 1) # Tolerate off-by-one precision error
assert torch.allclose(v1_logits, v2_logits, rtol=1e-3, atol=tick)

@pytest.mark.parametrize('model_and_input_shape', [(models_to_test.SingleResidual, (1, 3, 32, 32)),
@pytest.mark.parametrize('model_cls,input_shape', [(models_to_test.SingleResidual, (1, 3, 32, 32)),
(models_to_test.SoftMaxAvgPoolModel, (1, 4, 256, 512)),
(models_to_test.QuantSimTinyModel, (1, 3, 32, 32))])
def test_default_config(self, model_and_input_shape, quant_scheme):
model_cls, input_shape = model_and_input_shape
def test_default_config(self, model_cls, input_shape, quant_scheme, seed):
set_seed(seed)
model = model_cls()
dummy_input = torch.randn(input_shape)
self.check_qsim_logit_consistency(CONFIG_DEFAULT, quant_scheme, model, dummy_input)

@pytest.mark.parametrize('model_and_input_shape', [(models_to_test.SingleResidual, (1, 3, 32, 32)),
@pytest.mark.parametrize('model_cls,input_shape', [(models_to_test.SingleResidual, (1, 3, 32, 32)),
(models_to_test.QuantSimTinyModel, (1, 3, 32, 32))])
def test_param_quant(self, model_and_input_shape, quant_scheme):
model_cls, input_shape = model_and_input_shape
def test_param_quant(self, model_cls, input_shape, quant_scheme, seed):
set_seed(seed)
model = model_cls()
dummy_input = torch.randn(input_shape)
self.check_qsim_logit_consistency(CONFIG_PARAM_QUANT, quant_scheme, model, dummy_input)

@pytest.mark.parametrize('model_and_input_shape', [(models_to_test.SingleResidual, (1, 3, 32, 32)),
@pytest.mark.parametrize('model_cls,input_shape', [(models_to_test.SingleResidual, (1, 3, 32, 32)),
(models_to_test.QuantSimTinyModel, (1, 3, 32, 32))])
def test_op_specific_quant(self, model_and_input_shape, quant_scheme):
model_cls, input_shape = model_and_input_shape
def test_op_specific_quant(self, model_cls, input_shape, quant_scheme, seed):
set_seed(seed)
model = model_cls()
dummy_input = torch.randn(input_shape)
# Check per-tensor quantization for conv op
Expand All @@ -256,18 +272,20 @@ def test_op_specific_quant(self, model_and_input_shape, quant_scheme):
# Check per-channel quantization for conv op
self.check_qsim_logit_consistency(CONFIG_OP_SPECIFIC_QUANT_PER_CHANNEL, quant_scheme, model, dummy_input)

def test_supergroup(self, quant_scheme):
def test_supergroup(self, quant_scheme, seed):
set_seed(seed)
model = models_to_test.QuantSimTinyModel()
dummy_input = torch.randn(1, 3, 32, 32)
self.check_qsim_logit_consistency(CONFIG_SUPERGROUP, quant_scheme, model, dummy_input)

def test_multi_input(self, quant_scheme):
def test_multi_input(self, quant_scheme, seed):
set_seed(seed)
model = models_to_test.MultiInput()
dummy_input = (torch.rand(1, 3, 32, 32), torch.rand(1, 3, 20, 20))
self.check_qsim_logit_consistency(CONFIG_DEFAULT, quant_scheme, model, dummy_input)

def test_multi_output(self, quant_scheme):
def test_multi_output(self, quant_scheme, seed):
set_seed(seed)
model = models_to_test.ModelWith5Output()
dummy_input = torch.randn(1, 3, 224, 224)
self.check_qsim_logit_consistency(CONFIG_DEFAULT, quant_scheme, model, dummy_input)

Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# -*- mode: python -*-
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
# SPDX-License-Identifier: BSD-3-Clause
#
# @@-COPYRIGHT-END-@@
# =============================================================================
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from torch import nn

from aimet_torch import elementwise_ops
from aimet_torch.experimental.v2.nn.fake_quant import FakeQuantizationMixin


class SimpleConditional(torch.nn.Module):
Expand Down Expand Up @@ -395,3 +396,19 @@ def forward(self, *inputs):
x = x.view(x.size(0), -1)
x = self.fc(x)
return x


@FakeQuantizationMixin.implements(ModuleWith5Output)
class FakeQuantizationModuleWith5Output(FakeQuantizationMixin, ModuleWith5Output):
def __quant_init__(self):
super().__quant_init__()
self.output_quantizers = torch.nn.ModuleList([None, None, None, None, None])

def quantized_forward(self, input):
if self.input_quantizers[0]:
input = self.input_quantizers[0](input)
outputs = super().forward(input)
return tuple(
quantizer(out) if quantizer else out
for out, quantizer in zip(outputs, self.output_quantizers)
)
Loading
Loading