Skip to content

Commit

Permalink
Fixes for failing models (#2553)
Browse files Browse the repository at this point in the history
Signed-off-by: Harshita Mangal <quic_mangal@quicinc.com>
  • Loading branch information
quic-mangal authored and quic-bharathr committed Sep 13, 2024
1 parent 8ba882b commit 01f4ced
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,11 @@ def _optimize_rounding(cls, module: ModuleInfo, quantized_input_name,
use_cache_acts_data = TorchAdaroundOptimizer._can_cache_acts_data(len(cached_dataset), inp_data_torch.shape,
out_data_torch.shape)

attributes = read_attributes_for_op(module)
if len(attributes['pad']) > 2:
logger.info("Skipping the Convolution layer because padding size of 4 is not supported for optimization")
return

if use_cache_acts_data and AdaroundOptimizer.enable_caching_acts_data():
logger.debug("Caching intermediate activations data for optimization.")
all_inp_data, all_orig_out_data = act_sampler.sample_and_place_all_acts_on_cpu(cached_dataset)
Expand Down Expand Up @@ -266,15 +271,15 @@ def _compute_output_with_adarounded_weights(weights: torch.Tensor, quant_module,
if 'bias' in quant_module.params:
bias = torch.from_numpy(numpy_helper.to_array(quant_module.params['bias'].tensor)).to(device)
out_data = functional.conv2d(inp_data, adarounded_weights, bias=bias, stride=attributes['strides'],
dilation=attributes['dilations'], padding=attributes['pads'][0],
dilation=attributes['dilations'], padding=attributes['pads'],
groups=attributes['group'])
elif quant_module.type == 'ConvTranspose':
attributes = read_attributes_for_op(quant_module)
bias = None
if 'bias' in quant_module.params:
bias = torch.from_numpy(numpy_helper.to_array(quant_module.params['bias'].tensor)).to(device)
out_data = functional.conv_transpose2d(inp_data, adarounded_weights, bias=bias, stride=attributes['strides'],
dilation=attributes['dilations'], padding=attributes['pads'][0],
dilation=attributes['dilations'], padding=attributes['pads'],
groups=attributes['group'])
elif quant_module.type in ['Gemm', 'MatMul']:
bias = torch.from_numpy(numpy_helper.to_array(quant_module.params['bias'].tensor)).to(device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import json
from typing import Tuple, Dict, List, Callable
from onnx import onnx_pb
from onnxruntime.quantization.onnx_quantizer import ONNXModel
from tqdm import tqdm

# Import AIMET specific modules
Expand Down Expand Up @@ -139,6 +140,8 @@ def apply_adaround(cls, model: onnx_pb.ModelProto, params: AdaroundParameters,
"""
# pylint: disable=too-many-arguments
# Create Quant sim with given parameters
if not isinstance(model, ONNXModel):
model = ONNXModel(model)
quant_sim = QuantizationSimModel(copy.deepcopy(model), quant_scheme=default_quant_scheme,
default_param_bw=default_param_bw,
config_file=default_config_file)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,14 @@ def read_attributes_for_op(module_info: ModuleInfo) -> Dict:
if attribute.name == 'dilations':
attributes['dilations'] = list(attribute.ints)
elif attribute.name == 'pads':
attributes['pads'] = list(attribute.ints)
padding = list(attribute.ints)
unique_vals = set()
new_padding = []
for val in padding:
if val not in unique_vals:
unique_vals.add(val)
new_padding.append(val)
attributes['pads'] = new_padding
elif attribute.name == 'strides':
attributes['strides'] = list(attribute.ints)
elif attribute.name == 'group':
Expand Down

0 comments on commit 01f4ced

Please sign in to comment.