Skip to content

Commit

Permalink
Add sequential MSE implementation (#2538)
Browse files Browse the repository at this point in the history
* Add sequential MSE implementation

Signed-off-by: Hitarth Mehta <quic_hitameht@quicinc.com>

* Refactor and reuse optimized activation sampling

Signed-off-by: Hitarth Mehta <quic_hitameht@quicinc.com>

---------

Signed-off-by: Hitarth Mehta <quic_hitameht@quicinc.com>
  • Loading branch information
quic-hitameht authored and quic-bharathr committed Sep 13, 2024
1 parent 8b78ad5 commit 4a4ce62
Show file tree
Hide file tree
Showing 4 changed files with 904 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,138 @@

""" Sample input to quantized wrapper module and output from original module for Adaround feature """

from typing import Tuple, Union, List, Callable, Any
from typing import Tuple, Union, List, Callable, Any, Dict
import torch
from torch.utils.data import Dataset

# Import AIMET specific modules
from aimet_common.utils import AimetLogger
from aimet_torch.utils import ModuleData
from aimet_torch.utils import CachedDataset, ModuleData, get_named_module, cache_intermediate_datasets,\
change_tensor_device_placement, in_eval_mode, save_to_cache
from aimet_torch.qc_quantize_op import QcQuantizeWrapper
from aimet_torch.quantsim import QuantizationSimModel

logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant)


def create_modulelist_for_group_modules(model: torch.nn.Module, sim: QuantizationSimModel, grouped_modules: Dict)\
-> Tuple[List[torch.nn.ModuleList], List[torch.nn.ModuleList]]:
"""
Use torch.nn.ModuleList to group modules from a single block.
:param model: FP32 model
:param sim: QuantizationSimModel object
:param grouped_modules: Group modules
:return: List of modulelist for FP32 and quant models
"""
sub_fp_models = []
sub_sim_models = []
for _, modules in grouped_modules.items():
fp_modulelist = torch.nn.ModuleList()
quant_modulelist = torch.nn.ModuleList()
for name in modules:
fp_modulelist.append(get_named_module(model, name))
quant_modulelist.append(get_named_module(sim.model, name))
sub_fp_models.append(fp_modulelist)
sub_sim_models.append(quant_modulelist)

return sub_fp_models, sub_sim_models


def get_block_inputs(model: torch.nn.Module, sim: QuantizationSimModel,
breakpoint_module_name: str, cached_dataset: CachedDataset,
cache_on_cpu: bool, forward_fn: Callable, num_batches: int, working_dir: str)\
-> Union[Tuple[List, List], Tuple[CachedDataset, CachedDataset]]:
"""
Get inputs to block/module from FP32 and QuantizationSimModel models
:param model: FP32 model
:param sim: QuantizationSimModel object
:param breakpoint_module_name: Breakpoint block/module name
:param cached_dataset: Cached dataset
:param cache_on_cpu: Whether to cache intermediate data on CPU or store to disk
:param forward_fn: adapter function that performs forward pass given a model and inputs
yielded from the data loader. The function expects model as first argument and inputs to model
as second argument.
:param num_batches: Number of batches
:param working_dir: Working to directory to save block inputs data to disk
:return: Inputs to block from FP32 and QuantizationSimModel models
"""
# Cache input data to first block from both FP32 and quant models
if cache_on_cpu:
cached_fp_dataset = cache_intermediate_datasets(cached_dataset, cache_on_cpu, model,
breakpoint_module_name, forward_fn)
cached_quant_dataset = cache_intermediate_datasets(cached_dataset, cache_on_cpu,
sim.model, breakpoint_module_name, forward_fn)
else:
fp32_cache_path = working_dir + 'fp32/'
quant_cache_path = working_dir + 'quant/'
cache_intermediate_datasets(cached_dataset, cache_on_cpu, model, breakpoint_module_name,
forward_fn, fp32_cache_path)
cache_intermediate_datasets(cached_dataset, cache_on_cpu, sim.model, breakpoint_module_name,
forward_fn, quant_cache_path)
cached_fp_dataset = CachedDataset(None, num_batches, fp32_cache_path)
cached_quant_dataset = CachedDataset(None, num_batches, quant_cache_path)
return cached_fp_dataset, cached_quant_dataset


