diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5594fa9a..421d9d84 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,7 +8,7 @@ repos: - id: trailing-whitespace - id: end-of-file-fixer - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.2.0 + rev: v0.5.0 hooks: - id: ruff - id: ruff-format diff --git a/Pipfile b/Pipfile index 9ea28096..b39cf793 100644 --- a/Pipfile +++ b/Pipfile @@ -9,7 +9,7 @@ fastapi = "*" uvicorn = "*" click = "*" boto3 = "*" -"ga4gh.vrs" = "~=2.0.0a1" +"ga4gh.vrs" = "~=2.0.0a8" [dev-packages] gene = {editable = true, path = "."} @@ -19,7 +19,7 @@ wags-tails = ">=0.1.1" psycopg = {version = "*", extras=["binary"]} pytest = "*" pre-commit = "*" -ruff = ">=0.1.2" +ruff = "==0.5.0" pytest-cov = "*" httpx = "*" mock = "*" diff --git a/docs/scripts/generate_normalize_figure.py b/docs/scripts/generate_normalize_figure.py index 1a39a085..5a5ec1ad 100644 --- a/docs/scripts/generate_normalize_figure.py +++ b/docs/scripts/generate_normalize_figure.py @@ -6,9 +6,9 @@ Embeddable HTML for the normalization figure should be deposited in the correct location, within docs/source/_static/html/. -""" +""" # noqa: INP001 + import json -from typing import Dict import gravis as gv @@ -24,7 +24,7 @@ ] -def create_gjgf(result: UnmergedNormalizationService) -> Dict: +def create_gjgf(result: UnmergedNormalizationService) -> dict: """Create gravis input. :param result: result from Unmerged Normalization search @@ -43,13 +43,13 @@ def create_gjgf(result: UnmergedNormalizationService) -> Dict: } } - for i, (source, matches) in enumerate(result.source_matches.items()): + for i, (_, matches) in enumerate(result.source_matches.items()): for match in matches.records: graph["graph"]["nodes"][match.concept_id] = { "metadata": { "color": COLORS[i], - "hover": f"{match.concept_id}\n{match.symbol}\n{match.label}", # noqa: E501 - "click": f"

{json.dumps(match.model_dump(), indent=2)}

", # noqa: E501 + "hover": f"{match.concept_id}\n{match.symbol}\n{match.label}", + "click": f"

{json.dumps(match.model_dump(), indent=2)}

