Skip to content

Commit

Permalink
Merge pull request #43 from x-tabdeveloping/dynamic_keynmf
Browse files Browse the repository at this point in the history
Dynamic keynmf
  • Loading branch information
x-tabdeveloping authored Jun 9, 2024
2 parents 0ec180e + 0226673 commit 24c60ff
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 32 deletions.
8 changes: 7 additions & 1 deletion docs/KeyNMF.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ while taking inspiration from classical matrix-decomposition approaches for extr
## The Model

<figure>
<img src="/images/keynmf.png" width="90%" style="margin-left: auto;margin-right: auto;">
<img src="../images/keynmf.png" width="90%" style="margin-left: auto;margin-right: auto;">
<figcaption>Schematic overview of KeyNMF</figcaption>
</figure>

Expand All @@ -30,6 +30,12 @@ Topics in this matrix are then discovered using Non-negative Matrix Factorizatio
Essentially the model tries to discover underlying dimensions/factors along which most of the variance in term importance
can be explained.

### _(Optional)_ 3. Dynamic Modeling

KeyNMF is also capable of modeling topics over time.
This happens by fitting a KeyNMF model first on the entire corpus, then
fitting individual topic-term matrices using coordinate descent based on the document-topic and document-term matrices in the given time slices.

## Considerations

### Strengths
Expand Down
27 changes: 8 additions & 19 deletions docs/dynamic.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,41 +4,30 @@ If you want to examine the evolution of topics over time, you will need a dynami

> Note that regular static models can also be used to study the evolution of topics and information dynamics, but they can't capture changes in the topics themselves.
## Theory
## Models

A number of different conceptualizations can be used to study evolving topics in corpora, for instance:

