Skip to content

Commit

Permalink
MNT: Fix issues raised by pyright. [skip ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
Taher Chegini committed Jul 7, 2024
1 parent f37268a commit 5c7cd9c
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 130 deletions.
30 changes: 26 additions & 4 deletions src/pygridmet/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, TypeVar
from typing import TYPE_CHECKING, Literal, TypeVar

import click
import geopandas as gpd
Expand All @@ -19,6 +19,24 @@

if TYPE_CHECKING:
DFType = TypeVar("DFType", pd.DataFrame, gpd.GeoDataFrame)
VARS = Literal[
"pr",
"rmax",
"rmin",
"sph",
"srad",
"th",
"tmmn",
"tmmx",
"vs",
"bi",
"fm100",
"fm1000",
"erc",
"etr",
"pet",
"vpd",
]


def parse_snow(target_df: pd.DataFrame) -> pd.DataFrame:
Expand Down Expand Up @@ -90,7 +108,7 @@ def cli() -> None:
@ssl_opt
def coords(
fpath: Path,
variables: list[str] | str | None = None,
variables: list[VARS] | VARS | None = None,
save_dir: str | Path = "clm_gridmet",
disable_ssl: bool = False,
) -> None:
Expand Down Expand Up @@ -139,7 +157,11 @@ def coords(
if fname.exists():
continue
kwrgs = dict(zip(req_cols[1:], args))
clm = gridmet.get_bycoords(**kwrgs, variables=variables, ssl=not disable_ssl)
clm = gridmet.get_bycoords(
**kwrgs,
variables=variables,
ssl=not disable_ssl,
)
clm.to_csv(fname, index=False)
click.echo("Done.")

Expand All @@ -151,7 +173,7 @@ def coords(
@ssl_opt
def geometry(
fpath: Path,
variables: list[str] | str | None = None,
variables: list[VARS] | VARS | None = None,
save_dir: str | Path = "clm_gridmet",
disable_ssl: bool = False,
) -> None:
Expand Down
163 changes: 37 additions & 126 deletions src/pygridmet/pygridmet.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,24 @@
from shapely import MultiPolygon, Polygon

CRSTYPE = Union[int, str, pyproj.CRS]
VARS = Literal[
"pr",
"rmax",
"rmin",
"sph",
"srad",
"th",
"tmmn",
"tmmx",
"vs",
"bi",
"fm100",
"fm1000",
"erc",
"etr",
"pet",
"vpd",
]

DATE_FMT = "%Y-%m-%dT%H:%M:%SZ"
MAX_CONN = 4
Expand All @@ -38,26 +56,7 @@

def _coord_urls(
coord: tuple[float, float],
variables: Iterable[
Literal[
"pr",
"rmax",
"rmin",
"sph",
"srad",
"th",
"tmmn",
"tmmx",
"vs",
"bi",
"fm100",
"fm1000",
"erc",
"etr",
"pet",
"vpd",
]
],
variables: Iterable[VARS],
dates: list[tuple[pd.Timestamp, pd.Timestamp]],
long_names: dict[str, str],
) -> Generator[list[tuple[str, dict[str, dict[str, str]]]], None, None]:
Expand Down Expand Up @@ -129,7 +128,7 @@ def _by_coord(
) -> pd.DataFrame:
"""Get climate data for a coordinate and return as a DataFrame."""
coords = (lon, lat)
url_kwds = _coord_urls(coords, gridmet.variables, dates, gridmet.long_names)
url_kwds = _coord_urls(coords, gridmet.variables, dates, gridmet.long_names) # pyright: ignore[reportArgumentType]
retrieve = functools.partial(ar.retrieve_text, max_workers=MAX_CONN, ssl=ssl)

clm = pd.concat( # pyright: ignore[reportCallIssue]
Expand Down Expand Up @@ -167,45 +166,7 @@ def get_bycoords(
dates: tuple[str, str] | int | list[int],
coords_id: Sequence[str | int] | None = None,
crs: CRSTYPE = 4326,
variables: Iterable[
Literal[
"pr",
"rmax",
"rmin",
"sph",
"srad",
"th",
"tmmn",
"tmmx",
"vs",
"bi",
"fm100",
"fm1000",
"erc",
"etr",
"pet",
"vpd",
]
]
| Literal[
"pr",
"rmax",
"rmin",
"sph",
"srad",
"th",
"tmmn",
"tmmx",
"vs",
"bi",
"fm100",
"fm1000",
"erc",
"etr",
"pet",
"vpd",
]
| None = None,
variables: Iterable[VARS] | VARS | None = None,
snow: bool = False,
snow_params: dict[str, float] | None = None,
ssl: bool = True,
Expand Down Expand Up @@ -303,26 +264,7 @@ def get_bycoords(

def _gridded_urls(
bounds: tuple[float, float, float, float],
variables: Iterable[
Literal[
"pr",
"rmax",
"rmin",
"sph",
"srad",
"th",
"tmmn",
"tmmx",
"vs",
"bi",
"fm100",
"fm1000",
"erc",
"etr",
"pet",
"vpd",
]
],
variables: Iterable[VARS],
dates: list[tuple[pd.Timestamp, pd.Timestamp]],
long_names: dict[str, str],
) -> tuple[list[str], list[dict[str, dict[str, str]]]]:
Expand Down Expand Up @@ -413,27 +355,34 @@ def _check_nans(
def _download_urls(
urls: list[str],
kwds: list[dict[str, dict[str, str]]],
clm_files: list[Path],
clm_files: Sequence[Path],
ssl: bool,
long2abbr: dict[str, str],
) -> xr.Dataset:
"""Download the URLs and return the dataset."""
clm_files_full = clm_files.copy()
clm_files_full = list(clm_files)
clm_files_ = clm_files_full.copy()
clm = None
# Sometimes the server returns NaNs, so we must check for that, remove
# the files containing NaNs, and try again.
for _ in range(N_RETRIES):
clm_files = ogc.streaming_download(urls, kwds, clm_files, ssl=ssl, n_jobs=MAX_CONN)
clm_files = [f for f in clm_files if f is not None]
clm_files_ = ogc.streaming_download(
urls,
kwds,
clm_files_,
ssl=ssl,
n_jobs=MAX_CONN,
)
clm_files_ = [f for f in clm_files_ if f is not None]
try:
# open_mfdataset can run into too many open files error so we use merge
# https://docs.xarray.dev/en/stable/user-guide/io.html#reading-multi-file-datasets
clm = xr.merge(_open_dataset(f) for f in clm_files_full).astype("f4")
except ValueError:
_ = [f.unlink() for f in clm_files]
_ = [f.unlink() for f in clm_files_]
continue

has_nans, urls, kwds, clm_files = _check_nans(clm, urls, kwds, clm_files, long2abbr)
has_nans, urls, kwds, clm_files_ = _check_nans(clm, urls, kwds, clm_files_, long2abbr)
if has_nans:
clm = None
continue
Expand All @@ -454,45 +403,7 @@ def get_bygeom(
geometry: Polygon | MultiPolygon | tuple[float, float, float, float],
dates: tuple[str, str] | int | list[int],
crs: CRSTYPE = 4326,
variables: Iterable[
Literal[
"pr",
"rmax",
"rmin",
"sph",
"srad",
"th",
"tmmn",
"tmmx",
"vs",
"bi",
"fm100",
"fm1000",
"erc",
"etr",
"pet",
"vpd",
]
]
| Literal[
"pr",
"rmax",
"rmin",
"sph",
"srad",
"th",
"tmmn",
"tmmx",
"vs",
"bi",
"fm100",
"fm1000",
"erc",
"etr",
"pet",
"vpd",
]
| None = None,
variables: Iterable[VARS] | VARS | None = None,
snow: bool = False,
snow_params: dict[str, float] | None = None,
ssl: bool = True,
Expand Down Expand Up @@ -551,7 +462,7 @@ def get_bygeom(

urls, kwds = _gridded_urls(
_geometry.bounds, # pyright: ignore[reportGeneralTypeIssues]
gridmet.variables,
gridmet.variables, # pyright: ignore[reportArgumentType]
gridmet.date_iterator,
gridmet.long_names,
)
Expand Down

0 comments on commit 5c7cd9c

Please sign in to comment.