Skip to content

Commit

Permalink
chore: cleanup get_all_netzvertraege; split into stream and list pa…
Browse files Browse the repository at this point in the history
…rt (#34)
  • Loading branch information
hf-kklein authored Mar 12, 2024
1 parent 8073b71 commit 79134da
Showing 1 changed file with 82 additions and 54 deletions.
136 changes: 82 additions & 54 deletions src/tmdsclient/client/tmdsclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@
_DEFAULT_CHUNK_SIZE = 100


def _log_chunk_success(chunk_size: int, total_size: int, chunk_idx: int, chunk_length: int) -> None:
_logger.debug(
"Downloaded Netzvertrag (%i/%i) / chunk %i/%i",
chunk_size * chunk_idx + chunk_length,
total_size,
chunk_idx + 1,
total_size // chunk_size + 1,
)


class TmdsClient:
"""
an async wrapper around the TMDS API
Expand Down Expand Up @@ -131,65 +141,59 @@ async def get_all_netzvertrag_ids(self) -> list[uuid.UUID]:
_logger.info("There are %i Netzvertraege on server side", len(result))
return result

@overload
async def get_all_netzvertraege(
self, as_generator: Literal[False], chunk_size: int = _DEFAULT_CHUNK_SIZE
) -> list[Netzvertrag]: ...

@overload
async def get_all_netzvertraege(
self, as_generator: Literal[True], chunk_size: int = _DEFAULT_CHUNK_SIZE
) -> AsyncGenerator[Netzvertrag, None]: ...

async def get_all_netzvertraege(
self, as_generator: bool, chunk_size: int = _DEFAULT_CHUNK_SIZE
) -> list[Netzvertrag] | AsyncGenerator[Netzvertrag, None]:
def _get_all_netzvertraege_stream(
self, all_ids: list[uuid.UUID], chunk_size: int
) -> AsyncGenerator[Netzvertrag, None]:
"""
download all netzverträge from TMDS
"""
all_ids = await self.get_all_netzvertrag_ids()

def _log_chunk_success(chunk_idx: int, chunk_length: int) -> None:
_logger.debug(
"Downloaded Netzvertrag (%i/%i) / chunk %i/%i",
chunk_size * chunk_idx + chunk_length,
len(all_ids),
chunk_idx + 1,
len(all_ids) // chunk_size + 1,
)

if as_generator:

async def generator():
successfully_downloaded = 0
for chunk_index, id_chunk in enumerate(chunked(all_ids, chunk_size)):
get_tasks = [self.get_netzvertrag_by_id(nv_id) for nv_id in id_chunk]
try:
result_chunk = await asyncio.gather(*get_tasks)
for nv in result_chunk:
yield nv
async def generator():
successfully_downloaded = 0
for chunk_index, id_chunk in enumerate(chunked(all_ids, chunk_size)):
get_tasks = [self.get_netzvertrag_by_id(nv_id) for nv_id in id_chunk]
try:
_result_chunk = await asyncio.gather(*get_tasks)
for nv in _result_chunk:
yield nv
successfully_downloaded += len(_result_chunk)
_log_chunk_success(
chunk_size=chunk_size,
total_size=len(all_ids),
chunk_idx=chunk_index,
chunk_length=len(_result_chunk),
)
except (asyncio.TimeoutError, ClientResponseError) as chunk_error:
if isinstance(chunk_error, ClientResponseError) and chunk_error.status != 500:
raise
_logger.warning(
"Failed to download chunk %i; Retrying one by one; %s", chunk_index, str(chunk_error)
)
for _nv_id in id_chunk:
# This is a bit dumb; If we had aiostream here, we could create multiple requests at once
# and yield from a merged stream. This might be a future improvement... For now it's ok.
# With a moderate sized chunk_size it should be fine as there are not that many 500 errors.
success_in_this_chunk = 0
try:
yield await self.get_netzvertrag_by_id(_nv_id)
successfully_downloaded += 1
except (asyncio.TimeoutError, ClientResponseError) as chunk_error:
if isinstance(chunk_error, ClientResponseError) and chunk_error.status != 500:
raise
_logger.warning(
"Failed to download chunk %i; Retrying one by one; %s", chunk_index, str(chunk_error)
success_in_this_chunk += 1
except (asyncio.TimeoutError, ClientResponseError) as single_error:
if isinstance(single_error, ClientResponseError) and single_error.status != 500:
raise
_logger.exception("Failed to download Netzvertrag %s; skipping", _nv_id)
continue
_log_chunk_success(
chunk_size=chunk_size,
total_size=len(all_ids),
chunk_idx=chunk_index,
chunk_length=success_in_this_chunk,
)
for nv_id in id_chunk:
# This is a bit dumb; If we had aiostream here, we could create multiple requests at once
# and yield from a merged stream. This might be a future improvement... For now it's ok.
# With a moderate sized chunk_size it should be fine as there are not that many 500 errors.
try:
yield await self.get_netzvertrag_by_id(nv_id)
successfully_downloaded += 1
except (asyncio.TimeoutError, ClientResponseError) as single_error:
if isinstance(single_error, ClientResponseError) and single_error.status != 500:
raise
_logger.exception("Failed to download Netzvertrag %s; skipping", nv_id)
continue
_logger.info("Successfully downloaded %i Netzvertraege", successfully_downloaded)

return generator() # This needs to be called to return an AsyncGenerator
_logger.info("Successfully downloaded %i Netzvertraege", successfully_downloaded)

return generator() # This needs to be called to return an AsyncGenerator

async def _get_all_netzvertraege_list(self, all_ids: list[uuid.UUID], chunk_size: int) -> list[Netzvertrag]:
result: list[Netzvertrag] = []
for chunk_index, id_chunk in enumerate(chunked(all_ids, chunk_size)):
# we probably need to account for the fact that this leads to HTTP 500 errors, let's see
Expand All @@ -213,11 +217,35 @@ async def generator():
result_chunk.append(nv)
if any(x is None for x in result_chunk):
raise ValueError("This must not happen.")
_log_chunk_success(chunk_index, len(result_chunk))
_log_chunk_success(
chunk_size=chunk_size, chunk_idx=chunk_index, total_size=len(all_ids), chunk_length=len(result_chunk)
)
result.extend(result_chunk) # type:ignore[arg-type]
_logger.info("Successfully downloaded %i Netzvertraege", len(result))
return result

@overload
async def get_all_netzvertraege(
self, as_generator: Literal[False], chunk_size: int = _DEFAULT_CHUNK_SIZE
) -> list[Netzvertrag]: ...

@overload
async def get_all_netzvertraege(
self, as_generator: Literal[True], chunk_size: int = _DEFAULT_CHUNK_SIZE
) -> AsyncGenerator[Netzvertrag, None]: ...

async def get_all_netzvertraege(
self, as_generator: bool, chunk_size: int = _DEFAULT_CHUNK_SIZE
) -> list[Netzvertrag] | AsyncGenerator[Netzvertrag, None]:
"""
download all netzverträge from TMDS
"""
all_ids = await self.get_all_netzvertrag_ids()

if as_generator:
return self._get_all_netzvertraege_stream(all_ids, chunk_size)
return await self._get_all_netzvertraege_list(all_ids, chunk_size)

async def update_netzvertrag(
self, netzvertrag_id: uuid.UUID, changes: list[Callable[[Netzvertrag], None]]
) -> Netzvertrag:
Expand Down

0 comments on commit 79134da

Please sign in to comment.