diff --git a/docs/scripts/generate_normalize_figure.py b/docs/scripts/generate_normalize_figure.py index 1a39a085..bb05c7c1 100644 --- a/docs/scripts/generate_normalize_figure.py +++ b/docs/scripts/generate_normalize_figure.py @@ -12,8 +12,8 @@ import gravis as gv -from gene import APP_ROOT from gene.database import create_db +from gene.etl.base import APP_ROOT from gene.query import QueryHandler from gene.schemas import UnmergedNormalizationService diff --git a/src/gene/__init__.py b/src/gene/__init__.py index 2c569554..d3dab1ac 100644 --- a/src/gene/__init__.py +++ b/src/gene/__init__.py @@ -1,65 +1,4 @@ """The VICC library for normalizing genes.""" -import logging -from os import environ -from pathlib import Path +from .version import __version__ -from .version import __version__ # noqa: F401 - -APP_ROOT = Path(__file__).resolve().parent - -logging.basicConfig( - filename="gene.log", format="[%(asctime)s] - %(name)s - %(levelname)s : %(message)s" -) -logger = logging.getLogger("gene") -logger.setLevel(logging.DEBUG) -logger.handlers = [] - -logging.getLogger("boto3").setLevel(logging.INFO) -logging.getLogger("botocore").setLevel(logging.INFO) -logging.getLogger("urllib3").setLevel(logging.INFO) -logging.getLogger("python_jsonschema_objects").setLevel(logging.INFO) -logging.getLogger("biocommons.seqrepo.seqaliasdb.seqaliasdb").setLevel(logging.INFO) -logging.getLogger("biocommons.seqrepo.fastadir.fastadir").setLevel(logging.INFO) - - -SEQREPO_ROOT_DIR = Path( - environ.get("SEQREPO_ROOT_DIR", "/usr/local/share/seqrepo/latest") -) - - -class DownloadException(Exception): # noqa: N818 - """Exception for failures relating to source file downloads.""" - - -from gene.schemas import ( # noqa: E402 - NamespacePrefix, - RefType, - SourceIDAfterNamespace, - SourceName, -) - -ITEM_TYPES = {k.lower(): v.value for k, v in RefType.__members__.items()} - -# Sources we import directly (HGNC, Ensembl, NCBI) -SOURCES = { - source.value.lower(): source.value for source in SourceName.__members__.values() -} - -# Set of sources we import directly -XREF_SOURCES = {src.lower() for src in SourceName.__members__} - -# use to fetch source name from schema based on concept id namespace -# e.g. {"hgnc": "HGNC"} -PREFIX_LOOKUP = { - v.value: SourceName[k].value - for k, v in NamespacePrefix.__members__.items() - if k in SourceName.__members__.keys() -} - -# use to generate namespace prefix from source ID value -# e.g. {"ensg": "ensembl"} -NAMESPACE_LOOKUP = { - v.value.lower(): NamespacePrefix[k].value - for k, v in SourceIDAfterNamespace.__members__.items() - if v.value != "" -} +__all__ = ["__version__"] diff --git a/src/gene/database/dynamodb.py b/src/gene/database/dynamodb.py index 5df9e0d0..161972e1 100644 --- a/src/gene/database/dynamodb.py +++ b/src/gene/database/dynamodb.py @@ -11,7 +11,6 @@ from boto3.dynamodb.conditions import Key from botocore.exceptions import ClientError -from gene import ITEM_TYPES, PREFIX_LOOKUP from gene.database.database import ( AWS_ENV_VAR_NAME, SKIP_AWS_DB_ENV_NAME, @@ -23,7 +22,14 @@ DatabaseWriteException, confirm_aws_db_use, ) -from gene.schemas import RecordType, RefType, SourceMeta, SourceName +from gene.schemas import ( + ITEM_TYPES, + PREFIX_LOOKUP, + RecordType, + RefType, + SourceMeta, + SourceName, +) logger = logging.getLogger(__name__) diff --git a/src/gene/etl/base.py b/src/gene/etl/base.py index 020cd3ed..f46f0044 100644 --- a/src/gene/etl/base.py +++ b/src/gene/etl/base.py @@ -6,7 +6,7 @@ import shutil from abc import ABC, abstractmethod from ftplib import FTP -from os import remove +from os import environ, remove from pathlib import Path from typing import Callable, Dict, List, Optional @@ -15,14 +15,19 @@ from dateutil import parser from gffutils.feature import Feature -from gene import ITEM_TYPES, SEQREPO_ROOT_DIR from gene.database import AbstractDatabase -from gene.schemas import Gene, GeneSequenceLocation, MatchType, SourceName +from gene.schemas import ITEM_TYPES, Gene, GeneSequenceLocation, MatchType, SourceName logger = logging.getLogger("gene") logger.setLevel(logging.DEBUG) +APP_ROOT = Path(__file__).resolve().parent +SEQREPO_ROOT_DIR = Path( + environ.get("SEQREPO_ROOT_DIR", "/usr/local/share/seqrepo/latest") +) + + class Base(ABC): """The ETL base class.""" diff --git a/src/gene/etl/ensembl.py b/src/gene/etl/ensembl.py index 8ff78f23..f75ed034 100644 --- a/src/gene/etl/ensembl.py +++ b/src/gene/etl/ensembl.py @@ -10,9 +10,8 @@ import requests from gffutils.feature import Feature -from gene import APP_ROOT from gene.database import AbstractDatabase -from gene.etl.base import Base +from gene.etl.base import APP_ROOT, Base from gene.etl.exceptions import ( GeneFileVersionError, GeneNormalizerEtlError, diff --git a/src/gene/etl/hgnc.py b/src/gene/etl/hgnc.py index c78ce294..bfd07a7e 100644 --- a/src/gene/etl/hgnc.py +++ b/src/gene/etl/hgnc.py @@ -10,15 +10,15 @@ from dateutil import parser -from gene import APP_ROOT, PREFIX_LOOKUP from gene.database import AbstractDatabase -from gene.etl.base import Base +from gene.etl.base import APP_ROOT, Base from gene.etl.exceptions import ( GeneFileVersionError, GeneNormalizerEtlError, GeneSourceFetchError, ) from gene.schemas import ( + PREFIX_LOOKUP, Annotation, Chromosome, NamespacePrefix, diff --git a/src/gene/etl/ncbi.py b/src/gene/etl/ncbi.py index d57bc614..45d3c308 100644 --- a/src/gene/etl/ncbi.py +++ b/src/gene/etl/ncbi.py @@ -9,15 +9,15 @@ import gffutils -from gene import APP_ROOT, PREFIX_LOOKUP from gene.database import AbstractDatabase -from gene.etl.base import Base +from gene.etl.base import APP_ROOT, Base from gene.etl.exceptions import ( GeneFileVersionError, GeneNormalizerEtlError, GeneSourceFetchError, ) from gene.schemas import ( + PREFIX_LOOKUP, Annotation, Chromosome, NamespacePrefix, diff --git a/src/gene/main.py b/src/gene/main.py index e6c87223..8135da4e 100644 --- a/src/gene/main.py +++ b/src/gene/main.py @@ -6,8 +6,14 @@ from gene import __version__ from gene.database import create_db -from gene.query import InvalidParameterException, QueryHandler -from gene.schemas import NormalizeService, SearchService, UnmergedNormalizationService +from gene.query import QueryHandler +from gene.schemas import ( + SOURCES, + NormalizeService, + SearchService, + SourceName, + UnmergedNormalizationService, +) db = create_db() query_handler = QueryHandler(db) @@ -42,14 +48,10 @@ read_query_summary = "Given query, provide best-matching source records." response_description = "A response to a validly-formed query" q_descr = "Gene to normalize." -incl_descr = """Optional. Comma-separated list of source names to include in - response. Will exclude all other sources. Returns HTTP status code - 422: Unprocessable Entity if both 'incl' and 'excl' parameters - are given.""" -excl_descr = """Optional. Comma-separated list of source names to exclude in - response. Will include all other sources. Returns HTTP status - code 422: Unprocessable Entity if both 'incl' and 'excl' - parameters are given.""" +sources_descr = ( + "Optional. Comma-separated list of source names to include in response, if given. " + "Will exclude all other sources." +) search_description = ( "For each source, return strongest-match concepts " "for query string provided by user" @@ -66,24 +68,29 @@ ) def search( q: str = Query(..., description=q_descr), # noqa: D103 - incl: Optional[str] = Query(None, description=incl_descr), - excl: Optional[str] = Query(None, description=excl_descr), + sources: Optional[str] = Query(None, description=sources_descr), ) -> SearchService: """Return strongest match concepts to query string provided by user. - :param str q: gene search term - :param Optional[str] incl: comma-separated list of sources to include, - with all others excluded. Raises HTTPException if both `incl` and - `excl` are given. - :param Optional[str] excl: comma-separated list of sources exclude, with - all others included. Raises HTTPException if both `incl` and `excl` - are given. + :param q: gene search term + :param sources: If given, search only for records from these sources. + Provide as string of source names separated by commas. :return: JSON response with matched records and source metadata """ - try: - resp = query_handler.search(html.unescape(q), incl=incl, excl=excl) - except InvalidParameterException as e: - raise HTTPException(status_code=422, detail=str(e)) + parsed_sources = [] + if sources: + for candidate_source in sources.split(","): + try: + parsed_source = SourceName[ + SOURCES[candidate_source.strip().lower()].upper() + ] + except KeyError: + raise HTTPException( + status_code=422, + detail=f"Unable to parse source name: {candidate_source}", + ) + parsed_sources.append(parsed_source) + resp = query_handler.search(html.unescape(q), sources=parsed_sources) return resp diff --git a/src/gene/query.py b/src/gene/query.py index e30a79d8..d402390d 100644 --- a/src/gene/query.py +++ b/src/gene/query.py @@ -1,14 +1,17 @@ """Provides methods for handling queries.""" +import logging import re from datetime import datetime -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar from ga4gh.core import core_models, ga4gh_identify from ga4gh.vrs import models -from gene import ITEM_TYPES, NAMESPACE_LOOKUP, PREFIX_LOOKUP, logger from gene.database import AbstractDatabase, DatabaseReadException from gene.schemas import ( + ITEM_TYPES, + NAMESPACE_LOOKUP, + PREFIX_LOOKUP, BaseGene, BaseNormalizationService, Gene, @@ -28,6 +31,7 @@ ) from gene.version import __version__ +_logger = logging.getLogger(__name__) NormService = TypeVar("NormService", bound=BaseNormalizationService) @@ -72,7 +76,7 @@ def _emit_warnings(query_str: str) -> List: "non_breaking_space_characters": "Query contains non-breaking space characters" } ] - logger.warning( + _logger.warning( f"Query ({query_str}) contains non-breaking space characters." ) return warnings @@ -188,14 +192,14 @@ def _fetch_record( try: match = self.db.get_record_by_id(concept_id, case_sensitive=False) except DatabaseReadException as e: - logger.error( + _logger.error( f"Encountered DatabaseReadException looking up {concept_id}: {e}" ) else: if match: self._add_record(response, match, match_type) else: - logger.error( + _logger.error( f"Unable to find expected record for {concept_id} matching as {match_type}" ) # noqa: E501 @@ -220,7 +224,7 @@ def _post_process_resp(self, resp: Dict) -> Dict: records = sorted(records, key=lambda k: k.match_type, reverse=True) return resp - def _get_search_response(self, query: str, sources: Set[str]) -> Dict: + def _get_search_response(self, query: str, sources: Iterable[SourceName]) -> Dict: """Return response as dict where key is source name and value is a list of records. @@ -231,7 +235,7 @@ def _get_search_response(self, query: str, sources: Set[str]) -> Dict: resp = { "query": query, "warnings": self._emit_warnings(query), - "source_matches": {source: None for source in sources}, + "source_matches": {source.value: None for source in sources}, } if query == "": return self._post_process_resp(resp) @@ -263,7 +267,7 @@ def _get_search_response(self, query: str, sources: Set[str]) -> Dict: matched_concept_ids.append(ref) except DatabaseReadException as e: - logger.error( + _logger.error( f"Encountered DatabaseReadException looking up {item_type}" f" {term}: {e}" ) @@ -283,9 +287,7 @@ def _get_service_meta() -> ServiceMeta: def search( self, query_str: str, - incl: str = "", - excl: str = "", - **params, + sources: Optional[List[SourceName]] = None, ) -> SearchService: """Return highest match for each source. @@ -297,57 +299,16 @@ def search( 'ncbigene:673' :param query_str: query, a string, to search for - :param incl: str containing comma-separated names of sources to use. Will - exclude all other sources. Case-insensitive. - :param excl: str containing comma-separated names of source to exclude. Will - include all other source. Case-insensitive. + :param sources: If given, only return records from these sources :return: SearchService class containing all matches found in sources. :raise InvalidParameterException: if both `incl` and `excl` args are provided, or if invalid source names are given """ - possible_sources = { - name.value.lower(): name.value for name in SourceName.__members__.values() - } - sources = dict() - for k, v in possible_sources.items(): - if self.db.get_source_metadata(v): - sources[k] = v - - if not incl and not excl: - query_sources = set(sources.values()) - elif incl and excl: - detail = "Cannot request both source inclusions and exclusions." - raise InvalidParameterException(detail) - elif incl: - req_sources = [n.strip() for n in incl.split(",")] - invalid_sources = [] - query_sources = set() - for source in req_sources: - if source.lower() in sources.keys(): - query_sources.add(sources[source.lower()]) - else: - invalid_sources.append(source) - if invalid_sources: - detail = f"Invalid source name(s): {invalid_sources}" - raise InvalidParameterException(detail) - else: - req_exclusions = [n.strip() for n in excl.lower().split(",")] - req_excl_dict = {r.lower(): r for r in req_exclusions} - invalid_sources = [] - query_sources = set() - for req_l, req in req_excl_dict.items(): - if req_l not in sources.keys(): - invalid_sources.append(req) - for src_l, src in sources.items(): - if src_l not in req_excl_dict.keys(): - query_sources.add(src) - if invalid_sources: - detail = f"Invalid source name(s): {invalid_sources}" - raise InvalidParameterException(detail) + if not sources: + sources = list(SourceName.__members__.values()) query_str = query_str.strip() - - resp = self._get_search_response(query_str, query_sources) + resp = self._get_search_response(query_str, sources) resp["service_meta_"] = self._get_service_meta() return SearchService(**resp) @@ -535,7 +496,7 @@ def _handle_failed_merge_ref(record: Dict, response: Dict, query: str) -> Dict: :param query: original query value :return: response with no match """ - logger.error( + _logger.error( f"Merge ref lookup failed for ref {record['merge_ref']} " f"in record {record['concept_id']} from query {query}" ) @@ -600,7 +561,7 @@ def _resolve_merge( merge = self.db.get_record_by_id(merge_ref, False, True) if merge is None: query = response.query - logger.error( + _logger.error( f"Merge ref lookup failed for ref {record['merge_ref']} " f"in record {record['concept_id']} from query `{query}`" ) diff --git a/src/gene/schemas.py b/src/gene/schemas.py index 6f85b1bc..e6cb5183 100644 --- a/src/gene/schemas.py +++ b/src/gene/schemas.py @@ -147,6 +147,12 @@ class SourceName(Enum): NCBI = "NCBI" +# lowercase imported source name to correctly-cased name, e.g. {"ensembl": "Ensembl"} +SOURCES = { + source.value.lower(): source.value for source in SourceName.__members__.values() +} + + class SourcePriority(IntEnum): """Define priorities for sources when building merged concepts.""" @@ -196,6 +202,23 @@ class NamespacePrefix(Enum): RFAM = "rfam" +# use to fetch source name from schema based on concept id namespace +# e.g. {"hgnc": "HGNC"} +PREFIX_LOOKUP = { + v.value: SourceName[k].value + for k, v in NamespacePrefix.__members__.items() + if k in SourceName.__members__.keys() +} + +# use to generate namespace prefix from source ID value +# e.g. {"ensg": "ensembl"} +NAMESPACE_LOOKUP = { + v.value.lower(): NamespacePrefix[k].value + for k, v in SourceIDAfterNamespace.__members__.items() + if v.value != "" +} + + class DataLicenseAttributes(BaseModel): """Define constraints for data license attributes.""" @@ -222,6 +245,10 @@ class RefType(str, Enum): ASSOCIATED_WITH = "associated_with" +# collective name to singular name, e.g. {"previous_symbols": "prev_symbol"} +ITEM_TYPES = {k.lower(): v.value for k, v in RefType.__members__.items()} + + class SourceMeta(BaseModel): """Metadata for a given source to return in response object.""" diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..f58158c6 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Provide tests.""" diff --git a/tests/conftest.py b/tests/conftest.py index 30c2e9dd..ad1a14a2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,6 @@ """Provide utilities for test cases.""" +import logging + import pytest from gene.database import AbstractDatabase, create_db @@ -10,6 +12,26 @@ def database() -> AbstractDatabase: return create_db() +def pytest_addoption(parser): + """Add custom commands to pytest invocation. + See https://docs.pytest.org/en/7.1.x/reference/reference.html#parser + """ + parser.addoption( + "--verbose-logs", + action="store_true", + default=False, + help="show noisy module logs", + ) + + +def pytest_configure(config): + """Configure pytest setup.""" + if not config.getoption("--verbose-logs"): + logging.getLogger("botocore").setLevel(logging.ERROR) + logging.getLogger("boto3").setLevel(logging.ERROR) + logging.getLogger("urllib3.connectionpool").setLevel(logging.ERROR) + + def _compare_records(normalized_gene, test_gene, match_type): """Check that normalized_gene and test_gene are the same.""" assert normalized_gene.match_type == match_type diff --git a/tests/unit/test_endpoints.py b/tests/unit/test_endpoints.py index 0639e6a0..58073e69 100644 --- a/tests/unit/test_endpoints.py +++ b/tests/unit/test_endpoints.py @@ -8,6 +8,7 @@ from fastapi.testclient import TestClient from gene.main import app +from gene.schemas import SourceName @pytest.fixture(scope="module") @@ -16,16 +17,31 @@ def api_client(): return TestClient(app) -def test_search(api_client): +def test_search(api_client: TestClient): """Test /search endpoint.""" response = api_client.get("/gene/search?q=braf") - assert response.status_code == 200 + response.raise_for_status() assert ( response.json()["source_matches"]["HGNC"]["records"][0]["concept_id"] == "hgnc:1097" ) + assert len(response.json()["source_matches"]) == 3 + + response = api_client.get("/gene/search?q=braf&sources=Hgnc") + response.raise_for_status() + assert SourceName.HGNC.value in response.json()["source_matches"] + assert len(response.json()["source_matches"]) == 1 + + response = api_client.get("/gene/search?q=braf&sources=EnsEMBL, NCBI") + response.raise_for_status() + assert SourceName.NCBI.value in response.json()["source_matches"] + assert SourceName.ENSEMBL.value in response.json()["source_matches"] + assert len(response.json()["source_matches"]) == 2 + + response = api_client.get("/gene/search?q=braf&sources=EnsEMBL, NCBzI") + assert response.status_code == 422 - response = api_client.get("/gene/search?q=braf&incl=sdkl") + response = api_client.get("/gene/search?q=braf&sources=sdkl") assert response.status_code == 422 diff --git a/tests/unit/test_ensembl_source.py b/tests/unit/test_ensembl_source.py index a3d32183..9b959a06 100644 --- a/tests/unit/test_ensembl_source.py +++ b/tests/unit/test_ensembl_source.py @@ -13,8 +13,8 @@ class QueryGetter: def __init__(self): self.query_handler = QueryHandler(database) - def search(self, query_str, incl="ensembl"): - resp = self.query_handler.search(query_str, incl=incl) + def search(self, query_str): + resp = self.query_handler.search(query_str, sources=[SourceName.ENSEMBL]) return resp.source_matches[SourceName.ENSEMBL] e = QueryGetter() diff --git a/tests/unit/test_hgnc_source.py b/tests/unit/test_hgnc_source.py index f99e4c41..10db9462 100644 --- a/tests/unit/test_hgnc_source.py +++ b/tests/unit/test_hgnc_source.py @@ -15,8 +15,8 @@ class QueryGetter: def __init__(self): self.query_handler = QueryHandler(database) - def search(self, query_str, incl="hgnc"): - resp = self.query_handler.search(query_str, incl=incl) + def search(self, query_str): + resp = self.query_handler.search(query_str, sources=[SourceName.HGNC]) return resp.source_matches[SourceName.HGNC] h = QueryGetter() diff --git a/tests/unit/test_ncbi_source.py b/tests/unit/test_ncbi_source.py index d0083a43..c90c353d 100644 --- a/tests/unit/test_ncbi_source.py +++ b/tests/unit/test_ncbi_source.py @@ -33,8 +33,8 @@ class QueryGetter: def __init__(self): self.query_handler = QueryHandler(database) - def search(self, query_str, incl="ncbi"): - resp = self.query_handler.search(query_str, incl=incl) + def search(self, query_str): + resp = self.query_handler.search(query_str, sources=[SourceName.NCBI]) return resp.source_matches[SourceName.NCBI] n = QueryGetter() diff --git a/tests/unit/test_query.py b/tests/unit/test_query.py index f767ced1..bfb11460 100644 --- a/tests/unit/test_query.py +++ b/tests/unit/test_query.py @@ -2,7 +2,7 @@ import pytest from ga4gh.core import core_models -from gene.query import InvalidParameterException, QueryHandler +from gene.query import QueryHandler from gene.schemas import BaseGene, MatchType, SourceName @@ -14,8 +14,8 @@ class QueryGetter: def __init__(self): self.query_handler = QueryHandler(database) - def search(self, query_str, incl="", excl=""): - return self.query_handler.search(query_str=query_str, incl=incl, excl=excl) + def search(self, query_str, sources=None): + return self.query_handler.search(query_str=query_str, sources=sources) def normalize(self, query_str): return self.query_handler.normalize(query_str) @@ -1149,35 +1149,22 @@ def test_search_query(query_handler, num_sources): assert len(matches) == num_sources -def test_search_query_inc_exc(query_handler, num_sources): - """Test that query incl and excl work correctly.""" - sources = "hgnc, ensembl, ncbi" - resp = query_handler.search("BRAF", excl=sources) +def test_search_query_source_filters(query_handler): + """Test query source filtering.""" + sources = [SourceName.HGNC, SourceName.NCBI] + resp = query_handler.search("BRAF", sources=sources) matches = resp.source_matches - assert len(matches) == num_sources - len(sources.split()) - - sources = "Hgnc, NCBi" - resp = query_handler.search("BRAF", incl=sources) - matches = resp.source_matches - assert len(matches) == len(sources.split()) + assert len(matches) == len(sources) assert SourceName.HGNC in matches assert SourceName.NCBI in matches - sources = "HGnC" - resp = query_handler.search("BRAF", excl=sources) + sources = [SourceName.HGNC, SourceName.NCBI, SourceName.ENSEMBL] + resp = query_handler.search("BRAF", sources=sources) matches = resp.source_matches - assert len(matches) == num_sources - len(sources.split()) + assert len(matches) == len(sources) assert SourceName.ENSEMBL in matches assert SourceName.NCBI in matches - - -def test_search_invalid_parameter_exception(query_handler): - """Test that Invalid parameter exception works correctly.""" - with pytest.raises(InvalidParameterException): - _ = query_handler.search("BRAF", incl="hgn") # noqa: F841, E501 - - with pytest.raises(InvalidParameterException): - resp = query_handler.search("BRAF", incl="hgnc", excl="hgnc") # noqa: F841 + assert SourceName.HGNC in matches def test_ache_query(query_handler, num_sources, normalized_ache, source_meta):