From 4e3f2291c72b6285927cff2af22148d072fa2cc9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Peter=20Bj=C3=B8rn=20J=C3=B8rgensen?= Date: Tue, 31 Oct 2023 16:09:14 +0100 Subject: [PATCH 01/21] Add Danish version of dolma taggers These taggers can be used with the dolma tool. Should be rewritten to provide labels for all target languages at once. --- pyproject.toml | 5 +- .../dolma_taggers/language_da.py | 196 ++++++++++++++++++ 2 files changed, 200 insertions(+), 1 deletion(-) create mode 100644 src/dfm/common/data_cleaning/dolma_taggers/language_da.py diff --git a/pyproject.toml b/pyproject.toml index af570248..16f25303 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,10 @@ classifiers = [ ] requires-python = ">=3.10" -dependencies = ["pydantic==1.8.2"] +dependencies = [ + "pydantic==1.8.2", + "dolma>=0.9.1", +] [project.optional-dependencies] dev = ["black==23.9.1", "ruff==0.1.0", "pyright==1.1.331", "pre-commit==3.5.0"] diff --git a/src/dfm/common/data_cleaning/dolma_taggers/language_da.py b/src/dfm/common/data_cleaning/dolma_taggers/language_da.py new file mode 100644 index 00000000..eae0acc4 --- /dev/null +++ b/src/dfm/common/data_cleaning/dolma_taggers/language_da.py @@ -0,0 +1,196 @@ +""" + +Filters. + +""" +from collections.abc import Iterable +from typing import TYPE_CHECKING + +import necessary +import pycld2 as cld2 +import regex +from anyascii import anyascii +from dolma.core.data_types import DocResult, Document, Span, TextSlice +from dolma.core.ft_tagger import BaseFastTextTagger, Prediction +from dolma.core.registry import TaggerRegistry +from dolma.core.taggers import BaseTagger +from dolma.core.utils import split_paragraphs + +with necessary.necessary("cld3", soft=True) as CLD3_AVAILABLE: + if CLD3_AVAILABLE or TYPE_CHECKING: + import cld3 # pyright:ignore pylint:disable=import-error + + +@TaggerRegistry.add("cld3_da_doc_v2") +class Cld3LanguageTaggerDa(BaseTagger): + def __init__(self) -> None: + if not CLD3_AVAILABLE: + raise ImportError( + f"cld3 is not install, cannot instantiate {self.__class__.__name__}" + ) + + def _predict_text(self, text: str) -> tuple[str, float]: + pred = cld3.get_language(text) # pyright: ignore + score = pred.probability if pred.language == "da" else 0.0 + return "da", score + + def predict(self, doc: Document) -> DocResult: + lang, score = self._predict_text(doc.text) + positive_span = Span(start=0, end=len(doc.text), type=lang, score=score) + negative_span = Span( + start=0, end=len(doc.text), type=f"not_{lang}", score=1.0 - score + ) + return DocResult(doc=doc, spans=[positive_span, negative_span]) + + +@TaggerRegistry.add("cld3_da_paragraph_v2") +class Cld3LanguageTaggerParagraphDa(Cld3LanguageTaggerDa): + def predict(self, doc: Document) -> DocResult: + paragraphs = split_paragraphs(doc.text) + spans: list[Span] = [] + for paragraph in paragraphs: + lang, score = self._predict_text(paragraph.text) # pyright: ignore + positive_span = Span( + start=paragraph.start, end=paragraph.end, type=lang, score=score + ) + negative_span = Span( + start=paragraph.start, + end=paragraph.end, + type=f"not_{lang}", + score=1.0 - score, + ) + spans.extend((positive_span, negative_span)) + return DocResult(doc=doc, spans=spans) + + +@TaggerRegistry.add("cld2_da_doc_v2") +class Cld2LanguageFilterDa(BaseTagger): + RE_BAD_CHARS = regex.compile(r"[\p{Cc}\p{Cs}]+") + + def _sanitize_input(self, text: str) -> str: + return self.RE_BAD_CHARS.sub("", text) + + def _to_ascii_input(self, text: str) -> str: + return anyascii(text) + + def _identity_fn(self, text: str) -> str: + return text + + def _predict_text(self, text: str) -> tuple[str, float]: + details = [] + is_reliable = False + for fn in (self._identity_fn, self._to_ascii_input, self._sanitize_input): + try: + is_reliable, _, details = cld2.detect(fn(text)) + break + except cld2.error: + ... + + score = max([d[2] for d in details if d[0] == "DANISH" and is_reliable] or [0]) + return "da", score / 100.0 + + def predict(self, doc: Document) -> DocResult: + lang, score = self._predict_text(doc.text) + positive_span = Span(start=0, end=len(doc.text), type=lang, score=score) + negative_span = Span( + start=0, end=len(doc.text), type=f"not_{lang}", score=1.0 - score + ) + return DocResult(doc=doc, spans=[positive_span, negative_span]) + + +@TaggerRegistry.add("cld2_da_paragraph_v2") +class Cld2LanguageFilterParagraphDa(Cld2LanguageFilterDa): + def predict(self, doc: Document) -> DocResult: + paragraphs = split_paragraphs(doc.text) + spans: list[Span] = [] + for paragraph in paragraphs: + lang, score = self._predict_text(paragraph.text) # pyright: ignore + positive_span = Span( + start=paragraph.start, end=paragraph.end, type=lang, score=score + ) + negative_span = Span( + start=paragraph.start, + end=paragraph.end, + type=f"not_{lang}", + score=1.0 - score, + ) + spans.extend((positive_span, negative_span)) + return DocResult(doc=doc, spans=spans) + + +@TaggerRegistry.add("ft_lang_id_da_doc_v2") +class FastTextDanishLanguageDocumentTagger(BaseFastTextTagger): + MODEL_PATH = "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin" + + def __init__(self): + super().__init__( + model_path=self.MODEL_PATH, model_mode=self.DOCUMENT_LEVEL_TAGGER + ) + + def predict_slice(self, text_slice: TextSlice) -> Iterable[Prediction]: + pred = self.classifier.predict( + text_slice.text.lower().replace("\n", " ").strip(), k=-1 + ) + for label, score in zip(*pred): + if label == "__label__da": + return Prediction(label="da", score=score), Prediction( + label="not_da", score=1.0 - score + ) + return Prediction(label="da", score=0.0), Prediction(label="not_da", score=1.0) + + +@TaggerRegistry.add("ft_lang_id_da_paragraph_v2") +class FastTextDanishLanguageParagraphTagger(FastTextDanishLanguageDocumentTagger): + def __init__(self): + BaseFastTextTagger.__init__( + self, model_path=self.MODEL_PATH, model_mode=self.PARAGRAPH_LEVEL_TAGGER + ) + + +def add_global_language_score_from_slice_score_da(result: DocResult) -> DocResult: + # the total document score is # of characters in each "danish" span multiplied by the likelihood + # of said span being danish + try: + doc_da_score = sum( + (s.end - s.start) * s.score for s in result.spans if s.type == "da" + ) / len( + result.doc.text, + ) + doc_not_da_score = 1 - doc_da_score + except ZeroDivisionError: + doc_da_score = doc_not_da_score = 0.0 + + doc_level = ( + Span(start=0, end=len(result.doc.text), type="doc_da", score=doc_da_score), + Span( + start=0, end=len(result.doc.text), type="doc_not_da", score=doc_not_da_score + ), + ) + result.spans.extend(doc_level) + return result + + +@TaggerRegistry.add("cld2_da_paragraph_with_doc_score_v2") +class Cld2LanguageFilterParagraphWithDocScoreTaggerDa(Cld2LanguageFilterParagraphDa): + def predict(self, doc: Document) -> DocResult: + doc_result = super().predict(doc) + doc_result = add_global_language_score_from_slice_score_da(doc_result) + return doc_result + + +@TaggerRegistry.add("cld3_da_paragraph_with_doc_score_v2") +class Cld3LanguageFilterParagraphWithDocScoreTaggerDa(Cld3LanguageTaggerParagraphDa): + def predict(self, doc: Document) -> DocResult: + doc_result = super().predict(doc) + doc_result = add_global_language_score_from_slice_score_da(doc_result) + return doc_result + + +@TaggerRegistry.add("ft_lang_id_da_paragraph_with_doc_score_v2") +class FastTextDanishLanguageParagraphWithDocScoreTagger( + FastTextDanishLanguageParagraphTagger +): + def predict(self, doc: Document) -> DocResult: + doc_result = super().predict(doc) + doc_result = add_global_language_score_from_slice_score_da(doc_result) + return doc_result From cd64cdbad5ba966941d88fe6642f928ca3a1c71e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Peter=20Bj=C3=B8rn=20J=C3=B8rgensen?= Date: Wed, 1 Nov 2023 14:31:04 +0100 Subject: [PATCH 02/21] update dependencies --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 16f25303..3e5e33ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,8 +17,8 @@ classifiers = [ requires-python = ">=3.10" dependencies = [ - "pydantic==1.8.2", - "dolma>=0.9.1", + "pydantic>=2.4.2", # dolma does not work with very old versions of pydantic + "dolma@git+https://github.com/peterbjorgensen/dolma.git@extendable_tagger_cli", # Install from fork until pull request makes it into upstream dolma ] [project.optional-dependencies] From 38cb33c58d8cc69661505db8d0486639ec5a3ec8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Peter=20Bj=C3=B8rn=20J=C3=B8rgensen?= Date: Thu, 2 Nov 2023 15:42:37 +0100 Subject: [PATCH 03/21] make language taggers handle scandinavian languages as opposed to only Danish --- .../dolma_taggers/language_da.py | 196 ------------------ .../dolma_taggers/language_scandi.py | 181 ++++++++++++++++ 2 files changed, 181 insertions(+), 196 deletions(-) delete mode 100644 src/dfm/common/data_cleaning/dolma_taggers/language_da.py create mode 100644 src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py diff --git a/src/dfm/common/data_cleaning/dolma_taggers/language_da.py b/src/dfm/common/data_cleaning/dolma_taggers/language_da.py deleted file mode 100644 index eae0acc4..00000000 --- a/src/dfm/common/data_cleaning/dolma_taggers/language_da.py +++ /dev/null @@ -1,196 +0,0 @@ -""" - -Filters. - -""" -from collections.abc import Iterable -from typing import TYPE_CHECKING - -import necessary -import pycld2 as cld2 -import regex -from anyascii import anyascii -from dolma.core.data_types import DocResult, Document, Span, TextSlice -from dolma.core.ft_tagger import BaseFastTextTagger, Prediction -from dolma.core.registry import TaggerRegistry -from dolma.core.taggers import BaseTagger -from dolma.core.utils import split_paragraphs - -with necessary.necessary("cld3", soft=True) as CLD3_AVAILABLE: - if CLD3_AVAILABLE or TYPE_CHECKING: - import cld3 # pyright:ignore pylint:disable=import-error - - -@TaggerRegistry.add("cld3_da_doc_v2") -class Cld3LanguageTaggerDa(BaseTagger): - def __init__(self) -> None: - if not CLD3_AVAILABLE: - raise ImportError( - f"cld3 is not install, cannot instantiate {self.__class__.__name__}" - ) - - def _predict_text(self, text: str) -> tuple[str, float]: - pred = cld3.get_language(text) # pyright: ignore - score = pred.probability if pred.language == "da" else 0.0 - return "da", score - - def predict(self, doc: Document) -> DocResult: - lang, score = self._predict_text(doc.text) - positive_span = Span(start=0, end=len(doc.text), type=lang, score=score) - negative_span = Span( - start=0, end=len(doc.text), type=f"not_{lang}", score=1.0 - score - ) - return DocResult(doc=doc, spans=[positive_span, negative_span]) - - -@TaggerRegistry.add("cld3_da_paragraph_v2") -class Cld3LanguageTaggerParagraphDa(Cld3LanguageTaggerDa): - def predict(self, doc: Document) -> DocResult: - paragraphs = split_paragraphs(doc.text) - spans: list[Span] = [] - for paragraph in paragraphs: - lang, score = self._predict_text(paragraph.text) # pyright: ignore - positive_span = Span( - start=paragraph.start, end=paragraph.end, type=lang, score=score - ) - negative_span = Span( - start=paragraph.start, - end=paragraph.end, - type=f"not_{lang}", - score=1.0 - score, - ) - spans.extend((positive_span, negative_span)) - return DocResult(doc=doc, spans=spans) - - -@TaggerRegistry.add("cld2_da_doc_v2") -class Cld2LanguageFilterDa(BaseTagger): - RE_BAD_CHARS = regex.compile(r"[\p{Cc}\p{Cs}]+") - - def _sanitize_input(self, text: str) -> str: - return self.RE_BAD_CHARS.sub("", text) - - def _to_ascii_input(self, text: str) -> str: - return anyascii(text) - - def _identity_fn(self, text: str) -> str: - return text - - def _predict_text(self, text: str) -> tuple[str, float]: - details = [] - is_reliable = False - for fn in (self._identity_fn, self._to_ascii_input, self._sanitize_input): - try: - is_reliable, _, details = cld2.detect(fn(text)) - break - except cld2.error: - ... - - score = max([d[2] for d in details if d[0] == "DANISH" and is_reliable] or [0]) - return "da", score / 100.0 - - def predict(self, doc: Document) -> DocResult: - lang, score = self._predict_text(doc.text) - positive_span = Span(start=0, end=len(doc.text), type=lang, score=score) - negative_span = Span( - start=0, end=len(doc.text), type=f"not_{lang}", score=1.0 - score - ) - return DocResult(doc=doc, spans=[positive_span, negative_span]) - - -@TaggerRegistry.add("cld2_da_paragraph_v2") -class Cld2LanguageFilterParagraphDa(Cld2LanguageFilterDa): - def predict(self, doc: Document) -> DocResult: - paragraphs = split_paragraphs(doc.text) - spans: list[Span] = [] - for paragraph in paragraphs: - lang, score = self._predict_text(paragraph.text) # pyright: ignore - positive_span = Span( - start=paragraph.start, end=paragraph.end, type=lang, score=score - ) - negative_span = Span( - start=paragraph.start, - end=paragraph.end, - type=f"not_{lang}", - score=1.0 - score, - ) - spans.extend((positive_span, negative_span)) - return DocResult(doc=doc, spans=spans) - - -@TaggerRegistry.add("ft_lang_id_da_doc_v2") -class FastTextDanishLanguageDocumentTagger(BaseFastTextTagger): - MODEL_PATH = "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin" - - def __init__(self): - super().__init__( - model_path=self.MODEL_PATH, model_mode=self.DOCUMENT_LEVEL_TAGGER - ) - - def predict_slice(self, text_slice: TextSlice) -> Iterable[Prediction]: - pred = self.classifier.predict( - text_slice.text.lower().replace("\n", " ").strip(), k=-1 - ) - for label, score in zip(*pred): - if label == "__label__da": - return Prediction(label="da", score=score), Prediction( - label="not_da", score=1.0 - score - ) - return Prediction(label="da", score=0.0), Prediction(label="not_da", score=1.0) - - -@TaggerRegistry.add("ft_lang_id_da_paragraph_v2") -class FastTextDanishLanguageParagraphTagger(FastTextDanishLanguageDocumentTagger): - def __init__(self): - BaseFastTextTagger.__init__( - self, model_path=self.MODEL_PATH, model_mode=self.PARAGRAPH_LEVEL_TAGGER - ) - - -def add_global_language_score_from_slice_score_da(result: DocResult) -> DocResult: - # the total document score is # of characters in each "danish" span multiplied by the likelihood - # of said span being danish - try: - doc_da_score = sum( - (s.end - s.start) * s.score for s in result.spans if s.type == "da" - ) / len( - result.doc.text, - ) - doc_not_da_score = 1 - doc_da_score - except ZeroDivisionError: - doc_da_score = doc_not_da_score = 0.0 - - doc_level = ( - Span(start=0, end=len(result.doc.text), type="doc_da", score=doc_da_score), - Span( - start=0, end=len(result.doc.text), type="doc_not_da", score=doc_not_da_score - ), - ) - result.spans.extend(doc_level) - return result - - -@TaggerRegistry.add("cld2_da_paragraph_with_doc_score_v2") -class Cld2LanguageFilterParagraphWithDocScoreTaggerDa(Cld2LanguageFilterParagraphDa): - def predict(self, doc: Document) -> DocResult: - doc_result = super().predict(doc) - doc_result = add_global_language_score_from_slice_score_da(doc_result) - return doc_result - - -@TaggerRegistry.add("cld3_da_paragraph_with_doc_score_v2") -class Cld3LanguageFilterParagraphWithDocScoreTaggerDa(Cld3LanguageTaggerParagraphDa): - def predict(self, doc: Document) -> DocResult: - doc_result = super().predict(doc) - doc_result = add_global_language_score_from_slice_score_da(doc_result) - return doc_result - - -@TaggerRegistry.add("ft_lang_id_da_paragraph_with_doc_score_v2") -class FastTextDanishLanguageParagraphWithDocScoreTagger( - FastTextDanishLanguageParagraphTagger -): - def predict(self, doc: Document) -> DocResult: - doc_result = super().predict(doc) - doc_result = add_global_language_score_from_slice_score_da(doc_result) - return doc_result diff --git a/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py b/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py new file mode 100644 index 00000000..0c2b4261 --- /dev/null +++ b/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py @@ -0,0 +1,181 @@ +""" + +Filters. + +""" +from collections.abc import Iterable + +import pycld2 as cld2 +import regex +from anyascii import anyascii +from dolma.core.data_types import DocResult, Document, Span, TextSlice +from dolma.core.ft_tagger import BaseFastTextTagger, Prediction +from dolma.core.registry import TaggerRegistry +from dolma.core.taggers import BaseTagger +from dolma.core.utils import split_paragraphs + +LANGS = { + "ENGLISH": "en", + "DANISH": "da", + "SWEDISH": "sv", + "NORWEGIAN": "no", + "ICELANDIC": "is", + "FAROESE": "fo", # Note that FAROESE is not supported by cld2 +} + +@TaggerRegistry.add("cld2_scandi_doc") +class Cld2LanguageFilterScandi(BaseTagger): + RE_BAD_CHARS = regex.compile(r"[\p{Cc}\p{Cs}]+") + + + def _sanitize_input(self, text: str) -> str: + return self.RE_BAD_CHARS.sub("", text) + + def _to_ascii_input(self, text: str) -> str: + return anyascii(text) + + def _identity_fn(self, text: str) -> str: + return text + + def _predict_text(self, text: str) -> dict[str, float]: + details = [] + is_reliable = False + for fn in (self._identity_fn, self._to_ascii_input, self._sanitize_input): + try: + is_reliable, _, details = cld2.detect(fn(text)) + break + except cld2.error: + ... + + scores: dict[str, float] = {} + if is_reliable: + for lang, lang_code, score, _ in details: + if lang in LANGS: + scores[LANGS[lang]] = score + + return scores + + def predict(self, doc: Document) -> DocResult: + lang_scores = self._predict_text(doc.text) + spans: list[Span] = [] + for lang, lang_code in LANGS.items(): + if lang_code in lang_scores: + score = lang_scores[lang_code] + else: + score = 0 + + positive_span = Span(start=0, end=len(doc.text), type=lang_code, score=score) + negative_span = Span( + start=0, end=len(doc.text), type=f"not_{lang_code}", score=1.0 - score + ) + spans.append(positive_span) + spans.append(negative_span) + return DocResult(doc=doc, spans=spans) + +@TaggerRegistry.add("cld2_scandi_paragraph") +class Cld2LanguageFilterParagraphScandi(Cld2LanguageFilterScandi): + def predict(self, doc: Document) -> DocResult: + paragraphs = split_paragraphs(doc.text) + spans: list[Span] = [] + for paragraph in paragraphs: + lang_scores = self._predict_text(paragraph.text) + for lang_code in LANGS.values(): + if lang_code in lang_scores: + score = lang_scores[lang_code] + else: + score = 0.0 + + positive_span = Span( + start=paragraph.start, end=paragraph.end, type=lang_code, score=score + ) + negative_span = Span( + start=paragraph.start, + end=paragraph.end, + type=f"not_{lang_code}", + score=1.0 - score, + ) + spans.extend((positive_span, negative_span)) + return DocResult(doc=doc, spans=spans) + + +@TaggerRegistry.add("ft_lang_id_scandi_doc") +class FastTextScandiLanguageDocumentTagger(BaseFastTextTagger): + MODEL_PATH = "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin" + + def __init__(self): + super().__init__( + model_path=self.MODEL_PATH, model_mode=self.DOCUMENT_LEVEL_TAGGER + ) + + def predict_slice(self, text_slice: TextSlice) -> Iterable[Prediction]: + pred = self.classifier.predict( + text_slice.text.lower().replace("\n", " ").strip(), k=-1 + ) + # Initialize scores to 0 + scores = {k: 0.0 for k in LANGS.values()} + + for label, score in zip(*pred): + # label is of the form __label__[code] + label_code = label[-2:] + if label_code in scores: + scores[label_code] = score + if label == "__label__da": + return Prediction(label="da", score=score), Prediction( + label="not_da", score=1.0 - score + ) + + predictions_positive = [Prediction(label=k, score=v) for k,v in scores.items()] + predictions_negative = [Prediction(label=k, score=1.0 - v) for k,v in scores.items()] + + return predictions_positive + predictions_negative + +@TaggerRegistry.add("ft_lang_id_scandi_paragraph") +class FastTextScandiLanguageParagraphTagger(FastTextScandiLanguageDocumentTagger): + def __init__(self): + BaseFastTextTagger.__init__( + self, model_path=self.MODEL_PATH, model_mode=self.PARAGRAPH_LEVEL_TAGGER + ) + + +def add_global_language_score_from_slice_score(result: DocResult) -> DocResult: + # the total document score is # of characters in each "lang" span multiplied by the likelihood + # of said span being lang + for lang in LANGS.values(): + try: + doc_lang_score = sum( + (s.end - s.start) * s.score for s in result.spans if s.type == lang + ) / len( + result.doc.text, + ) + doc_not_lang_score = 1 - doc_lang_score + except ZeroDivisionError: + doc_lang_score = doc_not_lang_score = 0.0 + + doc_level = ( + Span(start=0, end=len(result.doc.text), type=f"doc_{lang}", score=doc_lang_score), + Span( + start=0, end=len(result.doc.text), type=f"doc_not_{lang}", score=doc_not_lang_score + ), + ) + result.spans.extend(doc_level) + return result + + +# Composite tagger that provides both paragraph and doc scores +@TaggerRegistry.add("cld2_scandi_paragraph_with_doc_score") +class Cld2LanguageFilterParagraphWithDocScoreTaggerScandi(Cld2LanguageFilterParagraphScandi): + def predict(self, doc: Document) -> DocResult: + doc_result = super().predict(doc) + doc_result = add_global_language_score_from_slice_score(doc_result) + return doc_result + + +# Composite tagger that provides both paragraph and doc scores +@TaggerRegistry.add("ft_lang_id_scandi_paragraph_with_doc_score") +class FastTextScandiLanguageParagraphWithDocScoreTagger( + FastTextScandiLanguageParagraphTagger +): + def predict(self, doc: Document) -> DocResult: + doc_result = super().predict(doc) + doc_result = add_global_language_score_from_slice_score(doc_result) + return doc_result From 73a7ffcb5592da47e69cdbbb9c4869760faa0b09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Peter=20Bj=C3=B8rn=20J=C3=B8rgensen?= Date: Thu, 2 Nov 2023 16:03:46 +0100 Subject: [PATCH 04/21] fix missing normalization --- src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py b/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py index 0c2b4261..17307cce 100644 --- a/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py +++ b/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py @@ -51,7 +51,7 @@ def _predict_text(self, text: str) -> dict[str, float]: if is_reliable: for lang, lang_code, score, _ in details: if lang in LANGS: - scores[LANGS[lang]] = score + scores[LANGS[lang]] = score / 100.0 return scores From 6fe9d3407236a23c04af96e6e4f6570514d69e37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Peter=20Bj=C3=B8rn=20J=C3=B8rgensen?= Date: Mon, 6 Nov 2023 17:05:38 +0100 Subject: [PATCH 05/21] add first version of ccnet based perplexity tagger --- pyproject.toml | 1 + .../data_cleaning/dolma_taggers/perplexity.py | 127 ++++++++++++++++++ 2 files changed, 128 insertions(+) create mode 100644 src/dfm/common/data_cleaning/dolma_taggers/perplexity.py diff --git a/pyproject.toml b/pyproject.toml index 3e5e33ac..64938b81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ requires-python = ">=3.10" dependencies = [ "pydantic>=2.4.2", # dolma does not work with very old versions of pydantic "dolma@git+https://github.com/peterbjorgensen/dolma.git@extendable_tagger_cli", # Install from fork until pull request makes it into upstream dolma + "kenlm", # Used for dolma perplexity tagging ] [project.optional-dependencies] diff --git a/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py b/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py new file mode 100644 index 00000000..e04723b8 --- /dev/null +++ b/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py @@ -0,0 +1,127 @@ +""" +Perplexity taggers + +This module contain taggers based on language models +""" +import hashlib +import logging +import os +import requests + +import kenlm + +from dolma.core.data_types import DocResult, Document, Span +from dolma.core.registry import TaggerRegistry +from dolma.core.taggers import BaseTagger +from dolma.core.utils import split_paragraphs + +ccnet_sha256 = { +"af.arpa.bin":"7278e70cb22e29e94942b103c0ba49f406a9369c2949199fdf8d4bee4b0ce48e", +"ar.arpa.bin":"85739ba1e022a4abd9eb260e6c67e8a4e7646f0717e2800d8dde1ec039b7f5e2", +"az.arpa.bin":"247fd2355db94b4357d19c78c8ac38ce16299d1dac237745edeea8005d7771ba", +"be.arpa.bin":"b23a70aa0cec41555932e6b4aaa5a361c95d091fbd6d4c21e6a48c866b9cd1e8", +"bg.arpa.bin":"1edb68d25238d692cb9cc6b2e4f9fce0e99b49b421020c8e89d0781507dbcd38", +"bn.arpa.bin":"f21c8187eb77d2d7d17892b61dc3446dab79a61d3d0af4f0c90660f9df500cb2", +"ca.arpa.bin":"1e4e84639fd9a35cbfa47709ca2cd9eefc84dcee7ab7d91df11e5f89f88312d4", +"cs.arpa.bin":"4f89f980c12cae596b19fccd9aebea4be5be86c6f81a8b42fc975922ea656bb1", +"da.arpa.bin":"b7f754b56421944ada2c979d0b11e8eada8308e179cb60fbc1acc4318b03695b", +"de.arpa.bin":"a5bc18a9741dc57593d7cce469350d5d2db8ce1e87be6c2ec450850316e586ba", +"el.arpa.bin":"8a53a69835d0a8e88c720fc052180c54973d2b6ac3ed2ff83c666d432a0d3686", +"en.arpa.bin":"e90c9b25af01dcaa2667ed45d012d891269760fc6eccfe8dbbd161eb20e01d7d", +"es.arpa.bin":"00121ab8c31f275132fc67c292392a33ff81b8eae1015103e8a86f9df2e642d4", +"et.arpa.bin":"7c4b98dc3f7fff73611afdd0dc1379437cb0b3dd3addc0abadb65864cabb937f", +"fa.arpa.bin":"05d00d4fdb31e00295a63e4df4187954d43850a8bd7b61c717f809b19fc94cfe", +"fi.arpa.bin":"56aa4a6890c4152be3d594e7f7dc353e78881500803f36586c1c01d88f906618", +"fr.arpa.bin":"4a52387916be57551013df3f9052ee031c042445940a4d0e69b066597586c6aa", +"gu.arpa.bin":"4ad5be86ef47f3105eb9d7d178520a0cede5d02e4ca61a3aa2d32c8322ca5bd1", +"he.arpa.bin":"69d1ab538beb6c8aa646b7c611b701ad2d1a19dcce00d6690072fa9453ad2f00", +"hi.arpa.bin":"b7173df087ff5b24d759fdbf8d07d8e21a31c1b54c978c7c5c71f05b24e12f47", +"hr.arpa.bin":"3ba8caf473415c4d12be594c36892f1454a71a08441ad796bf105ebe4e957a8f", +"hu.arpa.bin":"ce82ceb8a1e808fc441d985c4249c08c67d527937d26e3e524404185803723cf", +"hy.arpa.bin":"3c5c3511a82538ab198536e54df4e770c40d78bf5929a7143ab42695641a0031", +"id.arpa.bin":"8e871368fb386180df09d1dfb45f0319dba7a1955b9d209e498c49d96d07b3dd", +"is.arpa.bin":"287f6f7bd8130d50df8966169427b236e9aa79ff2b4250c5bdfdc2c9a0c19f52", +"it.arpa.bin":"784efb647bd699041809d59dd309193f78a47ea347d13b0c93c3bd74f437a53b", +"ja.arpa.bin":"efa96d229e2a84be705f81bc4ea1c6da79505e5c7f001f92586e16481e5b586a", +"ka.arpa.bin":"07477bd9166bc2c748532f1c3af65aad42740231c0dc1f8a4410764e0d626199", +"kk.arpa.bin":"3cec2b6c9b3ae34919dd23ff59148e81b76593d7ec17feefcd5e2829cd1643c0", +"km.arpa.bin":"84a09db4e1e7a70e1cd7c347d9729339e3eaa993f42b4bba4ba91fe0a84ff763", +"kn.arpa.bin":"f1e0e469c8c78ac4e3b62d348e966e658cf7b8f683aafa4a2b4d55ca1e7d756c", +"ko.arpa.bin":"7e345046786a1ac6dbb0d3d0fdd65d2ff0e8a848395dbc84c6152acee1987f5f", +"lt.arpa.bin":"ecc1703e098477503035d980f6be841b5359f8f5f55cc4f78087232c7da15398", +"lv.arpa.bin":"5f6212551d5de115309674eed8ea595f1375973832917dd285942a0ef8d6c7e7", +"mk.arpa.bin":"0915b0c452f5bc6dd254c4145fd09f1252ea5e17f13f48991c72cb98fa2ed804", +"ml.arpa.bin":"3f0cfbf0bdc6935229d6903df8cb60b4ed2b9ed2cb9d4c253266b13bd3211297", +"mn.arpa.bin":"c8e57fcf604d178d45fbe3b1650c04e715c41cb8151bf8b115dc88c52ebfba56", +"mr.arpa.bin":"e00986484585cd67deba5902c7da78566452e3c40fc9aa285218152563d33303", +"my.arpa.bin":"ac3496e2981ea3ad85673ca52e04f5aa8e7be68d1d94c2e73ce26436864ae217", +"ne.arpa.bin":"7ef6c2d3e4e1858fb207e6c200e422833ccf072157a6a0148b408db3e760d22e", +"nl.arpa.bin":"aa017d97061e84f51d7f74b83a6a43aef246974fc9a502436043f6f0e9e12bbb", +"no.arpa.bin":"0ec663c264d6580beebe7e0e80a939dbe7082af55af3875f292ebd11ea5800de", +"pl.arpa.bin":"b97634bca2b28d95716b951ceadca3de4a170ff07639bcdc3c73fc0961362e98", +"pt.arpa.bin":"f5a10774d7b7125c6e887b62c56fea2d348adebc81ab1708d34f68de722090e0", +"ro.arpa.bin":"619b9a2d4d53bdb368bfdf2cc770e1e9549d52b22d1fd3afc0ee8a022543ed56", +"ru.arpa.bin":"588da7d3e160f61f7e821804bc4d518460687e1c4832c339bb3a28c03417ab53", +"uk.arpa.bin":"bfd09bdfe669a9fd5f8f8d9be519bdce3fb678214bc6afd5ccce499930b7d311", +"zh.arpa.bin":"f157d94cb2828bbb44b5dddf38e7eb7f62a47d317917646a73fe2af50a3dad68", +} + +def _get_ccnet_pretrained_lm(lang: str): + # Download pretrained model and save to the data folder + url = f"http://dl.fbaipublicfiles.com/cc_net/lm/{lang}.arpa.bin" + data_folder = "data_lm" + + if not os.path.exists(data_folder): + os.makedirs(data_folder) + + filename = f"{lang}.arpa.bin" + file_path = os.path.join(data_folder, filename) + + # Check if the file already exists + if not os.path.exists(file_path): + # If the file does not exist, download it + logging.info(f"Downloading {lang} model...") + response = requests.get(url) + if response.status_code == requests.codes.ok: + sha256 = hashlib.sha256(response.content).hexdigest() + if sha256 != ccnet_sha256[filename]: + raise RuntimeError(f"Checksum mismatch {sha256} != {ccnet_sha256[filename]}") + with open(file_path, 'wb') as file: + file.write(response.content) + logging.info(f"{lang} model downloaded and saved at {file_path}") + else: + raise RuntimeError(f"Failed to download {lang} model. Status code: {response.status_code}") + else: + logging.info(f"{lang} model already exists at {file_path}") + + return file_path + +def pp(log_score: float, length: float): + """Convert total log-probability to perplexity""" + return 10.0 ** (-log_score / length) + +@TaggerRegistry.add("ccnet_paragraph_w_doc_da") +class CCNetDa(BaseTagger): + def __init__(self): + model_bin_path = _get_ccnet_pretrained_lm("da") + self.model = kenlm.Model(model_bin_path) + + def predict(self, doc: Document) -> DocResult: + paragraphs = split_paragraphs(doc.text) + spans: list[Span] = [] + doc_log_prob: float = 0.0 + doc_length: float = 0.0 + for paragraph in paragraphs: + log_prob = self.model.score(paragraph.text) + length = len(paragraph.text.split()) + 1 + doc_log_prob += log_prob + doc_length += length + paragraph_span = Span( + start=paragraph.start, end=paragraph.end, type="perplexity", score=pp(log_prob, length) + ) + spans.append(paragraph_span) + + paragraph_span = Span( + start=0, end=len(doc.text), type="doc_perplexity", score=pp(doc_log_prob, doc_length) + ) + return DocResult(doc=doc, spans=spans) From ae5e021e0d960b7c1ea6d5028bb6f528481285c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Peter=20Bj=C3=B8rn=20J=C3=B8rgensen?= Date: Tue, 7 Nov 2023 14:49:14 +0100 Subject: [PATCH 06/21] add perplexity taggers for Scandinavian languages --- pyproject.toml | 6 +- src/dfm/__init__.py | 4 + src/dfm/common/__init__.py | 0 src/dfm/common/data_cleaning/__init__.py | 0 .../data_cleaning/dolma_taggers/__init__.py | 0 .../data_cleaning/dolma_taggers/perplexity.py | 43 +++- .../common/data_cleaning/text_normalizer.py | 189 ++++++++++++++++++ 7 files changed, 231 insertions(+), 11 deletions(-) create mode 100644 src/dfm/__init__.py create mode 100644 src/dfm/common/__init__.py create mode 100644 src/dfm/common/data_cleaning/__init__.py create mode 100644 src/dfm/common/data_cleaning/dolma_taggers/__init__.py create mode 100644 src/dfm/common/data_cleaning/text_normalizer.py diff --git a/pyproject.toml b/pyproject.toml index 64938b81..fdce88dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,8 +18,10 @@ requires-python = ">=3.10" dependencies = [ "pydantic>=2.4.2", # dolma does not work with very old versions of pydantic - "dolma@git+https://github.com/peterbjorgensen/dolma.git@extendable_tagger_cli", # Install from fork until pull request makes it into upstream dolma - "kenlm", # Used for dolma perplexity tagging + "dolma@git+https://github.com/allenai/dolma.git", # Install from git until a 0.9.2 package is released + "kenlm", # Used for perplexity tagging + "blingfire", # Used for perplexity tagging + "requests", ] [project.optional-dependencies] diff --git a/src/dfm/__init__.py b/src/dfm/__init__.py new file mode 100644 index 00000000..d01c3abd --- /dev/null +++ b/src/dfm/__init__.py @@ -0,0 +1,4 @@ +import importlib.metadata + +# Fetches the version of the package as defined in pyproject.toml +__version__ = importlib.metadata.version(__package__) diff --git a/src/dfm/common/__init__.py b/src/dfm/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/dfm/common/data_cleaning/__init__.py b/src/dfm/common/data_cleaning/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/dfm/common/data_cleaning/dolma_taggers/__init__.py b/src/dfm/common/data_cleaning/dolma_taggers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py b/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py index e04723b8..9731a0f7 100644 --- a/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py +++ b/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py @@ -6,15 +6,19 @@ import hashlib import logging import os -import requests +from typing import Type, TypeVar +import blingfire import kenlm +import requests from dolma.core.data_types import DocResult, Document, Span from dolma.core.registry import TaggerRegistry from dolma.core.taggers import BaseTagger from dolma.core.utils import split_paragraphs +from dfm.common.data_cleaning.text_normalizer import normalize + ccnet_sha256 = { "af.arpa.bin":"7278e70cb22e29e94942b103c0ba49f406a9369c2949199fdf8d4bee4b0ce48e", "ar.arpa.bin":"85739ba1e022a4abd9eb260e6c67e8a4e7646f0717e2800d8dde1ec039b7f5e2", @@ -96,24 +100,32 @@ def _get_ccnet_pretrained_lm(lang: str): return file_path -def pp(log_score: float, length: float): +def pp(log_score: float, length: float) -> float: """Convert total log-probability to perplexity""" return 10.0 ** (-log_score / length) -@TaggerRegistry.add("ccnet_paragraph_w_doc_da") -class CCNetDa(BaseTagger): - def __init__(self): - model_bin_path = _get_ccnet_pretrained_lm("da") +def create_ccnet_perplexity_tagger(lang: str) -> Type[BaseTagger]: + """Dynamically create tagger class for a given language""" + T = TypeVar("T") + def __init__(self: T) -> T: + model_bin_path = _get_ccnet_pretrained_lm(lang) self.model = kenlm.Model(model_bin_path) + return self - def predict(self, doc: Document) -> DocResult: + def predict(self: BaseTagger, doc: Document) -> DocResult: paragraphs = split_paragraphs(doc.text) spans: list[Span] = [] doc_log_prob: float = 0.0 doc_length: float = 0.0 for paragraph in paragraphs: - log_prob = self.model.score(paragraph.text) - length = len(paragraph.text.split()) + 1 + # To get proper scores from the language model we need to normalize the text + # Do not remove accents as it removes æøå and others. + normalized_text = blingfire.normalize_spaces(normalize(paragraph.text, accent=False)) + # The kenlm model expects end of sentence punctuation to be separated from words with spaces + # so we separate the words using blingfire. + normalized_words = blingfire.text_to_words(normalized_text) + log_prob = self.model.score(normalized_words) + length = len(normalized_words.split()) + 1 doc_log_prob += log_prob doc_length += length paragraph_span = Span( @@ -125,3 +137,16 @@ def predict(self, doc: Document) -> DocResult: start=0, end=len(doc.text), type="doc_perplexity", score=pp(doc_log_prob, doc_length) ) return DocResult(doc=doc, spans=spans) + + cls = type( + f"CCNetPerplexity{lang}", (BaseTagger, ), + { + "__init__": __init__, + "predict": predict + } + ) + cls = TaggerRegistry.add(f"ccnet_perplexity_paragraph_w_doc_{lang}")(cls) + return cls + +for lang in ["da", "en", "is", "no", "sv"]: + create_ccnet_perplexity_tagger(lang) diff --git a/src/dfm/common/data_cleaning/text_normalizer.py b/src/dfm/common/data_cleaning/text_normalizer.py new file mode 100644 index 00000000..cc2bebc4 --- /dev/null +++ b/src/dfm/common/data_cleaning/text_normalizer.py @@ -0,0 +1,189 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# + +import re +import unicodedata + +UNICODE_PUNCT = { + ",": ",", + "。": ".", + "、": ",", + "„": '"', + "”": '"', + "“": '"', + "«": '"', + "»": '"', + "1": '"', + "」": '"', + "「": '"', + "《": '"', + "》": '"', + "´": "'", + "∶": ":", + ":": ":", + "?": "?", + "!": "!", + "(": "(", + ")": ")", + ";": ";", + "–": "-", + "—": " - ", + ".": ". ", + "~": "~", + "’": "'", + "…": "...", + "━": "-", + "〈": "<", + "〉": ">", + "【": "[", + "】": "]", + "%": "%", + "►": "-", +} + +UNICODE_PUNCT_RE = re.compile(f"[{''.join(UNICODE_PUNCT.keys())}]") + + +def replace_unicode_punct(text: str) -> str: + return "".join((UNICODE_PUNCT.get(c, c) for c in text)) + + +def remove_unicode_punct(text: str) -> str: + """More aggressive version of replace_unicode_punct but also faster.""" + return UNICODE_PUNCT_RE.sub("", text) + + +def strip_accents(line: str) -> str: + """Strips accents from a piece of text.""" + nfd = unicodedata.normalize("NFD", line) + output = [c for c in nfd if unicodedata.category(c) != "Mn"] + if len(output) == line: + return line + return "".join(output) + + +# Build a regex matching all control characters. +NON_PRINTING_CHARS_RE = re.compile( + f"[{''.join(map(chr, list(range(0,32)) + list(range(127,160))))}]" +) +DIGIT_RE = re.compile(r"\d") +PUNCT_OR_NON_PRINTING_CHARS_RE = re.compile( + (UNICODE_PUNCT_RE.pattern + NON_PRINTING_CHARS_RE.pattern).replace("][", "") +) + + +def remove_non_printing_char(text: str) -> str: + return NON_PRINTING_CHARS_RE.sub("", text) + + +def normalize_spacing_for_tok(text: str, language: str = "en") -> str: + res = ( + text.replace("\r", "") + # remove extra spaces + .replace("(", " (") + .replace(")", ") ") + .replace(" +", " ") + ) + res = re.sub(r"\) ([\.\!\:\?\;\,])", r"\)\1", res) + res = res.replace("( ", "(").replace(" )", ")") + res = re.sub(r"(\d) \%", r"\1\%", res) + res = res.replace(" :", ":").replace(" ;", ";") + res = res.replace("`", "'").replace("''", ' " ') + + res = ( + res.replace("„", '"') + .replace("“", '"') + .replace("”", '"') + .replace("–", "-") + .replace("—", " - ") + .replace(" +", " ") + .replace("´", "'") + .replace("([a-z])‘([a-z])", r"\1'\2/") + .replace("([a-z])’([a-z])", r"\1'\2/") + .replace("‘", '"') + .replace("‚", '"') + .replace("’", '"') + .replace("''", '"') + .replace("´´", '"') + .replace("…", "...") + # French quotes + .replace(" « ", ' "') + .replace("« ", '"') + .replace("«", '"') + .replace(" » ", '" ') + .replace(" »", '"') + .replace("»", '"') + # handle pseudo-spaces + .replace(" %", "%") + .replace("nº ", "nº ") + .replace(" :", ":") + .replace(" ºC", " ºC") + .replace(" cm", " cm") + .replace(" ?", "?") + .replace(" !", "!") + .replace(" ;", ";") + .replace(", ", ", ") + .replace(" +", " ") + .replace(".", ". ") + ) + # English "quotation," followed by comma, style + if language == "en": + res = re.sub(r"\"([,\.]+)", r"\1\"", res) + # Czech is confused + elif language == "cs" or language == "cz": + pass + # German/Spanish/French "quotation", followed by comma, style + else: + res = res.replace(',"', '",') + res = re.sub( + r"(\.+)\"(\s*[^<])", r"\"\1\2", res + ) # don't fix period at end of sentence + + if ( + language == "de" + or language == "es" + or language == "cz" + or language == "cs" + or language == "fr" + ): + res = re.sub(r"(\d) (\d)", r"\1,\2", res) + else: + res = re.sub(r"(\d) (\d)", r"\1.\2", res) + return res + + +def normalize(line: str, accent=True, case=True, numbers=True, punct=1) -> str: + line = line.strip() + if not line: + return line + if case: + line = line.lower() + if accent: + line = strip_accents(line) + if numbers: + line = DIGIT_RE.sub("0", line) + if punct == 1: + line = replace_unicode_punct(line) + elif punct == 2: + line = remove_unicode_punct(line) + line = remove_non_printing_char(line) + return line + + +def slow_normalize_for_dedup(line: str) -> str: + return normalize(line, accent=False, case=True, numbers=True, punct=2) + + +def normalize_for_dedup(line: str) -> str: + line = line.strip() + if not line: + return line + # case + line = line.lower() + # numbers + line = DIGIT_RE.sub("0", line) + line = PUNCT_OR_NON_PRINTING_CHARS_RE.sub("", line) + return line From 3a18f4a52e2ca1038a3108d8293c4e0ff6c9b2c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Peter=20Bj=C3=B8rn=20J=C3=B8rgensen?= Date: Tue, 7 Nov 2023 15:19:58 +0100 Subject: [PATCH 07/21] add cmake and rust dependencies to Dockerfile --- Dockerfile.dev | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/Dockerfile.dev b/Dockerfile.dev index 2d19d84f..8bd57cd0 100644 --- a/Dockerfile.dev +++ b/Dockerfile.dev @@ -1,5 +1,23 @@ FROM python:3.11-bullseye +# Update default packages +RUN apt-get -qq update + +# Get Ubuntu packages +RUN apt-get install -y -q \ + build-essential \ + curl \ + cmake + +# NOTE: no need to run update again at this point +# RUN apt-get update + +# Get Rust; NOTE: using sh for better compatibility with other base images +RUN curl https://sh.rustup.rs -sSf | sh -s -- -y + +# Add .cargo/bin to PATH +ENV PATH="/root/.cargo/bin:${PATH}" + # Set the working directory to /app WORKDIR /app @@ -10,4 +28,4 @@ RUN make install # Install the app COPY . /app -RUN pip install -e . \ No newline at end of file +RUN pip install -e . From 801854453f1150e0925089520f3c8ffb616829c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Peter=20Bj=C3=B8rn=20J=C3=B8rgensen?= Date: Tue, 7 Nov 2023 16:28:34 +0100 Subject: [PATCH 08/21] fix some formatting --- .../dolma_taggers/language_scandi.py | 16 ++++++++-------- .../data_cleaning/dolma_taggers/perplexity.py | 13 ++++++------- src/dfm/common/data_cleaning/text_normalizer.py | 8 ++++---- 3 files changed, 18 insertions(+), 19 deletions(-) diff --git a/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py b/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py index 17307cce..46f38a64 100644 --- a/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py +++ b/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py @@ -66,7 +66,7 @@ def predict(self, doc: Document) -> DocResult: positive_span = Span(start=0, end=len(doc.text), type=lang_code, score=score) negative_span = Span( - start=0, end=len(doc.text), type=f"not_{lang_code}", score=1.0 - score + start=0, end=len(doc.text), type=f"not_{lang_code}", score=1.0 - score, ) spans.append(positive_span) spans.append(negative_span) @@ -86,7 +86,7 @@ def predict(self, doc: Document) -> DocResult: score = 0.0 positive_span = Span( - start=paragraph.start, end=paragraph.end, type=lang_code, score=score + start=paragraph.start, end=paragraph.end, type=lang_code, score=score, ) negative_span = Span( start=paragraph.start, @@ -104,12 +104,12 @@ class FastTextScandiLanguageDocumentTagger(BaseFastTextTagger): def __init__(self): super().__init__( - model_path=self.MODEL_PATH, model_mode=self.DOCUMENT_LEVEL_TAGGER + model_path=self.MODEL_PATH, model_mode=self.DOCUMENT_LEVEL_TAGGER, ) def predict_slice(self, text_slice: TextSlice) -> Iterable[Prediction]: pred = self.classifier.predict( - text_slice.text.lower().replace("\n", " ").strip(), k=-1 + text_slice.text.lower().replace("\n", " ").strip(), k=-1, ) # Initialize scores to 0 scores = {k: 0.0 for k in LANGS.values()} @@ -121,7 +121,7 @@ def predict_slice(self, text_slice: TextSlice) -> Iterable[Prediction]: scores[label_code] = score if label == "__label__da": return Prediction(label="da", score=score), Prediction( - label="not_da", score=1.0 - score + label="not_da", score=1.0 - score, ) predictions_positive = [Prediction(label=k, score=v) for k,v in scores.items()] @@ -133,7 +133,7 @@ def predict_slice(self, text_slice: TextSlice) -> Iterable[Prediction]: class FastTextScandiLanguageParagraphTagger(FastTextScandiLanguageDocumentTagger): def __init__(self): BaseFastTextTagger.__init__( - self, model_path=self.MODEL_PATH, model_mode=self.PARAGRAPH_LEVEL_TAGGER + self, model_path=self.MODEL_PATH, model_mode=self.PARAGRAPH_LEVEL_TAGGER, ) @@ -154,7 +154,7 @@ def add_global_language_score_from_slice_score(result: DocResult) -> DocResult: doc_level = ( Span(start=0, end=len(result.doc.text), type=f"doc_{lang}", score=doc_lang_score), Span( - start=0, end=len(result.doc.text), type=f"doc_not_{lang}", score=doc_not_lang_score + start=0, end=len(result.doc.text), type=f"doc_not_{lang}", score=doc_not_lang_score, ), ) result.spans.extend(doc_level) @@ -173,7 +173,7 @@ def predict(self, doc: Document) -> DocResult: # Composite tagger that provides both paragraph and doc scores @TaggerRegistry.add("ft_lang_id_scandi_paragraph_with_doc_score") class FastTextScandiLanguageParagraphWithDocScoreTagger( - FastTextScandiLanguageParagraphTagger + FastTextScandiLanguageParagraphTagger, ): def predict(self, doc: Document) -> DocResult: doc_result = super().predict(doc) diff --git a/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py b/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py index 9731a0f7..4a648bdb 100644 --- a/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py +++ b/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py @@ -11,7 +11,6 @@ import blingfire import kenlm import requests - from dolma.core.data_types import DocResult, Document, Span from dolma.core.registry import TaggerRegistry from dolma.core.taggers import BaseTagger @@ -90,7 +89,7 @@ def _get_ccnet_pretrained_lm(lang: str): sha256 = hashlib.sha256(response.content).hexdigest() if sha256 != ccnet_sha256[filename]: raise RuntimeError(f"Checksum mismatch {sha256} != {ccnet_sha256[filename]}") - with open(file_path, 'wb') as file: + with open(file_path, "wb") as file: file.write(response.content) logging.info(f"{lang} model downloaded and saved at {file_path}") else: @@ -104,7 +103,7 @@ def pp(log_score: float, length: float) -> float: """Convert total log-probability to perplexity""" return 10.0 ** (-log_score / length) -def create_ccnet_perplexity_tagger(lang: str) -> Type[BaseTagger]: +def create_ccnet_perplexity_tagger(lang: str) -> type[BaseTagger]: """Dynamically create tagger class for a given language""" T = TypeVar("T") def __init__(self: T) -> T: @@ -129,12 +128,12 @@ def predict(self: BaseTagger, doc: Document) -> DocResult: doc_log_prob += log_prob doc_length += length paragraph_span = Span( - start=paragraph.start, end=paragraph.end, type="perplexity", score=pp(log_prob, length) + start=paragraph.start, end=paragraph.end, type="perplexity", score=pp(log_prob, length), ) spans.append(paragraph_span) paragraph_span = Span( - start=0, end=len(doc.text), type="doc_perplexity", score=pp(doc_log_prob, doc_length) + start=0, end=len(doc.text), type="doc_perplexity", score=pp(doc_log_prob, doc_length), ) return DocResult(doc=doc, spans=spans) @@ -142,8 +141,8 @@ def predict(self: BaseTagger, doc: Document) -> DocResult: f"CCNetPerplexity{lang}", (BaseTagger, ), { "__init__": __init__, - "predict": predict - } + "predict": predict, + }, ) cls = TaggerRegistry.add(f"ccnet_perplexity_paragraph_w_doc_{lang}")(cls) return cls diff --git a/src/dfm/common/data_cleaning/text_normalizer.py b/src/dfm/common/data_cleaning/text_normalizer.py index cc2bebc4..057a248e 100644 --- a/src/dfm/common/data_cleaning/text_normalizer.py +++ b/src/dfm/common/data_cleaning/text_normalizer.py @@ -48,7 +48,7 @@ def replace_unicode_punct(text: str) -> str: - return "".join((UNICODE_PUNCT.get(c, c) for c in text)) + return "".join(UNICODE_PUNCT.get(c, c) for c in text) def remove_unicode_punct(text: str) -> str: @@ -67,11 +67,11 @@ def strip_accents(line: str) -> str: # Build a regex matching all control characters. NON_PRINTING_CHARS_RE = re.compile( - f"[{''.join(map(chr, list(range(0,32)) + list(range(127,160))))}]" + f"[{''.join(map(chr, list(range(32)) + list(range(127,160))))}]", ) DIGIT_RE = re.compile(r"\d") PUNCT_OR_NON_PRINTING_CHARS_RE = re.compile( - (UNICODE_PUNCT_RE.pattern + NON_PRINTING_CHARS_RE.pattern).replace("][", "") + (UNICODE_PUNCT_RE.pattern + NON_PRINTING_CHARS_RE.pattern).replace("][", ""), ) @@ -139,7 +139,7 @@ def normalize_spacing_for_tok(text: str, language: str = "en") -> str: else: res = res.replace(',"', '",') res = re.sub( - r"(\.+)\"(\s*[^<])", r"\"\1\2", res + r"(\.+)\"(\s*[^<])", r"\"\1\2", res, ) # don't fix period at end of sentence if ( From 3a409daeb5c0174fea0ad608c488723576bdfae9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Peter=20Bj=C3=B8rn=20J=C3=B8rgensen?= Date: Tue, 7 Nov 2023 16:45:10 +0100 Subject: [PATCH 09/21] make linters happy --- .../dolma_taggers/language_scandi.py | 67 +++++--- .../data_cleaning/dolma_taggers/perplexity.py | 143 ++++++++++-------- .../common/data_cleaning/text_normalizer.py | 17 ++- 3 files changed, 140 insertions(+), 87 deletions(-) diff --git a/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py b/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py index 46f38a64..a5d14fce 100644 --- a/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py +++ b/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py @@ -20,14 +20,14 @@ "SWEDISH": "sv", "NORWEGIAN": "no", "ICELANDIC": "is", - "FAROESE": "fo", # Note that FAROESE is not supported by cld2 + "FAROESE": "fo", # Note that FAROESE is not supported by cld2 } + @TaggerRegistry.add("cld2_scandi_doc") class Cld2LanguageFilterScandi(BaseTagger): RE_BAD_CHARS = regex.compile(r"[\p{Cc}\p{Cs}]+") - def _sanitize_input(self, text: str) -> str: return self.RE_BAD_CHARS.sub("", text) @@ -58,20 +58,23 @@ def _predict_text(self, text: str) -> dict[str, float]: def predict(self, doc: Document) -> DocResult: lang_scores = self._predict_text(doc.text) spans: list[Span] = [] - for lang, lang_code in LANGS.items(): - if lang_code in lang_scores: - score = lang_scores[lang_code] - else: - score = 0 + for lang_code in LANGS.values(): + score = lang_scores.get(lang_code, 0) - positive_span = Span(start=0, end=len(doc.text), type=lang_code, score=score) + positive_span = Span( + start=0, end=len(doc.text), type=lang_code, score=score + ) negative_span = Span( - start=0, end=len(doc.text), type=f"not_{lang_code}", score=1.0 - score, + start=0, + end=len(doc.text), + type=f"not_{lang_code}", + score=1.0 - score, ) spans.append(positive_span) spans.append(negative_span) return DocResult(doc=doc, spans=spans) + @TaggerRegistry.add("cld2_scandi_paragraph") class Cld2LanguageFilterParagraphScandi(Cld2LanguageFilterScandi): def predict(self, doc: Document) -> DocResult: @@ -80,13 +83,13 @@ def predict(self, doc: Document) -> DocResult: for paragraph in paragraphs: lang_scores = self._predict_text(paragraph.text) for lang_code in LANGS.values(): - if lang_code in lang_scores: - score = lang_scores[lang_code] - else: - score = 0.0 + score = lang_scores.get(lang_code, 0.0) positive_span = Span( - start=paragraph.start, end=paragraph.end, type=lang_code, score=score, + start=paragraph.start, + end=paragraph.end, + type=lang_code, + score=score, ) negative_span = Span( start=paragraph.start, @@ -104,12 +107,14 @@ class FastTextScandiLanguageDocumentTagger(BaseFastTextTagger): def __init__(self): super().__init__( - model_path=self.MODEL_PATH, model_mode=self.DOCUMENT_LEVEL_TAGGER, + model_path=self.MODEL_PATH, + model_mode=self.DOCUMENT_LEVEL_TAGGER, ) def predict_slice(self, text_slice: TextSlice) -> Iterable[Prediction]: pred = self.classifier.predict( - text_slice.text.lower().replace("\n", " ").strip(), k=-1, + text_slice.text.lower().replace("\n", " ").strip(), + k=-1, ) # Initialize scores to 0 scores = {k: 0.0 for k in LANGS.values()} @@ -121,19 +126,25 @@ def predict_slice(self, text_slice: TextSlice) -> Iterable[Prediction]: scores[label_code] = score if label == "__label__da": return Prediction(label="da", score=score), Prediction( - label="not_da", score=1.0 - score, + label="not_da", + score=1.0 - score, ) - predictions_positive = [Prediction(label=k, score=v) for k,v in scores.items()] - predictions_negative = [Prediction(label=k, score=1.0 - v) for k,v in scores.items()] + predictions_positive = [Prediction(label=k, score=v) for k, v in scores.items()] + predictions_negative = [ + Prediction(label=k, score=1.0 - v) for k, v in scores.items() + ] return predictions_positive + predictions_negative + @TaggerRegistry.add("ft_lang_id_scandi_paragraph") class FastTextScandiLanguageParagraphTagger(FastTextScandiLanguageDocumentTagger): def __init__(self): BaseFastTextTagger.__init__( - self, model_path=self.MODEL_PATH, model_mode=self.PARAGRAPH_LEVEL_TAGGER, + self, + model_path=self.MODEL_PATH, + model_mode=self.PARAGRAPH_LEVEL_TAGGER, ) @@ -152,9 +163,17 @@ def add_global_language_score_from_slice_score(result: DocResult) -> DocResult: doc_lang_score = doc_not_lang_score = 0.0 doc_level = ( - Span(start=0, end=len(result.doc.text), type=f"doc_{lang}", score=doc_lang_score), Span( - start=0, end=len(result.doc.text), type=f"doc_not_{lang}", score=doc_not_lang_score, + start=0, + end=len(result.doc.text), + type=f"doc_{lang}", + score=doc_lang_score, + ), + Span( + start=0, + end=len(result.doc.text), + type=f"doc_not_{lang}", + score=doc_not_lang_score, ), ) result.spans.extend(doc_level) @@ -163,7 +182,9 @@ def add_global_language_score_from_slice_score(result: DocResult) -> DocResult: # Composite tagger that provides both paragraph and doc scores @TaggerRegistry.add("cld2_scandi_paragraph_with_doc_score") -class Cld2LanguageFilterParagraphWithDocScoreTaggerScandi(Cld2LanguageFilterParagraphScandi): +class Cld2LanguageFilterParagraphWithDocScoreTaggerScandi( + Cld2LanguageFilterParagraphScandi +): def predict(self, doc: Document) -> DocResult: doc_result = super().predict(doc) doc_result = add_global_language_score_from_slice_score(doc_result) diff --git a/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py b/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py index 4a648bdb..bf8d56a1 100644 --- a/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py +++ b/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py @@ -6,7 +6,8 @@ import hashlib import logging import os -from typing import Type, TypeVar +from pathlib import Path +from typing import TypeVar import blingfire import kenlm @@ -19,93 +20,101 @@ from dfm.common.data_cleaning.text_normalizer import normalize ccnet_sha256 = { -"af.arpa.bin":"7278e70cb22e29e94942b103c0ba49f406a9369c2949199fdf8d4bee4b0ce48e", -"ar.arpa.bin":"85739ba1e022a4abd9eb260e6c67e8a4e7646f0717e2800d8dde1ec039b7f5e2", -"az.arpa.bin":"247fd2355db94b4357d19c78c8ac38ce16299d1dac237745edeea8005d7771ba", -"be.arpa.bin":"b23a70aa0cec41555932e6b4aaa5a361c95d091fbd6d4c21e6a48c866b9cd1e8", -"bg.arpa.bin":"1edb68d25238d692cb9cc6b2e4f9fce0e99b49b421020c8e89d0781507dbcd38", -"bn.arpa.bin":"f21c8187eb77d2d7d17892b61dc3446dab79a61d3d0af4f0c90660f9df500cb2", -"ca.arpa.bin":"1e4e84639fd9a35cbfa47709ca2cd9eefc84dcee7ab7d91df11e5f89f88312d4", -"cs.arpa.bin":"4f89f980c12cae596b19fccd9aebea4be5be86c6f81a8b42fc975922ea656bb1", -"da.arpa.bin":"b7f754b56421944ada2c979d0b11e8eada8308e179cb60fbc1acc4318b03695b", -"de.arpa.bin":"a5bc18a9741dc57593d7cce469350d5d2db8ce1e87be6c2ec450850316e586ba", -"el.arpa.bin":"8a53a69835d0a8e88c720fc052180c54973d2b6ac3ed2ff83c666d432a0d3686", -"en.arpa.bin":"e90c9b25af01dcaa2667ed45d012d891269760fc6eccfe8dbbd161eb20e01d7d", -"es.arpa.bin":"00121ab8c31f275132fc67c292392a33ff81b8eae1015103e8a86f9df2e642d4", -"et.arpa.bin":"7c4b98dc3f7fff73611afdd0dc1379437cb0b3dd3addc0abadb65864cabb937f", -"fa.arpa.bin":"05d00d4fdb31e00295a63e4df4187954d43850a8bd7b61c717f809b19fc94cfe", -"fi.arpa.bin":"56aa4a6890c4152be3d594e7f7dc353e78881500803f36586c1c01d88f906618", -"fr.arpa.bin":"4a52387916be57551013df3f9052ee031c042445940a4d0e69b066597586c6aa", -"gu.arpa.bin":"4ad5be86ef47f3105eb9d7d178520a0cede5d02e4ca61a3aa2d32c8322ca5bd1", -"he.arpa.bin":"69d1ab538beb6c8aa646b7c611b701ad2d1a19dcce00d6690072fa9453ad2f00", -"hi.arpa.bin":"b7173df087ff5b24d759fdbf8d07d8e21a31c1b54c978c7c5c71f05b24e12f47", -"hr.arpa.bin":"3ba8caf473415c4d12be594c36892f1454a71a08441ad796bf105ebe4e957a8f", -"hu.arpa.bin":"ce82ceb8a1e808fc441d985c4249c08c67d527937d26e3e524404185803723cf", -"hy.arpa.bin":"3c5c3511a82538ab198536e54df4e770c40d78bf5929a7143ab42695641a0031", -"id.arpa.bin":"8e871368fb386180df09d1dfb45f0319dba7a1955b9d209e498c49d96d07b3dd", -"is.arpa.bin":"287f6f7bd8130d50df8966169427b236e9aa79ff2b4250c5bdfdc2c9a0c19f52", -"it.arpa.bin":"784efb647bd699041809d59dd309193f78a47ea347d13b0c93c3bd74f437a53b", -"ja.arpa.bin":"efa96d229e2a84be705f81bc4ea1c6da79505e5c7f001f92586e16481e5b586a", -"ka.arpa.bin":"07477bd9166bc2c748532f1c3af65aad42740231c0dc1f8a4410764e0d626199", -"kk.arpa.bin":"3cec2b6c9b3ae34919dd23ff59148e81b76593d7ec17feefcd5e2829cd1643c0", -"km.arpa.bin":"84a09db4e1e7a70e1cd7c347d9729339e3eaa993f42b4bba4ba91fe0a84ff763", -"kn.arpa.bin":"f1e0e469c8c78ac4e3b62d348e966e658cf7b8f683aafa4a2b4d55ca1e7d756c", -"ko.arpa.bin":"7e345046786a1ac6dbb0d3d0fdd65d2ff0e8a848395dbc84c6152acee1987f5f", -"lt.arpa.bin":"ecc1703e098477503035d980f6be841b5359f8f5f55cc4f78087232c7da15398", -"lv.arpa.bin":"5f6212551d5de115309674eed8ea595f1375973832917dd285942a0ef8d6c7e7", -"mk.arpa.bin":"0915b0c452f5bc6dd254c4145fd09f1252ea5e17f13f48991c72cb98fa2ed804", -"ml.arpa.bin":"3f0cfbf0bdc6935229d6903df8cb60b4ed2b9ed2cb9d4c253266b13bd3211297", -"mn.arpa.bin":"c8e57fcf604d178d45fbe3b1650c04e715c41cb8151bf8b115dc88c52ebfba56", -"mr.arpa.bin":"e00986484585cd67deba5902c7da78566452e3c40fc9aa285218152563d33303", -"my.arpa.bin":"ac3496e2981ea3ad85673ca52e04f5aa8e7be68d1d94c2e73ce26436864ae217", -"ne.arpa.bin":"7ef6c2d3e4e1858fb207e6c200e422833ccf072157a6a0148b408db3e760d22e", -"nl.arpa.bin":"aa017d97061e84f51d7f74b83a6a43aef246974fc9a502436043f6f0e9e12bbb", -"no.arpa.bin":"0ec663c264d6580beebe7e0e80a939dbe7082af55af3875f292ebd11ea5800de", -"pl.arpa.bin":"b97634bca2b28d95716b951ceadca3de4a170ff07639bcdc3c73fc0961362e98", -"pt.arpa.bin":"f5a10774d7b7125c6e887b62c56fea2d348adebc81ab1708d34f68de722090e0", -"ro.arpa.bin":"619b9a2d4d53bdb368bfdf2cc770e1e9549d52b22d1fd3afc0ee8a022543ed56", -"ru.arpa.bin":"588da7d3e160f61f7e821804bc4d518460687e1c4832c339bb3a28c03417ab53", -"uk.arpa.bin":"bfd09bdfe669a9fd5f8f8d9be519bdce3fb678214bc6afd5ccce499930b7d311", -"zh.arpa.bin":"f157d94cb2828bbb44b5dddf38e7eb7f62a47d317917646a73fe2af50a3dad68", + "af.arpa.bin": "7278e70cb22e29e94942b103c0ba49f406a9369c2949199fdf8d4bee4b0ce48e", + "ar.arpa.bin": "85739ba1e022a4abd9eb260e6c67e8a4e7646f0717e2800d8dde1ec039b7f5e2", + "az.arpa.bin": "247fd2355db94b4357d19c78c8ac38ce16299d1dac237745edeea8005d7771ba", + "be.arpa.bin": "b23a70aa0cec41555932e6b4aaa5a361c95d091fbd6d4c21e6a48c866b9cd1e8", + "bg.arpa.bin": "1edb68d25238d692cb9cc6b2e4f9fce0e99b49b421020c8e89d0781507dbcd38", + "bn.arpa.bin": "f21c8187eb77d2d7d17892b61dc3446dab79a61d3d0af4f0c90660f9df500cb2", + "ca.arpa.bin": "1e4e84639fd9a35cbfa47709ca2cd9eefc84dcee7ab7d91df11e5f89f88312d4", + "cs.arpa.bin": "4f89f980c12cae596b19fccd9aebea4be5be86c6f81a8b42fc975922ea656bb1", + "da.arpa.bin": "b7f754b56421944ada2c979d0b11e8eada8308e179cb60fbc1acc4318b03695b", + "de.arpa.bin": "a5bc18a9741dc57593d7cce469350d5d2db8ce1e87be6c2ec450850316e586ba", + "el.arpa.bin": "8a53a69835d0a8e88c720fc052180c54973d2b6ac3ed2ff83c666d432a0d3686", + "en.arpa.bin": "e90c9b25af01dcaa2667ed45d012d891269760fc6eccfe8dbbd161eb20e01d7d", + "es.arpa.bin": "00121ab8c31f275132fc67c292392a33ff81b8eae1015103e8a86f9df2e642d4", + "et.arpa.bin": "7c4b98dc3f7fff73611afdd0dc1379437cb0b3dd3addc0abadb65864cabb937f", + "fa.arpa.bin": "05d00d4fdb31e00295a63e4df4187954d43850a8bd7b61c717f809b19fc94cfe", + "fi.arpa.bin": "56aa4a6890c4152be3d594e7f7dc353e78881500803f36586c1c01d88f906618", + "fr.arpa.bin": "4a52387916be57551013df3f9052ee031c042445940a4d0e69b066597586c6aa", + "gu.arpa.bin": "4ad5be86ef47f3105eb9d7d178520a0cede5d02e4ca61a3aa2d32c8322ca5bd1", + "he.arpa.bin": "69d1ab538beb6c8aa646b7c611b701ad2d1a19dcce00d6690072fa9453ad2f00", + "hi.arpa.bin": "b7173df087ff5b24d759fdbf8d07d8e21a31c1b54c978c7c5c71f05b24e12f47", + "hr.arpa.bin": "3ba8caf473415c4d12be594c36892f1454a71a08441ad796bf105ebe4e957a8f", + "hu.arpa.bin": "ce82ceb8a1e808fc441d985c4249c08c67d527937d26e3e524404185803723cf", + "hy.arpa.bin": "3c5c3511a82538ab198536e54df4e770c40d78bf5929a7143ab42695641a0031", + "id.arpa.bin": "8e871368fb386180df09d1dfb45f0319dba7a1955b9d209e498c49d96d07b3dd", + "is.arpa.bin": "287f6f7bd8130d50df8966169427b236e9aa79ff2b4250c5bdfdc2c9a0c19f52", + "it.arpa.bin": "784efb647bd699041809d59dd309193f78a47ea347d13b0c93c3bd74f437a53b", + "ja.arpa.bin": "efa96d229e2a84be705f81bc4ea1c6da79505e5c7f001f92586e16481e5b586a", + "ka.arpa.bin": "07477bd9166bc2c748532f1c3af65aad42740231c0dc1f8a4410764e0d626199", + "kk.arpa.bin": "3cec2b6c9b3ae34919dd23ff59148e81b76593d7ec17feefcd5e2829cd1643c0", + "km.arpa.bin": "84a09db4e1e7a70e1cd7c347d9729339e3eaa993f42b4bba4ba91fe0a84ff763", + "kn.arpa.bin": "f1e0e469c8c78ac4e3b62d348e966e658cf7b8f683aafa4a2b4d55ca1e7d756c", + "ko.arpa.bin": "7e345046786a1ac6dbb0d3d0fdd65d2ff0e8a848395dbc84c6152acee1987f5f", + "lt.arpa.bin": "ecc1703e098477503035d980f6be841b5359f8f5f55cc4f78087232c7da15398", + "lv.arpa.bin": "5f6212551d5de115309674eed8ea595f1375973832917dd285942a0ef8d6c7e7", + "mk.arpa.bin": "0915b0c452f5bc6dd254c4145fd09f1252ea5e17f13f48991c72cb98fa2ed804", + "ml.arpa.bin": "3f0cfbf0bdc6935229d6903df8cb60b4ed2b9ed2cb9d4c253266b13bd3211297", + "mn.arpa.bin": "c8e57fcf604d178d45fbe3b1650c04e715c41cb8151bf8b115dc88c52ebfba56", + "mr.arpa.bin": "e00986484585cd67deba5902c7da78566452e3c40fc9aa285218152563d33303", + "my.arpa.bin": "ac3496e2981ea3ad85673ca52e04f5aa8e7be68d1d94c2e73ce26436864ae217", + "ne.arpa.bin": "7ef6c2d3e4e1858fb207e6c200e422833ccf072157a6a0148b408db3e760d22e", + "nl.arpa.bin": "aa017d97061e84f51d7f74b83a6a43aef246974fc9a502436043f6f0e9e12bbb", + "no.arpa.bin": "0ec663c264d6580beebe7e0e80a939dbe7082af55af3875f292ebd11ea5800de", + "pl.arpa.bin": "b97634bca2b28d95716b951ceadca3de4a170ff07639bcdc3c73fc0961362e98", + "pt.arpa.bin": "f5a10774d7b7125c6e887b62c56fea2d348adebc81ab1708d34f68de722090e0", + "ro.arpa.bin": "619b9a2d4d53bdb368bfdf2cc770e1e9549d52b22d1fd3afc0ee8a022543ed56", + "ru.arpa.bin": "588da7d3e160f61f7e821804bc4d518460687e1c4832c339bb3a28c03417ab53", + "uk.arpa.bin": "bfd09bdfe669a9fd5f8f8d9be519bdce3fb678214bc6afd5ccce499930b7d311", + "zh.arpa.bin": "f157d94cb2828bbb44b5dddf38e7eb7f62a47d317917646a73fe2af50a3dad68", } -def _get_ccnet_pretrained_lm(lang: str): + +def _get_ccnet_pretrained_lm(lang: str) -> Path: # Download pretrained model and save to the data folder url = f"http://dl.fbaipublicfiles.com/cc_net/lm/{lang}.arpa.bin" - data_folder = "data_lm" + data_folder = Path("data_lm") - if not os.path.exists(data_folder): - os.makedirs(data_folder) + if not Path.exists(data_folder): + Path.mkdir(data_folder, parents=True) filename = f"{lang}.arpa.bin" - file_path = os.path.join(data_folder, filename) + file_path = data_folder / filename # Check if the file already exists - if not os.path.exists(file_path): + if not Path.exists(file_path): # If the file does not exist, download it logging.info(f"Downloading {lang} model...") response = requests.get(url) if response.status_code == requests.codes.ok: sha256 = hashlib.sha256(response.content).hexdigest() if sha256 != ccnet_sha256[filename]: - raise RuntimeError(f"Checksum mismatch {sha256} != {ccnet_sha256[filename]}") - with open(file_path, "wb") as file: + raise RuntimeError( + f"Checksum mismatch {sha256} != {ccnet_sha256[filename]}" + ) + with Path.open(file_path, "wb") as file: file.write(response.content) logging.info(f"{lang} model downloaded and saved at {file_path}") else: - raise RuntimeError(f"Failed to download {lang} model. Status code: {response.status_code}") + raise RuntimeError( + f"Failed to download {lang} model. Status code: {response.status_code}" + ) else: logging.info(f"{lang} model already exists at {file_path}") return file_path + def pp(log_score: float, length: float) -> float: """Convert total log-probability to perplexity""" return 10.0 ** (-log_score / length) + def create_ccnet_perplexity_tagger(lang: str) -> type[BaseTagger]: """Dynamically create tagger class for a given language""" T = TypeVar("T") + def __init__(self: T) -> T: model_bin_path = _get_ccnet_pretrained_lm(lang) self.model = kenlm.Model(model_bin_path) @@ -119,7 +128,9 @@ def predict(self: BaseTagger, doc: Document) -> DocResult: for paragraph in paragraphs: # To get proper scores from the language model we need to normalize the text # Do not remove accents as it removes æøå and others. - normalized_text = blingfire.normalize_spaces(normalize(paragraph.text, accent=False)) + normalized_text = blingfire.normalize_spaces( + normalize(paragraph.text, accent=False) + ) # The kenlm model expects end of sentence punctuation to be separated from words with spaces # so we separate the words using blingfire. normalized_words = blingfire.text_to_words(normalized_text) @@ -128,17 +139,24 @@ def predict(self: BaseTagger, doc: Document) -> DocResult: doc_log_prob += log_prob doc_length += length paragraph_span = Span( - start=paragraph.start, end=paragraph.end, type="perplexity", score=pp(log_prob, length), + start=paragraph.start, + end=paragraph.end, + type="perplexity", + score=pp(log_prob, length), ) spans.append(paragraph_span) paragraph_span = Span( - start=0, end=len(doc.text), type="doc_perplexity", score=pp(doc_log_prob, doc_length), + start=0, + end=len(doc.text), + type="doc_perplexity", + score=pp(doc_log_prob, doc_length), ) return DocResult(doc=doc, spans=spans) cls = type( - f"CCNetPerplexity{lang}", (BaseTagger, ), + f"CCNetPerplexity{lang}", + (BaseTagger,), { "__init__": __init__, "predict": predict, @@ -147,5 +165,6 @@ def predict(self: BaseTagger, doc: Document) -> DocResult: cls = TaggerRegistry.add(f"ccnet_perplexity_paragraph_w_doc_{lang}")(cls) return cls + for lang in ["da", "en", "is", "no", "sv"]: create_ccnet_perplexity_tagger(lang) diff --git a/src/dfm/common/data_cleaning/text_normalizer.py b/src/dfm/common/data_cleaning/text_normalizer.py index 057a248e..ed936725 100644 --- a/src/dfm/common/data_cleaning/text_normalizer.py +++ b/src/dfm/common/data_cleaning/text_normalizer.py @@ -1,11 +1,16 @@ +# This file has initially been copied from the ccnet repository from Facebook. +# # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # +# This file is full of ambigous characters, so disable ruff check for those. +# ruff: noqa: RUF001 import re import unicodedata +from typing import Literal UNICODE_PUNCT = { ",": ",", @@ -139,7 +144,9 @@ def normalize_spacing_for_tok(text: str, language: str = "en") -> str: else: res = res.replace(',"', '",') res = re.sub( - r"(\.+)\"(\s*[^<])", r"\"\1\2", res, + r"(\.+)\"(\s*[^<])", + r"\"\1\2", + res, ) # don't fix period at end of sentence if ( @@ -155,7 +162,13 @@ def normalize_spacing_for_tok(text: str, language: str = "en") -> str: return res -def normalize(line: str, accent=True, case=True, numbers=True, punct=1) -> str: +def normalize( + line: str, + accent: bool = True, + case: bool = True, + numbers: bool = True, + punct: Literal[1, 2] = 1, +) -> str: line = line.strip() if not line: return line From 34f0d9206ac9e6029e81a4de5ee1be68d5cd1d48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Peter=20Bj=C3=B8rn=20J=C3=B8rgensen?= Date: Tue, 7 Nov 2023 16:55:38 +0100 Subject: [PATCH 10/21] minor fixes --- src/dfm/common/data_cleaning/dolma_taggers/perplexity.py | 1 - src/dfm/common/data_cleaning/text_normalizer.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py b/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py index bf8d56a1..b2fcab2d 100644 --- a/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py +++ b/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py @@ -5,7 +5,6 @@ """ import hashlib import logging -import os from pathlib import Path from typing import TypeVar diff --git a/src/dfm/common/data_cleaning/text_normalizer.py b/src/dfm/common/data_cleaning/text_normalizer.py index ed936725..601e06c0 100644 --- a/src/dfm/common/data_cleaning/text_normalizer.py +++ b/src/dfm/common/data_cleaning/text_normalizer.py @@ -65,8 +65,6 @@ def strip_accents(line: str) -> str: """Strips accents from a piece of text.""" nfd = unicodedata.normalize("NFD", line) output = [c for c in nfd if unicodedata.category(c) != "Mn"] - if len(output) == line: - return line return "".join(output) From aaa5663f9881802947d497d25f499a5d71c00ee1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Peter=20Bj=C3=B8rn=20J=C3=B8rgensen?= Date: Wed, 8 Nov 2023 10:43:09 +0100 Subject: [PATCH 11/21] minor cleanup --- .../common/data_cleaning/dolma_taggers/language_scandi.py | 2 +- src/dfm/common/data_cleaning/dolma_taggers/perplexity.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py b/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py index a5d14fce..eff23c1c 100644 --- a/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py +++ b/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py @@ -49,7 +49,7 @@ def _predict_text(self, text: str) -> dict[str, float]: scores: dict[str, float] = {} if is_reliable: - for lang, lang_code, score, _ in details: + for lang, _, score, _ in details: if lang in LANGS: scores[LANGS[lang]] = score / 100.0 diff --git a/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py b/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py index b2fcab2d..232aeea7 100644 --- a/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py +++ b/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py @@ -112,7 +112,7 @@ def pp(log_score: float, length: float) -> float: def create_ccnet_perplexity_tagger(lang: str) -> type[BaseTagger]: """Dynamically create tagger class for a given language""" - T = TypeVar("T") + T = TypeVar("T", bound=BaseTagger) def __init__(self: T) -> T: model_bin_path = _get_ccnet_pretrained_lm(lang) @@ -153,6 +153,8 @@ def predict(self: BaseTagger, doc: Document) -> DocResult: ) return DocResult(doc=doc, spans=spans) + # Build the class dynamiccaly from base class + # and methods. cls = type( f"CCNetPerplexity{lang}", (BaseTagger,), @@ -161,6 +163,7 @@ def predict(self: BaseTagger, doc: Document) -> DocResult: "predict": predict, }, ) + # Add the class decorator explicitly to add the tagger to the registry cls = TaggerRegistry.add(f"ccnet_perplexity_paragraph_w_doc_{lang}")(cls) return cls From 6ba0fc818e7f81ed96483c3556d55bee979fc241 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Peter=20Bj=C3=B8rn=20J=C3=B8rgensen?= Date: Wed, 8 Nov 2023 11:27:20 +0100 Subject: [PATCH 12/21] convert Posix path to string for kenlm.Model and improve typing --- .../data_cleaning/dolma_taggers/perplexity.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py b/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py index 232aeea7..b7ae4d49 100644 --- a/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py +++ b/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py @@ -6,7 +6,7 @@ import hashlib import logging from pathlib import Path -from typing import TypeVar +from typing import Self, Any import blingfire import kenlm @@ -110,16 +110,21 @@ def pp(log_score: float, length: float) -> float: return 10.0 ** (-log_score / length) -def create_ccnet_perplexity_tagger(lang: str) -> type[BaseTagger]: - """Dynamically create tagger class for a given language""" - T = TypeVar("T", bound=BaseTagger) +class PerplexityBaseTagger(BaseTagger): + @property + def model(self: Self) -> kenlm.Model: + return self._model + @model.setter + def model(self: Self, model: kenlm.Model): + self._model = model - def __init__(self: T) -> T: +def create_ccnet_perplexity_tagger(lang: str) -> type[PerplexityBaseTagger]: + """Dynamically create tagger class for a given language""" + def __init__(self: Any) -> None: model_bin_path = _get_ccnet_pretrained_lm(lang) - self.model = kenlm.Model(model_bin_path) - return self + self.model = kenlm.Model(str(model_bin_path)) - def predict(self: BaseTagger, doc: Document) -> DocResult: + def predict(self: PerplexityBaseTagger, doc: Document) -> DocResult: paragraphs = split_paragraphs(doc.text) spans: list[Span] = [] doc_log_prob: float = 0.0 @@ -157,7 +162,7 @@ def predict(self: BaseTagger, doc: Document) -> DocResult: # and methods. cls = type( f"CCNetPerplexity{lang}", - (BaseTagger,), + (PerplexityBaseTagger,), { "__init__": __init__, "predict": predict, From ddf43fa9d15ed460febbfe5b23004511255c1e24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Peter=20Bj=C3=B8rn=20J=C3=B8rgensen?= Date: Wed, 8 Nov 2023 13:51:38 +0100 Subject: [PATCH 13/21] black format --- .../data_cleaning/dolma_taggers/language_scandi.py | 7 +++++-- .../common/data_cleaning/dolma_taggers/perplexity.py | 11 +++++++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py b/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py index eff23c1c..cfc124aa 100644 --- a/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py +++ b/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py @@ -62,7 +62,10 @@ def predict(self, doc: Document) -> DocResult: score = lang_scores.get(lang_code, 0) positive_span = Span( - start=0, end=len(doc.text), type=lang_code, score=score + start=0, + end=len(doc.text), + type=lang_code, + score=score, ) negative_span = Span( start=0, @@ -183,7 +186,7 @@ def add_global_language_score_from_slice_score(result: DocResult) -> DocResult: # Composite tagger that provides both paragraph and doc scores @TaggerRegistry.add("cld2_scandi_paragraph_with_doc_score") class Cld2LanguageFilterParagraphWithDocScoreTaggerScandi( - Cld2LanguageFilterParagraphScandi + Cld2LanguageFilterParagraphScandi, ): def predict(self, doc: Document) -> DocResult: doc_result = super().predict(doc) diff --git a/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py b/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py index b7ae4d49..12a76d54 100644 --- a/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py +++ b/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py @@ -6,7 +6,7 @@ import hashlib import logging from pathlib import Path -from typing import Self, Any +from typing import Any, Self import blingfire import kenlm @@ -90,14 +90,14 @@ def _get_ccnet_pretrained_lm(lang: str) -> Path: sha256 = hashlib.sha256(response.content).hexdigest() if sha256 != ccnet_sha256[filename]: raise RuntimeError( - f"Checksum mismatch {sha256} != {ccnet_sha256[filename]}" + f"Checksum mismatch {sha256} != {ccnet_sha256[filename]}", ) with Path.open(file_path, "wb") as file: file.write(response.content) logging.info(f"{lang} model downloaded and saved at {file_path}") else: raise RuntimeError( - f"Failed to download {lang} model. Status code: {response.status_code}" + f"Failed to download {lang} model. Status code: {response.status_code}", ) else: logging.info(f"{lang} model already exists at {file_path}") @@ -114,12 +114,15 @@ class PerplexityBaseTagger(BaseTagger): @property def model(self: Self) -> kenlm.Model: return self._model + @model.setter def model(self: Self, model: kenlm.Model): self._model = model + def create_ccnet_perplexity_tagger(lang: str) -> type[PerplexityBaseTagger]: """Dynamically create tagger class for a given language""" + def __init__(self: Any) -> None: model_bin_path = _get_ccnet_pretrained_lm(lang) self.model = kenlm.Model(str(model_bin_path)) @@ -133,7 +136,7 @@ def predict(self: PerplexityBaseTagger, doc: Document) -> DocResult: # To get proper scores from the language model we need to normalize the text # Do not remove accents as it removes æøå and others. normalized_text = blingfire.normalize_spaces( - normalize(paragraph.text, accent=False) + normalize(paragraph.text, accent=False), ) # The kenlm model expects end of sentence punctuation to be separated from words with spaces # so we separate the words using blingfire. From 530305f5f8d349f1e2f8b3c452983f6d64adc776 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Peter=20Bj=C3=B8rn=20J=C3=B8rgensen?= Date: Mon, 13 Nov 2023 15:45:08 +0100 Subject: [PATCH 14/21] Add type stubs for external libraries (only partially typed for the functions we use) --- .../dolma_taggers/language_scandi.py | 6 +- typings/blingfire/__init__.pyi | 54 ++++ typings/fasttext/FastText.pyi | 264 ++++++++++++++++++ typings/fasttext/__init__.pyi | 17 ++ typings/fasttext/tests/__init__.pyi | 7 + .../fasttext/tests/test_configurations.pyi | 26 ++ typings/fasttext/tests/test_script.pyi | 82 ++++++ typings/fasttext/util/__init__.pyi | 6 + typings/fasttext/util/util.pyi | 39 +++ typings/kenlm/__init__.pyi | 7 + typings/pycld2/__init__.pyi | 21 ++ 11 files changed, 527 insertions(+), 2 deletions(-) create mode 100644 typings/blingfire/__init__.pyi create mode 100644 typings/fasttext/FastText.pyi create mode 100644 typings/fasttext/__init__.pyi create mode 100644 typings/fasttext/tests/__init__.pyi create mode 100644 typings/fasttext/tests/test_configurations.pyi create mode 100644 typings/fasttext/tests/test_script.pyi create mode 100644 typings/fasttext/util/__init__.pyi create mode 100644 typings/fasttext/util/util.pyi create mode 100644 typings/kenlm/__init__.pyi create mode 100644 typings/pycld2/__init__.pyi diff --git a/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py b/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py index cfc124aa..11915793 100644 --- a/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py +++ b/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py @@ -38,11 +38,13 @@ def _identity_fn(self, text: str) -> str: return text def _predict_text(self, text: str) -> dict[str, float]: - details = [] is_reliable = False + details: Iterable[tuple[str, str, int, float]] = [] for fn in (self._identity_fn, self._to_ascii_input, self._sanitize_input): try: - is_reliable, _, details = cld2.detect(fn(text)) + retvals = cld2.detect(fn(text)) + assert len(retvals) == 3 + is_reliable, _, details = retvals break except cld2.error: ... diff --git a/typings/blingfire/__init__.pyi b/typings/blingfire/__init__.pyi new file mode 100644 index 00000000..20da73f7 --- /dev/null +++ b/typings/blingfire/__init__.pyi @@ -0,0 +1,54 @@ +""" +This type stub file was generated by pyright. +""" +# def text_to_sentences(s: str): # -> Any | Literal['']: +# ... +# +# def text_to_sentences_with_model(h, s): # -> Any | Literal['']: +# ... + +def normalize_spaces(s: str, uSpace: int = 0x20) -> str: # -> Any | Literal['']: + ... + +def text_to_words(s: str) -> str: # -> Any | Literal['']: + ... + +# Uncomment lines that are used in project +# def text_to_words_with_model(h, s): # -> Any | Literal['']: +# ... +# +# def word_hyphenation_with_model(h, s, uHy=...): # -> Any | Literal['']: +# ... +# +# def get_blingfiretok_version(): # -> Any: +# ... +# +# def text_to_hashes(s, word_n_grams, bucketSize): # -> NDArray[Any] | None: +# ... +# +# def text_to_token_with_offsets(s, text_to_token_f, split_byte): # -> tuple[Literal[''], list[Unknown]] | tuple[Any, list[tuple[Unknown, Unknown]]]: +# ... +# +# def text_to_words_with_offsets(s): # -> tuple[Literal[''], list[Unknown]] | tuple[Any, list[tuple[Unknown, Unknown]]]: +# ... +# +# def text_to_sentences_and_offsets(s): # -> tuple[Literal[''], list[Unknown]] | tuple[Any, list[tuple[Unknown, Unknown]]]: +# ... +# +# def load_model(file_name): # -> Any: +# ... +# +# def free_model(h): # -> None: +# ... +# +# def text_to_ids(h, s, max_len, unk=..., no_padding=...): # -> NDArray[Any]: +# ... +# +# def ids_to_text(h, ids, skip_special_tokens=..., output_buffer_size=...): # -> Any | Literal['']: +# ... +# +# def utf8text_to_ids_with_offsets(h, s_bytes, max_len, unk=..., no_padding=...): # -> tuple[NDArray[Any], NDArray[Any], NDArray[Any]]: +# ... +# +# def change_settings_dummy_prefix(h, add_prefix): # -> None: +# ... diff --git a/typings/fasttext/FastText.pyi b/typings/fasttext/FastText.pyi new file mode 100644 index 00000000..4dde0e7b --- /dev/null +++ b/typings/fasttext/FastText.pyi @@ -0,0 +1,264 @@ +""" +This type stub file was initially generated by pyright +""" +from typing import Iterable + +loss_name = ... +model_name = ... +EOS = ... +BOW = ... +EOW = ... +displayed_errors = ... + +def eprint(*args, **kwargs): # -> None: + ... + +class _Meter: + def __init__(self, fasttext_model, meter) -> None: ... + def score_vs_true(self, label): # -> tuple[NDArray[Unknown], NDArray[Any]]: + """Return scores and the gold of each sample for a specific label""" + ... + def precision_recall_curve( + self, label=... + ): # -> tuple[NDArray[Unknown], NDArray[Any]]: + """Return precision/recall curve""" + ... + def precision_at_recall(self, recall, label=...): + """Return precision for a given recall""" + ... + def recall_at_precision(self, precision, label=...): + """Return recall for a given precision""" + ... + +class _FastText: + """ + This class defines the API to inspect models and should not be used to + create objects. It will be returned by functions such as load_model or + train. + + In general this API assumes to be given only unicode for Python2 and the + Python3 equvalent called str for any string-like arguments. All unicode + strings are then encoded as UTF-8 and fed to the fastText C++ API. + """ + + def __init__(self, model_path=..., args=...) -> None: ... + def set_args(self, args=...): # -> None: + ... + def is_quantized(self): ... + def get_dimension(self): + """Get the dimension (size) of a lookup vector (hidden layer).""" + ... + def get_word_vector(self, word): # -> NDArray[Unknown]: + """Get the vector representation of word.""" + ... + def get_sentence_vector(self, text): # -> NDArray[Unknown]: + """ + Given a string, get a single vector represenation. This function + assumes to be given a single line of text. We split words on + whitespace (space, newline, tab, vertical tab) and the control + characters carriage return, formfeed and the null character. + """ + ... + def get_nearest_neighbors(self, word, k=..., on_unicode_error=...): ... + def get_analogies(self, wordA, wordB, wordC, k=..., on_unicode_error=...): ... + def get_word_id(self, word): + """ + Given a word, get the word id within the dictionary. + Returns -1 if word is not in the dictionary. + """ + ... + def get_label_id(self, label): + """ + Given a label, get the label id within the dictionary. + Returns -1 if label is not in the dictionary. + """ + ... + def get_subword_id(self, subword): + """ + Given a subword, return the index (within input matrix) it hashes to. + """ + ... + def get_subwords( + self, word, on_unicode_error=... + ): # -> tuple[Unknown, NDArray[Unknown]]: + """ + Given a word, get the subwords and their indicies. + """ + ... + def get_input_vector(self, ind): # -> NDArray[Unknown]: + """ + Given an index, get the corresponding vector of the Input Matrix. + """ + ... + def predict( + self, + text: str, + k: int = ..., + threshold: float = ..., + on_unicode_error: str = ..., + ) -> Iterable[ + tuple[str, float] + ]: # -> tuple[Unknown, Unknown] | tuple[Any | tuple[()], NDArray[Unknown]]: + """ + Given a string, get a list of labels and a list of + corresponding probabilities. k controls the number + of returned labels. A choice of 5, will return the 5 + most probable labels. By default this returns only + the most likely label and probability. threshold filters + the returned labels by a threshold on probability. A + choice of 0.5 will return labels with at least 0.5 + probability. k and threshold will be applied together to + determine the returned labels. + + This function assumes to be given + a single line of text. We split words on whitespace (space, + newline, tab, vertical tab) and the control characters carriage + return, formfeed and the null character. + + If the model is not supervised, this function will throw a ValueError. + + If given a list of strings, it will return a list of results as usually + received for a single line of text. + """ + ... + def get_input_matrix(self): # -> NDArray[Unknown]: + """ + Get a reference to the full input matrix of a Model. This only + works if the model is not quantized. + """ + ... + def get_output_matrix(self): # -> NDArray[Unknown]: + """ + Get a reference to the full output matrix of a Model. This only + works if the model is not quantized. + """ + ... + def get_words( + self, include_freq=..., on_unicode_error=... + ): # -> tuple[Unknown, NDArray[Unknown]]: + """ + Get the entire list of words of the dictionary optionally + including the frequency of the individual words. This + does not include any subwords. For that please consult + the function get_subwords. + """ + ... + def get_labels( + self, include_freq=..., on_unicode_error=... + ): # -> tuple[Unknown, NDArray[Unknown]]: + """ + Get the entire list of labels of the dictionary optionally + including the frequency of the individual labels. Unsupervised + models use words as labels, which is why get_labels + will call and return get_words for this type of + model. + """ + ... + def get_line(self, text, on_unicode_error=...): + """ + Split a line of text into words and labels. Labels must start with + the prefix used to create the model (__label__ by default). + """ + ... + def save_model(self, path): # -> None: + """Save the model to the given path""" + ... + def test(self, path, k=..., threshold=...): + """Evaluate supervised model using file given by path""" + ... + def test_label(self, path, k=..., threshold=...): + """ + Return the precision and recall score for each label. + + The returned value is a dictionary, where the key is the label. + For example: + f.test_label(...) + {'__label__italian-cuisine' : {'precision' : 0.7, 'recall' : 0.74}} + """ + ... + def get_meter(self, path, k=...): # -> _Meter: + ... + def quantize( + self, + input=..., + qout=..., + cutoff=..., + retrain=..., + epoch=..., + lr=..., + thread=..., + verbose=..., + dsub=..., + qnorm=..., + ): # -> None: + """ + Quantize the model reducing the size of the model and + it's memory footprint. + """ + ... + def set_matrices(self, input_matrix, output_matrix): # -> None: + """ + Set input and output matrices. This function assumes you know what you + are doing. + """ + ... + @property + def words(self): # -> tuple[Unknown, NDArray[Unknown]]: + ... + @property + def labels(self): # -> tuple[Unknown, NDArray[Unknown]]: + ... + def __getitem__(self, word): # -> NDArray[Unknown]: + ... + def __contains__(self, word): # -> bool: + ... + +def tokenize(text): + """Given a string of text, tokenize it and return a list of tokens""" + ... + +def load_model(path): # -> _FastText: + """Load a model given a filepath and return a model object.""" + ... + +unsupervised_default = ... + +def read_args( + arg_list, arg_dict, arg_names, default_values +): # -> tuple[dict[Unknown, Unknown], set[Unknown]]: + ... + +def train_supervised(*kargs, **kwargs): # -> _FastText: + """ + Train a supervised model and return a model object. + + input must be a filepath. The input text does not need to be tokenized + as per the tokenize function, but it must be preprocessed and encoded + as UTF-8. You might want to consult standard preprocessing scripts such + as tokenizer.perl mentioned here: http://www.statmt.org/wmt07/baseline.html + + The input file must must contain at least one label per line. For an + example consult the example datasets which are part of the fastText + repository such as the dataset pulled by classification-example.sh. + """ + ... + +def train_unsupervised(*kargs, **kwargs): # -> _FastText: + """ + Train an unsupervised model and return a model object. + + input must be a filepath. The input text does not need to be tokenized + as per the tokenize function, but it must be preprocessed and encoded + as UTF-8. You might want to consult standard preprocessing scripts such + as tokenizer.perl mentioned here: http://www.statmt.org/wmt07/baseline.html + + The input field must not contain any labels or use the specified label prefix + unless it is ok for those words to be ignored. For an example consult the + dataset pulled by the example script word-vector-example.sh, which is + part of the fastText repository. + """ + ... + +def cbow(*kargs, **kwargs): ... +def skipgram(*kargs, **kwargs): ... +def supervised(*kargs, **kwargs): ... diff --git a/typings/fasttext/__init__.pyi b/typings/fasttext/__init__.pyi new file mode 100644 index 00000000..cbf98ace --- /dev/null +++ b/typings/fasttext/__init__.pyi @@ -0,0 +1,17 @@ +""" +This type stub file was generated by pyright. +""" + +from __future__ import absolute_import, division, print_function, unicode_literals +from .FastText import ( + BOW, + EOS, + EOW, + cbow, + load_model, + skipgram, + supervised, + tokenize, + train_supervised, + train_unsupervised, +) diff --git a/typings/fasttext/tests/__init__.pyi b/typings/fasttext/tests/__init__.pyi new file mode 100644 index 00000000..8cec2d4a --- /dev/null +++ b/typings/fasttext/tests/__init__.pyi @@ -0,0 +1,7 @@ +""" +This type stub file was generated by pyright. +""" + +from __future__ import absolute_import, division, print_function, unicode_literals +from .test_configurations import get_supervised_models +from .test_script import gen_tests, gen_unit_tests diff --git a/typings/fasttext/tests/test_configurations.pyi b/typings/fasttext/tests/test_configurations.pyi new file mode 100644 index 00000000..6930c565 --- /dev/null +++ b/typings/fasttext/tests/test_configurations.pyi @@ -0,0 +1,26 @@ +""" +This type stub file was generated by pyright. +""" + +def max_thread(): # -> int: + ... + +def check_supervised_configuration(configuration, verbose=...): ... +def check_supervised_configurations(configurations, verbose=...): ... +def flickr_job(thread=...): # -> dict[Unknown, Unknown]: + ... + +def langid_job1(thread=...): # -> dict[Unknown, Unknown]: + ... + +def langid_job2(thread=...): # -> dict[Unknown, Unknown]: + ... + +def cooking_job1(thread=...): # -> dict[Unknown, Unknown]: + ... + +def cooking_job2(thread=...): # -> dict[Unknown, Unknown]: + ... + +def get_supervised_models(thread=..., verbose=...): # -> list[Unknown]: + ... diff --git a/typings/fasttext/tests/test_script.pyi b/typings/fasttext/tests/test_script.pyi new file mode 100644 index 00000000..632d8d85 --- /dev/null +++ b/typings/fasttext/tests/test_script.pyi @@ -0,0 +1,82 @@ +""" +This type stub file was generated by pyright. +""" + +import unittest + +def eprint(cls, *args, **kwargs): # -> None: + ... + +def get_random_unicode(length): # -> str: + ... + +def get_random_words(N, a=..., b=..., unique=...): # -> list[Unknown]: + ... + +def get_random_data( + num_lines=..., + max_vocab_size=..., + min_words_line=..., + max_words_line=..., + min_len_word=..., + max_len_word=..., + unique_words=..., +): # -> list[Unknown]: + ... + +def default_kwargs(kwargs): ... +def build_unsupervised_model(data, kwargs): # -> _FastText: + ... + +def build_supervised_model(data, kwargs): # -> _FastText: + ... + +def read_labels(data_file): # -> tuple[list[Unknown], list[Unknown]]: + ... + +class TestFastTextUnitPy(unittest.TestCase): + def gen_test_get_vector(self, kwargs): # -> None: + ... + def gen_test_multi_get_line(self, kwargs): # -> None: + ... + def gen_test_supervised_util_test(self, kwargs): # -> None: + ... + def gen_test_supervised_predict(self, kwargs): # -> None: + ... + def gen_test_supervised_multiline_predict(self, kwargs): # -> None: + ... + def gen_test_vocab(self, kwargs): # -> None: + ... + def gen_test_subwords(self, kwargs): # -> None: + ... + def gen_test_tokenize(self, kwargs): # -> None: + ... + def gen_test_unsupervised_dimension(self, kwargs): # -> None: + ... + def gen_test_supervised_dimension(self, kwargs): # -> None: + ... + def gen_test_subword_vector(self, kwargs): # -> None: + ... + def gen_test_unsupervised_get_words(self, kwargs): # -> None: + ... + def gen_test_supervised_get_words(self, kwargs): # -> None: + ... + def gen_test_unsupervised_get_labels(self, kwargs): # -> None: + ... + def gen_test_supervised_get_labels(self, kwargs): # -> None: + ... + def gen_test_unsupervised_exercise_is_quant(self, kwargs): # -> None: + ... + def gen_test_supervised_exercise_is_quant(self, kwargs): # -> None: + ... + def gen_test_newline_predict_sentence(self, kwargs): # -> None: + ... + +def gen_sup_test(configuration, data_dir): # -> (self: Unknown) -> None: + ... + +def gen_unit_tests(verbose=...): # -> type[TestFastTextUnitPy]: + ... + +def gen_tests(data_dir, verbose=...): # -> type[TestFastTextPy]: + class TestFastTextPy(unittest.TestCase): ... diff --git a/typings/fasttext/util/__init__.pyi b/typings/fasttext/util/__init__.pyi new file mode 100644 index 00000000..87465a71 --- /dev/null +++ b/typings/fasttext/util/__init__.pyi @@ -0,0 +1,6 @@ +""" +This type stub file was generated by pyright. +""" + +from __future__ import absolute_import, division, print_function, unicode_literals +from .util import download_model, find_nearest_neighbor, reduce_model, test diff --git a/typings/fasttext/util/util.pyi b/typings/fasttext/util/util.pyi new file mode 100644 index 00000000..16d2c81a --- /dev/null +++ b/typings/fasttext/util/util.pyi @@ -0,0 +1,39 @@ +""" +This type stub file was generated by pyright. +""" + +valid_lang_ids = ... + +def test(predictions, labels, k=...): # -> tuple[float, float]: + """ + Return precision and recall modeled after fasttext's test + """ + ... + +def find_nearest_neighbor(query, vectors, ban_set, cossims=...): # -> Any: + """ + query is a 1d numpy array corresponding to the vector to which you want to + find the closest vector + vectors is a 2d numpy array corresponding to the vectors you want to consider + ban_set is a set of indicies within vectors you want to ignore for nearest match + cossims is a 1d numpy array of size len(vectors), which can be passed for efficiency + + returns the index of the closest match to query within vectors + + """ + ... + +def reduce_model(ft_model, target_dim): + """ + ft_model is an instance of `_FastText` class + This function computes the PCA of the input and the output matrices + and sets the reduced ones. + """ + ... + +def download_model(lang_id, if_exists=..., dimension=...): # -> None: + """ + Download pre-trained common-crawl vectors from fastText's website + https://fasttext.cc/docs/en/crawl-vectors.html + """ + ... diff --git a/typings/kenlm/__init__.pyi b/typings/kenlm/__init__.pyi new file mode 100644 index 00000000..40f566db --- /dev/null +++ b/typings/kenlm/__init__.pyi @@ -0,0 +1,7 @@ +""" +Type stub for kenlm +""" + +class Model: + def __init__(self, model_bin_path: str) -> None: ... + def score(self, sentence: str) -> float: ... diff --git a/typings/pycld2/__init__.pyi b/typings/pycld2/__init__.pyi new file mode 100644 index 00000000..5b478a63 --- /dev/null +++ b/typings/pycld2/__init__.pyi @@ -0,0 +1,21 @@ +""" +Type stub file for pycld2 +""" + +from typing import Union, TypeAlias + +from pycld2 import DETECTED_LANGUAGES, ENCODINGS, LANGUAGES, VERSION, __version__, error + +IsReliable: TypeAlias = bool +TextBytesFound: TypeAlias = int +DetectDetails: TypeAlias = tuple[tuple[str, str, int, float], ...] +Vectors: TypeAlias = tuple[tuple[int, int, str, str], ...] + +def detect( + text: str, returnVectors: bool = False +) -> Union[ + tuple[IsReliable, TextBytesFound, DetectDetails], + tuple[IsReliable, TextBytesFound, DetectDetails, Vectors], +]: ... + +__all__ = ("DETECTED_LANGUAGES", "ENCODINGS", "LANGUAGES", "VERSION", "detect", "error") From 48fcd12ec378c5ab4d841687be00050de514d4a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Peter=20Bj=C3=B8rn=20J=C3=B8rgensen?= Date: Mon, 13 Nov 2023 16:52:15 +0100 Subject: [PATCH 15/21] fix bug in fasttext Scandinavian tagger --- .../common/data_cleaning/dolma_taggers/language_scandi.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py b/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py index 11915793..1e8e2b8f 100644 --- a/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py +++ b/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py @@ -129,15 +129,10 @@ def predict_slice(self, text_slice: TextSlice) -> Iterable[Prediction]: label_code = label[-2:] if label_code in scores: scores[label_code] = score - if label == "__label__da": - return Prediction(label="da", score=score), Prediction( - label="not_da", - score=1.0 - score, - ) predictions_positive = [Prediction(label=k, score=v) for k, v in scores.items()] predictions_negative = [ - Prediction(label=k, score=1.0 - v) for k, v in scores.items() + Prediction(label=f"not_{k}", score=1.0 - v) for k, v in scores.items() ] return predictions_positive + predictions_negative From cb13d62dc8c3f7ef2525c24f2f94b09593425d96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Peter=20Bj=C3=B8rn=20J=C3=B8rgensen?= Date: Tue, 14 Nov 2023 11:18:02 +0100 Subject: [PATCH 16/21] add version numbers to dependencies --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fdce88dd..aacb6838 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,9 +19,9 @@ requires-python = ">=3.10" dependencies = [ "pydantic>=2.4.2", # dolma does not work with very old versions of pydantic "dolma@git+https://github.com/allenai/dolma.git", # Install from git until a 0.9.2 package is released - "kenlm", # Used for perplexity tagging - "blingfire", # Used for perplexity tagging - "requests", + "kenlm>=0.2.0", # Used for perplexity tagging + "blingfire>=0.1.8", # Used for perplexity tagging + "requests>=2.31.0", ] [project.optional-dependencies] From a10c5d2570e7f8e19a856cdeb32c9f188ad89e64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Peter=20Bj=C3=B8rn=20J=C3=B8rgensen?= Date: Tue, 14 Nov 2023 16:02:54 +0100 Subject: [PATCH 17/21] add docstrings for language and perplexity taggers --- .../dolma_taggers/language_scandi.py | 45 ++++++++++++++++++- .../data_cleaning/dolma_taggers/perplexity.py | 19 +++++++- 2 files changed, 61 insertions(+), 3 deletions(-) diff --git a/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py b/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py index 1e8e2b8f..c39df461 100644 --- a/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py +++ b/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py @@ -20,12 +20,16 @@ "SWEDISH": "sv", "NORWEGIAN": "no", "ICELANDIC": "is", - "FAROESE": "fo", # Note that FAROESE is not supported by cld2 + "FAROESE": "fo", # Note that FAROESE is not supported by cld2 or fasttext } @TaggerRegistry.add("cld2_scandi_doc") class Cld2LanguageFilterScandi(BaseTagger): + """This tagger runs the Compact Language Detect 2 model on a full document + and will return a score between 0 and 1 for each language in LANGS. + It uses the pretrained model from the pycld2 package.""" + RE_BAD_CHARS = regex.compile(r"[\p{Cc}\p{Cs}]+") def _sanitize_input(self, text: str) -> str: @@ -82,6 +86,9 @@ def predict(self, doc: Document) -> DocResult: @TaggerRegistry.add("cld2_scandi_paragraph") class Cld2LanguageFilterParagraphScandi(Cld2LanguageFilterScandi): + """This tagger runs the Compact Language Detect 2 model on each paragraph, + and will save a score between 0 and 1 for each language in LANGS""" + def predict(self, doc: Document) -> DocResult: paragraphs = split_paragraphs(doc.text) spans: list[Span] = [] @@ -108,6 +115,29 @@ def predict(self, doc: Document) -> DocResult: @TaggerRegistry.add("ft_lang_id_scandi_doc") class FastTextScandiLanguageDocumentTagger(BaseFastTextTagger): + """This tagger runs the FastText language detection model on each document. + The score is between 0 and 1 and provided for each language in LANGS. + + The method is described in the following papers: + + @article{joulin2016bag, + title={Bag of Tricks for Efficient Text Classification}, + author={Joulin, Armand and Grave, Edouard and Bojanowski, Piotr and Mikolov, Tomas}, + journal={arXiv preprint arXiv:1607.01759}, + year={2016} + } + @article{joulin2016fasttext, + title={FastText.zip: Compressing text classification models}, + author={Joulin, Armand and Grave, Edouard and Bojanowski, Piotr and Douze, Matthijs and J{\'e}gou, H{\'e}rve and Mikolov, Tomas}, + journal={arXiv preprint arXiv:1612.03651}, + year={2016} + } + + The pretrained model is automatically downloaded (link publically available at): + https://fasttext.cc/docs/en/language-identification.html + + """ + MODEL_PATH = "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin" def __init__(self): @@ -140,6 +170,10 @@ def predict_slice(self, text_slice: TextSlice) -> Iterable[Prediction]: @TaggerRegistry.add("ft_lang_id_scandi_paragraph") class FastTextScandiLanguageParagraphTagger(FastTextScandiLanguageDocumentTagger): + """This tagger runs the FastText language detection model on each paragraph. + The score is between 0 and 1 and provided for each language in LANGS. + """ + def __init__(self): BaseFastTextTagger.__init__( self, @@ -185,6 +219,10 @@ def add_global_language_score_from_slice_score(result: DocResult) -> DocResult: class Cld2LanguageFilterParagraphWithDocScoreTaggerScandi( Cld2LanguageFilterParagraphScandi, ): + """This tagger runs the Compact Language Detect 2 model on each paragraph + and will also provide a total score for each document. + The score is between 0 and 1 and provided for each language in LANGS.""" + def predict(self, doc: Document) -> DocResult: doc_result = super().predict(doc) doc_result = add_global_language_score_from_slice_score(doc_result) @@ -196,6 +234,11 @@ def predict(self, doc: Document) -> DocResult: class FastTextScandiLanguageParagraphWithDocScoreTagger( FastTextScandiLanguageParagraphTagger, ): + """This tagger runs the FastText language detection model on each paragraph, + and will also provide a total score for each document. + The score is between 0 and 1 and provided for each language in LANGS. + """ + def predict(self, doc: Document) -> DocResult: doc_result = super().predict(doc) doc_result = add_global_language_score_from_slice_score(doc_result) diff --git a/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py b/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py index 12a76d54..09cb56f2 100644 --- a/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py +++ b/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py @@ -111,6 +111,8 @@ def pp(log_score: float, length: float) -> float: class PerplexityBaseTagger(BaseTagger): + """Base class for CCNet based perplexity tagger""" + @property def model(self: Self) -> kenlm.Model: return self._model @@ -121,7 +123,20 @@ def model(self: Self, model: kenlm.Model): def create_ccnet_perplexity_tagger(lang: str) -> type[PerplexityBaseTagger]: - """Dynamically create tagger class for a given language""" + """Dynamically create perplexity tagger class for a given language. + The class for each language is based on a CCNet pretrained model [1]. + The pretrained models are available throught the Github project page https://github.com/facebookresearch/cc_net. + The models are small language models trained on the Wikipedia of the corresponding language. + + [1] + @inproceedings{wenzek2020ccnet, + title={CCNet: Extracting High Quality Monolingual Datasets from Web Crawl Data}, + author={Wenzek, Guillaume and Lachaux, Marie-Anne and Conneau, Alexis and Chaudhary, Vishrav and Guzm{\'a}n, Francisco and Joulin, Armand and Grave, {\'E}douard}, + booktitle={Proceedings of The 12th Language Resources and Evaluation Conference}, + pages={4003--4012}, + year={2020} + } + """ def __init__(self: Any) -> None: model_bin_path = _get_ccnet_pretrained_lm(lang) @@ -161,7 +176,7 @@ def predict(self: PerplexityBaseTagger, doc: Document) -> DocResult: ) return DocResult(doc=doc, spans=spans) - # Build the class dynamiccaly from base class + # Build the class dynamically from base class # and methods. cls = type( f"CCNetPerplexity{lang}", From b63dc0375c61b8e1571e63152e8957c45b8d36e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Peter=20Bj=C3=B8rn=20J=C3=B8rgensen?= Date: Tue, 14 Nov 2023 16:20:17 +0100 Subject: [PATCH 18/21] Use commit hash for dolma git version. Co-authored-by: Martin Bernstorff --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index aacb6838..6b70cd2d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ requires-python = ">=3.10" dependencies = [ "pydantic>=2.4.2", # dolma does not work with very old versions of pydantic - "dolma@git+https://github.com/allenai/dolma.git", # Install from git until a 0.9.2 package is released + "dolma@git+https://github.com/allenai/dolma.git@5a010a2685914b1db7744426abfb4b9ece52da95", # Install from git until a 0.9.2 package is released "kenlm>=0.2.0", # Used for perplexity tagging "blingfire>=0.1.8", # Used for perplexity tagging "requests>=2.31.0", From a9a17e6178e32ccc137cebacf79fa7aaf56b8b2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Peter=20Bj=C3=B8rn=20J=C3=B8rgensen?= Date: Tue, 14 Nov 2023 16:24:49 +0100 Subject: [PATCH 19/21] Simplify mkdir Co-authored-by: Martin Bernstorff --- src/dfm/common/data_cleaning/dolma_taggers/perplexity.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py b/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py index 09cb56f2..399d14db 100644 --- a/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py +++ b/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py @@ -75,8 +75,7 @@ def _get_ccnet_pretrained_lm(lang: str) -> Path: url = f"http://dl.fbaipublicfiles.com/cc_net/lm/{lang}.arpa.bin" data_folder = Path("data_lm") - if not Path.exists(data_folder): - Path.mkdir(data_folder, parents=True) + Path.mkdir(data_folder, parents=True, exist_ok=True) filename = f"{lang}.arpa.bin" file_path = data_folder / filename From 7b1a929e839badb17705841d706ed843330885c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Peter=20Bj=C3=B8rn=20J=C3=B8rgensen?= Date: Wed, 15 Nov 2023 13:42:19 +0100 Subject: [PATCH 20/21] rename text normalizer to ccnet text normalizer --- .../{text_normalizer.py => ccnet_text_normalizer.py} | 3 +++ src/dfm/common/data_cleaning/dolma_taggers/perplexity.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) rename src/dfm/common/data_cleaning/{text_normalizer.py => ccnet_text_normalizer.py} (95%) diff --git a/src/dfm/common/data_cleaning/text_normalizer.py b/src/dfm/common/data_cleaning/ccnet_text_normalizer.py similarity index 95% rename from src/dfm/common/data_cleaning/text_normalizer.py rename to src/dfm/common/data_cleaning/ccnet_text_normalizer.py index 601e06c0..dab25c41 100644 --- a/src/dfm/common/data_cleaning/text_normalizer.py +++ b/src/dfm/common/data_cleaning/ccnet_text_normalizer.py @@ -1,4 +1,7 @@ # This file has initially been copied from the ccnet repository from Facebook. +# https://github.com/facebookresearch/cc_net/blob/main/cc_net/text_normalizer.py +# The utility functions can be used to normalize text before processing it +# with ccnet models, but might not be the best general purpose implementation. # # Copyright (c) Facebook, Inc. and its affiliates. # diff --git a/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py b/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py index 399d14db..baaebd1b 100644 --- a/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py +++ b/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py @@ -16,7 +16,7 @@ from dolma.core.taggers import BaseTagger from dolma.core.utils import split_paragraphs -from dfm.common.data_cleaning.text_normalizer import normalize +from dfm.common.data_cleaning.ccnet_text_normalizer import normalize ccnet_sha256 = { "af.arpa.bin": "7278e70cb22e29e94942b103c0ba49f406a9369c2949199fdf8d4bee4b0ce48e", From 4ff1eee0f03eed6a137bd4b6f157c39f63d36d8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Peter=20Bj=C3=B8rn=20J=C3=B8rgensen?= Date: Wed, 15 Nov 2023 15:04:02 +0100 Subject: [PATCH 21/21] Rename tagger classes for better consistency and add some more comments --- .../dolma_taggers/language_scandi.py | 18 +++++++++++++----- .../data_cleaning/dolma_taggers/perplexity.py | 2 +- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py b/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py index c39df461..1108abff 100644 --- a/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py +++ b/src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py @@ -1,6 +1,6 @@ """ -Filters. +Dolma taggers for Scandinavian language detection. """ from collections.abc import Iterable @@ -25,7 +25,7 @@ @TaggerRegistry.add("cld2_scandi_doc") -class Cld2LanguageFilterScandi(BaseTagger): +class Cld2ScandiLanguageTagger(BaseTagger): """This tagger runs the Compact Language Detect 2 model on a full document and will return a score between 0 and 1 for each language in LANGS. It uses the pretrained model from the pycld2 package.""" @@ -42,6 +42,7 @@ def _identity_fn(self, text: str) -> str: return text def _predict_text(self, text: str) -> dict[str, float]: + """Predict the language of a string and return the detected languages in a dictionary.""" is_reliable = False details: Iterable[tuple[str, str, int, float]] = [] for fn in (self._identity_fn, self._to_ascii_input, self._sanitize_input): @@ -49,6 +50,11 @@ def _predict_text(self, text: str) -> dict[str, float]: retvals = cld2.detect(fn(text)) assert len(retvals) == 3 is_reliable, _, details = retvals + # is_reliable is True if the detection is "high confidence" + # details is a Tuple of up to three detected languages, where each is + # tuple is (languageName, languageCode, percent, score). percent is + # what percentage of the original text was detected as this language + # and score is the confidence score for that language. break except cld2.error: ... @@ -65,6 +71,8 @@ def predict(self, doc: Document) -> DocResult: lang_scores = self._predict_text(doc.text) spans: list[Span] = [] for lang_code in LANGS.values(): + # If the language was not detected we will still tag + # the sentence with a score of 0 score = lang_scores.get(lang_code, 0) positive_span = Span( @@ -85,7 +93,7 @@ def predict(self, doc: Document) -> DocResult: @TaggerRegistry.add("cld2_scandi_paragraph") -class Cld2LanguageFilterParagraphScandi(Cld2LanguageFilterScandi): +class Cld2ScandiLanguageParagraphTagger(Cld2ScandiLanguageTagger): """This tagger runs the Compact Language Detect 2 model on each paragraph, and will save a score between 0 and 1 for each language in LANGS""" @@ -216,8 +224,8 @@ def add_global_language_score_from_slice_score(result: DocResult) -> DocResult: # Composite tagger that provides both paragraph and doc scores @TaggerRegistry.add("cld2_scandi_paragraph_with_doc_score") -class Cld2LanguageFilterParagraphWithDocScoreTaggerScandi( - Cld2LanguageFilterParagraphScandi, +class Cld2ScandiLanguageParagraphWithDocScoreTagger( + Cld2ScandiLanguageParagraphTagger, ): """This tagger runs the Compact Language Detect 2 model on each paragraph and will also provide a total score for each document. diff --git a/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py b/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py index baaebd1b..1825c5c7 100644 --- a/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py +++ b/src/dfm/common/data_cleaning/dolma_taggers/perplexity.py @@ -178,7 +178,7 @@ def predict(self: PerplexityBaseTagger, doc: Document) -> DocResult: # Build the class dynamically from base class # and methods. cls = type( - f"CCNetPerplexity{lang}", + f"CCNetPerplexity{lang}Tagger", (PerplexityBaseTagger,), { "__init__": __init__,