From b9524217db4b4d5be4ff2f08b6a5ca0be7f7c263 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 12 Oct 2023 11:29:28 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../applications/dagw_reddit/apply_filter.py | 7 +- .../applications/dagw_reddit/preprocess.py | 10 +- .../src/applications/danews/add_metadata.py | 4 +- archive_v1/src/applications/danews/dedupe.py | 8 +- .../src/applications/danews/quality_filter.py | 3 +- .../intercoder_reliability.py | 15 +- .../applications/dataset_validation/main.py | 7 +- .../applications/dataset_validation/mc4.py | 1 - .../applications/hopetwitter/add_metadata.py | 8 +- .../applications/hopetwitter/apply_filter.py | 7 +- .../src/applications/hopetwitter/dedupe.py | 6 +- .../hopetwitter/flatten_ndjson.py | 13 +- .../hopetwitter/quality_filter.py | 6 +- .../cleaning-v1/add_is_duplicate_column.py | 2 +- .../netarkivet/cleaning-v1/apply_filters.py | 12 +- .../content_filtering/DNS_filter.py | 29 +- .../content_filtering/count_domains.py | 12 +- .../content_filtering/filter_domains.py | 13 +- .../cleaning-v1/create_metadata_csv.py | 7 +- .../netarkivet/cleaning-v1/dedupe.py | 9 +- .../netarkivet/cleaning-v1/desc_stats.py | 7 +- .../cleaning-v1/extract_summary_statistics.py | 2 +- .../netarkivet/cleaning-v1/quality_filter.py | 5 +- .../train/run_mlm_pytorch_stream.py | 147 +++-- archive_v1/src/dfm/cleaning/clean_cli.py | 33 +- archive_v1/src/dfm/cleaning/dedupe_cli.py | 42 +- archive_v1/src/dfm/cleaning/deduper.py | 93 +-- archive_v1/src/dfm/cleaning/deduper_utils.py | 4 +- archive_v1/src/dfm/cleaning/quality_filter.py | 111 ++-- .../src/dfm/cleaning/sentence_filter.py | 48 +- archive_v1/src/dfm/data/load_datasets.py | 56 +- .../dataset_validation/rating_interface.py | 12 +- .../dfm/description/description_patterns.py | 619 +++++++++--------- .../dfm/description/generate_description.py | 12 +- .../src/dfm/description/match_counter.py | 36 +- .../src/dfm/dfm_tokenizers/train_tokenizer.py | 27 +- archive_v1/src/dfm/modelling/preprocess.py | 14 +- archive_v1/tests/cleaning/deduper_test.py | 53 +- archive_v1/tests/cleaning/quality_test.py | 61 +- .../tests/cleaning/sentence_filter_test.py | 60 +- .../tests/description/match_counter_test.py | 9 +- .../dfm_tokenizers/tokenizer_config_test.py | 38 +- .../dfm_tokenizers/train_tokenizer_test.py | 82 ++- 43 files changed, 934 insertions(+), 816 deletions(-) diff --git a/archive_v1/src/applications/dagw_reddit/apply_filter.py b/archive_v1/src/applications/dagw_reddit/apply_filter.py index ce9cffae..531fafc7 100644 --- a/archive_v1/src/applications/dagw_reddit/apply_filter.py +++ b/archive_v1/src/applications/dagw_reddit/apply_filter.py @@ -14,13 +14,16 @@ ds = load_from_disk(path) ds_filtered = ds.filter( - lambda example: example["is_13_gram_duplicate"] is False, num_proc=16 + lambda example: example["is_13_gram_duplicate"] is False, + num_proc=16, ) assert len(set(ds_filtered["is_13_gram_duplicate"])) == 1 # write dataset with added metadata save_path = os.path.join( - "/work", "dagw-clean", f"dagw_reddit_filtered_v{ds.version}.arrow" + "/work", + "dagw-clean", + f"dagw_reddit_filtered_v{ds.version}.arrow", ) msg.info(f"Saving to disk: {save_path}") ds_filtered.save_to_disk(save_path) diff --git a/archive_v1/src/applications/dagw_reddit/preprocess.py b/archive_v1/src/applications/dagw_reddit/preprocess.py index b72e882a..681a7270 100644 --- a/archive_v1/src/applications/dagw_reddit/preprocess.py +++ b/archive_v1/src/applications/dagw_reddit/preprocess.py @@ -7,9 +7,8 @@ from pathlib import Path from datasets import concatenate_datasets, load_dataset -from wasabi import msg - from dfm.cleaning import Deduper, QualityFilter +from wasabi import msg def q_filter(batch): @@ -75,7 +74,9 @@ def dedupe_batch(batch, deduper: Deduper): return batch -def filter_categories(examples, remove_cat={"danavis"}): +def filter_categories(examples, remove_cat=None): + if remove_cat is None: + remove_cat = {"danavis"} i = 0 while i < len(examples["source"]): s = examples["source"][i] @@ -99,7 +100,8 @@ def main( reddit_da = reddit_da.rename_columns({"id": "doc_id"}) reddit_da = reddit_da.add_column("LICENSE", ["MIT"] * len(reddit_da)) reddit_da = reddit_da.add_column( - "date_built", ["Wed Dec 15 00:00:00 2021 CEST +0200"] * len(reddit_da) + "date_built", + ["Wed Dec 15 00:00:00 2021 CEST +0200"] * len(reddit_da), ) reddit_da = reddit_da.add_column("source", ["reddit-da"] * len(reddit_da)) reddit_da = reddit_da.add_column("uri", ["NA"] * len(reddit_da)) diff --git a/archive_v1/src/applications/danews/add_metadata.py b/archive_v1/src/applications/danews/add_metadata.py index 371211a8..c2c213de 100644 --- a/archive_v1/src/applications/danews/add_metadata.py +++ b/archive_v1/src/applications/danews/add_metadata.py @@ -272,8 +272,8 @@ def word_count(batch): news_sub = news.remove_columns( [ c - for c in news.features.keys() + for c in news.features if c not in {"n_tokens", "is_duplicate", "passed_quality_filter", "Source"} - ] + ], ) news_sub.to_csv("news_meta.csv") diff --git a/archive_v1/src/applications/danews/dedupe.py b/archive_v1/src/applications/danews/dedupe.py index d07d5a89..bb2ffe4c 100644 --- a/archive_v1/src/applications/danews/dedupe.py +++ b/archive_v1/src/applications/danews/dedupe.py @@ -12,9 +12,8 @@ from functools import partial from datasets import load_from_disk -from wasabi import msg - from dfm.cleaning import Deduper +from wasabi import msg def filter_batch(batch, i): @@ -54,7 +53,6 @@ def __extract_is_duplicate(mask): def main( path, ) -> None: - deduper = Deduper() msg.info("Loading Dataset") @@ -63,7 +61,9 @@ def main( msg.info("Starting deduping") deduper = partial( - dedupe, deduper=deduper, dedupe_path=os.path.join(path, "deduplicated") + dedupe, + deduper=deduper, + dedupe_path=os.path.join(path, "deduplicated"), ) # dedupe dataset ds = ds.map( diff --git a/archive_v1/src/applications/danews/quality_filter.py b/archive_v1/src/applications/danews/quality_filter.py index 650db411..7c951bc2 100644 --- a/archive_v1/src/applications/danews/quality_filter.py +++ b/archive_v1/src/applications/danews/quality_filter.py @@ -14,9 +14,8 @@ from pathlib import Path from datasets import load_dataset -from wasabi import msg - from dfm.cleaning import QualityFilter +from wasabi import msg def filter_batch(batch, i): diff --git a/archive_v1/src/applications/dataset_validation/intercoder_reliability.py b/archive_v1/src/applications/dataset_validation/intercoder_reliability.py index dab8a124..c2acebd1 100644 --- a/archive_v1/src/applications/dataset_validation/intercoder_reliability.py +++ b/archive_v1/src/applications/dataset_validation/intercoder_reliability.py @@ -47,7 +47,7 @@ ## Text proportions ---- -""" +""", ) @@ -61,7 +61,7 @@ def get_proportions(taggers, md): md.add( f"- *Date: {date}*" + f"\n- *Sentences tagged: {len(df)}*" - + f"\n- *Documents tagged: {n_docs}*" + + f"\n- *Documents tagged: {n_docs}*", ) n_char = sum(len(t) for t in df["text"].values) @@ -100,20 +100,21 @@ def get_proportions(taggers, md): tagger1_name, _, session_n1, __, n_docs1, date1 = tagger1.split("_") tagger2_name, _, session_n2, __, n_docs2, date2 = tagger2.split("_") md.add( - f"**{tagger1_name.capitalize()}** (Session: {session_n1}) vs **{tagger2_name.capitalize()}** - (Session: {session_n2})" + f"**{tagger1_name.capitalize()}** (Session: {session_n1}) vs **{tagger2_name.capitalize()}** - (Session: {session_n2})", ) # merge df = pd.merge(taggers[pair[0]], taggers[pair[1]], on="text", suffixes=("_1", "_2")) kappa = cohen_kappa_score(df["category_1"], df["category_2"]) md.add( - f"- Cohen's Kappa (all categories): {kappa:.4f} (Overlap in sentences: {df.shape[0]})" + f"- Cohen's Kappa (all categories): {kappa:.4f} (Overlap in sentences: {df.shape[0]})", ) kappa = cohen_kappa_score( - df["category_1"] == "correct_language", df["category_2"] == "correct_language" + df["category_1"] == "correct_language", + df["category_2"] == "correct_language", ) md.add( - f"- Cohen's Kappa (correct_language vs not correct_language): {kappa:.4f} (Overlap in sentences: {df.shape[0]})" + f"- Cohen's Kappa (correct_language vs not correct_language): {kappa:.4f} (Overlap in sentences: {df.shape[0]})", ) @@ -131,7 +132,7 @@ def get_proportions(taggers, md): ``` While non-language texts in NAT was often menu bars, contact information, or navigation. -""" +""", ) diff --git a/archive_v1/src/applications/dataset_validation/main.py b/archive_v1/src/applications/dataset_validation/main.py index 5470e92b..2b07cca4 100644 --- a/archive_v1/src/applications/dataset_validation/main.py +++ b/archive_v1/src/applications/dataset_validation/main.py @@ -1,6 +1,7 @@ """Script for rating text quality of NAT.""" +from collections.abc import Iterable from datetime import date -from typing import Iterable, Optional +from typing import Optional from dfm.cleaning import QualityFilter, SentenceFilter from dfm.dataset_validation.rating_interface import ExampleRater @@ -10,7 +11,9 @@ def text_generator( - seed, n_texts: Optional[int], max_texts: Optional[int] + seed, + n_texts: Optional[int], + max_texts: Optional[int], ) -> Iterable[str]: """ Create text generator diff --git a/archive_v1/src/applications/dataset_validation/mc4.py b/archive_v1/src/applications/dataset_validation/mc4.py index 1fd1f94c..b11b919c 100644 --- a/archive_v1/src/applications/dataset_validation/mc4.py +++ b/archive_v1/src/applications/dataset_validation/mc4.py @@ -2,7 +2,6 @@ from datetime import date from datasets import load_dataset - from dfm.dataset_validation.rating_interface import ExampleRater diff --git a/archive_v1/src/applications/hopetwitter/add_metadata.py b/archive_v1/src/applications/hopetwitter/add_metadata.py index ea09f523..a66cdec4 100644 --- a/archive_v1/src/applications/hopetwitter/add_metadata.py +++ b/archive_v1/src/applications/hopetwitter/add_metadata.py @@ -15,10 +15,14 @@ from wasabi import msg path = os.path.join( - "/work", "twitter_cleaned", "twitter_da_stopwords_2019-01-01_2021-04-30" + "/work", + "twitter_cleaned", + "twitter_da_stopwords_2019-01-01_2021-04-30", ) write_path = os.path.join( - "/work", "twitter_cleaned", "twitter_da_stopwords_2019-01-01_2021-04-30.arrow" + "/work", + "twitter_cleaned", + "twitter_da_stopwords_2019-01-01_2021-04-30.arrow", ) ds = load_from_disk(path) diff --git a/archive_v1/src/applications/hopetwitter/apply_filter.py b/archive_v1/src/applications/hopetwitter/apply_filter.py index fdf598b6..d22abc47 100644 --- a/archive_v1/src/applications/hopetwitter/apply_filter.py +++ b/archive_v1/src/applications/hopetwitter/apply_filter.py @@ -14,14 +14,17 @@ if __name__ == "__main__": path = os.path.join( - "/work", "twitter_cleaned", "twitter_da_stopwords_2019-01-01_2021-04-30.arrow" + "/work", + "twitter_cleaned", + "twitter_da_stopwords_2019-01-01_2021-04-30.arrow", ) msg.info(f"loading: {path}") ds = load_from_disk(path) ds_filtered = ds.filter( - lambda example: example["is_duplicate"] is False, num_proc=16 + lambda example: example["is_duplicate"] is False, + num_proc=16, ) assert len(set(ds_filtered["is_duplicate"])) == 1 diff --git a/archive_v1/src/applications/hopetwitter/dedupe.py b/archive_v1/src/applications/hopetwitter/dedupe.py index 93433183..13c560fc 100644 --- a/archive_v1/src/applications/hopetwitter/dedupe.py +++ b/archive_v1/src/applications/hopetwitter/dedupe.py @@ -14,9 +14,8 @@ from pathlib import Path from datasets import load_dataset -from wasabi import msg - from dfm.cleaning import Deduper +from wasabi import msg def filter_batch(batch, i): @@ -67,7 +66,8 @@ def main( json_files = glob.glob(path, recursive=True) w_path = os.path.join( - write_path, "twitter_da_stopwords_2019-01-01_2021-04-30.arrow" + write_path, + "twitter_da_stopwords_2019-01-01_2021-04-30.arrow", ) deduper = Deduper(ngram_size=10) diff --git a/archive_v1/src/applications/hopetwitter/flatten_ndjson.py b/archive_v1/src/applications/hopetwitter/flatten_ndjson.py index a30d1670..ef0ee8eb 100644 --- a/archive_v1/src/applications/hopetwitter/flatten_ndjson.py +++ b/archive_v1/src/applications/hopetwitter/flatten_ndjson.py @@ -15,8 +15,17 @@ def flatten_post( post: dict, - keys_to_keep=["text", "id", "possibly_sensitive", "author_id", "source", "lang"], + keys_to_keep=None, ): + if keys_to_keep is None: + keys_to_keep = [ + "text", + "id", + "possibly_sensitive", + "author_id", + "source", + "lang", + ] return {k: post[k] for k in keys_to_keep} @@ -29,7 +38,7 @@ def flatten_ndjson(path: str, write_folder: str): print(f"Flattening: {path} to {write_path}") # stream in json from orgin to write_path - with open(path, "r") as f: + with open(path) as f: reader = ndjson.reader(f) with open(write_path, "w") as f: diff --git a/archive_v1/src/applications/hopetwitter/quality_filter.py b/archive_v1/src/applications/hopetwitter/quality_filter.py index d948a8ef..92d29d7d 100644 --- a/archive_v1/src/applications/hopetwitter/quality_filter.py +++ b/archive_v1/src/applications/hopetwitter/quality_filter.py @@ -14,9 +14,8 @@ from pathlib import Path from datasets import load_dataset -from wasabi import msg - from dfm.cleaning import QualityFilter +from wasabi import msg def filter_batch(batch, i): @@ -97,7 +96,8 @@ def main( json_files = glob.glob(path, recursive=True) w_path = os.path.join( - write_path, "twitter_da_stopwords_2019-01-01_2021-04-30.jsonl" + write_path, + "twitter_da_stopwords_2019-01-01_2021-04-30.jsonl", ) if os.path.exists(w_path): raise Exception(f"File {w_path} already exist") diff --git a/archive_v1/src/applications/netarkivet/cleaning-v1/add_is_duplicate_column.py b/archive_v1/src/applications/netarkivet/cleaning-v1/add_is_duplicate_column.py index d2474ac5..d95a7de0 100644 --- a/archive_v1/src/applications/netarkivet/cleaning-v1/add_is_duplicate_column.py +++ b/archive_v1/src/applications/netarkivet/cleaning-v1/add_is_duplicate_column.py @@ -42,7 +42,7 @@ def main(netarkivet_path=Path("/work/netarkivet-cleaned")): meta = load_dataset("csv", data_files=meta_path) meta = meta["train"] assert len(meta) == len( - ds + ds, ), "length of dataset and its metadata is not the same." ds = ds.add_column("is_duplicate", meta["is_duplicate"]) ds.to_json(jsonl_file) diff --git a/archive_v1/src/applications/netarkivet/cleaning-v1/apply_filters.py b/archive_v1/src/applications/netarkivet/cleaning-v1/apply_filters.py index 57d5237b..b495ab94 100644 --- a/archive_v1/src/applications/netarkivet/cleaning-v1/apply_filters.py +++ b/archive_v1/src/applications/netarkivet/cleaning-v1/apply_filters.py @@ -10,9 +10,10 @@ import glob import os import random +from collections.abc import Iterable from contextlib import ExitStack from pathlib import Path -from typing import Iterable, List, Optional, Union +from typing import Optional, Union import ndjson from wasabi import msg @@ -55,7 +56,7 @@ def shuffle_buffer(x: Iterable, buffer_size: int) -> Iterable: def jsonl_merge( # noqa C901 - jsonl_files: List[Union[Path, str]], + jsonl_files: list[Union[Path, str]], buffer_size: Optional[int] = None, sample: bool = True, ) -> Iterable[dict]: @@ -81,8 +82,7 @@ def __sample_yield(readers: list) -> Iterable: def __iterative_yield(readers: list) -> Iterable: for reader in readers: - for sample in reader: - yield sample + yield from reader yield_fn = __sample_yield if sample is True else __iterative_yield @@ -102,7 +102,9 @@ def __iterative_yield(readers: list) -> Iterable: yield sample -def apply_filter(dataset=Iterable[dict], columns_to_keep=["text"]) -> Iterable[dict]: +def apply_filter(dataset=Iterable[dict], columns_to_keep=None) -> Iterable[dict]: + if columns_to_keep is None: + columns_to_keep = ["text"] for sample in dataset: if sample["is_duplicate"] is False: yield {k: sample[k] for k in columns_to_keep} diff --git a/archive_v1/src/applications/netarkivet/cleaning-v1/content_filtering/DNS_filter.py b/archive_v1/src/applications/netarkivet/cleaning-v1/content_filtering/DNS_filter.py index 2ef87a91..7da7f9fc 100644 --- a/archive_v1/src/applications/netarkivet/cleaning-v1/content_filtering/DNS_filter.py +++ b/archive_v1/src/applications/netarkivet/cleaning-v1/content_filtering/DNS_filter.py @@ -12,11 +12,11 @@ import os import pickle import ssl -import sys import time +from collections.abc import Iterable from concurrent.futures import ThreadPoolExecutor from itertools import islice -from typing import Iterable, Optional, Tuple +from typing import Optional import aiohttp from aiohttp import ClientTimeout @@ -35,18 +35,17 @@ SSL_PROTOCOLS = (*SSL_PROTOCOLS, uvloop.loop.SSLProtocol) -def get_domains(limit=None, checked={}): +def get_domains(limit=None, checked=None): """Extract a list of domains""" - with open(path, "r") as f: + if checked is None: + checked = {} + with open(path) as f: ss_domains = json.load(f) - if limit: - domains = ss_domains["false"][:limit] - else: - domains = ss_domains["false"] + domains = ss_domains["false"][:limit] if limit else ss_domains["false"] return ["http://" + d for d in domains if "http://" + d not in checked] -async def check_sites(session, ssl_context, url) -> Tuple[str, int]: +async def check_sites(session, ssl_context, url) -> tuple[str, int]: """check if sites is fetchable return the url along with status code""" async with session.get( url, @@ -118,8 +117,7 @@ def ignore_aiohttp_ssl_eror(loop): or 3.8. """ - if sys.version_info >= (3, 7, 4): - return + return orig_handler = loop.get_exception_handler() @@ -189,7 +187,7 @@ def dns_filter( checked = set() for i, sites in enumerate( - batch(get_domains(limit=None, checked=checked), batch_size) + batch(get_domains(limit=None, checked=checked), batch_size), ): msg.info(f"Currently at batch {i} with a batch size of {batch_size}") start_time = time.time() @@ -210,7 +208,7 @@ def dns_filter( pickle.dump(output, f) msg.info( - "Comparing diff between checking Google public DNS and Cleanbrowsing adult DNS" + "Comparing diff between checking Google public DNS and Cleanbrowsing adult DNS", ) domains = { d: (output["google public DNS"][d], c) for d, c in output["cleanweb"].items() @@ -227,7 +225,7 @@ def dns_filter( msg.info( "Performing 10 additional checks to see to using to ensure the sites isn't" - + " fetchable using cleanBrowsing" + + " fetchable using cleanBrowsing", ) output["unsafe_sites_double_checked"] = output["unsafe_sites"] @@ -238,7 +236,7 @@ def dns_filter( ignore_aiohttp_ssl_eror(loop) loop.set_default_executor(ThreadPoolExecutor(n_threads)) t = loop.run_until_complete( - check_all_sites(output["unsafe_sites_double_checked"], ssl_context) + check_all_sites(output["unsafe_sites_double_checked"], ssl_context), ) output["unsafe_sites_double_checked"] = [ site @@ -254,7 +252,6 @@ def dns_filter( if __name__ == "__main__": - path = os.path.join("/work/netarkivet-cleaned/safe_search_domains.json") save_path = os.path.join("/work/netarkivet-cleaned/safe_search_domains_safe.pkl") diff --git a/archive_v1/src/applications/netarkivet/cleaning-v1/content_filtering/count_domains.py b/archive_v1/src/applications/netarkivet/cleaning-v1/content_filtering/count_domains.py index fd0a29db..b4a6d377 100644 --- a/archive_v1/src/applications/netarkivet/cleaning-v1/content_filtering/count_domains.py +++ b/archive_v1/src/applications/netarkivet/cleaning-v1/content_filtering/count_domains.py @@ -15,7 +15,7 @@ import shutil from collections import Counter from pathlib import Path -from typing import List, Tuple, Union +from typing import Union import pandas as pd from wasabi import msg @@ -26,7 +26,7 @@ def get_paths( nested: bool = True, folder_suffix=".parquet", file_suffix=".parquet", -) -> Union[List[str], dict]: +) -> Union[list[str], dict]: """extracts paths from netarkivet either in a nested format""" folders = [ os.path.join(path, f) for f in os.listdir(path) if f.endswith(folder_suffix) @@ -44,7 +44,7 @@ def get_paths( ] -def split_mult_extension(path: str) -> Tuple[str, str]: +def split_mult_extension(path: str) -> tuple[str, str]: """An extension of os.path.splitext which splits extentions until there is none left""" ext = "" @@ -56,8 +56,10 @@ def split_mult_extension(path: str) -> Tuple[str, str]: path = path_ -def process(path, lang_codes={"da"}): +def process(path, lang_codes=None): """process a single file path, calculating domain counts and timestamps""" + if lang_codes is None: + lang_codes = {"da"} df = pd.read_parquet(path, engine="pyarrow") # filter @@ -74,7 +76,7 @@ def main(n_process: int, read_path: str, write_path: str): """Applied process to each file obtained using get_paths (all of netarkivet). Write {script_name}__SUCCESS file when finished. Will skill already processed folders.""" paths = get_paths(read_path) - list(paths.keys())[0] + next(iter(paths.keys())) if n_process == -1: n_process = mp.cpu_count() diff --git a/archive_v1/src/applications/netarkivet/cleaning-v1/content_filtering/filter_domains.py b/archive_v1/src/applications/netarkivet/cleaning-v1/content_filtering/filter_domains.py index 221f47a2..a7b6e01e 100644 --- a/archive_v1/src/applications/netarkivet/cleaning-v1/content_filtering/filter_domains.py +++ b/archive_v1/src/applications/netarkivet/cleaning-v1/content_filtering/filter_domains.py @@ -12,7 +12,6 @@ import json import os from collections import Counter, defaultdict -from typing import List from pysafebrowsing import SafeBrowsing from wasabi import msg @@ -27,7 +26,7 @@ daily_api_calls = 10_000 -def sum_counters(counters: List[Counter]) -> Counter: +def sum_counters(counters: list[Counter]) -> Counter: """ Recursive counter with a O(log(n)) Complexity """ @@ -42,7 +41,7 @@ def sum_counters(counters: List[Counter]) -> Counter: def read_counter_json(path): """read jsons consisting of multiple counters in as one counter""" - with open(path, "r") as f: + with open(path) as f: c = json.load(f) return sum_counters([Counter(counts) for counts in c]) @@ -50,7 +49,7 @@ def read_counter_json(path): def main(): # load in existing lookups if any if os.path.isfile(previous_lookups): - with open(previous_lookups, "r") as f: + with open(previous_lookups) as f: prev = json.load(f) else: msg.warn("Did not find any previous lookups.") @@ -76,19 +75,19 @@ def main(): domains = [(dom, count) for dom, count in counter.most_common() if count > 1] # load api key - with open(API_key_path, "r") as f: + with open(API_key_path) as f: key = f.read() safebrowse = SafeBrowsing(key) n = len(domains) msg.info( f"A total of {len(counter)} unique domains and {n} unique domains with more" - + " than one entry." + + " than one entry.", ) msg.info( f"There was a total of {n_domains_entries} entries and" + f" {n_domains_entries - sum([c for d,c in domains])}" - + " were unique domain entries." + + " were unique domain entries.", ) # call safe search API diff --git a/archive_v1/src/applications/netarkivet/cleaning-v1/create_metadata_csv.py b/archive_v1/src/applications/netarkivet/cleaning-v1/create_metadata_csv.py index 68ad8313..105ebffd 100644 --- a/archive_v1/src/applications/netarkivet/cleaning-v1/create_metadata_csv.py +++ b/archive_v1/src/applications/netarkivet/cleaning-v1/create_metadata_csv.py @@ -30,7 +30,8 @@ def word_count(batch): y_path = path / str(year) / "*.jsonl" j_files = glob.glob(str(y_path)) j_files = sorted( - j_files, key=lambda path: int(os.path.splitext(path)[0].split("/")[-1]) + j_files, + key=lambda path: int(os.path.splitext(path)[0].split("/")[-1]), ) for f in j_files: f_, ext = os.path.splitext(f) @@ -45,7 +46,7 @@ def word_count(batch): ds_subset = ds.remove_columns( [ c - for c in ds.features.keys() + for c in ds.features if c not in { "text", @@ -56,7 +57,7 @@ def word_count(batch): "language", "domain_key", } - ] + ], ) ds_subset = ds_subset.map( word_count, diff --git a/archive_v1/src/applications/netarkivet/cleaning-v1/dedupe.py b/archive_v1/src/applications/netarkivet/cleaning-v1/dedupe.py index 980c3efb..949378e1 100644 --- a/archive_v1/src/applications/netarkivet/cleaning-v1/dedupe.py +++ b/archive_v1/src/applications/netarkivet/cleaning-v1/dedupe.py @@ -11,16 +11,15 @@ import glob import os +from collections.abc import Iterable from functools import partial -from typing import Iterable import ndjson import psutil +from dfm.cleaning import Deduper from psutil._common import bytes2human from wasabi import msg -from dfm.cleaning import Deduper - def filter_example(example, already_checked): """check whether sample i should be included""" @@ -36,8 +35,10 @@ def get_id_from_path(x): def create_paths( - years=[2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016] + years=None, ): + if years is None: + years = [2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016] for year in years: read_path = os.path.join("/work", "netarkivet-cleaned", f"{year}") path = os.path.join(read_path, "**", "*.jsonl") diff --git a/archive_v1/src/applications/netarkivet/cleaning-v1/desc_stats.py b/archive_v1/src/applications/netarkivet/cleaning-v1/desc_stats.py index ead564ff..edf0f61e 100644 --- a/archive_v1/src/applications/netarkivet/cleaning-v1/desc_stats.py +++ b/archive_v1/src/applications/netarkivet/cleaning-v1/desc_stats.py @@ -12,7 +12,6 @@ import os import sys from collections import Counter -from typing import List import pandas as pd @@ -20,12 +19,12 @@ sys.path.append(dfm_path) -from src.applications.netarkivet.content_filtering.count_domains_netarkivet import ( # noqa E402 +from src.applications.netarkivet.content_filtering.count_domains_netarkivet import ( # E402 get_paths, ) -def sum_counters(counters: List[Counter]) -> Counter: +def sum_counters(counters: list[Counter]) -> Counter: """ Recursive counter with a O(log(n)) Complexity """ @@ -40,7 +39,7 @@ def sum_counters(counters: List[Counter]) -> Counter: def read_counter_json(path): """read jsons consisting of multiple counters in as one counter""" - with open(path, "r") as f: + with open(path) as f: c = json.load(f) return sum_counters([Counter(counts) for counts in c]) diff --git a/archive_v1/src/applications/netarkivet/cleaning-v1/extract_summary_statistics.py b/archive_v1/src/applications/netarkivet/cleaning-v1/extract_summary_statistics.py index 2a3b96eb..a20c7a54 100644 --- a/archive_v1/src/applications/netarkivet/cleaning-v1/extract_summary_statistics.py +++ b/archive_v1/src/applications/netarkivet/cleaning-v1/extract_summary_statistics.py @@ -76,7 +76,7 @@ def split_mask(mask, year_file): domains.update(domains_) lang_sites += Counter(ds["language"]) n_passed_quality_filter_ = len( - [i for i in ds["passed_quality_filter"] if i is True] + [i for i in ds["passed_quality_filter"] if i is True], ) n_passed_quality_filter += n_passed_quality_filter_ n_tokens += sum(ds["n_tokens"]) diff --git a/archive_v1/src/applications/netarkivet/cleaning-v1/quality_filter.py b/archive_v1/src/applications/netarkivet/cleaning-v1/quality_filter.py index e4c07c7f..fbd8d6ad 100644 --- a/archive_v1/src/applications/netarkivet/cleaning-v1/quality_filter.py +++ b/archive_v1/src/applications/netarkivet/cleaning-v1/quality_filter.py @@ -13,11 +13,10 @@ from pathlib import Path from datasets import load_dataset -from tqdm import tqdm -from wasabi import msg - from dfm.cleaning import QualityFilter from dfm.utils import batch +from tqdm import tqdm +from wasabi import msg def filter_batch(batch, i): diff --git a/archive_v1/src/applications/train/run_mlm_pytorch_stream.py b/archive_v1/src/applications/train/run_mlm_pytorch_stream.py index 1360d5a5..cad6d47d 100644 --- a/archive_v1/src/applications/train/run_mlm_pytorch_stream.py +++ b/archive_v1/src/applications/train/run_mlm_pytorch_stream.py @@ -34,7 +34,7 @@ import sys from dataclasses import dataclass, field from itertools import chain -from typing import Optional, Tuple, Union +from typing import Optional, Union import datasets import transformers @@ -68,14 +68,14 @@ class ModelArguments: metadata={ "help": ( "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch." - ) + ), }, ) model_type: Optional[str] = field( default=None, metadata={ "help": "If training from scratch, pass a model type from the list: " - + ", ".join(MODEL_TYPES) + + ", ".join(MODEL_TYPES), }, ) config_overrides: Optional[str] = field( @@ -84,37 +84,37 @@ class ModelArguments: "help": ( "Override some existing default config settings when a model is trained from scratch. Example: " "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" - ) + ), }, ) config_name: Optional[str] = field( default=None, metadata={ - "help": "Pretrained config name or path if not the same as model_name" + "help": "Pretrained config name or path if not the same as model_name", }, ) tokenizer_name: Optional[str] = field( default=None, metadata={ - "help": "Pretrained tokenizer name or path if not the same as model_name" + "help": "Pretrained tokenizer name or path if not the same as model_name", }, ) cache_dir: Optional[str] = field( default=None, metadata={ - "help": "Where do you want to store the pretrained models downloaded from huggingface.co" + "help": "Where do you want to store the pretrained models downloaded from huggingface.co", }, ) use_fast_tokenizer: bool = field( default=True, metadata={ - "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not." + "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.", }, ) model_revision: str = field( default="main", metadata={ - "help": "The specific model version to use (can be a branch name, tag name or commit id)." + "help": "The specific model version to use (can be a branch name, tag name or commit id).", }, ) use_auth_token: bool = field( @@ -123,7 +123,7 @@ class ModelArguments: "help": ( "Will use the token generated when running `transformers-cli login` (necessary to use this script " "with private models)." - ) + ), }, ) @@ -132,7 +132,7 @@ def __post_init__(self): self.config_name is not None or self.model_name_or_path is not None ): raise ValueError( - "--config_overrides can't be used in combination with --config_name or --model_name_or_path" + "--config_overrides can't be used in combination with --config_name or --model_name_or_path", ) @@ -147,7 +147,7 @@ class DataTrainingArguments: dataset_config_name: Optional[str] = field( default=None, metadata={ - "help": "The configuration name of the dataset to use (via the datasets library)." + "help": "The configuration name of the dataset to use (via the datasets library).", }, ) streaming: bool = field( @@ -155,12 +155,13 @@ class DataTrainingArguments: metadata={"help": "Whether to load the dataset using streaming"}, ) train_file: Optional[str] = field( - default=None, metadata={"help": "The input training data file (a text file)."} + default=None, + metadata={"help": "The input training data file (a text file)."}, ) validation_file: Optional[str] = field( default=None, metadata={ - "help": "An optional input evaluation data file to evaluate the perplexity on (a text file)." + "help": "An optional input evaluation data file to evaluate the perplexity on (a text file).", }, ) overwrite_cache: bool = field( @@ -171,7 +172,7 @@ class DataTrainingArguments: default=5, metadata={ "help": "The percentage of the train set used as validation set in case" - + "there's no validation split. If streaming is True then this will be the count" + + "there's no validation split. If streaming is True then this will be the count", }, ) max_seq_length: Optional[int] = field( @@ -180,7 +181,7 @@ class DataTrainingArguments: "help": ( "The maximum total input sequence length after tokenization. Sequences longer " "than this will be truncated." - ) + ), }, ) preprocessing_num_workers: Optional[int] = field( @@ -194,7 +195,7 @@ class DataTrainingArguments: line_by_line: bool = field( default=False, metadata={ - "help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences." + "help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences.", }, ) pad_to_max_length: bool = field( @@ -203,7 +204,7 @@ class DataTrainingArguments: "help": ( "Whether to pad all samples to `max_seq_length`. " "If False, will pad the samples dynamically when batching to the maximum length in the batch." - ) + ), }, ) max_train_samples: Optional[int] = field( @@ -212,7 +213,7 @@ class DataTrainingArguments: "help": ( "For debugging purposes or quicker training, truncate the number of training examples to this " "value if set." - ) + ), }, ) max_eval_samples: Optional[int] = field( @@ -221,7 +222,7 @@ class DataTrainingArguments: "help": ( "For debugging purposes or quicker training, truncate the number of evaluation examples to this " "value if set." - ) + ), }, ) @@ -232,25 +233,26 @@ def __post_init__(self): and self.validation_file is None ): raise ValueError( - "Need either a dataset name or a training/validation file." + "Need either a dataset name or a training/validation file.", ) else: if self.train_file is not None: extension = self.train_file.split(".")[-1] if extension not in ["csv", "json", "txt"]: raise ValueError( - "`train_file` should be a csv, a json or a txt file." + "`train_file` should be a csv, a json or a txt file.", ) if self.validation_file is not None: extension = self.validation_file.split(".")[-1] if extension not in ["csv", "json", "txt"]: raise ValueError( - "`validation_file` should be a csv, a json or a txt file." + "`validation_file` should be a csv, a json or a txt file.", ) def get_dataset( - model_args: ModelArguments, data_args: DataTrainingArguments + model_args: ModelArguments, + data_args: DataTrainingArguments, ) -> DatasetDict: """Get the datasets. @@ -284,7 +286,7 @@ def get_dataset( use_auth_token=True if model_args.use_auth_token else None, streaming=data_args.streaming, ) - if "validation" not in raw_datasets.keys(): + if "validation" not in raw_datasets: if data_args.streaming is False: raw_datasets["validation"] = load_dataset( data_args.dataset_name, @@ -304,10 +306,10 @@ def get_dataset( ) else: raw_datasets["validation"] = raw_datasets["train"].take( - data_args.validation_split + data_args.validation_split, ) raw_datasets["train"] = raw_datasets["train"].skip( - data_args.validation_split + data_args.validation_split, ) else: @@ -329,7 +331,7 @@ def get_dataset( ) # If no validation data is there, validation_split will be used to divide the dataset. - if "validation" not in raw_datasets.keys(): + if "validation" not in raw_datasets: if data_args.streaming is False: raw_datasets["validation"] = load_dataset( extension, @@ -349,10 +351,10 @@ def get_dataset( ) else: raw_datasets["validation"] = raw_datasets["train"].take( - data_args.validation_split + data_args.validation_split, ) raw_datasets["train"] = raw_datasets["train"].skip( - data_args.validation_split + data_args.validation_split, ) # See more about loading any type of standard or custom dataset (from files, python @@ -363,7 +365,7 @@ def get_dataset( def get_tokenizer_and_model( model_args: ModelArguments, -) -> Tuple[AutoTokenizer, AutoModelForMaskedLM]: +) -> tuple[AutoTokenizer, AutoModelForMaskedLM]: """Load pretrained model and tokenizer. Distributed training: @@ -386,7 +388,8 @@ def get_tokenizer_and_model( config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) elif model_args.model_name_or_path: config = AutoConfig.from_pretrained( - model_args.model_name_or_path, **config_kwargs + model_args.model_name_or_path, + **config_kwargs, ) else: config = CONFIG_MAPPING[model_args.model_type]() @@ -404,16 +407,18 @@ def get_tokenizer_and_model( } if model_args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained( - model_args.tokenizer_name, **tokenizer_kwargs + model_args.tokenizer_name, + **tokenizer_kwargs, ) elif model_args.model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( - model_args.model_name_or_path, **tokenizer_kwargs + model_args.model_name_or_path, + **tokenizer_kwargs, ) else: raise ValueError( "You are instantiating a new tokenizer from scratch. This is not supported by this script." - "You can do it from another script, save it, and load it from here, using --tokenizer_name." + "You can do it from another script, save it, and load it from here, using --tokenizer_name.", ) if model_args.model_name_or_path: @@ -433,7 +438,7 @@ def get_tokenizer_and_model( return tokenizer, model -def preprocess_dataset( # noqa: C901 +def preprocess_dataset( data_args: DataTrainingArguments, training_args: TrainingArguments, raw_datasets: DatasetDict, @@ -460,14 +465,14 @@ def preprocess_dataset( # noqa: C901 if max_seq_length > 1024: logger.warning( f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " - "Picking 1024 instead. You can change that default value by passing --max_seq_length xxx." + "Picking 1024 instead. You can change that default value by passing --max_seq_length xxx.", ) max_seq_length = 1024 else: if data_args.max_seq_length > tokenizer.model_max_length: logger.warning( f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" - f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}.", ) max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) @@ -499,17 +504,18 @@ def tokenize_function(examples): # efficient when it receives the `special_tokens_mask`. def tokenize_function(examples): return tokenizer( - examples[text_column_name], return_special_tokens_mask=True + examples[text_column_name], + return_special_tokens_mask=True, ) desc = "Running tokenizer on every text in dataset" with training_args.main_process_first(desc="dataset map tokenization"): - _map_config = dict( - function=tokenize_function, - batched=True, - remove_columns=column_names, - ) + _map_config = { + "function": tokenize_function, + "batched": True, + "remove_columns": column_names, + } if data_args.streaming is False: _map_config["num_proc"] = data_args.preprocessing_num_workers _map_config["load_from_cache_file"] = not data_args.overwrite_cache @@ -521,10 +527,8 @@ def tokenize_function(examples): # max_seq_length. def group_texts(examples): # Concatenate all texts. - concatenated_examples = { - k: list(chain(*examples[k])) for k in examples.keys() - } - total_length = len(concatenated_examples[list(examples.keys())[0]]) + concatenated_examples = {k: list(chain(*examples[k])) for k in examples} + total_length = len(concatenated_examples[next(iter(examples.keys()))]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # customize this part to your needs. if total_length >= max_seq_length: @@ -547,10 +551,10 @@ def group_texts(examples): # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map with training_args.main_process_first(desc="grouping texts together"): - _map_config = dict( - function=group_texts, - batched=True, - ) + _map_config = { + "function": group_texts, + "batched": True, + } if data_args.streaming is False: _map_config["num_proc"] = (data_args.preprocessing_num_workers,) _map_config["load_from_cache_file"] = (not data_args.overwrite_cache,) @@ -592,7 +596,7 @@ def train( max_train_samples = data_args.max_train_samples if max_train_samples is None: raise ValueError( - "When specifying --streaming, then you must also specify --max_train_samples" + "When specifying --streaming, then you must also specify --max_train_samples", ) metrics["train_samples"] = data_args.max_train_samples else: @@ -660,19 +664,19 @@ def evaluate( return kwargs -def main(): # noqa +def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser( - (ModelArguments, DataTrainingArguments, TrainingArguments) + (ModelArguments, DataTrainingArguments, TrainingArguments), ) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file( - json_file=os.path.abspath(sys.argv[1]) + json_file=os.path.abspath(sys.argv[1]), ) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() @@ -694,7 +698,7 @@ def main(): # noqa # Log on each process the small summary: logger.warning( f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" - + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}", ) # Set the verbosity to info of the Transformers logger (on main process only): logger.info(f"Training/evaluation parameters {training_args}") @@ -710,14 +714,14 @@ def main(): # noqa if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty. " - "Use --overwrite_output_dir to overcome." + "Use --overwrite_output_dir to overcome.", ) elif ( last_checkpoint is not None and training_args.resume_from_checkpoint is None ): logger.info( f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " - "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch.", ) # Set seed before initializing model. @@ -728,7 +732,10 @@ def main(): # noqa tokenizer, model = get_tokenizer_and_model(model_args) tokenized_datasets = preprocess_dataset( - data_args, training_args, raw_datasets, tokenizer + data_args, + training_args, + raw_datasets, + tokenizer, ) if training_args.do_train: @@ -784,20 +791,20 @@ def compute_metrics(eval_preds): ) # Initialize our Trainer - _training_args = dict( - model=model, - args=training_args, - train_dataset=train_dataset if training_args.do_train else None, - eval_dataset=eval_dataset if training_args.do_eval else None, - tokenizer=tokenizer, - data_collator=data_collator, - compute_metrics=compute_metrics + _training_args = { + "model": model, + "args": training_args, + "train_dataset": train_dataset if training_args.do_train else None, + "eval_dataset": eval_dataset if training_args.do_eval else None, + "tokenizer": tokenizer, + "data_collator": data_collator, + "compute_metrics": compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None, - preprocess_logits_for_metrics=preprocess_logits_for_metrics + "preprocess_logits_for_metrics": preprocess_logits_for_metrics if training_args.do_eval and not is_torch_tpu_available() else None, - ) + } if data_args.streaming: # convert to pytorch iterable dataset diff --git a/archive_v1/src/dfm/cleaning/clean_cli.py b/archive_v1/src/dfm/cleaning/clean_cli.py index 581294c5..d0c30318 100644 --- a/archive_v1/src/dfm/cleaning/clean_cli.py +++ b/archive_v1/src/dfm/cleaning/clean_cli.py @@ -27,12 +27,11 @@ import hydra from datasets import load_dataset from datasets.utils import disable_progress_bar +from dfm.cleaning import QualityFilter, SentenceFilter from omegaconf import DictConfig, OmegaConf from tqdm import tqdm from tqdm.contrib.logging import logging_redirect_tqdm -from dfm.cleaning import QualityFilter, SentenceFilter - CFG_PATH = Path(__file__).parent / "configs" VALID_SAVE_FORMATS = { "parquet": "parquet", @@ -120,7 +119,6 @@ def apply_quality_filter(batch: dict, cfg: DictConfig) -> dict: qf = create_quality_filter(cfg) if cfg.save_meta_data: - valid_langs = set(cfg.valid_languages) if valid_langs: @@ -189,7 +187,6 @@ def apply_sentence_filter(batch: dict, cfg: DictConfig) -> dict: sf = create_sentence_filter(cfg) if cfg.save_meta_data: - valid_langs = set(cfg.valid_languages) if valid_langs: @@ -210,9 +207,7 @@ def filter_lang(batch, i): batch[cfg.text_col] = sf(batch[cfg.text_col]) # create meta data columns - batch["passed_sentence_filter"] = [ - True if t else False for t in batch[cfg.text_col] - ] + batch["passed_sentence_filter"] = [bool(t) for t in batch[cfg.text_col]] return batch batch[cfg.text_col] = sf(batch[cfg.text_col]) @@ -278,13 +273,11 @@ def process_files(path: Path, cfg: DictConfig) -> None: logging.debug("The columns of the dataset is:\n %s", dataset.column_names) # filter languages: - if not cfg.save_meta_data: - if cfg.valid_languages and cfg.lang_col: - valid_langs = set(cfg.valid_languages) - dataset.filter(lambda example: example[cfg.lang_col] in valid_langs) + if not cfg.save_meta_data and cfg.valid_languages and cfg.lang_col: + valid_langs = set(cfg.valid_languages) + dataset.filter(lambda example: example[cfg.lang_col] in valid_langs) if cfg.apply_sentence_filter: - dataset = dataset.map( lambda batch: apply_sentence_filter(batch, cfg), batched=True, @@ -349,28 +342,24 @@ def main(cfg: DictConfig) -> None: with open(save_dir / "clean_config.yaml", "w", encoding="utf-8") as f: OmegaConf.save(cfg, f) - if cfg.num_proc == -1: - num_proc = mp.cpu_count() - 1 - else: - num_proc = cfg.num_proc + num_proc = mp.cpu_count() - 1 if cfg.num_proc == -1 else cfg.num_proc paths = glob(cfg.path) # check save path file extension if cfg.save_file_ext not in VALID_SAVE_FORMATS: raise ValueError( - f"Invalid save path file extension. Must be one of {VALID_SAVE_FORMATS}" + f"Invalid save path file extension. Must be one of {VALID_SAVE_FORMATS}", ) # group paths in batches files = tqdm(paths, desc="files") _process_files = partial(process_files, cfg=cfg) - with logging_redirect_tqdm(): - with mp.Pool(num_proc) as pool: - # process files in parallel - for _ in pool.imap_unordered(_process_files, files, chunksize=1): - pass + with logging_redirect_tqdm(), mp.Pool(num_proc) as pool: + # process files in parallel + for _ in pool.imap_unordered(_process_files, files, chunksize=1): + pass logging.info("Finished cleaning %s files", len(paths)) diff --git a/archive_v1/src/dfm/cleaning/dedupe_cli.py b/archive_v1/src/dfm/cleaning/dedupe_cli.py index a31be0b7..1801e142 100644 --- a/archive_v1/src/dfm/cleaning/dedupe_cli.py +++ b/archive_v1/src/dfm/cleaning/dedupe_cli.py @@ -20,21 +20,21 @@ import logging import multiprocessing as mp +from collections.abc import Generator, Iterable from glob import glob from pathlib import Path -from typing import Callable, Generator, Iterable, Union +from typing import Callable, Union import datasets import hydra import ndjson from datasets import Dataset, load_dataset from datasets.utils import disable_progress_bar +from dfm.cleaning import Deduper from omegaconf import DictConfig, OmegaConf from tqdm import tqdm from tqdm.contrib.logging import logging_redirect_tqdm -from dfm.cleaning import Deduper - CFG_PATH = Path(__file__).parent / "configs" VALID_SAVE_FORMATS = { "parquet": "parquet", @@ -86,7 +86,10 @@ def __iter__(self): def dataset_to_disk( - dataset: Union[Dataset, Iterable[dict]], path: Path, ext: str, streaming: bool + dataset: Union[Dataset, Iterable[dict]], + path: Path, + ext: str, + streaming: bool, ) -> None: """Save a dataset to disk @@ -104,7 +107,7 @@ def dataset_to_disk( if streaming and ext not in {"jsonl", "json"}: raise ValueError( f"Streaming is only supported for jsonl files not for {ext}. " - "Please use a different save format." + "Please use a different save format.", ) if streaming: # write each row to path as a jsonl file @@ -141,13 +144,12 @@ def create_dataset_generator(path: Union[Path, str]) -> Generator[dict, None, No if ext not in {"json", "jsonl"}: raise ValueError( "Only json and jsonl files are supported when use_huggingface_loader is" - + f" False, not {ext}. Please use a different save format." + + f" False, not {ext}. Please use a different save format.", ) - with open(path, "r") as file: # pylint: disable=unspecified-encoding + with open(path) as file: # pylint: disable=unspecified-encoding reader = ndjson.reader(file) - for post in reader: - yield post + yield from reader def process_path(path: Union[Path, str], deduper: Deduper, cfg: DictConfig) -> None: @@ -170,7 +172,10 @@ def process_path(path: Union[Path, str], deduper: Deduper, cfg: DictConfig) -> N ext = VALID_SAVE_FORMATS[file_ext[1:]] # remove the "." if cfg.use_huggingface_loader: dataset = load_dataset( - ext, data_files=str(path), split="train", streaming=cfg.streaming + ext, + data_files=str(path), + split="train", + streaming=cfg.streaming, ) else: dataset = multigen(create_dataset_generator)(path) @@ -197,7 +202,7 @@ def process_path(path: Union[Path, str], deduper: Deduper, cfg: DictConfig) -> N if cfg.use_huggingface_loader: dataset_filtered = dataset.filter( - lambda x: x["passed_quality_filter"] is True + lambda x: x["passed_quality_filter"] is True, ) else: dataset_filtered = ( @@ -255,7 +260,7 @@ def process_path(path: Union[Path, str], deduper: Deduper, cfg: DictConfig) -> N # filter out duplicates if cfg.use_huggingface_loader: dataset_deduplicated = dataset_dedup.filter( - lambda x: x["is_duplicate"] is False + lambda x: x["is_duplicate"] is False, ) else: dataset_deduplicated = ( @@ -266,7 +271,10 @@ def process_path(path: Union[Path, str], deduper: Deduper, cfg: DictConfig) -> N # save dataset with new file extension dataset_to_disk( - dataset_deduplicated, save_path, cfg.save_file_ext, streaming=cfg.streaming + dataset_deduplicated, + save_path, + cfg.save_file_ext, + streaming=cfg.streaming, ) @@ -294,17 +302,15 @@ def main(cfg: DictConfig) -> None: logging.basicConfig(filename=save_dir / "deduplication.log", level=logging.INFO) if cfg.verbosity_level == 2: logging.basicConfig( - filename=save_dir / "deduplication.log", level=logging.DEBUG + filename=save_dir / "deduplication.log", + level=logging.DEBUG, ) # save config to folder with open(save_dir / "deduplication_config.yaml", "w", encoding="utf-8") as file: OmegaConf.save(cfg, file) - if cfg.num_proc == -1: - num_proc = mp.cpu_count() - 1 - else: - num_proc = cfg.num_proc + num_proc = mp.cpu_count() - 1 if cfg.num_proc == -1 else cfg.num_proc paths = glob(cfg.path) paths = [ diff --git a/archive_v1/src/dfm/cleaning/deduper.py b/archive_v1/src/dfm/cleaning/deduper.py index c80ed7aa..5e6c60b6 100644 --- a/archive_v1/src/dfm/cleaning/deduper.py +++ b/archive_v1/src/dfm/cleaning/deduper.py @@ -18,9 +18,10 @@ import multiprocessing as mp import pickle import shutil +from collections.abc import Iterable from functools import partial from pathlib import Path -from typing import Callable, Dict, Iterable, Optional, Tuple, Union +from typing import Callable, Optional, Union import more_itertools as mit from datasets.arrow_dataset import Dataset @@ -111,17 +112,19 @@ def __init__( self.verbose = verbose self.save_mask = save_mask if save_mask: - self.mask = list() + self.mask = [] self.lsh_cache = MinHashLSH( - threshold=self.similarity_threshold, num_perm=self.num_minhashes + threshold=self.similarity_threshold, + num_perm=self.num_minhashes, ) def reset(self): """Reset the deduplicator, removing the mask and the LSH cache""" if self.save_mask: - self.mask = list() + self.mask = [] self.lsh_cache = MinHashLSH( - threshold=self.similarity_threshold, num_perm=self.num_minhashes + threshold=self.similarity_threshold, + num_perm=self.num_minhashes, ) return self @@ -158,7 +161,7 @@ def load_from_disk(cls, directory: Union[str, Path]) -> "Deduper": # Load the mask if it exists mask_path = directory / "mask.jsonl" if mask_path.exists(): - with open(mask_path, "r") as f: + with open(mask_path) as f: mask = [json.loads(line) for line in f] deduper.mask = mask @@ -175,18 +178,18 @@ def get_config(self) -> dict: Returns: dict: The configuration of the deduplicator. """ - config = dict( - split_method=self.split_method, - ngram_size=self.ngram_size, - ngram_stride=self.ngram_stride, - similarity_threshold=self.similarity_threshold, - num_minhashes=self.num_minhashes, - batch_size=self.batch_size, - n_jobs=self.n_jobs, - random_seed=self.random_seed, - normalization_func=self.normalization_func, - verbose=self.verbose, - ) + config = { + "split_method": self.split_method, + "ngram_size": self.ngram_size, + "ngram_stride": self.ngram_stride, + "similarity_threshold": self.similarity_threshold, + "num_minhashes": self.num_minhashes, + "batch_size": self.batch_size, + "n_jobs": self.n_jobs, + "random_seed": self.random_seed, + "normalization_func": self.normalization_func, + "verbose": self.verbose, + } return config def _store_document(self, output_path: Union[str, Path], **kwargs): @@ -211,8 +214,8 @@ def deduplicate( corpus: Union[ Dataset, IterableDataset, - Iterable[Tuple[Union[str, int], str]], - Iterable[Dict[str, Union[str, int]]], + Iterable[tuple[Union[str, int], str]], + Iterable[dict[str, Union[str, int]]], ], id_column: str = "id", text_column: str = "text", @@ -278,6 +281,7 @@ def deduplicate( else: for _ in iterable: pass + return None def save_to_disk( self, @@ -316,7 +320,7 @@ def save_to_disk( "the files. If you are loading an existing " "Deduper from the directory then the previous " "config, mask and LSH cache will still will " - "not be lost and will be stored in the directory." + "not be lost and will be stored in the directory.", ) elif output_dir.exists() and overwrite: # Delete the output directory @@ -344,13 +348,13 @@ def save_to_disk( with config_path.open("wb") as f: pickle.dump(config, f) - def _deduplicate( # noqa: C901 + def _deduplicate( self, corpus: Union[ Dataset, IterableDataset, - Iterable[Tuple[Union[str, int], str]], - Iterable[Dict[str, Union[str, int]]], + Iterable[tuple[Union[str, int], str]], + Iterable[dict[str, Union[str, int]]], ], id_column: str = "id", text_column: str = "text", @@ -404,13 +408,12 @@ def _deduplicate( # noqa: C901 # If the corpus is a Dataset or IterableDataset then convert it to an # iterable of tuples - if isinstance(corpus, Dataset) or isinstance(corpus, IterableDataset): + if isinstance(corpus, (Dataset, IterableDataset)): corpus = ((sample[id_column], sample[text_column]) for sample in corpus) # Otherwise we check if the corpus is an iterable of dictionaries, in # which case we also convert it to an iterable of tuples else: - # extract the first element of the corpus corpus = iter(corpus) sample = next(corpus) @@ -463,17 +466,15 @@ def _deduplicate( # noqa: C901 # Iterate over the corpus and store documents that are not duplicates duplicates = 0 num_processed = 0 - pbar_params = dict( - desc="Deduplicating", - total=num_docs, - disable=(not self.verbose), - leave=False, - ) + pbar_params = { + "desc": "Deduplicating", + "total": num_docs, + "disable": (not self.verbose), + "leave": False, + } with tqdm(batches, **pbar_params) as pbar: - # Initialise the multiprocessing with Parallel(n_jobs=self.n_jobs) as parallel: - # Define the function that will be called in parallel fn = delayed( partial( @@ -484,12 +485,11 @@ def _deduplicate( # noqa: C901 ngram_stride=self.ngram_stride, num_minhashes=self.num_minhashes, random_seed=self.random_seed, - ) + ), ) # Iterate over the batches for batch in pbar: - # Create a copy of the batch to ensure that we're not # modifying the original batch, batch_copy = it.tee(batch) @@ -501,9 +501,11 @@ def _deduplicate( # noqa: C901 batch_size = new_num_processed - num_processed # Define parameters used in batch progress bars - pbar_params = dict( - total=batch_size, leave=False, disable=(not self.verbose) - ) + pbar_params = { + "total": batch_size, + "leave": False, + "disable": (not self.verbose), + } # Compute the fingerprint for the document pbar_params["desc"] = "Computing minhashes" @@ -514,13 +516,11 @@ def _deduplicate( # noqa: C901 pbar_params["desc"] = "Deduplicating batch" with tqdm(batch_copy, **pbar_params) as batch_pbar: for (idx, doc), minhash in zip(batch_pbar, minhashes): - # If the document is not a near-duplicate candidate # then store in the LSH cache and append it to the # JSONL output file candidates = self.lsh_cache.query(minhash) if len(candidates) == 0: - # Insert the document into the LSH cache self.lsh_cache.insert(idx, minhash) @@ -528,11 +528,13 @@ def _deduplicate( # noqa: C901 # output if store_corpus_to_disk: self._store_document( - id=idx, text=doc, output_path=output_path + id=idx, + text=doc, + output_path=output_path, ) # Compute the mask for the document - mask_entry = dict(id=idx, duplicate=False) + mask_entry = {"id": idx, "duplicate": False} # Otherwise, increment the number of duplicate # documents @@ -540,7 +542,7 @@ def _deduplicate( # noqa: C901 duplicates += 1 # Compute the mask for the document - mask_entry = dict(id=idx, duplicate=True) + mask_entry = {"id": idx, "duplicate": True} # Add the mask to the mask attribute if self.save_mask: @@ -553,7 +555,8 @@ def _deduplicate( # noqa: C901 # Store the mask to disk if store_mask_to_disk: self._store_document( - output_path=mask_path, **mask_entry + output_path=mask_path, + **mask_entry, ) # Store the LSH cache to disk diff --git a/archive_v1/src/dfm/cleaning/deduper_utils.py b/archive_v1/src/dfm/cleaning/deduper_utils.py index 55ecef2c..0da7a935 100644 --- a/archive_v1/src/dfm/cleaning/deduper_utils.py +++ b/archive_v1/src/dfm/cleaning/deduper_utils.py @@ -5,7 +5,7 @@ """ import re -from typing import Callable, List +from typing import Callable from unicodedata import normalize from datasketch import LeanMinHash, MinHash @@ -17,7 +17,7 @@ def get_shingles( split_method: str, ngram_size: int, ngram_stride: int, -) -> List[str]: +) -> list[str]: """Extracts the shingles from a document. Args: diff --git a/archive_v1/src/dfm/cleaning/quality_filter.py b/archive_v1/src/dfm/cleaning/quality_filter.py index 91f10db6..dc2fa029 100644 --- a/archive_v1/src/dfm/cleaning/quality_filter.py +++ b/archive_v1/src/dfm/cleaning/quality_filter.py @@ -23,8 +23,9 @@ """ from collections import Counter, defaultdict +from collections.abc import Iterable, Sequence from functools import partial -from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Optional, Union import spacy from spacy.tokens import Doc @@ -52,7 +53,7 @@ def duplicate_chr_fraction_getter(doc: Doc, attr: str) -> float: return frac -def n_gram_counter(doc: Doc, ngram_range: Tuple[int, int]) -> Dict[str, Counter]: +def n_gram_counter(doc: Doc, ngram_range: tuple[int, int]) -> dict[str, Counter]: """Calculate the counts of n-grams in the specified range. Args: @@ -98,7 +99,10 @@ def duplicate_fraction_getter(doc: Doc, attr: str = "lines_counter") -> float: def set_dynamic_ext( - ext_name: str, func: Callable, dynamic_ext_prefix: str = "_", object=Doc + ext_name: str, + func: Callable, + dynamic_ext_prefix: str = "_", + object=Doc, ) -> None: """Sets a dynamic extension to reduce redundant computation. @@ -240,8 +244,8 @@ class QualityFilter: def __init__( self, min_stop_words: int = 2, - mean_word_length: Tuple[int, int] = (3, 10), - doc_length: Tuple[int, int] = (50, 100_000), + mean_word_length: tuple[int, int] = (3, 10), + doc_length: tuple[int, int] = (50, 100_000), alpha_ratio: float = 0.6, symbol_2_word_hashtag: float = 0.1, symbol_2_word_ellipsis: float = 0.1, @@ -251,34 +255,32 @@ def __init__( min_ellipsis: int = 2, duplicate_lines_chr_fraction: float = 0.2, duplicate_paragraph_chr_fraction: float = 0.2, - top_ngram_chr_fraction_thresholds: List[float] = [0.20, 0.18, 0.16], - top_ngram_chr_fraction_range: Tuple[int, int] = (2, 4), + top_ngram_chr_fraction_thresholds: Optional[list[float]] = None, + top_ngram_chr_fraction_range: tuple[int, int] = (2, 4), top_ngram_min_count: int = 3, - duplicate_n_gram_fraction_thresholds: List[float] = [ - 0.25, - 0.24, - 0.23, - 0.22, - 0.21, - 0.20, - ], - duplicate_n_gram_fraction_range: Tuple[int, int] = (5, 10), + duplicate_n_gram_fraction_thresholds: Optional[list[float]] = None, + duplicate_n_gram_fraction_range: tuple[int, int] = (5, 10), max_length: int = 5_000_000, string_filter: Optional[str] = None, - ignore_filters: List[str] = [], + ignore_filters: Optional[list[str]] = None, language_detection_tool: str = "luga", language_threshold: float = 0.90, languages: Sequence[str] = ["da"], short_long_sentence_length_split: int = 30, short_long_sentence_threshold: float = 0.5, ): - + if ignore_filters is None: + ignore_filters = [] + if duplicate_n_gram_fraction_thresholds is None: + duplicate_n_gram_fraction_thresholds = [0.25, 0.24, 0.23, 0.22, 0.21, 0.2] + if top_ngram_chr_fraction_thresholds is None: + top_ngram_chr_fraction_thresholds = [0.2, 0.18, 0.16] __available_language_detection_tools = ["langdetect", "luga"] if language_detection_tool not in __available_language_detection_tools: raise AttributeError( f"{language_detection_tool} is not a valid language detection " - f"packages - must be in {__available_language_detection_tools}" + f"packages - must be in {__available_language_detection_tools}", ) # Load Danish spaCy model @@ -291,15 +293,20 @@ def __init__( self.filters = { "doc_length": partial(self.doc_length, doc_length=doc_length), "mean_word_length": partial( - self.mean_word_length, mean_word_length=mean_word_length + self.mean_word_length, + mean_word_length=mean_word_length, ), "alpha_ratio": partial(self.alpha, ratio=alpha_ratio), "stop_word": partial(self.stop_word, n=min_stop_words), "symbol_2_word_hashtag": partial( - self.symbol_2_word, ratio=symbol_2_word_hashtag, symbol="#" + self.symbol_2_word, + ratio=symbol_2_word_hashtag, + symbol="#", ), "symbol_2_word_ellipsis": partial( - self.symbol_2_word, ratio=symbol_2_word_ellipsis, symbol="…" + self.symbol_2_word, + ratio=symbol_2_word_ellipsis, + symbol="…", ), "line_bullets_or_ellipsis": partial( self.line_bullets_or_ellipsis, @@ -309,7 +316,8 @@ def __init__( min_ellipsis=min_ellipsis, ), "duplicate_lines_chr_fraction": partial( - self.duplicate_lines_chr_filter, fraction=duplicate_lines_chr_fraction + self.duplicate_lines_chr_filter, + fraction=duplicate_lines_chr_fraction, ), "duplicate_paragraph_chr_fraction": partial( self.duplicate_paragraph_chr_fraction_filter, @@ -341,7 +349,8 @@ def __init__( if string_filter: self.filters["string_filter"] = partial( - self.string_filter, string=string_filter + self.string_filter, + string=string_filter, ) for f in ignore_filters: @@ -369,7 +378,8 @@ def __set_extensions(self) -> None: set_dynamic_ext("lines_counter", func=lambda doc: Counter(doc._.lines)) set_dynamic_ext( - "paragraphs_counter", func=lambda doc: Counter(doc._.paragraphs) + "paragraphs_counter", + func=lambda doc: Counter(doc._.paragraphs), ) set_dynamic_ext( @@ -394,8 +404,11 @@ def __set_extensions(self) -> None: set_dynamic_ext("chr_len", func=lambda doc: len(doc.text)) def filter_corpus( - self, texts: Iterable[str], as_tuples: bool = False, **kwargs - ) -> Union[Iterable[str], Iterable[Tuple[str, Union[Any, None]]]]: + self, + texts: Iterable[str], + as_tuples: bool = False, + **kwargs, + ) -> Union[Iterable[str], Iterable[tuple[str, Union[Any, None]]]]: """Applies quality filter. Args: @@ -441,8 +454,10 @@ def filter_corpus( break def __call__( - self, *args, **kwargs - ) -> Union[Iterable[str], Iterable[Tuple[str, Union[Any, None]]]]: + self, + *args, + **kwargs, + ) -> Union[Iterable[str], Iterable[tuple[str, Union[Any, None]]]]: """Applies quality filter. Args: @@ -474,6 +489,7 @@ def is_filtered(self, doc: Doc) -> Optional[str]: # log filtered documents self.filtered[filter] += 1 return filter + return None def describe_filter(self, texts: Iterable[tuple], **kwargs) -> Iterable[str]: """ @@ -563,7 +579,9 @@ def detect_language( """ def luga_detect( - doc: Doc, languages: Sequence[str], language_threshold: float + doc: Doc, + languages: Sequence[str], + language_threshold: float, ) -> bool: from luga import language @@ -572,17 +590,16 @@ def luga_detect( sentence.strip() for sentence in doc.text.split("\n") if len(sentence.strip()) > 0 - ] + ], ) detected = language(doc_joined) # type: ignore lang, score = detected.name, detected.score - if score >= language_threshold and lang in languages: - return True - else: - return False + return bool(score >= language_threshold and lang in languages) def langdetect_detect( - doc: Doc, languages: Sequence[str], language_threshold: float + doc: Doc, + languages: Sequence[str], + language_threshold: float, ) -> bool: from langdetect import detect_langs @@ -592,7 +609,7 @@ def langdetect_detect( return True return False - detectors: Dict[str, Callable[[Doc, Sequence[str], float], bool]] = { + detectors: dict[str, Callable[[Doc, Sequence[str], float], bool]] = { "luga": luga_detect, "langdetect": langdetect_detect, } @@ -600,7 +617,7 @@ def langdetect_detect( return detectors[language_detection_tool](doc, languages, language_threshold) @staticmethod - def doc_length(doc: Doc, doc_length: Tuple[int, int]) -> bool: + def doc_length(doc: Doc, doc_length: tuple[int, int]) -> bool: """ A filter that removes any document that does not contain between {doc_length[0]} and {doc_length[1]} words @@ -616,7 +633,7 @@ def doc_length(doc: Doc, doc_length: Tuple[int, int]) -> bool: return doc_length[0] <= doc._.n_words <= doc_length[1] @staticmethod - def mean_word_length(doc: Doc, mean_word_length: Tuple[int, int]) -> bool: + def mean_word_length(doc: Doc, mean_word_length: tuple[int, int]) -> bool: """ Filter document whose mean word length is outside the range of {mean_word_length[0]} to {mean_word_length[1]} characters @@ -652,10 +669,7 @@ def alpha(doc: Doc, ratio: float) -> bool: """ def contains_alpha_fn(token: str): - for c in token: - if c.isalpha(): - return True - return False + return any(c.isalpha() for c in token) # min number of word to satisfy the ratio min_alpha_token = int(doc._.n_words * ratio) @@ -770,8 +784,8 @@ def string_filter(doc: Doc, string: Optional[str] = None) -> bool: @staticmethod def duplicate_ngram_fraction_filter( doc: Doc, - ngram_range: Tuple[int, int], - thresholds: List[float], + ngram_range: tuple[int, int], + thresholds: list[float], ) -> bool: """calculates the character fraction of duplicate n-gram over the overall text, taking care not to count overlapping n-grams twice. @@ -802,7 +816,6 @@ def duplicate_ngram_fraction_filter( for i, _ in enumerate(doc): for ngram_size in range(lower, upper + 1): - min_, max_ = minmax[ngram_size] end = i + ngram_size @@ -840,8 +853,8 @@ def duplicate_ngram_fraction_filter( @staticmethod def top_ngram_chr_fraction_filter( doc: Doc, - ngram_range: Tuple[int, int], - thresholds: List[float], + ngram_range: tuple[int, int], + thresholds: list[float], min_count: int, ) -> bool: """Calculated whether the character fraction of the top n-grams is below the @@ -949,7 +962,7 @@ def duplicate_line_filter(doc: Doc, fraction: float) -> bool: t0 = time.time() # Filter the texts - for doc in filtered_docs: + for _doc in filtered_docs: pass # Record the time taken diff --git a/archive_v1/src/dfm/cleaning/sentence_filter.py b/archive_v1/src/dfm/cleaning/sentence_filter.py index d3f0a2f0..797753c5 100644 --- a/archive_v1/src/dfm/cleaning/sentence_filter.py +++ b/archive_v1/src/dfm/cleaning/sentence_filter.py @@ -12,7 +12,8 @@ import multiprocessing as mp from collections import Counter -from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Tuple, Union +from collections.abc import Iterable, Sequence +from typing import Any, Callable, Optional, Union import emoji from joblib import Parallel, delayed @@ -72,7 +73,6 @@ def __init__( curly_brackets_threshold: int = 2, n_jobs: int = -1, ): - # Store arguments as attributes self.title_cased_words_threshold = title_cased_words_threshold self.min_num_words = min_num_words @@ -80,15 +80,15 @@ def __init__( self.n_jobs = n_jobs # Create a dictionary with all the sentence filters - _all_filters: Dict[str, Callable[[str], bool]] = dict( - ends_with_punctuation_or_emoji=self._ends_with_punctuation_or_emoji, - has_few_title_cased_words=self._has_few_title_cased_words, - has_enough_words=self._has_enough_words, - has_few_curly_brackets=self._has_few_curly_brackets, - ) + _all_filters: dict[str, Callable[[str], bool]] = { + "ends_with_punctuation_or_emoji": self._ends_with_punctuation_or_emoji, + "has_few_title_cased_words": self._has_few_title_cased_words, + "has_enough_words": self._has_enough_words, + "has_few_curly_brackets": self._has_few_curly_brackets, + } # Create variable storing the filters to be used - self.filters: Dict[str, Callable[[str], bool]] = dict() + self.filters: dict[str, Callable[[str], bool]] = {} if filter_names is None: self.filters = _all_filters else: @@ -99,12 +99,12 @@ def __init__( # Create a counter for keeping track of how many documents each filter removed self.filter_counts = Counter() - def filter_corpus( # noqa: C901 + def filter_corpus( self, - texts: Union[Iterable[str], Iterable[Tuple[str, Optional[Any]]]], + texts: Union[Iterable[str], Iterable[tuple[str, Optional[Any]]]], progress_bar: bool = True, total: Optional[int] = None, - ) -> Union[Iterable[str], Iterable[Tuple[str, Union[Any, None]]]]: + ) -> Union[Iterable[str], Iterable[tuple[str, Union[Any, None]]]]: """Filters a corpus using the sentence filters. Args: @@ -127,8 +127,8 @@ def filter_corpus( # noqa: C901 docs = iter(texts) def filter_sample( - sample: Union[str, Tuple[str, Optional[Any]]] - ) -> Union[str, Tuple[str, Optional[Any]]]: + sample: Union[str, tuple[str, Optional[Any]]], + ) -> Union[str, tuple[str, Optional[Any]]]: """Filter a sample. Args: @@ -148,7 +148,7 @@ def filter_sample( context = None else: raise TypeError( - f"Expected either a string or a tuple, got {type(sample)}." + f"Expected either a string or a tuple, got {type(sample)}.", ) # Split the document into sentences, splitting on newlines @@ -177,16 +177,12 @@ def filter_sample( return new_doc # Main filtering loop - if self.n_jobs == -1: - n_jobs = mp.cpu_count() - 1 - else: - n_jobs = self.n_jobs + n_jobs = mp.cpu_count() - 1 if self.n_jobs == -1 else self.n_jobs if n_jobs == 1: for doc in docs: yield filter_sample(doc) else: with Parallel(n_jobs=n_jobs, backend="threading") as parallel: - # Set up iterator, depending on whether we have a progress bar or not if progress_bar: itr = tqdm(docs, desc="Filtering corpus", total=total) @@ -202,8 +198,10 @@ def filter_sample( itr.close() def __call__( - self, *args, **kwargs - ) -> Union[Iterable[str], Iterable[Tuple[str, Union[Any, None]]]]: + self, + *args, + **kwargs, + ) -> Union[Iterable[str], Iterable[tuple[str, Union[Any, None]]]]: """Calls the `filter_corpus` method on the inputs. Args: @@ -233,7 +231,6 @@ def apply_filters(self, doc: str) -> Union[str, None]: """ # Iterate over all the filter functions for filter_name, filter_fn in self.filters.items(): - # Apply the filter function, which returns True if the document satisfied # the filter, and False if it didn't satisfied_filter = filter_fn(doc) @@ -245,6 +242,7 @@ def apply_filters(self, doc: str) -> Union[str, None]: if not satisfied_filter: self.filter_counts[filter_name] += 1 return filter_name + return None def _ends_with_punctuation_or_emoji(self, sentence: str) -> bool: """Checks if a sentence ends with punctuation or an emoji. @@ -258,7 +256,7 @@ def _ends_with_punctuation_or_emoji(self, sentence: str) -> bool: True if the sentence ends with punctuation, False otherwise. """ # Initialise the list of emojis - emojis = list() + emojis = [] # Add all unicode emojis as well as the codes for them, like :blush:. The codes # are the keys in the `emoji.UNICODE_EMOJI["en"]` dictionary, and the emojis @@ -371,7 +369,7 @@ def _has_few_curly_brackets(self, sentence: str) -> bool: t0 = time.time() # Filter the texts - for doc in filtered_docs: + for _doc in filtered_docs: pass # Record the time taken diff --git a/archive_v1/src/dfm/data/load_datasets.py b/archive_v1/src/dfm/data/load_datasets.py index ee571256..23807417 100644 --- a/archive_v1/src/dfm/data/load_datasets.py +++ b/archive_v1/src/dfm/data/load_datasets.py @@ -2,9 +2,10 @@ Datasets loaders for DFM datasets. """ +from collections.abc import Iterable from functools import partial from pathlib import Path -from typing import Dict, Iterable, List, Optional, Union +from typing import Optional, Union from datasets import DatasetDict, IterableDatasetDict, interleave_datasets, load_dataset @@ -20,7 +21,8 @@ def __add_column(example, value, column: str): def __select_columns( - dataset: Union[IterableDatasetDict, DatasetDict], columns: List[str] + dataset: Union[IterableDatasetDict, DatasetDict], + columns: list[str], ) -> Union[IterableDatasetDict, DatasetDict]: """Select columns in dataset to keep and removes the rest. @@ -34,16 +36,16 @@ def __select_columns( desired columns. """ # extract a sample from a subset (typically train) to get column names. - subset = dataset[list(dataset.keys())[0]] + subset = dataset[next(iter(dataset.keys()))] sample = next(iter(subset)) - col_to_remove = [c_name for c_name in sample.keys() if c_name not in columns] + col_to_remove = [c_name for c_name in sample if c_name not in columns] return dataset.remove_columns(col_to_remove) def load_hopetwitter( path_to_hopetwitter: Union[str, Path] = HOPETWITTER_PATH, - columns_to_keep: Optional[List[str]] = None, + columns_to_keep: Optional[list[str]] = None, n_training_repeats: int = 1, **kwargs, ) -> IterableDatasetDict: @@ -94,7 +96,7 @@ def load_hopetwitter( def load_dagw_dfm( path_to_dagw: Union[str, Path] = DAGW_DFM_PATH, - columns_to_keep: Optional[List[str]] = None, + columns_to_keep: Optional[list[str]] = None, n_training_repeats: int = 1, **kwargs, ) -> IterableDatasetDict: @@ -133,7 +135,7 @@ def load_dagw_dfm( def load_danews( path_to_danews: Union[str, Path] = DANEWS_PATH, - columns_to_keep: Optional[List[str]] = None, + columns_to_keep: Optional[list[str]] = None, n_training_repeats: int = 1, **kwargs, ) -> IterableDatasetDict: @@ -177,8 +179,8 @@ def load_danews( def load_nat( path_to_nat: Union[str, Path] = NAT_PATH, years: Iterable[int] = range(2006, 2017), - probabilities: Optional[List[float]] = None, - columns_to_keep: Optional[List[str]] = None, + probabilities: Optional[list[float]] = None, + columns_to_keep: Optional[list[str]] = None, n_training_repeats: int = 10, seed: Optional[int] = None, **kwargs, @@ -234,23 +236,13 @@ def load_nat( def load_dcc( version: str = "1.0.0", - probabilities: Dict[str, float] = { - "danews": 0.06, - "dagw_dfm": 0.06, - "hopetwitter": 0.03, - "nat": 0.85, - }, - n_training_repeats: Dict[str, int] = { - "danews": 1_000, - "dagw_dfm": 1_000, - "hopetwitter": 1_000, - "nat": 100, - }, + probabilities: Optional[dict[str, float]] = None, + n_training_repeats: Optional[dict[str, int]] = None, path_to_hopetwitter: Union[str, Path] = HOPETWITTER_PATH, path_to_dagw: Union[str, Path] = DAGW_DFM_PATH, path_to_danews: Union[str, Path] = DANEWS_PATH, path_to_nat: Union[str, Path] = NAT_PATH, - columns_to_keep: List[str] = ["text", "source"], + columns_to_keep: Optional[list[str]] = None, **kwargs, ): """ @@ -276,11 +268,27 @@ def load_dcc( Returns: IterableDatasetDict: A datasets IterableDatasetDict """ + if columns_to_keep is None: + columns_to_keep = ["text", "source"] + if n_training_repeats is None: + n_training_repeats = { + "danews": 1000, + "dagw_dfm": 1000, + "hopetwitter": 1000, + "nat": 100, + } + if probabilities is None: + probabilities = { + "danews": 0.06, + "dagw_dfm": 0.06, + "hopetwitter": 0.03, + "nat": 0.85, + } versions_options = ["1.0.0"] if version != "1.0.0": raise ValueError( "Version {version} is not available. Available versions" - + f": {versions_options}" + + f": {versions_options}", ) datasets = {} @@ -346,5 +354,5 @@ def load_dcc( "train": train, "validation": dcc_test_val["validation"], "test": dcc_test_val["test"], - } + }, ) diff --git a/archive_v1/src/dfm/dataset_validation/rating_interface.py b/archive_v1/src/dfm/dataset_validation/rating_interface.py index 2cede005..31f31a19 100644 --- a/archive_v1/src/dfm/dataset_validation/rating_interface.py +++ b/archive_v1/src/dfm/dataset_validation/rating_interface.py @@ -1,6 +1,6 @@ import curses import os -from typing import Iterable +from collections.abc import Iterable class ExampleRater: @@ -61,13 +61,13 @@ def main_window(self, win, example_str, is_porn=False, is_offensive=False): win.addstr("Tags (non-exclusive):\n") win.addstr(f"{left_spacing*2}[P]orn: {self.sign_from_bool(is_porn)} \n") win.addstr( - f"{left_spacing*2}[O]ffensive: {self.sign_from_bool(is_offensive)}" + f"{left_spacing*2}[O]ffensive: {self.sign_from_bool(is_offensive)}", ) win.addstr("\n" * 2) win.addstr("Category (exclusive):\n") win.addstr( - f"{left_spacing*2}[N]ot language | [W]rong language | [C]orrect language | [S]kip \n" + f"{left_spacing*2}[N]ot language | [W]rong language | [C]orrect language | [S]kip \n", ) win.addstr(f"{left_spacing*2}[U]ndo") win.addstr("\n" * 2) @@ -104,11 +104,11 @@ def undo_one_item(self, current_item_str, is_porn, is_offensive): "example_str": current_item_str, "is_porn": is_porn, "is_offensive": is_offensive, - } + }, ) # Remove the line from the .c - with open(self.output_path, "r") as f: + with open(self.output_path) as f: lines = f.readlines() lines_to_write = lines[:-1] with open(self.output_path, "w") as f: @@ -131,7 +131,7 @@ def process_example(self, category, is_porn, is_offensive, example_str): "example_str": example_str, "is_porn": is_porn, "is_offensive": is_offensive, - } + }, ) self.write_example_to_csv( diff --git a/archive_v1/src/dfm/description/description_patterns.py b/archive_v1/src/dfm/description/description_patterns.py index ed4107a9..f77ca6e3 100644 --- a/archive_v1/src/dfm/description/description_patterns.py +++ b/archive_v1/src/dfm/description/description_patterns.py @@ -1,5 +1,3 @@ -from typing import Dict, List - from .match_counter import MatchCounter # Terms for religions are: @@ -17,7 +15,7 @@ "kristendomme", "kristendommen", "kristendommene", - ] + ], }, { "jew": [ @@ -30,7 +28,7 @@ "jødedommen", "jødedomme", "jødedommene", - ] + ], }, { "buddhist": [ @@ -43,7 +41,7 @@ "buddhismen", "buddhismer", "buddhismerne", - ] + ], }, { "hindu": [ @@ -60,7 +58,7 @@ "hinduismen", "hinduismer", "hinduismerne", - ] + ], }, { "atheist": [ @@ -73,14 +71,15 @@ "ahteismen", "atheismer", "atheismerne", - ] + ], }, ] def get_religion_patterns(): return MatchCounter.list_of_labelled_term_lists_to_spacy_match_patterns( - religion_labelled_match_patterns, label_prefix="rel_" + religion_labelled_match_patterns, + label_prefix="rel_", ) @@ -99,7 +98,7 @@ def get_religion_patterns(): "administratoren", "administratorer", "administratorerne", - ] + ], }, { "analytiker": [ @@ -107,7 +106,7 @@ def get_religion_patterns(): "analytikeren", "analytikere", "analytikerne", - ] + ], }, { "arkitekt": [ @@ -115,7 +114,7 @@ def get_religion_patterns(): "arkitekten", "arkitekter", "arkitekterne", - ] + ], }, { "assistent": [ @@ -123,7 +122,7 @@ def get_religion_patterns(): "assistenten", "assistenter", "assistenterne", - ] + ], }, { "bager": [ @@ -131,7 +130,7 @@ def get_religion_patterns(): "bageren", "bagere", "bagerne", - ] + ], }, { "bartender": [ @@ -139,7 +138,7 @@ def get_religion_patterns(): "bartenderen", "bartendere", "bartenderne", - ] + ], }, { "ejendomsmægler": [ @@ -147,7 +146,7 @@ def get_religion_patterns(): "ejendomsmægleren", "ejendomsmæglere", "ejendomsmæglerne", - ] + ], }, { "tømrer": [ @@ -155,7 +154,7 @@ def get_religion_patterns(): "tømreren", "tømrere", "tømrerne", - ] + ], }, { "kassemedarbejder": [ @@ -163,7 +162,7 @@ def get_religion_patterns(): "kassemedearbejderen", "kassemedarbejdere", "kassemedarbejderne", - ] + ], }, { "kok": [ @@ -171,7 +170,7 @@ def get_religion_patterns(): "kokken", "kokke", "kokkene", - ] + ], }, { "kemiker": [ @@ -179,7 +178,7 @@ def get_religion_patterns(): "kemikeren", "kemikere", "kemikerne", - ] + ], }, { "chef": [ @@ -187,7 +186,7 @@ def get_religion_patterns(): "chefen", "chefer", "cheferne", - ] + ], }, { "rengøringshjælp": [ @@ -195,7 +194,7 @@ def get_religion_patterns(): "rengøringshjælpen", "rengøringshjælpere", "rengøringshjælperne", - ] + ], }, { "ekspedient": [ @@ -203,7 +202,7 @@ def get_religion_patterns(): "ekspedienten", "ekspedienter", "ekspedienterne", - ] + ], }, { "terapeut": [ @@ -211,7 +210,7 @@ def get_religion_patterns(): "terapeuten", "terapeuter", "terapeuterne", - ] + ], }, { "advokat": [ @@ -219,7 +218,7 @@ def get_religion_patterns(): "advokaten", "advokater", "advokaterne", - ] + ], }, { "diætist": [ @@ -227,7 +226,7 @@ def get_religion_patterns(): "diætisten", "diætister", "diætisterne", - ] + ], }, { "læge": [ @@ -235,7 +234,7 @@ def get_religion_patterns(): "lægen", "læger", "lægerne", - ] + ], }, { "chauffør": [ @@ -243,7 +242,7 @@ def get_religion_patterns(): "chaufføren", "chauffører", "chaufførerne", - ] + ], }, { "redaktør": [ @@ -251,7 +250,7 @@ def get_religion_patterns(): "redatøren", "redaktører", "redaktørerne", - ] + ], }, { "elektriker": [ @@ -259,7 +258,7 @@ def get_religion_patterns(): "elektrikeren", "elektrikere", "elektrikerne", - ] + ], }, { "ingeniør": [ @@ -267,7 +266,7 @@ def get_religion_patterns(): "ingeniøren", "ingeniører", "ingeniørerne", - ] + ], }, { "landmand": [ @@ -275,7 +274,7 @@ def get_religion_patterns(): "landmanden", "landmænd", "landmændene", - ] + ], }, { "brandmand": [ @@ -283,7 +282,7 @@ def get_religion_patterns(): "brandmanden", "brandmænd", "brandmændene", - ] + ], }, { "vagt": [ @@ -291,7 +290,7 @@ def get_religion_patterns(): "vagten", "vagter", "vagterne", - ] + ], }, { "frisør": [ @@ -299,7 +298,7 @@ def get_religion_patterns(): "frisøren", "frisører", "frisørerne", - ] + ], }, { "instruktør": [ @@ -307,7 +306,7 @@ def get_religion_patterns(): "instruktøren", "instruktører", "instruktørerne", - ] + ], }, { "efterforsker": [ @@ -315,7 +314,7 @@ def get_religion_patterns(): "efterforskeren", "efterforskere", "efterforskerne", - ] + ], }, { "pedel": [ @@ -323,7 +322,7 @@ def get_religion_patterns(): "pedellen", "pedeller", "pedellerne", - ] + ], }, { "advokat": [ @@ -331,7 +330,7 @@ def get_religion_patterns(): "advokaten", "advokater", "advokaterne", - ] + ], }, { "biliotekar": [ @@ -339,7 +338,7 @@ def get_religion_patterns(): "bibliotekaren", "bibliotekarer", "bibliotekarerne", - ] + ], }, { "mekaniker": [ @@ -347,7 +346,7 @@ def get_religion_patterns(): "makanikeren", "mekanikere", "mekanikerne", - ] + ], }, { "sygeplejerske": [ @@ -355,7 +354,7 @@ def get_religion_patterns(): "sygeplersken", "sygeplejersker", "sygeplejeskerne", - ] + ], }, { "politibetjent": [ @@ -363,7 +362,7 @@ def get_religion_patterns(): "politibetjenten", "politibetjente", "politibetjentene", - ] + ], }, { "maler": [ @@ -371,7 +370,7 @@ def get_religion_patterns(): "maleren", "malerne", "malere", - ] + ], }, { "ambulanceredder": [ @@ -379,7 +378,7 @@ def get_religion_patterns(): "ambulanceredderen", "ambulancereddere", "ambulanceredderne", - ] + ], }, { "ambulancebehandler": [ @@ -387,7 +386,7 @@ def get_religion_patterns(): "ambulancebehandleren", "ambulancebehandlere", "ambulancebehandlerne", - ] + ], }, { "patolog": [ @@ -395,7 +394,7 @@ def get_religion_patterns(): "patologen", "patologer", "patologerne", - ] + ], }, { "farmaceut": [ @@ -403,7 +402,7 @@ def get_religion_patterns(): "farmaceuten", "farmaceuter", "farmaceuterne", - ] + ], }, { "blikkenslager": [ @@ -411,7 +410,7 @@ def get_religion_patterns(): "blikkenslageren", "blikkenslagere", "blikkenslagerne", - ] + ], }, { "programmør": [ @@ -419,7 +418,7 @@ def get_religion_patterns(): "programmøren", "programmører", "programmørerne", - ] + ], }, { "psykolog": [ @@ -427,7 +426,7 @@ def get_religion_patterns(): "psykologen", "psykologer", "psykologerne", - ] + ], }, { "receptionist": [ @@ -435,7 +434,7 @@ def get_religion_patterns(): "receptionisten", "receptionister", "receptionisterne", - ] + ], }, { "sekretær": [ @@ -443,7 +442,7 @@ def get_religion_patterns(): "sekretæren", "sekretærer", "sekretærerne", - ] + ], }, { "kirurg": [ @@ -451,7 +450,7 @@ def get_religion_patterns(): "kirurgen", "kirurger", "kirurgerne", - ] + ], }, { "skrædder": [ @@ -459,7 +458,7 @@ def get_religion_patterns(): "skrædderen", "skræddere", "skrædderne", - ] + ], }, { "tekniker": [ @@ -467,7 +466,7 @@ def get_religion_patterns(): "teknikeren", "teknikere", "teknikerne", - ] + ], }, { "terapeut": [ @@ -475,7 +474,7 @@ def get_religion_patterns(): "terapeuten", "terapeuter", "terapeuterne", - ] + ], }, { "dyrlæge": [ @@ -483,7 +482,7 @@ def get_religion_patterns(): "dyrlægen", "dyrlæger", "dyrlægerne", - ] + ], }, { "forfatter": [ @@ -491,311 +490,309 @@ def get_religion_patterns(): "forfatteren", "forfattere", "forfatterne", - ] + ], }, ] def get_occupation_patterns(): return MatchCounter.list_of_labelled_term_lists_to_spacy_match_patterns( - occupation_labelled_match_patterns, label_prefix="occu_" + occupation_labelled_match_patterns, + label_prefix="occu_", ) # List is a partial translation of Rae et al. 2022, p. 95 -female_gendered_terms = set( - [ - "pige", - "pigen", - "piger", - "pigerne", - "søster", - "søsteren", - "søstere", - "søsterne", - "mor", - "moren", - "mødre", - "mødrene", - "kone", - "konen", - "koner", - "konerne", - "brud", - "bruden", - "brude", - "brudene", - "dame", - "damen", - "damer", - "damerne", - "datter", - "datteren", - "døtre", - "døtrene", - ] -) +female_gendered_terms = { + "pige", + "pigen", + "piger", + "pigerne", + "søster", + "søsteren", + "søstere", + "søsterne", + "mor", + "moren", + "mødre", + "mødrene", + "kone", + "konen", + "koner", + "konerne", + "brud", + "bruden", + "brude", + "brudene", + "dame", + "damen", + "damer", + "damerne", + "datter", + "datteren", + "døtre", + "døtrene", +} def get_female_gendered_patterns(): return MatchCounter.term_list_to_spacy_match_patterns( - female_gendered_terms, label="gender_female_terms" + female_gendered_terms, + label="gender_female_terms", ) -male_gendered_terms = set( - [ - "dreng", - "drengen", - "drenge", - "drengene", - "bror", - "broren", - "brødre", - "brødrene", - "far", - "faren", - "fædre", - "fædrene", - "mand", - "manden", - "mænd", - "mændene", - "brudgom", - "brudgommen", - "brudgomme", - "brudgommene", - "herre", - "herren", - "herrer", - "herrerne", - "søn", - "sønnen", - "sønner", - "sønnerne", - ] -) +male_gendered_terms = { + "dreng", + "drengen", + "drenge", + "drengene", + "bror", + "broren", + "brødre", + "brødrene", + "far", + "faren", + "fædre", + "fædrene", + "mand", + "manden", + "mænd", + "mændene", + "brudgom", + "brudgommen", + "brudgomme", + "brudgommene", + "herre", + "herren", + "herrer", + "herrerne", + "søn", + "sønnen", + "sønner", + "sønnerne", +} def get_male_gendered_patterns(): return MatchCounter.term_list_to_spacy_match_patterns( - male_gendered_terms, label="gender_male_terms" + male_gendered_terms, + label="gender_male_terms", ) -danish_adult_words = set( - [ - "amatør", - "anal", - "anus", - "babes", - "bdsm", - "begær", - "bestialitet", - "blodig", - "blowjob", - "bordel", - "bordeller", - "bryster", - "bøsse", - "bøssefilm", - "c-skål", - "damer", - "dating", - "dildo", - "dildoer", - "dildomaskine", - "dyrisk", - "ejakulation", - "ejakulere", - "ejakulerede", - "ejakulerer", - "elskerinde", - "endetarm", - "erotik", - "erotisk", - "erotiske", - "escort", - "escortpige", - "escortpiger", - "escortpigerne", - "fanden", - "fisse", - "fisser", - "fræk", - "frække", - "frækt", - "fucked", - "fucker", - "gangbang", - "gay", - "hardcore", - "hentai", - "homo", - "hore", - "intim", - "intime", - "kinky", - "klitoris", - "kneppe", - "kusse", - "kvinder", - "latex", - "latino", - "lesbisk", - "liderlig", - "liderlige", - "lort", - "lorte", - "luder", - "masochist", - "massage", - "massageescort", - "massageklinik", - "massagen", - "massagepige", - "massagepiger", - "massagepigerne", - "milf", - "nigger", - "niggere", - "nøgenbillede", - "nøgenbilleder", - "nøgenbillederne", - "nøgne", - "onanere", - "orgasme", - "orgasmer", - "patter", - "pecker", - "penis", - "piger", - "pigesex", - "pik", - "pis", - "pisse", - "pisser", - "pisses", - "porn", - "porno", - "porno-casting", - "pornofilm", - "pornografi", - "pornostar", - "pornostjerne", - "pornostjernen", - "pornostjerner", - "prostitueret", - "røv", - "røvhul", - "røvhuller", - "sadist", - "samleje", - "sex", - "sexcam", - "sexdating", - "sexdatingsites", - "sexfilm", - "sexfoto", - "sexhistorier", - "sexparadis", - "sexshop", - "sexstillinger", - "sexvideo", - "sexvideoer", - "sexvideoerne", - "sexvideoen", - "sexy", - "shemale", - "shemales", - "sjofelhed", - "sjofelheder", - "sjofelhederne", - "skamlæber", - "skider", - "sluger", - "sm", - "spanking", - "sprække", - "sprøjteorgasme", - "sprøjteorgasmer", - "sprøjteorgasmen", - "sprøjteorgasmerne", - "strip", - "svans", - "swinger", - "swingerdating", - "swingerklub", - "sæd", - "sædafgang", - "tantra", - "telefonsex", - "testikel", - "thai", - "thaimassage", - "thaipiger", - "tranny", - "tæve", - "tæver", - "tøs", - "tøser", - "urin", - "vagina", - "vaginaen", - "viagra", - "viagraen", - "voldtage", - "voldtager", - "voldtægt" "vulva", - "webcam", - "webcam-chat", - "x-bedømt", - "xxx", - ] -) - - -def get_muslim_name_patterns() -> List[Dict[str, list]]: +danish_adult_words = { + "amatør", + "anal", + "anus", + "babes", + "bdsm", + "begær", + "bestialitet", + "blodig", + "blowjob", + "bordel", + "bordeller", + "bryster", + "bøsse", + "bøssefilm", + "c-skål", + "damer", + "dating", + "dildo", + "dildoer", + "dildomaskine", + "dyrisk", + "ejakulation", + "ejakulere", + "ejakulerede", + "ejakulerer", + "elskerinde", + "endetarm", + "erotik", + "erotisk", + "erotiske", + "escort", + "escortpige", + "escortpiger", + "escortpigerne", + "fanden", + "fisse", + "fisser", + "fræk", + "frække", + "frækt", + "fucked", + "fucker", + "gangbang", + "gay", + "hardcore", + "hentai", + "homo", + "hore", + "intim", + "intime", + "kinky", + "klitoris", + "kneppe", + "kusse", + "kvinder", + "latex", + "latino", + "lesbisk", + "liderlig", + "liderlige", + "lort", + "lorte", + "luder", + "masochist", + "massage", + "massageescort", + "massageklinik", + "massagen", + "massagepige", + "massagepiger", + "massagepigerne", + "milf", + "nigger", + "niggere", + "nøgenbillede", + "nøgenbilleder", + "nøgenbillederne", + "nøgne", + "onanere", + "orgasme", + "orgasmer", + "patter", + "pecker", + "penis", + "piger", + "pigesex", + "pik", + "pis", + "pisse", + "pisser", + "pisses", + "porn", + "porno", + "porno-casting", + "pornofilm", + "pornografi", + "pornostar", + "pornostjerne", + "pornostjernen", + "pornostjerner", + "prostitueret", + "røv", + "røvhul", + "røvhuller", + "sadist", + "samleje", + "sex", + "sexcam", + "sexdating", + "sexdatingsites", + "sexfilm", + "sexfoto", + "sexhistorier", + "sexparadis", + "sexshop", + "sexstillinger", + "sexvideo", + "sexvideoer", + "sexvideoerne", + "sexvideoen", + "sexy", + "shemale", + "shemales", + "sjofelhed", + "sjofelheder", + "sjofelhederne", + "skamlæber", + "skider", + "sluger", + "sm", + "spanking", + "sprække", + "sprøjteorgasme", + "sprøjteorgasmer", + "sprøjteorgasmen", + "sprøjteorgasmerne", + "strip", + "svans", + "swinger", + "swingerdating", + "swingerklub", + "sæd", + "sædafgang", + "tantra", + "telefonsex", + "testikel", + "thai", + "thaimassage", + "thaipiger", + "tranny", + "tæve", + "tæver", + "tøs", + "tøser", + "urin", + "vagina", + "vaginaen", + "viagra", + "viagraen", + "voldtage", + "voldtager", + "voldtægt" "vulva", + "webcam", + "webcam-chat", + "x-bedømt", + "xxx", +} + + +def get_muslim_name_patterns() -> list[dict[str, list]]: """Gets a list of all muslim first names in Denmark from DaCy, and converts to a list of lowercase spacy patterns. Returns: List[Dict[str, list]]: list of lowercase spacy match patterns """ from dacy.datasets import muslim_names - from dfm.description.match_counter import MatchCounter muslim_names_list = [name.lower() for name in muslim_names()["first_name"]] return MatchCounter.term_list_to_spacy_match_patterns( - term_list=muslim_names_list, label="rel_muslim_names" + term_list=muslim_names_list, + label="rel_muslim_names", ) -def get_gender_name_patterns() -> List[Dict[str, list]]: +def get_gender_name_patterns() -> list[dict[str, list]]: """Gets a list of all gendered first names in Denmark from DaCy, and converts to a list of lowercase spacy patterns. Returns: List[Dict[str, list]]: list of lowercase spacy match patterns """ from dacy.datasets import female_names, male_names - from dfm.description.match_counter import MatchCounter female_names_list = [name.lower() for name in female_names()["first_name"]] female_names_patterns = MatchCounter.term_list_to_spacy_match_patterns( - female_names_list, label="gender_female_names" + female_names_list, + label="gender_female_names", ) male_names_list = [name.lower() for name in male_names()["first_name"]] male_name_patterns = MatchCounter.term_list_to_spacy_match_patterns( - male_names_list, label="gender_male_names" + male_names_list, + label="gender_male_names", ) return female_names_patterns + male_name_patterns -def get_positive_word_patterns() -> List[Dict[str, list]]: +def get_positive_word_patterns() -> list[dict[str, list]]: """Loads a list of word- and sentiment pairs from "da_lexicon_afinn_v1.txt", splits it by tabs, then sorts them into lists depending on whether the sentiment is positive or negative. Returns: @@ -807,7 +804,7 @@ def get_positive_word_patterns() -> List[Dict[str, list]]: path = pathlib.Path(__file__).parent / "da_lexicon_afinn_v1.txt" - with open(path, "r") as f: + with open(path) as f: lines = f.readlines() lines = [line.split("\t") for line in lines] @@ -815,13 +812,14 @@ def get_positive_word_patterns() -> List[Dict[str, list]]: positive_words = [line[0] for line in lines if int(line[1]) > 0] positive_patterns = MatchCounter.term_list_to_spacy_match_patterns( - positive_words, label="positive_words" + positive_words, + label="positive_words", ) return positive_patterns -def get_negative_word_patterns() -> List[Dict[str, list]]: +def get_negative_word_patterns() -> list[dict[str, list]]: """Loads a list of word- and sentiment pairs from "da_lexicon_afinn_v1.txt", splits it by tabs, then sorts them into lists depending on whether the sentiment is positive or negative. Returns: @@ -833,7 +831,7 @@ def get_negative_word_patterns() -> List[Dict[str, list]]: path = pathlib.Path(__file__).parent / "da_lexicon_afinn_v1.txt" - with open(path, "r") as f: + with open(path) as f: lines = f.readlines() lines = [line.split("\t") for line in lines] @@ -841,7 +839,8 @@ def get_negative_word_patterns() -> List[Dict[str, list]]: negative_words = [line[0] for line in lines if int(line[1]) < 0] negative_patterns = MatchCounter.term_list_to_spacy_match_patterns( - negative_words, label="negative_words" + negative_words, + label="negative_words", ) return negative_patterns diff --git a/archive_v1/src/dfm/description/generate_description.py b/archive_v1/src/dfm/description/generate_description.py index 4a69306d..4c3a28aa 100644 --- a/archive_v1/src/dfm/description/generate_description.py +++ b/archive_v1/src/dfm/description/generate_description.py @@ -1,10 +1,8 @@ import os import time -from typing import List import spacy from datasets import load_dataset - from dfm.description.description_patterns import ( danish_adult_words, get_female_gendered_patterns, @@ -19,7 +17,7 @@ from dfm.description.match_counter import MatchCounter -def create_patterns() -> List: +def create_patterns() -> list: """Generates the patterns we've selected for the present analyses. Returns: @@ -34,7 +32,8 @@ def create_patterns() -> List: # Adult words adult_patterns = MatchCounter.term_list_to_spacy_match_patterns( - danish_adult_words, label_prefix="porn_" + danish_adult_words, + label_prefix="porn_", ) return ( @@ -59,7 +58,8 @@ def create_patterns() -> List: ds = load_dataset("DDSC/partial-danish-gigaword-no-twitter") ds_sharded = ds.shuffle()["train"].shard( - num_shards=10000, index=0 + num_shards=10000, + index=0, ) # Work on 1/100th of DGW nlp = spacy.blank("da") @@ -83,7 +83,7 @@ def create_patterns() -> List: save_path.mkdir(parents=False, exist_ok=True) # only create if needed dgw_processed = dgw_processed.remove_columns( - ["text", "doc_id", "LICENSE", "uri", "date_built"] + ["text", "doc_id", "LICENSE", "uri", "date_built"], ) # Remove irrelevant columns dgw_processed.to_csv("csv/output_100.csv") diff --git a/archive_v1/src/dfm/description/match_counter.py b/archive_v1/src/dfm/description/match_counter.py index c4307130..b6ae5fb3 100644 --- a/archive_v1/src/dfm/description/match_counter.py +++ b/archive_v1/src/dfm/description/match_counter.py @@ -1,5 +1,6 @@ from collections import defaultdict -from typing import Dict, Iterable, List, Optional +from collections.abc import Iterable +from typing import Optional from spacy.language import Language from spacy.matcher import Matcher @@ -14,18 +15,18 @@ class MatchCounter: nlp (Language): The spacy language to use """ - def __init__(self, match_patterns: List[Dict[str, list]], nlp: Language): + def __init__(self, match_patterns: list[dict[str, list]], nlp: Language): self.nlp = nlp self.matcher_objects = self.create_matcher_object_from_pattern_list( - match_patterns + match_patterns, ) @staticmethod def list_of_labelled_term_lists_to_spacy_match_patterns( - list_of_labelled_term_lists: List[Dict[str, List[str]]], + list_of_labelled_term_lists: list[dict[str, list[str]]], label_prefix: Optional[str] = "", lowercase: bool = True, - ) -> List[str]: + ) -> list[str]: """Takes a list of strings and converts it to a list of spacy match patterns Args: @@ -43,7 +44,9 @@ def list_of_labelled_term_lists_to_spacy_match_patterns( for labelled_term_list in list_of_labelled_term_lists: for label, term_list in labelled_term_list.items(): match_patterns = MatchCounter.term_list_to_spacy_match_patterns( - term_list=term_list, label_prefix=label_prefix, label=label + term_list=term_list, + label_prefix=label_prefix, + label=label, ) out_list += match_patterns @@ -51,11 +54,11 @@ def list_of_labelled_term_lists_to_spacy_match_patterns( @staticmethod def term_list_to_spacy_match_patterns( - term_list: List[str], + term_list: list[str], label_prefix: Optional[str] = "", label: Optional[str] = None, lowercase: bool = True, - ) -> List[str]: + ) -> list[str]: """Takes a list of strings and converts it to a list of spacy match patterns Args: @@ -75,17 +78,15 @@ def term_list_to_spacy_match_patterns( attribute = "LOWER" if lowercase else "TEXT" for term in term_list: - if label is None: - cur_label = label_prefix + term - else: - cur_label = label_prefix + label + cur_label = label_prefix + term if label is None else label_prefix + label out_list.append({cur_label: [{attribute: term}]}) return out_list def create_matcher_object_from_pattern_list( - self, pattern_container_list: List[Dict[str, List]] + self, + pattern_container_list: list[dict[str, list]], ) -> Matcher: """ Generates a matcher object from a list of dictionaries with {matcher_label (str): pattern (list)} @@ -114,7 +115,7 @@ def create_matcher_object_from_pattern_list( return matcher_object - def count(self, texts: Iterable[str]) -> Dict[str, List[int]]: + def count(self, texts: Iterable[str]) -> dict[str, list[int]]: """Generates counts from the match patterns in the MatchCounter object. Args: @@ -130,10 +131,11 @@ def count(self, texts: Iterable[str]) -> Dict[str, List[int]]: for doc in docs: doc_match_counts = self._get_match_counts_from_doc( - doc, self.matcher_objects + doc, + self.matcher_objects, ) - for pattern_label in doc_match_counts.keys(): + for pattern_label in doc_match_counts: pattern_match_count = doc_match_counts.get(pattern_label, 0) aggregated_match_counts[pattern_label].append(pattern_match_count) @@ -160,7 +162,7 @@ def _get_match_counts_from_doc(self, doc: Doc, matcher_object: Matcher) -> dict: counts[pattern_label] = 0 - for match_id, start, end in matcher_object(doc): + for match_id, _start, _end in matcher_object(doc): counts[self.nlp.vocab.strings[match_id]] += 1 return dict(counts) diff --git a/archive_v1/src/dfm/dfm_tokenizers/train_tokenizer.py b/archive_v1/src/dfm/dfm_tokenizers/train_tokenizer.py index 65b8ebf2..5369424e 100644 --- a/archive_v1/src/dfm/dfm_tokenizers/train_tokenizer.py +++ b/archive_v1/src/dfm/dfm_tokenizers/train_tokenizer.py @@ -1,7 +1,8 @@ """Script to train tokenizers""" +from collections.abc import Iterable from pathlib import Path -from typing import Iterable, Union +from typing import Union from datasets.arrow_dataset import Dataset from datasets.iterable_dataset import IterableDataset @@ -52,16 +53,16 @@ def train_tokenizer( # noqa C901 config = TokenizerConfig(**config) # Convert corpus to an iterable of strings if a Dataset is given - if isinstance(corpus, Dataset) or isinstance(corpus, IterableDataset): + if isinstance(corpus, (Dataset, IterableDataset)): corpus = (sample["text"] for sample in corpus) # Instantiate the tokenizer model if config.tokenizer_type == "bpe": model = models.BPE(unk_token=config.unk_token) elif config.tokenizer_type == "wordpiece": - model = models.WordPiece(unk_token=config.unk_token) # noqa + model = models.WordPiece(unk_token=config.unk_token) elif config.tokenizer_type == "unigram": - model = models.Unigram() # noqa + model = models.Unigram() # Instantiate the tokenizer tokenizer = tokenizers.Tokenizer(model) @@ -77,19 +78,19 @@ def train_tokenizer( # noqa C901 tokenizer.add_special_tokens(special_tokens) # Initialise the normalizer and add it to the tokenizer - normalizer_list = list() + normalizer_list = [] if config.nfkc_normalization: normalizer_list.append(normalizers.NFKC()) if config.lower_case: normalizer_list.append(normalizers.Lowercase()) - normalizer = normalizers.Sequence(normalizer_list) # noqa + normalizer = normalizers.Sequence(normalizer_list) tokenizer.normalizer = normalizer # Shorthand for whether a prefix whitespace should be added to words pre_ws = config.add_prefix_space # Initialise the pre-tokenizer and add it to the tokenizer - pre_tok_list = list() + pre_tok_list = [] if config.byte_level: pre_tok_list.append(pre_tokenizers.ByteLevel(add_prefix_space=pre_ws)) if config.sentence_piece: @@ -101,12 +102,12 @@ def train_tokenizer( # noqa C901 # Initialise the post-processor if config.add_sep_and_cls_tokens: - params = dict( - cls=(config.bos_token, 1), - sep=(config.eos_token, 2), - trim_offsets=True, - add_prefix_space=pre_ws, - ) + params = { + "cls": (config.bos_token, 1), + "sep": (config.eos_token, 2), + "trim_offsets": True, + "add_prefix_space": pre_ws, + } tokenizer.post_processor = processors.RobertaProcessing(**params) elif config.byte_level: tokenizer.post_processor = processors.ByteLevel(trim_offsets=True) diff --git a/archive_v1/src/dfm/modelling/preprocess.py b/archive_v1/src/dfm/modelling/preprocess.py index 7ad64050..405e87bf 100644 --- a/archive_v1/src/dfm/modelling/preprocess.py +++ b/archive_v1/src/dfm/modelling/preprocess.py @@ -44,7 +44,7 @@ def preprocess_dataset( """ # Only use text columns - for key in dataset.keys(): + for key in dataset: cols = dataset[key].column_names cols.remove("text") dataset[key] = dataset[key].remove_columns(cols) @@ -52,7 +52,10 @@ def preprocess_dataset( # Tokenize texts tokenize_func_ = partial(tokenize_func, tokenizer=tokenizer) dataset = dataset.map( - tokenize_func_, batched=True, num_proc=num_proc, remove_columns=["text"] + tokenize_func_, + batched=True, + num_proc=num_proc, + remove_columns=["text"], ) # Group texts into blocks of `block_size`. @@ -72,7 +75,8 @@ def preprocess_dataset( def tokenize_func( - examples: dict, tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizerBase] + examples: dict, + tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizerBase], ) -> BatchEncoding: """Wrapper for tokenization. @@ -98,8 +102,8 @@ def group_texts(examples: dict, block_size: int) -> dict: """ # Concatenate all texts. - concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} - total_length = len(concatenated_examples[list(examples.keys())[0]]) + concatenated_examples = {k: sum(examples[k], []) for k in examples} + total_length = len(concatenated_examples[next(iter(examples.keys()))]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # customize this part to your needs. diff --git a/archive_v1/tests/cleaning/deduper_test.py b/archive_v1/tests/cleaning/deduper_test.py index 39ab5643..2b439736 100644 --- a/archive_v1/tests/cleaning/deduper_test.py +++ b/archive_v1/tests/cleaning/deduper_test.py @@ -42,25 +42,27 @@ def identity_fn(doc: str) -> str: class TestDeduper: @pytest.fixture(scope="class") def shingle_params(self): - yield dict(normalization_func=default_normalization, split_method="word_ngram") + return { + "normalization_func": default_normalization, + "split_method": "word_ngram", + } @pytest.fixture(scope="class") def minhash_params(self): - yield dict( - normalization_func=default_normalization, - split_method="paragraph", - ngram_size=1, - ngram_stride=1, - num_minhashes=128, - random_seed=42, - ) + return { + "normalization_func": default_normalization, + "split_method": "paragraph", + "ngram_size": 1, + "ngram_stride": 1, + "num_minhashes": 128, + "random_seed": 42, + } def deduper(self, **kwargs): - default_test_args = dict(ngram_size=1, random_seed=42, verbose=False) + default_test_args = {"ngram_size": 1, "random_seed": 42, "verbose": False} return Deduper(**dict(default_test_args, **kwargs)) def dedup(self, corpus, **kwargs): - # Add a document ID to the corpus, if it isn't there already if isinstance(corpus, list) and isinstance(corpus[0], str): corpus = list(enumerate(corpus)) @@ -80,20 +82,20 @@ def miss_percentage(self, corpus=None, iterations=100, **kwargs): "én, to, tre! én, to, tre!", ] misses = 0 - for i in range(0, iterations): + for i in range(iterations): if len(self.dedup(corpus, random_seed=i, **kwargs)) == 2: misses += 1 return (100.0 * misses) / iterations def test_stream(self): corpus = iter( - [(0, "hej med dig min ven"), (1, "hej med dig"), (2, "farvel du gamle")] + [(0, "hej med dig min ven"), (1, "hej med dig"), (2, "farvel du gamle")], ) self.dedup(corpus) == ["hej med dig min ven", "farvel du gamle"] def test_removes_exact_duplicates(self): assert self.dedup( - ["hej med dig min ven", "hej med dig min ven", "farvel du gamle"] + ["hej med dig min ven", "hej med dig min ven", "farvel du gamle"], ) == ["hej med dig min ven", "farvel du gamle"] def test_removes_near_duplicates(self): @@ -103,15 +105,16 @@ def test_removes_near_duplicates(self): "én, to, tre! én, to, tre!", "Da kom en soldat marcherende hen ad landevejen:\n " "én, to, tre! én, to, tre!", - ] + ], ) == [ "Der kom en soldat marcherende hen ad landevejen:\n " - "én, to, tre! én, to, tre!" + "én, to, tre! én, to, tre!", ] def test_document_shorter_than_shingles(self): assert self.dedup( - ["Hej med dig", "Hej med dig", "Gå din vej"], ngram_size=13 + ["Hej med dig", "Hej med dig", "Gå din vej"], + ngram_size=13, ) == ["Hej med dig", "Gå din vej"] def test_split_by_word_ngram(self): @@ -186,19 +189,28 @@ def test_256_minhashes(self): def test_2_ngram_shingles(self, shingle_params): shingles = get_shingles( - "Hej med dig Kim", ngram_size=2, ngram_stride=1, **shingle_params + "Hej med dig Kim", + ngram_size=2, + ngram_stride=1, + **shingle_params, ) assert shingles == ["Hej med", "med dig", "dig Kim"] def test_3_ngram_shingles(self, shingle_params): shingles = get_shingles( - "Hej med dig Kim", ngram_size=3, ngram_stride=1, **shingle_params + "Hej med dig Kim", + ngram_size=3, + ngram_stride=1, + **shingle_params, ) assert shingles == ["Hej med dig", "med dig Kim"] def test_double_stride_shingles(self, shingle_params): shingles = get_shingles( - "Hej med dig Kim", ngram_size=1, ngram_stride=2, **shingle_params + "Hej med dig Kim", + ngram_size=1, + ngram_stride=2, + **shingle_params, ) assert shingles == ["Hej", "dig"] @@ -212,7 +224,6 @@ def test_load_from_disk(self, minhash_params): corpus = ["hej med dig min ven", "hej med dig min ven", "farvel du gamle"] corpus = list(enumerate(corpus)) with tempfile.TemporaryDirectory() as temp: - # Create a deduper loaded from disk, and a different new one deduper = self.deduper(split_method="paragraph") deduper.deduplicate(corpus, output_dir=temp, overwrite=True) diff --git a/archive_v1/tests/cleaning/quality_test.py b/archive_v1/tests/cleaning/quality_test.py index 8948ade9..a954bb6c 100644 --- a/archive_v1/tests/cleaning/quality_test.py +++ b/archive_v1/tests/cleaning/quality_test.py @@ -1,6 +1,5 @@ """Test for the quality filter""" -from typing import List import pytest from pytest_lazyfixture import lazy_fixture @@ -65,14 +64,14 @@ def bullets_texts(self): @pytest.fixture(scope="class") def all_texts(self, tweet_texts, long_text, bullets_texts): - return tweet_texts + [long_text] + [bullets_texts] + return [*tweet_texts, long_text, bullets_texts] @pytest.fixture(scope="class") def quality_filter(self): return QualityFilter(top_ngram_min_count=1) @pytest.mark.parametrize( - "text,expected", + ("text", "expected"), [ ("jeg er glad", True), ("56789okd23456789098765sds", False), @@ -82,13 +81,13 @@ def test_stop_words(self, quality_filter, text: str, expected: bool): assert quality_filter.stop_word(quality_filter.nlp(text), n=2) is expected @pytest.mark.parametrize( - "texts,expected", + ("texts", "expected"), [ (lazy_fixture("bullets_texts"), False), (["56789okd23456789098765sds"], True), ], ) - def test_line_bullets(self, quality_filter, texts: List[str], expected: bool): + def test_line_bullets(self, quality_filter, texts: list[str], expected: bool): for t in texts: assert ( quality_filter.line_bullets_or_ellipsis( @@ -102,7 +101,7 @@ def test_line_bullets(self, quality_filter, texts: List[str], expected: bool): ) @pytest.mark.parametrize( - "text,expected", + ("text", "expected"), [("jeg er glad", True), ("jeg er glad...", False), ("jeg er glad…", False)], ) def test_line_ellipsis(self, quality_filter, text: str, expected: bool): @@ -118,24 +117,27 @@ def test_line_ellipsis(self, quality_filter, text: str, expected: bool): ) @pytest.mark.parametrize( - "text,expected", [("jeg er glad", True), ("67 54 13 B7", False)] + ("text", "expected"), + [("jeg er glad", True), ("67 54 13 B7", False)], ) def test_find_alpha(self, quality_filter, text: str, expected: bool): assert quality_filter.alpha(quality_filter.nlp(text), ratio=0.8) is expected @pytest.mark.parametrize( - "text,expected", [("jeg er glad", True), ("56789okd23456789098765sds", False)] + ("text", "expected"), + [("jeg er glad", True), ("56789okd23456789098765sds", False)], ) def test_mean_word_length(self, quality_filter, text: str, expected: bool): assert ( quality_filter.mean_word_length( - quality_filter.nlp(text), mean_word_length=(3, 10) + quality_filter.nlp(text), + mean_word_length=(3, 10), ) is expected ) @pytest.mark.parametrize( - "text,expected", + ("text", "expected"), [(pytest.lazy_fixture("long_text"), True), ("jeg er glad", False)], ) def test_doc_length(self, quality_filter, text: str, expected: bool): @@ -150,35 +152,41 @@ def test_quality_filter(self, quality_filter, all_texts): assert sum(quality_filter.filtered.values()) == (len(all_texts) - 1) @pytest.mark.parametrize( - "text, expected", + ("text", "expected"), [ ("dasdsadasdasdddddddddd\njeg er glad\n" * 2, False), ("jeg er glad\n\n" * 4, False), ], ) def test_duplicate_line_chr_fraction( - self, quality_filter, text: str, expected: bool + self, + quality_filter, + text: str, + expected: bool, ): filter_func = quality_filter.filters["duplicate_lines_chr_fraction"] nlp = quality_filter.nlp assert filter_func(nlp(text)) is expected @pytest.mark.parametrize( - "text, expected", + ("text", "expected"), [ ("dasdsadasdasdddddddddd\njeg er glad\n" * 2, True), ("jeg er glad\n\n" * 4, False), ], ) def test_duplicate_para_chr_fraction( - self, quality_filter, text: str, expected: bool + self, + quality_filter, + text: str, + expected: bool, ): filter_func = quality_filter.filters["duplicate_paragraph_chr_fraction"] nlp = quality_filter.nlp assert filter_func(nlp(text)) is expected @pytest.mark.parametrize( - "text, expected", + ("text", "expected"), [ ("Jeg er jeg er jeg, JeG ER, jeg er", False), ("jeg er glad, men også noglegange sur...", True), @@ -190,21 +198,24 @@ def test_top_ngram_chr_fraction(self, quality_filter, text: str, expected: bool) assert filter_func(nlp(text)) is expected @pytest.mark.parametrize( - "text,expected", + ("text", "expected"), [ ("jeg er glad, men også noglegange sur måske hvertfald." * 10, False), ("jeg er glad, men også noglegange sur...", True), ], ) def test_duplicate_ngram_chr_fraction( - self, quality_filter, text: str, expected: bool + self, + quality_filter, + text: str, + expected: bool, ): filter_func = quality_filter.filters["duplicate_ngram_chr_fraction"] nlp = quality_filter.nlp assert filter_func(nlp(text)) is expected @pytest.mark.parametrize( - "documents_language, correct_document_indicies", + ("documents_language", "correct_document_indicies"), [ ( [ @@ -244,11 +255,14 @@ def test_duplicate_ngram_chr_fraction( """, ], [1, 2], - ) + ), ], ) def test_language_detection_luga( - self, quality_filter, documents_language, correct_document_indicies + self, + quality_filter, + documents_language, + correct_document_indicies, ): filter = quality_filter.filters["detect_language"] nlp = quality_filter.nlp @@ -258,7 +272,7 @@ def test_language_detection_luga( assert passed_indicies == correct_document_indicies @pytest.mark.parametrize( - "documents_language, expected", + ("documents_language", "expected"), [ ( """Denne paragraf @@ -279,7 +293,10 @@ def test_language_detection_luga( ], ) def test_short_long_sentence( - self, quality_filter, documents_language: str, expected: bool + self, + quality_filter, + documents_language: str, + expected: bool, ): filter = quality_filter.filters["short_long_sentece"] nlp = quality_filter.nlp diff --git a/archive_v1/tests/cleaning/sentence_filter_test.py b/archive_v1/tests/cleaning/sentence_filter_test.py index 0b23eecb..37bbfed6 100644 --- a/archive_v1/tests/cleaning/sentence_filter_test.py +++ b/archive_v1/tests/cleaning/sentence_filter_test.py @@ -8,11 +8,11 @@ class TestEndsWithPunctuationOrEmoji: @pytest.fixture(scope="class") def sentence_filter(self): - yield SentenceFilter(filter_names=["ends_with_punctuation_or_emoji"]) + return SentenceFilter(filter_names=["ends_with_punctuation_or_emoji"]) @pytest.fixture(scope="class") def sentences(self): - yield [ + return [ "Det her er en sætning, som slutter med et punktum.", "Denne her sætning, skrevet af Hr. Mortensen, slutter med en smiley 🎉", "Denne her slutter ikke", @@ -22,18 +22,21 @@ def sentences(self): @pytest.fixture(scope="class") def clean_sentence_indices(self): - yield [0, 1, 3] + return [0, 1, 3] @pytest.fixture(scope="class") def document(self, sentences): - yield "\n".join(sentences) + return "\n".join(sentences) @pytest.fixture(scope="class") def cleaned_document(self, sentences, clean_sentence_indices): - yield "\n".join([sentences[i] for i in clean_sentence_indices]) + return "\n".join([sentences[i] for i in clean_sentence_indices]) def test_sentence_ends_with_punctuation_or_emoji( - self, sentences, sentence_filter, clean_sentence_indices + self, + sentences, + sentence_filter, + clean_sentence_indices, ) -> None: """Tests that the sentences are correctly.""" filter_outputs = [ @@ -54,11 +57,11 @@ def test_filter_corpus(self, sentence_filter, document, cleaned_document) -> Non class TestHasFewTitleCasedWords: @pytest.fixture(scope="class") def sentence_filter(self): - yield SentenceFilter(filter_names=["has_few_title_cased_words"]) + return SentenceFilter(filter_names=["has_few_title_cased_words"]) @pytest.fixture(scope="class") def sentences(self): - yield [ + return [ "Det her er en sætning, som kun har ét ord, der starter med stort bogstav.", "Om os Indkøbskurv Shop Find butik Kontakt", "Han hedder John Hansen, blev der sagt.", @@ -66,18 +69,21 @@ def sentences(self): @pytest.fixture(scope="class") def clean_sentence_indices(self): - yield [0, 2] + return [0, 2] @pytest.fixture(scope="class") def document(self, sentences): - yield "\n".join(sentences) + return "\n".join(sentences) @pytest.fixture(scope="class") def cleaned_document(self, sentences, clean_sentence_indices): - yield "\n".join([sentences[i] for i in clean_sentence_indices]) + return "\n".join([sentences[i] for i in clean_sentence_indices]) def test_has_few_title_cased_words( - self, sentences, sentence_filter, clean_sentence_indices + self, + sentences, + sentence_filter, + clean_sentence_indices, ) -> None: """Tests that the sentences are correctly filtered.""" filter_outputs = [ @@ -98,11 +104,11 @@ def test_filter_corpus(self, sentence_filter, document, cleaned_document) -> Non class TestHasEnoughWords: @pytest.fixture(scope="class") def sentence_filter(self): - yield SentenceFilter(filter_names=["has_enough_words"]) + return SentenceFilter(filter_names=["has_enough_words"]) @pytest.fixture(scope="class") def sentences(self): - yield [ + return [ "Det her er en sætning, som har nok ord.", "Få ord!", "Hej", @@ -112,18 +118,21 @@ def sentences(self): @pytest.fixture(scope="class") def clean_sentence_indices(self): - yield [0, 4] + return [0, 4] @pytest.fixture(scope="class") def document(self, sentences): - yield "\n".join(sentences) + return "\n".join(sentences) @pytest.fixture(scope="class") def cleaned_document(self, sentences, clean_sentence_indices): - yield "\n".join([sentences[i] for i in clean_sentence_indices]) + return "\n".join([sentences[i] for i in clean_sentence_indices]) def test_has_enough_words( - self, sentences, sentence_filter, clean_sentence_indices + self, + sentences, + sentence_filter, + clean_sentence_indices, ) -> None: """Tests that the sentences are correctly filtered.""" filter_outputs = [ @@ -143,11 +152,11 @@ def test_filter_corpus(self, sentence_filter, document, cleaned_document) -> Non class TestFewCurlyBrackets: @pytest.fixture(scope="class") def sentence_filter(self): - yield SentenceFilter(filter_names=["has_few_curly_brackets"]) + return SentenceFilter(filter_names=["has_few_curly_brackets"]) @pytest.fixture(scope="class") def sentences(self): - yield [ + return [ "Det her er bare en helt normal sætning.", "En sætning må gerne have nogle krølleparanteser :-}", "Men den må ikke have {nogle stykker}.", @@ -156,18 +165,21 @@ def sentences(self): @pytest.fixture(scope="class") def clean_sentence_indices(self): - yield [0, 1] + return [0, 1] @pytest.fixture(scope="class") def document(self, sentences): - yield "\n".join(sentences) + return "\n".join(sentences) @pytest.fixture(scope="class") def cleaned_document(self, sentences, clean_sentence_indices): - yield "\n".join([sentences[i] for i in clean_sentence_indices]) + return "\n".join([sentences[i] for i in clean_sentence_indices]) def test_has_few_curly_brackets( - self, sentences, sentence_filter, clean_sentence_indices + self, + sentences, + sentence_filter, + clean_sentence_indices, ) -> None: """Tests that the sentences are correctly filtered.""" filter_outputs = [ diff --git a/archive_v1/tests/description/match_counter_test.py b/archive_v1/tests/description/match_counter_test.py index b7d8f0d5..31b270ff 100644 --- a/archive_v1/tests/description/match_counter_test.py +++ b/archive_v1/tests/description/match_counter_test.py @@ -25,7 +25,7 @@ def regex_patterns(self): @pytest.fixture(scope="class") def term_pattern_list(self): return MatchCounter.term_list_to_spacy_match_patterns( - term_list=["heks", "soldat"] + term_list=["heks", "soldat"], ) @pytest.fixture(scope="class") @@ -39,7 +39,7 @@ def mc_basic(self, nlp, term_pattern_list): @pytest.fixture(scope="class") def pæn_matcher(self, nlp): pæn_match_patterns = MatchCounter.term_list_to_spacy_match_patterns( - term_list=["pæn"] + term_list=["pæn"], ) return MatchCounter(match_patterns=pæn_match_patterns, nlp=nlp) @@ -52,7 +52,7 @@ def test_term_list_pattern_generation(self, term_pattern_list): def test_matcher_object_generation(self, regex_patterns, mc_basic): matcher_objects = mc_basic.create_matcher_object_from_pattern_list( - pattern_container_list=regex_patterns + pattern_container_list=regex_patterns, ) assert len(matcher_objects) == 2 @@ -99,7 +99,8 @@ def test_labelled_term_list_generation(self): labelled_term_list = [{"christian": ["christian", "christianity"]}] output = MatchCounter.list_of_labelled_term_lists_to_spacy_match_patterns( - list_of_labelled_term_lists=labelled_term_list, lowercase=True + list_of_labelled_term_lists=labelled_term_list, + lowercase=True, ) assert output == [ diff --git a/archive_v1/tests/dfm_tokenizers/tokenizer_config_test.py b/archive_v1/tests/dfm_tokenizers/tokenizer_config_test.py index e3fd8b72..6ff37b25 100644 --- a/archive_v1/tests/dfm_tokenizers/tokenizer_config_test.py +++ b/archive_v1/tests/dfm_tokenizers/tokenizer_config_test.py @@ -14,24 +14,24 @@ @pytest.fixture(scope="module") def valid_config_dict(): - yield dict( - tokenizer_type="bpe", - vocab_size=1000, - lower_case=False, - sentence_piece=False, - add_prefix_space=False, - byte_level=False, - add_sep_and_cls_tokens=False, - padding=False, - truncation=False, - max_length=512, - nfkc_normalization=False, - pad_token="", - bos_token="", - eos_token="", - unk_token="", - mask_token="", - ) + return { + "tokenizer_type": "bpe", + "vocab_size": 1000, + "lower_case": False, + "sentence_piece": False, + "add_prefix_space": False, + "byte_level": False, + "add_sep_and_cls_tokens": False, + "padding": False, + "truncation": False, + "max_length": 512, + "nfkc_normalization": False, + "pad_token": "", + "bos_token": "", + "eos_token": "", + "unk_token": "", + "mask_token": "", + } class TestTokenizerConfig: @@ -39,7 +39,7 @@ class TestTokenizerConfig: @pytest.fixture(scope="class") def config_path(self): - yield Path("test_config.json") + return Path("test_config.json") def test_tokenizer_config_init(self, valid_config_dict): TokenizerConfig(**valid_config_dict) diff --git a/archive_v1/tests/dfm_tokenizers/train_tokenizer_test.py b/archive_v1/tests/dfm_tokenizers/train_tokenizer_test.py index 22c46eee..a4920560 100644 --- a/archive_v1/tests/dfm_tokenizers/train_tokenizer_test.py +++ b/archive_v1/tests/dfm_tokenizers/train_tokenizer_test.py @@ -14,24 +14,24 @@ @pytest.fixture(scope="module") def valid_config_dict(): - yield dict( - tokenizer_type="bpe", - vocab_size=1000, - lower_case=False, - sentence_piece=False, - add_prefix_space=False, - byte_level=False, - add_sep_and_cls_tokens=False, - padding=False, - truncation=False, - max_length=512, - nfkc_normalization=False, - pad_token="", - bos_token="", - eos_token="", - unk_token="", - mask_token="", - ) + return { + "tokenizer_type": "bpe", + "vocab_size": 1000, + "lower_case": False, + "sentence_piece": False, + "add_prefix_space": False, + "byte_level": False, + "add_sep_and_cls_tokens": False, + "padding": False, + "truncation": False, + "max_length": 512, + "nfkc_normalization": False, + "pad_token": "", + "bos_token": "", + "eos_token": "", + "unk_token": "", + "mask_token": "", + } class TestTrainTokenizer: @@ -39,26 +39,32 @@ class TestTrainTokenizer: @pytest.fixture(scope="class") def dataset(self): - yield load_dataset("DDSC/lcc", split="test") + return load_dataset("DDSC/lcc", split="test") @pytest.fixture(scope="class") def streamed_dataset(self): - yield load_dataset("DDSC/lcc", split="test", streaming=True) + return load_dataset("DDSC/lcc", split="test", streaming=True) @pytest.fixture(scope="class") def train_params(self): - yield dict(save_tokenizer=False, show_progress=False) + return {"save_tokenizer": False, "show_progress": False} @pytest.fixture(scope="class") def test_docs(self): - yield ["Dette er en dårlig .", "Test…"] + return ["Dette er en dårlig .", "Test…"] def test_iterable_of_texts( - self, dataset, valid_config_dict, train_params, test_docs + self, + dataset, + valid_config_dict, + train_params, + test_docs, ): config = TokenizerConfig(**valid_config_dict) tok = train_tokenizer( - corpus=(s["text"] for s in dataset), config=config, **train_params + corpus=(s["text"] for s in dataset), + config=config, + **train_params, ) tokens = ["Det", "te", "er", "en", "dår", "lig", "", "."] assert tok.encode(test_docs[0]).tokens == tokens @@ -66,7 +72,9 @@ def test_iterable_of_texts( def test_list_of_texts(self, dataset, valid_config_dict, train_params, test_docs): config = TokenizerConfig(**valid_config_dict) tok = train_tokenizer( - corpus=[s["text"] for s in dataset], config=config, **train_params + corpus=[s["text"] for s in dataset], + config=config, + **train_params, ) tokens = ["Det", "te", "er", "en", "dår", "lig", "", "."] assert tok.encode(test_docs[0]).tokens == tokens @@ -79,7 +87,11 @@ def test_series_of_texts(self, dataset, valid_config_dict, train_params, test_do assert tok.encode(test_docs[0]).tokens == tokens def test_streaming( - self, streamed_dataset, valid_config_dict, train_params, test_docs + self, + streamed_dataset, + valid_config_dict, + train_params, + test_docs, ): config = TokenizerConfig(**valid_config_dict) tok = train_tokenizer(corpus=streamed_dataset, config=config, **train_params) @@ -127,7 +139,11 @@ def test_sentence_piece(self, dataset, valid_config_dict, train_params, test_doc assert tok.encode(test_docs[0]).tokens == tokens def test_add_prefix_space( - self, dataset, valid_config_dict, train_params, test_docs + self, + dataset, + valid_config_dict, + train_params, + test_docs, ): config_dict = valid_config_dict.copy() config_dict["add_prefix_space"] = True @@ -145,7 +161,11 @@ def test_byte_level(self, dataset, valid_config_dict, train_params, test_docs): assert tok.encode(test_docs[0]).tokens == tokens def test_add_special_tokens( - self, dataset, valid_config_dict, train_params, test_docs + self, + dataset, + valid_config_dict, + train_params, + test_docs, ): config_dict = valid_config_dict.copy() config_dict["add_sep_and_cls_tokens"] = True @@ -183,7 +203,11 @@ def test_truncation(self, dataset, valid_config_dict, train_params, test_docs): assert tok.encode(test_docs[0]).tokens == tokens def test_nfkc_normalization( - self, dataset, valid_config_dict, train_params, test_docs + self, + dataset, + valid_config_dict, + train_params, + test_docs, ): config_dict = valid_config_dict.copy() config_dict["nfkc_normalization"] = True