1. One can imagine topic representations to be governed by a Brownian Markov Process (random walk), in such a case the evolution is part of the model itself.
In layman's terms you describe the evolution of topics directly in your generative model by expecting the topic representations to be sampled from Gaussian noise around the last time step.
Sometimes researchers will also refer to such models as _state-space_ approaches.
This is the approach that the original [DTM paper](https://mimno.infosci.cornell.edu/info6150/readings/dynamic_topic_models.pdf) utilizes.
Along with [this paper](https://arxiv.org/pdf/1709.00025.pdf) on Dynamic NMF.
2. You can fit one underlying statistical model over the entire corpus, and then do post-hoc term importance estimation per time slice.
This is [what BERTopic does](https://maartengr.github.io/BERTopic/getting_started/topicsovertime/topicsovertime.html).
3. You can fit one model per time slice, and then use some aggregation procedure to merge the models.
This approach is used in the Dynamic NMF in [this paper](https://www.cambridge.org/core/journals/political-analysis/article/exploring-the-political-agenda-of-the-european-parliament-using-a-dynamic-topic-modeling-approach/BBC7751778E4542C7C6C69E6BF954E4B).

Developing such approaches takes a lot of time and effort, and we have plans to add dynamic modeling capabilities to all models in Turftopic.
For now only models of the second kind are on our list of things to do, and dynamic topic modeling has been implemented for GMM, and will soon be implemented for Clustering Topic Models.
For more theoretical background, see the page on [GMM](GMM.md).
In Turftopic you can currently use three different topic models for modeling topics over time:
1. [ClusteringTopicModel](clustering.md), where an overall model is fitted on the whole corpus, and then term importances are estimated over time slices.
2. [GMM](GMM.md), similarly to clustering models, term importances are reestimated per time slice
3. [KeyNMF](KeyNMF.md), an overall decomposition is done, then using coordinate descent, topic-term-matrices are recalculated based on document-topic importances in the given time slice.

## Usage

Dynamic topic models in Turftopic have a unified interface.
To fit a dynamic topic model you will need a corpus, that has been annotated with timestamps.
The timestamps need to be Python `datetime` objects, but pandas `Timestamp` object are also supported.

Models that have dynamic modeling capabilities (currently, `GMM` and `ClusteringTopicModel`) have a `fit_transform_dynamic()` method, that fits the model on the corpus over time.
Models that have dynamic modeling capabilities (`KeyNMF`, `GMM` and `ClusteringTopicModel`) have a `fit_transform_dynamic()` method, that fits the model on the corpus over time.

```python
from datetime import datetime

from turftopic import GMM
from turftopic import KeyNMF

corpus: list[str] = [...]
timestamps: list[datetime] = [...]

model = GMM(5)
model = KeyNMF(5)
document_topic_matrix = model.fit_transform_dynamic(corpus, timestamps=timestamps)
```

Expand Down
2 changes: 1 addition & 1 deletion docs/model_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ Here is an opinionated guide for common use cases:
### 1. When in doubt **use KeyNMF**.

When you can't make an informed decision about which model is optimal for your use case, or you just want to get your hands dirty with topic modeling,
KeyNMF is the best option.
KeyNMF is by far the best option.
It is very stable, gives high quality topics, and is incredibly robust to noise.
It is also the closest to classical topic models and thus conforms to your intuition about topic modeling.

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ line-length=79

[tool.poetry]
name = "turftopic"
version = "0.2.13"
version = "0.3.0"
description = "Topic modeling with contextual representations from sentence transformers."
authors = ["Márton Kardos <power.up1163@gmail.com>"]
license = "MIT"
Expand Down
14 changes: 5 additions & 9 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime
import tempfile
from datetime import datetime
from pathlib import Path

import numpy as np
Expand All @@ -8,13 +8,8 @@
from sentence_transformers import SentenceTransformer
from sklearn.datasets import fetch_20newsgroups

from turftopic import (
GMM,
AutoEncodingTopicModel,
ClusteringTopicModel,
KeyNMF,
SemanticSignalSeparation,
)
from turftopic import (GMM, AutoEncodingTopicModel, ClusteringTopicModel,
KeyNMF, SemanticSignalSeparation)


def generate_dates(
Expand Down Expand Up @@ -75,8 +70,9 @@ def generate_dates(
n_reduce_to=5,
feature_importance="soft-c-tf-idf",
encoder=trf,
reduction_method="smallest"
reduction_method="smallest",
),
KeyNMF(5, encoder=trf),
]


Expand Down
76 changes: 75 additions & 1 deletion turftopic/models/keynmf.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,67 @@
import itertools
import json
import random
from datetime import datetime
from typing import Dict, Iterable, List, Optional, Union

import numpy as np
from rich.console import Console
from sentence_transformers import SentenceTransformer
from sklearn.decomposition import NMF, MiniBatchNMF
from sklearn.decomposition._nmf import (_initialize_nmf,
_update_coordinate_descent)
from sklearn.exceptions import NotFittedError
from sklearn.feature_extraction import DictVectorizer
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.utils import check_array

from turftopic.base import ContextualModel, Encoder
from turftopic.data import TopicData
from turftopic.dynamic import DynamicTopicModel, bin_timestamps
from turftopic.vectorizer import default_vectorizer


def fit_timeslice(
X,
W,
H,
tol=1e-4,
max_iter=200,
l1_reg_W=0,
l1_reg_H=0,
l2_reg_W=0,
l2_reg_H=0,
verbose=0,
shuffle=False,
random_state=None,
):
"""Fits topic_term_matrix based on a precomputed document_topic_matrix.
This is used to get temporal components in dynamic KeyNMF.
"""
Ht = check_array(H.T, order="C")
if random_state is None:
rng = np.random.mtrand._rand
else:
rng = np.random.RandomState(random_state)
for n_iter in range(1, max_iter + 1):
violation = 0.0
violation += _update_coordinate_descent(
X.T, Ht, W, l1_reg_H, l2_reg_H, shuffle, rng
)
if n_iter == 1:
violation_init = violation
if violation_init == 0:
break
if verbose:
print("violation:", violation / violation_init)
if violation / violation_init <= tol:
if verbose:
print("Converged at iteration", n_iter + 1)
break
return W, Ht.T, n_iter


def batched(iterable, n: int) -> Iterable[List[str]]:
"Batch data into tuples of length n. The last batch may be shorter."
# batched('ABCDEFG', 3) --> ABC DEF G
Expand Down Expand Up @@ -48,7 +93,7 @@ def __iter__(self) -> Iterable[Dict[str, float]]:
yield deserialize_keywords(line.strip())


class KeyNMF(ContextualModel):
class KeyNMF(ContextualModel, DynamicTopicModel):
"""Extracts keywords from documents based on semantic similarity of
term encodings to document encodings.
Topics are then extracted with non-negative matrix factorization from
Expand Down Expand Up @@ -305,3 +350,32 @@ def prepare_topic_data(
"topic_names": self.topic_names,
}
return res

def fit_transform_dynamic(
self,
raw_documents,
timestamps: list[datetime],
embeddings: Optional[np.ndarray] = None,
bins: Union[int, list[datetime]] = 10,
) -> np.ndarray:
time_labels, self.time_bin_edges = bin_timestamps(timestamps, bins)
topic_data = self.prepare_topic_data(
raw_documents, embeddings=embeddings
)
n_bins = len(self.time_bin_edges) + 1
n_comp, n_vocab = self.components_.shape
self.temporal_components_ = np.zeros((n_bins, n_comp, n_vocab))
self.temporal_importance_ = np.zeros((n_bins, n_comp))
for label in np.unique(time_labels):
idx = np.nonzero(time_labels == label)
X = topic_data["document_term_matrix"][idx]
W = topic_data["document_topic_matrix"][idx]
_, H = _initialize_nmf(
X, self.components_.shape[0], random_state=self.random_state
)
_, H, _ = fit_timeslice(X, W, H, random_state=self.random_state)
self.temporal_components_[label] = H
topic_importances = np.squeeze(np.asarray(W.sum(axis=0)))
topic_importances = topic_importances / topic_importances.sum()
self.temporal_importance_[label] = topic_importances
return topic_data["document_topic_matrix"]

0 comments on commit 24c60ff

Please sign in to comment.