From a310905898505f2c9f4f4d23024edff070b3b80f Mon Sep 17 00:00:00 2001 From: Hitarth Mehta Date: Tue, 31 Oct 2023 10:45:04 +0530 Subject: [PATCH] Add sequential MSE implementation (#2538) * Add sequential MSE implementation Signed-off-by: Hitarth Mehta * Refactor and reuse optimized activation sampling Signed-off-by: Hitarth Mehta --------- Signed-off-by: Hitarth Mehta --- .../adaround/activation_sampler.py | 127 ++++- .../aimet_torch/adaround/adaround_weight.py | 147 ++--- .../torch/src/python/aimet_torch/seq_mse.py | 504 ++++++++++++++++++ .../torch/test/python/test_seq_mse.py | 223 ++++++++ 4 files changed, 904 insertions(+), 97 deletions(-) create mode 100644 TrainingExtensions/torch/src/python/aimet_torch/seq_mse.py create mode 100644 TrainingExtensions/torch/test/python/test_seq_mse.py diff --git a/TrainingExtensions/torch/src/python/aimet_torch/adaround/activation_sampler.py b/TrainingExtensions/torch/src/python/aimet_torch/adaround/activation_sampler.py index a19b4d70a5b..484eb61be21 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/adaround/activation_sampler.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/adaround/activation_sampler.py @@ -38,18 +38,138 @@ """ Sample input to quantized wrapper module and output from original module for Adaround feature """ -from typing import Tuple, Union, List, Callable, Any +from typing import Tuple, Union, List, Callable, Any, Dict import torch from torch.utils.data import Dataset # Import AIMET specific modules from aimet_common.utils import AimetLogger -from aimet_torch.utils import ModuleData +from aimet_torch.utils import CachedDataset, ModuleData, get_named_module, cache_intermediate_datasets,\ + change_tensor_device_placement, in_eval_mode, save_to_cache from aimet_torch.qc_quantize_op import QcQuantizeWrapper +from aimet_torch.quantsim import QuantizationSimModel logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant) +def create_modulelist_for_group_modules(model: torch.nn.Module, sim: QuantizationSimModel, grouped_modules: Dict)\ + -> Tuple[List[torch.nn.ModuleList], List[torch.nn.ModuleList]]: + """ + Use torch.nn.ModuleList to group modules from a single block. + + :param model: FP32 model + :param sim: QuantizationSimModel object + :param grouped_modules: Group modules + :return: List of modulelist for FP32 and quant models + """ + sub_fp_models = [] + sub_sim_models = [] + for _, modules in grouped_modules.items(): + fp_modulelist = torch.nn.ModuleList() + quant_modulelist = torch.nn.ModuleList() + for name in modules: + fp_modulelist.append(get_named_module(model, name)) + quant_modulelist.append(get_named_module(sim.model, name)) + sub_fp_models.append(fp_modulelist) + sub_sim_models.append(quant_modulelist) + + return sub_fp_models, sub_sim_models + + +def get_block_inputs(model: torch.nn.Module, sim: QuantizationSimModel, + breakpoint_module_name: str, cached_dataset: CachedDataset, + cache_on_cpu: bool, forward_fn: Callable, num_batches: int, working_dir: str)\ + -> Union[Tuple[List, List], Tuple[CachedDataset, CachedDataset]]: + """ + Get inputs to block/module from FP32 and QuantizationSimModel models + + :param model: FP32 model + :param sim: QuantizationSimModel object + :param breakpoint_module_name: Breakpoint block/module name + :param cached_dataset: Cached dataset + :param cache_on_cpu: Whether to cache intermediate data on CPU or store to disk + :param forward_fn: adapter function that performs forward pass given a model and inputs + yielded from the data loader. The function expects model as first argument and inputs to model + as second argument. + :param num_batches: Number of batches + :param working_dir: Working to directory to save block inputs data to disk + :return: Inputs to block from FP32 and QuantizationSimModel models + """ + # Cache input data to first block from both FP32 and quant models + if cache_on_cpu: + cached_fp_dataset = cache_intermediate_datasets(cached_dataset, cache_on_cpu, model, + breakpoint_module_name, forward_fn) + cached_quant_dataset = cache_intermediate_datasets(cached_dataset, cache_on_cpu, + sim.model, breakpoint_module_name, forward_fn) + else: + fp32_cache_path = working_dir + 'fp32/' + quant_cache_path = working_dir + 'quant/' + cache_intermediate_datasets(cached_dataset, cache_on_cpu, model, breakpoint_module_name, + forward_fn, fp32_cache_path) + cache_intermediate_datasets(cached_dataset, cache_on_cpu, sim.model, breakpoint_module_name, + forward_fn, quant_cache_path) + cached_fp_dataset = CachedDataset(None, num_batches, fp32_cache_path) + cached_quant_dataset = CachedDataset(None, num_batches, quant_cache_path) + return cached_fp_dataset, cached_quant_dataset + + +def get_block_outputs(fp_block: torch.nn.ModuleList, quant_block: torch.nn.ModuleList, include_static_inputs: str, + cached_fp_dataset: List, cached_quant_dataset: List, + cache_on_cpu: bool, forward_fn: Callable, device: torch.device, working_dir: str): + """ + Get outputs from block/module from FP32 and QuantizationSimModel models and assign for next block/module. + + NOTE: "static_inputs" (like attention_mask, position_ids) remains the same across different blocks. + So, if "include_static_inputs" is set to True, then such inputs are reused. + + :param fp_block: ModuleList for fp32 modules + :param quant_block: ModuleList for quant modules + :param include_static_inputs: Flag to include "static_inputs" or not + :param cached_fp_dataset: Cached dataset for fp32 model + :param cached_quant_dataset: Cached dataset for quant model + :param cache_on_cpu: Whether to cache intermediate data on CPU or store to disk + :param forward_fn: Optional adapter function that performs forward pass given a model and inputs + yielded from the data loader. The function expects model as first argument and inputs to model as second argument. + :param device: torch device + :param working_dir: Working to directory to save block inputs data to disk + """ + # pylint: disable=too-many-locals, too-many-arguments + fp_block.to(device) + quant_block.to(device) + + fp_iterator = iter(cached_fp_dataset) + quant_iterator = iter(cached_quant_dataset) + for idx in range(len(cached_fp_dataset)): # pylint: disable=consider-using-enumerate + fp_inputs = change_tensor_device_placement(next(fp_iterator), device) + quant_inputs = change_tensor_device_placement(next(quant_iterator), device) + + with in_eval_mode(fp_block), in_eval_mode(quant_block), torch.no_grad(): + fp_outputs = forward_fn(fp_block, fp_inputs) + fp_outputs = fp_outputs[0].cpu() if isinstance(fp_outputs, (tuple, list)) else fp_outputs.cpu() + quant_outputs = forward_fn(quant_block, quant_inputs) + quant_outputs = quant_outputs[0].cpu() if isinstance(quant_outputs, (tuple, list)) else quant_outputs.cpu() + + # Check if the next ModuleList needs static inputs or not and assign + # the outputs (fp32/quant) from current block to be the input (fp32/quant) of next block + if include_static_inputs == "True": + fp_inputs[0], quant_inputs[0] = fp_outputs, quant_outputs + else: + fp_inputs, quant_inputs = [fp_outputs], [quant_outputs] + + # Cache the outputs on CPU or disk + if cache_on_cpu: + cached_fp_dataset[idx] = fp_inputs + cached_quant_dataset[idx] = quant_inputs + else: + fp32_cache_path = working_dir + 'fp32/' + quant_cache_path = working_dir + 'quant/' + save_to_cache(fp_inputs, fp32_cache_path, idx) + save_to_cache(quant_inputs, quant_cache_path, idx) + + fp_block.cpu() + quant_block.cpu() + + class ActivationSampler: """ For a module in the original model and the corresponding module in the weight quantized QuantSim model, @@ -73,7 +193,8 @@ def __init__(self, orig_module: torch.nn.Module, quant_module: QcQuantizeWrapper self._orig_module_collector = ModuleData(orig_model, orig_module, forward_fn) self._quant_module_collector = ModuleData(quant_model, quant_module, forward_fn) - def sample_and_place_all_acts_on_cpu(self, cached_dataset: Dataset, cached_quant_dataset: Dataset = None) -> Tuple[torch.Tensor, torch.Tensor]: + def sample_and_place_all_acts_on_cpu(self, cached_dataset: Dataset, + cached_quant_dataset: Dataset = None) -> Tuple[torch.Tensor, torch.Tensor]: """ From the original module, collect output activations and input activations to corresponding quantized module. diff --git a/TrainingExtensions/torch/src/python/aimet_torch/adaround/adaround_weight.py b/TrainingExtensions/torch/src/python/aimet_torch/adaround/adaround_weight.py index 60fa0ed2640..e9ae5e80536 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/adaround/adaround_weight.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/adaround/adaround_weight.py @@ -43,7 +43,7 @@ import itertools import json import shutil -from typing import Tuple, Union, Dict, List, Callable, Any +from typing import Tuple, Union, Dict, List, Callable, Any, Optional import torch from torch.utils.data import DataLoader from tqdm import tqdm @@ -61,6 +61,8 @@ from aimet_torch.adaround.adaround_tensor_quantizer import AdaroundTensorQuantizer from aimet_torch.adaround.adaround_optimizer import AdaroundOptimizer from aimet_torch.adaround.adaround_loss import AdaroundHyperParameters +from aimet_torch.adaround.activation_sampler import create_modulelist_for_group_modules, get_block_inputs,\ + get_block_outputs logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant) @@ -195,17 +197,22 @@ def _apply_adaround(cls, quant_sim: QuantizationSimModel, model: torch.nn.Module @classmethod def _adaround_model(cls, model: torch.nn.Module, quant_sim: QuantizationSimModel, module_act_func_pair: Dict, - params: AdaroundParameters, dummy_input: Union[torch.Tensor, Tuple], checkpoints_config: str = None): + params: AdaroundParameters, dummy_input: Union[torch.Tensor, Tuple], + checkpoints_config: str = None): """ Optimize weight rounding of every module (AdaroundSupportedModules) of model in sequential manner based on occurrence + + NOTE: When checkpoints_config file is provided, assumption is that the outputs from previous group modules (block) + should feed directly into next group modules (block) + :param model: Original fp32 model from which quant_sim was created. :param quant_sim: QuantizationSimModel object to optimize weight rounding. The activation quantizers are expected to have been disabled. :param module_act_func_pair: Dictionary of module to immediate following activation function :param params: Adaround parameters :param dummy_input: Dummy input to the model - :param checkpoints_config: Config files to split fp32/quant model by checkpoints + :param checkpoints_config: Config files to split fp32/quant model by checkpoints to speedup activations sampling """ # pylint: disable=too-many-locals, protected-access, too-many-branches, too-many-statements @@ -222,7 +229,6 @@ def _adaround_model(cls, model: torch.nn.Module, quant_sim: QuantizationSimModel num_iterations = 15000 else: num_iterations = 10000 - try: # Cache model input data to WORKING_DIR cached_dataset = utils.CachedDataset(params.data_loader, params.num_batches, WORKING_DIR) @@ -233,8 +239,28 @@ def _adaround_model(cls, model: torch.nn.Module, quant_sim: QuantizationSimModel # AdaRound must be applied to modules in the order of occurrence if checkpoints_config: + # Load the predefined json file for checkpoints info + ckpts_file = json.load(open(checkpoints_config)) + assert 'grouped_modules' in ckpts_file.keys(),\ + "Please provide a dictionary of grouped_modules in the file to define checkpoints" + assert 'include_static_inputs' in ckpts_file.keys(),\ + "Please provide a dictionary of include_static_inputs in the file to define checkpoints" + assert 'cache_on_cpu' in ckpts_file.keys(),\ + "Please define cache_on_cpu to determine whether to cache intermediate tensors on CPU" + + grouped_modules = ckpts_file['grouped_modules'] + breakpoint_module_name = ckpts_file['grouped_modules'][list(grouped_modules.keys())[0]][0] + include_static_inputs = ckpts_file['include_static_inputs'] + cache_on_cpu = ckpts_file['cache_on_cpu'] + cached_fp_dataset, cached_quant_dataset = get_block_inputs(model, quant_sim, + breakpoint_module_name, + cached_dataset, cache_on_cpu, + params.forward_fn, params.num_batches, + WORKING_DIR) # Get the device of model to latter be used to place input tensor on the same device device = utils.get_device(model) + model.cpu() + quant_sim.model.cpu() # Forward function for the ModuleList object def fwd_mod_ls(mod_ls, x): @@ -242,96 +268,25 @@ def fwd_mod_ls(mod_ls, x): x = params.forward_fn(mod, x) return x - fp32_cache_path = WORKING_DIR+'fp32/' - quant_cache_path = WORKING_DIR+'quant/' + sub_fp_models, sub_sim_models = create_modulelist_for_group_modules(model, quant_sim, grouped_modules) + for i, (fp_block, quant_sim_block, static_input) in enumerate(zip(sub_fp_models, + sub_sim_models, + include_static_inputs)): + modules = utils.get_ordered_list_of_modules(fp_block, cached_fp_dataset[0], fwd_mod_ls) + cls._run_adaround_model(modules, fp_block, quant_sim_block, + module_act_func_pair, opt_params, + fwd_mod_ls, + cached_fp_dataset, cached_quant_dataset) + + # Get the outputs from the current block and assign to be the inputs for next block + # except for the last block + if i < len(sub_fp_models) - 1: + get_block_outputs(fp_block, quant_sim_block, static_input, + cached_fp_dataset, cached_quant_dataset, cache_on_cpu, + fwd_mod_ls, device, WORKING_DIR) - # Load the predefined json file for checkpoints info - ckpts_file = json.load(open(checkpoints_config)) - assert 'grouped_modules' in ckpts_file.keys(), "Please provide a dictionary of grouped_modules in the file to define checkpoints" - assert 'include_static_inputs' in ckpts_file.keys(), "Please provide a dictionary of include_static_inputs in the file to define checkpoints" - assert 'cache_on_cpu' in ckpts_file.keys(), "Please define cache_on_cpu to determine whether to cache intermediate tensors on CPU" - - grouped_modules_dict = ckpts_file['grouped_modules'] - break_point = ckpts_file['grouped_modules'][list(grouped_modules_dict.keys())[0]][0] - include_static_inputs = ckpts_file['include_static_inputs'] - cache_on_cpu = ckpts_file['cache_on_cpu'] - - # Cache input data for both fp and quant model - if cache_on_cpu: - cached_fp_dataset = utils.cache_intermediate_datasets(cached_dataset, cache_on_cpu, model, - break_point, params.forward_fn) - cached_quant_dataset = utils.cache_intermediate_datasets(cached_dataset, cache_on_cpu, - quant_sim.model, break_point, - params.forward_fn) - else: - utils.cache_intermediate_datasets(cached_dataset, cache_on_cpu, model, break_point, - params.forward_fn, fp32_cache_path) - utils.cache_intermediate_datasets(cached_dataset, cache_on_cpu, quant_sim.model, break_point, - params.forward_fn, quant_cache_path) - cached_fp_dataset = utils.CachedDataset(None, params.num_batches, fp32_cache_path) - cached_quant_dataset = utils.CachedDataset(None, params.num_batches, quant_cache_path) - - # Place fp32/quant model to cpu to save the memory usage of GPU - model.cpu() - quant_sim.model.cpu() - - # Use torch.nn.ModuleList to group modules - sub_fp_models = [] - sub_sim_models = [] - for _, modules in grouped_modules_dict.items(): - fp_mod_ls = torch.nn.ModuleList() - quant_mod_ls = torch.nn.ModuleList() - for name in modules: - fp_mod_ls.append(utils.get_named_module(model, name)) - quant_mod_ls.append(utils.get_named_module(quant_sim.model, name)) - sub_fp_models.append(fp_mod_ls) - sub_sim_models.append(quant_mod_ls) - - for n, (fp_model, sim_model, include_static_input) in enumerate(zip(sub_fp_models, sub_sim_models, include_static_inputs)): - # Place sub fp32/quant model to the device - fp_model.to(device) - sim_model.to(device) - - modules = utils.get_ordered_list_of_modules(fp_model, cached_fp_dataset[0], fwd_mod_ls) - cls._run_adaround_model(modules, fp_model, sim_model, module_act_func_pair, opt_params, fwd_mod_ls, cached_fp_dataset, cached_quant_dataset) - - if n < len(sub_fp_models) - 1: - # Cache the outputs of current sub fp32/quant model to be the input of next sub fp32/quant model - fp_iterator = iter(cached_fp_dataset) - quant_iterator = iter(cached_quant_dataset) - # pylint: disable=consider-using-enumerate - for idx in range(len(cached_fp_dataset)): - # Place the input tensors on the same device as sub fp32/quant model - fp_data = utils.change_tensor_device_placement(next(fp_iterator), device) - quant_data = utils.change_tensor_device_placement(next(quant_iterator), device) - with utils.in_eval_mode(fp_model), utils.in_eval_mode(sim_model), torch.no_grad(): - fp_output = fwd_mod_ls(fp_model, fp_data) - fp_output = fp_output[0].cpu() if isinstance(fp_output, (tuple, list)) else fp_output.cpu() - quant_output = fwd_mod_ls(sim_model, quant_data) - quant_output = quant_output[0].cpu() if isinstance(quant_output, (tuple, list)) else quant_output.cpu() - - # Check if the next ModuleList needs static inputs or not - if include_static_input == "True": - fp_data[0] = fp_output - quant_data[0] = quant_output - else: - fp_data = [fp_output] - quant_data = [quant_output] - - # Cache the outputs on CPU or disk - if cache_on_cpu: - cached_fp_dataset[idx] = fp_data - cached_quant_dataset[idx] = quant_data - else: - utils.save_to_cache(fp_data, fp32_cache_path, idx) - utils.save_to_cache(quant_data, quant_cache_path, idx) - - # Place sub fp32/quant model to cpu - fp_model.cpu() - sim_model.cpu() # After finishing Adaround, placing the quant model back to its original device quant_sim.model.to(device) - else: modules = utils.get_ordered_list_of_modules(model, dummy_input) cls._run_adaround_model(modules, model, quant_sim.model, module_act_func_pair, opt_params, @@ -342,10 +297,14 @@ def fwd_mod_ls(mod_ls, x): shutil.rmtree(WORKING_DIR) @classmethod - def _run_adaround_model(cls, modules, model, quant_sim_model, module_act_func_pair, opt_params, forward_fn, - cached_dataset, cached_quant_dataset=None): + def _run_adaround_model(cls, modules: List, model: torch.nn.Module, quant_sim_model: torch.nn.Module, + module_act_func_pair: Dict, opt_params: AdaroundHyperParameters, forward_fn: Callable, + cached_dataset: utils.CachedDataset, + cached_quant_dataset: Optional[utils.CachedDataset] = None): """ - Iterate through all modules to find out Adaround supported modules and apply Adaround optimization to those modules + Iterate through all modules to find out Adaround supported modules and + apply Adaround optimization to those modules + :param modules: Candidate modules :param model: Original fp32 model :param quant_sim_model: QuantSim model diff --git a/TrainingExtensions/torch/src/python/aimet_torch/seq_mse.py b/TrainingExtensions/torch/src/python/aimet_torch/seq_mse.py new file mode 100644 index 00000000000..7483ef799e5 --- /dev/null +++ b/TrainingExtensions/torch/src/python/aimet_torch/seq_mse.py @@ -0,0 +1,504 @@ +# /usr/bin/env python +# -*- mode: python -*- +# ============================================================================= +# @@-COPYRIGHT-START-@@ +# +# Copyright (c) 2023, Qualcomm Innovation Center, Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# SPDX-License-Identifier: BSD-3-Clause +# +# @@-COPYRIGHT-END-@@ +# ============================================================================= + +""" Sequential MSE implementation """ + +import json +import os +import tempfile +from dataclasses import dataclass +from typing import Optional, Union, Tuple, List, Callable +import torch +import torch.nn.functional as functional +from torch.utils.data import DataLoader + +from aimet_common.defs import QuantScheme +import aimet_common.libpymo as libpymo +from aimet_torch.utils import CachedDataset, get_ordered_list_of_modules, in_eval_mode, StopForwardException,\ + change_tensor_device_placement, get_device +from aimet_torch.adaround.activation_sampler import create_modulelist_for_group_modules,\ + get_block_inputs, get_block_outputs +from aimet_torch.qc_quantize_op import QcQuantizeWrapper, QcQuantizeOpMode +from aimet_torch.tensor_quantizer import TensorQuantizer, StaticGridPerTensorQuantizer, StaticGridPerChannelQuantizer +from aimet_torch.quantsim import QuantizationSimModel + +# The following modules with weights are supported +SUPPORTED_MODULES = (torch.nn.Linear, ) + + +def default_forward_fn(model, inputs): + """ + Default forward function. + :param model: pytorch model + :param inputs: model inputs + """ + if isinstance(inputs, torch.Tensor): + inputs = [inputs] + return model(*inputs) + + +@dataclass +class SeqMseParams: + """ + Sequential MSE parameters + + :param num_batches: Number of batches. + :param num_candidates: Number of candidates to perform grid search. Default 20. + :param inp_symmetry: Input symmetry. Default 'symqt'. + :param loss_fn: Loss function. Default 'mse'. + :param forward_fn: Optional adapter function that performs forward pass given a model and inputs + yielded from the data loader. The function expects model as first argument and inputs to model as second argument. + """ + num_batches: int + num_candidates: int = 20 + inp_symmetry: str = 'symqt' + loss_fn: str = 'mse' + forward_fn: Callable = default_forward_fn + + +def apply_seq_mse(model: torch.nn.Module, + sim: QuantizationSimModel, + data_loader: DataLoader, + params: SeqMseParams, + modules_to_exclude: Optional[List[torch.nn.Module]] = None, + module_classes_to_exclude: Optional[List[torch.nn.Module]] = None, + checkpoints_config: Optional[str] = None): + """ + Apply sequential MSE - find and freze optimal parameter encodings candidate + 1 Disable all input/output quantizers, param quantizers from exclusion list + 2 Find and feeze optimal parameter encodings candidate for remaining supported modules + 3 Re-enable disabled quantizers from step 1 + + NOTE: module reference(s) passed to module_to_exclude list should be from sim.model. + + :param model: Original fp32 model + :param sim: Corresponding QuantizationSimModel object + :param data_loader: Data loader + :param params: Sequential MSE parameters + :param modules_to_exclude: List of supported modules to exclude when applying Sequential MSE + :param module_classes_to_exclude: List of supported module classes to exclude when applying Sequential MSE + :param checkpoints_config: Config files to split fp32/quant model by checkpoints to speedup activations sampling + """ + # pylint: disable=protected-access + assert sim._quant_scheme == QuantScheme.post_training_tf, "Use TF quant-scheme with sequential MSE." + + # disable all input/output activation quantizers and + # parameter quantizers corresponding to modules from exclusion list + quantizers = get_quantizers_to_be_disabled(sim, modules_to_exclude, module_classes_to_exclude) + enable_disable_quantizers(quantizers, enabled=False) + + # Initialize all remaining parameters' encodings + compute_all_param_encodings(sim) + + # Find and freeze optimal parameter encodings candidate + with tempfile.TemporaryDirectory() as tempdir: + cached_dataset = CachedDataset(data_loader, params.num_batches, os.path.join(tempdir, 'cached_dataset')) + if checkpoints_config: + apply_seq_mse_using_opt_sampling(checkpoints_config, model, sim, cached_dataset, params, tempdir) + else: + dummy_input = change_tensor_device_placement(next(iter(data_loader)), get_device(model)) + fp32_modules = get_ordered_list_of_modules(model, dummy_input) + fp32_modules = [(name, module) for name, module in fp32_modules if isinstance(module, SUPPORTED_MODULES)] + run_seq_mse(fp32_modules, model, sim.model, params, params.forward_fn, + cached_dataset, None) + + # re-enable disabled quantizers + enable_disable_quantizers(quantizers, enabled=True) + + +def apply_seq_mse_using_opt_sampling(checkpoints_config: str, + model: torch.nn.Module, + sim: QuantizationSimModel, + cached_dataset: CachedDataset, + params: SeqMseParams, + tempdir: str): + """ + Apply sequential MSE using optimized sampling of intermediate data. When checkpoints_config file is provided, + intermediate activations from breakpoint are treated as model inputs for next blocks. + + NOTE: Assumption is that the outputs from the current block are fed directly to following block + and there are no funciotnal operations in-between. + + :param checkpoints_config: Config files to split fp32/quant model by checkpoints to speedup activations sampling + :param model: Original fp32 model + :param sim: Corresponding QuantizationSimModel object + :param cached_dataset: Cached dataset + :param params: Sequential MSE parameters + :param tempdir: temporary working directory + """ + # pylint: disable=too-many-locals + ckpts_file = json.load(open(checkpoints_config)) + assert 'grouped_modules' in ckpts_file.keys(), \ + "Please provide a dictionary of grouped_modules in the file to define checkpoints" + assert 'include_static_inputs' in ckpts_file.keys(), \ + "Please provide a dictionary of include_static_inputs in the file to define checkpoints" + assert 'cache_on_cpu' in ckpts_file.keys(), \ + "Please define cache_on_cpu to determine whether to cache intermediate tensors on CPU" + + grouped_modules = ckpts_file['grouped_modules'] + breakpoint_module_name = ckpts_file['grouped_modules'][list(grouped_modules.keys())[0]][0] + include_static_inputs = ckpts_file['include_static_inputs'] + cache_on_cpu = ckpts_file['cache_on_cpu'] + cached_fp_dataset, cached_quant_dataset = get_block_inputs(model, sim, + breakpoint_module_name, + cached_dataset, cache_on_cpu, + params.forward_fn, params.num_batches, + tempdir) + # Get the device of model to latter be used to place input tensor on the same device + device = get_device(model) + model.cpu() + sim.model.cpu() + + # Forward function for the ModuleList object + def fwd_fn_modulelist(modulelists, x): + for mod in modulelists: + x = mod(*x) if isinstance(x, (tuple, list)) else mod(x) + return x + + sub_fp_models, sub_sim_models = create_modulelist_for_group_modules(model, sim, grouped_modules) + for i, (fp_block, quant_sim_block, static_input) in enumerate(zip(sub_fp_models, + sub_sim_models, + include_static_inputs)): + fp32_modules = get_ordered_list_of_modules(fp_block, cached_fp_dataset[0], fwd_fn_modulelist) + fp32_modules = [(name, module) for name, module in fp32_modules if isinstance(module, SUPPORTED_MODULES)] + run_seq_mse(fp32_modules, fp_block, quant_sim_block, params, fwd_fn_modulelist, + cached_fp_dataset, cached_quant_dataset) + + # Get the outputs from the current block and assign to be the inputs for next block + # except for the last block + if i < len(sub_fp_models) - 1: + get_block_outputs(fp_block, quant_sim_block, static_input, + cached_fp_dataset, cached_quant_dataset, cache_on_cpu, + fwd_fn_modulelist, device, tempdir) + sim.model.to(device) + +def run_seq_mse(fp32_modules: List[Tuple[str, torch.nn.Module]], + model: torch.nn.Module, + quant_model: torch.nn.Module, + params: SeqMseParams, + forward_fn: Callable, + cached_fp_dataset: CachedDataset, + cached_quant_dataset: Optional[CachedDataset] = None, + ): + """ + Run Sequential MSE + + :param fp32_modules: List of FP32 candidate modules in order of occurence + :param model: FP32 model + :param quant_model: QuantizationSimModel object + :param params: Sequential MSE parameters + :param forward_fn: Optional adapter function that performs forward pass given a model and inputs + yielded from the data loader. The function expects model as first argument and inputs to model as second argument. + :param cached_fp_dataset: Cached dataset object + :param cached_quant_dataset: Cached dataset object + """ + name_to_quant_module = {} + for name, quant_module in quant_model.named_modules(): + name_to_quant_module[name] = quant_module + + if not cached_quant_dataset: + cached_quant_dataset = cached_fp_dataset + + for module_qualified_name, fp32_module in fp32_modules: + try: + quant_module = name_to_quant_module[module_qualified_name] + except KeyError: + continue + + print("Finding optimal parameter encodings candidate: ", module_qualified_name) + if params.inp_symmetry == "asym": + fp32_inp_acts = get_module_inp_acts(fp32_module, model, params, forward_fn, cached_fp_dataset) + quant_inp_acts = get_module_inp_acts(quant_module, quant_model, params, forward_fn, cached_quant_dataset) + optimize_module(quant_module, fp32_inp_acts, quant_inp_acts, params) + elif params.inp_symmetry == "symfp": + fp32_inp_acts = get_module_inp_acts(fp32_module, model, params, forward_fn, cached_fp_dataset) + optimize_module(quant_module, fp32_inp_acts, fp32_inp_acts, params) + elif params.inp_symmetry == "symqt": + quant_inp_acts = get_module_inp_acts(quant_module, quant_model, params, forward_fn, cached_quant_dataset) + optimize_module(quant_module, quant_inp_acts, quant_inp_acts, params) + else: + raise ValueError(f"Invalid inp_symmetry: {params.inp_symmetry}") + + +def get_module_inp_acts(module: torch.nn.Module, + model: torch.nn.Module, + params: SeqMseParams, + forward_fn: Callable, + cached_dataset: CachedDataset, + ) -> torch.Tensor: + """ + For given module, get inputs to the module. + + :param module: FP32/quant module + :param model: FP32/quant model + :param params: Sequential MSE parameters + :param forward_fn: Optional adapter function that performs forward pass given a model and inputs + yielded from the data loader. The function expects model as first argument and inputs to model as second argument. + :param cached_dataset: Cached dataset + :return: Concatenated inputs + """ + inp_acts = [] + def hook_fn(_, inp, __): + if isinstance(inp, tuple): + inp_acts.append(inp[0]) + raise StopForwardException + handle = module.register_forward_hook(hook_fn) + + iterator = iter(cached_dataset) + for _ in range(params.num_batches): + batch = change_tensor_device_placement(next(iterator), get_device(model)) + try: + with in_eval_mode(model), torch.no_grad(): + forward_fn(model, batch) + except StopForwardException: + pass + handle.remove() + + inp_acts = torch.stack(inp_acts) + return inp_acts + + +def get_quantizers_to_be_disabled(sim: QuantizationSimModel, + modules_to_exclude: Optional[List[torch.nn.Module]], + module_classes_to_exclude: Optional[List[torch.nn.Module]])\ + -> List[TensorQuantizer]: + """ + For given quantsim model, get all quantizers to be disabled before applying sequential MSE. + + :param sim: QuantizationSimModel object + :param modules_to_exclude: List of supported modules to exclude when applying Sequential MSE + :param module_classes_to_exclude: List of supported module classes to exclude when applying Sequential MSE + :return: List of quantizers to be disabled. + """ + # pylint: disable=protected-access + # pylint: disable=unidiomatic-typecheck + quantizers_to_be_disabled = [] + for _, quant_wrapper in sim.quant_wrappers(): + for quantizer in quant_wrapper.input_quantizers: + if quantizer.enabled: + quantizers_to_be_disabled.append(quantizer) + for quantizer in quant_wrapper.output_quantizers: + if quantizer.enabled: + quantizers_to_be_disabled.append(quantizer) + + for _, quant_wrapper in sim.quant_wrappers(): + if modules_to_exclude and quant_wrapper in modules_to_exclude: + for quantizer in quant_wrapper.param_quantizers.values(): + if quantizer.enabled: + quantizers_to_be_disabled.append(quantizer) + if module_classes_to_exclude and type(quant_wrapper._module_to_wrap) in module_classes_to_exclude: + for quantizer in quant_wrapper.param_quantizers.values(): + if quantizer.enabled: + quantizers_to_be_disabled.append(quantizer) + return quantizers_to_be_disabled + + +def enable_disable_quantizers(quantizers: List[TensorQuantizer], enabled: bool): + """ + For given list of quantizers, set (enable/disable) quantizer's 'enabled' attribute. + + :param quantizers: List of quantizers. + :param enabled: Enabled flag. + """ + for quantizer in quantizers: + quantizer.enabled = enabled + + +def compute_all_param_encodings(sim: QuantizationSimModel): + """ + Compute encodings for all parameters, needed for initializing Sequential MSE + + :param sim: Quant sim + """ + for _, quant_wrapper in sim.quant_wrappers(): + for name, quantizer in quant_wrapper.param_quantizers.items(): + quantizer.reset_encoding_stats() + quantizer.update_encoding_stats(getattr(quant_wrapper, name).data) + quantizer.compute_encoding() + + # Wrapper mode must be set to ACTIVE because the wrapper's quantize_dequantize_params() will only call + # into the param tensor quantizer's quantize_dequantize() if the mode isn't PASSTHROUGH. + quant_wrapper.set_mode(QcQuantizeOpMode.ACTIVE) + + +def get_candidates(num_candidates: int, + per_channel_max: torch.Tensor, + per_channel_min: Optional[torch.Tensor]) -> List[Tuple[torch.Tensor, torch.Tensor]]: + """ + Perform grid search. + + :param num_candidates: Number of candidates + :param per_channel_max: Per channel max values + :param per_channel_min: Per channel min values + :return: candidates + """ + candidates = [] + if per_channel_min is not None: + for cand in range(num_candidates): + cand_max = torch.tensor(per_channel_max / num_candidates * (cand + 1)) + cand_min = torch.tensor(per_channel_min / num_candidates * (cand + 1)) + candidates.append((cand_max, cand_min)) + else: + for cand in range(num_candidates): + cand_max = torch.tensor(per_channel_max / num_candidates * (cand + 1)) + cand_min = -cand_max + candidates.append((cand_max, cand_min)) + return candidates + + +def optimize_module(quant_module: QcQuantizeWrapper, + x: torch.Tensor, + xq: torch.Tensor, + params: SeqMseParams): + """ + Find and freeze optimal parameter encodings candidate for given module. + + :param quant_module: Quant module to be optimized + :param x: Inputs to module from FP32 model + :param xq: Inputs to module from QuantSim model + :param params: Sequenial MSE parameters + """ + # pylint: disable=too-many-locals + if quant_module.param_quantizers["weight"].use_symmetric_encodings: + per_channel_max = torch.max(quant_module.weight.abs(), dim=1)[0].detach() + per_channel_min = None + else: + per_channel_max = torch.max(quant_module.weight, dim=1)[0].detach() + per_channel_min = torch.min(quant_module.weight, dim=1)[0].detach() + candidates = get_candidates(params.num_candidates, per_channel_max, per_channel_min) + + total_loss = [] + for cand_max, cand_min in candidates: + compute_param_encodings(quant_module.param_quantizers['weight'], cand_min, cand_max) + w = quant_module.weight + wq = quant_module.param_quantizers['weight'].quantize_dequantize(w, libpymo.RoundingMode.ROUND_NEAREST) + loss = torch.zeros(len(cand_max), device=w.device) + with torch.no_grad(): + for batch_idx in range(params.num_batches): + xqwq, xw = compute_outputs(quant_module, x[batch_idx], xq[batch_idx], w, wq) + loss += compute_recon_loss(xqwq, xw, params) + total_loss.append(loss) + + best_indices = torch.stack(total_loss).min(0, keepdim=True)[1] + print(best_indices.squeeze(0)[:params.num_candidates]) + best_max = torch.stack([cand_max for cand_max, _ in candidates]).gather(0, best_indices)[0] + best_min = torch.stack([cand_min for _, cand_min in candidates]).gather(0, best_indices)[0] + + # Compute and freeze parameter encodings using best candidate + compute_param_encodings(quant_module.param_quantizers['weight'], best_min, best_max) + quant_module.param_quantizers['weight'].freeze_encoding() + + +def compute_param_encodings(quantizer: Union[StaticGridPerTensorQuantizer, StaticGridPerChannelQuantizer], + x_min: torch.Tensor, + x_max: torch.Tensor): + """ + Compute encodings for parameter quantizer using given x_min and x_max values. + + :param quantizer: Tensor quantizer + :param x_min: min values + :param x_max: max values + """ + tensor = torch.stack([x_min, x_max], dim=-1) + quantizer.reset_encoding_stats() + quantizer.update_encoding_stats(tensor) + quantizer.compute_encoding() + + +def compute_outputs(quant_module: QcQuantizeWrapper, + x: torch.Tensor, + xq: torch.Tensor, + w: torch.Tensor, + wq: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute X^W^ and XW output acitvations. + + :param quant_module: Wrapper module to be optimized + :param x: Inputs from FP32 model + :param xq: Inputs from QuantSim model + :param w: FP32 weights + :param wq: Quantized-dequantized weights + :return: xqwq, xw + """ + # pylint: disable=protected-access + module = quant_module._module_to_wrap + + if isinstance(module, torch.nn.Linear): + xqwq = functional.linear(xq, wq, module.bias) + xw = functional.linear(x, w, module.bias) + else: + raise ValueError('Unsupported module: ', module) + return xqwq, xw + + +def compute_recon_loss(xqwq: torch.Tensor, xw: torch.Tensor, params: SeqMseParams): + """ + Compute reconsturction loss + + :param xqwq: X^Q^ quantized-dequantized values + :param xw: XW FP32 values + :param params: Sequenial MSE parameters + :return: loss + """ + if params.loss_fn == "mse": + loss_fn = functional.mse_loss + elif params.loss_fn == "l1": + loss_fn = functional.l1_loss + else: + loss_fn = neg_sqnr + loss = loss_fn(xqwq, xw, reduction="none").sum((0, 1)) + return loss + + +def neg_sqnr(pred: torch.Tensor, target: torch.Tensor, eps=1e-10, reduction="none"): + """ + Loss function to minimize negative SQNR which is equivalent to maximizing SQNR. + + :param pred: X^Q^ quantized-dequantized values + :param target: XW FP32 values + :param eps: epsilon + :param reduction: unused arg + :return: Negative SQNR + """ + # pylint: disable=unused-argument + quant_error = target - pred + exp_noise = torch.mean(quant_error ** 2, (0, 1), keepdim=True) + eps + exp_signal = torch.mean(target ** 2, (0, 1), keepdim=True) + sqnr = exp_signal / exp_noise + sqnr_db = 10 * torch.log10(sqnr) + return -sqnr_db diff --git a/TrainingExtensions/torch/test/python/test_seq_mse.py b/TrainingExtensions/torch/test/python/test_seq_mse.py new file mode 100644 index 00000000000..af3791195e2 --- /dev/null +++ b/TrainingExtensions/torch/test/python/test_seq_mse.py @@ -0,0 +1,223 @@ +# /usr/bin/env python +# -*- mode: python -*- +# ============================================================================= +# @@-COPYRIGHT-START-@@ +# +# Copyright (c) 2023, Qualcomm Innovation Center, Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# SPDX-License-Identifier: BSD-3-Clause +# +# @@-COPYRIGHT-END-@@ +# ============================================================================= + +import json +import pytest +import numpy +import torch +from torch.utils.data import Dataset, DataLoader + +from aimet_torch.utils import create_fake_data_loader +from aimet_torch.quantsim import QuantizationSimModel +from aimet_torch.qc_quantize_op import StaticGridQuantWrapper, QuantScheme +from aimet_torch.seq_mse import apply_seq_mse, get_candidates, optimize_module, SeqMseParams +from models.mnist_torch_model import Net + +@pytest.fixture(scope="session") +def dummy_input(): + return torch.randn((1, 1, 28, 28)) + + +@pytest.fixture(scope="session") +def unlabeled_data_loader(dummy_input): + class MyDataset(Dataset): + def __init__(self, data): + self.data = data + + def __getitem__(self, index): + return self.data[index] + + def __len__(self): + return len(self.data) + + dataset = MyDataset([dummy_input[0, :] for _ in range(32)]) + return DataLoader(dataset) + + +def save_config_file_for_checkpoints(): + checkpoints_config = { + "grouped_modules": { + "0": ["conv1", "bn1", "relu1", "maxpool"], + "1": ["conv2", "bn2", "relu2"], + "2": ["conv3", "relu3", "avgpool"], + "3": ["conv4", "flatten", "fc"], + }, + "include_static_inputs": [ + "False", + "False", + "False", + "False" + ], + "cache_on_cpu": "False" + } + + with open('./test_checkpoints.json', 'w') as f: + json.dump(checkpoints_config, f) + + +class SplittableModel(torch.nn.Module): + """ Use this model for unit testing purposes. Expect input shape (1, 3, 32, 32) """ + def __init__(self): + super(SplittableModel, self).__init__() + self.conv1 = torch.nn.Conv2d(3, 32, kernel_size=2, stride=2, padding=2, bias=False) + self.bn1 = torch.nn.BatchNorm2d(32) + self.relu1 = torch.nn.ReLU(inplace=True) + self.maxpool = torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=1) + self.conv2 = torch.nn.Conv2d(32, 16, kernel_size=2, stride=2, padding=2, bias=False) + self.bn2 = torch.nn.BatchNorm2d(16) + self.relu2 = torch.nn.ReLU(inplace=True) + self.conv3 = torch.nn.Conv2d(16, 8, kernel_size=2, stride=2, padding=2, bias=False) + self.relu3 = torch.nn.ReLU(inplace=True) + self.avgpool = torch.nn.AvgPool2d(3, stride=1) + self.conv4 = torch.nn.Conv2d(8, 4, kernel_size=2, stride=2, padding=2, bias=True) + self.flatten = torch.nn.Flatten() + self.fc = torch.nn.Linear(36, 12) + + def forward(self, *inputs): + x = self.conv1(inputs[0]) + x = self.bn1(x) + x = self.relu1(x) + x = self.maxpool(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu2(x) + x = self.conv3(x) + x = self.relu3(x) + x = self.avgpool(x) + x = self.conv4(x) + x = self.flatten(x) + x = self.fc(x) + return x + + +class TestSeqMse: + + def test_seq_mse(self): + """ test get_candidates() """ + torch.manual_seed(0) + linear = torch.nn.Linear(2, 4) + x_max = torch.max(linear.weight.abs(), dim=1)[0] + x_min = None + candidates = get_candidates(20, x_max, x_min) + for cand_max, cand_min in candidates: + assert list(cand_max.size())[0] == linear.out_features + assert list(cand_min.size())[0] == linear.out_features + + @pytest.mark.parametrize("enable_pcq", [True, False]) + @pytest.mark.parametrize("param_bw", [2, 31]) + def test_optimize_module_linear(self, enable_pcq, param_bw): + """ test optimize module for linear """ + torch.manual_seed(0) + linear = torch.nn.Linear(64, 128) + wrapper = StaticGridQuantWrapper(linear, param_bw, 16, 'nearest', QuantScheme.post_training_tf) + wrapper.input_quantizers[0].enabled = False + wrapper.output_quantizers[0].enabled = False + if enable_pcq: + wrapper.enable_per_channel_quantization() + + xq = torch.randn(32, 4, 32, 64) + wrapper.param_quantizers['weight'].reset_encoding_stats() + wrapper.param_quantizers['weight'].update_encoding_stats(wrapper.weight.data) + wrapper.param_quantizers['weight'].compute_encoding() + before = wrapper.param_quantizers['weight'].encoding + params = SeqMseParams(num_batches=32) + optimize_module(wrapper, xq, xq, params) + after = wrapper.param_quantizers['weight'].encoding + + # If we use higher param_bw (for example 16, 31), then it should always choose larger candidates so + # before and after param encodings should be almost same. + if param_bw == 31: + if enable_pcq: + assert numpy.isclose(before[0].min, after[0].min) + assert numpy.isclose(before[0].max, after[0].max) + else: + assert numpy.isclose(before.min, after.min) + assert numpy.isclose(before.max, after.max) + else: + if enable_pcq: + assert not numpy.isclose(before[0].min, after[0].min) + assert not numpy.isclose(before[0].max, after[0].max) + else: + assert not numpy.isclose(before.min, after.min) + assert not numpy.isclose(before.max, after.max) + + @pytest.mark.cuda() + @pytest.mark.parametrize("inp_symmetry", ['asym', 'symfp', 'symqt']) + @pytest.mark.parametrize("loss_fn", ['mse', 'l1', 'aa']) + def test_apply_seq_mse(self, unlabeled_data_loader, inp_symmetry, loss_fn): + """ test apply_seq_mse end-to-end """ + torch.manual_seed(0) + model = Net().eval().cuda() + dummy_input = torch.randn(1, 1, 28, 28).cuda() + sim = QuantizationSimModel(model, dummy_input, default_param_bw=4, quant_scheme=QuantScheme.post_training_tf) + params = SeqMseParams(num_batches=2, inp_symmetry=inp_symmetry, loss_fn=loss_fn) + apply_seq_mse(model, sim, unlabeled_data_loader, params, modules_to_exclude=[sim.model.conv1]) + assert sim.model.fc1.param_quantizers['weight'].is_encoding_frozen + assert sim.model.fc2.param_quantizers['weight'].is_encoding_frozen + assert not sim.model.conv1.param_quantizers['weight'].encoding + assert sim.model.conv2.param_quantizers['weight'].encoding + + @pytest.mark.parametrize("inp_symmetry", ['asym', 'symfp', 'symqt']) + @pytest.mark.parametrize("loss_fn", ['mse', 'l1', 'aa']) + def test_seq_mse_with_and_without_checkpoints_config(self, inp_symmetry, loss_fn): + """ test apply_seq_mse end-to-end with and without checkpoints configs """ + torch.manual_seed(0) + + data_loader = create_fake_data_loader(dataset_size=2, batch_size=1, image_size=(3, 32, 32)) + model = SplittableModel().eval() + save_config_file_for_checkpoints() + dummy_input = torch.randn(1, 3, 32, 32) + sim_without = QuantizationSimModel(model, dummy_input, default_param_bw=4, + quant_scheme=QuantScheme.post_training_tf) + sim_with = QuantizationSimModel(model, dummy_input, default_param_bw=4, + quant_scheme=QuantScheme.post_training_tf) + params = SeqMseParams(num_batches=2, inp_symmetry=inp_symmetry, loss_fn=loss_fn) + + # Apply Sequential MSE without checkpoints config + apply_seq_mse(model, sim_without, data_loader, params) + without_checkpoints_enc = sim_without.model.fc.param_quantizers['weight'].encoding + + # Apply Sequential MSE with checkpoints config + apply_seq_mse(model, sim_with, data_loader, params, checkpoints_config="./test_checkpoints.json") + with_checkpoints_enc = sim_with.model.fc.param_quantizers['weight'].encoding + + # encodings should be bit-exact + assert without_checkpoints_enc.min == with_checkpoints_enc.min + assert without_checkpoints_enc.max == with_checkpoints_enc.max + assert without_checkpoints_enc.delta == with_checkpoints_enc.delta + assert without_checkpoints_enc.offset == with_checkpoints_enc.offset