Skip to content

Commit

Permalink
feat!: use enumerated source name types in search/ (#291)
Browse files Browse the repository at this point in the history
  • Loading branch information
jsstevenson authored Nov 7, 2023
1 parent cae10d7 commit 748a394
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 109 deletions.
54 changes: 30 additions & 24 deletions src/gene/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@

from fastapi import FastAPI, HTTPException, Query

from gene import __version__
from gene import SOURCES, __version__
from gene.database import create_db
from gene.query import InvalidParameterException, QueryHandler
from gene.schemas import NormalizeService, SearchService, UnmergedNormalizationService
from gene.query import QueryHandler
from gene.schemas import (
NormalizeService,
SearchService,
SourceName,
UnmergedNormalizationService,
)

db = create_db()
query_handler = QueryHandler(db)
Expand Down Expand Up @@ -42,14 +47,10 @@
read_query_summary = "Given query, provide best-matching source records."
response_description = "A response to a validly-formed query"
q_descr = "Gene to normalize."
incl_descr = """Optional. Comma-separated list of source names to include in
response. Will exclude all other sources. Returns HTTP status code
422: Unprocessable Entity if both 'incl' and 'excl' parameters
are given."""
excl_descr = """Optional. Comma-separated list of source names to exclude in
response. Will include all other sources. Returns HTTP status
code 422: Unprocessable Entity if both 'incl' and 'excl'
parameters are given."""
sources_descr = (
"Optional. Comma-separated list of source names to include in response, if given. "
"Will exclude all other sources."
)
search_description = (
"For each source, return strongest-match concepts "
"for query string provided by user"
Expand All @@ -66,24 +67,29 @@
)
def search(
q: str = Query(..., description=q_descr), # noqa: D103
incl: Optional[str] = Query(None, description=incl_descr),
excl: Optional[str] = Query(None, description=excl_descr),
sources: Optional[str] = Query(None, description=sources_descr),
) -> SearchService:
"""Return strongest match concepts to query string provided by user.
:param str q: gene search term
:param Optional[str] incl: comma-separated list of sources to include,
with all others excluded. Raises HTTPException if both `incl` and
`excl` are given.
:param Optional[str] excl: comma-separated list of sources exclude, with
all others included. Raises HTTPException if both `incl` and `excl`
are given.
:param q: gene search term
:param sources: If given, search only for records from these sources.
Provide as string of source names separated by commas.
:return: JSON response with matched records and source metadata
"""
try:
resp = query_handler.search(html.unescape(q), incl=incl, excl=excl)
except InvalidParameterException as e:
raise HTTPException(status_code=422, detail=str(e))
parsed_sources = []
if sources:
for candidate_source in sources.split(","):
try:
parsed_source = SourceName[
SOURCES[candidate_source.strip().lower()].upper()
]
except KeyError:
raise HTTPException(
status_code=422,
detail=f"Unable to parse source name: {candidate_source}",
)
parsed_sources.append(parsed_source)
resp = query_handler.search(html.unescape(q), sources=parsed_sources)
return resp


Expand Down
59 changes: 8 additions & 51 deletions src/gene/query.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Provides methods for handling queries."""
import re
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar

from ga4gh.core import core_models, ga4gh_identify
from ga4gh.vrs import models
Expand Down Expand Up @@ -220,7 +220,7 @@ def _post_process_resp(self, resp: Dict) -> Dict:
records = sorted(records, key=lambda k: k.match_type, reverse=True)
return resp

def _get_search_response(self, query: str, sources: Set[str]) -> Dict:
def _get_search_response(self, query: str, sources: Iterable[SourceName]) -> Dict:
"""Return response as dict where key is source name and value is a list of
records.
Expand All @@ -231,7 +231,7 @@ def _get_search_response(self, query: str, sources: Set[str]) -> Dict:
resp = {
"query": query,
"warnings": self._emit_warnings(query),
"source_matches": {source: None for source in sources},
"source_matches": {source.value: None for source in sources},
}
if query == "":
return self._post_process_resp(resp)
Expand Down Expand Up @@ -283,9 +283,7 @@ def _get_service_meta() -> ServiceMeta:
def search(
self,
query_str: str,
incl: str = "",
excl: str = "",
**params,
sources: Optional[List[SourceName]] = None,
) -> SearchService:
"""Return highest match for each source.
Expand All @@ -297,57 +295,16 @@ def search(
'ncbigene:673'
:param query_str: query, a string, to search for
:param incl: str containing comma-separated names of sources to use. Will
exclude all other sources. Case-insensitive.
:param excl: str containing comma-separated names of source to exclude. Will
include all other source. Case-insensitive.
:param sources: If given, only return records from these sources
:return: SearchService class containing all matches found in sources.
:raise InvalidParameterException: if both `incl` and `excl` args are provided,
or if invalid source names are given
"""
possible_sources = {
name.value.lower(): name.value for name in SourceName.__members__.values()
}
sources = dict()
for k, v in possible_sources.items():
if self.db.get_source_metadata(v):
sources[k] = v

if not incl and not excl:
query_sources = set(sources.values())
elif incl and excl:
detail = "Cannot request both source inclusions and exclusions."
raise InvalidParameterException(detail)
elif incl:
req_sources = [n.strip() for n in incl.split(",")]
invalid_sources = []
query_sources = set()
for source in req_sources:
if source.lower() in sources.keys():
query_sources.add(sources[source.lower()])
else:
invalid_sources.append(source)
if invalid_sources:
detail = f"Invalid source name(s): {invalid_sources}"
raise InvalidParameterException(detail)
else:
req_exclusions = [n.strip() for n in excl.lower().split(",")]
req_excl_dict = {r.lower(): r for r in req_exclusions}
invalid_sources = []
query_sources = set()
for req_l, req in req_excl_dict.items():
if req_l not in sources.keys():
invalid_sources.append(req)
for src_l, src in sources.items():
if src_l not in req_excl_dict.keys():
query_sources.add(src)
if invalid_sources:
detail = f"Invalid source name(s): {invalid_sources}"
raise InvalidParameterException(detail)
if not sources:
sources = list(SourceName.__members__.values())

query_str = query_str.strip()

resp = self._get_search_response(query_str, query_sources)
resp = self._get_search_response(query_str, sources)

resp["service_meta_"] = self._get_service_meta()
return SearchService(**resp)
Expand Down
22 changes: 19 additions & 3 deletions tests/unit/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fastapi.testclient import TestClient

from gene.main import app
from gene.schemas import SourceName


@pytest.fixture(scope="module")
Expand All @@ -16,16 +17,31 @@ def api_client():
return TestClient(app)


def test_search(api_client):
def test_search(api_client: TestClient):
"""Test /search endpoint."""
response = api_client.get("/gene/search?q=braf")
assert response.status_code == 200
response.raise_for_status()
assert (
response.json()["source_matches"]["HGNC"]["records"][0]["concept_id"]
== "hgnc:1097"
)
assert len(response.json()["source_matches"]) == 3

response = api_client.get("/gene/search?q=braf&sources=Hgnc")
response.raise_for_status()
assert SourceName.HGNC.value in response.json()["source_matches"]
assert len(response.json()["source_matches"]) == 1

response = api_client.get("/gene/search?q=braf&sources=EnsEMBL, NCBI")
response.raise_for_status()
assert SourceName.NCBI.value in response.json()["source_matches"]
assert SourceName.ENSEMBL.value in response.json()["source_matches"]
assert len(response.json()["source_matches"]) == 2

response = api_client.get("/gene/search?q=braf&sources=EnsEMBL, NCBzI")
assert response.status_code == 422

response = api_client.get("/gene/search?q=braf&incl=sdkl")
response = api_client.get("/gene/search?q=braf&sources=sdkl")
assert response.status_code == 422


Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_ensembl_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ class QueryGetter:
def __init__(self):
self.query_handler = QueryHandler(database)

def search(self, query_str, incl="ensembl"):
resp = self.query_handler.search(query_str, incl=incl)
def search(self, query_str):
resp = self.query_handler.search(query_str, sources=[SourceName.ENSEMBL])
return resp.source_matches[SourceName.ENSEMBL]

e = QueryGetter()
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_hgnc_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ class QueryGetter:
def __init__(self):
self.query_handler = QueryHandler(database)

def search(self, query_str, incl="hgnc"):
resp = self.query_handler.search(query_str, incl=incl)
def search(self, query_str):
resp = self.query_handler.search(query_str, sources=[SourceName.HGNC])
return resp.source_matches[SourceName.HGNC]

h = QueryGetter()
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_ncbi_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ class QueryGetter:
def __init__(self):
self.query_handler = QueryHandler(database)

def search(self, query_str, incl="ncbi"):
resp = self.query_handler.search(query_str, incl=incl)
def search(self, query_str):
resp = self.query_handler.search(query_str, sources=[SourceName.NCBI])
return resp.source_matches[SourceName.NCBI]

n = QueryGetter()
Expand Down
37 changes: 12 additions & 25 deletions tests/unit/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest
from ga4gh.core import core_models

from gene.query import InvalidParameterException, QueryHandler
from gene.query import QueryHandler
from gene.schemas import BaseGene, MatchType, SourceName


Expand All @@ -14,8 +14,8 @@ class QueryGetter:
def __init__(self):
self.query_handler = QueryHandler(database)

def search(self, query_str, incl="", excl=""):
return self.query_handler.search(query_str=query_str, incl=incl, excl=excl)
def search(self, query_str, sources=None):
return self.query_handler.search(query_str=query_str, sources=sources)

def normalize(self, query_str):
return self.query_handler.normalize(query_str)
Expand Down Expand Up @@ -1149,35 +1149,22 @@ def test_search_query(query_handler, num_sources):
assert len(matches) == num_sources


def test_search_query_inc_exc(query_handler, num_sources):
"""Test that query incl and excl work correctly."""
sources = "hgnc, ensembl, ncbi"
resp = query_handler.search("BRAF", excl=sources)
def test_search_query_source_filters(query_handler):
"""Test query source filtering."""
sources = [SourceName.HGNC, SourceName.NCBI]
resp = query_handler.search("BRAF", sources=sources)
matches = resp.source_matches
assert len(matches) == num_sources - len(sources.split())

sources = "Hgnc, NCBi"
resp = query_handler.search("BRAF", incl=sources)
matches = resp.source_matches
assert len(matches) == len(sources.split())
assert len(matches) == len(sources)
assert SourceName.HGNC in matches
assert SourceName.NCBI in matches

sources = "HGnC"
resp = query_handler.search("BRAF", excl=sources)
sources = [SourceName.HGNC, SourceName.NCBI, SourceName.ENSEMBL]
resp = query_handler.search("BRAF", sources=sources)
matches = resp.source_matches
assert len(matches) == num_sources - len(sources.split())
assert len(matches) == len(sources)
assert SourceName.ENSEMBL in matches
assert SourceName.NCBI in matches


def test_search_invalid_parameter_exception(query_handler):
"""Test that Invalid parameter exception works correctly."""
with pytest.raises(InvalidParameterException):
_ = query_handler.search("BRAF", incl="hgn") # noqa: F841, E501

with pytest.raises(InvalidParameterException):
resp = query_handler.search("BRAF", incl="hgnc", excl="hgnc") # noqa: F841
assert SourceName.HGNC in matches


def test_ache_query(query_handler, num_sources, normalized_ache, source_meta):
Expand Down

0 comments on commit 748a394

Please sign in to comment.