From 1595d37a0ddd0e03b9cee10b7abfbd0564795d6e Mon Sep 17 00:00:00 2001 From: Alex Hebing Date: Mon, 1 Jul 2019 15:38:37 +0200 Subject: [PATCH] Refer #25. Evaluate implemented, tests not really completed --- evaluate.py | 205 +++++++++++++++++++++++++++++++++++++++++++++++ test_evaluate.py | 64 +++++++++++++++ 2 files changed, 269 insertions(+) create mode 100644 evaluate.py create mode 100644 test_evaluate.py diff --git a/evaluate.py b/evaluate.py new file mode 100644 index 0000000..a63919d --- /dev/null +++ b/evaluate.py @@ -0,0 +1,205 @@ +import os +import sys +import argparse +from sklearn.metrics import classification_report + + +# Input validation +def dir(path): + if not os.path.isdir(path): + raise argparse.ArgumentTypeError( + "Path '{}' is not a directory".format(path)) + return path + + +def parseArguments(sysArgs): + parser = argparse.ArgumentParser( + description="""Evaluate BIO formatted entity labels against a Golden Standard (also BIO formatted). + Note that the script expects the Predicted entitites and the Golden Standard to be in + files of the same name, with extension .bio""") + + parser.add_argument( + '--gold_dir', + dest='gold_dir', required=True, type=dir, + help="The directory that the Golden Standard is in") + + parser.add_argument( + '--pred_dir', + dest='pred_dir', required=True, type=dir, + help="The directory that the predicted entities are in") + + + parsedArgs = parser.parse_args() + + return parsedArgs + +def evaluate_file(gold_file, pred_file): + gold_labels = extract_labels(gold_file) + pred_labels = extract_labels(pred_file) + return get_report(gold_labels, pred_labels) + + +def main(args): + args = parseArguments(args) + + reports = [] + gold_dir = args.gold_dir + pred_dir = args.pred_dir + + for pred_file_name in os.listdir(pred_dir): + if pred_file_name.endswith(".bio"): + pred_file_path = os.path.join(pred_dir, pred_file_name) + gold_file_path = os.path.join(gold_dir, pred_file_name) + + if os.path.exists(gold_file_path): + report = evaluate_file(gold_file_path, pred_file_path) + reports.append(report) + + process_reports(reports) + + +def process_reports(reports): + ras = [] + + for report in reports: + for type, values in report.items(): + if type in ['LOC', 'PER', 'ORG', 'OTH', 'O']: + existing_ra = next((ra for ra in ras if ra.type == type), None) + + if not existing_ra: + ras.append(ReportAverager(type, values)) + else: + existing_ra.add(values) + + sorted_ras = sorted(ras) + pretty_print(sorted_ras) + + +def pretty_print(report_averagers): + name_width = max(len(cn.type) for cn in report_averagers) + width = max(name_width, 3, 3) + head_fmt = '{:>{width}s} ' + ' {:>9}' * 3 + report = '\n\n' + report += head_fmt.format('', 'precision', 'recall', + 'f1-score', width=width) + report += '\n\n' + row_fmt = '{:>{width}s} ' + ' {:>9.{digits}f}' * 3 + ' \n' + + o_row = None + for ra in report_averagers: + if ra.type == "O": + o_row = row_fmt.format(ra.type, ra.get_precision( + ), ra.get_recall(), ra.get_f1score(), width=width, digits=3) + else: + report += row_fmt.format(ra.type, ra.get_precision(), + ra.get_recall(), ra.get_f1score(), width=width, digits=3) + report += o_row + report += '\n' + print(report) + + +class ReportAverager: + def __init__(self, type, values): + self.type = type + self.reports = [values] + + def __eq__(self, other): + return self.type == other.type + + def __lt__(self, other): + return self.type