Skip to content
This repository has been archived by the owner on Jul 25, 2024. It is now read-only.

Commit

Permalink
Refer #25. Evaluate implemented, tests not really completed
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex Hebing committed Jul 1, 2019
1 parent 15f61ea commit 1595d37
Show file tree
Hide file tree
Showing 2 changed files with 269 additions and 0 deletions.
205 changes: 205 additions & 0 deletions evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import os
import sys
import argparse
from sklearn.metrics import classification_report


# Input validation
def dir(path):
if not os.path.isdir(path):
raise argparse.ArgumentTypeError(
"Path '{}' is not a directory".format(path))
return path


def parseArguments(sysArgs):
parser = argparse.ArgumentParser(
description="""Evaluate BIO formatted entity labels against a Golden Standard (also BIO formatted).
Note that the script expects the Predicted entitites and the Golden Standard to be in
files of the same name, with extension .bio""")

parser.add_argument(
'--gold_dir',
dest='gold_dir', required=True, type=dir,
help="The directory that the Golden Standard is in")

parser.add_argument(
'--pred_dir',
dest='pred_dir', required=True, type=dir,
help="The directory that the predicted entities are in")


parsedArgs = parser.parse_args()

return parsedArgs

def evaluate_file(gold_file, pred_file):
gold_labels = extract_labels(gold_file)
pred_labels = extract_labels(pred_file)
return get_report(gold_labels, pred_labels)


def main(args):
args = parseArguments(args)

reports = []
gold_dir = args.gold_dir
pred_dir = args.pred_dir

for pred_file_name in os.listdir(pred_dir):
if pred_file_name.endswith(".bio"):
pred_file_path = os.path.join(pred_dir, pred_file_name)
gold_file_path = os.path.join(gold_dir, pred_file_name)

if os.path.exists(gold_file_path):
report = evaluate_file(gold_file_path, pred_file_path)
reports.append(report)

process_reports(reports)


def process_reports(reports):
ras = []

for report in reports:
for type, values in report.items():
if type in ['LOC', 'PER', 'ORG', 'OTH', 'O']:
existing_ra = next((ra for ra in ras if ra.type == type), None)

if not existing_ra:
ras.append(ReportAverager(type, values))
else:
existing_ra.add(values)

sorted_ras = sorted(ras)
pretty_print(sorted_ras)


def pretty_print(report_averagers):
name_width = max(len(cn.type) for cn in report_averagers)
width = max(name_width, 3, 3)
head_fmt = '{:>{width}s} ' + ' {:>9}' * 3
report = '\n\n'
report += head_fmt.format('', 'precision', 'recall',
'f1-score', width=width)
report += '\n\n'
row_fmt = '{:>{width}s} ' + ' {:>9.{digits}f}' * 3 + ' \n'

o_row = None
for ra in report_averagers:
if ra.type == "O":
o_row = row_fmt.format(ra.type, ra.get_precision(
), ra.get_recall(), ra.get_f1score(), width=width, digits=3)
else:
report += row_fmt.format(ra.type, ra.get_precision(),
ra.get_recall(), ra.get_f1score(), width=width, digits=3)
report += o_row
report += '\n'
print(report)


class ReportAverager:
def __init__(self, type, values):
self.type = type
self.reports = [values]

def __eq__(self, other):
return self.type == other.type

def __lt__(self, other):
return self.type<other.type

def add(self, report):
self.reports.append(report)

def get_precision(self):
total = self.get_total('precision')
if total == 0.00:
return total
return total / len(self.reports)

def get_recall(self):
total = self.get_total('recall')
if total == 0.00:
return total
return total / len(self.reports)

def get_f1score(self):
total = self.get_total('f1-score')
if total == 0.00:
return total
return total / len(self.reports)

def get_total(self, property_name):
total = 0.00
for report in self.reports:
total = total + report[property_name]
return total


