diff --git a/README.md b/README.md
index d59eba5..cdcb8f7 100644
--- a/README.md
+++ b/README.md
@@ -93,7 +93,7 @@ model.print_topics()
```python
# Print highest ranking documents for topic 0
-model.print_highest_ranking_documents(0, corpus, document_topic_matrix)
+model.print_representative_documents(0, corpus, document_topic_matrix)
```
diff --git a/docs/index.md b/docs/index.md
index 748af1f..7caf87d 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -179,7 +179,7 @@ model.print_topics()
```python
# Print highest ranking documents for topic 0
-model.print_highest_ranking_documents(0, corpus, document_topic_matrix)
+model.print_representative_documents(0, corpus, document_topic_matrix)
```
@@ -217,7 +217,7 @@ csv_table: str = model.export_topic_distribution("something something", format="
latex_table: str = model.export_topics(format="latex")
-md_table: str = model.export_highest_ranking_documents(0, corpus, document_topic_matrix, format="markdown")
+md_table: str = model.export_representative_documents(0, corpus, document_topic_matrix, format="markdown")
```
### Visualization
diff --git a/pyproject.toml b/pyproject.toml
index be2c2ed..56caacb 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,9 +1,12 @@
[tool.black]
line-length=79
+[tool.ruff]
+line-length=79
+
[tool.poetry]
name = "turftopic"
-version = "0.2.12"
+version = "0.2.13"
description = "Topic modeling with contextual representations from sentence transformers."
authors = ["Márton Kardos "]
license = "MIT"
diff --git a/turftopic/base.py b/turftopic/base.py
index 25b1ec7..4f722fe 100644
--- a/turftopic/base.py
+++ b/turftopic/base.py
@@ -29,9 +29,7 @@ def get_topics(
"""Returns high-level topic representations in form of the top K words
in each topic.
- Parameters
- ----------
- top_k: int, default 10
+ Parameters ---------- top_k: int, default 10
Number of top words to return for each topic.
Returns
@@ -62,22 +60,57 @@ def get_topics(
return topics
def _topics_table(
- self, top_k: int = 10, show_scores: bool = False
+ self,
+ top_k: int = 10,
+ show_scores: bool = False,
+ show_negative: bool = False,
) -> list[list[str]]:
- topics = self.get_topics(top_k)
- columns = ["Topic ID", f"Top {top_k} Words"]
+ columns = ["Topic ID", "Highest Ranking"]
+ if show_negative:
+ columns.append("Lowest Ranking")
rows = []
- for topic_id, terms in topics:
+ try:
+ classes = self.classes_
+ except AttributeError:
+ classes = list(range(self.components_.shape[0]))
+ vocab = self.get_vocab()
+ for topic_id, component in zip(classes, self.components_):
+ highest = np.argpartition(-component, top_k)[:top_k]
+ highest = highest[np.argsort(-component[highest])]
+ lowest = np.argpartition(component, top_k)[:top_k]
+ lowest = lowest[np.argsort(component[lowest])]
if show_scores:
- concat_words = ", ".join(
- [f"{word}({importance:.2f})" for word, importance in terms]
+ concat_positive = ", ".join(
+ [
+ f"{word}({importance:.2f})"
+ for word, importance in zip(
+ vocab[highest], component[highest]
+ )
+ ]
+ )
+ concat_negative = ", ".join(
+ [
+ f"{word}({importance:.2f})"
+ for word, importance in zip(
+ vocab[lowest], component[lowest]
+ )
+ ]
)
else:
- concat_words = ", ".join([word for word, importance in terms])
- rows.append([f"{topic_id}", f"{concat_words}"])
+ concat_positive = ", ".join([word for word in vocab[highest]])
+ concat_negative = ", ".join([word for word in vocab[lowest]])
+ row = [f"{topic_id}", f"{concat_positive}"]
+ if show_negative:
+ row.append(concat_negative)
+ rows.append(row)
return [columns, *rows]
- def print_topics(self, top_k: int = 10, show_scores: bool = False):
+ def print_topics(
+ self,
+ top_k: int = 10,
+ show_scores: bool = False,
+ show_negative: bool = False,
+ ):
"""Pretty prints topics in the model in a table.
Parameters
@@ -86,23 +119,36 @@ def print_topics(self, top_k: int = 10, show_scores: bool = False):
Number of top words to return for each topic.
show_scores: bool, default False
Indicates whether to show importance scores for each word.
+ show_negative: bool, default False
+ Indicates whether the most negative terms should also be displayed.
"""
- columns, *rows = self._topics_table(top_k, show_scores)
+ columns, *rows = self._topics_table(top_k, show_scores, show_negative)
table = Table(show_lines=True)
- table.add_column(columns[0], style="blue", justify="right")
+ table.add_column("Topic ID", style="blue", justify="right")
table.add_column(
- columns[1],
+ "Highest Ranking",
justify="left",
style="magenta",
max_width=100,
)
+ if show_negative:
+ table.add_column(
+ "Lowest Ranking",
+ justify="left",
+ style="red",
+ max_width=100,
+ )
for row in rows:
table.add_row(*row)
console = Console()
console.print(table)
def export_topics(
- self, top_k: int = 10, show_scores: bool = False, format: str = "csv"
+ self,
+ top_k: int = 10,
+ show_scores: bool = False,
+ show_negative: bool = False,
+ format: str = "csv",
) -> str:
"""Exports top K words from topics in a table in a given format.
Returns table as a pure string.
@@ -113,15 +159,24 @@ def export_topics(
Number of top words to return for each topic.
show_scores: bool, default False
Indicates whether to show importance scores for each word.
+ show_negative: bool, default False
+ Indicates whether the most negative terms should also be displayed.
format: 'csv', 'latex' or 'markdown'
Specifies which format should be used.
'csv', 'latex' and 'markdown' are supported.
"""
- table = self._topics_table(top_k, show_scores)
+ table = self._topics_table(
+ top_k, show_scores, show_negative=show_negative
+ )
return export_table(table, format=format)
- def _highest_ranking_docs(
- self, topic_id, raw_documents, document_topic_matrix=None, top_k=5
+ def _representative_docs(
+ self,
+ topic_id,
+ raw_documents,
+ document_topic_matrix=None,
+ top_k=5,
+ show_negative: bool = False,
) -> list[list[str]]:
if document_topic_matrix is None:
try:
@@ -154,10 +209,30 @@ def _highest_ranking_docs(
if len(doc) > 300:
doc = doc[:300] + "..."
rows.append([doc, f"{score:.2f}"])
+ if show_negative:
+ rows.append(["...", ""])
+ lowest = np.argpartition(document_topic_matrix[:, topic_id], kth)[
+ :kth
+ ]
+ lowest = lowest[
+ np.argsort(document_topic_matrix[lowest, topic_id])
+ ]
+ scores = document_topic_matrix[lowest, topic_id]
+ for document_id, score in zip(lowest, scores):
+ doc = raw_documents[document_id]
+ doc = remove_whitespace(doc)
+ if len(doc) > 300:
+ doc = doc[:300] + "..."
+ rows.append([doc, f"{score:.2f}"])
return [columns, *rows]
- def print_highest_ranking_documents(
- self, topic_id, raw_documents, document_topic_matrix=None, top_k=5
+ def print_representative_documents(
+ self,
+ topic_id,
+ raw_documents,
+ document_topic_matrix=None,
+ top_k=5,
+ show_negative: bool = False,
):
"""Pretty prints the highest ranking documents in a topic.
@@ -172,9 +247,15 @@ def print_highest_ranking_documents(
as they cannot infer topics from text.
top_k: int, default 5
Top K documents to show.
+ show_negative: bool, default False
+ Indicates whether lowest ranking documents should also be shown.
"""
- columns, *rows = self._highest_ranking_docs(
- topic_id, raw_documents, document_topic_matrix, top_k
+ columns, *rows = self._representative_docs(
+ topic_id,
+ raw_documents,
+ document_topic_matrix,
+ top_k,
+ show_negative,
)
table = Table(show_lines=True)
table.add_column(
@@ -186,12 +267,13 @@ def print_highest_ranking_documents(
console = Console()
console.print(table)
- def export_highest_ranking_documents(
+ def export_representative_documents(
self,
topic_id,
raw_documents,
document_topic_matrix=None,
top_k=5,
+ show_negative: bool = False,
format: str = "csv",
):
"""Exports the highest ranking documents in a topic as a text table.
@@ -207,12 +289,18 @@ def export_highest_ranking_documents(
as they cannot infer topics from text.
top_k: int, default 5
Top K documents to show.
+ show_negative: bool, default False
+ Indicates whether lowest ranking documents should also be shown.
format: 'csv', 'latex' or 'markdown'
Specifies which format should be used.
'csv', 'latex' and 'markdown' are supported.
"""
table = self._highest_ranking_docs(
- topic_id, raw_documents, document_topic_matrix, top_k
+ topic_id,
+ raw_documents,
+ document_topic_matrix,
+ top_k,
+ show_negative,
)
return export_table(table, format=format)
diff --git a/turftopic/models/decomp.py b/turftopic/models/decomp.py
index c727490..d29d119 100644
--- a/turftopic/models/decomp.py
+++ b/turftopic/models/decomp.py
@@ -108,3 +108,54 @@ def transform(
if embeddings is None:
embeddings = self.encoder_.encode(raw_documents)
return self.decomposition.transform(embeddings)
+
+ def print_topics(
+ self,
+ top_k: int = 5,
+ show_scores: bool = False,
+ show_negative: bool = True,
+ ):
+ super().print_topics(top_k, show_scores, show_negative)
+
+ def export_topics(
+ self,
+ top_k: int = 5,
+ show_scores: bool = False,
+ show_negative: bool = True,
+ format: str = "csv",
+ ) -> str:
+ return super().export_topics(top_k, show_scores, show_negative, format)
+
+ def print_representative_documents(
+ self,
+ topic_id,
+ raw_documents,
+ document_topic_matrix=None,
+ top_k=5,
+ show_negative: bool = True,
+ ):
+ super().print_representative_documents(
+ topic_id,
+ raw_documents,
+ document_topic_matrix,
+ top_k,
+ show_negative,
+ )
+
+ def export_representative_documents(
+ self,
+ topic_id,
+ raw_documents,
+ document_topic_matrix=None,
+ top_k=5,
+ show_negative: bool = True,
+ format: str = "csv",
+ ):
+ return super().export_representative_documents(
+ topic_id,
+ raw_documents,
+ document_topic_matrix,
+ top_k,
+ show_negative,
+ format,
+ )