diff --git a/README.md b/README.md index d59eba5..cdcb8f7 100644 --- a/README.md +++ b/README.md @@ -93,7 +93,7 @@ model.print_topics() ```python # Print highest ranking documents for topic 0 -model.print_highest_ranking_documents(0, corpus, document_topic_matrix) +model.print_representative_documents(0, corpus, document_topic_matrix) ```
diff --git a/docs/index.md b/docs/index.md index 748af1f..7caf87d 100644 --- a/docs/index.md +++ b/docs/index.md @@ -179,7 +179,7 @@ model.print_topics() ```python # Print highest ranking documents for topic 0 -model.print_highest_ranking_documents(0, corpus, document_topic_matrix) +model.print_representative_documents(0, corpus, document_topic_matrix) ```
@@ -217,7 +217,7 @@ csv_table: str = model.export_topic_distribution("something something", format=" latex_table: str = model.export_topics(format="latex") -md_table: str = model.export_highest_ranking_documents(0, corpus, document_topic_matrix, format="markdown") +md_table: str = model.export_representative_documents(0, corpus, document_topic_matrix, format="markdown") ``` ### Visualization diff --git a/pyproject.toml b/pyproject.toml index be2c2ed..56caacb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,12 @@ [tool.black] line-length=79 +[tool.ruff] +line-length=79 + [tool.poetry] name = "turftopic" -version = "0.2.12" +version = "0.2.13" description = "Topic modeling with contextual representations from sentence transformers." authors = ["Márton Kardos "] license = "MIT" diff --git a/turftopic/base.py b/turftopic/base.py index 25b1ec7..4f722fe 100644 --- a/turftopic/base.py +++ b/turftopic/base.py @@ -29,9 +29,7 @@ def get_topics( """Returns high-level topic representations in form of the top K words in each topic. - Parameters - ---------- - top_k: int, default 10 + Parameters ---------- top_k: int, default 10 Number of top words to return for each topic. Returns @@ -62,22 +60,57 @@ def get_topics( return topics def _topics_table( - self, top_k: int = 10, show_scores: bool = False + self, + top_k: int = 10, + show_scores: bool = False, + show_negative: bool = False, ) -> list[list[str]]: - topics = self.get_topics(top_k) - columns = ["Topic ID", f"Top {top_k} Words"] + columns = ["Topic ID", "Highest Ranking"] + if show_negative: + columns.append("Lowest Ranking") rows = [] - for topic_id, terms in topics: + try: + classes = self.classes_ + except AttributeError: + classes = list(range(self.components_.shape[0])) + vocab = self.get_vocab() + for topic_id, component in zip(classes, self.components_): + highest = np.argpartition(-component, top_k)[:top_k] + highest = highest[np.argsort(-component[highest])] + lowest = np.argpartition(component, top_k)[:top_k] + lowest = lowest[np.argsort(component[lowest])] if show_scores: - concat_words = ", ".join( - [f"{word}({importance:.2f})" for word, importance in terms] + concat_positive = ", ".join( + [ + f"{word}({importance:.2f})" + for word, importance in zip( + vocab[highest], component[highest] + ) + ] + ) + concat_negative = ", ".join( + [ + f"{word}({importance:.2f})" + for word, importance in zip( + vocab[lowest], component[lowest] + ) + ] ) else: - concat_words = ", ".join([word for word, importance in terms]) - rows.append([f"{topic_id}", f"{concat_words}"]) + concat_positive = ", ".join([word for word in vocab[highest]]) + concat_negative = ", ".join([word for word in vocab[lowest]]) + row = [f"{topic_id}", f"{concat_positive}"] + if show_negative: + row.append(concat_negative) + rows.append(row) return [columns, *rows] - def print_topics(self, top_k: int = 10, show_scores: bool = False): + def print_topics( + self, + top_k: int = 10, + show_scores: bool = False, + show_negative: bool = False, + ): """Pretty prints topics in the model in a table. Parameters @@ -86,23 +119,36 @@ def print_topics(self, top_k: int = 10, show_scores: bool = False): Number of top words to return for each topic. show_scores: bool, default False Indicates whether to show importance scores for each word. + show_negative: bool, default False + Indicates whether the most negative terms should also be displayed. """ - columns, *rows = self._topics_table(top_k, show_scores) + columns, *rows = self._topics_table(top_k, show_scores, show_negative) table = Table(show_lines=True) - table.add_column(columns[0], style="blue", justify="right") + table.add_column("Topic ID", style="blue", justify="right") table.add_column( - columns[1], + "Highest Ranking", justify="left", style="magenta", max_width=100, ) + if show_negative: + table.add_column( + "Lowest Ranking", + justify="left", + style="red", + max_width=100, + ) for row in rows: table.add_row(*row) console = Console() console.print(table) def export_topics( - self, top_k: int = 10, show_scores: bool = False, format: str = "csv" + self, + top_k: int = 10, + show_scores: bool = False, + show_negative: bool = False, + format: str = "csv", ) -> str: """Exports top K words from topics in a table in a given format. Returns table as a pure string. @@ -113,15 +159,24 @@ def export_topics( Number of top words to return for each topic. show_scores: bool, default False Indicates whether to show importance scores for each word. + show_negative: bool, default False + Indicates whether the most negative terms should also be displayed. format: 'csv', 'latex' or 'markdown' Specifies which format should be used. 'csv', 'latex' and 'markdown' are supported. """ - table = self._topics_table(top_k, show_scores) + table = self._topics_table( + top_k, show_scores, show_negative=show_negative + ) return export_table(table, format=format) - def _highest_ranking_docs( - self, topic_id, raw_documents, document_topic_matrix=None, top_k=5 + def _representative_docs( + self, + topic_id, + raw_documents, + document_topic_matrix=None, + top_k=5, + show_negative: bool = False, ) -> list[list[str]]: if document_topic_matrix is None: try: @@ -154,10 +209,30 @@ def _highest_ranking_docs( if len(doc) > 300: doc = doc[:300] + "..." rows.append([doc, f"{score:.2f}"]) + if show_negative: + rows.append(["...", ""]) + lowest = np.argpartition(document_topic_matrix[:, topic_id], kth)[ + :kth + ] + lowest = lowest[ + np.argsort(document_topic_matrix[lowest, topic_id]) + ] + scores = document_topic_matrix[lowest, topic_id] + for document_id, score in zip(lowest, scores): + doc = raw_documents[document_id] + doc = remove_whitespace(doc) + if len(doc) > 300: + doc = doc[:300] + "..." + rows.append([doc, f"{score:.2f}"]) return [columns, *rows] - def print_highest_ranking_documents( - self, topic_id, raw_documents, document_topic_matrix=None, top_k=5 + def print_representative_documents( + self, + topic_id, + raw_documents, + document_topic_matrix=None, + top_k=5, + show_negative: bool = False, ): """Pretty prints the highest ranking documents in a topic. @@ -172,9 +247,15 @@ def print_highest_ranking_documents( as they cannot infer topics from text. top_k: int, default 5 Top K documents to show. + show_negative: bool, default False + Indicates whether lowest ranking documents should also be shown. """ - columns, *rows = self._highest_ranking_docs( - topic_id, raw_documents, document_topic_matrix, top_k + columns, *rows = self._representative_docs( + topic_id, + raw_documents, + document_topic_matrix, + top_k, + show_negative, ) table = Table(show_lines=True) table.add_column( @@ -186,12 +267,13 @@ def print_highest_ranking_documents( console = Console() console.print(table) - def export_highest_ranking_documents( + def export_representative_documents( self, topic_id, raw_documents, document_topic_matrix=None, top_k=5, + show_negative: bool = False, format: str = "csv", ): """Exports the highest ranking documents in a topic as a text table. @@ -207,12 +289,18 @@ def export_highest_ranking_documents( as they cannot infer topics from text. top_k: int, default 5 Top K documents to show. + show_negative: bool, default False + Indicates whether lowest ranking documents should also be shown. format: 'csv', 'latex' or 'markdown' Specifies which format should be used. 'csv', 'latex' and 'markdown' are supported. """ table = self._highest_ranking_docs( - topic_id, raw_documents, document_topic_matrix, top_k + topic_id, + raw_documents, + document_topic_matrix, + top_k, + show_negative, ) return export_table(table, format=format) diff --git a/turftopic/models/decomp.py b/turftopic/models/decomp.py index c727490..d29d119 100644 --- a/turftopic/models/decomp.py +++ b/turftopic/models/decomp.py @@ -108,3 +108,54 @@ def transform( if embeddings is None: embeddings = self.encoder_.encode(raw_documents) return self.decomposition.transform(embeddings) + + def print_topics( + self, + top_k: int = 5, + show_scores: bool = False, + show_negative: bool = True, + ): + super().print_topics(top_k, show_scores, show_negative) + + def export_topics( + self, + top_k: int = 5, + show_scores: bool = False, + show_negative: bool = True, + format: str = "csv", + ) -> str: + return super().export_topics(top_k, show_scores, show_negative, format) + + def print_representative_documents( + self, + topic_id, + raw_documents, + document_topic_matrix=None, + top_k=5, + show_negative: bool = True, + ): + super().print_representative_documents( + topic_id, + raw_documents, + document_topic_matrix, + top_k, + show_negative, + ) + + def export_representative_documents( + self, + topic_id, + raw_documents, + document_topic_matrix=None, + top_k=5, + show_negative: bool = True, + format: str = "csv", + ): + return super().export_representative_documents( + topic_id, + raw_documents, + document_topic_matrix, + top_k, + show_negative, + format, + )