-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathanalyse.py
85 lines (69 loc) · 2.51 KB
/
analyse.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from tqdm import tqdm
import json
from collections import defaultdict
import codecs
def evaluate(data):
official_A, official_B, official_C = 1e-10, 1e-10, 1e-10
results = []
for d in tqdm(iter(data)):
R = set([tuple(i) for i in d['predict']])
official_T = set([tuple(i) for i in d['truth']])
official_A += len(R & official_T)
official_B += len(R)
official_C += len(official_T)
return 2 * official_A / (official_B + official_C), official_A / official_B, official_A / official_C
def split_file_by_overlapping_type(data):
normal_results = []
epo_results = []
spo_results = []
for d in tqdm(iter(data)):
official_T = set([tuple(i) for i in d['truth']])
head_dict = defaultdict(int)
tail_dict = defaultdict(int)
head_tail_dict = defaultdict(int)
for spo in official_T:
head_dict[spo[0]] += 1
tail_dict[spo[2]] += 1
head_tail_dict[(spo[0],spo[2])] += 1
epo_flag = spo_flag = False
for head_tail in head_tail_dict:
if head_tail_dict[head_tail] > 1:
epo_flag = True
for head in head_dict:
if head_dict[head] > 1:
spo_flag = True
for tail in tail_dict:
if tail_dict[tail] > 1:
spo_flag = True
if epo_flag:
epo_results.append(d)
elif spo_flag:
spo_results.append(d)
if not spo_flag and not epo_flag:
normal_results.append(d)
return normal_results, spo_results, epo_results
def split_file_by_triplet_num(data):
normal_results = []
one_results = []
two_results = []
three_results = []
four_results = []
gfour_results = []
for d in tqdm(iter(data)):
official_T = set([tuple(i) for i in d['truth']])
if len(official_T) == 1:
one_results.append(d)
elif len(official_T) == 2:
two_results.append(d)
elif len(official_T) == 3:
three_results.append(d)
elif len(official_T) == 4:
four_results.append(d)
elif len(official_T) > 4:
gfour_results.append(d)
return one_results, two_results, three_results, four_results, gfour_results
if __name__ == '__main__':
test_file_name = 'saved_models/multi/best_test_results.json' # You can obtain this file after running the eval.py script
test_file = json.load(open(test_file_name))
for res in split_by_length(test_file):
print(evaluate(res))