Skip to content

Commit

Permalink
pulled out knn base functions into cell_search_knn.py
Browse files Browse the repository at this point in the history
  • Loading branch information
tony-kuo committed Mar 22, 2024
1 parent aa66ea7 commit 0564894
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 60 deletions.
4 changes: 2 additions & 2 deletions src/scimilarity/cell_annotation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Optional, Union, List, Set, Tuple

from scimilarity.cell_embedding import CellEmbedding
from scimilarity.cell_search_knn import CellSearchKNN


class CellAnnotation(CellEmbedding):
class CellAnnotation(CellSearchKNN):
"""A class that annotates cells using a cell embedding and then kNN search."""

def __init__(
Expand Down
56 changes: 0 additions & 56 deletions src/scimilarity/cell_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def __init__(

self.model_path = model_path
self.use_gpu = use_gpu
self.knn = None

if filenames is None:
filenames = {}
Expand Down Expand Up @@ -175,58 +174,3 @@ def get_embeddings(
raise RuntimeError(f"NaN detected in embeddings.")

return embedding

def load_knn_index(self, knn_file: str):
"""Load the kNN index file
Parameters
----------
knn_file: str
Filename of the kNN index.
"""

import hnswlib
import os

if os.path.isfile(knn_file):
self.knn = hnswlib.Index(space="cosine", dim=self.model.latent_dim)
self.knn.load_index(knn_file)
else:
print(f"Warning: No KNN index found at {knn_file}")
self.knn = None

def get_nearest_neighbors(
self, embeddings: "numpy.ndarray", k: int = 50, ef: int = 100
) -> Tuple["numpy.ndarray", "numpy.ndarray"]:
"""Get nearest neighbors.
Used by classes that inherit from CellEmbedding and have an instantiated kNN.
Parameters
----------
embeddings: numpy.ndarray
Embeddings as a numpy array.
k: int, default: 50
The number of nearest neighbors.
ef: int, default: 100
The size of the dynamic list for the nearest neighbors.
See https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md
Returns
-------
nn_idxs: numpy.ndarray
A 2D numpy array of nearest neighbor indices [num_cells x k].
nn_dists: numpy.ndarray
A 2D numpy array of nearest neighbor distances [num_cells x k].
Examples
--------
>>> from scimilarity.utils import align_dataset
>>> ca = CellAnnotation(model_path="/opt/data/model")
>>> embedding = ca.get_embeddings(align_dataset(data, ca.gene_order).X)
>>> nn_idxs, nn_dists = get_nearest_neighbors(embeddings)
"""

if self.knn is None:
raise RuntimeError("kNN is not initialized.")
self.knn.set_ef(ef)
return self.knn.knn_query(embeddings, k=k)
4 changes: 2 additions & 2 deletions src/scimilarity/cell_query.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Dict, List, Optional, Tuple, Union, Set

from scimilarity.cell_embedding import CellEmbedding
from scimilarity.cell_search_knn import CellSearchKNN


class CellQuery(CellEmbedding):
class CellQuery(CellSearchKNN):
"""A class that searches for similar cells using a cell embedding and then a kNN search."""

def __init__(
Expand Down

0 comments on commit 0564894

Please sign in to comment.