Skip to content

Commit

Permalink
Save written files to TIMDEXDataset object and log
Browse files Browse the repository at this point in the history
  • Loading branch information
ghukill committed Dec 5, 2024
1 parent d30c9e2 commit d7a1b29
Showing 1 changed file with 26 additions and 4 deletions.
30 changes: 26 additions & 4 deletions timdex_dataset_api/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(self, location: str | list[str]):
self.dataset: ds.Dataset = None # type: ignore[assignment]
self.schema = TIMDEX_DATASET_SCHEMA
self.partition_columns = TIMDEX_DATASET_PARTITION_COLUMNS
self._written_files: list[ds.WrittenFile] = None # type: ignore[assignment]

@classmethod
def load(cls, location: str) -> "TIMDEXDataset":
Expand Down Expand Up @@ -197,6 +198,7 @@ def write(
- use_threads: boolean if threads should be used for writing
"""
start_time = time.perf_counter()
self._written_files = []

if isinstance(self.source, list):
raise TypeError(
Expand All @@ -209,14 +211,13 @@ def write(
batch_size=batch_size,
)

written_files = []
ds.write_dataset(
record_batches_iter,
base_dir=self.source,
basename_template="%s-{i}.parquet" % (str(uuid.uuid4())), # noqa: UP031
existing_data_behavior="delete_matching",
filesystem=self.filesystem,
file_visitor=lambda written_file: written_files.append(written_file),
file_visitor=lambda written_file: self._written_files.append(written_file), # type: ignore[arg-type]
format="parquet",
max_open_files=500,
max_rows_per_file=MAX_ROWS_PER_FILE,
Expand All @@ -227,8 +228,8 @@ def write(
use_threads=use_threads,
)

logger.info(f"write elapsed: {round(time.perf_counter()-start_time, 2)}s")
return written_files # type: ignore[return-value]
self.log_write_statistics(start_time)
return self._written_files # type: ignore[return-value]

def get_dataset_record_batches(
self,
Expand Down Expand Up @@ -266,3 +267,24 @@ def get_dataset_record_batches(
f"elapsed: {round(time.perf_counter()-batch_start_time, 6)}s"
)
yield batch

def log_write_statistics(self, start_time: float) -> None:
"""Parse written files from write and log statistics."""
total_time = round(time.perf_counter() - start_time, 2)
total_files = len(self._written_files)
total_rows = sum(
[
wf.metadata.num_rows # type: ignore[attr-defined]
for wf in self._written_files
]
)
total_size = sum(
[wf.size for wf in self._written_files] # type: ignore[attr-defined]
)
logger.info(
f"Dataset write complete - elapsed: "
f"{total_time}s, "
f"total files: {total_files}, "
f"total rows: {total_rows}, "
f"total size: {total_size}"
)

0 comments on commit d7a1b29

Please sign in to comment.