def get_block_outputs(fp_block: torch.nn.ModuleList, quant_block: torch.nn.ModuleList, include_static_inputs: str,
cached_fp_dataset: List, cached_quant_dataset: List,
cache_on_cpu: bool, forward_fn: Callable, device: torch.device, working_dir: str):
"""
Get outputs from block/module from FP32 and QuantizationSimModel models and assign for next block/module.
NOTE: "static_inputs" (like attention_mask, position_ids) remains the same across different blocks.
So, if "include_static_inputs" is set to True, then such inputs are reused.
:param fp_block: ModuleList for fp32 modules
:param quant_block: ModuleList for quant modules
:param include_static_inputs: Flag to include "static_inputs" or not
:param cached_fp_dataset: Cached dataset for fp32 model
:param cached_quant_dataset: Cached dataset for quant model
:param cache_on_cpu: Whether to cache intermediate data on CPU or store to disk
:param forward_fn: Optional adapter function that performs forward pass given a model and inputs
yielded from the data loader. The function expects model as first argument and inputs to model as second argument.
:param device: torch device
:param working_dir: Working to directory to save block inputs data to disk
"""
# pylint: disable=too-many-locals, too-many-arguments
fp_block.to(device)
quant_block.to(device)

fp_iterator = iter(cached_fp_dataset)
quant_iterator = iter(cached_quant_dataset)
for idx in range(len(cached_fp_dataset)): # pylint: disable=consider-using-enumerate
fp_inputs = change_tensor_device_placement(next(fp_iterator), device)
quant_inputs = change_tensor_device_placement(next(quant_iterator), device)

with in_eval_mode(fp_block), in_eval_mode(quant_block), torch.no_grad():
fp_outputs = forward_fn(fp_block, fp_inputs)
fp_outputs = fp_outputs[0].cpu() if isinstance(fp_outputs, (tuple, list)) else fp_outputs.cpu()
quant_outputs = forward_fn(quant_block, quant_inputs)
quant_outputs = quant_outputs[0].cpu() if isinstance(quant_outputs, (tuple, list)) else quant_outputs.cpu()

# Check if the next ModuleList needs static inputs or not and assign
# the outputs (fp32/quant) from current block to be the input (fp32/quant) of next block
if include_static_inputs == "True":
fp_inputs[0], quant_inputs[0] = fp_outputs, quant_outputs
else:
fp_inputs, quant_inputs = [fp_outputs], [quant_outputs]

# Cache the outputs on CPU or disk
if cache_on_cpu:
cached_fp_dataset[idx] = fp_inputs
cached_quant_dataset[idx] = quant_inputs
else:
fp32_cache_path = working_dir + 'fp32/'
quant_cache_path = working_dir + 'quant/'
save_to_cache(fp_inputs, fp32_cache_path, idx)
save_to_cache(quant_inputs, quant_cache_path, idx)

fp_block.cpu()
quant_block.cpu()


class ActivationSampler:
"""
For a module in the original model and the corresponding module in the weight quantized QuantSim model,
Expand All @@ -72,7 +192,8 @@ def __init__(self, orig_module: torch.nn.Module, quant_module: QcQuantizeWrapper
self._orig_module_collector = ModuleData(orig_model, orig_module, forward_fn)
self._quant_module_collector = ModuleData(quant_model, quant_module, forward_fn)

def sample_and_place_all_acts_on_cpu(self, cached_dataset: Dataset, cached_quant_dataset: Dataset = None) -> Tuple[torch.Tensor, torch.Tensor]:
def sample_and_place_all_acts_on_cpu(self, cached_dataset: Dataset,
cached_quant_dataset: Dataset = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
From the original module, collect output activations and input activations
to corresponding quantized module.
Expand Down
Loading

0 comments on commit 4a4ce62

Please sign in to comment.