Skip to content

Commit

Permalink
Gradual activation quantization in Keras (#1244)
Browse files Browse the repository at this point in the history
Add support for gradual activation quantization in Keras.
This mainly converts torch implementation of gradual activation quantization to a common implementation.

---------

Co-authored-by: reuvenp <reuvenp@altair-semi.com>
  • Loading branch information
reuvenperetz and reuvenp authored Oct 26, 2024
1 parent b253ebd commit 7ba396a
Show file tree
Hide file tree
Showing 25 changed files with 461 additions and 216 deletions.
9 changes: 3 additions & 6 deletions .github/workflows/run_keras_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,13 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install tensorflow==${{ inputs.tf-version }} sony-custom-layers
pip install pytest
pip install tensorflow==${{ inputs.tf-version }} sony-custom-layers pytest
- name: Run unittests
# Some tests are sensitive to memory because we use tf gradients on a multi-thread/process
# CPU environment (https://github.com/tensorflow/tensorflow/issues/41718).
# For this reason, if we run them in such an environment, we need to run them first non-parallel separately.
run: |
python -m unittest discover tests/keras_tests -v
- name: Run pytest
run: |
pytest tests_pytest/keras
5 changes: 5 additions & 0 deletions .github/workflows/run_tests_suite_coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ jobs:
source tf_env/bin/activate
coverage run --parallel-mode -m --omit "*__init__.py" --include "model_compression_toolkit/**/*.py" pytest tests_pytest/keras
- name: Run keras pytest
run: |
source tf_env/bin/activate
coverage run --parallel-mode -m --omit "*__init__.py" --include "model_compression_toolkit/**/*.py" pytest tests_pytest/keras
- name: Set up Pytorch environment
run: |
python -m venv torch_env
Expand Down
9 changes: 8 additions & 1 deletion model_compression_toolkit/gptq/common/gptq_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,11 @@

# GPTQ config constant
QUANT_PARAM_LEARNING_STR = 'quantization_parameter_learning'
MAX_LSB_STR = 'max_lsbs_change_map'
MAX_LSB_STR = 'max_lsbs_change_map'

# GPTQ learning hyperparameters
LR_DEFAULT = 3e-2
LR_REST_DEFAULT = 1e-4
LR_BIAS_DEFAULT = 1e-3
LR_QUANTIZATION_PARAM_DEFAULT = 1e-3
GPTQ_MOMENTUM = 0.9
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,23 @@
# limitations under the License.
# ==============================================================================
from functools import partial
from typing import Callable
from typing import Callable, Any

from model_compression_toolkit.gptq import GradientPTQConfig, QFractionLinearAnnealingConfig
from model_compression_toolkit.trainable_infrastructure import BasePytorchTrainableQuantizer

from model_compression_toolkit.trainable_infrastructure.pytorch.annealing_schedulers import LinearAnnealingScheduler
from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import BaseTrainableQuantizer


def get_gradual_activation_quantizer_wrapper_factory(gptq_config: GradientPTQConfig,
get_total_grad_steps_fn: Callable[[], int]) \
-> Callable[[BasePytorchTrainableQuantizer], 'GradualActivationQuantizerWrapper']:
get_total_grad_steps_fn: Callable[[], int],
fw_linear_annealing_scheduler: type) \
-> Callable[[Any], 'GradualActivationQuantizerWrapper']:
"""
Get a factory for 'GradualActivationQuantizerWrapper'.
Args:
gptq_config: GPTQ configuration.
get_total_grad_steps_fn: a callable to obtain the total expected number of gradient steps.
fw_linear_annealing_scheduler: LinearAnnealingScheduler implementation of the framework (tf/pytorch).
Returns:
A factory function to build 'GradualActivationQuantizerWrapper' from Quantizer.
Expand All @@ -40,9 +40,9 @@ def get_gradual_activation_quantizer_wrapper_factory(gptq_config: GradientPTQCon
annealing_cfg = gptq_config.gradual_activation_quantization_config.q_fraction_scheduler_policy
if isinstance(annealing_cfg, QFractionLinearAnnealingConfig):
t_end = annealing_cfg.end_step or get_total_grad_steps_fn()
factor_scheduler = LinearAnnealingScheduler(t_start=annealing_cfg.start_step, t_end=t_end,
initial_val=annealing_cfg.initial_q_fraction,
target_val=annealing_cfg.target_q_fraction)
factor_scheduler = fw_linear_annealing_scheduler(t_start=annealing_cfg.start_step, t_end=t_end,
initial_val=annealing_cfg.initial_q_fraction,
target_val=annealing_cfg.target_q_fraction)
else:
raise ValueError(f'Unknown annealing policy {annealing_cfg}')

Expand All @@ -64,7 +64,7 @@ class GradualActivationQuantizerWrapper:
quantizer: quantizer to wrap.
q_fraction_scheduler: a callable that accepts a gradient step and returns the corresponding quantized fraction.
"""
def __init__(self, quantizer: BasePytorchTrainableQuantizer, q_fraction_scheduler: Callable[[int], float]):
def __init__(self, quantizer: BaseTrainableQuantizer, q_fraction_scheduler: Callable[[int], float]):
self.quantizer = quantizer
self.q_fraction_scheduler = q_fraction_scheduler
self.step_cnt = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,47 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Callable

