-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathevaluation.py
93 lines (77 loc) · 2.81 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from dataclasses import dataclass
from typing import Sequence
import pandas as pd
from tabulate import tabulate
from torchmetrics.text import CharErrorRate
@dataclass
class EvaluationReport:
""" "
Store the evaluation report, CER, and accuracy.
"""
report: pd.DataFrame
cer: float
accuracy: float
def __repr__(self) -> str:
"""Return a string representation of the evaluation report when print()."""
return (
f"CER: {self.cer:.2f}, Accuracy: {self.accuracy:.2f}\n"
f"{tabulate(self.report, headers='keys', tablefmt='simple_outline')}"
)
def evaluate(predictions: Sequence[str], labels: Sequence[str]) -> EvaluationReport:
"""Evaluate the error between predictions and labels."""
####################
# Input validation #
####################
# If both predictions and labels are strings, convert them to lists
if isinstance(predictions, str) and isinstance(labels, str):
predictions = [predictions]
labels = [labels]
# Check if both predictions and labels are sequences
if not isinstance(predictions, Sequence) or not isinstance(labels, Sequence):
raise ValueError("Both predictions and labels must be sequences.")
# Check if the length of predictions and labels are the same
if len(predictions) != len(labels):
raise ValueError("The length of predictions and labels must be the same.")
##############
# Evaluation #
##############
total_correct_prediction = 0
cer = CharErrorRate()
cer_result = []
# Evaluate each prediction
for prediction, label in zip(predictions, labels):
# Handle empty strings
if prediction == "" and label == "":
cer_result.append(0.0)
total_correct_prediction += 1
continue
# CER
value = cer(prediction, label).item() # Convert tensor to float
cer_result.append(value)
# Accuracy
if prediction == label:
total_correct_prediction += 1
# Create CER report
cer_report = pd.DataFrame(
{"label": labels, "prediction": predictions, "cer": cer_result}
)
# Compute total CER
total_cer = cer.compute().item() # Convert tensor to float
# Compute accuracy
total_accuracy = total_correct_prediction / len(predictions)
return EvaluationReport(report=cer_report, cer=total_cer, accuracy=total_accuracy)
if __name__ == "__main__":
# Run this file to test the evaluate function
predictions = [
"",
"hallucination",
"",
"substltution",
"deleion",
"insertionn",
"correct",
"asdfasdfasdfasdfasdfasdfasdf",
]
labels = ["", "", "text", "substitution", "deletion", "insertion", "correct", "12"]
report = evaluate(predictions, labels)
print(report)