-
Notifications
You must be signed in to change notification settings - Fork 3
/
inference.py
45 lines (36 loc) · 1.88 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import argparse
import os
from utils.inference_helper import generate_inference
from utils.model_loader import load_custom_model
from glob import glob
from tqdm import tqdm
import yaml
def execute_inference():
parser = argparse.ArgumentParser()
parser.add_argument('-y', '--yaml', default='config/parameters.yaml', help='Config file YAML format')
parser.add_argument('-c', '--checkpoint_dir', help='Folder to load exported checkpoint.')
parser.add_argument('-i', '--image_path', help='Input jpg images folder')
parser.add_argument('-e', '--image_extension', help='Image file extension: jpg or jpeg')
parser.add_argument('-l', '--label_map', help='Path to pbtxt file')
parser.add_argument('-o', '--output_dir', default='./output_inference', help='Output folder')
args = parser.parse_args()
try:
with open(args.yaml, 'r') as file:
config = yaml.safe_load(file)
except Exception as e:
print('Error reading the config file {}'.format(args.yaml))
print(e)
exit()
model_path = args.checkpoint_dir if args.checkpoint_dir else os.path.join(config['pipeline_config']['checkpoint_save_path'], 'exported')
label_map = args.label_map if args.label_map else config['pipeline_config']['labelmap_path']
image_path = args.image_path if args.image_path else config['pipeline_config']['input_test_img_folder']
img_extension = args.image_extension if args.image_extension else config['pipeline_config']['image_extension']
print("Loading model...")
detection_model = load_custom_model(model_path)
image_files = os.path.join(image_path, f'*.{img_extension}')
os.makedirs(args.output_dir, exist_ok=True)
print("Executing inference...")
for image_path in tqdm(glob(image_files)):
generate_inference(detection_model, label_map, image_path, args.output_dir)
if __name__ == "__main__":
execute_inference()