from model_compression_toolkit.gptq import RoundingType, GradientPTQConfig
from model_compression_toolkit.gptq.pytorch.quantizer.soft_rounding.soft_quantizer_reg import \
SoftQuantizerRegularization
from model_compression_toolkit.trainable_infrastructure.pytorch.annealing_schedulers import LinearAnnealingScheduler
from tqdm import tqdm
from typing import Callable, Type

from model_compression_toolkit.gptq import RoundingType, GradientPTQConfig

# Common warmup fraction
WARMUP_STEP_FRACTION = 0.2

def get_regularization(gptq_config: GradientPTQConfig, get_total_grad_steps_fn: Callable[[], int]) -> Callable:

def get_regularization(gptq_config: GradientPTQConfig,
get_total_grad_steps_fn: Callable[[], int],
SoftQuantizerRegularizationFWClass: Type,
LinearAnnealingSchedulerFWClass: Type) -> Callable:
"""
Returns a function that computes the regularization term for GPTQ training based on the given
rounding type in the GPTQ configuration.
Args:
gptq_config: A GPTQ configuration.
get_total_grad_steps_fn: a callable to obtain the total expected number of gradient steps.
SoftQuantizerRegularizationFWClass: The class to use for soft quantizer regularization (framework-specific).
LinearAnnealingSchedulerFWClass: The class to use for the annealing scheduler (framework-specific).
Returns: A function for computing the regularization. If there is no regularization function defined for the given
rounding type, then it returns a function that just returns 0.
Returns:
Callable: A function for computing the regularization. If there is no regularization function
defined for the given rounding type, then it returns a function that just returns 0.
"""
if gptq_config.rounding_type == RoundingType.SoftQuantizer:
total_gradient_steps = get_total_grad_steps_fn()
t_start = int(WARMUP_STEP_FRACTION * total_gradient_steps)
scheduler = LinearAnnealingScheduler(t_start=t_start, t_end=total_gradient_steps, initial_val=20, target_val=2)
return SoftQuantizerRegularization(scheduler)

# Directly initializing the scheduler within the method
scheduler = LinearAnnealingSchedulerFWClass(
t_start=t_start,
t_end=total_gradient_steps,
initial_val=20,
target_val=2
)

