Skip to content

Commit

Permalink
Update Pytorch AMP supported kernels logic
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Hsieh <quic_klhsieh@quicinc.com>
  • Loading branch information
quic-klhsieh committed Oct 12, 2024
1 parent 9d3a7f9 commit ccb56ab
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def rename_nodes(G: nx.DiGraph, module_name_to_module_dict: Dict, dotted_name2op
for node in G_copy.nodes:
if node not in ["input_ops", "output_ops"]:
if ("input_ops", node) in G_copy.edges: # quantizer groups with an input quantizer
try:
module_name, _ = find_wrapper_module(node, module_name_to_module_dict)
module_name, _ = find_wrapper_module(node, module_name_to_module_dict)
if module_name is not None:
mapping = {node: module_name + "_output"}
G = nx.relabel_nodes(G, mapping)

Expand All @@ -173,14 +173,14 @@ def rename_nodes(G: nx.DiGraph, module_name_to_module_dict: Dict, dotted_name2op

G.nodes[new_node_name]["tensor_dims"] = input_shape
G.nodes[new_node_name]["tensor_size"] = input_size
except:
else:
_logger.info("did not change node name: %s", node)
else:
try:
module_name, _ = find_wrapper_module(node, module_name_to_module_dict)
module_name, _ = find_wrapper_module(node, module_name_to_module_dict)
if module_name is not None:
mapping = {node: module_name + "_output"}
G = nx.relabel_nodes(G, mapping)
except:
else:
_logger.info("did not change node name: %s", node)

G.remove_node("input_ops")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@ def get_supported_candidates_for_quantizers(quantizers: List,
# Store candidates for quantizer
store_candidates_for_quantizer(supported_kernels, op, amp_candidates_set, act_bw_set, act_and_param_set,
act_only_set, null_intersection_ops)
break

# Default candidate selected if op not found in supported kernels
if not ops_found:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

""" Find quantizer groups in a model """
import itertools
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, Optional
from collections import defaultdict
from dataclasses import dataclass, field
import torch
Expand Down Expand Up @@ -66,6 +66,7 @@ class QuantizerGroup(QuantizerGroupBase):
input_quantizers: Tuple[str, ...] = field(default_factory=tuple)
output_quantizers: Tuple[str, ...] = field(default_factory=tuple)
parameter_quantizers: Tuple[str, ...] = field(default_factory=tuple)
supported_kernel_ops: Tuple[str, ...] = field(default_factory=tuple)

def get_candidate(self, name_to_quantizer_dict: Dict) -> CANDIDATE_WITH_DTYPE:
"""
Expand Down Expand Up @@ -167,7 +168,8 @@ def get_input_quantizer_modules(self):
return tuple(sorted(result))


def find_wrapper_module(op_name: str, module_name_to_quantizer_dict: Dict) -> Tuple[str, torch.nn.Module]:
def find_wrapper_module(op_name: str, module_name_to_quantizer_dict: Dict) -> \
Tuple[Optional[str], Optional[torch.nn.Module]]:
"""
Finds quantization (wrapping) module corresponding to the wrapper module's dotted name
:param op_name: Dotted name of op as represented in connected graph
Expand All @@ -179,7 +181,7 @@ def find_wrapper_module(op_name: str, module_name_to_quantizer_dict: Dict) -> Tu
if module_name in module_name_to_quantizer_dict:
return module_name, module_name_to_quantizer_dict[module_name]
# Else it is a functional op
raise KeyError
return None, None


def get_module_name_to_module_dict(sim: QuantizationSimModel) -> Dict:
Expand Down Expand Up @@ -320,11 +322,8 @@ def get_input_and_param_quantizers(
"""
input_quantizers = []
parameter_quantizers = []
try:
module_name, module = find_wrapper_module(child, module_name_to_module_dict)
except KeyError:
pass
else:
module_name, module = find_wrapper_module(child, module_name_to_module_dict)
if module_name is not None:
for idx, input_quantizer in enumerate(module.input_quantizers):
if input_quantizer.enabled:
input_quantizers.append(module_name + '_input_quantizer_idx_' + str(idx))
Expand Down Expand Up @@ -358,9 +357,14 @@ def find_quantizer_group(sim: QuantizationSimModel) -> Tuple[Dict, List[Quantize
# Add one quantizer group for each input and it's weight param
input_quantizers, parameter_quantizers = get_input_and_param_quantizers(child, module_name_to_module_dict)
if input_quantizers or parameter_quantizers:
child_module_name, _ = find_wrapper_module(child, module_name_to_module_dict)
supported_kernel_ops = []
if child_module_name is not None:
supported_kernel_ops.append(child_module_name)
quantizer_group = QuantizerGroup(
input_quantizers=input_quantizers,
parameter_quantizers=parameter_quantizers
parameter_quantizers=parameter_quantizers,
supported_kernel_ops=tuple(supported_kernel_ops)
)
quantizer_groups.append(quantizer_group)
logger.debug('\n Quantizer Group Added: %s', quantizer_group)
Expand All @@ -375,42 +379,44 @@ def find_quantizer_group(sim: QuantizationSimModel) -> Tuple[Dict, List[Quantize
if not isinstance(parents, tuple):
parents = [parents]
for parent in parents:
try:
module_name, module = find_wrapper_module(parent, module_name_to_module_dict)
except KeyError:
continue
module_name, module = find_wrapper_module(parent, module_name_to_module_dict)
if module is not None:
for output_quantizer in module.output_quantizers:
if output_quantizer.enabled:
output_quantizers += (module_name,)

supported_kernel_ops = []
for child in children:
input_q, param_q = get_input_and_param_quantizers(child, module_name_to_module_dict)
input_quantizers += input_q
parameter_quantizers += param_q
child_module_name, _ = find_wrapper_module(child, module_name_to_module_dict)
if child_module_name is not None:
supported_kernel_ops.append(child_module_name)

# Don't add quantizer group if it is empty
if input_quantizers or output_quantizers or parameter_quantizers:
quantizer_group = QuantizerGroup(
input_quantizers=input_quantizers,
output_quantizers=output_quantizers,
parameter_quantizers=parameter_quantizers
parameter_quantizers=parameter_quantizers,
supported_kernel_ops=tuple(supported_kernel_ops)
)
quantizer_groups.append(quantizer_group)
logger.debug('\n Quantizer Group added: %s', quantizer_group)

if 'output_ops' in parent_child_op_groups:
for parent in parent_child_op_groups['output_ops']:
# Add one quantizer group for each input and it's weight param
try:
module_name, module = find_wrapper_module(parent, module_name_to_module_dict)
except KeyError:
continue
module_name, module = find_wrapper_module(parent, module_name_to_module_dict)
if module is not None:
for output_quantizer in module.output_quantizers:
if output_quantizer.enabled:
# Using empty supported kernel ops so that model output quantizers are able to consider all
# default candidates
quantizer_group = QuantizerGroup(
output_quantizers=(module_name,),
supported_kernel_ops=tuple()
)
quantizer_groups.append(quantizer_group)
logger.debug('\n Quantizer Group added: %s', quantizer_group)
Expand All @@ -437,42 +443,42 @@ def find_supported_candidates(quantizer_groups: List[QuantizerGroup],
quantizers_with_supported_candidates = defaultdict(list)

# pylint: disable=too-many-nested-blocks
# pylint: disable=protected-access
for quantizer_group in quantizer_groups:
quantizers = sorted(set(itertools.chain(quantizer_group.get_input_quantizer_modules(),
quantizer_group.output_quantizers,
quantizer_group.parameter_quantizers)))

# quantizers are now unique ops present in the given quantizer_group
onnx_ops = defaultdict(list)
for quantizer in quantizers:
if quantizer not in module_name_to_module_dict:
raise RuntimeError('module_name_to_module_dict does not contain an entry for the quantizer:',
quantizer)

# pylint: disable=protected-access
module = module_name_to_module_dict[quantizer]._module_to_wrap

supported_kernel_types = set()
for supported_kernel_op in quantizer_group.supported_kernel_ops:
module = module_name_to_module_dict[supported_kernel_op]._module_to_wrap
try:
backend_type = aimet_op_to_backend_op_name_map[module.__class__]
except KeyError:
backend_type = aimet_op_to_backend_op_name_map.get(module.__class__.__name__)

if backend_type in supported_kernels:
onnx_ops[quantizer] = [backend_type]
supported_kernel_types.add(backend_type)
else:
onnx_types = onnx_utils.map_torch_types_to_onnx.get(
type(module_name_to_module_dict[quantizer]._module_to_wrap), [])
type(module_name_to_module_dict[supported_kernel_op]._module_to_wrap), [])

if not onnx_types:
logger.warning("No mapping found for %s in the torch to onnx op type mapping dictionary.",
str(type(module_name_to_module_dict[quantizer]._module_to_wrap)))
str(type(module_name_to_module_dict[supported_kernel_op]._module_to_wrap)))

supported_kernel_types.update(onnx_types)

onnx_ops[quantizer] = onnx_types
for onnx_type in onnx_types:
if onnx_type not in supported_kernels.keys():
if module in supported_kernels:
supported_kernels[onnx_type] = supported_kernels[module]

for quantizer in quantizers:
onnx_ops[quantizer] = list(supported_kernel_types)

supported_kernels_for_quantizers = get_supported_candidates_for_quantizers(quantizers,
onnx_ops,
supported_kernels,
Expand Down
10 changes: 4 additions & 6 deletions TrainingExtensions/torch/test/python/test_mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,19 +1111,17 @@ def test_supported_candidates_2(
# default_supported_kernels and conv_supported_kernels are the configurations added in the json file above.
default_supported_kernels = [((16, QuantizationDataType.int), (16, QuantizationDataType.int)),
((16, QuantizationDataType.float), (16, QuantizationDataType.float)),
((8, QuantizationDataType.float), (16, QuantizationDataType.float))]
((8, QuantizationDataType.int), (16, QuantizationDataType.int))]

conv_supported_kernels = [((16, QuantizationDataType.float), (16, QuantizationDataType.float)),
((8, QuantizationDataType.int), (16, QuantizationDataType.int))]

for quantizer_group, quantizer_candidates in algo._supported_candidates_per_quantizer_group.items():
quantizers = sorted(itertools.chain(quantizer_group.get_input_quantizer_modules(),
quantizer_group.output_quantizers,
quantizer_group.parameter_quantizers))
supported_kernel_ops = quantizer_group.supported_kernel_ops
onnx_types = []
for q in quantizers:
for op in supported_kernel_ops:
onnx_types.append(
onnx_utils.map_torch_types_to_onnx.get(type(algo._module_name_dict[q]._module_to_wrap)))
onnx_utils.map_torch_types_to_onnx.get(type(algo._module_name_dict[op]._module_to_wrap)))

# verify to make sure the candidates returned is always part of amp_candidates and they are part of
# "Defaults" or "Conv" appropriately
Expand Down
Loading

0 comments on commit ccb56ab

Please sign in to comment.