-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathfilter_data.py
147 lines (107 loc) · 4.28 KB
/
filter_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import argparse
import os
from tqdm import tqdm
from src.fact import Fact, fact_from_dict
from src.utils.io import read_json, save_json
from src.utils.logger import freeze_args, get_logger
import re
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--unfiltered_base_fakepedia_path", type=str
)
parser.add_argument("--data_dir", default="./data/", type=str)
return parser.parse_args()
def split_text(text):
# Regular expression pattern for commas and sentence enders
pattern = r"[.?!]\s+|,\s+"
parts = re.split(pattern, text)
# Filter out any empty strings in the result
parts = [part.strip() for part in parts if part.strip()]
return parts
def is_paragraph_bad(fact: Fact):
bad_expressions = [
"often mistaken",
"often misunderstood",
"common misconception",
"false",
"is not",
"was not",
"does not",
"did not",
]
if fact.get_intermediate_paragraph() is not None:
paragraph = fact.get_intermediate_paragraph().lower()
else:
paragraph = fact.get_paragraph().lower()
false_object = fact.get_object().lower()
true_object = fact.get_parent().get_object().lower()
subject = fact.get_subject().lower()
# Split paragraph into sentences on punctuation
sentences = split_text(paragraph)
# Check false object is in paragraph
if false_object not in paragraph:
return True
# Check if any of the bad expressions are in the paragraph
for sentence in sentences:
# Check if sentence contains false object
if false_object in sentence:
# Check if sentence contains any of the expressions
for expression in bad_expressions:
if expression in sentence:
return True
# Check if sentence contains true object, the subject and does not contain "not"
# We do this only for facts where the true object is not part of the subject
elif true_object not in subject and true_object in sentence and "not" not in sentence:
return True
return False
def generate_dataset(args):
logger = get_logger()
# Find good paragraphs
bad_paragraphs = set()
good_paragraphs = set()
unfiltered_dataset = read_json(args.unfiltered_base_fakepedia_path)
bad_paragraphs_count = 0
for entry in tqdm(unfiltered_dataset, desc="Finding bad paragraphs"):
fact = fact_from_dict(entry)
to_discard = is_paragraph_bad(fact)
if to_discard:
bad_paragraphs.add(fact.get_paragraph())
bad_paragraphs_count += 1
else:
good_paragraphs.add(fact.get_paragraph())
# Show number of bad paragraphs and good paragraphs out of total
logger.info(
"Found {} bad paragraphs out of {} total paragraphs".format(bad_paragraphs_count, len(unfiltered_dataset))
)
# Get the files with "unfiltered_" in the absolute path and set the output path to the same path without "unfiltered_".
files_to_filter = [
os.path.join(dirpath, filename)
for dirpath, dirnames, filenames in os.walk(args.data_dir)
for filename in filenames
]
files_to_filter = [file for file in files_to_filter if "unfiltered_" in file]
for file_path in files_to_filter:
output_path = file_path.replace("unfiltered_", "")
logger.info("Loading '{}'...".format(file_path))
unfiltered_dataset = read_json(file_path)
logger.info("Filtering entries...")
dataset = []
for entry in tqdm(unfiltered_dataset, desc="Filtering entries"):
fact = fact_from_dict(entry["fact"] if "fact" in entry else entry)
if fact.get_intermediate_paragraph() is not None:
paragraph = fact.get_intermediate_paragraph()
else:
paragraph = fact.get_paragraph()
to_save = paragraph not in bad_paragraphs
if to_save:
dataset.append(entry)
logger.info("Filtered dataset has {} entries".format(len(dataset)))
logger.info("Saving filtered dataset to '{}'...".format(output_path))
save_json(dataset, output_path)
def main():
args = get_args()
freeze_args(args)
generate_dataset(args)
if __name__ == "__main__":
main()