diff --git a/.github/workflows/run_keras_tests.yml b/.github/workflows/run_keras_tests.yml index 95a220819..32cb60a67 100644 --- a/.github/workflows/run_keras_tests.yml +++ b/.github/workflows/run_keras_tests.yml @@ -25,6 +25,10 @@ jobs: pip install -r requirements.txt pip install tensorflow==${{ inputs.tf-version }} - name: Run unittests + # Some tests are sensitive to memory because we use tf gradients on a multi-thread/process + # CPU environment (https://github.com/tensorflow/tensorflow/issues/41718). + # For this reason, if we run them in such an environment, we need to run them first non-parallel separately. run: | - for script in tests/keras_tests/exporter_tests tests/keras_tests/feature_networks_tests tests/keras_tests/function_tests tests/keras_tests/graph_tests tests/keras_tests/layer_tests; do python -m unittest discover $script -v & pids+=($!); done; for pid in ${pids[@]}; do wait $pid || exit 1; done + python -m unittest discover tests/keras_tests/non_parallel_tests -v + for script in tests/keras_tests/exporter_tests tests/keras_tests/feature_networks_tests tests/keras_tests/graph_tests tests/keras_tests/layer_tests; do python -m unittest discover $script -v & pids+=($!); done; for pid in ${pids[@]}; do wait $pid || exit 1; done diff --git a/model_compression_toolkit/core/common/hessian/hessian_info_service.py b/model_compression_toolkit/core/common/hessian/hessian_info_service.py index c33e5a33b..5dfc57544 100644 --- a/model_compression_toolkit/core/common/hessian/hessian_info_service.py +++ b/model_compression_toolkit/core/common/hessian/hessian_info_service.py @@ -148,12 +148,6 @@ def fetch_hessian(self, The inner list length dependent on the granularity (1 for per-tensor, OC for per-output-channel when the requested node has OC output-channels, etc.) """ - num_keys = len(self.trace_hessian_request_to_score_list) - num_values = sum([len(list(v)) for v in self.trace_hessian_request_to_score_list.values()]) - print(f"########### Keys: {num_keys}") - print(f"########### Values: {num_values}") - - if required_size==0: return [] diff --git a/model_compression_toolkit/core/keras/hessian/activation_trace_hessian_calculator_keras.py b/model_compression_toolkit/core/keras/hessian/activation_trace_hessian_calculator_keras.py index b805065b7..599b71148 100644 --- a/model_compression_toolkit/core/keras/hessian/activation_trace_hessian_calculator_keras.py +++ b/model_compression_toolkit/core/keras/hessian/activation_trace_hessian_calculator_keras.py @@ -100,6 +100,10 @@ def compute(self) -> List[float]: grad = tf.reshape(grad, [grad.shape[0], -1]) score_approx_per_output.append(tf.reduce_mean(tf.reduce_sum(tf.pow(grad, 2.0)))) + # Free gradients + del grad + del gradients + # If the change to the mean approximation is insignificant (to all outputs) # we stop the calculation. if j > MIN_JACOBIANS_ITER: @@ -133,7 +137,11 @@ def compute(self) -> List[float]: trace_approx_by_node = tf.reduce_mean([trace_approx_by_node], axis=0) # Just to get one tensor instead of list of tensors with single element - return trace_approx_by_node.numpy().tolist() + # Free gradient tape + del g + + return trace_approx_by_node.numpy().tolist() + else: Logger.error(f"{self.hessian_request.granularity} is not supported for Keras activation hessian's trace approx calculator") diff --git a/model_compression_toolkit/core/keras/hessian/weights_trace_hessian_calculator_keras.py b/model_compression_toolkit/core/keras/hessian/weights_trace_hessian_calculator_keras.py index 95ef58abc..aaa1544bf 100644 --- a/model_compression_toolkit/core/keras/hessian/weights_trace_hessian_calculator_keras.py +++ b/model_compression_toolkit/core/keras/hessian/weights_trace_hessian_calculator_keras.py @@ -15,7 +15,6 @@ import numpy as np import tensorflow as tf -from keras.layers import Conv2D, Dense, Conv2DTranspose, DepthwiseConv2D from typing import List from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS, MIN_JACOBIANS_ITER, JACOBIANS_COMP_TOLERANCE @@ -116,6 +115,9 @@ def compute(self) -> np.ndarray: num_of_scores) approx = tf.reduce_sum(tf.pow(gradients, 2.0), axis=1) + # Free gradients + del gradients + # If the change to the mean approximation is insignificant (to all outputs) # we stop the calculation. if j > MIN_JACOBIANS_ITER: @@ -132,6 +134,9 @@ def compute(self) -> np.ndarray: # Compute the mean of the approximations final_approx = tf.reduce_mean(tf.stack(approximation_per_iteration), axis=0) + # Free gradient tape + del tape + if self.hessian_request.granularity == HessianInfoGranularity.PER_TENSOR: if final_approx.shape != (1,): Logger.error(f"In HessianInfoGranularity.PER_TENSOR the score shape is expected" diff --git a/model_compression_toolkit/gptq/keras/quantization_facade.py b/model_compression_toolkit/gptq/keras/quantization_facade.py index de2df6609..e9c98f3db 100644 --- a/model_compression_toolkit/gptq/keras/quantization_facade.py +++ b/model_compression_toolkit/gptq/keras/quantization_facade.py @@ -225,6 +225,8 @@ def keras_gradient_post_training_quantization_experimental(in_model: Model, tb_w, hessian_info_service=hessian_info_service) + del hessian_info_service + if core_config.debug_config.analyze_similarity: analyzer_model_quantization(representative_data_gen, tb_w, tg_gptq, fw_impl, fw_info) diff --git a/tests/keras_tests/function_tests/test_hessian_info_calculator_weights.py b/tests/keras_tests/function_tests/test_hessian_info_calculator_weights.py index f092862a4..968fd227c 100644 --- a/tests/keras_tests/function_tests/test_hessian_info_calculator_weights.py +++ b/tests/keras_tests/function_tests/test_hessian_info_calculator_weights.py @@ -129,6 +129,7 @@ def test_conv2d_granularity(self): interest_points[1], granularity=hessian_common.HessianInfoGranularity.PER_ELEMENT, expected_shape=(3, 3, 3, 2)) + del hessian_service def test_dense_granularity(self): input_shape = (1, 8) @@ -160,6 +161,7 @@ def test_dense_granularity(self): interest_points[1], granularity=hessian_common.HessianInfoGranularity.PER_ELEMENT, expected_shape=(8, 2)) + del hessian_service def test_conv2dtranspose_granularity(self): input_shape = (1, 8, 8, 3) @@ -191,6 +193,7 @@ def test_conv2dtranspose_granularity(self): interest_points[1], granularity=hessian_common.HessianInfoGranularity.PER_ELEMENT, expected_shape=(3, 3, 2, 3)) + del hessian_service def test_depthwiseconv2d_granularity(self): input_shape = (1, 8, 8, 3) @@ -222,6 +225,7 @@ def test_depthwiseconv2d_granularity(self): interest_points[1], granularity=hessian_common.HessianInfoGranularity.PER_ELEMENT, expected_shape=(3, 3, 3, 1)) + del hessian_service def test_reused_layer(self): input_shape = (1, 8, 8, 3) @@ -268,6 +272,7 @@ def test_reused_layer(self): granularity=hessian_common.HessianInfoGranularity.PER_ELEMENT)) self.assertTrue(node2_count == 1) self.assertTrue(len(hessian_service.trace_hessian_request_to_score_list)==1) + del hessian_service ######################################################### # The following part checks different possible graph @@ -308,6 +313,8 @@ def _test_advanced_graph(self, float_model, _repr_dataset): granularity=hessian_common.HessianInfoGranularity.PER_ELEMENT, expected_shape=(3, 3, 3, 2)) + del hessian_service + def test_multiple_inputs(self): input_shape = (1, 8, 8, 3) diff --git a/tests/keras_tests/function_tests/test_hessian_service.py b/tests/keras_tests/function_tests/test_hessian_service.py index ba4883cb7..0935f350b 100644 --- a/tests/keras_tests/function_tests/test_hessian_service.py +++ b/tests/keras_tests/function_tests/test_hessian_service.py @@ -44,6 +44,8 @@ def representative_dataset(num_of_inputs=1): class TestHessianService(unittest.TestCase): + def tearDown(self) -> None: + del self.hessian_service def setUp(self): diff --git a/tests/keras_tests/function_tests/test_model_gradients.py b/tests/keras_tests/function_tests/test_model_gradients.py index b92ea3bbe..05a7eb1f2 100644 --- a/tests/keras_tests/function_tests/test_model_gradients.py +++ b/tests/keras_tests/function_tests/test_model_gradients.py @@ -119,6 +119,7 @@ def _get_normalized_hessian_trace_approx(graph, interest_points, keras_impl, alp assert len(hessian_data_per_image) == 1 x.append(hessian_data_per_image[0]) x = hessian_common.hessian_utils.normalize_weights(x, alpha=alpha, outputs_indices=[len(interest_points) - 1]) + del hessian_service return x diff --git a/tests/keras_tests/non_parallel_tests/__init__.py b/tests/keras_tests/non_parallel_tests/__init__.py new file mode 100644 index 000000000..6fbcf2bf5 --- /dev/null +++ b/tests/keras_tests/non_parallel_tests/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# Some tests are sensitive to memory because we use tf gradients on a multi-thread/process +# CPU environment (https://github.com/tensorflow/tensorflow/issues/41718). +# For this reason, if we run them in such an environment, we need to run them first non-parallel separately. + diff --git a/tests/keras_tests/function_tests/test_keras_tp_model.py b/tests/keras_tests/non_parallel_tests/test_keras_tp_model.py similarity index 96% rename from tests/keras_tests/function_tests/test_keras_tp_model.py rename to tests/keras_tests/non_parallel_tests/test_keras_tp_model.py index 70f8a757a..a4537e9db 100644 --- a/tests/keras_tests/function_tests/test_keras_tp_model.py +++ b/tests/keras_tests/non_parallel_tests/test_keras_tp_model.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +import keras import unittest from functools import partial @@ -24,10 +25,10 @@ from model_compression_toolkit.core.common import BaseNode if version.parse(tf.__version__) >= version.parse("2.13"): - from keras.src.layers import Conv2D, Conv2DTranspose, ReLU, Activation + from keras.src.layers import Conv2D, Conv2DTranspose, ReLU, Activation, BatchNormalization from keras.src import Input else: - from keras.layers import Conv2D, Conv2DTranspose, ReLU, Activation + from keras.layers import Conv2D, Conv2DTranspose, ReLU, Activation, BatchNormalization from keras import Input import model_compression_toolkit as mct @@ -229,10 +230,15 @@ def test_keras_fusing_patterns(self): class TestGetKerasTPC(unittest.TestCase): def test_get_keras_tpc(self): tpc = mct.get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL) - model = MobileNetV2() + input_shape = (1, 8, 8, 3) + input_tensor = Input(shape=input_shape[1:]) + conv = Conv2D(3, 3)(input_tensor) + bn = BatchNormalization()(conv) + relu = ReLU()(bn) + model = keras.Model(inputs=input_tensor, outputs=relu) def rep_data(): - yield [np.random.randn(1, 224, 224, 3)] + yield [np.random.randn(*input_shape)] quantized_model, _ = mct.ptq.keras_post_training_quantization_experimental(model, rep_data, @@ -240,7 +246,8 @@ def rep_data(): new_experimental_exporter=True) core_config = mct.core.CoreConfig( - mixed_precision_config=mct.core.MixedPrecisionQuantizationConfigV2(num_of_images=1)) + mixed_precision_config=mct.core.MixedPrecisionQuantizationConfigV2(num_of_images=1, + use_grad_based_weights=False)) quantized_model, _ = mct.ptq.keras_post_training_quantization_experimental(model, rep_data, core_config=core_config, diff --git a/tests/keras_tests/function_tests/test_lp_search_bitwidth.py b/tests/keras_tests/non_parallel_tests/test_lp_search_bitwidth.py similarity index 94% rename from tests/keras_tests/function_tests/test_lp_search_bitwidth.py rename to tests/keras_tests/non_parallel_tests/test_lp_search_bitwidth.py index b4c80208b..ee081ae4c 100644 --- a/tests/keras_tests/function_tests/test_lp_search_bitwidth.py +++ b/tests/keras_tests/non_parallel_tests/test_lp_search_bitwidth.py @@ -14,29 +14,30 @@ # ============================================================================== import numpy as np import unittest -from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2 +import keras +from model_compression_toolkit.core import DEFAULTCONFIG from model_compression_toolkit.core.common.mixed_precision.distance_weighting import get_average_weights, \ get_last_layer_weights from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI, KPITarget from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \ MixedPrecisionQuantizationConfigV2 -from model_compression_toolkit.core.common.quantization.core_config import CoreConfig from model_compression_toolkit.core.common.mixed_precision.mixed_precision_search_facade import search_bit_width, \ BitWidthSearchMethod from model_compression_toolkit.core.common.mixed_precision.search_methods.linear_programming import \ mp_integer_programming_search +from model_compression_toolkit.core.common.model_collector import ModelCollector +from model_compression_toolkit.core.common.quantization.core_config import CoreConfig from model_compression_toolkit.core.common.quantization.quantization_analyzer import analyzer_graph from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_computation import \ calculate_quantization_params from model_compression_toolkit.core.common.quantization.set_node_quantization_config import \ set_quantization_configuration_to_graph -from model_compression_toolkit.core.common.model_collector import ModelCollector -from model_compression_toolkit.core import DEFAULTCONFIG from model_compression_toolkit.core.common.similarity_analyzer import compute_mse -from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import get_op_quantization_configs, generate_keras_tpc from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation +from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import \ + get_op_quantization_configs from tests.keras_tests.tpc_keras import get_weights_only_mp_tpc_keras @@ -204,7 +205,12 @@ def run_search_bitwidth_config_test(self, core_config): name="bitwidth_cfg_test") fw_info = DEFAULT_KERAS_INFO - in_model = MobileNetV2() + input_shape = (1, 8, 8, 3) + input_tensor = keras.layers.Input(shape=input_shape[1:]) + conv = keras.layers.Conv2D(3, 3)(input_tensor) + bn = keras.layers.BatchNormalization()(conv) + relu = keras.layers.ReLU()(bn) + in_model = keras.Model(inputs=input_tensor, outputs=relu) keras_impl = KerasImplementation() def dummy_representative_dataset(): @@ -230,19 +236,19 @@ def dummy_representative_dataset(): fw_info=DEFAULT_KERAS_INFO, fw_impl=keras_impl) - for i in range(10): - mi.infer([np.random.randn(1, 224, 224, 3)]) + for i in range(1): + mi.infer([np.random.randn(*input_shape)]) def representative_data_gen(): - yield [np.random.random((1, 224, 224, 3))] + yield [np.random.random(input_shape)] calculate_quantization_params(graph, fw_info, fw_impl=keras_impl) - keras_sens_eval = keras_impl.get_sensitivity_evaluator(graph, - core_config.mixed_precision_config, - representative_data_gen, - fw_info=fw_info) + keras_impl.get_sensitivity_evaluator(graph, + core_config.mixed_precision_config, + representative_data_gen, + fw_info=fw_info) cfg = search_bit_width(graph_to_search_cfg=graph, fw_info=DEFAULT_KERAS_INFO, diff --git a/tests/keras_tests/function_tests/test_tensorboard_writer.py b/tests/keras_tests/non_parallel_tests/test_tensorboard_writer.py similarity index 98% rename from tests/keras_tests/function_tests/test_tensorboard_writer.py rename to tests/keras_tests/non_parallel_tests/test_tensorboard_writer.py index f9d91a2f3..2bbd72fa7 100644 --- a/tests/keras_tests/function_tests/test_tensorboard_writer.py +++ b/tests/keras_tests/non_parallel_tests/test_tensorboard_writer.py @@ -139,7 +139,8 @@ def test_steps_by_order(self): def rep_data(): yield [np.random.randn(1, 8, 8, 3)] - mp_qc = mct.core.MixedPrecisionQuantizationConfigV2(num_of_images=1) + mp_qc = mct.core.MixedPrecisionQuantizationConfigV2(num_of_images=1, + use_grad_based_weights=False) core_config = mct.core.CoreConfig(mixed_precision_config=mp_qc) quantized_model, _ = mct.ptq.keras_post_training_quantization_experimental(self.model, rep_data, diff --git a/tests/test_suite.py b/tests/test_suite.py index 3e7f69006..277569340 100644 --- a/tests/test_suite.py +++ b/tests/test_suite.py @@ -18,8 +18,6 @@ import importlib import unittest -from packaging import version - from tests.common_tests.function_tests.test_collectors_manipulation import TestCollectorsManipulations from tests.common_tests.function_tests.test_folder_image_loader import TestFolderLoader # ---------------- Individual test suites @@ -38,16 +36,15 @@ "torchvision") is not None if found_tf: - import tensorflow as tf from tests.keras_tests.function_tests.test_hessian_info_calculator_weights import TestHessianInfoCalculatorWeights from tests.keras_tests.function_tests.test_hessian_service import TestHessianService from tests.keras_tests.feature_networks_tests.test_features_runner import FeatureNetworkTest from tests.keras_tests.function_tests.test_quantization_configurations import TestQuantizationConfigurations - from tests.keras_tests.function_tests.test_tensorboard_writer import TestFileLogger + from tests.keras_tests.non_parallel_tests.test_tensorboard_writer import TestFileLogger from tests.keras_tests.function_tests.test_lut_quanitzer_params import TestLUTQuantizerParams from tests.keras_tests.function_tests.test_lut_activation_quanitzer_params import TestLUTActivationsQuantizerParams from tests.keras_tests.function_tests.test_lut_activation_quanitzer_fake_quant import TestLUTQuantizerFakeQuant - from tests.keras_tests.function_tests.test_lp_search_bitwidth import TestLpSearchBitwidth, \ + from tests.keras_tests.non_parallel_tests.test_lp_search_bitwidth import TestLpSearchBitwidth, \ TestSearchBitwidthConfiguration from tests.keras_tests.function_tests.test_bn_info_collection import TestBNInfoCollection from tests.keras_tests.graph_tests.test_graph_reading import TestGraphReading @@ -57,7 +54,7 @@ TestSymmetricThresholdSelectionWeights from tests.keras_tests.function_tests.test_uniform_quantize_tensor import TestUniformQuantizeTensor from tests.keras_tests.function_tests.test_uniform_range_selection_weights import TestUniformRangeSelectionWeights - from tests.keras_tests.function_tests.test_keras_tp_model import TestKerasTPModel + from tests.keras_tests.non_parallel_tests.test_keras_tp_model import TestKerasTPModel from tests.keras_tests.function_tests.test_sensitivity_metric_interest_points import \ TestSensitivityMetricInterestPoints from tests.keras_tests.function_tests.test_weights_activation_split_substitution import TestWeightsActivationSplit