", } } for xref in match.xrefs: @@ -57,22 +57,25 @@ def create_gjgf(result: UnmergedNormalizationService) -> Dict: {"source": match.concept_id, "target": xref} ) - included_edges = [] - for edge in graph["graph"]["edges"]: + included_edges = [ + edge + for edge in graph["graph"]["edges"] if ( edge["target"] in graph["graph"]["nodes"] and edge["source"] in graph["graph"]["nodes"] - ): - included_edges.append(edge) + ) + ] + graph["graph"]["edges"] = included_edges included_nodes = {k["source"] for k in graph["graph"]["edges"]}.union( {k["target"] for k in graph["graph"]["edges"]} ) - new_nodes = {} - for key, value in graph["graph"]["nodes"].items(): - if key in included_nodes: - new_nodes[key] = value + new_nodes = { + key: value + for key, value in graph["graph"]["nodes"].items() + if key in included_nodes + } graph["graph"]["nodes"] = new_nodes return graph diff --git a/docs/source/conf.py b/docs/source/conf.py index f22f3d54..959356b5 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -79,7 +79,7 @@ def linkcode_resolve(domain, info): if not info["module"]: return None filename = info["module"].replace(".", "/") - return f"https://github.com/cancervariants/gene-normalization/blob/main/{filename}.py" # noqa: E501 + return f"https://github.com/cancervariants/gene-normalization/blob/main/{filename}.py" # -- code block style -------------------------------------------------------- diff --git a/pyproject.toml b/pyproject.toml index 97f9a0f8..eb227eda 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ dynamic = ["version"] pg = ["psycopg[binary]"] etl = ["gffutils", "biocommons.seqrepo", "wags-tails>=0.1.1"] test = ["pytest>=6.0", "pytest-cov", "mock", "httpx"] -dev = ["pre-commit", "ruff==0.2.0"] +dev = ["pre-commit", "ruff==0.5.0"] docs = [ "sphinx==6.1.3", "sphinx-autodoc-typehints==1.22.0", @@ -107,16 +107,22 @@ select = [ "DTZ", # https://docs.astral.sh/ruff/rules/#flake8-datetimez-dtz "T10", # https://docs.astral.sh/ruff/rules/#flake8-datetimez-dtz "EM", # https://docs.astral.sh/ruff/rules/#flake8-errmsg-em + "LOG", # https://docs.astral.sh/ruff/rules/#flake8-logging-log "G", # https://docs.astral.sh/ruff/rules/#flake8-logging-format-g + "INP", # https://docs.astral.sh/ruff/rules/#flake8-no-pep420-inp "PIE", # https://docs.astral.sh/ruff/rules/#flake8-pie-pie "T20", # https://docs.astral.sh/ruff/rules/#flake8-print-t20 "PT", # https://docs.astral.sh/ruff/rules/#flake8-pytest-style-pt "Q", # https://docs.astral.sh/ruff/rules/#flake8-quotes-q "RSE", # https://docs.astral.sh/ruff/rules/#flake8-raise-rse "RET", # https://docs.astral.sh/ruff/rules/#flake8-return-ret + "SLF", # https://docs.astral.sh/ruff/rules/#flake8-self-slf "SIM", # https://docs.astral.sh/ruff/rules/#flake8-simplify-sim + "ARG", # https://docs.astral.sh/ruff/rules/#flake8-unused-arguments-arg "PTH", # https://docs.astral.sh/ruff/rules/#flake8-use-pathlib-pth "PGH", # https://docs.astral.sh/ruff/rules/#pygrep-hooks-pgh + "PERF", # https://docs.astral.sh/ruff/rules/#perflint-perf + "FURB", # https://docs.astral.sh/ruff/rules/#refurb-furb "RUF", # https://docs.astral.sh/ruff/rules/#ruff-specific-rules-ruf ] fixable = [ @@ -127,15 +133,19 @@ fixable = [ "ANN", "B", "C4", + "LOG", "G", "PIE", "PT", "RSE", "SIM", + "PERF", + "FURB", "RUF" ] -# ANN101 - missing-type-self # ANN003 - missing-type-kwargs +# ANN101 - missing-type-self +# ANN102 - missing-type-cls # D203 - one-blank-line-before-class # D205 - blank-line-after-summary # D206 - indent-with-spaces* @@ -151,7 +161,7 @@ fixable = [ # S321 - suspicious-ftp-lib-usage # *ignored for compatibility with formatter ignore = [ - "ANN101", "ANN003", + "ANN003", "ANN101", "ANN102", "D203", "D205", "D206", "D213", "D300", "D400", "D415", "E111", "E114", "E117", "E501", "W191", @@ -169,7 +179,13 @@ ignore = [ # D100 - undocumented-public-module # D103 - undocumented-public-function # I001 - unsorted-imports -"tests/*" = ["ANN001", "ANN2", "ANN102", "S101", "B011"] +# INP001 - implicit-namespace-package +# ARG001 - unused-function-argument +# SLF001 - private-member-acces +"tests/*" = ["ANN001", "ANN2", "ANN102", "S101", "INP001", "SLF001", "ARG001"] "*__init__.py" = ["F401"] "gene/schemas.py" = ["ANN001", "ANN201", "N805"] "docs/source/conf.py" = ["D100", "I001", "D103", "ANN201", "ANN001"] + +[tool.ruff.format] +docstring-code-format = true diff --git a/src/gene/__init__.py b/src/gene/__init__.py index 30d1b8c3..4d2b71ef 100644 --- a/src/gene/__init__.py +++ b/src/gene/__init__.py @@ -1,8 +1,9 @@ """The VICC library for normalizing genes.""" + from os import environ from pathlib import Path -from .version import __version__ # noqa: F401 +from .version import __version__ APP_ROOT = Path(__file__).resolve().parent @@ -37,7 +38,7 @@ class DownloadException(Exception): # noqa: N818 PREFIX_LOOKUP = { v.value: SourceName[k].value for k, v in NamespacePrefix.__members__.items() - if k in SourceName.__members__.keys() + if k in SourceName.__members__ } # use to generate namespace prefix from source ID value diff --git a/src/gene/cli.py b/src/gene/cli.py index 50e458b8..c3286721 100644 --- a/src/gene/cli.py +++ b/src/gene/cli.py @@ -1,9 +1,10 @@ """Provides a CLI util to make updates to normalizer database.""" + import logging import os +from collections.abc import Collection from pathlib import Path from timeit import default_timer as timer -from typing import Collection, List, Optional, Set import click @@ -61,7 +62,7 @@ def check_db(db_url: str, verbose: bool = False) -> None: @click.command() @click.option("--data_url", help="URL to data dump") @click.option("--db_url", help="URL endpoint for the application database.") -def update_from_remote(data_url: Optional[str], db_url: str) -> None: +def update_from_remote(data_url: str | None, db_url: str) -> None: """Update data from remotely-hosted DB dump. By default, fetches from latest available dump on VICC S3 bucket; specific URLs can be provided instead by command line option or GENE_NORM_REMOTE_DB_URL environment variable. @@ -81,10 +82,10 @@ def update_from_remote(data_url: Optional[str], db_url: str) -> None: except NotImplementedError: click.echo( f"Error: Fetching remote data dump not supported for {db.__class__.__name__}" - ) # noqa: E501 + ) click.get_current_context().exit(1) except DatabaseException as e: - click.echo(f"Encountered exception during update: {str(e)}") + click.echo(f"Encountered exception during update: {e!s}") click.get_current_context().exit(1) _logger.info("Successfully loaded data from remote snapshot.") @@ -106,7 +107,7 @@ def dump_database(output_directory: Path, db_url: str) -> None: """ # noqa: D301 _configure_logging() if not output_directory: - output_directory = Path(".") + output_directory = Path() db = create_db(db_url, False) try: @@ -114,10 +115,10 @@ def dump_database(output_directory: Path, db_url: str) -> None: except NotImplementedError: click.echo( f"Error: Dumping data to file not supported for {db.__class__.__name__}" - ) # noqa: E501 + ) click.get_current_context().exit(1) except DatabaseException as e: - click.echo(f"Encountered exception during update: {str(e)}") + click.echo(f"Encountered exception during update: {e!s}") click.get_current_context().exit(1) _logger.info("Database dump successful.") @@ -137,7 +138,7 @@ def _update_normalizer( :param use_existing: if True, use most recent local version of source data instead of fetching from remote """ - processed_ids = list() + processed_ids = [] for n in sources: delete_time = _delete_source(n, db) _load_source(n, db, delete_time, processed_ids, use_existing) @@ -173,7 +174,7 @@ def _load_source( n: SourceName, db: AbstractDatabase, delete_time: float, - processed_ids: List[str], + processed_ids: list[str], use_existing: bool, ) -> None: """Load individual source data. @@ -199,7 +200,7 @@ def _load_source( f"Encountered ModuleNotFoundError attempting to import {e.name}. {_etl_dependency_help}" ) click.get_current_context().exit() - SourceClass = eval(n.value) # noqa: N806 + SourceClass = eval(n.value) # noqa: N806, S307 source = SourceClass(database=db, silent=False) try: @@ -234,7 +235,7 @@ def _delete_normalized_data(database: AbstractDatabase) -> None: click.echo(f"Deleted normalized records in {delete_time:.5f} seconds.") -def _load_merge(db: AbstractDatabase, processed_ids: Set[str]) -> None: +def _load_merge(db: AbstractDatabase, processed_ids: set[str]) -> None: """Load merged concepts :param db: database instance @@ -313,19 +314,21 @@ def update_normalizer_db( ctx = click.get_current_context() click.echo( "Must either enter 1 or more sources, or use `--update_all` parameter" - ) # noqa: E501 + ) click.echo(ctx.get_help()) ctx.exit() else: sources_split = sources.lower().split() if len(sources_split) == 0: - raise Exception("Must enter 1 or more source names to update") + err_msg = "Must enter 1 or more source names to update" + raise Exception(err_msg) non_sources = set(sources_split) - set(SOURCES) if len(non_sources) != 0: - raise Exception(f"Not valid source(s): {non_sources}") + err_msg = f"Not valid source(s): {non_sources}" + raise Exception(err_msg) parsed_source_names = {SourceName(SOURCES[s]) for s in sources_split} _update_normalizer(parsed_source_names, db, update_merged, use_existing) diff --git a/src/gene/database/__init__.py b/src/gene/database/__init__.py index 3a71e721..40d092bc 100644 --- a/src/gene/database/__init__.py +++ b/src/gene/database/__init__.py @@ -1,4 +1,5 @@ """Provide database clients.""" + from .database import ( AWS_ENV_VAR_NAME, AbstractDatabase, diff --git a/src/gene/database/database.py b/src/gene/database/database.py index 67bcafd6..3c90d6fc 100644 --- a/src/gene/database/database.py +++ b/src/gene/database/database.py @@ -1,10 +1,12 @@ """Provide abstract Database class and relevant tools for database initialization.""" + import abc import sys +from collections.abc import Generator from enum import Enum from os import environ from pathlib import Path -from typing import Any, Dict, Generator, List, Optional, Set, Union +from typing import Any import click @@ -34,7 +36,7 @@ class AbstractDatabase(abc.ABC): """ @abc.abstractmethod - def __init__(self, db_url: Optional[str] = None, **db_args) -> None: + def __init__(self, db_url: str | None = None, **db_args) -> None: """Initialize database instance. Generally, implementing classes should be able to construct a connection by @@ -47,7 +49,7 @@ def __init__(self, db_url: Optional[str] = None, **db_args) -> None: """ @abc.abstractmethod - def list_tables(self) -> List[str]: + def list_tables(self) -> list[str]: """Return names of tables in database. :return: Table names in database @@ -63,12 +65,11 @@ def _check_delete_okay() -> bool: """ if environ.get(AWS_ENV_VAR_NAME, "") == AwsEnvName.PRODUCTION: if environ.get(SKIP_AWS_DB_ENV_NAME, "") == "true": - raise DatabaseWriteException( - f"Must unset {SKIP_AWS_DB_ENV_NAME} env variable to enable drop_db()" # noqa: E501 - ) + err_msg = f"Must unset {SKIP_AWS_DB_ENV_NAME} env variable to enable drop_db()" + raise DatabaseWriteException(err_msg) return click.confirm("Are you sure you want to delete existing data?") - else: - return True + + return True @abc.abstractmethod def drop_db(self) -> None: @@ -107,7 +108,7 @@ def initialize_db(self) -> None: """ @abc.abstractmethod - def get_source_metadata(self, src_name: Union[str, SourceName]) -> Dict: + def get_source_metadata(self, src_name: str | SourceName) -> dict: """Get license, versioning, data lookup, etc information for a source. :param src_name: name of the source to get data for @@ -116,7 +117,7 @@ def get_source_metadata(self, src_name: Union[str, SourceName]) -> Dict: @abc.abstractmethod def get_record_by_id( self, concept_id: str, case_sensitive: bool = True, merge: bool = False - ) -> Optional[Dict]: + ) -> dict | None: """Fetch record corresponding to provided concept ID :param concept_id: concept ID for gene record @@ -128,7 +129,7 @@ def get_record_by_id( """ @abc.abstractmethod - def get_refs_by_type(self, search_term: str, ref_type: RefType) -> List[str]: + def get_refs_by_type(self, search_term: str, ref_type: RefType) -> list[str]: """Retrieve concept IDs for records matching the user's query. Other methods are responsible for actually retrieving full records. @@ -138,14 +139,14 @@ def get_refs_by_type(self, search_term: str, ref_type: RefType) -> List[str]: """ @abc.abstractmethod - def get_all_concept_ids(self) -> Set[str]: + def get_all_concept_ids(self) -> set[str]: """Retrieve all available concept IDs for use in generating normalized records. :return: List of concept IDs as strings. """ @abc.abstractmethod - def get_all_records(self, record_type: RecordType) -> Generator[Dict, None, None]: + def get_all_records(self, record_type: RecordType) -> Generator[dict, None, None]: """Retrieve all source or normalized records. Either return all source records, or all records that qualify as "normalized" (i.e., merged groups + source records that are otherwise ungrouped). @@ -172,7 +173,7 @@ def add_source_metadata(self, src_name: SourceName, data: SourceMeta) -> None: """ @abc.abstractmethod - def add_record(self, record: Dict, src_name: SourceName) -> None: + def add_record(self, record: dict, src_name: SourceName) -> None: """Add new record to database. :param record: record to upload @@ -180,7 +181,7 @@ def add_record(self, record: Dict, src_name: SourceName) -> None: """ @abc.abstractmethod - def add_merged_record(self, record: Dict) -> None: + def add_merged_record(self, record: dict) -> None: """Add merged record to database. :param record: merged record to add @@ -224,7 +225,7 @@ def close_connection(self) -> None: """Perform any manual connection closure procedures if necessary.""" @abc.abstractmethod - def load_from_remote(self, url: Optional[str] = None) -> None: + def load_from_remote(self, url: str | None = None) -> None: """Load DB from remote dump. Warning: Deletes all existing data. :param url: remote location to retrieve gzipped dump file from @@ -272,7 +273,7 @@ def confirm_aws_db_use(env_name: str) -> None: def create_db( - db_url: Optional[str] = None, aws_instance: bool = False + db_url: str | None = None, aws_instance: bool = False ) -> AbstractDatabase: """Database factory method. Checks environment variables and provided parameters and creates a DB instance. @@ -324,7 +325,7 @@ def create_db( else: if db_url: endpoint_url = db_url - elif "GENE_NORM_DB_URL" in environ.keys(): + elif "GENE_NORM_DB_URL" in environ: endpoint_url = environ["GENE_NORM_DB_URL"] else: endpoint_url = "http://localhost:8000" diff --git a/src/gene/database/dynamodb.py b/src/gene/database/dynamodb.py index 5baee7b7..37e05966 100644 --- a/src/gene/database/dynamodb.py +++ b/src/gene/database/dynamodb.py @@ -1,10 +1,12 @@ """Provide DynamoDB client.""" + import atexit import logging import sys +from collections.abc import Generator from os import environ from pathlib import Path -from typing import Any, Dict, Generator, List, Optional, Set, Union +from typing import Any import boto3 import click @@ -31,7 +33,7 @@ class DynamoDbDatabase(AbstractDatabase): """Database class employing DynamoDB.""" - def __init__(self, db_url: Optional[str] = None, **db_args) -> None: + def __init__(self, db_url: str | None = None, **db_args) -> None: """Initialize Database class. :param str db_url: URL endpoint for DynamoDB source @@ -44,20 +46,18 @@ def __init__(self, db_url: Optional[str] = None, **db_args) -> None: if AWS_ENV_VAR_NAME in environ: if "GENE_TEST" in environ: - raise DatabaseInitializationException( - f"Cannot have both GENE_TEST and {AWS_ENV_VAR_NAME} set." - ) # noqa: E501 + err_msg = f"Cannot have both GENE_TEST and {AWS_ENV_VAR_NAME} set." + raise DatabaseInitializationException(err_msg) aws_env = environ[AWS_ENV_VAR_NAME] if aws_env not in VALID_AWS_ENV_NAMES: - raise DatabaseInitializationException( - f"{AWS_ENV_VAR_NAME} must be one of {VALID_AWS_ENV_NAMES}" - ) # noqa: E501 + err_msg = f"{AWS_ENV_VAR_NAME} must be one of {VALID_AWS_ENV_NAMES}" + raise DatabaseInitializationException(err_msg) skip_confirmation = environ.get(SKIP_AWS_DB_ENV_NAME) if (not skip_confirmation) or ( skip_confirmation and skip_confirmation != "true" - ): # noqa: E501 + ): confirm_aws_db_use(environ[AWS_ENV_VAR_NAME]) boto_params = {"region_name": region_name} @@ -89,7 +89,7 @@ def __init__(self, db_url: Optional[str] = None, **db_args) -> None: self._cached_sources = {} atexit.register(self.close_connection) - def list_tables(self) -> List[str]: + def list_tables(self) -> list[str]: """Return names of tables in database. :return: Table names in DynamoDB @@ -156,7 +156,7 @@ def check_schema_initialized(self) -> bool: existing_tables = self.list_tables() exists = self.gene_table in existing_tables if not exists: - _logger.info(f"{self.gene_table} table is missing or unavailable.") + _logger.info("%s table is missing or unavailable.", self.gene_table) return exists def check_tables_populated(self) -> bool: @@ -201,7 +201,7 @@ def initialize_db(self) -> None: if not self.check_schema_initialized(): self._create_genes_table() - def get_source_metadata(self, src_name: Union[str, SourceName]) -> Dict: + def get_source_metadata(self, src_name: str | SourceName) -> dict: """Get license, versioning, data lookup, etc information for a source. :param src_name: name of the source to get data for @@ -210,22 +210,21 @@ def get_source_metadata(self, src_name: Union[str, SourceName]) -> Dict: src_name = src_name.value if src_name in self._cached_sources: return self._cached_sources[src_name] - else: - pk = f"{src_name.lower()}##source" - concept_id = f"source:{src_name.lower()}" - metadata = self.genes.get_item( - Key={"label_and_type": pk, "concept_id": concept_id} - ).get("Item") - if not metadata: - raise DatabaseReadException( - f"Unable to retrieve data for source {src_name}" - ) - self._cached_sources[src_name] = metadata - return metadata + + pk = f"{src_name.lower()}##source" + concept_id = f"source:{src_name.lower()}" + metadata = self.genes.get_item( + Key={"label_and_type": pk, "concept_id": concept_id} + ).get("Item") + if not metadata: + err_msg = f"Unable to retrieve data for source {src_name}" + raise DatabaseReadException(err_msg) + self._cached_sources[src_name] = metadata + return metadata def get_record_by_id( self, concept_id: str, case_sensitive: bool = True, merge: bool = False - ) -> Optional[Dict]: + ) -> dict | None: """Fetch record corresponding to provided concept ID :param str concept_id: concept ID for gene record @@ -246,23 +245,23 @@ def get_record_by_id( Key={"label_and_type": pk, "concept_id": concept_id} ) return match["Item"] - else: - exp = Key("label_and_type").eq(pk) - response = self.genes.query(KeyConditionExpression=exp) - record = response["Items"][0] - del record["label_and_type"] - return record + + exp = Key("label_and_type").eq(pk) + response = self.genes.query(KeyConditionExpression=exp) + record = response["Items"][0] + del record["label_and_type"] + return record except ClientError as e: _logger.error( - f"boto3 client error on get_records_by_id for " - f"search term {concept_id}: " - f"{e.response['Error']['Message']}" + "boto3 client error on get_records_by_id for search term %s: %s", + concept_id, + e.response["Error"]["Message"], ) return None except (KeyError, IndexError): # record doesn't exist return None - def get_refs_by_type(self, search_term: str, ref_type: RefType) -> List[str]: + def get_refs_by_type(self, search_term: str, ref_type: RefType) -> list[str]: """Retrieve concept IDs for records matching the user's query. Other methods are responsible for actually retrieving full records. @@ -277,13 +276,13 @@ def get_refs_by_type(self, search_term: str, ref_type: RefType) -> List[str]: return [m["concept_id"] for m in matches.get("Items", None)] except ClientError as e: _logger.error( - f"boto3 client error on get_refs_by_type for " - f"search term {search_term}: " - f"{e.response['Error']['Message']}" + "boto3 client error on get_refs_by_type for search term %s: %s", + search_term, + e.response["Error"]["Message"], ) return [] - def get_all_concept_ids(self) -> Set[str]: + def get_all_concept_ids(self) -> set[str]: """Retrieve concept IDs for use in generating normalized records. :return: List of concept IDs as strings. @@ -301,14 +300,13 @@ def get_all_concept_ids(self) -> Set[str]: else: response = self.genes.scan(**params) records = response["Items"] - for record in records: - concept_ids.append(record["concept_id"]) + concept_ids.extend(record["concept_id"] for record in records) last_evaluated_key = response.get("LastEvaluatedKey") if not last_evaluated_key: break return set(concept_ids) - def get_all_records(self, record_type: RecordType) -> Generator[Dict, None, None]: + def get_all_records(self, record_type: RecordType) -> Generator[dict, None, None]: """Retrieve all source or normalized records. Either return all source records, or all records that qualify as "normalized" (i.e., merged groups + source records that are otherwise ungrouped). @@ -341,7 +339,7 @@ def get_all_records(self, record_type: RecordType) -> Generator[Dict, None, None else: if ( incoming_record_type == RecordType.IDENTITY - and not record.get("merge_ref") # noqa: E501 + and not record.get("merge_ref") ) or incoming_record_type == RecordType.MERGER: yield record last_evaluated_key = response.get("LastEvaluatedKey") @@ -364,9 +362,9 @@ def add_source_metadata(self, src_name: SourceName, metadata: SourceMeta) -> Non try: self.genes.put_item(Item=metadata_item) except ClientError as e: - raise DatabaseWriteException(e) + raise DatabaseWriteException(e) from e - def add_record(self, record: Dict, src_name: SourceName) -> None: + def add_record(self, record: dict, src_name: SourceName) -> None: """Add new record to database. :param Dict record: record to upload @@ -381,8 +379,9 @@ def add_record(self, record: Dict, src_name: SourceName) -> None: self.batch.put_item(Item=record) except ClientError as e: _logger.error( - "boto3 client error on add_record for " - f"{concept_id}: {e.response['Error']['Message']}" + "boto3 client error on add_record for %s: %s", + concept_id, + e.response["Error"]["Message"], ) for attr_type, item_type in ITEM_TYPES.items(): if attr_type in record: @@ -398,7 +397,7 @@ def add_record(self, record: Dict, src_name: SourceName) -> None: item, record["concept_id"], item_type, src_name ) - def add_merged_record(self, record: Dict) -> None: + def add_merged_record(self, record: dict) -> None: """Add merged record to database. :param record: merged record to add @@ -413,8 +412,9 @@ def add_merged_record(self, record: Dict) -> None: self.batch.put_item(Item=record) except ClientError as e: _logger.error( - "boto3 client error on add_record for " - f"{concept_id}: {e.response['Error']['Message']}" + "boto3 client error on add_record for " "%s: %s", + concept_id, + e.response["Error"]["Message"], ) def _add_ref_record( @@ -439,9 +439,11 @@ def _add_ref_record( self.batch.put_item(Item=record) except ClientError as e: _logger.error( - f"boto3 client error adding reference {term} for " - f"{concept_id} with match type {ref_type}: " - f"{e.response['Error']['Message']}" + "boto3 client error adding reference %s for %s with match type %s: %s", + term, + concept_id, + ref_type, + e.response["Error"]["Message"], ) def update_merge_ref(self, concept_id: str, merge_ref: Any) -> None: # noqa: ANN401 @@ -466,14 +468,15 @@ def update_merge_ref(self, concept_id: str, merge_ref: Any) -> None: # noqa: AN except ClientError as e: code = e.response.get("Error", {}).get("Code") if code == "ConditionalCheckFailedException": - raise DatabaseWriteException( + err_msg = ( f"No such record exists for keys {label_and_type}, {concept_id}" ) - else: - _logger.error( - f"boto3 client error in `database.update_record()`: " - f"{e.response['Error']['Message']}" - ) + raise DatabaseWriteException(err_msg) from e + + _logger.error( + "boto3 client error in `database.update_record()`: %s", + e.response["Error"]["Message"], + ) def delete_normalized_concepts(self) -> None: """Remove merged records from the database. Use when performing a new update @@ -495,7 +498,7 @@ def delete_normalized_concepts(self) -> None: ), ) except ClientError as e: - raise DatabaseReadException(e) + raise DatabaseReadException(e) from e records = response["Items"] if not records: break @@ -522,23 +525,23 @@ def delete_source(self, src_name: SourceName) -> None: KeyConditionExpression=Key("src_name").eq(src_name.value), ) except ClientError as e: - raise DatabaseReadException(e) + raise DatabaseReadException(e) from e records = response["Items"] if not records: break with self.genes.batch_writer( overwrite_by_pkeys=["label_and_type", "concept_id"] ) as batch: - for record in records: - try: + try: + for record in records: batch.delete_item( Key={ "label_and_type": record["label_and_type"], "concept_id": record["concept_id"], } ) - except ClientError as e: - raise DatabaseWriteException(e) + except ClientError as e: + raise DatabaseWriteException(e) from e def complete_write_transaction(self) -> None: """Conclude transaction or batch writing if relevant.""" @@ -549,7 +552,7 @@ def close_connection(self) -> None: """Perform any manual connection closure procedures if necessary.""" self.batch.__exit__(*sys.exc_info()) - def load_from_remote(self, url: Optional[str] = None) -> None: + def load_from_remote(self, url: str | None = None) -> None: """Load DB from remote dump. Not available for DynamoDB database backend. :param url: remote location to retrieve gzipped dump file from diff --git a/src/gene/database/postgresql.py b/src/gene/database/postgresql.py index e181a2a1..59d4388c 100644 --- a/src/gene/database/postgresql.py +++ b/src/gene/database/postgresql.py @@ -1,13 +1,15 @@ """Provide PostgreSQL client.""" + import atexit +import datetime import json import logging import os import tarfile import tempfile -from datetime import datetime +from collections.abc import Generator from pathlib import Path -from typing import Any, Dict, Generator, List, Optional, Set, Tuple +from typing import Any, ClassVar import psycopg import requests @@ -35,7 +37,7 @@ class PostgresDatabase(AbstractDatabase): """Database class employing PostgreSQL.""" - def __init__(self, db_url: Optional[str] = None, **db_args) -> None: + def __init__(self, db_url: str | None = None, **db_args) -> None: """Initialize Postgres connection. >>> from gene.database.postgresql import PostgresDatabase @@ -79,7 +81,7 @@ def __init__(self, db_url: Optional[str] = None, **db_args) -> None: AND table_type = 'BASE TABLE'; """ - def list_tables(self) -> List[str]: + def list_tables(self) -> list[str]: """Return names of tables in database. :return: Table names in database @@ -267,7 +269,7 @@ def _create_tables(self) -> None: cur.execute(tables_query) self.conn.commit() - def get_source_metadata(self, src_name: SourceName) -> Dict: + def get_source_metadata(self, src_name: SourceName) -> dict: """Get license, versioning, data lookup, etc information for a source. :param src_name: name of the source to get data for @@ -283,7 +285,8 @@ def get_source_metadata(self, src_name: SourceName) -> Dict: cur.execute(metadata_query, [src_name]) metadata_result = cur.fetchone() if not metadata_result: - raise DatabaseReadException(f"{src_name} metadata lookup failed") + err_msg = f"{src_name} metadata lookup failed" + raise DatabaseReadException(err_msg) metadata = { "data_license": metadata_result[1], "data_license_url": metadata_result[2], @@ -301,10 +304,10 @@ def get_source_metadata(self, src_name: SourceName) -> Dict: return metadata _get_record_query = ( - b"SELECT * FROM record_lookup_view WHERE lower(concept_id) = %s;" # noqa: E501 + b"SELECT * FROM record_lookup_view WHERE lower(concept_id) = %s;" ) - def _format_source_record(self, source_row: Tuple) -> Dict: + def _format_source_record(self, source_row: tuple) -> dict: """Restructure row from gene_concepts table as source record result object. :param source_row: result tuple from psycopg @@ -329,13 +332,11 @@ def _format_source_record(self, source_row: Tuple) -> Dict: } return {k: v for k, v in gene_record.items() if v} - def _get_record(self, concept_id: str, case_sensitive: bool) -> Optional[Dict]: + def _get_record(self, concept_id: str) -> dict | None: """Retrieve non-merged record. The query is pretty different, so this method is broken out for PostgreSQL. :param concept_id: ID of concept to get - :param case_sensitive: record lookups are performed using a case-insensitive - index, so this parameter isn't used by Postgres :return: complete record object if successful """ concept_id_param = concept_id.lower() @@ -347,7 +348,7 @@ def _get_record(self, concept_id: str, case_sensitive: bool) -> Optional[Dict]: return None return self._format_source_record(result) - def _format_merged_record(self, merged_row: Tuple) -> Dict: + def _format_merged_record(self, merged_row: tuple) -> dict: """Restructure row from gene_merged table as normalized result object. :param merged_row: result tuple from psycopg @@ -375,17 +376,13 @@ def _format_merged_record(self, merged_row: Tuple) -> Dict: return {k: v for k, v in merged_record.items() if v} _get_merged_record_query = ( - b"SELECT * FROM gene_merged WHERE lower(concept_id) = %s;" # noqa: E501 + b"SELECT * FROM gene_merged WHERE lower(concept_id) = %s;" ) - def _get_merged_record( - self, concept_id: str, case_sensitive: bool - ) -> Optional[Dict]: + def _get_merged_record(self, concept_id: str) -> dict | None: """Retrieve normalized record from DB. :param concept_id: normalized ID for the merged record - :param case_sensitive: record lookups are performed using a case-insensitive - index, so this parameter isn't used by Postgres :return: normalized record if successful """ concept_id = concept_id.lower() @@ -397,29 +394,32 @@ def _get_merged_record( return self._format_merged_record(result) def get_record_by_id( - self, concept_id: str, case_sensitive: bool = True, merge: bool = False - ) -> Optional[Dict]: + self, + concept_id: str, + case_sensitive: bool = True, # noqa: ARG002 + merge: bool = False, + ) -> dict | None: """Fetch record corresponding to provided concept ID :param str concept_id: concept ID for gene record - :param bool case_sensitive: + :param bool case_sensitive: Not used by PostgreSQL instance. :param bool merge: if true, look for merged record; look for identity record otherwise. :return: complete gene record, if match is found; None otherwise """ if merge: - return self._get_merged_record(concept_id, case_sensitive) - else: - return self._get_record(concept_id, case_sensitive) + return self._get_merged_record(concept_id) - _ref_types_query = { - RefType.SYMBOL: b"SELECT concept_id FROM gene_symbols WHERE lower(symbol) = %s;", # noqa: E501 - RefType.PREVIOUS_SYMBOLS: b"SELECT concept_id FROM gene_previous_symbols WHERE lower(prev_symbol) = %s;", # noqa: E501 - RefType.ALIASES: b"SELECT concept_id FROM gene_aliases WHERE lower(alias) = %s;", # noqa: E501 + return self._get_record(concept_id) + + _ref_types_query: ClassVar[dict] = { + RefType.SYMBOL: b"SELECT concept_id FROM gene_symbols WHERE lower(symbol) = %s;", + RefType.PREVIOUS_SYMBOLS: b"SELECT concept_id FROM gene_previous_symbols WHERE lower(prev_symbol) = %s;", + RefType.ALIASES: b"SELECT concept_id FROM gene_aliases WHERE lower(alias) = %s;", RefType.XREFS: b"SELECT concept_id FROM gene_xrefs WHERE lower(xref) = %s;", - RefType.ASSOCIATED_WITH: b"SELECT concept_id FROM gene_associations WHERE lower(associated_with) = %s;", # noqa: E501 + RefType.ASSOCIATED_WITH: b"SELECT concept_id FROM gene_associations WHERE lower(associated_with) = %s;", } - def get_refs_by_type(self, search_term: str, ref_type: RefType) -> List[str]: + def get_refs_by_type(self, search_term: str, ref_type: RefType) -> list[str]: """Retrieve concept IDs for records matching the user's query. Other methods are responsible for actually retrieving full records. @@ -429,19 +429,20 @@ def get_refs_by_type(self, search_term: str, ref_type: RefType) -> List[str]: """ query = self._ref_types_query.get(ref_type) if not query: - raise ValueError("invalid reference type") + err_msg = "invalid reference type" + raise ValueError(err_msg) with self.conn.cursor() as cur: cur.execute(query, (search_term.lower(),)) concept_ids = cur.fetchall() if concept_ids: return [i[0] for i in concept_ids] - else: - return [] + + return [] _ids_query = b"SELECT concept_id FROM gene_concepts;" - def get_all_concept_ids(self) -> Set[str]: + def get_all_concept_ids(self) -> set[str]: """Retrieve concept IDs for use in generating normalized records. :return: Set of concept IDs as strings. @@ -453,11 +454,11 @@ def get_all_concept_ids(self) -> Set[str]: _get_all_normalized_records_query = b"SELECT * FROM gene_merged;" _get_all_unmerged_source_records_query = ( - b"SELECT * FROM record_lookup_view WHERE merge_ref IS NULL;" # noqa: E501 + b"SELECT * FROM record_lookup_view WHERE merge_ref IS NULL;" ) _get_all_source_records_query = b"SELECT * FROM record_lookup_view;" - def get_all_records(self, record_type: RecordType) -> Generator[Dict, None, None]: + def get_all_records(self, record_type: RecordType) -> Generator[dict, None, None]: """Retrieve all source or normalized records. Either return all source records, or all records that qualify as "normalized" (i.e., merged groups + source records that are otherwise ungrouped). @@ -557,7 +558,7 @@ def add_source_metadata(self, src_name: SourceName, meta: SourceMeta) -> None: b"INSERT INTO gene_associations (associated_with, concept_id) VALUES (%s, %s);" ) - def add_record(self, record: Dict, src_name: SourceName) -> None: + def add_record(self, record: dict, src_name: SourceName) -> None: # noqa: ARG002 """Add new record to database. :param record: record to upload @@ -594,7 +595,7 @@ def add_record(self, record: Dict, src_name: SourceName) -> None: cur.execute(self._ins_symbol_query, [record["symbol"], concept_id]) self.conn.commit() except UniqueViolation: - _logger.error(f"Record with ID {concept_id} already exists") + _logger.error("Record with ID %s already exists", concept_id) self.conn.rollback() _add_merged_record_query = b""" @@ -607,7 +608,7 @@ def add_record(self, record: Dict, src_name: SourceName) -> None: VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s); """ - def add_merged_record(self, record: Dict) -> None: + def add_merged_record(self, record: dict) -> None: """Add merged record to database. :param record: merged record to add @@ -668,9 +669,8 @@ def update_merge_ref(self, concept_id: str, merge_ref: Any) -> None: # noqa: AN # UPDATE will fail silently unless we check the # of affected rows if row_count < 1: - raise DatabaseWriteException( - f"No such record exists for primary key {concept_id}" - ) + err_msg = f"No such record exists for primary key {concept_id}" + raise DatabaseWriteException(err_msg) def delete_normalized_concepts(self) -> None: """Remove merged records from the database. Use when performing a new update @@ -775,7 +775,7 @@ def close_connection(self) -> None: self.conn.commit() self.conn.close() - def load_from_remote(self, url: Optional[str]) -> None: + def load_from_remote(self, url: str | None) -> None: """Load DB from remote dump. Warning: Deletes all existing data. If not passed as an argument, will try to grab latest release from VICC S3 bucket. @@ -784,35 +784,35 @@ def load_from_remote(self, url: Optional[str]) -> None: command fails """ if not url: - url = "https://vicc-normalizers.s3.us-east-2.amazonaws.com/gene_normalization/postgresql/gene_norm_latest.sql.tar.gz" # noqa: E501 + url = "https://vicc-normalizers.s3.us-east-2.amazonaws.com/gene_normalization/postgresql/gene_norm_latest.sql.tar.gz" with tempfile.TemporaryDirectory() as tempdir: tempdir_path = Path(tempdir) temp_tarfile = tempdir_path / "gene_norm_latest.tar.gz" - with requests.get(url, stream=True) as r: + with requests.get(url, stream=True, timeout=10) as r: try: r.raise_for_status() - except requests.HTTPError: - raise DatabaseException( - f"Unable to retrieve PostgreSQL dump file from {url}" - ) - with open(temp_tarfile, "wb") as h: + except requests.HTTPError as e: + err_msg = f"Unable to retrieve PostgreSQL dump file from {url}" + raise DatabaseException(err_msg) from e + with temp_tarfile.open("wb") as h: for chunk in r.iter_content(chunk_size=8192): if chunk: h.write(chunk) tar = tarfile.open(temp_tarfile, "r:gz") - tar_dump_file = [ + tar_dump_file = next( f for f in tar.getmembers() if f.name.startswith("gene_norm_") - ][0] - tar.extractall(path=tempdir_path, members=[tar_dump_file]) + ) + tar.extractall(path=tempdir_path, members=[tar_dump_file]) # noqa: S202 dump_file = tempdir_path / tar_dump_file.name self.drop_db() system_call = f"psql {self.conninfo} -f {dump_file.absolute()}" - result = os.system(system_call) + result = os.system(system_call) # noqa: S605 if result != 0: - raise DatabaseException( + err_msg = ( f"System call '{system_call}' returned failing exit code {result}." ) + raise DatabaseException(err_msg) def export_db(self, output_directory: Path) -> None: """Dump DB to specified location. @@ -824,14 +824,16 @@ def export_db(self, output_directory: Path) -> None: :raise DatabaseException: if psql call fails """ if not output_directory.is_dir() or not output_directory.exists(): - raise ValueError( + err_msg = ( f"Output location {output_directory} isn't a directory or doesn't exist" - ) # noqa: E501 - now = datetime.now().strftime("%Y%m%d%H%M%S") + ) + raise ValueError(err_msg) + now = datetime.datetime.now(tz=datetime.timezone.utc).strftime("%Y%m%d%H%M%S") output_location = output_directory / f"gene_norm_{now}.sql" system_call = f"pg_dump {self.conninfo} -E UTF8 -f {output_location}" - result = os.system(system_call) + result = os.system(system_call) # noqa: S605 if result != 0: - raise DatabaseException( + err_msg = ( f"System call '{system_call}' returned failing exit code {result}." ) + raise DatabaseException(err_msg) diff --git a/src/gene/etl/__init__.py b/src/gene/etl/__init__.py index 569df1d7..8c26a188 100644 --- a/src/gene/etl/__init__.py +++ b/src/gene/etl/__init__.py @@ -1,4 +1,5 @@ """Module to load and init namespace at package level.""" + from .ensembl import Ensembl from .exceptions import ( GeneFileVersionError, diff --git a/src/gene/etl/base.py b/src/gene/etl/base.py index d49b3b11..8e237fc9 100644 --- a/src/gene/etl/base.py +++ b/src/gene/etl/base.py @@ -1,9 +1,9 @@ """A base class for extraction, transformation, and loading of data.""" + import logging import re from abc import ABC, abstractmethod from pathlib import Path -from typing import Dict, List, Optional, Union import click import pydantic @@ -32,7 +32,7 @@ def __init__( self, database: AbstractDatabase, seqrepo_dir: Path = SEQREPO_ROOT_DIR, - data_path: Optional[Path] = None, + data_path: Path | None = None, silent: bool = True, ) -> None: """Instantiate Base class. @@ -47,11 +47,11 @@ def __init__( self._data_source = self._get_data_handler(data_path) self._database = database self.seqrepo = self.get_seqrepo(seqrepo_dir) - self._processed_ids = list() + self._processed_ids = [] def _get_data_handler( - self, data_path: Optional[Path] = None - ) -> Union[HgncData, EnsemblData, NcbiGeneData]: + self, data_path: Path | None = None + ) -> HgncData | EnsemblData | NcbiGeneData: """Construct data handler instance for source. Overwrite for edge-case sources. :param data_path: location of data storage @@ -59,7 +59,7 @@ def _get_data_handler( """ return DATA_DISPATCH[self._src_name](data_dir=data_path, silent=self._silent) - def perform_etl(self, use_existing: bool = False) -> List[str]: + def perform_etl(self, use_existing: bool = False) -> list[str]: """Public-facing method to begin ETL procedures on given data. Returned concept IDs can be passed to Merge method for computing merged concepts. @@ -100,7 +100,7 @@ def _add_meta(self) -> None: """Add source meta to database source info.""" raise NotImplementedError - def _load_gene(self, gene: Dict) -> None: + def _load_gene(self, gene: dict) -> None: """Load a gene record into database. This method takes responsibility for: * validating structure correctness * removing duplicates from list-like fields @@ -109,9 +109,9 @@ def _load_gene(self, gene: Dict) -> None: :param gene: Gene record """ try: - assert Gene(match_type=MatchType.NO_MATCH, **gene) + Gene(match_type=MatchType.NO_MATCH, **gene) except pydantic.ValidationError as e: - _logger.warning(f"Unable to load {gene} due to validation error: " f"{e}") + _logger.warning("Unable to load %s due to validation error: %s", gene, e) else: concept_id = gene["concept_id"] gene["label_and_type"] = f"{concept_id.lower()}##identity" @@ -137,20 +137,21 @@ def get_seqrepo(self, seqrepo_dir: Path) -> SeqRepo: :return: SeqRepo instance """ if not Path(seqrepo_dir).exists(): - raise NotADirectoryError(f"Could not find {seqrepo_dir}") + err_msg = f"Could not find {seqrepo_dir}" + raise NotADirectoryError(err_msg) return SeqRepo(seqrepo_dir) - def _set_cl_interval_range(self, loc: str, arm_ix: int, location: Dict) -> None: + def _set_cl_interval_range(self, loc: str, arm_ix: int, location: dict) -> None: """Set the Chromosome location interval range. :param loc: A gene location :param arm_ix: The index of the q or p arm for a given location :param location: VRS chromosome location. This will be mutated. """ - range_ix = re.search("-", loc).start() # type: ignore + range_ix = re.search("-", loc).start() start = loc[arm_ix:range_ix] - start_arm_ix = re.search("[pq]", start).start() # type: ignore + start_arm_ix = re.search("[pq]", start).start() start_arm = start[start_arm_ix] end = loc[range_ix + 1 :] @@ -161,7 +162,7 @@ def _set_cl_interval_range(self, loc: str, arm_ix: int, location: Dict) -> None: end = f"{start[0]}{end}" end_arm_match = re.search("[pq]", end) - end_arm_ix = end_arm_match.start() # type: ignore + end_arm_ix = end_arm_match.start() end_arm = end[end_arm_ix] if (start_arm == end_arm and start > end) or ( @@ -202,7 +203,7 @@ def _set_cl_interval_range(self, loc: str, arm_ix: int, location: Dict) -> None: # return chr_location # return None - def _get_seq_id_aliases(self, seq_id: str) -> List[str]: + def _get_seq_id_aliases(self, seq_id: str) -> list[str]: """Get GA4GH aliases for a sequence id :param seq_id: Sequence ID accession @@ -212,10 +213,10 @@ def _get_seq_id_aliases(self, seq_id: str) -> List[str]: try: aliases = self.seqrepo.translate_alias(seq_id, target_namespaces="ga4gh") except KeyError as e: - _logger.warning(f"SeqRepo raised KeyError: {e}") + _logger.warning("SeqRepo raised KeyError: %s", e) return aliases - def _get_sequence_location(self, seq_id: str, gene: Feature, params: Dict) -> Dict: + def _get_sequence_location(self, seq_id: str, gene: Feature, params: dict) -> dict: """Get a gene's GeneSequenceLocation. :param seq_id: The sequence ID. @@ -232,15 +233,17 @@ def _get_sequence_location(self, seq_id: str, gene: Feature, params: Dict) -> Di sequence = aliases[0] if gene.start != "." and gene.end != "." and sequence: - if 0 <= gene.start <= gene.end: # type: ignore + if 0 <= gene.start <= gene.end: location = GeneSequenceLocation( - start=gene.start - 1, # type: ignore - end=gene.end, # type: ignore + start=gene.start - 1, + end=gene.end, sequence_id=sequence, - ).model_dump() # type: ignore + ).model_dump() else: _logger.warning( - f"{params['concept_id']} has invalid interval:" - f"start={gene.start - 1} end={gene.end}" - ) # type: ignore + "%s has invalid interval: start=%i end=%i", + params["concept_id"], + gene.start - 1, + gene.end, + ) return location diff --git a/src/gene/etl/ensembl.py b/src/gene/etl/ensembl.py index e062a545..89953d99 100644 --- a/src/gene/etl/ensembl.py +++ b/src/gene/etl/ensembl.py @@ -1,7 +1,7 @@ """Defines the Ensembl ETL methods.""" + import logging import re -from typing import Dict import gffutils from gffutils.feature import Feature @@ -45,7 +45,7 @@ def _transform_data(self) -> None: ) # Get accession numbers - accession_numbers = dict() + accession_numbers = {} for item in db.features_of_type("scaffold"): accession_numbers[item[0]] = item[8]["Alias"][-1] for item in db.features_of_type("chromosome"): @@ -60,14 +60,14 @@ def _transform_data(self) -> None: self._load_gene(gene) _logger.info("Successfully transformed Ensembl.") - def _add_gene(self, f: Feature, accession_numbers: Dict) -> Dict: + def _add_gene(self, f: Feature, accession_numbers: dict) -> dict: """Create a transformed gene record. :param f: A gene from the data :param accession_numbers: Accession numbers for each chromosome and scaffold :return: A gene dictionary containing data if the ID attribute exists. """ - gene = dict() + gene = {} if f.strand == "-": gene["strand"] = Strand.REVERSE.value elif f.strand == "+": @@ -84,7 +84,7 @@ def _add_gene(self, f: Feature, accession_numbers: Dict) -> Dict: return gene - def _add_attributes(self, f: Feature, gene: Dict) -> None: + def _add_attributes(self, f: Feature, gene: dict) -> None: """Add concept_id, symbol, xrefs, and associated_with to a gene record. :param f: A gene from the data @@ -100,17 +100,13 @@ def _add_attributes(self, f: Feature, gene: Dict) -> None: for attribute in f.attributes.items(): key = attribute[0] - if key in attributes.keys(): + if key in attributes: val = attribute[1] if len(val) == 1: val = val[0] - if key == "ID": - if val.startswith("gene"): - val = ( - f"{NamespacePrefix.ENSEMBL.value}:" - f"{val.split(':')[1]}" - ) + if key == "ID" and val.startswith("gene"): + val = f"{NamespacePrefix.ENSEMBL.value}:" f"{val.split(':')[1]}" if key == "description": gene["label"] = val.split("[")[0].strip() @@ -133,7 +129,7 @@ def _add_attributes(self, f: Feature, gene: Dict) -> None: gene[attributes[key]] = val - def _add_location(self, f: Feature, gene: Dict, accession_numbers: Dict) -> Dict: + def _add_location(self, f: Feature, gene: dict, accession_numbers: dict) -> dict: """Add GA4GH SequenceLocation to a gene record. https://vr-spec.readthedocs.io/en/1.1/terms_and_model.html#sequencelocation @@ -144,14 +140,14 @@ def _add_location(self, f: Feature, gene: Dict, accession_numbers: Dict) -> Dict """ return self._get_sequence_location(accession_numbers[f.seqid], f, gene) - def _get_xref_associated_with(self, src_name: str, src_id: str) -> Dict: + def _get_xref_associated_with(self, src_name: str, src_id: str) -> dict: """Get xref or associated_with concept. :param src_name: Source name :param src_id: The source's accession number :return: A dict containing an other identifier or xref """ - source = dict() + source = {} if src_name.startswith("HGNC"): source["xrefs"] = [f"{NamespacePrefix.HGNC.value}:{src_id}"] elif src_name.startswith("NCBI"): @@ -170,9 +166,8 @@ def _add_meta(self) -> None: :raise GeneNormalizerEtlError: if requisite metadata is unset """ if not self._version or not self._assembly: - raise GeneNormalizerEtlError( - "Source metadata unavailable -- was data properly acquired before attempting to load DB?" - ) + err_msg = "Source metadata unavailable -- was data properly acquired before attempting to load DB?" + raise GeneNormalizerEtlError(err_msg) metadata = SourceMeta( data_license="custom", data_license_url="https://useast.ensembl.org/info/about" diff --git a/src/gene/etl/hgnc.py b/src/gene/etl/hgnc.py index 242923d7..fbb7f8ad 100644 --- a/src/gene/etl/hgnc.py +++ b/src/gene/etl/hgnc.py @@ -1,8 +1,8 @@ """Defines the HGNC ETL methods.""" + import json import logging import re -from typing import Dict from gene import PREFIX_LOOKUP from gene.etl.base import Base @@ -27,13 +27,13 @@ class HGNC(Base): def _transform_data(self) -> None: """Transform the HGNC source.""" _logger.info("Transforming HGNC...") - with open(self._data_file, "r") as f: # type: ignore + with self._data_file.open() as f: data = json.load(f) records = data["response"]["docs"] for r in records: - gene = dict() + gene = {} gene["concept_id"] = r["hgnc_id"].lower() gene["label_and_type"] = f"{gene['concept_id']}##identity" gene["item_type"] = "identity" @@ -59,14 +59,14 @@ def _transform_data(self) -> None: self._load_gene(gene) _logger.info("Successfully transformed HGNC.") - def _get_aliases(self, r: Dict, gene: Dict) -> None: + def _get_aliases(self, r: dict, gene: dict) -> None: """Store aliases in a gene record. :param r: A gene record in the HGNC data file :param gene: A transformed gene record """ - alias_symbol = list() - enzyme_id = list() + alias_symbol = [] + enzyme_id = [] if "alias_symbol" in r: alias_symbol = r["alias_symbol"] @@ -76,7 +76,7 @@ def _get_aliases(self, r: Dict, gene: Dict) -> None: if alias_symbol or enzyme_id: gene["aliases"] = list(set(alias_symbol + enzyme_id)) - def _get_previous_symbols(self, r: Dict, gene: Dict) -> None: + def _get_previous_symbols(self, r: dict, gene: dict) -> None: """Store previous symbols in a gene record. :param r: A gene record in the HGNC data file @@ -86,14 +86,14 @@ def _get_previous_symbols(self, r: Dict, gene: Dict) -> None: if prev_symbols: gene["previous_symbols"] = list(set(prev_symbols)) - def _get_xrefs_associated_with(self, r: Dict, gene: Dict) -> None: + def _get_xrefs_associated_with(self, r: dict, gene: dict) -> None: """Store xrefs and/or associated_with refs in a gene record. :param r: A gene record in the HGNC data file :param gene: A transformed gene record """ - xrefs = list() - associated_with = list() + xrefs = [] + associated_with = [] sources = [ "entrez_id", "ensembl_gene_id", @@ -133,12 +133,12 @@ def _get_xrefs_associated_with(self, r: Dict, gene: Dict) -> None: key = src if key.upper() in NamespacePrefix.__members__: - if NamespacePrefix[key.upper()].value in PREFIX_LOOKUP.keys(): + if NamespacePrefix[key.upper()].value in PREFIX_LOOKUP: self._get_xref_associated_with(key, src, r, xrefs) else: self._get_xref_associated_with(key, src, r, associated_with) else: - _logger.warning(f"{key} not in schemas.py") + _logger.warning("%s not in schemas.py", key) if xrefs: gene["xrefs"] = xrefs @@ -146,7 +146,7 @@ def _get_xrefs_associated_with(self, r: Dict, gene: Dict) -> None: gene["associated_with"] = associated_with def _get_xref_associated_with( - self, key: str, src: str, r: Dict, src_type: Dict + self, key: str, src: str, r: dict, src_type: dict ) -> None: """Add an xref or associated_with ref to a gene record. @@ -163,7 +163,7 @@ def _get_xref_associated_with( r[src] = r[src].split(":")[-1].strip() src_type.append(f"{NamespacePrefix[key.upper()].value}" f":{r[src]}") - def _get_location(self, r: Dict, gene: Dict) -> None: + def _get_location(self, r: dict, gene: dict) -> None: """Store GA4GH VRS ChromosomeLocation in a gene record. https://vr-spec.readthedocs.io/en/1.1/terms_and_model.html#chromosomelocation @@ -176,8 +176,8 @@ def _get_location(self, r: Dict, gene: Dict) -> None: else: locations = [r["location"]] - location_list = list() - gene["location_annotations"] = list() + location_list = [] + gene["location_annotations"] = [] for loc in locations: loc = loc.strip() loc = self._set_annotation(loc, gene) @@ -186,7 +186,7 @@ def _get_location(self, r: Dict, gene: Dict) -> None: if loc == "mitochondria": gene["location_annotations"].append(Chromosome.MITOCHONDRIA.value) else: - location = dict() + location = {} self._set_location(loc, location, gene) # chr_location = self._get_chromosome_location(location, gene) # if chr_location: @@ -197,7 +197,7 @@ def _get_location(self, r: Dict, gene: Dict) -> None: if not gene["location_annotations"]: del gene["location_annotations"] - def _set_annotation(self, loc: str, gene: Dict) -> None: + def _set_annotation(self, loc: str, gene: dict) -> None: """Set the annotations attribute if one is provided. Return `True` if a location is provided, `False` otherwise. @@ -216,7 +216,7 @@ def _set_annotation(self, loc: str, gene: Dict) -> None: return None return loc - def _set_location(self, loc: str, location: Dict, gene: Dict) -> None: + def _set_location(self, loc: str, location: dict, gene: dict) -> None: """Set a gene's location. :param loc: A gene location @@ -248,9 +248,8 @@ def _add_meta(self) -> None: :raise GeneNormalizerEtlError: if requisite metadata is unset """ if not self._version: - raise GeneNormalizerEtlError( - "Source metadata unavailable -- was data properly acquired before attempting to load DB?" - ) + err_msg = "Source metadata unavailable -- was data properly acquired before attempting to load DB?" + raise GeneNormalizerEtlError(err_msg) metadata = SourceMeta( data_license="CC0", data_license_url="https://www.genenames.org/about/license/", diff --git a/src/gene/etl/merge.py b/src/gene/etl/merge.py index 651ce8ae..e56c92d5 100644 --- a/src/gene/etl/merge.py +++ b/src/gene/etl/merge.py @@ -1,7 +1,7 @@ """Create concept groups and merged records.""" + import logging from timeit import default_timer as timer -from typing import Dict, Optional, Set, Tuple from gene.database import AbstractDatabase from gene.database.database import DatabaseWriteException @@ -21,7 +21,7 @@ def __init__(self, database: AbstractDatabase) -> None: self._database = database self._groups = {} # dict keying concept IDs to group Sets - def create_merged_concepts(self, record_ids: Set[str]) -> None: + def create_merged_concepts(self, record_ids: set[str]) -> None: """Create concept groups, generate merged concept records, and update database. :param record_ids: concept identifiers from which groups should be generated. @@ -35,7 +35,7 @@ def create_merged_concepts(self, record_ids: Set[str]) -> None: for concept_id in new_group: self._groups[concept_id] = new_group end = timer() - _logger.debug(f"Built record ID sets in {end - start} seconds") + _logger.debug("Built record ID sets in %f seconds", end - start) self._groups = {k: v for k, v in self._groups.items() if len(v) > 1} @@ -58,8 +58,9 @@ def create_merged_concepts(self, record_ids: Set[str]) -> None: except DatabaseWriteException as dw: if str(dw).startswith("No such record exists"): _logger.error( - f"Updating nonexistent record: {concept_id} " - f"for merge ref to {merge_ref}" + "Updating nonexistent record: %s for merge ref to %s", + concept_id, + merge_ref, ) else: _logger.error(str(dw)) @@ -67,11 +68,11 @@ def create_merged_concepts(self, record_ids: Set[str]) -> None: self._database.complete_write_transaction() _logger.info("Merged concept generation successful.") end = timer() - _logger.debug(f"Generated and added concepts in {end - start} seconds") + _logger.debug("Generated and added concepts in %f seconds", end - start) def _create_record_id_set( - self, record_id: str, observed_id_set: Optional[Set] = None - ) -> Set[str]: + self, record_id: str, observed_id_set: set | None = None + ) -> set[str]: """Recursively create concept ID group for an individual record ID. :param record_id: concept ID for record to build group from @@ -84,29 +85,27 @@ def _create_record_id_set( if record_id in self._groups: return self._groups[record_id] - else: - db_record = self._database.get_record_by_id(record_id) - if not db_record: - _logger.warning( - f"Record ID set creator could not resolve " - f"lookup for {record_id} in ID set: " - f"{observed_id_set}" - ) - return observed_id_set - {record_id} - - record_xrefs = db_record.get("xrefs") - if not record_xrefs: - return observed_id_set | {db_record["concept_id"]} - else: - local_id_set = set(record_xrefs) - merged_id_set = {record_id} | observed_id_set - for local_record_id in local_id_set - observed_id_set: - merged_id_set |= self._create_record_id_set( - local_record_id, merged_id_set - ) - return merged_id_set - def _generate_merged_record(self, record_id_set: Set[str]) -> Dict: + db_record = self._database.get_record_by_id(record_id) + if not db_record: + _logger.warning( + "Record ID set creator could not resolve lookup for %s in ID set: %s", + record_id, + observed_id_set, + ) + return observed_id_set - {record_id} + + record_xrefs = db_record.get("xrefs") + if not record_xrefs: + return observed_id_set | {db_record["concept_id"]} + + local_id_set = set(record_xrefs) + merged_id_set = {record_id} | observed_id_set + for local_record_id in local_id_set - observed_id_set: + merged_id_set |= self._create_record_id_set(local_record_id, merged_id_set) + return merged_id_set + + def _generate_merged_record(self, record_id_set: set[str]) -> dict: """Generate merged record from provided concept ID group. Where attributes are sets, they should be merged, and where they are scalars, assign from the highest-priority source where that attribute @@ -124,19 +123,21 @@ def _generate_merged_record(self, record_id_set: Set[str]) -> Dict: records.append(record) else: _logger.error( - f"Merge record generator could not retrieve " - f"record for {record_id} in {record_id_set}" + "Merge record generator could not retrieve record for %s in %s", + record_id, + record_id_set, ) - def record_order(record: Dict) -> Tuple: + def record_order(record: dict) -> tuple: """Provide priority values of concepts for sort function.""" src = record["src_name"].upper() if src in SourcePriority.__members__: source_rank = SourcePriority[src].value else: - raise Exception( + err_msg = ( f"Prohibited source: {src} in concept_id " f"{record['concept_id']}" ) + raise Exception(err_msg) return source_rank, record["concept_id"] records.sort(key=record_order) @@ -175,7 +176,8 @@ def record_order(record: Dict) -> Tuple: merged_field = GeneTypeFieldName[record["src_name"].upper()] merged_attrs[merged_field] |= {gene_type} - for field in set_fields + [ + for field in [ + *set_fields, "hgnc_locus_type", "ncbi_gene_type", "ensembl_biotype", @@ -192,7 +194,7 @@ def record_order(record: Dict) -> Tuple: if num_unique_strand_values > 1: del merged_attrs["strand"] elif num_unique_strand_values == 1: - merged_attrs["strand"] = list(unique_strand_values)[0] + merged_attrs["strand"] = next(iter(unique_strand_values)) merged_attrs["item_type"] = RecordType.MERGER.value return merged_attrs diff --git a/src/gene/etl/ncbi.py b/src/gene/etl/ncbi.py index 01681bb4..5826e820 100644 --- a/src/gene/etl/ncbi.py +++ b/src/gene/etl/ncbi.py @@ -1,9 +1,9 @@ """Defines ETL methods for the NCBI data source.""" + import csv import logging import re from pathlib import Path -from typing import Dict, List, Optional import gffutils from wags_tails import NcbiGenomeData @@ -34,7 +34,7 @@ def __init__( self, database: AbstractDatabase, seqrepo_dir: Path = SEQREPO_ROOT_DIR, - data_path: Optional[Path] = None, + data_path: Path | None = None, silent: bool = True, ) -> None: """Instantiate Base class. @@ -59,7 +59,7 @@ def _extract_data(self, use_existing: bool) -> None: gene_paths: NcbiGenePaths gene_paths, self._version = self._data_source.get_latest( from_local=use_existing - ) # type: ignore + ) self._info_src = gene_paths.gene_info self._history_src = gene_paths.gene_history self._gene_url = ( @@ -68,13 +68,13 @@ def _extract_data(self, use_existing: bool) -> None: self._history_url = "ftp.ncbi.nlm.nih.govgene/DATA/gene_history.gz" self._assembly_url = "ftp.ncbi.nlm.nih.govgenomes/refseq/vertebrate_mammalian/Homo_sapiens/latest_assembly_versions/" - def _get_prev_symbols(self) -> Dict[str, str]: + def _get_prev_symbols(self) -> dict[str, str]: """Store a gene's symbol history. :return: A dictionary of a gene's previous symbols """ # get symbol history - history_file = open(self._history_src, "r") + history_file = self._history_src.open() history = csv.reader(history_file, delimiter="\t") next(history) prev_symbols = {} @@ -83,7 +83,7 @@ def _get_prev_symbols(self) -> Dict[str, str]: if row[0] == "9606": if row[1] != "-": gene_id = row[1] - if gene_id in prev_symbols.keys(): + if gene_id in prev_symbols: prev_symbols[gene_id].append(row[3]) else: prev_symbols[gene_id] = [row[3]] @@ -98,7 +98,7 @@ def _get_prev_symbols(self) -> Dict[str, str]: history_file.close() return prev_symbols - def _add_xrefs_associated_with(self, val: List[str], params: Dict) -> None: + def _add_xrefs_associated_with(self, val: list[str], params: dict) -> None: """Add xrefs and associated_with refs to a transformed gene. :param val: A list of source ids for a given gene @@ -130,58 +130,59 @@ def _add_xrefs_associated_with(self, val: List[str], params: Dict) -> None: if prefix: params["associated_with"].append(f"{prefix}:{src_id}") else: - _logger.info(f"{src_name} is not in NameSpacePrefix.") + _logger.info("%s is not in NameSpacePrefix.", src_name) if not params["xrefs"]: del params["xrefs"] if not params["associated_with"]: del params["associated_with"] - def _get_gene_info(self, prev_symbols: Dict[str, str]) -> Dict[str, str]: + def _get_gene_info(self, prev_symbols: dict[str, str]) -> dict[str, str]: """Store genes from NCBI info file. :param prev_symbols: A dictionary of a gene's previous symbols :return: A dictionary of gene's from the NCBI info file. """ + info_genes = {} + # open info file, skip headers - info_file = open(self._info_src, "r") - info = csv.reader(info_file, delimiter="\t") - next(info) - - info_genes = dict() - for row in info: - params = dict() - params["concept_id"] = f"{NamespacePrefix.NCBI.value}:{row[1]}" - # get symbol - params["symbol"] = row[2] - # get aliases - if row[4] != "-": - params["aliases"] = row[4].split("|") - else: - params["aliases"] = [] - # get associated_with - if row[5] != "-": - associated_with = row[5].split("|") - self._add_xrefs_associated_with(associated_with, params) - # get chromosome location - vrs_chr_location = self._get_vrs_chr_location(row, params) - if "exclude" in vrs_chr_location: - # Exclude genes with multiple distinct locations (e.g. OMS) - continue - if not vrs_chr_location: - vrs_chr_location = [] - params["locations"] = vrs_chr_location - # get label - if row[8] != "-": - params["label"] = row[8] - # add prev symbols - if row[1] in prev_symbols.keys(): - params["previous_symbols"] = prev_symbols[row[1]] - info_genes[params["symbol"]] = params - # get type - params["gene_type"] = row[9] + with self._info_src.open() as info_file: + info = csv.reader(info_file, delimiter="\t") + next(info) + + for row in info: + params = {} + params["concept_id"] = f"{NamespacePrefix.NCBI.value}:{row[1]}" + # get symbol + params["symbol"] = row[2] + # get aliases + if row[4] != "-": + params["aliases"] = row[4].split("|") + else: + params["aliases"] = [] + # get associated_with + if row[5] != "-": + associated_with = row[5].split("|") + self._add_xrefs_associated_with(associated_with, params) + # get chromosome location + vrs_chr_location = self._get_vrs_chr_location(row, params) + if "exclude" in vrs_chr_location: + # Exclude genes with multiple distinct locations (e.g. OMS) + continue + if not vrs_chr_location: + vrs_chr_location = [] + params["locations"] = vrs_chr_location + # get label + if row[8] != "-": + params["label"] = row[8] + # add prev symbols + if row[1] in prev_symbols: + params["previous_symbols"] = prev_symbols[row[1]] + info_genes[params["symbol"]] = params + # get type + params["gene_type"] = row[9] return info_genes - def _get_gene_gff(self, db: gffutils.FeatureDB, info_genes: Dict) -> None: + def _get_gene_gff(self, db: gffutils.FeatureDB, info_genes: dict) -> None: """Store genes from NCBI gff file. :param db: GFF database @@ -197,7 +198,7 @@ def _get_gene_gff(self, db: gffutils.FeatureDB, info_genes: Dict) -> None: params = info_genes.get(symbol) vrs_sq_location = self._get_vrs_sq_location(db, params, f_id) if vrs_sq_location: - params["locations"].append(vrs_sq_location) # type: ignore + params["locations"].append(vrs_sq_location) else: # Need to add entire gene gene = self._add_gff_gene(db, f, f_id) @@ -205,7 +206,7 @@ def _get_gene_gff(self, db: gffutils.FeatureDB, info_genes: Dict) -> None: def _add_gff_gene( self, db: gffutils.FeatureDB, f: gffutils.Feature, f_id: str - ) -> Optional[Dict]: + ) -> dict | None: """Create a transformed gene recor from NCBI gff file. :param db: GFF database @@ -213,18 +214,18 @@ def _add_gff_gene( :param f_id: The feature's ID :return: A gene dictionary if the ID attribute exists. Else return None. """ - params = dict() + params = {} params["src_name"] = SourceName.NCBI.value self._add_attributes(f, params) sq_loc = self._get_vrs_sq_location(db, params, f_id) if sq_loc: params["locations"] = [sq_loc] else: - params["locations"] = list() + params["locations"] = [] params["label_and_type"] = f"{params['concept_id'].lower()}##identity" return params - def _add_attributes(self, f: gffutils.feature.Feature, gene: Dict) -> None: + def _add_attributes(self, f: gffutils.feature.Feature, gene: dict) -> None: """Add concept_id, symbol, and xrefs/associated_with to a gene record. :param gffutils.feature.Feature f: A gene from the data @@ -246,8 +247,8 @@ def _add_attributes(self, f: gffutils.feature.Feature, gene: Dict) -> None: gene["symbol"] = val def _get_vrs_sq_location( - self, db: gffutils.FeatureDB, params: Dict, f_id: str - ) -> Dict: + self, db: gffutils.FeatureDB, params: dict, f_id: str + ) -> dict: """Store GA4GH VRS SequenceLocation in a gene record. https://vr-spec.readthedocs.io/en/1.1/terms_and_model.html#sequencelocation @@ -260,14 +261,14 @@ def _get_vrs_sq_location( params["strand"] = gene.strand return self._get_sequence_location(gene.seqid, gene, params) - def _get_xref_associated_with(self, src_name: str, src_id: str) -> Dict: + def _get_xref_associated_with(self, src_name: str, src_id: str) -> dict: """Get xref or associated_with ref. :param src_name: Source name :param src_id: The source's accession number :return: A dict containing an xref or associated_with ref """ - source = dict() + source = {} if src_name.startswith("HGNC"): source["xrefs"] = [f"{NamespacePrefix.HGNC.value}:{src_id}"] elif src_name.startswith("NCBI"): @@ -280,7 +281,7 @@ def _get_xref_associated_with(self, src_name: str, src_id: str) -> Dict: source["associated_with"] = [f"{NamespacePrefix.RFAM.value}:{src_id}"] return source - def _get_vrs_chr_location(self, row: List[str], params: Dict) -> List: + def _get_vrs_chr_location(self, row: list[str], params: dict) -> list: """Store GA4GH VRS ChromosomeLocation in a gene record. https://vr-spec.readthedocs.io/en/1.1/terms_and_model.html#chromosomelocation @@ -288,14 +289,14 @@ def _get_vrs_chr_location(self, row: List[str], params: Dict) -> List: :param params: A transformed gene record :return: A list of GA4GH VRS ChromosomeLocations """ - params["location_annotations"] = list() + params["location_annotations"] = [] chromosomes_locations = self._set_chromsomes_locations(row, params) locations = chromosomes_locations["locations"] chromosomes = chromosomes_locations["chromosomes"] if chromosomes_locations["exclude"]: return ["exclude"] - location_list = list() + location_list = [] if chromosomes and not locations: for chromosome in chromosomes: if chromosome == "MT": @@ -303,12 +304,12 @@ def _get_vrs_chr_location(self, row: List[str], params: Dict) -> List: else: params["location_annotations"].append(chromosome.strip()) elif locations: - self._add_chromosome_location(locations, location_list, params) + self._add_chromosome_location(locations, params) if not params["location_annotations"]: del params["location_annotations"] return location_list - def _set_chromsomes_locations(self, row: List[str], params: Dict) -> Dict: + def _set_chromsomes_locations(self, row: list[str], params: dict) -> dict: """Set chromosomes and locations for a given gene record. :param row: A gene row in the NCBI data file @@ -317,18 +318,18 @@ def _set_chromsomes_locations(self, row: List[str], params: Dict) -> Dict: """ chromosomes = None if row[6] != "-": - if "|" in row[6]: - chromosomes = row[6].split("|") - else: - chromosomes = [row[6]] + chromosomes = row[6].split("|") if "|" in row[6] else [row[6]] - if len(chromosomes) >= 2: - if chromosomes and "X" not in chromosomes and "Y" not in chromosomes: - _logger.info( - f"{row[2]} contains multiple distinct " - f"chromosomes: {chromosomes}." - ) - chromosomes = None + if ( + len(chromosomes) >= 2 + and chromosomes + and "X" not in chromosomes + and "Y" not in chromosomes + ): + _logger.info( + "%s contains multiple distinct chromosomes: %s", row[2], chromosomes + ) + chromosomes = None locations = None exclude = False @@ -343,15 +344,14 @@ def _set_chromsomes_locations(self, row: List[str], params: Dict) -> Dict: locations = [row[7]] # Sometimes locations will store the same location twice - if len(locations) == 2: - if locations[0] == locations[1]: - locations = [locations[0]] + if len(locations) == 2 and locations[0] == locations[1]: + locations = [locations[0]] # Exclude genes where there are multiple distinct locations # i.e. OMS: '10q26.3', '19q13.42-q13.43', '3p25.3' if len(locations) > 2: _logger.info( - f"{row[2]} contains multiple distinct " f"locations: {locations}." + "%s contains multiple distinct locations: %s", row[2], locations ) locations = None exclude = True @@ -362,24 +362,21 @@ def _set_chromsomes_locations(self, row: List[str], params: Dict) -> Dict: loc = locations[i].strip() if not re.match("^([1-9][0-9]?|X[pq]?|Y[pq]?)", loc): _logger.info( - f"{row[2]} contains invalid map location:" f"{loc}." + "%s contains invalid map location: %s", row[2], loc ) params["location_annotations"].append(loc) del locations[i] return {"locations": locations, "chromosomes": chromosomes, "exclude": exclude} - def _add_chromosome_location( - self, locations: List, location_list: List, params: Dict - ) -> None: + def _add_chromosome_location(self, locations: list, params: dict) -> None: """Add a chromosome location to the location list. :param locations: NCBI map locations for a gene record. - :param location_list: A list to store chromosome locations. :param params: A transformed gene record """ for i in range(len(locations)): loc = locations[i].strip() - location = dict() + location = {} if Annotation.ALT_LOC.value in loc: loc = loc.split(f"{Annotation.ALT_LOC.value}")[0].strip() @@ -423,22 +420,22 @@ def _add_chromosome_location( # if chr_location: # location_list.append(chr_location) - def _set_centromere_location(self, loc: str, location: Dict) -> None: + def _set_centromere_location(self, loc: str, location: dict) -> None: """Set centromere location for a gene. :param loc: A gene location :param location: GA4GH location """ - centromere_ix = re.search("cen", loc).start() # type: ignore + centromere_ix = re.search("cen", loc).start() if "-" in loc: # Location gives both start and end - range_ix = re.search("-", loc).start() # type: ignore + range_ix = re.search("-", loc).start() if "q" in loc: location["chr"] = loc[:centromere_ix].strip() location["start"] = "cen" location["end"] = loc[range_ix + 1 :] elif "p" in loc: - p_ix = re.search("p", loc).start() # type: ignore + p_ix = re.search("p", loc).start() location["chr"] = loc[:p_ix].strip() location["end"] = "cen" location["start"] = loc[:range_ix] @@ -464,7 +461,7 @@ def _transform_data(self) -> None: self._get_gene_gff(db, info_genes) - for gene in info_genes.keys(): + for gene in info_genes: self._load_gene(info_genes[gene]) _logger.info("Successfully transformed NCBI.") @@ -482,9 +479,8 @@ def _add_meta(self) -> None: self._assembly, ] ): - raise GeneNormalizerEtlError( - "Source metadata unavailable -- was data properly acquired before attempting to load DB?" - ) + err_msg = "Source metadata unavailable -- was data properly acquired before attempting to load DB?" + raise GeneNormalizerEtlError(err_msg) metadata = SourceMeta( data_license="custom", data_license_url="https://www.ncbi.nlm.nih.gov/home/about/policies/", diff --git a/src/gene/main.py b/src/gene/main.py index 5cc8ba53..ec51c965 100644 --- a/src/gene/main.py +++ b/src/gene/main.py @@ -1,6 +1,6 @@ """Main application for FastAPI""" + import html -from typing import Optional from fastapi import FastAPI, HTTPException, Query @@ -27,7 +27,7 @@ contact={ "name": "Alex H. Wagner", "email": "Alex.Wagner@nationwidechildrens.org", - "url": "https://www.nationwidechildrens.org/specialties/institute-for-genomic-medicine/research-labs/wagner-lab", # noqa: E501 + "url": "https://www.nationwidechildrens.org/specialties/institute-for-genomic-medicine/research-labs/wagner-lab", }, license={ "name": "MIT", @@ -65,9 +65,9 @@ tags=["Query"], ) def search( - q: str = Query(..., description=q_descr), # noqa: D103 - incl: Optional[str] = Query(None, description=incl_descr), - excl: Optional[str] = Query(None, description=excl_descr), + q: str = Query(..., description=q_descr), + incl: str | None = Query(None, description=incl_descr), + excl: str | None = Query(None, description=excl_descr), ) -> SearchService: """Return strongest match concepts to query string provided by user. @@ -83,7 +83,7 @@ def search( try: resp = query_handler.search(html.unescape(q), incl=incl, excl=excl) except InvalidParameterException as e: - raise HTTPException(status_code=422, detail=str(e)) + raise HTTPException(status_code=422, detail=str(e)) from e return resp @@ -108,8 +108,7 @@ def normalize(q: str = Query(..., description=normalize_q_descr)) -> NormalizeSe :param str q: gene search term :return: JSON response with normalized gene concept """ - resp = query_handler.normalize(html.unescape(q)) - return resp + return query_handler.normalize(html.unescape(q)) unmerged_matches_summary = ( @@ -142,5 +141,4 @@ def normalize_unmerged( :param q: Gene search term :returns: JSON response with matching normalized record and source metadata """ - response = query_handler.normalize_unmerged(html.unescape(q)) - return response + return query_handler.normalize_unmerged(html.unescape(q)) diff --git a/src/gene/query.py b/src/gene/query.py index 93adafba..17ecd3d8 100644 --- a/src/gene/query.py +++ b/src/gene/query.py @@ -1,8 +1,10 @@ """Provides methods for handling queries.""" + +import datetime import logging import re -from datetime import datetime -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar +from collections.abc import Callable +from typing import Any, TypeVar from ga4gh.core import domain_models, entity_models, ga4gh_identify from ga4gh.vrs import models @@ -29,7 +31,6 @@ ) from gene.version import __version__ - _logger = logging.getLogger(__name__) NormService = TypeVar("NormService", bound=BaseNormalizationService) @@ -62,7 +63,7 @@ def __init__(self, database: AbstractDatabase) -> None: self.db = database @staticmethod - def _emit_warnings(query_str: str) -> List: + def _emit_warnings(query_str: str) -> list: """Emit warnings if query contains non breaking space characters. :param query_str: query string @@ -77,12 +78,12 @@ def _emit_warnings(query_str: str) -> List: } ] _logger.warning( - f"Query ({query_str}) contains non-breaking space characters." + "Query (%s) contains non-breaking space characters.", query_str ) return warnings @staticmethod - def _transform_sequence_location(loc: Dict) -> models.SequenceLocation: + def _transform_sequence_location(loc: dict) -> models.SequenceLocation: """Transform a sequence location to VRS sequence location :param loc: GeneSequenceLocation represented as a dict @@ -110,7 +111,7 @@ def _transform_sequence_location(loc: Dict) -> models.SequenceLocation: # end=loc["end"] # ) - def _transform_location(self, loc: Dict) -> Dict: + def _transform_location(self, loc: dict) -> dict: """Transform a sequence/chromosome location to VRS sequence/chromosome location :param loc: Sequence or Chromosome location @@ -125,17 +126,19 @@ def _transform_location(self, loc: Dict) -> Dict: transformed_loc.id = ga4gh_identify(transformed_loc) return transformed_loc.model_dump(exclude_none=True) - def _transform_locations(self, record: Dict) -> Dict: + def _transform_locations(self, record: dict) -> dict: """Transform gene locations to VRS Chromosome/Sequence Locations :param record: original record :return: record with transformed locations attributes, if applicable """ - record_locations = list() + record_locations = [] if "locations" in record: - for loc in record["locations"]: - if loc["type"] == "SequenceLocation": - record_locations.append(self._transform_location(loc)) + record_locations.extend( + self._transform_location(loc) + for loc in record["locations"] + if loc["type"] == "SequenceLocation" + ) record["locations"] = record_locations return record @@ -148,15 +151,18 @@ def _get_src_name(self, concept_id: str) -> SourceName: """ if concept_id.startswith(NamespacePrefix.ENSEMBL.value): return SourceName.ENSEMBL - elif concept_id.startswith(NamespacePrefix.NCBI.value): + + if concept_id.startswith(NamespacePrefix.NCBI.value): return SourceName.NCBI - elif concept_id.startswith(NamespacePrefix.HGNC.value): + + if concept_id.startswith(NamespacePrefix.HGNC.value): return SourceName.HGNC - else: - raise ValueError("Invalid or unrecognized concept ID provided") + + err_msg = "Invalid or unrecognized concept ID provided" + raise ValueError(err_msg) def _add_record( - self, response: Dict[str, Dict], item: Dict, match_type: MatchType + self, response: dict[str, dict], item: dict, match_type: MatchType ) -> None: """Add individual record (i.e. Item in DynamoDB) to response object @@ -170,7 +176,7 @@ def _add_record( src_name = item["src_name"] matches = response["source_matches"] - if src_name not in matches.keys(): + if src_name not in matches: pass elif matches[src_name] is None: matches[src_name] = { @@ -181,7 +187,7 @@ def _add_record( matches[src_name]["records"].append(gene) def _fetch_record( - self, response: Dict[str, Dict], concept_id: str, match_type: MatchType + self, response: dict[str, dict], concept_id: str, match_type: MatchType ) -> None: """Add fetched record to response @@ -193,17 +199,19 @@ def _fetch_record( match = self.db.get_record_by_id(concept_id, case_sensitive=False) except DatabaseReadException as e: _logger.error( - f"Encountered DatabaseReadException looking up {concept_id}: {e}" + "Encountered DatabaseReadException looking up %s: %s", concept_id, e ) else: if match: self._add_record(response, match, match_type) else: _logger.error( - f"Unable to find expected record for {concept_id} matching as {match_type}" - ) # noqa: E501 + "Unable to find expected record for %s matching as %s", + concept_id, + match_type, + ) - def _post_process_resp(self, resp: Dict) -> Dict: + def _post_process_resp(self, resp: dict) -> dict: """Fill all empty source_matches slots with NO_MATCH results and sort source records by descending `match_type`. @@ -211,7 +219,7 @@ def _post_process_resp(self, resp: Dict) -> Dict: :return: response object with empty source slots filled with NO_MATCH results and corresponding source metadata """ - for src_name in resp["source_matches"].keys(): + for src_name in resp["source_matches"]: if resp["source_matches"][src_name] is None: resp["source_matches"][src_name] = { "match_type": MatchType.NO_MATCH, @@ -224,7 +232,7 @@ def _post_process_resp(self, resp: Dict) -> Dict: records = sorted(records, key=lambda k: k.match_type, reverse=True) return resp - def _get_search_response(self, query: str, sources: Set[str]) -> Dict: + def _get_search_response(self, query: str, sources: set[str]) -> dict: """Return response as dict where key is source name and value is a list of records. @@ -241,18 +249,17 @@ def _get_search_response(self, query: str, sources: Set[str]) -> Dict: return self._post_process_resp(resp) query_l = query.lower() - queries = list() - if [p for p in PREFIX_LOOKUP.keys() if query_l.startswith(p)]: + queries = [] + if [p for p in PREFIX_LOOKUP if query_l.startswith(p)]: queries.append((query_l, RecordType.IDENTITY.value)) - for prefix in [p for p in NAMESPACE_LOOKUP.keys() if query_l.startswith(p)]: + for prefix in [p for p in NAMESPACE_LOOKUP if query_l.startswith(p)]: term = f"{NAMESPACE_LOOKUP[prefix].lower()}:{query_l}" queries.append((term, RecordType.IDENTITY.value)) - for match in ITEM_TYPES.values(): - queries.append((query_l, match)) + queries.extend((query_l, match) for match in ITEM_TYPES.values()) - matched_concept_ids = list() + matched_concept_ids = [] for term, item_type in queries: try: if item_type == RecordType.IDENTITY.value: @@ -266,10 +273,12 @@ def _get_search_response(self, query: str, sources: Set[str]) -> Dict: self._fetch_record(resp, ref, MatchType[item_type.upper()]) matched_concept_ids.append(ref) - except DatabaseReadException as e: + except DatabaseReadException as e: # noqa: PERF203 _logger.error( - f"Encountered DatabaseReadException looking up {item_type}" - f" {term}: {e}" + "Encountered DatabaseReadException looking up %s %s: ", + item_type, + term, + e, ) continue @@ -282,14 +291,16 @@ def _get_service_meta() -> ServiceMeta: :return: Service Meta """ - return ServiceMeta(version=__version__, response_datetime=str(datetime.now())) + return ServiceMeta( + version=__version__, + response_datetime=str(datetime.datetime.now(tz=datetime.timezone.utc)), + ) def search( self, query_str: str, incl: str = "", excl: str = "", - **params, ) -> SearchService: """Return highest match for each source. @@ -312,10 +323,9 @@ def search( possible_sources = { name.value.lower(): name.value for name in SourceName.__members__.values() } - sources = dict() - for k, v in possible_sources.items(): - if self.db.get_source_metadata(v): - sources[k] = v + sources = { + k: v for k, v in possible_sources.items() if self.db.get_source_metadata(v) + } if not incl and not excl: query_sources = set(sources.values()) @@ -327,7 +337,7 @@ def search( invalid_sources = [] query_sources = set() for source in req_sources: - if source.lower() in sources.keys(): + if source.lower() in sources: query_sources.add(sources[source.lower()]) else: invalid_sources.append(source) @@ -340,10 +350,10 @@ def search( invalid_sources = [] query_sources = set() for req_l, req in req_excl_dict.items(): - if req_l not in sources.keys(): + if req_l not in sources: invalid_sources.append(req) for src_l, src in sources.items(): - if src_l not in req_excl_dict.keys(): + if src_l not in req_excl_dict: query_sources.add(src) if invalid_sources: detail = f"Invalid source name(s): {invalid_sources}" @@ -371,7 +381,7 @@ def _add_merged_meta(self, response: NormalizeService) -> NormalizeService: for src in sources: try: src_name = PREFIX_LOOKUP[src] - except KeyError: + except KeyError: # noqa: PERF203 # not an imported source continue else: @@ -382,7 +392,7 @@ def _add_merged_meta(self, response: NormalizeService) -> NormalizeService: return response def _add_alt_matches( - self, response: NormService, record: Dict, possible_concepts: List[str] + self, response: NormService, record: dict, possible_concepts: list[str] ) -> NormService: """Add alternate matches warning to response object @@ -408,9 +418,9 @@ def _add_alt_matches( def _add_gene( self, response: NormalizeService, - record: Dict, + record: dict, match_type: MatchType, - possible_concepts: Optional[List[str]] = None, + possible_concepts: list[str] | None = None, ) -> NormalizeService: """Add core Gene object to response. @@ -444,7 +454,7 @@ def _add_gene( # aliases aliases = set() for key in ["previous_symbols", "aliases"]: - if key in record and record[key]: + if record.get(key): val = record[key] if isinstance(val, str): val = [val] @@ -462,7 +472,7 @@ def _add_gene( ("strand", "strand"), ] for ext_label, record_label in extension_and_record_labels: - if record_label in record and record[record_label]: + if record.get(record_label): extensions.append( entity_models.Extension(name=ext_label, value=record[record_label]) ) @@ -473,15 +483,16 @@ def _add_gene( if locs: record_locations[f"{record['src_name'].lower()}_locations"] = locs elif record["item_type"] == RecordType.MERGER: - for k, v in record.items(): - if k.endswith("locations") and v: - record_locations[k] = v + record_locations.update( + {k: v for k, v in record.items() if k.endswith("locations") and v} + ) for loc_name, locations in record_locations.items(): - transformed_locs = [] - for loc in locations: - if loc["type"] == "SequenceLocation": - transformed_locs.append(self._transform_location(loc)) + transformed_locs = [ + self._transform_location(loc) + for loc in locations + if loc["type"] == "SequenceLocation" + ] if transformed_locs: extensions.append( @@ -502,10 +513,10 @@ def _add_gene( for f in GeneTypeFieldName: field_name = f.value values = record.get(field_name, []) - for value in values: - extensions.append( - entity_models.Extension(name=field_name, value=value) - ) + extensions.extend( + entity_models.Extension(name=field_name, value=value) + for value in values + ) if extensions: gene_obj.extensions = extensions @@ -520,7 +531,7 @@ def _add_gene( return response @staticmethod - def _record_order(record: Dict) -> Tuple[int, str]: + def _record_order(record: dict) -> tuple[int, str]: """Construct priority order for matching. Only called by sort(). :param record: individual record item in iterable to sort @@ -531,7 +542,7 @@ def _record_order(record: Dict) -> Tuple[int, str]: return source_rank, record["concept_id"] @staticmethod - def _handle_failed_merge_ref(record: Dict, response: Dict, query: str) -> Dict: + def _handle_failed_merge_ref(record: dict, response: dict, query: str) -> dict: """Log + fill out response for a failed merge reference lookup. :param record: record containing failed merge_ref @@ -540,13 +551,15 @@ def _handle_failed_merge_ref(record: Dict, response: Dict, query: str) -> Dict: :return: response with no match """ _logger.error( - f"Merge ref lookup failed for ref {record['merge_ref']} " - f"in record {record['concept_id']} from query {query}" + "Merge ref lookup failed for ref %s in record %s from query %s", + record["merge_ref"], + record["concept_id"], + query, ) response["match_type"] = MatchType.NO_MATCH return response - def _prepare_normalized_response(self, query: str) -> Dict[str, Any]: + def _prepare_normalized_response(self, query: str) -> dict[str, Any]: """Provide base response object for normalize endpoints. :param query: user-provided query @@ -557,7 +570,8 @@ def _prepare_normalized_response(self, query: str) -> Dict[str, Any]: "match_type": MatchType.NO_MATCH, "warnings": self._emit_warnings(query), "service_meta_": ServiceMeta( - version=__version__, response_datetime=str(datetime.now()) + version=__version__, + response_datetime=str(datetime.datetime.now(tz=datetime.timezone.utc)), ), } @@ -584,10 +598,10 @@ def normalize(self, query: str) -> NormalizeService: def _resolve_merge( self, response: NormService, - record: Dict, + record: dict, match_type: MatchType, callback: Callable, - possible_concepts: Optional[List[str]] = None, + possible_concepts: list[str] | None = None, ) -> NormService: """Given a record, return the corresponding normalized record @@ -605,15 +619,17 @@ def _resolve_merge( if merge is None: query = response.query _logger.error( - f"Merge ref lookup failed for ref {record['merge_ref']} " - f"in record {record['concept_id']} from query `{query}`" + "Merge ref lookup failed for ref %s in record %s from query `%s`", + record["merge_ref"], + record["concept_id"], + query, ) return response - else: - return callback(response, merge, match_type, possible_concepts) - else: - # record is sole member of concept group - return callback(response, record, match_type, possible_concepts) + + return callback(response, merge, match_type, possible_concepts) + + # record is sole member of concept group + return callback(response, record, match_type, possible_concepts) def _perform_normalized_lookup( self, response: NormService, query: str, response_builder: Callable @@ -623,6 +639,7 @@ def _perform_normalized_lookup( :param response: in-progress response object :param query: user-provided query :param response_builder: response constructor callback method + :raises ValueError: If a matching record is null :return: completed service response object """ if query == "": @@ -647,16 +664,15 @@ def _perform_normalized_lookup( matching_records = [ self.db.get_record_by_id(ref, False) for ref in matching_refs ] - matching_records.sort(key=self._record_order) # type: ignore + matching_records.sort(key=self._record_order) - if len(matching_refs) > 1: - possible_concepts = [ref for ref in matching_refs] - else: - possible_concepts = None + possible_concepts = list(matching_refs) if len(matching_refs) > 1 else None # attempt merge ref resolution until successful for match in matching_records: - assert match is not None + if match is None: + err_msg = "Matching record must be nonnull" + raise ValueError(err_msg) record = self.db.get_record_by_id(match["concept_id"], False) if record: match_type_value = MatchType[match_type.value.upper()] @@ -672,9 +688,9 @@ def _perform_normalized_lookup( def _add_normalized_records( self, response: UnmergedNormalizationService, - normalized_record: Dict, + normalized_record: dict, match_type: MatchType, - possible_concepts: Optional[List[str]] = None, + possible_concepts: list[str] | None = None, ) -> UnmergedNormalizationService: """Add individual records to unmerged normalize response. @@ -692,12 +708,11 @@ def _add_normalized_records( meta = self.db.get_source_metadata(record_source.value) response.source_matches[record_source] = MatchesNormalized( records=[BaseGene(**self._transform_locations(normalized_record))], - source_meta_=meta, # type: ignore + source_meta_=meta, ) else: - concept_ids = [normalized_record["concept_id"]] + normalized_record.get( - "xrefs", [] - ) + xrefs = normalized_record.get("xrefs") or [] + concept_ids = [normalized_record["concept_id"], *xrefs] for concept_id in concept_ids: record = self.db.get_record_by_id(concept_id, case_sensitive=False) if not record: @@ -709,8 +724,7 @@ def _add_normalized_records( else: meta = self.db.get_source_metadata(record_source.value) response.source_matches[record_source] = MatchesNormalized( - records=[gene], - source_meta_=meta, # type: ignore + records=[gene], source_meta_=meta ) if possible_concepts: response = self._add_alt_matches( diff --git a/src/gene/schemas.py b/src/gene/schemas.py index d79afd7e..f2987697 100644 --- a/src/gene/schemas.py +++ b/src/gene/schemas.py @@ -1,6 +1,7 @@ """Contains data models for representing VICC normalized gene records.""" + from enum import Enum, IntEnum -from typing import Dict, List, Literal, Optional, Union +from typing import Literal from ga4gh.core import domain_models from ga4gh.vrs import models @@ -69,7 +70,7 @@ class GeneSequenceLocation(BaseModel): type: Literal["SequenceLocation"] = "SequenceLocation" start: StrictInt end: StrictInt - sequence_id: constr(pattern=r"^ga4gh:SQ.[0-9A-Za-z_\-]{32}$") # noqa: F722 + sequence_id: constr(pattern=r"^ga4gh:SQ.[0-9A-Za-z_\-]{32}$") # class GeneChromosomeLocation(BaseModel): @@ -89,20 +90,16 @@ class BaseGene(BaseModel): concept_id: CURIE symbol: StrictStr - symbol_status: Optional[SymbolStatus] = None - label: Optional[StrictStr] = None - strand: Optional[Strand] = None - location_annotations: List[StrictStr] = [] - locations: Union[ - List[models.SequenceLocation], List[GeneSequenceLocation] - # List[Union[SequenceLocation, ChromosomeLocation]], - # List[Union[GeneSequenceLocation, GeneChromosomeLocation]] # dynamodb - ] = [] - aliases: List[StrictStr] = [] - previous_symbols: List[StrictStr] = [] - xrefs: List[CURIE] = [] - associated_with: List[CURIE] = [] - gene_type: Optional[StrictStr] = None + symbol_status: SymbolStatus | None = None + label: StrictStr | None = None + strand: Strand | None = None + location_annotations: list[StrictStr] = [] + locations: list[models.SequenceLocation] | list[GeneSequenceLocation] = [] + aliases: list[StrictStr] = [] + previous_symbols: list[StrictStr] = [] + xrefs: list[CURIE] = [] + associated_with: list[CURIE] = [] + gene_type: StrictStr | None = None class Gene(BaseGene): @@ -136,7 +133,7 @@ class GeneGroup(Gene): description: StrictStr type_identifier: StrictStr - genes: List[Gene] = [] + genes: list[Gene] = [] class SourceName(Enum): @@ -160,7 +157,7 @@ class SourceIDAfterNamespace(Enum): HGNC = "" ENSEMBL = "ENSG" - NCBI = "" + NCBI = "" # noqa: PIE796 class NamespacePrefix(Enum): @@ -228,10 +225,10 @@ class SourceMeta(BaseModel): data_license: StrictStr data_license_url: StrictStr version: StrictStr - data_url: Dict[StrictStr, StrictStr] # TODO strictness necessary? - rdp_url: Optional[StrictStr] = None - data_license_attributes: Dict[StrictStr, StrictBool] - genome_assemblies: List[StrictStr] = [] + data_url: dict[StrictStr, StrictStr] # TODO strictness necessary? + rdp_url: StrictStr | None = None + data_license_attributes: dict[StrictStr, StrictBool] + genome_assemblies: list[StrictStr] = [] model_config = ConfigDict( json_schema_extra={ @@ -259,7 +256,7 @@ class SourceMeta(BaseModel): class SourceSearchMatches(BaseModel): """Container for matching information from an individual source.""" - records: List[Gene] = [] + records: list[Gene] = [] source_meta_: SourceMeta model_config = ConfigDict(json_schema_extra={"example": {}}) # TODO @@ -271,9 +268,9 @@ class ServiceMeta(BaseModel): name: Literal["gene-normalizer"] = "gene-normalizer" version: StrictStr response_datetime: StrictStr - url: Literal[ + url: Literal["https://github.com/cancervariants/gene-normalization"] = ( "https://github.com/cancervariants/gene-normalization" - ] = "https://github.com/cancervariants/gene-normalization" # noqa: E501 + ) model_config = ConfigDict( json_schema_extra={ @@ -291,8 +288,8 @@ class SearchService(BaseModel): """Define model for returning highest match typed concepts from sources.""" query: StrictStr - warnings: List[Dict] = [] - source_matches: Dict[SourceName, SourceSearchMatches] + warnings: list[dict] = [] + source_matches: dict[SourceName, SourceSearchMatches] service_meta_: ServiceMeta model_config = ConfigDict(json_schema_extra={}) # TODO @@ -312,7 +309,7 @@ class BaseNormalizationService(BaseModel): """Base method providing shared attributes to Normalization service classes.""" query: StrictStr - warnings: List[Dict] = [] + warnings: list[dict] = [] match_type: MatchType service_meta_: ServiceMeta @@ -320,9 +317,9 @@ class BaseNormalizationService(BaseModel): class NormalizeService(BaseNormalizationService): """Define model for returning normalized concept.""" - normalized_id: Optional[str] = None - gene: Optional[domain_models.Gene] = None - source_meta_: Dict[SourceName, SourceMeta] = {} + normalized_id: str | None = None + gene: domain_models.Gene | None = None + source_meta_: dict[SourceName, SourceMeta] = {} model_config = ConfigDict( json_schema_extra={ @@ -410,7 +407,7 @@ class NormalizeService(BaseNormalizationService): # { # "name": "chromosome_location", # "value": { - # "id": "ga4gh:CL.O6yCQ1cnThOrTfK9YUgMlTfM6HTqbrKw", # noqa: E501 + # "id": "ga4gh:CL.O6yCQ1cnThOrTfK9YUgMlTfM6HTqbrKw", # "type": "ChromosomeLocation", # "species_id": "taxonomy:9606", # "chr": "7", @@ -438,7 +435,7 @@ class NormalizeService(BaseNormalizationService): }, "Ensembl": { "data_license": "custom", - "data_license_url": "https://useast.ensembl.org/info/about/legal/disclaimer.html", # noqa: E501 + "data_license_url": "https://useast.ensembl.org/info/about/legal/disclaimer.html", "version": "104", "data_url": { "genome_annotations": "ftp://ftp.ensembl.org/pub/current_gff3/homo_sapiens/Homo_sapiens.GRCh38.110.gff3.gz" @@ -453,7 +450,7 @@ class NormalizeService(BaseNormalizationService): }, "NCBI": { "data_license": "custom", - "data_license_url": "https://www.ncbi.nlm.nih.gov/home/about/policies/", # noqa: E501 + "data_license_url": "https://www.ncbi.nlm.nih.gov/home/about/policies/", "version": "20210813", "data_url": { "info_file": "ftp.ncbi.nlm.nih.govgene/DATA/GENE_INFO/Mammalia/Homo_sapiens.gene_info.gz", @@ -483,7 +480,7 @@ class NormalizeService(BaseNormalizationService): class MatchesNormalized(BaseModel): """Matches associated with normalized concept from a single source.""" - records: List[BaseGene] = [] + records: list[BaseGene] = [] source_meta_: SourceMeta @@ -493,8 +490,8 @@ class UnmergedNormalizationService(BaseNormalizationService): attributes. """ - normalized_concept_id: Optional[CURIE] = None - source_matches: Dict[SourceName, MatchesNormalized] + normalized_concept_id: CURIE | None = None + source_matches: dict[SourceName, MatchesNormalized] model_config = ConfigDict( json_schema_extra={ @@ -516,13 +513,13 @@ class UnmergedNormalizationService(BaseNormalizationService): "concept_id": "hgnc:108", "symbol": "ACHE", "symbol_status": "approved", - "label": "acetylcholinesterase (Cartwright blood group)", # noqa: E501 + "label": "acetylcholinesterase (Cartwright blood group)", "strand": None, "location_annotations": [], "locations": [ # { # "type": "ChromosomeLocation", - # "id": "ga4gh:CL.VtdU_0lYXL_o95lXRUfhv-NDJVVpmKoD", # noqa: E501 + # "id": "ga4gh:CL.VtdU_0lYXL_o95lXRUfhv-NDJVVpmKoD", # "species_id": "taxonomy:9606", # "chr": "7", # "start": "q22.1", @@ -570,17 +567,17 @@ class UnmergedNormalizationService(BaseNormalizationService): "concept_id": "ensembl:ENSG00000087085", "symbol": "ACHE", "symbol_status": None, - "label": "acetylcholinesterase (Cartwright blood group)", # noqa: E501 + "label": "acetylcholinesterase (Cartwright blood group)", "strand": "-", "location_annotations": [], "locations": [ { - "id": "ga4gh:SL.4taOKYezIxUvFozs6c6OC0bJAQ2zwjxu", # noqa: E501 + "id": "ga4gh:SL.4taOKYezIxUvFozs6c6OC0bJAQ2zwjxu", "digest": "4taOKYezIxUvFozs6c6OC0bJAQ2zwjxu", "type": "SequenceLocation", "sequenceReference": { "type": "SequenceReference", - "refgetAccession": "SQ.F-LrLMe1SRpfUZHkQmvkVKFEGaoDeHul", # noqa: E501 + "refgetAccession": "SQ.F-LrLMe1SRpfUZHkQmvkVKFEGaoDeHul", }, "start": 100889993, "end": 100896974, @@ -595,7 +592,7 @@ class UnmergedNormalizationService(BaseNormalizationService): ], "source_meta_": { "data_license": "custom", - "data_license_url": "https://useast.ensembl.org/info/about/legal/disclaimer.html", # noqa: E501 + "data_license_url": "https://useast.ensembl.org/info/about/legal/disclaimer.html", "version": "104", "data_url": { "genome_annotations": "ftp://ftp.ensembl.org/pub/current_gff3/homo_sapiens/Homo_sapiens.GRCh38.110.gff3.gz" @@ -615,25 +612,25 @@ class UnmergedNormalizationService(BaseNormalizationService): "concept_id": "ncbigene:43", "symbol": "ACHE", "symbol_status": None, - "label": "acetylcholinesterase (Cartwright blood group)", # noqa: E501 + "label": "acetylcholinesterase (Cartwright blood group)", "strand": "-", "location_annotations": [], "locations": [ { # "type": "ChromosomeLocation", - # "id": "ga4gh:CL.VtdU_0lYXL_o95lXRUfhv-NDJVVpmKoD", # noqa: E501 + # "id": "ga4gh:CL.VtdU_0lYXL_o95lXRUfhv-NDJVVpmKoD", # "species_id": "taxonomy:9606", # "chr": "7", # "start": "q22.1", # "end": "q22.1" }, { - "id": "ga4gh:SL.OWr9DoyBhr2zpf4uLLcZSvsTSIDElU6R", # noqa: E501 + "id": "ga4gh:SL.OWr9DoyBhr2zpf4uLLcZSvsTSIDElU6R", "digest": "OWr9DoyBhr2zpf4uLLcZSvsTSIDElU6R", "type": "SequenceLocation", "sequenceReference": { "type": "SequenceReference", - "refgetAccession": "SQ.F-LrLMe1SRpfUZHkQmvkVKFEGaoDeHul", # noqa: E501 + "refgetAccession": "SQ.F-LrLMe1SRpfUZHkQmvkVKFEGaoDeHul", }, "start": 100889993, "end": 100896994, @@ -648,7 +645,7 @@ class UnmergedNormalizationService(BaseNormalizationService): ], "source_meta_": { "data_license": "custom", - "data_license_url": "https://www.ncbi.nlm.nih.gov/home/about/policies/", # noqa: E501 + "data_license_url": "https://www.ncbi.nlm.nih.gov/home/about/policies/", "version": "20220407", "data_url": { "info_file": "ftp.ncbi.nlm.nih.govgene/DATA/GENE_INFO/Mammalia/Homo_sapiens.gene_info.gz", diff --git a/src/gene/version.py b/src/gene/version.py index 3b2542cc..71445f07 100644 --- a/src/gene/version.py +++ b/src/gene/version.py @@ -1,2 +1,3 @@ """Gene normalizer version""" + __version__ = "0.3.2" diff --git a/tests/conftest.py b/tests/conftest.py index c9f8f204..c87f6646 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ """Provide utilities for test cases.""" + import logging import pytest diff --git a/tests/unit/test_database_and_etl.py b/tests/unit/test_database_and_etl.py index 092cc6c3..47aad2ee 100644 --- a/tests/unit/test_database_and_etl.py +++ b/tests/unit/test_database_and_etl.py @@ -1,10 +1,11 @@ """Test DynamoDB and ETL methods.""" + from os import environ from pathlib import Path +from unittest.mock import patch import pytest from boto3.dynamodb.conditions import Key -from mock import patch from gene.database import AWS_ENV_VAR_NAME from gene.etl import HGNC, NCBI, Ensembl @@ -52,7 +53,7 @@ def __init__(self): @pytest.fixture(scope="module") def processed_ids(): """Create a test fixture to store processed ids for merged concepts.""" - return list() + return [] def _get_aliases(seqid): @@ -95,7 +96,7 @@ def test_ensembl_etl(test_get_seqrepo, processed_ids, db_fixture, etl_data_path) """Test that ensembl etl methods work correctly.""" test_get_seqrepo.return_value = None e = Ensembl(db_fixture.db, data_path=etl_data_path) - e._get_seq_id_aliases = _get_aliases # type: ignore + e._get_seq_id_aliases = _get_aliases ensembl_ids = e.perform_etl(use_existing=True) processed_ids += ensembl_ids @@ -116,7 +117,7 @@ def test_ncbi_etl(test_get_seqrepo, processed_ids, db_fixture, etl_data_path): """Test that ncbi etl methods work correctly.""" test_get_seqrepo.return_value = None n = NCBI(db_fixture.db, data_path=etl_data_path) - n._get_seq_id_aliases = _get_aliases # type: ignore + n._get_seq_id_aliases = _get_aliases ncbi_ids = n.perform_etl(use_existing=True) processed_ids += ncbi_ids diff --git a/tests/unit/test_emit_warnings.py b/tests/unit/test_emit_warnings.py index c8309aac..6013bb53 100644 --- a/tests/unit/test_emit_warnings.py +++ b/tests/unit/test_emit_warnings.py @@ -1,4 +1,5 @@ """Test the emit_warnings function.""" + from gene.database import create_db from gene.query import QueryHandler @@ -18,10 +19,10 @@ def test_emit_warnings(): assert actual_warnings == [] # Test emit warnings - actual_warnings = query_handler._emit_warnings("sp ry3") + actual_warnings = query_handler._emit_warnings("sp ry3") assert actual_warnings == actual_warnings - actual_warnings = query_handler._emit_warnings("sp\u00A0ry3") + actual_warnings = query_handler._emit_warnings("sp\u00a0ry3") assert expected_warnings == actual_warnings actual_warnings = query_handler._emit_warnings("sp ry3") diff --git a/tests/unit/test_endpoints.py b/tests/unit/test_endpoints.py index 0639e6a0..4cbea671 100644 --- a/tests/unit/test_endpoints.py +++ b/tests/unit/test_endpoints.py @@ -4,6 +4,7 @@ response objects -- here, we're checking for bad branch logic and for basic assurances that routes integrate correctly with query methods. """ + import pytest from fastapi.testclient import TestClient diff --git a/tests/unit/test_ensembl_source.py b/tests/unit/test_ensembl_source.py index c4c8d2fa..acc2c71f 100644 --- a/tests/unit/test_ensembl_source.py +++ b/tests/unit/test_ensembl_source.py @@ -1,4 +1,5 @@ """Test that the gene normalizer works as intended for the Ensembl source.""" + import pytest from gene.query import QueryHandler @@ -17,8 +18,7 @@ def search(self, query_str, incl="ensembl"): resp = self.query_handler.search(query_str, incl=incl) return resp.source_matches[SourceName.ENSEMBL] - e = QueryGetter() - return e + return QueryGetter() @pytest.fixture(scope="module") diff --git a/tests/unit/test_hgnc_source.py b/tests/unit/test_hgnc_source.py index 54d0aff0..a771efa1 100644 --- a/tests/unit/test_hgnc_source.py +++ b/tests/unit/test_hgnc_source.py @@ -1,5 +1,6 @@ """Test that the gene normalizer works as intended for the HGNC source.""" -from datetime import datetime + +import datetime import pytest @@ -19,8 +20,7 @@ def search(self, query_str, incl="hgnc"): resp = self.query_handler.search(query_str, incl=incl) return resp.source_matches[SourceName.HGNC] - h = QueryGetter() - return h + return QueryGetter() # Test Non Alt Loci Set @@ -816,7 +816,9 @@ def test_meta_info(hgnc): assert ( resp.source_meta_.data_license_url == "https://www.genenames.org/about/license/" ) - assert datetime.strptime(resp.source_meta_.version, "%Y%m%d") + assert datetime.datetime.now(tz=datetime.timezone.utc).strptime( + resp.source_meta_.version, "%Y%m%d" + ) assert resp.source_meta_.data_url == { "complete_set_archive": "ftp.ebi.ac.uk/pub/databases/genenames/hgnc/json/hgnc_complete_set.json" } diff --git a/tests/unit/test_ncbi_source.py b/tests/unit/test_ncbi_source.py index 9b245c5f..ab647130 100644 --- a/tests/unit/test_ncbi_source.py +++ b/tests/unit/test_ncbi_source.py @@ -1,5 +1,6 @@ """Test import of NCBI source data""" -from datetime import datetime + +import datetime import pytest @@ -37,8 +38,7 @@ def search(self, query_str, incl="ncbi"): resp = self.query_handler.search(query_str, incl=incl) return resp.source_matches[SourceName.NCBI] - n = QueryGetter() - return n + return QueryGetter() @pytest.fixture(scope="module") @@ -847,7 +847,9 @@ def test_no_match(ncbi, source_urls): response.source_meta_.data_license_url == "https://www.ncbi.nlm.nih.gov/home/about/policies/" ) - assert datetime.strptime(response.source_meta_.version, "%Y%m%d") + assert datetime.datetime.now(tz=datetime.timezone.utc).strptime( + response.source_meta_.version, "%Y%m%d" + ) assert response.source_meta_.data_url == source_urls assert response.source_meta_.rdp_url == "https://reusabledata.org/ncbi-gene.html" assert not response.source_meta_.data_license_attributes["non_commercial"] @@ -893,7 +895,9 @@ def test_meta(ncbi, source_urls): response.source_meta_.data_license_url == "https://www.ncbi.nlm.nih.gov/home/about/policies/" ) - assert datetime.strptime(response.source_meta_.version, "%Y%m%d") + assert datetime.datetime.now(tz=datetime.timezone.utc).strptime( + response.source_meta_.version, "%Y%m%d" + ) assert response.source_meta_.data_url == source_urls assert response.source_meta_.rdp_url == "https://reusabledata.org/ncbi-gene.html" assert response.source_meta_.genome_assemblies == ["GRCh38.p14"] diff --git a/tests/unit/test_query.py b/tests/unit/test_query.py index cedcb29f..a1c2b485 100644 --- a/tests/unit/test_query.py +++ b/tests/unit/test_query.py @@ -1,4 +1,5 @@ """Module to test the query module.""" + import pytest from ga4gh.core import domain_models @@ -645,7 +646,7 @@ def normalize_unmerged_loc_653303(): "concept_id": "ncbigene:653303", "symbol": "LOC653303", "symbol_status": None, - "label": "proprotein convertase subtilisin/kexin type 7 pseudogene", # noqa: E501 + "label": "proprotein convertase subtilisin/kexin type 7 pseudogene", "strand": "+", "location_annotations": [], "locations": [ @@ -661,7 +662,7 @@ def normalize_unmerged_loc_653303(): "type": "SequenceLocation", "sequenceReference": { "type": "SequenceReference", - "refgetAccession": "SQ.2NkFm8HK88MqeNkCgj78KidCAXgnsfV1", # noqa: E501 + "refgetAccession": "SQ.2NkFm8HK88MqeNkCgj78KidCAXgnsfV1", }, "start": 117135528, "end": 117138867, @@ -742,7 +743,7 @@ def normalize_unmerged_chaf1a(): "type": "SequenceLocation", "sequenceReference": { "type": "SequenceReference", - "refgetAccession": "SQ.IIB53T8CNeJJdUqzn9V_JnRtQadwWCbl", # noqa: E501 + "refgetAccession": "SQ.IIB53T8CNeJJdUqzn9V_JnRtQadwWCbl", }, "start": 4402639, "end": 4445018, @@ -778,7 +779,7 @@ def normalize_unmerged_chaf1a(): "type": "SequenceLocation", "sequenceReference": { "type": "SequenceReference", - "refgetAccession": "SQ.IIB53T8CNeJJdUqzn9V_JnRtQadwWCbl", # noqa: E501 + "refgetAccession": "SQ.IIB53T8CNeJJdUqzn9V_JnRtQadwWCbl", }, "start": 4402639, "end": 4450830, @@ -824,7 +825,7 @@ def normalize_unmerged_ache(): "type": "SequenceLocation", "sequenceReference": { "type": "SequenceReference", - "refgetAccession": "SQ.F-LrLMe1SRpfUZHkQmvkVKFEGaoDeHul", # noqa: E501 + "refgetAccession": "SQ.F-LrLMe1SRpfUZHkQmvkVKFEGaoDeHul", }, "start": 100889993, "end": 100896994, @@ -852,7 +853,7 @@ def normalize_unmerged_ache(): "type": "SequenceLocation", "sequenceReference": { "type": "SequenceReference", - "refgetAccession": "SQ.F-LrLMe1SRpfUZHkQmvkVKFEGaoDeHul", # noqa: E501 + "refgetAccession": "SQ.F-LrLMe1SRpfUZHkQmvkVKFEGaoDeHul", }, "start": 100889993, "end": 100896974, @@ -953,7 +954,7 @@ def normalized_ifnr(): @pytest.fixture(scope="module") def num_sources(): """Get the number of sources.""" - return len({s for s in SourceName}) + return len(set(SourceName)) @pytest.fixture(scope="module") @@ -999,7 +1000,7 @@ def compare_normalize_resp( resp_source_meta_keys = resp.source_meta_.keys() assert len(resp_source_meta_keys) == len( expected_source_meta - ), "source_meta_keys" # noqa: E501 + ), "source_meta_keys" for src in expected_source_meta: assert src in resp_source_meta_keys compare_service_meta(resp.service_meta_) @@ -1075,9 +1076,11 @@ def compare_gene(test, actual): assert no_matches == [], no_matches assert len(actual.mappings) == len(test.mappings) - assert set(actual.alternativeLabels) == set(test.alternativeLabels), "alternativeLabels" - extensions_present = "extensions" in test.model_fields.keys() - assert ("extensions" in actual.model_fields.keys()) == extensions_present + assert set(actual.alternativeLabels) == set( + test.alternativeLabels + ), "alternativeLabels" + extensions_present = "extensions" in test.model_fields + assert ("extensions" in actual.model_fields) == extensions_present if extensions_present: actual_ext_names = sorted([ext.name for ext in actual.extensions]) unique_actual_ext_names = sorted(set(actual_ext_names)) @@ -1093,8 +1096,15 @@ def compare_gene(test, actual): if test_ext.value: if isinstance(test_ext.value[0], dict): if test_ext.value[0].get("type") == "SequenceLocation": - actual_digest = actual_ext.value[0].pop("id").split("ga4gh:SL.")[-1] - assert actual_ext.value[0].pop("digest") == actual_digest + actual_digest = ( + actual_ext.value[0] + .pop("id") + .split("ga4gh:SL.")[-1] + ) + assert ( + actual_ext.value[0].pop("digest") + == actual_digest + ) assert actual_ext.value == test_ext.value else: assert set(actual_ext.value) == set( @@ -1142,7 +1152,7 @@ def test_search_query_inc_exc(query_handler, num_sources): def test_search_invalid_parameter_exception(query_handler): """Test that Invalid parameter exception works correctly.""" with pytest.raises(InvalidParameterException): - _ = query_handler.search("BRAF", incl="hgn") # noqa: F841, E501 + _ = query_handler.search("BRAF", incl="hgn") with pytest.raises(InvalidParameterException): resp = query_handler.search("BRAF", incl="hgnc", excl="hgnc") # noqa: F841 diff --git a/tests/unit/test_schemas.py b/tests/unit/test_schemas.py index 8c183cd6..a5ebd1a0 100644 --- a/tests/unit/test_schemas.py +++ b/tests/unit/test_schemas.py @@ -1,4 +1,5 @@ """Module to test validators in the schemas module.""" + import pydantic import pytest from ga4gh.vrs import models