def get_report(gold_labels, pred_labels):
'''
Get a report much like this:
{
'LOC': {
'precision': 1.0,
'recall': 0.5,
'f1-score': 0.6666666666666666,
'support': 2
},
'PER': {
'precision': 1.0,
'recall': 1.0,
'f1-score': 1.0,
'support': 1
},
'O': {
'precision': 0.6666666666666666,
'recall': 1.0,
'f1-score': 0.8,
'support': 2
},
'accuracy': 0.8333333333333334,
'macro avg': {
'precision': 0.9166666666666666,
'recall': 0.875,
'f1-score': 0.8666666666666667,
'support': 6
},
'weighted avg': {
'precision': 0.8888888888888888,
'recall': 0.8333333333333334,
'f1-score': 0.8222222222222223,
'support': 6
}
}
'''
return classification_report(gold_labels, pred_labels, output_dict=True)


def extract_labels(file):
labels = []

with open(file, 'r') as fh:
for line in fh.readlines():
if 'LOC' in line:
labels.append('LOC')
continue
if 'PER' in line:
labels.append('PER')
continue
if 'ORG' in line:
labels.append('ORG')
continue
if 'OTHER' in line or 'OTH' in line:
labels.append('OTH')
continue

labels.append('O')

return labels


if __name__ == '__main__':
sys.exit(main(sys.argv))
64 changes: 64 additions & 0 deletions test_evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@

from sklearn.metrics import f1_score
from sklearn.metrics import classification_report

def test_f1_score_base():
gold = ['DE O', 'EERSTE B-LOC', 'DE O', 'BESTE B-PER', 'DUS B-LOC']
pred = ['DE O', 'EERSTE B-LOC', 'DE O', 'BESTE B-PER', 'DUS B-LOC']
macro = f1_score(gold, pred, average='macro')
assert macro == 1.0

micro = f1_score(gold, pred, average='micro')
assert micro == 1.0

weighted = f1_score(gold, pred, average='weighted')
assert weighted == 1.0

none = f1_score(gold, pred, average=None)
print(none)
assert 0 #none == [1. 1. 1. 1.]

# print('weighted: ', f1_score(gold, pred, average='weighted'))

def test_f1_score():
gold = ['O', 'LOC', 'O', 'PER', 'LOC', 'LOC', 'O']
pred = ['O', 'O', 'O', 'PER', 'LOC', 'LOC', 'LOC']

gold = ['LOC', 'LOC', 'O']
pred = ['O', 'LOC', 'LOC']

# macro = f1_score(gold, pred, average='macro')
# assert macro == 0.6

# micro = f1_score(gold, pred, average='micro')
# assert micro == 0.8000000000000002

# micro = f1_score(gold, pred, average='micro', labels=['B-LOC'])
# assert micro == 0.8000000000000002

# print('macro: ', f1_score(gold, pred, average='macro'))
# print('micro: ', f1_score(gold, pred, average='micro'))
# print('weighted: ', f1_score(gold, pred, average='weighted'))

print(classification_report(gold, pred))
assert 0

def test_f1_real_world_data():
gold = ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'PER', 'PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'PER', 'PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'PER', 'PER', 'O', 'O', 'O', 'O', 'O', 'ORG', 'ORG', 'ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
pred = ['O', 'ORG', 'ORG', 'ORG', 'O', 'LOC', 'O', 'O', 'ORG', 'O', 'O', 'O', 'O', 'PER', 'PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'LOC', 'LOC', 'LOC', 'O', 'O', 'LOC', 'LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'PER', 'PER', 'O', 'PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'LOC', 'LOC', 'LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'PER', 'PER', 'O', 'O', 'O', 'O', 'O', 'ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']

# macro = f1_score(gold, pred, average='macro')
# assert macro == 0.6

# micro = f1_score(gold, pred, average='micro')
# assert micro == 0.8000000000000002

# micro = f1_score(gold, pred, average='micro', labels=['B-LOC'])
# assert micro == 0.8000000000000002

# print('macro: ', f1_score(gold, pred, average='macro'))
# print('micro: ', f1_score(gold, pred, average='micro'))
# print('weighted: ', f1_score(gold, pred, average='weighted'))

print(classification_report(gold, pred))
assert 0

0 comments on commit 1595d37

Please sign in to comment.