Skip to content

Commit

Permalink
black format
Browse files Browse the repository at this point in the history
  • Loading branch information
peterbjorgensen committed Nov 8, 2023
1 parent d1520df commit ddf43fa
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
7 changes: 5 additions & 2 deletions src/dfm/common/data_cleaning/dolma_taggers/language_scandi.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ def predict(self, doc: Document) -> DocResult:
score = lang_scores.get(lang_code, 0)

positive_span = Span(
start=0, end=len(doc.text), type=lang_code, score=score
start=0,
end=len(doc.text),
type=lang_code,
score=score,
)
negative_span = Span(
start=0,
Expand Down Expand Up @@ -183,7 +186,7 @@ def add_global_language_score_from_slice_score(result: DocResult) -> DocResult:
# Composite tagger that provides both paragraph and doc scores
@TaggerRegistry.add("cld2_scandi_paragraph_with_doc_score")
class Cld2LanguageFilterParagraphWithDocScoreTaggerScandi(
Cld2LanguageFilterParagraphScandi
Cld2LanguageFilterParagraphScandi,
):
def predict(self, doc: Document) -> DocResult:
doc_result = super().predict(doc)
Expand Down
11 changes: 7 additions & 4 deletions src/dfm/common/data_cleaning/dolma_taggers/perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import hashlib
import logging
from pathlib import Path
from typing import Self, Any
from typing import Any, Self

import blingfire
import kenlm
Expand Down Expand Up @@ -90,14 +90,14 @@ def _get_ccnet_pretrained_lm(lang: str) -> Path:
sha256 = hashlib.sha256(response.content).hexdigest()
if sha256 != ccnet_sha256[filename]:
raise RuntimeError(
f"Checksum mismatch {sha256} != {ccnet_sha256[filename]}"
f"Checksum mismatch {sha256} != {ccnet_sha256[filename]}",
)
with Path.open(file_path, "wb") as file:
file.write(response.content)
logging.info(f"{lang} model downloaded and saved at {file_path}")
else:
raise RuntimeError(
f"Failed to download {lang} model. Status code: {response.status_code}"
f"Failed to download {lang} model. Status code: {response.status_code}",
)
else:
logging.info(f"{lang} model already exists at {file_path}")
Expand All @@ -114,12 +114,15 @@ class PerplexityBaseTagger(BaseTagger):
@property
def model(self: Self) -> kenlm.Model:
return self._model

@model.setter
def model(self: Self, model: kenlm.Model):
self._model = model


def create_ccnet_perplexity_tagger(lang: str) -> type[PerplexityBaseTagger]:
"""Dynamically create tagger class for a given language"""

def __init__(self: Any) -> None:
model_bin_path = _get_ccnet_pretrained_lm(lang)
self.model = kenlm.Model(str(model_bin_path))
Expand All @@ -133,7 +136,7 @@ def predict(self: PerplexityBaseTagger, doc: Document) -> DocResult:
# To get proper scores from the language model we need to normalize the text
# Do not remove accents as it removes æøå and others.
normalized_text = blingfire.normalize_spaces(
normalize(paragraph.text, accent=False)
normalize(paragraph.text, accent=False),
)
# The kenlm model expects end of sentence punctuation to be separated from words with spaces
# so we separate the words using blingfire.
Expand Down

0 comments on commit ddf43fa

Please sign in to comment.