Skip to content

Commit

Permalink
Support grouped convolution when initializing Layer object
Browse files Browse the repository at this point in the history
Signed-off-by: Geunho Lee <quic_geunlee@quicinc.com>
  • Loading branch information
quic-geunlee authored Feb 2, 2024
1 parent 75131de commit 1a675b5
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 18 deletions.
17 changes: 10 additions & 7 deletions TrainingExtensions/torch/src/python/aimet_torch/layer_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2018, Qualcomm Innovation Center, Inc. All rights reserved.
# Copyright (c) 2018-2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
Expand Down Expand Up @@ -37,7 +37,7 @@

"""Stores and updates Layer Attributes"""
import copy
from typing import Tuple, Union
from typing import Tuple, Union, List
import torch

from aimet_torch import utils
Expand All @@ -56,7 +56,7 @@ def _set_type_specific_params(self, module):
params = aimet_common.layer_database.Conv2dTypeSpecificParams(module.stride, module.padding, module.groups)
self.type_specific_params = params

def __init__(self, module: torch.nn.Module, name, output_shape):
def __init__(self, module: torch.nn.Module, name: str, output_shape: Union[List, Tuple]):
"""
Constructor
:param module: Reference to the layer
Expand All @@ -65,10 +65,13 @@ def __init__(self, module: torch.nn.Module, name, output_shape):
"""
if isinstance(module, torch.nn.Conv2d):
if module.groups > 1:
assert module.groups == module.in_channels
assert module.in_channels == module.out_channels

weight_shape = (module.out_channels, 1, module.kernel_size[0], module.kernel_size[1])
if module.in_channels == module.groups: # Depthwise convolution
assert module.in_channels == module.out_channels
weight_shape = (module.out_channels, 1, module.kernel_size[0], module.kernel_size[1])
elif module.in_channels % module.groups == 0: # Grouped convolution
weight_shape = (module.out_channels, module.in_channels // module.groups, module.kernel_size[0], module.kernel_size[1])
else:
raise AssertionError("Conv2d with invalid in_channels and groups values")
else:
weight_shape = (module.out_channels, module.in_channels, module.kernel_size[0], module.kernel_size[1])

Expand Down
25 changes: 22 additions & 3 deletions TrainingExtensions/torch/test/python/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2019-2023, Qualcomm Innovation Center, Inc. All rights reserved.
# Copyright (c) 2019-2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
Expand Down Expand Up @@ -35,13 +35,11 @@
# @@-COPYRIGHT-END-@@
# =============================================================================
""" Models for use in unit testing """

# pylint: skip-file
from collections import namedtuple
from typing import Dict, List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn as nn
from torchvision.ops import roi_align
Expand Down Expand Up @@ -1149,3 +1147,24 @@ def forward(self, *inputs):
x = self.conv3(x)
x = self.act3(x)
return x


class GroupedConvModel(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, groups=2, bias=False)

def forward(self, *inputs):
return self.conv(inputs[0])


class CustomGroupedConvModel(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, bias=False)
self.conv2 = nn.Conv2d(in_channels, out_channels, 3, bias=False)

def forward(self, *inputs):
input1, input2 = inputs
output1, output2 = self.conv1(input1), self.conv2(input2)
return torch.cat([output1, output2], dim=1)
57 changes: 51 additions & 6 deletions TrainingExtensions/torch/test/python/test_cost_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2017-2018, Qualcomm Innovation Center, Inc. All rights reserved.
# Copyright (c) 2017-2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
Expand Down Expand Up @@ -34,21 +34,27 @@
#
# @@-COPYRIGHT-END-@@
# =============================================================================

import math
import unittest
from decimal import Decimal

import torch
import torch.nn as nn

from aimet_common import cost_calculator as cc
from aimet_common.defs import CostMetric, LayerCompRatioPair
from aimet_common.utils import AimetLogger

from aimet_torch.utils import create_rand_tensors_given_shapes, create_fake_data_loader, get_device
from aimet_torch.channel_pruning.channel_pruner import (
InputChannelPruner,
ChannelPruningCostCalculator,
)
from aimet_torch.layer_database import Layer, LayerDatabase
from aimet_torch.channel_pruning.channel_pruner import InputChannelPruner, ChannelPruningCostCalculator
from models import mnist_torch_model
from aimet_torch.utils import (
create_rand_tensors_given_shapes,
create_fake_data_loader,
get_device,
)
from models import mnist_torch_model, test_models

logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Test)

Expand Down Expand Up @@ -122,6 +128,45 @@ def test_total_model_cost(self):
self.assertEqual(800 + 51200 + 3211264 + 10240, network_cost.memory)
self.assertEqual(627200 + 10035200 + 3211264 + 10240, network_cost.mac)

def test_mac_count_of_grouped_conv_net(self):
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

in_channels, out_channels = 6, 12
standard_grouped_conv_model = test_models.GroupedConvModel(
in_channels, out_channels
).to(device).eval()

