-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy patheval_utils.py
149 lines (117 loc) · 4.47 KB
/
eval_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# Utils
import torch
def compute_accuracy(model, data_iterator, loss_fn, no_print_idx, pad_value=-1,
show_example=False, only_nbatch=-1):
"""Compute accuracies and loss.
:param str, split_name: for printing the accuracy with the split name.
:param bool, show_example: if True, print some decoding output examples.
:param int, only_nbatch: Only use given number of batches. If -1, use all
data (default).
returns loss, accucary char-level accuracy, print accuracy
"""
model.eval()
total_loss = 0.0
corr = 0
corr_char = 0
corr_print = 0
step = 0
total_num_seqs = 0
total_char = 0
total_print = 0
for idx, batch in enumerate(data_iterator):
step += 1
src, tgt = batch
logits = model(src)
target = tgt # (B, len)
# to compute accuracy
output = torch.argmax(logits, dim=-1).squeeze()
# compute loss
logits = logits.contiguous().view(-1, logits.shape[-1])
labels = tgt.view(-1)
loss = loss_fn(logits, labels)
total_loss += loss
# sequence level accuracy
seq_match = (torch.eq(target, output) | (target == pad_value)
).all(1).sum().item()
corr += seq_match
total_num_seqs += src.size()[0]
# padded part should not be counted as correct
char_match = torch.logical_and(
torch.logical_and(torch.eq(target, output), target != pad_value),
target == no_print_idx).sum().item()
corr_char += char_match
total_char += torch.logical_and(
target != pad_value, target == no_print_idx).sum().item()
# Ignore non-print outputs
print_match = torch.logical_and(
torch.logical_and(torch.eq(target, output), target != pad_value),
target != no_print_idx).sum().item()
corr_print += print_match
total_print += torch.logical_and(
target != pad_value, target != no_print_idx).sum().item()
if only_nbatch > 0:
if idx > only_nbatch:
break
# if show_example:
# out_string = []
# for a in src[0]:
# out_string.append(module.source.vocab.itos[a.item()])
# print(f"example: {''.join(out_string)}")
# out_string = []
# for a in target[0]:
# out_string.append(module.source.vocab.itos[a.item()])
# print(f"correct answer: {''.join(out_string)}")
# out_string = []
# for a in output[0]:
# out_string.append(module.source.vocab.itos[a.item()])
# print(f"model prediction: {''.join(out_string)}")
res_loss = total_loss.item() / float(step)
acc = corr / float(total_num_seqs) * 100
if total_char > 0:
no_op_acc = corr_char / float(total_char) * 100
else:
no_op_acc = -0
print_acc = corr_print / float(total_print) * 100
return res_loss, acc, no_op_acc, print_acc
def listops_compute_accuracy(model, data_iterator, loss_fn,
show_example=False, only_nbatch=-1):
"""Compute accuracies and loss for ListOps.
:param str, split_name: for printing the accuracy with the split name.
:param bool, show_example: if True, print some decoding output examples.
:param int, only_nbatch: Only use given number of batches. If -1, use all
data (default).
returns loss, accucary char-level accuracy, print accuracy
"""
model.eval()
total_loss = 0.0
corr = 0
step = 0
total_num_seqs = 0
for idx, batch in enumerate(data_iterator):
step += 1
seq_len, labels, inputs = batch
logits = model(inputs)
target = labels # (B,)
seq_len = seq_len.unsqueeze(1).unsqueeze(2).expand(
-1, -1, logits.size(2))
logits = logits.gather(1, seq_len).squeeze(1)
# to compute accuracy
output = torch.argmax(logits, dim=-1)
# print('********')
# print(f'target {target}')
# print(f'model out: {output}')
# print('********')
# import sys; sys.exit(0)
# compute loss
loss = loss_fn(logits, labels)
total_loss += loss
# accuracy
seq_match = torch.eq(target, output).sum().item()
corr += seq_match
total_num_seqs += inputs.size()[0]
if only_nbatch > 0:
if idx > only_nbatch:
break
res_loss = total_loss.item() / float(step)
acc = corr / float(total_num_seqs) * 100
return res_loss, acc