Skip to content

Commit

Permalink
code review modification
Browse files Browse the repository at this point in the history
  • Loading branch information
eladc-git committed Oct 30, 2023
1 parent c1734f1 commit 1b88e58
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import tensorflow as tf
from typing import List

from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS, MIN_JACOBIANS_ITER, JACOBIANS_COMP_TOLERANCE
from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS, MIN_JACOBIANS_ITER, JACOBIANS_COMP_TOLERANCE, HESSIAN_EPS
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.core.common.hessian import TraceHessianRequest, HessianInfoGranularity
from model_compression_toolkit.core.keras.back2framework.float_model_builder import FloatKerasModelBuilder
Expand Down Expand Up @@ -124,7 +124,7 @@ def compute(self) -> np.ndarray:
# Compute new means and deltas
new_mean = tf.reduce_mean(tf.stack(approximation_per_iteration + approx), axis=0)
delta = new_mean - tf.reduce_mean(tf.stack(approximation_per_iteration), axis=0)
is_converged = np.all(np.abs(delta) / (np.abs(new_mean) + 1e-6) < JACOBIANS_COMP_TOLERANCE)
is_converged = np.all(np.abs(delta) / (np.abs(new_mean) + HESSIAN_EPS) < JACOBIANS_COMP_TOLERANCE)
if is_converged:
approximation_per_iteration.append(approx)
break
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ def __init__(self,
def compute(self) -> np.ndarray:
"""
Compute the Hessian-based scores w.r.t target node's weights.
The computed scores are returned in a numpy array. The shape of the result differs
according to the requested granularity. If for example the node is Conv2D with a kernel
shape of (2, 3, 3, 3) (namely, 3 input channels, 2 output channels and kernel size of 3x3)
and the required granularity is HessianInfoGranularity.PER_TENSOR the result shape will be (1,),
for HessianInfoGranularity.PER_OUTPUT_CHANNEL the shape will be (2,) and for
HessianInfoGranularity.PER_ELEMENT a shape of (2, 3, 3, 3).
Returns:
The computed scores as numpy ndarray for target node's weights.
Expand Down Expand Up @@ -107,8 +113,8 @@ def compute(self) -> np.ndarray:
approx = torch.sum(approx, dim=shape_channel_axis)

if j > MIN_JACOBIANS_ITER:
new_mean = (torch.sum(torch.stack(approximation_per_iteration))+approx)/(j+1)
delta = new_mean - torch.mean(torch.stack(approximation_per_iteration))
new_mean = (torch.sum(torch.stack(approximation_per_iteration), dim=0) + approx)/(j+1)
delta = new_mean - torch.mean(torch.stack(approximation_per_iteration), dim=0)
converged_tensor = torch.abs(delta) / (torch.abs(new_mean) + HESSIAN_EPS) < JACOBIANS_COMP_TOLERANCE
if torch.all(converged_tensor):
break
Expand Down
3 changes: 2 additions & 1 deletion tests/pytorch_tests/function_tests/test_function_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from tests.pytorch_tests.function_tests.test_sensitivity_eval_output_replacement import \
TestSensitivityEvalWithArgmaxOutputReplacementNodes, TestSensitivityEvalWithSoftmaxOutputReplacementNodes
from tests.pytorch_tests.function_tests.test_hessian_info_weights import WeightsHessianTraceBasicModelTest, WeightsHessianTraceAdvanceModelTest, \
WeightsHessianTraceMultipleOutputsModelTest
WeightsHessianTraceMultipleOutputsModelTest, WeightsHessianTraceReuseModelTest


class FunctionTestRunner(unittest.TestCase):
Expand Down Expand Up @@ -121,6 +121,7 @@ def test_weights_hessian_trace(self):
WeightsHessianTraceBasicModelTest(self).run_test()
WeightsHessianTraceAdvanceModelTest(self).run_test()
WeightsHessianTraceMultipleOutputsModelTest(self).run_test()
WeightsHessianTraceReuseModelTest(self).run_test()

def test_layer_fusing(self):
"""
Expand Down
71 changes: 64 additions & 7 deletions tests/pytorch_tests/function_tests/test_hessian_info_weights.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -54,7 +54,7 @@ def __init__(self):
self.conv2 = Conv2d(3, 3, kernel_size=1, stride=1)
self.bn2 = BatchNorm2d(3)
self.relu2 = ReLU()
self.dense = Linear(32, 7)
self.dense = Linear(8, 7)

def forward(self, inp):
x = self.conv1(inp)
Expand All @@ -76,7 +76,7 @@ def __init__(self):
self.conv2 = Conv2d(3, 3, kernel_size=1, stride=1)
self.bn2 = BatchNorm2d(3)
self.hswish = Hardswish()
self.dense = Linear(32, 7)
self.dense = Linear(8, 7)

def forward(self, inp):
x = self.conv1(inp)
Expand All @@ -89,6 +89,25 @@ def forward(self, inp):
return x1, x2, x3


class reused_model(torch.nn.Module):
def __init__(self):
super(reused_model, self).__init__()
self.conv1 = Conv2d(3, 3, kernel_size=1, stride=1)
self.bn1 = BatchNorm2d(3)
self.relu = ReLU()

def forward(self, inp):
x = self.conv1(inp)
x1 = self.bn1(x)
x1 = self.relu(x1)
x_split = torch.split(x1, split_size_or_sections=4, dim=-1)
x1 = self.conv1(x_split[0])
x2 = x_split[1]
x1 = self.relu(x1)
y = torch.concat([x1, x2], dim=-1)
return y


def generate_inputs(inputs_shape):
inputs = []
for in_shape in inputs_shape:
Expand Down Expand Up @@ -130,7 +149,7 @@ def __init__(self, unit_test):
self.val_batch_size = 1

def create_inputs_shape(self):
return [[self.val_batch_size, 3, 32, 32]]
return [[self.val_batch_size, 3, 8, 8]]

@staticmethod
def generate_inputs(input_shapes):
Expand Down Expand Up @@ -164,10 +183,10 @@ def run_test(self, seed=0):
class WeightsHessianTraceAdvanceModelTest(BasePytorchTest):
def __init__(self, unit_test):
super().__init__(unit_test)
self.val_batch_size = 4
self.val_batch_size = 2

def create_inputs_shape(self):
return [[self.val_batch_size, 3, 32, 32]]
return [[self.val_batch_size, 3, 8, 8]]

@staticmethod
def generate_inputs(input_shapes):
Expand Down Expand Up @@ -207,7 +226,7 @@ def __init__(self, unit_test):
self.val_batch_size = 1

def create_inputs_shape(self):
return [[self.val_batch_size, 3, 32, 32]]
return [[self.val_batch_size, 3, 8, 8]]

@staticmethod
def generate_inputs(input_shapes):
Expand Down Expand Up @@ -241,5 +260,43 @@ def run_test(self, seed=0):
granularity=hessian_common.HessianInfoGranularity.PER_ELEMENT)


