Skip to content

Commit

Permalink
Usa callable signatures and separate output references
Browse files Browse the repository at this point in the history
Signed-off-by: Joaquin Anton Guirao <janton@nvidia.com>
  • Loading branch information
jantonguirao committed Jan 9, 2025
1 parent 692d107 commit 428d776
Show file tree
Hide file tree
Showing 4 changed files with 799 additions and 334 deletions.
25 changes: 18 additions & 7 deletions dali/python/nvidia/dali/external_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,27 @@
)


def _get_shape(data):
if hasattr(data, "shape"):
return data.shape() if callable(data.shape) else data.shape
elif hasattr(data, "__array_interface__"):
return data.__array_interface__["shape"]
elif hasattr(data, "__cuda_array_interface__"):
return data.__cuda_array_interface__["shape"]
elif hasattr(data, "__array__"):
return data.__array__().shape
else:
raise RuntimeError(f"Don't know how to extract the shape out of {type(data)}")


def _get_batch_shape(data):
if isinstance(data, (list, tuple, _tensors.TensorListCPU, _tensors.TensorListGPU)):
if len(data) == 0:
return [], True
if callable(data[0].shape):
return [x.shape() for x in data], False
else:
return [x.shape for x in data], False
return [_get_shape(x) for x in data], False
else:
shape = data.shape
if callable(shape):
shape = data.shape()
shape = _get_shape(data)
return [shape[1:]] * shape[0], True


Expand Down Expand Up @@ -68,6 +77,8 @@ def to_numpy(x):
return x.asnumpy()
elif _types._is_torch_tensor(x):
return x.numpy()
elif hasattr(x, "__array__"):
return x.__array__()
else:
return x

Expand All @@ -79,7 +90,7 @@ def to_numpy(x):
if layout is not None:
_check_data_batch(data, batch_size, layout)
data = type(data)(data, layout)
elif isinstance(data, list):
elif isinstance(data, (list, tuple)):
inputs = []
checked = False
for datum in data:
Expand Down
7 changes: 7 additions & 0 deletions dali/python/nvidia/dali/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,12 @@ def is_restored_from_checkpoint(self):
"""If True, this pipeline was restored from checkpoint."""
return self._is_restored_from_checkpoint

@property
def num_outputs(self):
"""Number of pipeline outputs."""
self.build()
return self._num_outputs

def output_dtype(self) -> list:
"""Data types expected at the outputs."""
self.build()
Expand Down Expand Up @@ -854,6 +860,7 @@ def contains_nested_datanode(nested):
self._require_no_foreign_ops("The pipeline does not support checkpointing")

self._graph_outputs = outputs
self._num_outputs = len(self._graph_outputs)
self._setup_input_callbacks()
self._disable_pruned_external_source_instances()
self._py_graph_built = True
Expand Down
Loading

0 comments on commit 428d776

Please sign in to comment.