diff --git a/model_compression_toolkit/core/common/visualization/tensorboard_writer.py b/model_compression_toolkit/core/common/visualization/tensorboard_writer.py index 126682c5b..447503dd6 100644 --- a/model_compression_toolkit/core/common/visualization/tensorboard_writer.py +++ b/model_compression_toolkit/core/common/visualization/tensorboard_writer.py @@ -16,6 +16,7 @@ from copy import deepcopy import io +import os import numpy as np from PIL import Image from matplotlib.figure import Figure @@ -34,6 +35,9 @@ from model_compression_toolkit.core import FrameworkInfo from model_compression_toolkit.core.common import Graph, BaseNode from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector +from model_compression_toolkit.logger import Logger +from model_compression_toolkit.core.common.visualization.final_config_visualizer import \ + WeightsFinalBitwidthConfigVisualizer, ActivationFinalBitwidthConfigVisualizer DEVICE_STEP_STATS = "/device:CPU:0" @@ -486,3 +490,45 @@ def add_figure(self, er = self.__get_event_writer_by_tag_name(main_tag_name) er.add_event(event) er.flush() + + +def init_tensorboard_writer(fw_info: FrameworkInfo) -> TensorboardWriter: + """ + Create a TensorBoardWriter object initialized with the logger dir path if it was set, + or None otherwise. + + Args: + fw_info: FrameworkInfo object. + + Returns: + A TensorBoardWriter object. + """ + tb_w = None + if Logger.LOG_PATH is not None: + tb_log_dir = os.path.join(os.getcwd(), Logger.LOG_PATH, 'tensorboard_logs') + Logger.info(f'To use Tensorboard, please run: tensorboard --logdir {tb_log_dir}') + tb_w = TensorboardWriter(tb_log_dir, fw_info) + return tb_w + + +def finalize_bitwidth_in_tb(tb_w: TensorboardWriter, + weights_conf_nodes_bitwidth: List, + activation_conf_nodes_bitwidth: List): + """ + Set the final bit-width configuration of the quantized model in the provided TensorBoard object. + + Args: + tb_w: A TensorBoard object. + weights_conf_nodes_bitwidth: Final weights bit-width configuration. + activation_conf_nodes_bitwidth: Final activation bit-width configuration. + + """ + + if len(weights_conf_nodes_bitwidth) > 0: + visual = WeightsFinalBitwidthConfigVisualizer(weights_conf_nodes_bitwidth) + figure = visual.plot_config_bitwidth() + tb_w.add_figure(figure, f'Weights final bit-width config') + if len(activation_conf_nodes_bitwidth) > 0: + visual = ActivationFinalBitwidthConfigVisualizer(activation_conf_nodes_bitwidth) + figure = visual.plot_config_bitwidth() + tb_w.add_figure(figure, f'Activation final bit-width config') diff --git a/model_compression_toolkit/core/runner.py b/model_compression_toolkit/core/runner.py index be7305b20..6fcd77e53 100644 --- a/model_compression_toolkit/core/runner.py +++ b/model_compression_toolkit/core/runner.py @@ -14,11 +14,9 @@ # ============================================================================== -import os from typing import Callable, Tuple, Any, List, Dict import numpy as np -from tqdm import tqdm from model_compression_toolkit.core.common import FrameworkInfo from model_compression_toolkit.core.common.hessian.hessian_info_service import HessianInfoService @@ -33,20 +31,14 @@ from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi_functions_mapping import kpi_functions_mapping from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi_methods import MpKpiMetric from model_compression_toolkit.core.common.mixed_precision.mixed_precision_search_facade import search_bit_width -from model_compression_toolkit.core.common.model_collector import ModelCollector from model_compression_toolkit.core.common.network_editors.edit_network import edit_network_graph from model_compression_toolkit.core.common.quantization.core_config import CoreConfig -from model_compression_toolkit.core.common.quantization.quantization_analyzer import analyzer_graph -from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_computation import \ - calculate_quantization_params -from model_compression_toolkit.core.common.statistics_correction.statistics_correction import \ - statistics_correction_runner -from model_compression_toolkit.core.common.substitutions.apply_substitutions import substitute from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities from model_compression_toolkit.core.common.visualization.final_config_visualizer import \ WeightsFinalBitwidthConfigVisualizer, \ ActivationFinalBitwidthConfigVisualizer -from model_compression_toolkit.core.common.visualization.tensorboard_writer import TensorboardWriter +from model_compression_toolkit.core.common.visualization.tensorboard_writer import TensorboardWriter, \ + finalize_bitwidth_in_tb def core_runner(in_model: Any, @@ -150,37 +142,11 @@ def core_runner(in_model: Any, f'Final activation bit-width configuration: {[node_b[1] for node_b in activation_conf_nodes_bitwidth]}') if tb_w is not None: - if len(weights_conf_nodes_bitwidth) > 0: - visual = WeightsFinalBitwidthConfigVisualizer(weights_conf_nodes_bitwidth) - figure = visual.plot_config_bitwidth() - tb_w.add_figure(figure, f'Weights final bit-width config') - if len(activation_conf_nodes_bitwidth) > 0: - visual = ActivationFinalBitwidthConfigVisualizer(activation_conf_nodes_bitwidth) - figure = visual.plot_config_bitwidth() - tb_w.add_figure(figure, f'Activation final bit-width config') + finalize_bitwidth_in_tb(tb_w, weights_conf_nodes_bitwidth, activation_conf_nodes_bitwidth) return tg, bit_widths_config, hessian_info_service -def _init_tensorboard_writer(fw_info: FrameworkInfo) -> TensorboardWriter: - """ - Create a TensorBoardWriter object initialized with the logger dir path if it was set, - or None otherwise. - - Args: - fw_info: FrameworkInfo object. - - Returns: - A TensorBoardWriter object. - """ - tb_w = None - if Logger.LOG_PATH is not None: - tb_log_dir = os.path.join(os.getcwd(), Logger.LOG_PATH, 'tensorboard_logs') - Logger.info(f'To use Tensorboard, please run: tensorboard --logdir {tb_log_dir}') - tb_w = TensorboardWriter(tb_log_dir, fw_info) - return tb_w - - def _set_final_kpi(graph: Graph, final_bit_widths_config: List[int], kpi_functions_dict: Dict[KPITarget, Tuple[MpKpiMetric, MpKpiAggregation]], diff --git a/model_compression_toolkit/gptq/keras/quantization_facade.py b/model_compression_toolkit/gptq/keras/quantization_facade.py index e9c98f3db..1e5f72413 100644 --- a/model_compression_toolkit/gptq/keras/quantization_facade.py +++ b/model_compression_toolkit/gptq/keras/quantization_facade.py @@ -16,6 +16,7 @@ from typing import Callable, Tuple from packaging import version +from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer from model_compression_toolkit.logger import Logger from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF from model_compression_toolkit.core.common.user_info import UserInformation @@ -24,7 +25,7 @@ from model_compression_toolkit.core.common.framework_info import FrameworkInfo from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfigV2 from model_compression_toolkit.core import CoreConfig -from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer +from model_compression_toolkit.core.runner import core_runner from model_compression_toolkit.gptq.runner import gptq_runner from model_compression_toolkit.core.exporter import export_model from model_compression_toolkit.core.analyzer import analyzer_model_quantization @@ -202,7 +203,7 @@ def keras_gradient_post_training_quantization_experimental(in_model: Model, Logger.info("Using experimental mixed-precision quantization. " "If you encounter an issue please file a bug.") - tb_w = _init_tensorboard_writer(fw_info) + tb_w = init_tensorboard_writer(fw_info) fw_impl = GPTQKerasImplemantation() diff --git a/model_compression_toolkit/gptq/pytorch/quantization_facade.py b/model_compression_toolkit/gptq/pytorch/quantization_facade.py index a20c3842c..57f06d1af 100644 --- a/model_compression_toolkit/gptq/pytorch/quantization_facade.py +++ b/model_compression_toolkit/gptq/pytorch/quantization_facade.py @@ -15,12 +15,13 @@ from typing import Callable from model_compression_toolkit.core import common from model_compression_toolkit.constants import FOUND_TORCH +from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer from model_compression_toolkit.logger import Logger from model_compression_toolkit.constants import PYTORCH from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2 from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI -from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer +from model_compression_toolkit.core.runner import core_runner from model_compression_toolkit.gptq.keras.quantization_facade import GPTQ_MOMENTUM from model_compression_toolkit.gptq.runner import gptq_runner from model_compression_toolkit.core.exporter import export_model @@ -161,7 +162,7 @@ def pytorch_gradient_post_training_quantization_experimental(model: Module, Logger.info("Using experimental mixed-precision quantization. " "If you encounter an issue please file a bug.") - tb_w = _init_tensorboard_writer(DEFAULT_PYTORCH_INFO) + tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO) fw_impl = GPTQPytorchImplemantation() diff --git a/model_compression_toolkit/legacy/keras_quantization_facade.py b/model_compression_toolkit/legacy/keras_quantization_facade.py index 52505409d..57a8c1532 100644 --- a/model_compression_toolkit/legacy/keras_quantization_facade.py +++ b/model_compression_toolkit/legacy/keras_quantization_facade.py @@ -15,6 +15,7 @@ from typing import Callable, List, Tuple +from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer from model_compression_toolkit.logger import Logger from model_compression_toolkit.constants import TENSORFLOW from model_compression_toolkit.core.common.user_info import UserInformation @@ -28,7 +29,7 @@ from model_compression_toolkit.core.common.quantization.core_config import CoreConfig from model_compression_toolkit.core.common.quantization.debug_config import DebugConfig from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG -from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer +from model_compression_toolkit.core.runner import core_runner from model_compression_toolkit.gptq.runner import gptq_runner from model_compression_toolkit.ptq.runner import ptq_runner from model_compression_toolkit.core.exporter import export_model @@ -114,7 +115,7 @@ def keras_post_training_quantization(in_model: Model, network_editor=network_editor) ) - tb_w = _init_tensorboard_writer(fw_info) + tb_w = init_tensorboard_writer(fw_info) fw_impl = KerasImplementation() @@ -249,7 +250,7 @@ def keras_post_training_quantization_mixed_precision(in_model: Model, network_editor=network_editor) ) - tb_w = _init_tensorboard_writer(fw_info) + tb_w = init_tensorboard_writer(fw_info) fw_impl = KerasImplementation() diff --git a/model_compression_toolkit/legacy/pytorch_quantization_facade.py b/model_compression_toolkit/legacy/pytorch_quantization_facade.py index 630aa93c0..a28b9d7d7 100644 --- a/model_compression_toolkit/legacy/pytorch_quantization_facade.py +++ b/model_compression_toolkit/legacy/pytorch_quantization_facade.py @@ -14,6 +14,7 @@ # ============================================================================== from typing import Callable, List, Tuple +from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer from model_compression_toolkit.logger import Logger from model_compression_toolkit.constants import PYTORCH from model_compression_toolkit.core.common.user_info import UserInformation @@ -28,7 +29,7 @@ MixedPrecisionQuantizationConfig, DEFAULT_MIXEDPRECISION_CONFIG from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG -from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer +from model_compression_toolkit.core.runner import core_runner from model_compression_toolkit.gptq.runner import gptq_runner from model_compression_toolkit.ptq.runner import ptq_runner from model_compression_toolkit.core.exporter import export_model @@ -106,7 +107,7 @@ def pytorch_post_training_quantization(in_module: Module, debug_config=DebugConfig(analyze_similarity=analyze_similarity, network_editor=network_editor)) - tb_w = _init_tensorboard_writer(fw_info) + tb_w = init_tensorboard_writer(fw_info) fw_impl = PytorchImplementation() @@ -235,7 +236,7 @@ def pytorch_post_training_quantization_mixed_precision(in_model: Module, debug_config=DebugConfig(analyze_similarity=analyze_similarity, network_editor=network_editor)) - tb_w = _init_tensorboard_writer(fw_info) + tb_w = init_tensorboard_writer(fw_info) fw_impl = PytorchImplementation() diff --git a/model_compression_toolkit/ptq/keras/quantization_facade.py b/model_compression_toolkit/ptq/keras/quantization_facade.py index 829628e64..b0bdce3ad 100644 --- a/model_compression_toolkit/ptq/keras/quantization_facade.py +++ b/model_compression_toolkit/ptq/keras/quantization_facade.py @@ -17,6 +17,7 @@ from model_compression_toolkit.core import CoreConfig from model_compression_toolkit.core.analyzer import analyzer_model_quantization +from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer from model_compression_toolkit.logger import Logger from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI @@ -24,7 +25,7 @@ MixedPrecisionQuantizationConfigV2 from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities from model_compression_toolkit.core.exporter import export_model -from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer +from model_compression_toolkit.core.runner import core_runner from model_compression_toolkit.ptq.runner import ptq_runner if FOUND_TF: @@ -130,7 +131,7 @@ def keras_post_training_quantization_experimental(in_model: Model, Logger.info("Using experimental mixed-precision quantization. " "If you encounter an issue please file a bug.") - tb_w = _init_tensorboard_writer(fw_info) + tb_w = init_tensorboard_writer(fw_info) fw_impl = KerasImplementation() diff --git a/model_compression_toolkit/ptq/pytorch/quantization_facade.py b/model_compression_toolkit/ptq/pytorch/quantization_facade.py index 9a946d471..a53051c86 100644 --- a/model_compression_toolkit/ptq/pytorch/quantization_facade.py +++ b/model_compression_toolkit/ptq/pytorch/quantization_facade.py @@ -15,6 +15,7 @@ from typing import Callable from model_compression_toolkit.core import common +from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer from model_compression_toolkit.logger import Logger from model_compression_toolkit.constants import PYTORCH, FOUND_TORCH from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities @@ -22,7 +23,7 @@ from model_compression_toolkit.core import CoreConfig from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \ MixedPrecisionQuantizationConfigV2 -from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer +from model_compression_toolkit.core.runner import core_runner from model_compression_toolkit.ptq.runner import ptq_runner from model_compression_toolkit.core.exporter import export_model from model_compression_toolkit.core.analyzer import analyzer_model_quantization @@ -102,7 +103,7 @@ def pytorch_post_training_quantization_experimental(in_module: Module, Logger.info("Using experimental mixed-precision quantization. " "If you encounter an issue please file a bug.") - tb_w = _init_tensorboard_writer(DEFAULT_PYTORCH_INFO) + tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO) fw_impl = PytorchImplementation() diff --git a/model_compression_toolkit/qat/keras/quantization_facade.py b/model_compression_toolkit/qat/keras/quantization_facade.py index 457b83e1f..bc43e91c1 100644 --- a/model_compression_toolkit/qat/keras/quantization_facade.py +++ b/model_compression_toolkit/qat/keras/quantization_facade.py @@ -17,6 +17,7 @@ from functools import partial from model_compression_toolkit.core import CoreConfig +from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer from model_compression_toolkit.logger import Logger from model_compression_toolkit.constants import FOUND_TF from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI @@ -25,7 +26,7 @@ from mct_quantizers import KerasActivationQuantizationHolder from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities -from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer +from model_compression_toolkit.core.runner import core_runner from model_compression_toolkit.ptq.runner import ptq_runner if FOUND_TF: @@ -177,7 +178,7 @@ def keras_quantization_aware_training_init(in_model: Model, Logger.info("Using experimental mixed-precision quantization. " "If you encounter an issue please file a bug.") - tb_w = _init_tensorboard_writer(fw_info) + tb_w = init_tensorboard_writer(fw_info) fw_impl = KerasImplementation() diff --git a/model_compression_toolkit/qat/pytorch/quantization_facade.py b/model_compression_toolkit/qat/pytorch/quantization_facade.py index 181716b91..28df342e5 100644 --- a/model_compression_toolkit/qat/pytorch/quantization_facade.py +++ b/model_compression_toolkit/qat/pytorch/quantization_facade.py @@ -20,6 +20,7 @@ from model_compression_toolkit.core import CoreConfig from model_compression_toolkit.core import common +from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer from model_compression_toolkit.logger import Logger from model_compression_toolkit.core.common.framework_info import FrameworkInfo from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI @@ -27,7 +28,7 @@ MixedPrecisionQuantizationConfigV2 from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import \ TargetPlatformCapabilities -from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer +from model_compression_toolkit.core.runner import core_runner from model_compression_toolkit.ptq.runner import ptq_runner if FOUND_TORCH: @@ -145,7 +146,7 @@ def pytorch_quantization_aware_training_init(in_model: Module, Logger.info("Using experimental mixed-precision quantization. " "If you encounter an issue please file a bug.") - tb_w = _init_tensorboard_writer(fw_info) + tb_w = init_tensorboard_writer(fw_info) fw_impl = PytorchImplementation() diff --git a/tests/common_tests/helpers/prep_graph_for_func_test.py b/tests/common_tests/helpers/prep_graph_for_func_test.py index 7ed3c8f68..3c59beb2a 100644 --- a/tests/common_tests/helpers/prep_graph_for_func_test.py +++ b/tests/common_tests/helpers/prep_graph_for_func_test.py @@ -21,9 +21,9 @@ from model_compression_toolkit.core.common.quantization.quantization_analyzer import analyzer_graph from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_computation import \ calculate_quantization_params +from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer from model_compression_toolkit.core.graph_prep_runner import graph_preparation_runner from model_compression_toolkit.core.quantization_prep_runner import quantization_preparation_runner -from model_compression_toolkit.core.runner import _init_tensorboard_writer from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_tp_model, \ get_op_quantization_configs @@ -114,7 +114,7 @@ def prepare_graph_set_bit_widths(in_model, debug_config=DebugConfig(analyze_similarity=analyze_similarity, network_editor=network_editor)) - tb_w = _init_tensorboard_writer(fw_info) + tb_w = init_tensorboard_writer(fw_info) # convert old representative dataset generation to a generator def _representative_data_gen(): diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/second_moment_correction_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/second_moment_correction_test.py index 74cce220b..03f5ed40d 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/second_moment_correction_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/second_moment_correction_test.py @@ -27,13 +27,14 @@ from model_compression_toolkit.core.common.network_editors import EditRule from model_compression_toolkit.core.common.statistics_correction.apply_second_moment_correction_to_graph import \ quantized_model_builder_for_second_moment_correction +from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer from model_compression_toolkit.core.keras.constants import EPSILON_VAL, GAMMA, BETA, MOVING_MEAN, MOVING_VARIANCE from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation from model_compression_toolkit.core.keras.statistics_correction.apply_second_moment_correction import \ keras_apply_second_moment_correction -from model_compression_toolkit.core.runner import _init_tensorboard_writer, core_runner +from model_compression_toolkit.core.runner import core_runner from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities @@ -277,7 +278,7 @@ def prepare_graph(self, network_editor=network_editor) ) - tb_w = _init_tensorboard_writer(fw_info) + tb_w = init_tensorboard_writer(fw_info) fw_impl = KerasImplementation() diff --git a/tests/pytorch_tests/model_tests/feature_models/second_moment_correction_test.py b/tests/pytorch_tests/model_tests/feature_models/second_moment_correction_test.py index 58b62512e..2d3a61b95 100644 --- a/tests/pytorch_tests/model_tests/feature_models/second_moment_correction_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/second_moment_correction_test.py @@ -24,6 +24,7 @@ from model_compression_toolkit.core.common import Graph from model_compression_toolkit.core.common.statistics_correction.apply_second_moment_correction_to_graph import \ quantized_model_builder_for_second_moment_correction +from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities from model_compression_toolkit.core.pytorch.constants import EPSILON_VAL, GAMMA, BETA, MOVING_MEAN, MOVING_VARIANCE @@ -32,7 +33,7 @@ from model_compression_toolkit.core.pytorch.statistics_correction.apply_second_moment_correction import \ pytorch_apply_second_moment_correction from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, set_model -from model_compression_toolkit.core.runner import _init_tensorboard_writer, core_runner +from model_compression_toolkit.core.runner import core_runner from tests.common_tests.helpers.generate_test_tp_model import generate_test_tp_model from tests.pytorch_tests.model_tests.base_pytorch_test import BasePytorchTest from tests.pytorch_tests.tpc_pytorch import get_pytorch_test_tpc_dict @@ -346,7 +347,7 @@ def prepare_graph(self, target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_INFO) -> \ Tuple[Graph, Graph]: - tb_w = _init_tensorboard_writer(fw_info) + tb_w = init_tensorboard_writer(fw_info) fw_impl = PytorchImplementation()