class WeightsHessianTraceReuseModelTest(BasePytorchTest):
def __init__(self, unit_test):
super().__init__(unit_test)
self.val_batch_size = 1

def create_inputs_shape(self):
return [[self.val_batch_size, 3, 8, 8]]

@staticmethod
def generate_inputs(input_shapes):
return generate_inputs(input_shapes)

def representative_data_gen(self):
input_shapes = self.create_inputs_shape()
yield self.generate_inputs(input_shapes)

def run_test(self, seed=0):
model_float = reused_model()
pytorch_impl = PytorchImplementation()
graph = prepare_graph_with_configs(model_float, PytorchImplementation(), DEFAULT_PYTORCH_INFO,
self.representative_data_gen, generate_pytorch_tpc)
hessian_service = hessian_common.HessianInfoService(graph=graph,
representative_dataset=self.representative_data_gen,
fw_impl=pytorch_impl)
ipts = [n for n in graph.get_topo_sorted_nodes() if len(n.weights)>0]
for ipt in ipts:
test_weights_hessian_trace_approx(hessian_service,
interest_point=ipt,
num_scores=1,
granularity=hessian_common.HessianInfoGranularity.PER_OUTPUT_CHANNEL)
test_weights_hessian_trace_approx(hessian_service,
interest_point=ipt,
num_scores=2,
granularity=hessian_common.HessianInfoGranularity.PER_TENSOR)
test_weights_hessian_trace_approx(hessian_service,
interest_point=ipt,
num_scores=3,
granularity=hessian_common.HessianInfoGranularity.PER_ELEMENT)


0 comments on commit 1b88e58

Please sign in to comment.