Skip to content

Commit

Permalink
Move TensorboardWriter function from core runner (#876)
Browse files Browse the repository at this point in the history
Move tensor board functions to tensorboard_writer.py

---------

Co-authored-by: Ofir Gordon <Ofir.Gordon@altair-semi.com>
  • Loading branch information
ofirgo and Ofir Gordon authored Dec 3, 2023
1 parent a505a88 commit 753debd
Show file tree
Hide file tree
Showing 13 changed files with 83 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from copy import deepcopy

import io
import os
import numpy as np
from PIL import Image
from matplotlib.figure import Figure
Expand All @@ -34,6 +35,9 @@
from model_compression_toolkit.core import FrameworkInfo
from model_compression_toolkit.core.common import Graph, BaseNode
from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.core.common.visualization.final_config_visualizer import \
WeightsFinalBitwidthConfigVisualizer, ActivationFinalBitwidthConfigVisualizer

DEVICE_STEP_STATS = "/device:CPU:0"

Expand Down Expand Up @@ -486,3 +490,45 @@ def add_figure(self,
er = self.__get_event_writer_by_tag_name(main_tag_name)
er.add_event(event)
er.flush()


def init_tensorboard_writer(fw_info: FrameworkInfo) -> TensorboardWriter:
"""
Create a TensorBoardWriter object initialized with the logger dir path if it was set,
or None otherwise.
Args:
fw_info: FrameworkInfo object.
Returns:
A TensorBoardWriter object.
"""
tb_w = None
if Logger.LOG_PATH is not None:
tb_log_dir = os.path.join(os.getcwd(), Logger.LOG_PATH, 'tensorboard_logs')
Logger.info(f'To use Tensorboard, please run: tensorboard --logdir {tb_log_dir}')
tb_w = TensorboardWriter(tb_log_dir, fw_info)
return tb_w


def finalize_bitwidth_in_tb(tb_w: TensorboardWriter,
weights_conf_nodes_bitwidth: List,
activation_conf_nodes_bitwidth: List):
"""
Set the final bit-width configuration of the quantized model in the provided TensorBoard object.
Args:
tb_w: A TensorBoard object.
weights_conf_nodes_bitwidth: Final weights bit-width configuration.
activation_conf_nodes_bitwidth: Final activation bit-width configuration.
"""

if len(weights_conf_nodes_bitwidth) > 0:
visual = WeightsFinalBitwidthConfigVisualizer(weights_conf_nodes_bitwidth)
figure = visual.plot_config_bitwidth()
tb_w.add_figure(figure, f'Weights final bit-width config')
if len(activation_conf_nodes_bitwidth) > 0:
visual = ActivationFinalBitwidthConfigVisualizer(activation_conf_nodes_bitwidth)
figure = visual.plot_config_bitwidth()
tb_w.add_figure(figure, f'Activation final bit-width config')
40 changes: 3 additions & 37 deletions model_compression_toolkit/core/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@
# ==============================================================================


import os
from typing import Callable, Tuple, Any, List, Dict

import numpy as np
from tqdm import tqdm

from model_compression_toolkit.core.common import FrameworkInfo
from model_compression_toolkit.core.common.hessian.hessian_info_service import HessianInfoService
Expand All @@ -33,20 +31,14 @@
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi_functions_mapping import kpi_functions_mapping
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi_methods import MpKpiMetric
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_search_facade import search_bit_width
from model_compression_toolkit.core.common.model_collector import ModelCollector
from model_compression_toolkit.core.common.network_editors.edit_network import edit_network_graph
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
from model_compression_toolkit.core.common.quantization.quantization_analyzer import analyzer_graph
from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_computation import \
calculate_quantization_params
from model_compression_toolkit.core.common.statistics_correction.statistics_correction import \
statistics_correction_runner
from model_compression_toolkit.core.common.substitutions.apply_substitutions import substitute
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
from model_compression_toolkit.core.common.visualization.final_config_visualizer import \
WeightsFinalBitwidthConfigVisualizer, \
ActivationFinalBitwidthConfigVisualizer
from model_compression_toolkit.core.common.visualization.tensorboard_writer import TensorboardWriter
from model_compression_toolkit.core.common.visualization.tensorboard_writer import TensorboardWriter, \
finalize_bitwidth_in_tb


def core_runner(in_model: Any,
Expand Down Expand Up @@ -150,37 +142,11 @@ def core_runner(in_model: Any,
f'Final activation bit-width configuration: {[node_b[1] for node_b in activation_conf_nodes_bitwidth]}')

if tb_w is not None:
if len(weights_conf_nodes_bitwidth) > 0:
visual = WeightsFinalBitwidthConfigVisualizer(weights_conf_nodes_bitwidth)
figure = visual.plot_config_bitwidth()
tb_w.add_figure(figure, f'Weights final bit-width config')
if len(activation_conf_nodes_bitwidth) > 0:
visual = ActivationFinalBitwidthConfigVisualizer(activation_conf_nodes_bitwidth)
figure = visual.plot_config_bitwidth()
tb_w.add_figure(figure, f'Activation final bit-width config')
finalize_bitwidth_in_tb(tb_w, weights_conf_nodes_bitwidth, activation_conf_nodes_bitwidth)

return tg, bit_widths_config, hessian_info_service


def _init_tensorboard_writer(fw_info: FrameworkInfo) -> TensorboardWriter:
"""
Create a TensorBoardWriter object initialized with the logger dir path if it was set,
or None otherwise.
Args:
fw_info: FrameworkInfo object.
Returns:
A TensorBoardWriter object.
"""
tb_w = None
if Logger.LOG_PATH is not None:
tb_log_dir = os.path.join(os.getcwd(), Logger.LOG_PATH, 'tensorboard_logs')
Logger.info(f'To use Tensorboard, please run: tensorboard --logdir {tb_log_dir}')
tb_w = TensorboardWriter(tb_log_dir, fw_info)
return tb_w


def _set_final_kpi(graph: Graph,
final_bit_widths_config: List[int],
kpi_functions_dict: Dict[KPITarget, Tuple[MpKpiMetric, MpKpiAggregation]],
Expand Down
5 changes: 3 additions & 2 deletions model_compression_toolkit/gptq/keras/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Callable, Tuple
from packaging import version

from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF
from model_compression_toolkit.core.common.user_info import UserInformation
Expand All @@ -24,7 +25,7 @@
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfigV2
from model_compression_toolkit.core import CoreConfig
from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer
from model_compression_toolkit.core.runner import core_runner
from model_compression_toolkit.gptq.runner import gptq_runner
from model_compression_toolkit.core.exporter import export_model
from model_compression_toolkit.core.analyzer import analyzer_model_quantization
Expand Down Expand Up @@ -202,7 +203,7 @@ def keras_gradient_post_training_quantization_experimental(in_model: Model,
Logger.info("Using experimental mixed-precision quantization. "
"If you encounter an issue please file a bug.")

tb_w = _init_tensorboard_writer(fw_info)
tb_w = init_tensorboard_writer(fw_info)

fw_impl = GPTQKerasImplemantation()

Expand Down
5 changes: 3 additions & 2 deletions model_compression_toolkit/gptq/pytorch/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
from typing import Callable
from model_compression_toolkit.core import common
from model_compression_toolkit.constants import FOUND_TORCH
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.constants import PYTORCH
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer
from model_compression_toolkit.core.runner import core_runner
from model_compression_toolkit.gptq.keras.quantization_facade import GPTQ_MOMENTUM
from model_compression_toolkit.gptq.runner import gptq_runner
from model_compression_toolkit.core.exporter import export_model
Expand Down Expand Up @@ -161,7 +162,7 @@ def pytorch_gradient_post_training_quantization_experimental(model: Module,
Logger.info("Using experimental mixed-precision quantization. "
"If you encounter an issue please file a bug.")

tb_w = _init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO)

fw_impl = GPTQPytorchImplemantation()

Expand Down
7 changes: 4 additions & 3 deletions model_compression_toolkit/legacy/keras_quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from typing import Callable, List, Tuple

from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.constants import TENSORFLOW
from model_compression_toolkit.core.common.user_info import UserInformation
Expand All @@ -28,7 +29,7 @@
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
from model_compression_toolkit.core.common.quantization.debug_config import DebugConfig
from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer
from model_compression_toolkit.core.runner import core_runner
from model_compression_toolkit.gptq.runner import gptq_runner
from model_compression_toolkit.ptq.runner import ptq_runner
from model_compression_toolkit.core.exporter import export_model
Expand Down Expand Up @@ -114,7 +115,7 @@ def keras_post_training_quantization(in_model: Model,
network_editor=network_editor)
)

tb_w = _init_tensorboard_writer(fw_info)
tb_w = init_tensorboard_writer(fw_info)

fw_impl = KerasImplementation()

Expand Down Expand Up @@ -249,7 +250,7 @@ def keras_post_training_quantization_mixed_precision(in_model: Model,
network_editor=network_editor)
)

tb_w = _init_tensorboard_writer(fw_info)
tb_w = init_tensorboard_writer(fw_info)

fw_impl = KerasImplementation()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ==============================================================================
from typing import Callable, List, Tuple

from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.constants import PYTORCH
from model_compression_toolkit.core.common.user_info import UserInformation
Expand All @@ -28,7 +29,7 @@
MixedPrecisionQuantizationConfig, DEFAULT_MIXEDPRECISION_CONFIG
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer
from model_compression_toolkit.core.runner import core_runner
from model_compression_toolkit.gptq.runner import gptq_runner
from model_compression_toolkit.ptq.runner import ptq_runner
from model_compression_toolkit.core.exporter import export_model
Expand Down Expand Up @@ -106,7 +107,7 @@ def pytorch_post_training_quantization(in_module: Module,
debug_config=DebugConfig(analyze_similarity=analyze_similarity,
network_editor=network_editor))

tb_w = _init_tensorboard_writer(fw_info)
tb_w = init_tensorboard_writer(fw_info)

fw_impl = PytorchImplementation()

Expand Down Expand Up @@ -235,7 +236,7 @@ def pytorch_post_training_quantization_mixed_precision(in_model: Module,
debug_config=DebugConfig(analyze_similarity=analyze_similarity,
network_editor=network_editor))

tb_w = _init_tensorboard_writer(fw_info)
tb_w = init_tensorboard_writer(fw_info)

fw_impl = PytorchImplementation()

Expand Down
5 changes: 3 additions & 2 deletions model_compression_toolkit/ptq/keras/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@

from model_compression_toolkit.core import CoreConfig
from model_compression_toolkit.core.analyzer import analyzer_model_quantization
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
MixedPrecisionQuantizationConfigV2
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
from model_compression_toolkit.core.exporter import export_model
from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer
from model_compression_toolkit.core.runner import core_runner
from model_compression_toolkit.ptq.runner import ptq_runner

if FOUND_TF:
Expand Down Expand Up @@ -130,7 +131,7 @@ def keras_post_training_quantization_experimental(in_model: Model,
Logger.info("Using experimental mixed-precision quantization. "
"If you encounter an issue please file a bug.")

tb_w = _init_tensorboard_writer(fw_info)
tb_w = init_tensorboard_writer(fw_info)

fw_impl = KerasImplementation()

Expand Down
5 changes: 3 additions & 2 deletions model_compression_toolkit/ptq/pytorch/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
from typing import Callable

from model_compression_toolkit.core import common
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.constants import PYTORCH, FOUND_TORCH
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
from model_compression_toolkit.core import CoreConfig
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
MixedPrecisionQuantizationConfigV2
from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer
from model_compression_toolkit.core.runner import core_runner
from model_compression_toolkit.ptq.runner import ptq_runner
from model_compression_toolkit.core.exporter import export_model
from model_compression_toolkit.core.analyzer import analyzer_model_quantization
Expand Down Expand Up @@ -102,7 +103,7 @@ def pytorch_post_training_quantization_experimental(in_module: Module,
Logger.info("Using experimental mixed-precision quantization. "
"If you encounter an issue please file a bug.")

tb_w = _init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO)

fw_impl = PytorchImplementation()

Expand Down
5 changes: 3 additions & 2 deletions model_compression_toolkit/qat/keras/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from functools import partial

from model_compression_toolkit.core import CoreConfig
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.constants import FOUND_TF
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
Expand All @@ -25,7 +26,7 @@
from mct_quantizers import KerasActivationQuantizationHolder
from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer
from model_compression_toolkit.core.runner import core_runner
from model_compression_toolkit.ptq.runner import ptq_runner

if FOUND_TF:
Expand Down Expand Up @@ -177,7 +178,7 @@ def keras_quantization_aware_training_init(in_model: Model,
Logger.info("Using experimental mixed-precision quantization. "
"If you encounter an issue please file a bug.")

tb_w = _init_tensorboard_writer(fw_info)
tb_w = init_tensorboard_writer(fw_info)

fw_impl = KerasImplementation()

Expand Down
5 changes: 3 additions & 2 deletions model_compression_toolkit/qat/pytorch/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@

from model_compression_toolkit.core import CoreConfig
from model_compression_toolkit.core import common
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
MixedPrecisionQuantizationConfigV2
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import \
TargetPlatformCapabilities
from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer
from model_compression_toolkit.core.runner import core_runner
from model_compression_toolkit.ptq.runner import ptq_runner

if FOUND_TORCH:
Expand Down Expand Up @@ -145,7 +146,7 @@ def pytorch_quantization_aware_training_init(in_model: Module,
Logger.info("Using experimental mixed-precision quantization. "
"If you encounter an issue please file a bug.")

tb_w = _init_tensorboard_writer(fw_info)
tb_w = init_tensorboard_writer(fw_info)

fw_impl = PytorchImplementation()

Expand Down
4 changes: 2 additions & 2 deletions tests/common_tests/helpers/prep_graph_for_func_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
from model_compression_toolkit.core.common.quantization.quantization_analyzer import analyzer_graph
from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_computation import \
calculate_quantization_params
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
from model_compression_toolkit.core.graph_prep_runner import graph_preparation_runner
from model_compression_toolkit.core.quantization_prep_runner import quantization_preparation_runner
from model_compression_toolkit.core.runner import _init_tensorboard_writer

from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_tp_model, \
get_op_quantization_configs
Expand Down Expand Up @@ -114,7 +114,7 @@ def prepare_graph_set_bit_widths(in_model,
debug_config=DebugConfig(analyze_similarity=analyze_similarity,
network_editor=network_editor))

tb_w = _init_tensorboard_writer(fw_info)
tb_w = init_tensorboard_writer(fw_info)

# convert old representative dataset generation to a generator
def _representative_data_gen():
Expand Down
Loading

0 comments on commit 753debd

Please sign in to comment.