diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v2/experimental/lora.py b/TrainingExtensions/torch/src/python/aimet_torch/v2/experimental/lora.py new file mode 100644 index 00000000000..4aed305aead --- /dev/null +++ b/TrainingExtensions/torch/src/python/aimet_torch/v2/experimental/lora.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python3 +# -*- mode: python -*- +# ============================================================================= +# @@-COPYRIGHT-START-@@ +# +# Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# SPDX-License-Identifier: BSD-3-Clause +# +# @@-COPYRIGHT-END-@@ +# ============================================================================= +""" Quantized LoRA layers """ + +import torch +from torch import nn +import peft.tuners.lora.layer as lora +from aimet_torch.v2.nn import QuantizationMixin, custom +from aimet_torch.v2.nn.true_quant import _dispatch + + +class _TensorDict(torch.nn.ParameterDict): # pylint: disable=abstract-method + def __setitem__(self, key, value): + if not isinstance(value, torch.Tensor): + value = torch.tensor(value) + + super().__setitem__(key, value.detach()) + + def __getitem__(self, key) -> torch.Tensor: + ret = super().__getitem__(key).detach() + setattr(ret, '_consumer', key) + return ret + + +@QuantizationMixin.implements(lora.Linear) +class QuantizedLinear(QuantizationMixin, lora.Linear): # pylint: disable=too-many-ancestors + """ + Quantized lora.Linear. + """ + + # NOTE: The implementation of this class is tightly dependent on below assumptions + # 1) LoRA scale (``self.scaling``) will be always multiplied with the output of lora adapters. + # 2) The scaled output of LoRA adapters will be always added to the output of the base layer. + + def __quant_init__(self): + super().__quant_init__() + + # Quantized lora linear itself doesn't need input/output quantizers. + self.input_quantizers = nn.ModuleList([]) + self.output_quantizers = nn.ModuleList([]) + + self.scaling = _TensorDict(self.scaling) + self.mul = nn.ModuleDict({ + adapter_name: custom.QuantizedMultiply() for adapter_name in self.lora_A.keys() + }) + self.add = nn.ModuleDict({ + adapter_name: custom.QuantizedAdd() for adapter_name in self.lora_A.keys() + }) + + def _mul(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + Implementation of elemeentwise add which will be dispatched in place of + torch.Tensor.add and torch.add during forward. + + This function will invoke self.mul (type: QuantizedMultipy) if any of x and y + is an entry of self.scaling. + Otherwise, it will fall back to normal torch.Tensor.mul. + """ + adapter_name = getattr(x, '_consumer', None) + + if adapter_name is None: + adapter_name = getattr(y, '_consumer', None) + + if adapter_name is not None: + # `x` or `y` is a scaling factor for adapter `adapter_name`. + # Dispatch self.mul[adapter_name] in place of regular torch.Tensor.mul + # so the scaling factor can be observed and quantzied properly + out = self.mul[adapter_name](x, y) + setattr(out, '_producer', adapter_name) + else: + # `x` or `y` is NOT a scaling factor. + # Fall back to normal torch.Tensor.mul + out = x * y + + return out + + def _add(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + Implementation of elemeentwise add which will be dispatched in place of + torch.Tensor.add and torch.add during forward. + + This function will invoke self.add (type: QuantizedAdd) if any of x and y + is the output of a lora adapter scaled by self.scaling. + Otherwise, it will fall back to normal torch.Tensor.add. + """ + adapter_name = getattr(x, '_producer', None) + + if adapter_name is None: + adapter_name = getattr(y, '_producer', None) + + if adapter_name is not None: + # `x` or `y` is an output of adapter `adapter_name`. + # Dispatch self.add[adapter_name] in place of regular torch.Tensor.add + # so the output can be observed and quantzied properly + out = self.add[adapter_name](x, y) + else: + # `x` or `y` is NOT an output of any adapter. + # Fall back to normal torch.Tensor.add + out = x + y + + return out + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: # pylint: disable=arguments-differ + with _dispatch(torch.Tensor.mul, self._mul), _dispatch(torch.mul, self._mul),\ + _dispatch(torch.Tensor.add, self._add), _dispatch(torch.add, self._add): + return super().forward(x, *args, **kwargs) + + def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora: bool = False): + raise NotImplementedError + + def set_scale(self, adapter, scale): + raise NotImplementedError + + def scale_layer(self, *args, **kwargs): + raise NotImplementedError + + def unscale_layer(self, *args, **kwargs): + raise NotImplementedError + + def merge(self, *args, **kwargs) -> None: + raise NotImplementedError + + def unmerge(self, *args, **kwargs) -> None: + raise NotImplementedError + + def get_delta_weight(self, adapter) -> torch.Tensor: + raise NotImplementedError diff --git a/TrainingExtensions/torch/test/python/v2/nn/test_lora.py b/TrainingExtensions/torch/test/python/v2/nn/test_lora.py new file mode 100644 index 00000000000..5cfd33cb434 --- /dev/null +++ b/TrainingExtensions/torch/test/python/v2/nn/test_lora.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python3 +# -*- mode: python -*- +# ============================================================================= +# @@-COPYRIGHT-START-@@ +# +# Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# SPDX-License-Identifier: BSD-3-Clause +# +# @@-COPYRIGHT-END-@@ +# ============================================================================= + +import pytest +import torch +from torch import nn +import peft.tuners.lora.layer as lora + +import aimet_torch.v2 as aimet +from aimet_torch.v2.quantization import affine +from aimet_torch.v2.quantization.base import QuantizerBase +from aimet_torch.v2.quantsim import QuantizationSimModel +from aimet_torch.v2.experimental import lora as qlora + + +class TestQuantizedLinear: + def test_quantsim_construction(self): + model = lora.Linear(nn.Linear(10, 10), adapter_name='adapter_0', r=1) + dummy_input = torch.randn(10, 10) + sim = QuantizationSimModel(model, dummy_input) + + """ + When: Create quantsim with lora.Linear + Then: 1) lora.Linear should be converted to QuantizedLinear + 2) Mul and Add modules should have input and output quantizers as necessary + 3) All lora adapters (lora_A, B) and base layer should be converted to aimet.nn.QuantizedLinear + """ + assert isinstance(sim.model, qlora.QuantizedLinear) + assert isinstance(sim.model.mul['adapter_0'].input_quantizers[1], affine.QuantizeDequantize) + assert isinstance(sim.model.mul['adapter_0'].output_quantizers[0], affine.QuantizeDequantize) + assert isinstance(sim.model.add['adapter_0'].output_quantizers[0], affine.QuantizeDequantize) + + lora_A = sim.model.lora_A["adapter_0"] + assert isinstance(lora_A, aimet.nn.QuantizedLinear) + assert isinstance(lora_A.param_quantizers['weight'], affine.QuantizeDequantize) + assert isinstance(lora_A.output_quantizers[0], affine.QuantizeDequantize) + + lora_B = sim.model.lora_B["adapter_0"] + assert isinstance(lora_B, aimet.nn.QuantizedLinear) + assert isinstance(lora_B.param_quantizers['weight'], affine.QuantizeDequantize) + assert isinstance(lora_B.output_quantizers[0], affine.QuantizeDequantize) + + base_layer = sim.model.base_layer + assert isinstance(base_layer, aimet.nn.QuantizedLinear) + assert isinstance(base_layer.param_quantizers['weight'], affine.QuantizeDequantize) + assert isinstance(base_layer.output_quantizers[0], affine.QuantizeDequantize) + + """ + When: compute_encodings + Then: All quantizers should be initialized + """ + sim.compute_encodings(lambda model, _: model(dummy_input), None) + + for qtzr in sim.model.modules(): + if isinstance(qtzr, QuantizerBase): + assert qtzr.is_initialized() + + @pytest.mark.skip(reason="To be discussed") + def test_update_layer(self): + """ + When: Add a new lora adapter with "update_layer" API + Then: The new added adapters should be aimet.nn.QuantizedLinear with + param and output quantizers instantiated as necessary + """ + model = lora.Linear(nn.Linear(10, 10), adapter_name='adapter_0', r=1) + dummy_input = torch.randn(10, 10) + sim = QuantizationSimModel(model, dummy_input) + + sim.model.update_layer("new_adapter", ...) + new_lora_a = sim.model.lora_A["new_adapter"] + new_lora_b = sim.model.lora_B["new_adapter"] + + assert isinstance(new_lora_a, aimet.nn.QuantizedLinear) + assert isinstance(new_lora_a.param_quantizers['weight'], affine.QuantizeDequantize) + assert isinstance(new_lora_a.output_quantizers[0], affine.QuantizeDequantize) + + assert isinstance(new_lora_b, aimet.nn.QuantizedLinear) + assert isinstance(new_lora_b.param_quantizers['weight'], affine.QuantizeDequantize) + assert isinstance(new_lora_b.output_quantizers[0], affine.QuantizeDequantize) + + assert isinstance(sim.model.mul['new_adapter'].input_quantizers[1], affine.QuantizeDequantize) + assert isinstance(sim.model.mul['new_adapter'].output_quantizers[0], affine.QuantizeDequantize) + assert isinstance(sim.model.add['new_adapter'].output_quantizers[0], affine.QuantizeDequantize)