diff --git a/TrainingExtensions/common/src/python/aimet_common/cross_layer_equalization.py b/TrainingExtensions/common/src/python/aimet_common/cross_layer_equalization.py index c681a90d4ea..e92227526e2 100644 --- a/TrainingExtensions/common/src/python/aimet_common/cross_layer_equalization.py +++ b/TrainingExtensions/common/src/python/aimet_common/cross_layer_equalization.py @@ -46,7 +46,6 @@ from typing import List, Union, Tuple, Dict from enum import Enum import numpy as np -from onnx import onnx_pb from aimet_common.utils import AimetLogger from aimet_common.connected_graph.connectedgraph_utils import get_all_input_ops @@ -76,7 +75,7 @@ class ClsSetLayerPairInfo: Models a pair of layers that were scaled using CLS. And related information. """ - def __init__(self, layer1: onnx_pb.NodeProto, layer2: onnx_pb.NodeProto, scale_factor: np.ndarray, + def __init__(self, layer1, layer2, scale_factor: np.ndarray, relu_activation_between_layers: bool): """ :param layer1: Layer whose bias is folded