-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfrechet_distance.py
106 lines (88 loc) · 4.03 KB
/
frechet_distance.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import argparse
from typing import Dict, List
import csv
import json
from tqdm import tqdm
import numpy as np
from pytorch_fid.fid_score import calculate_frechet_distance
from features import Statistics
from utilities import extract_from_filename, glob_filepaths
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('real_dir', type=str, help='Real features directory (load all npz files inside)')
parser.add_argument('fake_dir', type=str, help='Fake features directory (load all npz files inside)')
parser.add_argument('out_path', type=str, help='Path to output csv file')
parser.add_argument('--recursive', '-r', action='store_true', help='Load features recursively from given directory')
parser.add_argument('--json', '-j', type=str, nargs='?', default=None, const='', help='Output json file')
return parser.parse_args()
def main(opts):
real_dir = os.path.abspath(opts.real_dir)
real_files = glob_filepaths(real_dir, recursive=opts.recursive)
real_stats_dict = dict()
classifier_set = set()
print('Loading real features...')
for f in tqdm(real_files):
name, classifier = extract_from_filename(f)
features = np.load(f)['features']
real_stats_dict[classifier] = Statistics(name, classifier, features)
classifier_set.add(classifier)
fake_dir = os.path.abspath(opts.fake_dir)
fake_files = glob_filepaths(fake_dir, recursive=opts.recursive)
fake_stats_list_dict:Dict[List[Statistics]] = dict()
fake_name_set = set()
print('Loading fake features...')
for f in tqdm(fake_files):
name, classifier = extract_from_filename(f)
features = np.load(f)['features']
stats = Statistics(name, classifier, features)
if classifier in fake_stats_list_dict:
fake_stats_list_dict[classifier] += [stats]
else:
fake_stats_list_dict[classifier] = [stats]
fake_name_set.add(name)
print('Evaluating Frechet Distances...')
results = dict()
classifiers = tqdm(real_stats_dict.keys())
classifiers.set_description('Total progress')
for classifier in classifiers:
real_stats = real_stats_dict[classifier]
if classifier not in results:
results[classifier] = dict()
fake_stats_list = tqdm(fake_stats_list_dict.get(classifier, []))
for fake_stats in fake_stats_list:
fake_stats_list.set_description(f'Classifier={classifier}, Name={fake_stats.name}')
results[classifier][fake_stats.name] = calculate_frechet_distance(real_stats.mu, real_stats.sigma, fake_stats.mu, fake_stats.sigma)
save_path = os.path.abspath(opts.out_path)
save_dir, _ = os.path.split(save_path)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
print(f'Writing results to {save_path}')
with open(save_path, mode='w', encoding='utf8', newline='') as f:
writer = csv.writer(f, delimiter=',')
fake_names = sorted(fake_name_set)
writer.writerow([''] + [name for name in fake_names])
for classifier in sorted(classifier_set):
row = [classifier]
res_cls = results.get(classifier)
if res_cls:
for name in fake_names:
row.append(res_cls.get(name))
writer.writerow(row)
if opts.json is not None:
json_path = os.path.splitext(save_path)[0] + '.json' if not opts.json else opts.json
print(f'Writing results to {json_path}')
json_data = []
for classifier in sorted(classifier_set):
res_cls = results.get(classifier)
if not res_cls:
continue
for name in fake_names:
value = res_cls.get(name)
if value is None:
continue
json_data.append({'classifier':classifier, 'name':name, 'value': value})
with open(json_path, mode='w', encoding='utf8', newline='') as f:
json.dump(json_data, f, indent=4)
if __name__ == '__main__':
main(parse_args())