custom_grouped_conv_model = test_models.CustomGroupedConvModel(
in_channels // 2, out_channels // 2
).to(device).eval()
with torch.no_grad():
custom_grouped_conv_model.conv1.weight.copy_(
standard_grouped_conv_model.conv.weight[: out_channels // 2]
)
custom_grouped_conv_model.conv2.weight.copy_(
standard_grouped_conv_model.conv.weight[out_channels // 2 :]
)

standard_module_inputs = torch.randn(1, 6, 10, 10, device=device)
custom_module_inputs = (
standard_module_inputs[:, : in_channels // 2, :, :],
standard_module_inputs[:, in_channels // 2 :, :, :],
)

with torch.inference_mode():
standard_module_outputs = standard_grouped_conv_model(standard_module_inputs)
custom_module_outputs = custom_grouped_conv_model(*custom_module_inputs)
self.assertTrue(torch.allclose(standard_module_outputs, custom_module_outputs))

standard_grouped_conv_db = LayerDatabase(standard_grouped_conv_model, standard_module_inputs)
custom_grouped_conv_db = LayerDatabase(custom_grouped_conv_model, custom_module_inputs)
cost_calc = cc.CostCalculator()

standard_grouped_conv_cost = cost_calc.compute_model_cost(standard_grouped_conv_db)
custom_grouped_conv_cost = cost_calc.compute_model_cost(custom_grouped_conv_db)
self.assertEqual(standard_grouped_conv_cost.mac, custom_grouped_conv_cost.mac)


class TestTrainingExtensionsSpatialSvdCostCalculator(unittest.TestCase):

Expand Down
11 changes: 10 additions & 1 deletion TrainingExtensions/torch/test/python/test_layer_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2017-2018, Qualcomm Innovation Center, Inc. All rights reserved.
# Copyright (c) 2017-2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
Expand Down Expand Up @@ -38,6 +38,7 @@
import unittest
from unittest.mock import MagicMock

import torch
import torch.nn as nn
from torch.nn import Conv2d, Linear

Expand Down Expand Up @@ -109,3 +110,11 @@ def test_select_all_conv_and_fc_layers(self):

layer_selector.select(layer_db, [layer2.module])
layer_db.mark_picked_layers.assert_called_once_with([layer1, layer3])

def test_grouped_convolution_support(self):
dummy_input = torch.randn(1, 4, 8, 8)
grouped_convolution = nn.Conv2d(4, 16, kernel_size=3, groups=2)

output_shape = grouped_convolution(dummy_input)
layer = Layer(grouped_convolution, "grouped_conv", output_shape)
assert layer.type_specific_params.groups == 2
55 changes: 54 additions & 1 deletion TrainingExtensions/torch/test/python/test_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2017-2023, Qualcomm Innovation Center, Inc. All rights reserved.
# Copyright (c) 2017-2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
Expand Down Expand Up @@ -3011,6 +3011,59 @@ def assert_param_quantizers(param_quantizer, module_name, param_name):

os.remove("./temp_partial_torch_encodings.encodings")

def test_logits_of_grouped_conv_net(self):
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

in_channels, out_channels = 6, 12
standard_grouped_conv_model = test_models.GroupedConvModel(
in_channels, out_channels
).to(device).eval()

custom_grouped_conv_model = test_models.CustomGroupedConvModel(
in_channels // 2, out_channels // 2
).to(device).eval()
with torch.no_grad():
custom_grouped_conv_model.conv1.weight.copy_(
standard_grouped_conv_model.conv.weight[: out_channels // 2]
)
custom_grouped_conv_model.conv2.weight.copy_(
standard_grouped_conv_model.conv.weight[out_channels // 2 :]
)

standard_module_inputs = torch.randn(1, 6, 10, 10, device=device)
custom_module_inputs = (
standard_module_inputs[:, : in_channels // 2, :, :],
standard_module_inputs[:, in_channels // 2 :, :, :],
)

pcq_config_path = get_path_for_per_channel_config()
sim_from_standard = QuantizationSimModel(
standard_grouped_conv_model, standard_module_inputs, config_file=pcq_config_path
)
sim_from_custom = QuantizationSimModel(
custom_grouped_conv_model, custom_module_inputs, config_file=pcq_config_path
)

# Disable activation quantizers to measure impact of grouped conv weight
def _disable_activation_quantizers(sim):
for _, wrapper in sim.quant_wrappers():
for q in wrapper.input_quantizers:
q.enabled = False

for q in wrapper.output_quantizers:
q.enabled = False

_disable_activation_quantizers(sim_from_standard)
_disable_activation_quantizers(sim_from_custom)

sim_from_standard.compute_encodings(lambda m, _: m(standard_module_inputs), None)
sim_from_custom.compute_encodings(lambda m, _: m(*custom_module_inputs), None)
with torch.inference_mode():
standard_module_outputs = sim_from_standard.model(standard_module_inputs)
custom_module_outputs = sim_from_custom.model(*custom_module_inputs)
assert torch.allclose(standard_module_outputs, custom_module_outputs)


class TestQuantizationSimLearnedGrid:

Expand Down

0 comments on commit 1a675b5

Please sign in to comment.