Skip to content

Commit

Permalink
Fixed documentation and eval script for Torch EfficientNet-lite0 mode…
Browse files Browse the repository at this point in the history
…l, added eval script for Tensorflow Resnet50 model

Signed-off-by: Bharath Ramaswamy <quic_bharathr@quicinc.com>
  • Loading branch information
quic-bharathr committed Dec 31, 2020
1 parent 154f372 commit 722a1ba
Show file tree
Hide file tree
Showing 5 changed files with 239 additions and 41 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ An original FP32 source model is quantized either using post-training quantizati
<td><a href="https://github.com/rwightman/gen-efficientnet-pytorch">GitHub Repo</a></td>
<td><a href="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_lite0_ra-37913777.pth">Pretrained Model</a></td>
<td><a href="zoo_torch/examples/eval_efficientnetlite0.py">See Example</a></td>
<td>(ImageNet) Top-1 Accuracy <br> FP32: 75.42%<br> INT8: 74.49%</td>
<td>(ImageNet) Top-1 Accuracy <br> FP32: 75.42%<br> INT8: 74.44%</td>
<td><a href="zoo_torch/Docs/EfficientNet-lite0.md">EfficientNet-lite0.md</a></td>
</tr>
<tr>
Expand Down
8 changes: 4 additions & 4 deletions zoo_tensorflow/Docs/ResNet50.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Please [install and setup AIMET](../../README.md#install-aimet) before proceedin

## Obtaining model checkpoint and dataset

- The optimized ResNet 50 checkpoint can be downloaded from [Releases](/../../releases).
- The original ResNet 50 checkpoint can be downloaded from [TensorFlow Models repo](http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz).

- ImageNet can be downloaded here:
- http://www.image-net.org/
Expand All @@ -35,17 +35,17 @@ Please [install and setup AIMET](../../README.md#install-aimet) before proceedin
- To run evaluation with QuantSim in AIMET, use the following

```bash
python resnet_v1_50.py \
python resnet_v1_50_quanteval.py \
--model-name=resnet_v1_50 \
--checkpoint-path=<path to resnet_v1_50 checkpoint> \
--dataset-dir=<path to imagenet validation TFRecords> \
--quantsim-config-file=<path to config file with symmetric weights>
```

- If you are using a model checkpoint which has Batch Norms already folded (such as the optimized model checkpoint), please specify the `--ckpt-bn-folded` flag:
- If you are using a model checkpoint which has Batch Norms already folded, please specify the `--ckpt-bn-folded` flag:

```bash
python resnet_v1_50.py \
python resnet_v1_50_quanteval.py \
--model-name=resnet_v1_50 \
--checkpoint-path=<path to resnet_v1_50 checkpoint> \
--dataset-dir=<path to imagenet validation TFRecords> \
Expand Down
224 changes: 224 additions & 0 deletions zoo_tensorflow/examples/resnet_v1_50_quanteval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
#! /usr/bin/env python3.6
# -*- mode: python -*-
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2019-2020, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
# SPDX-License-Identifier: BSD-3-Clause
#
# @@-COPYRIGHT-END-@@
# =============================================================================

import os
import sys
import json
import argparse
from tqdm import tqdm
from glob import glob

import numpy as np
import tensorflow as tf

import aimet_common.defs
from aimet_tensorflow import quantsim
from aimet_tensorflow.cross_layer_equalization import GraphSearchUtils, equalize_model
from aimet_tensorflow.bias_correction import BiasCorrectionParams, BiasCorrection, QuantParams
from aimet_tensorflow.quantsim import save_checkpoint, QuantizationSimModel
from aimet_tensorflow.batch_norm_fold import fold_all_batch_norms

from nets import nets_factory
from preprocessing import preprocessing_factory
from deployment import model_deploy
from datasets import dataset_factory

def wrap_preprocessing(preprocessing, height, width, num_classes, labels_offset):
'''Wrap preprocessing function to do parsing of TFrecords.
'''
def parse(serialized_example):
features = tf.parse_single_example(serialized_example, features={
'image/class/label': tf.FixedLenFeature([], tf.int64),
'image/encoded': tf.FixedLenFeature([], tf.string)
})

image_data = features['image/encoded']
image = tf.image.decode_jpeg(image_data, channels=3)
label = tf.cast(features['image/class/label'], tf.int32)
label = label - labels_offset

labels = tf.one_hot(indices=label, depth=num_classes)
image = preprocessing(image, height, width)
return image, labels
return parse

def run_evaluation(args):
# Build graph definition
with tf.Graph().as_default():
# Create iterator
tf_records = glob(args.dataset_dir + '/validation*')
preprocessing_fn = preprocessing_factory.get_preprocessing(args.model_name, is_training=False)
parse_function = wrap_preprocessing(preprocessing_fn, height=args.image_size, width=args.image_size, num_classes=(1001 - args.labels_offset), labels_offset=args.labels_offset)

dataset = tf.data.TFRecordDataset(tf_records).repeat(1)
dataset = dataset.map(parse_function, num_parallel_calls=1).apply(tf.contrib.data.batch_and_drop_remainder(args.batch_size))
iterator = dataset.make_initializable_iterator()
images, labels = iterator.get_next()

network_fn = nets_factory.get_network_fn(args.model_name, num_classes=(1001 - args.labels_offset), is_training=False)
with tf.device('/cpu:0'):
images = tf.placeholder_with_default(images,
shape=(None, args.image_size, args.image_size, 3),
name='input')
labels = tf.placeholder_with_default(labels,
shape=(None, 1001 - args.labels_offset),
name='labels')
logits, end_points = network_fn(images)
confidences = tf.nn.softmax(logits, axis=1, name='confidences')
categorical_preds = tf.argmax(confidences, axis=1, name='categorical_preds')
categorical_labels = tf.argmax(labels, axis=1, name='categorical_labels')
correct_predictions = tf.equal(categorical_labels, categorical_preds)
top1_acc = tf.reduce_mean(tf.cast(correct_predictions, tf.float32), name='top1-acc')
top5_acc = tf.reduce_mean(tf.cast(tf.nn.in_top_k(predictions=confidences,
targets=tf.cast(categorical_labels, tf.int32),
k=5), tf.float32), name='top5-acc')

saver = tf.train.Saver()
sess = tf.Session()

# Load model from checkpoint
if not args.ckpt_bn_folded:
saver.restore(sess, args.checkpoint_path)
else:
sess.run(tf.global_variables_initializer())

# Fold all BatchNorms before QuantSim
sess, folded_pairs = fold_all_batch_norms(sess, ['IteratorGetNext'], [logits.name[:-2]])

if args.ckpt_bn_folded:
with sess.graph.as_default():
saver = tf.train.Saver()
saver.restore(sess, args.checkpoint_path)
else:
# Do Cross Layer Equalization and Bias Correction if not loading from a batchnorm folded checkpoint
sess = equalize_model(sess, ['input'], [logits.op.name])
conv_bn_dict = BiasCorrection.find_all_convs_bn_with_activation(sess, ['input'], [logits.op.name])
quant_params = QuantParams(quant_mode=args.quant_scheme)
bias_correction_dataset = tf.data.TFRecordDataset(tf_records).repeat(1)
bias_correction_dataset = bias_correction_dataset.map(lambda x: parse_function(x)[0], num_parallel_calls=1).apply(tf.contrib.data.batch_and_drop_remainder(args.batch_size))
bias_correction_params = BiasCorrectionParams(batch_size=args.batch_size,
num_quant_samples=10,
num_bias_correct_samples=512,
input_op_names=['input'],
output_op_names=[logits.op.name])


sess = BiasCorrection.correct_bias(reference_model=sess,
bias_correct_params=bias_correction_params,
quant_params=quant_params,
data_set=bias_correction_dataset,
conv_bn_dict=conv_bn_dict,
perform_only_empirical_bias_corr=True)


# Define eval_func to use for compute encodings in QuantSim
def eval_func(session, iterations):
cnt = 0
avg_acc_top1 = 0
session.run('MakeIterator')
while cnt < iterations or iterations == -1:
try:
avg_acc_top1 += session.run('top1-acc:0')
cnt += 1
except:
return avg_acc_top1 / cnt

return avg_acc_top1 / cnt

# Select the right quant_scheme
if args.quant_scheme == 'range_learning_tf':
quant_scheme = aimet_common.defs.QuantScheme.training_range_learning_with_tf_init
elif args.quant_scheme == 'range_learning_tf_enhanced':
quant_scheme = aimet_common.defs.QuantScheme.training_range_learning_with_tf_enhanced_init
elif args.quant_scheme == 'tf':
quant_scheme = aimet_common.defs.QuantScheme.post_training_tf
elif args.quant_scheme == 'tf_enhanced':
quant_scheme = aimet_common.defs.QuantScheme.post_training_tf_enhanced
else:
raise ValueError("Got unrecognized quant_scheme: " + args.quant_scheme)

# Create QuantizationSimModel
sim = QuantizationSimModel(
session=sess,
starting_op_names=['IteratorGetNext'],
output_op_names=[logits.name[:-2]],
quant_scheme=quant_scheme,
rounding_mode=args.round_mode,
default_output_bw=args.default_output_bw,
default_param_bw=args.default_param_bw,
config_file=args.quantsim_config_file,
)

# Run compute_encodings
sim.compute_encodings(eval_func, forward_pass_callback_args=args.encodings_iterations)

# Run final evaluation
sess = sim.session

top1_acc = eval_func(sess, -1)
print('Avg accuracy Top 1: {}'.format(top1_acc))


def parse_args(args):
""" Parse the arguments.
"""
parser = argparse.ArgumentParser(description='Evaluation script for an Resnet 50 network.')

parser.add_argument('--model-name', help='Name of model to eval.', default='resnet_v1_50')
parser.add_argument('--checkpoint-path', help='Path to checkpoint to load from.')
parser.add_argument('--dataset-dir', help='Imagenet eval dataset directory.')
parser.add_argument('--labels-offset', help='Offset for whether to ignore background label', type=int, default=0)
parser.add_argument('--image-size', help='Image size.', type=int, default=224)
parser.add_argument('--batch-size', help='Batch size.', type=int, default=32)

parser.add_argument('--ckpt-bn-folded', help='Use this flag to specify whether checkpoint has batchnorms folded already or not.', action='store_true')
parser.add_argument('--quant-scheme', help='Quant scheme to use for quantization (tf, tf_enhanced, range_learning_tf, range_learning_tf_enhanced).', default='tf')
parser.add_argument('--round-mode', help='Round mode for quantization.', default='nearest')
parser.add_argument('--default-output-bw', help='Default output bitwidth for quantization.', type=int, default=8)
parser.add_argument('--default-param-bw', help='Default parameter bitwidth for quantization.', type=int, default=8)
parser.add_argument('--quantsim-config-file', help='Quantsim configuration file.', default=None)
parser.add_argument('--encodings-iterations', help='Number of iterations to use for compute encodings during quantization.', default=500)

return parser.parse_args(args)

def main(args=None):
args = parse_args(args)
run_evaluation(args)

if __name__ == '__main__':
main()
8 changes: 4 additions & 4 deletions zoo_torch/Docs/EfficientNet-lite0.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,20 @@ sudo -H pip install geffnet

- The original EfficientNet-lite0 checkpoint can be downloaded from here:
- https://github.com/rwightman/gen-efficientnet-pytorch
- Optimized EfficientNet-lite0 checkpoint can be downloaded from the [Releases](/../../releases) page.
- ImageNet can be downloaded from here:
- http://www.image-net.org/

## Usage
- To run evaluation with QuantSim in AIMET, use the following
```bash
python eval_efficientnetlite0.py \
--checkpoint <path to optimiezd checkpoint> \
--images-dir <path to imagenet root directory> \
--quant-scheme <quantization schme to run> \
--quant-tricks <preprocessing steps prior to Quantization> \
--default-output-bw <bitwidth for activation quantization> \
--default-param-bw <bitwidth for weight quantization> \
--num-iterations <Number of iterations used for adaround optimization If adaround is used> \
--num-batches <Number of batches used for adaround optimization If adaround is used> \
--default-param-bw <bitwidth for weight quantization>
```

## Quantization Configuration
Expand All @@ -34,5 +34,5 @@ python eval_efficientnetlite0.py \
- Activation quantization: 8 bits, asymmetric quantization
- Model inputs are not quantized
- TF_enhanced was used as quantization scheme
- Batch norm folding and Adaround has been applied on efficientnet-lite in the eval script
- Batch norm folding and Adaround has been applied on optimized efficientnet-lite checkpoint
- [Conv - Relu6] layers has been fused as one operation via manual configurations
38 changes: 6 additions & 32 deletions zoo_torch/examples/eval_efficientnetlite0.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from aimet_torch import batch_norm_fold
from aimet_common.defs import QuantScheme
from aimet_torch.pro.quantsim import QuantizationSimModel
from aimet_torch.adaround.adaround_weight import Adaround, AdaroundParameters
from aimet_torch.onnx_utils import onnx_pytorch_conn_graph_type_pairs
from aimet_common.utils import AimetLogger
import logging
Expand Down Expand Up @@ -107,36 +106,15 @@ def run_pytorch_cross_layer_equalization(config, model):
cross_layer_equalization.equalize_model(model.cpu(), config.input_shape)
return model

def run_pytorch_adaround(config, model, data_loaders):
if hasattr(config, 'quant_scheme'):
if config.quant_scheme == 'range_learning_tf':
quant_scheme = QuantScheme.post_training_tf
elif config.quant_scheme == 'range_learning_tfe':
quant_scheme = QuantScheme.post_training_tf_enhanced
elif config.quant_scheme == 'tf':
quant_scheme = QuantScheme.post_training_tf
elif config.quant_scheme == 'tf_enhanced':
quant_scheme = QuantScheme.post_training_tf_enhanced
else:
raise ValueError("Got unrecognized quant_scheme: " + config.quant_scheme)

params = AdaroundParameters(data_loader = data_loaders, num_batches = config.num_batches, default_num_iterations = config.num_iterations,
default_reg_param = 0.01, default_beta_range = (20, 2))
ada_model = Adaround.apply_adaround(model.cuda(), params, default_param_bw= config.default_param_bw,
default_quant_scheme = quant_scheme,
default_config_file = config.config_file
)
return ada_model


def arguments():
parser = argparse.ArgumentParser(description='Evaluation script for PyTorch EfficientNet-lite0 networks.')

parser.add_argument('--checkpoint', help='Path to optimized checkpoint', default=None, type=str)
parser.add_argument('--images-dir', help='Imagenet eval image', default='./ILSVRC2012_PyTorch/', type=str)
parser.add_argument('--input-shape', help='Model to an input image shape, (ex : [batch, channel, width, height]', default=(1,3,224,224))
parser.add_argument('--seed', help='Seed number for reproducibility', default=0)

parser.add_argument('--quant-tricks', help='Preprocessing prior to Quantization', choices=['BNfold', 'CLE', 'adaround'], nargs = "+")
parser.add_argument('--quant-tricks', help='Preprocessing prior to Quantization', default=[], choices=['BNfold', 'CLE'], nargs = "+")
parser.add_argument('--quant-scheme', help='Quant scheme to use for quantization (tf, tf_enhanced, range_learning_tf, range_learning_tf_enhanced).', default='tf', choices = ['tf', 'tf_enhanced', 'range_learning_tf', 'range_learning_tf_enhanced'])
parser.add_argument('--round-mode', help='Round mode for quantization.', default='nearest')
parser.add_argument('--default-output-bw', help='Default output bitwidth for quantization.', default=8)
Expand All @@ -147,18 +125,17 @@ def arguments():
parser.add_argument('--batch-size', help='Data batch size for a model', default=64)
parser.add_argument('--num-workers', help='Number of workers to run data loader in parallel', default=16)

parser.add_argument('--num-iterations', help='Number of iterations used for adaround optimization', default=10000, type = int)
parser.add_argument('--num-batches', help='Number of batches used for adaround optimization', default=16, type = int)

args = parser.parse_args()
return args


def main():
args = arguments()
seed(args)

model = load_model()
if args.checkpoint:
model = torch.load(args.checkpoint)
else:
model = load_model()
model.eval()

image_size = args.input_shape[-1]
Expand All @@ -183,9 +160,6 @@ def main():
if 'CLE' in args.quant_tricks:
print("CLE")
model = run_pytorch_cross_layer_equalization(args, model)
print(model)
if 'adaround' in args.quant_tricks:
model = run_pytorch_adaround(args, model, val_dataloader)

if hasattr(args, 'quant_scheme'):
if args.quant_scheme == 'range_learning_tf':
Expand Down

0 comments on commit 722a1ba

Please sign in to comment.