# Return the framework-specific soft quantizer regularization
return SoftQuantizerRegularizationFWClass(scheduler)
else:
return lambda *args, **kwargs: 0
37 changes: 26 additions & 11 deletions model_compression_toolkit/gptq/keras/gptq_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,14 @@
from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
from model_compression_toolkit.core.keras.data_util import data_gen_to_dataloader
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
from model_compression_toolkit.gptq.common.gradual_activation_quantization import \
get_gradual_activation_quantizer_wrapper_factory
from model_compression_toolkit.gptq.common.regularization_factory import get_regularization
from model_compression_toolkit.gptq.keras.quantizer.quantization_builder import quantization_builder
from model_compression_toolkit.logger import Logger
from mct_quantizers import KerasActivationQuantizationHolder
from model_compression_toolkit.trainable_infrastructure.common.util import get_total_grad_steps
from model_compression_toolkit.trainable_infrastructure.keras.annealing_schedulers import KerasLinearAnnealingScheduler

if version.parse(tf.__version__) >= version.parse("2.13"):
from keras.src.engine.base_layer import TensorFlowOpLayer
Expand All @@ -41,13 +46,12 @@
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.gptq.keras.graph_info import get_weights_for_loss, get_gptq_trainable_parameters
from model_compression_toolkit.gptq.keras.quantizer.regularization_factory import get_regularization
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
import numpy as np
import copy
from model_compression_toolkit.core.keras.constants import BIAS, USE_BIAS

from model_compression_toolkit.gptq.keras.quantizer.soft_rounding.soft_quantizer_reg import SoftQuantizerRegularization

class KerasGPTQTrainer(GPTQTrainer):
"""
Expand Down Expand Up @@ -78,6 +82,15 @@ def __init__(self,
hessian_info_service: HessianScoresService for fetching and computing Hessian's approximation scores.
"""

def _get_total_grad_steps():
return get_total_grad_steps(representative_data_gen) * gptq_config.n_epochs

# This must be set before the model building (as it is required for activation holder construction),
# which occurs in the base constructor.
self.gradual_act_quantizer_wrapper_factory = get_gradual_activation_quantizer_wrapper_factory(
gptq_config, _get_total_grad_steps, KerasLinearAnnealingScheduler)

super().__init__(graph_float,
graph_quant,
gptq_config,
Expand Down Expand Up @@ -119,7 +132,10 @@ def __init__(self,

self.weights_for_average_loss = self._get_compare_points_loss_weights()

self.reg_func = get_regularization(self.gptq_config, representative_data_gen)
self.reg_func = get_regularization(self.gptq_config,
_get_total_grad_steps,
SoftQuantizerRegularization,
KerasLinearAnnealingScheduler)

def _get_compare_points_loss_weights(self):
""" Get compare points weights for the distillation loss. """
Expand Down Expand Up @@ -185,14 +201,13 @@ def get_activation_quantizer_holder(self, n: common.BaseNode) -> Callable:
_, activation_quantizers = quantization_builder(n, self.gptq_config) # TODO: split quantizers building into two functions: for weights and activations

# Holder by definition uses a single quantizer for the activation quantization
# thus we make sure this is the only possible case (unless it's a node with no activation
# quantization, which in this case has an empty list).
if len(activation_quantizers) == 1:
return KerasActivationQuantizationHolder(activation_quantizers[0])

Logger.critical(f"'KerasActivationQuantizationHolder' is designed to support a single quantizer, "
f"but {len(activation_quantizers)} quantizers were found for node '{n}'. "
f"Ensure only one quantizer is configured for each node's activation.")
# thus we make sure this is the only possible case.
if len(activation_quantizers) != 1:
Logger.critical(f"'KerasActivationQuantizationHolder' is designed to support a single quantizer, "
f"but {len(activation_quantizers)} quantizers were found for node '{n}'. "
f"Ensure only one quantizer is configured for each node's activation.")
quantizer = self.gradual_act_quantizer_wrapper_factory(activation_quantizers[0])
return KerasActivationQuantizationHolder(quantizer)

def build_gptq_model(self) -> Tuple[Model, UserInformation]:
"""
Expand Down
Loading

0 comments on commit 7ba396a

Please sign in to comment.