diff --git a/timdex_dataset_api/dataset.py b/timdex_dataset_api/dataset.py index aa27285..50ebdad 100644 --- a/timdex_dataset_api/dataset.py +++ b/timdex_dataset_api/dataset.py @@ -54,6 +54,7 @@ def __init__(self, location: str | list[str]): self.dataset: ds.Dataset = None # type: ignore[assignment] self.schema = TIMDEX_DATASET_SCHEMA self.partition_columns = TIMDEX_DATASET_PARTITION_COLUMNS + self._written_files: list[ds.WrittenFile] = None # type: ignore[assignment] @classmethod def load(cls, location: str) -> "TIMDEXDataset": @@ -197,6 +198,7 @@ def write( - use_threads: boolean if threads should be used for writing """ start_time = time.perf_counter() + self._written_files = [] if isinstance(self.source, list): raise TypeError( @@ -209,14 +211,13 @@ def write( batch_size=batch_size, ) - written_files = [] ds.write_dataset( record_batches_iter, base_dir=self.source, basename_template="%s-{i}.parquet" % (str(uuid.uuid4())), # noqa: UP031 existing_data_behavior="delete_matching", filesystem=self.filesystem, - file_visitor=lambda written_file: written_files.append(written_file), + file_visitor=lambda written_file: self._written_files.append(written_file), # type: ignore[arg-type] format="parquet", max_open_files=500, max_rows_per_file=MAX_ROWS_PER_FILE, @@ -227,8 +228,8 @@ def write( use_threads=use_threads, ) - logger.info(f"write elapsed: {round(time.perf_counter()-start_time, 2)}s") - return written_files # type: ignore[return-value] + self.log_write_statistics(start_time) + return self._written_files # type: ignore[return-value] def get_dataset_record_batches( self, @@ -266,3 +267,24 @@ def get_dataset_record_batches( f"elapsed: {round(time.perf_counter()-batch_start_time, 6)}s" ) yield batch + + def log_write_statistics(self, start_time: float) -> None: + """Parse written files from write and log statistics.""" + total_time = round(time.perf_counter() - start_time, 2) + total_files = len(self._written_files) + total_rows = sum( + [ + wf.metadata.num_rows # type: ignore[attr-defined] + for wf in self._written_files + ] + ) + total_size = sum( + [wf.size for wf in self._written_files] # type: ignore[attr-defined] + ) + logger.info( + f"Dataset write complete - elapsed: " + f"{total_time}s, " + f"total files: {total_files}, " + f"total rows: {total_rows}, " + f"total size: {total_size}" + )