This repository has been archived by the owner on Dec 30, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathevaluate.py
72 lines (61 loc) · 2.41 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from rouge import Rouge
from cider import Cider
from meteor import Meteor
# DeepFashion数据集评分类
class DeepFashionEvalCap:
def __init__(self, df, dfRes):
"""
:param df: 参考句子, 字典类型, key为图片id, value为句子列表
:param dfRes: 测试句子, 字典类型, key为图片id, value为句子列表
"""
self.evalImgs = []
self.eval = {}
self.imgToEval = {}
self.df = df
self.dfRes = dfRes
self.params = {'image_id': df.keys()}
def evaluate(self):
image_ids = self.params['image_id']
gts = {}
res = {}
for image_id in image_ids:
gts[image_id] = self.df[image_id]
res[image_id] = self.dfRes[image_id]
# 标记化
gts = self.tokenize(gts)
res = self.tokenize(res)
# 设置评分器
print('Setting up scorers...')
scorers = [
(Meteor(), 'METEOR'),
(Rouge(), 'ROUGE'),
(Cider(), 'CIDER')
]
# 计算分数
for scorer, method in scorers:
print('Calculating %s score...' % method)
score, scores = scorer.compute_score(gts, res)
if isinstance(method, list):
for sc, scs, m in zip(score, scores, method):
self.set_eval(sc, m)
self.set_img_to_eval_imgs(scs, gts.keys(), m)
print('%s: %0.3f' % (m, sc))
else:
self.set_eval(score, method)
self.set_img_to_eval_imgs(scores, gts.keys(), method)
print('%s: %0.3f' % (method, score))
self.set_eval_imgs()
def set_eval(self, score, method):
self.eval[method] = score
def set_img_to_eval_imgs(self, scores, image_ids, method):
for image_id, score in zip(image_ids, scores):
if image_id not in self.imgToEval:
self.imgToEval[image_id] = {}
self.imgToEval[image_id]['image_id'] = image_id
self.imgToEval[image_id][method] = score
def set_eval_imgs(self):
self.evalImgs = [eval for _, eval in self.imgToEval.items()]
# 设置标记化为静态方法
@staticmethod
def tokenize(annotations):
return {image_id: [ann for ann in anns] for image_id, anns in annotations.items()}