Skip to content

Commit

Permalink
Merge pull request #31 from x-tabdeveloping/positive_negative
Browse files Browse the repository at this point in the history
WIP: Positive/negative highest ranking terms
  • Loading branch information
x-tabdeveloping authored Mar 22, 2024
2 parents 1da5425 + 99e9744 commit 1639f28
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 29 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ model.print_topics()

```python
# Print highest ranking documents for topic 0
model.print_highest_ranking_documents(0, corpus, document_topic_matrix)
model.print_representative_documents(0, corpus, document_topic_matrix)
```

<center>
Expand Down
4 changes: 2 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ model.print_topics()

```python
# Print highest ranking documents for topic 0
model.print_highest_ranking_documents(0, corpus, document_topic_matrix)
model.print_representative_documents(0, corpus, document_topic_matrix)
```

<center>
Expand Down Expand Up @@ -217,7 +217,7 @@ csv_table: str = model.export_topic_distribution("something something", format="

latex_table: str = model.export_topics(format="latex")

md_table: str = model.export_highest_ranking_documents(0, corpus, document_topic_matrix, format="markdown")
md_table: str = model.export_representative_documents(0, corpus, document_topic_matrix, format="markdown")
```

### Visualization
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
[tool.black]
line-length=79

[tool.ruff]
line-length=79

[tool.poetry]
name = "turftopic"
version = "0.2.12"
version = "0.2.13"
description = "Topic modeling with contextual representations from sentence transformers."
authors = ["Márton Kardos <power.up1163@gmail.com>"]
license = "MIT"
Expand Down
138 changes: 113 additions & 25 deletions turftopic/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ def get_topics(
"""Returns high-level topic representations in form of the top K words
in each topic.
Parameters
----------
top_k: int, default 10
Parameters ---------- top_k: int, default 10
Number of top words to return for each topic.
Returns
Expand Down Expand Up @@ -62,22 +60,57 @@ def get_topics(
return topics

def _topics_table(
self, top_k: int = 10, show_scores: bool = False
self,
top_k: int = 10,
show_scores: bool = False,
show_negative: bool = False,
) -> list[list[str]]:
topics = self.get_topics(top_k)
columns = ["Topic ID", f"Top {top_k} Words"]
columns = ["Topic ID", "Highest Ranking"]
if show_negative:
columns.append("Lowest Ranking")
rows = []
for topic_id, terms in topics:
try:
classes = self.classes_
except AttributeError:
classes = list(range(self.components_.shape[0]))
vocab = self.get_vocab()
for topic_id, component in zip(classes, self.components_):
highest = np.argpartition(-component, top_k)[:top_k]
highest = highest[np.argsort(-component[highest])]
lowest = np.argpartition(component, top_k)[:top_k]
lowest = lowest[np.argsort(component[lowest])]
if show_scores:
concat_words = ", ".join(
[f"{word}({importance:.2f})" for word, importance in terms]
concat_positive = ", ".join(
[
f"{word}({importance:.2f})"
for word, importance in zip(
vocab[highest], component[highest]
)
]
)
concat_negative = ", ".join(
[
f"{word}({importance:.2f})"
for word, importance in zip(
vocab[lowest], component[lowest]
)
]
)
else:
concat_words = ", ".join([word for word, importance in terms])
rows.append([f"{topic_id}", f"{concat_words}"])
concat_positive = ", ".join([word for word in vocab[highest]])
concat_negative = ", ".join([word for word in vocab[lowest]])
row = [f"{topic_id}", f"{concat_positive}"]
if show_negative:
row.append(concat_negative)
rows.append(row)
return [columns, *rows]

def print_topics(self, top_k: int = 10, show_scores: bool = False):
def print_topics(
self,
top_k: int = 10,
show_scores: bool = False,
show_negative: bool = False,
):
"""Pretty prints topics in the model in a table.
Parameters
Expand All @@ -86,23 +119,36 @@ def print_topics(self, top_k: int = 10, show_scores: bool = False):
Number of top words to return for each topic.
show_scores: bool, default False
Indicates whether to show importance scores for each word.
show_negative: bool, default False
Indicates whether the most negative terms should also be displayed.
"""
columns, *rows = self._topics_table(top_k, show_scores)
columns, *rows = self._topics_table(top_k, show_scores, show_negative)
table = Table(show_lines=True)
table.add_column(columns[0], style="blue", justify="right")
table.add_column("Topic ID", style="blue", justify="right")
table.add_column(
columns[1],
"Highest Ranking",
justify="left",
style="magenta",
max_width=100,
)
if show_negative:
table.add_column(
"Lowest Ranking",
justify="left",
style="red",
max_width=100,
)
for row in rows:
table.add_row(*row)
console = Console()
console.print(table)

def export_topics(
self, top_k: int = 10, show_scores: bool = False, format: str = "csv"
self,
top_k: int = 10,
show_scores: bool = False,
show_negative: bool = False,
format: str = "csv",
) -> str:
"""Exports top K words from topics in a table in a given format.
Returns table as a pure string.
Expand All @@ -113,15 +159,24 @@ def export_topics(
Number of top words to return for each topic.
show_scores: bool, default False
Indicates whether to show importance scores for each word.
show_negative: bool, default False
Indicates whether the most negative terms should also be displayed.
format: 'csv', 'latex' or 'markdown'
Specifies which format should be used.
'csv', 'latex' and 'markdown' are supported.
"""
table = self._topics_table(top_k, show_scores)
table = self._topics_table(
top_k, show_scores, show_negative=show_negative
)
return export_table(table, format=format)

def _highest_ranking_docs(
self, topic_id, raw_documents, document_topic_matrix=None, top_k=5
def _representative_docs(
self,
topic_id,
raw_documents,
document_topic_matrix=None,
top_k=5,
show_negative: bool = False,
) -> list[list[str]]:
if document_topic_matrix is None:
try:
Expand Down Expand Up @@ -154,10 +209,30 @@ def _highest_ranking_docs(
if len(doc) > 300:
doc = doc[:300] + "..."
rows.append([doc, f"{score:.2f}"])
if show_negative:
rows.append(["...", ""])
lowest = np.argpartition(document_topic_matrix[:, topic_id], kth)[
:kth
]
lowest = lowest[
np.argsort(document_topic_matrix[lowest, topic_id])
]
scores = document_topic_matrix[lowest, topic_id]
for document_id, score in zip(lowest, scores):
doc = raw_documents[document_id]
doc = remove_whitespace(doc)
if len(doc) > 300:
doc = doc[:300] + "..."
rows.append([doc, f"{score:.2f}"])
return [columns, *rows]

def print_highest_ranking_documents(
self, topic_id, raw_documents, document_topic_matrix=None, top_k=5
def print_representative_documents(
self,
topic_id,
raw_documents,
document_topic_matrix=None,
top_k=5,
show_negative: bool = False,
):
"""Pretty prints the highest ranking documents in a topic.
Expand All @@ -172,9 +247,15 @@ def print_highest_ranking_documents(
as they cannot infer topics from text.
top_k: int, default 5
Top K documents to show.
show_negative: bool, default False
Indicates whether lowest ranking documents should also be shown.
"""
columns, *rows = self._highest_ranking_docs(
topic_id, raw_documents, document_topic_matrix, top_k
columns, *rows = self._representative_docs(
topic_id,
raw_documents,
document_topic_matrix,
top_k,
show_negative,
)
table = Table(show_lines=True)
table.add_column(
Expand All @@ -186,12 +267,13 @@ def print_highest_ranking_documents(
console = Console()
console.print(table)

def export_highest_ranking_documents(
def export_representative_documents(
self,
topic_id,
raw_documents,
document_topic_matrix=None,
top_k=5,
show_negative: bool = False,
format: str = "csv",
):
"""Exports the highest ranking documents in a topic as a text table.
Expand All @@ -207,12 +289,18 @@ def export_highest_ranking_documents(
as they cannot infer topics from text.
top_k: int, default 5
Top K documents to show.
show_negative: bool, default False
Indicates whether lowest ranking documents should also be shown.
format: 'csv', 'latex' or 'markdown'
Specifies which format should be used.
'csv', 'latex' and 'markdown' are supported.
"""
table = self._highest_ranking_docs(
topic_id, raw_documents, document_topic_matrix, top_k
topic_id,
raw_documents,
document_topic_matrix,
top_k,
show_negative,
)
return export_table(table, format=format)

Expand Down
51 changes: 51 additions & 0 deletions turftopic/models/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,54 @@ def transform(
if embeddings is None:
embeddings = self.encoder_.encode(raw_documents)
return self.decomposition.transform(embeddings)

def print_topics(
self,
top_k: int = 5,
show_scores: bool = False,
show_negative: bool = True,
):
super().print_topics(top_k, show_scores, show_negative)

def export_topics(
self,
top_k: int = 5,
show_scores: bool = False,
show_negative: bool = True,
format: str = "csv",
) -> str:
return super().export_topics(top_k, show_scores, show_negative, format)

def print_representative_documents(
self,
topic_id,
raw_documents,
document_topic_matrix=None,
top_k=5,
show_negative: bool = True,
):
super().print_representative_documents(
topic_id,
raw_documents,
document_topic_matrix,
top_k,
show_negative,
)

def export_representative_documents(
self,
topic_id,
raw_documents,
document_topic_matrix=None,
top_k=5,
show_negative: bool = True,
format: str = "csv",
):
return super().export_representative_documents(
topic_id,
raw_documents,
document_topic_matrix,
top_k,
show_negative,
format,
)

0 comments on commit 1639f28

Please sign in to comment.