Skip to content

Commit

Permalink
Added option to use any decomposition method in S3
Browse files Browse the repository at this point in the history
  • Loading branch information
x-tabdeveloping committed Apr 4, 2024
1 parent 450184b commit 4a8487e
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions turftopic/models/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import numpy as np
from rich.console import Console
from sentence_transformers import SentenceTransformer
from sklearn.decomposition import PCA, FastICA
from sklearn.base import TransformerMixin
from sklearn.decomposition import FastICA
from sklearn.feature_extraction.text import CountVectorizer

from turftopic.base import ContextualModel, Encoder
Expand Down Expand Up @@ -33,6 +34,11 @@ class SemanticSignalSeparation(ContextualModel):
vectorizer: CountVectorizer, default None
Vectorizer used for term extraction.
Can be used to prune or filter the vocabulary.
decomposition: TransformerMixin, default None
Custom decomposition method to use.
Can be an instance of FastICA or PCA, or basically any dimensionality
reduction method. Has to have `fit_transform` and `fit` methods.
If not specified, FastICA is used.
max_iter: int, default 200
Maximum number of iterations for ICA.
random_state: int, default None
Expand All @@ -46,6 +52,7 @@ def __init__(
Encoder, str
] = "sentence-transformers/all-MiniLM-L6-v2",
vectorizer: Optional[CountVectorizer] = None,
decomposition: Optional[TransformerMixin] = None,
max_iter: int = 200,
random_state: Optional[int] = None,
):
Expand All @@ -61,9 +68,12 @@ def __init__(
self.vectorizer = vectorizer
self.max_iter = max_iter
self.random_state = random_state
self.decomposition = FastICA(
n_components, max_iter=max_iter, random_state=random_state
)
if decomposition is None:
self.decomposition = FastICA(
n_components, max_iter=max_iter, random_state=random_state
)
else:
self.decomposition = decomposition

def fit_transform(
self, raw_documents, y=None, embeddings: Optional[np.ndarray] = None
Expand Down

0 comments on commit 4a8487e

Please sign in to comment.