Skip to content

Commit

Permalink
Fix support const as input to linear layers in pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
lapid92 committed May 22, 2024
1 parent f822b9d commit 313cf0a
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 29 deletions.
6 changes: 4 additions & 2 deletions model_compression_toolkit/core/common/graph/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,10 @@ def insert_positional_weights_to_input_list(self, input_tensors: List) -> List:
"""
for pos, weight in sorted((pos, weight) for pos, weight in self.weights.items()
if isinstance(pos, int)):
assert pos <= len(input_tensors), 'Positional weight index mismatch'
# Insert only positional weights that are not subject to quantization.
if pos > len(input_tensors):
Logger.critical("The positional weight index cannot exceed the number of input tensors to the node.") # pragma: no cover
# Insert only positional weights that are not subject to quantization. If the positional weight is
# subject to quantization, the quantization wrapper inserts the positional weight into the node.
if not self.is_weights_quantization_enabled(pos):
input_tensors.insert(pos, weight)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,17 @@ def forward(self, x):

class ConstRepresentationTest(BasePytorchFeatureNetworkTest):

def __init__(self, unit_test, func, const, input_reverse_order=False):
def __init__(self, unit_test, func, const, input_reverse_order=False, enable_weights_quantization=False):
super().__init__(unit_test=unit_test, input_shape=(16, 32, 32))
self.func = func
self.const = const
self.input_reverse_order = input_reverse_order
self.enable_weights_quantization = enable_weights_quantization

def get_tpc(self):
tp = generate_test_tp_model({'weights_n_bits': 32,
'activation_n_bits': 32,
'enable_weights_quantization': True,
'enable_weights_quantization': self.enable_weights_quantization,
'enable_activation_quantization': False})
return generate_pytorch_tpc(name="linear_collapsing_test", tp_model=tp)

Expand Down Expand Up @@ -97,8 +98,9 @@ def forward(self, x):

class ConstRepresentationMultiInputTest(ConstRepresentationTest):

def __init__(self, unit_test):
super().__init__(unit_test=unit_test, func=None, const=None, input_reverse_order=False)
def __init__(self, unit_test, enable_weights_quantization):
super().__init__(unit_test=unit_test, func=None, const=None, input_reverse_order=False,
enable_weights_quantization=enable_weights_quantization)

def create_networks(self):
return ConstRepresentationMultiInputNet()
Expand All @@ -118,8 +120,9 @@ def forward(self, x):

class ConstRepresentationLinearLayerTest(ConstRepresentationTest):

def __init__(self, unit_test, func, const):
super().__init__(unit_test=unit_test, func=func, const=const, input_reverse_order=False)
def __init__(self, unit_test, func, const, enable_weights_quantization):
super().__init__(unit_test=unit_test, func=func, const=const, input_reverse_order=False,
enable_weights_quantization=enable_weights_quantization)

def create_networks(self):
return ConstRepresentationLinearLayerNet(self.func, self.const)
Expand All @@ -139,8 +142,9 @@ def forward(self, x):

class ConstRepresentationGetIndexTest(ConstRepresentationTest):

def __init__(self, unit_test, func, const, indices):
super().__init__(unit_test=unit_test, func=func, const=const, input_reverse_order=False)
def __init__(self, unit_test, func, const, indices, enable_weights_quantization):
super().__init__(unit_test=unit_test, func=func, const=const, input_reverse_order=False,
enable_weights_quantization=enable_weights_quantization)
self.func = func
self.const = const
self.indices = indices
Expand Down
46 changes: 27 additions & 19 deletions tests/pytorch_tests/model_tests/test_feature_models_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,25 +243,33 @@ def test_residual_collapsing(self):
# AdvancedConstQuantizationTest(self).run_test()

def test_const_representation(self):
for const_dtype in [np.float32, np.int64, np.int32]:
c = (np.ones((32,)) + np.random.random((32,))).astype(const_dtype)
for func in [torch.add, torch.sub, torch.mul, torch.div]:
ConstRepresentationTest(self, func, c).run_test()
ConstRepresentationTest(self, func, c, input_reverse_order=True).run_test()
ConstRepresentationTest(self, func, 2.45).run_test()
ConstRepresentationTest(self, func, 5, input_reverse_order=True).run_test()

c = (np.ones((64,)) + np.random.random((64,))).astype(const_dtype)
indices = np.random.randint(64, size=32)
for func in [torch.add, torch.sub, torch.mul, torch.div]:
ConstRepresentationGetIndexTest(self, func, c, indices).run_test()

ConstRepresentationMultiInputTest(self).run_test()

c = (np.ones((1, 16, 32, 32)) + np.random.random((1, 16, 32, 32))).astype(np.float32)
ConstRepresentationLinearLayerTest(self, func=nn.Linear(32, 32), const=c).run_test()
ConstRepresentationLinearLayerTest(self, func=nn.Conv2d(16, 16, 1), const=c).run_test()
ConstRepresentationLinearLayerTest(self, func=nn.ConvTranspose2d(16, 16, 1), const=c).run_test()
for enable_weights_quantization in [False, True]:
for const_dtype in [np.float32, np.int64, np.int32]:
c = (np.ones((32,)) + np.random.random((32,))).astype(const_dtype)
c_64 = (np.ones((64,)) + np.random.random((64,))).astype(const_dtype)
indices = np.random.randint(64, size=32)
for func in [torch.add, torch.sub, torch.mul, torch.div]:
ConstRepresentationTest(self, func, c,
enable_weights_quantization=enable_weights_quantization).run_test()
ConstRepresentationTest(self, func, c, input_reverse_order=True,
enable_weights_quantization=enable_weights_quantization).run_test()
ConstRepresentationTest(self, func, 2.45,
enable_weights_quantization=enable_weights_quantization).run_test()
ConstRepresentationTest(self, func, 5, input_reverse_order=True,
enable_weights_quantization=enable_weights_quantization).run_test()
ConstRepresentationGetIndexTest(self, func, c_64, indices,
enable_weights_quantization=enable_weights_quantization).run_test()

ConstRepresentationMultiInputTest(self, enable_weights_quantization=enable_weights_quantization).run_test()

c_img = (np.ones((1, 16, 32, 32)) + np.random.random((1, 16, 32, 32))).astype(np.float32)
ConstRepresentationLinearLayerTest(self, func=nn.Linear(32, 32), const=c_img,
enable_weights_quantization=enable_weights_quantization).run_test()
ConstRepresentationLinearLayerTest(self, func=nn.Conv2d(16, 16, 1),
const=c_img,
enable_weights_quantization=enable_weights_quantization).run_test()
ConstRepresentationLinearLayerTest(self, func=nn.ConvTranspose2d(16, 16, 1),
const=c_img, enable_weights_quantization=enable_weights_quantization).run_test()

def test_permute_substitution(self):
"""
Expand Down

0 comments on commit 313cf0a

Please sign in to comment.