Skip to content

Commit

Permalink
Extend gptq get config API to enable external regularization factor a…
Browse files Browse the repository at this point in the history
…rgument (sony#897)

---------

Co-authored-by: Ofir Gordon <Ofir.Gordon@altair-semi.com>
  • Loading branch information
ofirgo and Ofir Gordon authored Dec 28, 2023
1 parent 4d4d3ac commit 5f98016
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 6 deletions.
8 changes: 6 additions & 2 deletions model_compression_toolkit/gptq/keras/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from packaging import version

from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF
from model_compression_toolkit.core.common.user_info import UserInformation
Expand Down Expand Up @@ -64,7 +65,8 @@ def get_keras_gptq_config(n_epochs: int,
optimizer_rest: OptimizerV2 = tf.keras.optimizers.Adam(learning_rate=LR_REST_DEFAULT),
loss: Callable = GPTQMultipleTensorsLoss(),
log_function: Callable = None,
use_hessian_based_weights: bool = True) -> GradientPTQConfigV2:
use_hessian_based_weights: bool = True,
regularization_factor: float = REG_DEFAULT) -> GradientPTQConfigV2:
"""
Create a GradientPTQConfigV2 instance for Keras models.
Expand All @@ -75,6 +77,7 @@ def get_keras_gptq_config(n_epochs: int,
loss (Callable): loss to use during fine-tuning. should accept 4 lists of tensors. 1st list of quantized tensors, the 2nd list is the float tensors, the 3rd is a list of quantized weights and the 4th is a list of float weights.
log_function (Callable): Function to log information about the gptq process.
use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss.
regularization_factor (float): A floating point number that defines the regularization factor.
returns:
a GradientPTQConfigV2 object to use when fine-tuning the quantized model using gptq.
Expand Down Expand Up @@ -106,7 +109,8 @@ def get_keras_gptq_config(n_epochs: int,
log_function=log_function,
train_bias=True,
optimizer_bias=bias_optimizer,
use_hessian_based_weights=use_hessian_based_weights)
use_hessian_based_weights=use_hessian_based_weights,
regularization_factor=regularization_factor)


def keras_gradient_post_training_quantization_experimental(in_model: Model,
Expand Down
9 changes: 7 additions & 2 deletions model_compression_toolkit/gptq/pytorch/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from model_compression_toolkit.core import common
from model_compression_toolkit.constants import FOUND_TORCH
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.constants import PYTORCH
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2
Expand Down Expand Up @@ -55,7 +56,8 @@ def get_pytorch_gptq_config(n_epochs: int,
optimizer_rest: Optimizer = Adam([torch.Tensor([])], lr=LR_REST_DEFAULT),
loss: Callable = multiple_tensors_mse_loss,
log_function: Callable = None,
use_hessian_based_weights: bool = True) -> GradientPTQConfigV2:
use_hessian_based_weights: bool = True,
regularization_factor: float = REG_DEFAULT) -> GradientPTQConfigV2:
"""
Create a GradientPTQConfigV2 instance for Pytorch models.
Expand All @@ -66,6 +68,7 @@ def get_pytorch_gptq_config(n_epochs: int,
loss (Callable): loss to use during fine-tuning. should accept 4 lists of tensors. 1st list of quantized tensors, the 2nd list is the float tensors, the 3rd is a list of quantized weights and the 4th is a list of float weights.
log_function (Callable): Function to log information about the gptq process.
use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss.
regularization_factor (float): A floating point number that defines the regularization factor.
returns:
a GradientPTQConfigV2 object to use when fine-tuning the quantized model using gptq.
Expand All @@ -87,7 +90,9 @@ def get_pytorch_gptq_config(n_epochs: int,
"""
bias_optimizer = torch.optim.SGD([torch.Tensor([])], lr=LR_BIAS_DEFAULT, momentum=GPTQ_MOMENTUM)
return GradientPTQConfigV2(n_epochs, optimizer, optimizer_rest=optimizer_rest, loss=loss,
log_function=log_function, train_bias=True, optimizer_bias=bias_optimizer, use_hessian_based_weights=use_hessian_based_weights)
log_function=log_function, train_bias=True, optimizer_bias=bias_optimizer,
use_hessian_based_weights=use_hessian_based_weights,
regularization_factor=regularization_factor)


def pytorch_gradient_post_training_quantization_experimental(model: Module,
Expand Down
6 changes: 5 additions & 1 deletion tests/keras_tests/function_tests/test_get_gptq_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,11 @@ def setUp(self):
rounding_type=RoundingType.STE,
gptq_quantizer_params_override={MAX_LSB_STR: DefaultDict({}, 1)}),
get_keras_gptq_config(n_epochs=1,
optimizer=tf.keras.optimizers.Adam())]
optimizer=tf.keras.optimizers.Adam()),
get_keras_gptq_config(n_epochs=1,
optimizer=tf.keras.optimizers.Adam(),
regularization_factor=0.001)]


pot_tp = generate_test_tp_model({'weights_quantization_method': QuantizationMethod.POWER_OF_TWO})
self.pot_weights_tpc = generate_keras_tpc(name="gptq_pot_config_test", tp_model=pot_tp)
Expand Down
3 changes: 2 additions & 1 deletion tests/pytorch_tests/function_tests/get_gptq_config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def run_test(self):
cc = CoreConfig(quantization_config=qc)

gptqv2_config = get_pytorch_gptq_config(n_epochs=1,
optimizer=torch.optim.Adam([torch.Tensor([])], lr=1e-4))
optimizer=torch.optim.Adam([torch.Tensor([])], lr=1e-4),
regularization_factor=0.001)
gptqv2_config.rounding_type = self.rounding_type
gptqv2_config.train_bias = self.train_bias

Expand Down

0 comments on commit 5f98016

Please sign in to comment.