Skip to content

Commit

Permalink
[Bindings] Implement alloc + copy to local host when map is unavailab…
Browse files Browse the repository at this point in the history
  • Loading branch information
raikonenfnu authored Oct 19, 2023
1 parent 8b1af38 commit add9417
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 1 deletion.
83 changes: 83 additions & 0 deletions runtime/bindings/python/hal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,18 @@ static const char kHalDeviceQueueExecute[] =
signal_semaphores: Semaphores/Fence to signal.
)";

static const char kHalDeviceQueueCopy[] =
R"(Copy data from a source buffer to destination buffer.
Args:
source_buffer: `HalBuffer` that holds src data.
target_buffer: `HalBuffer` that will receive data.
wait_semaphores: `List[Tuple[HalSemaphore, int]]` of semaphore values or
a HalFence. The allocation will be made once these semaphores are
satisfied.
signal_semaphores: Semaphores/Fence to signal.
)";

static const char kHalFenceWait[] =
R"(Waits until the fence is signalled or errored.
Expand Down Expand Up @@ -524,6 +536,69 @@ void HalDevice::QueueExecute(py::handle command_buffers,
"executing command buffers");
}

void HalDevice::QueueCopy(HalBuffer& source_buffer, HalBuffer& target_buffer,
py::handle wait_semaphores,
py::handle signal_semaphores) {
iree_hal_semaphore_list_t wait_list;
iree_hal_semaphore_list_t signal_list;

// Wait list.
if (py::isinstance<HalFence>(wait_semaphores)) {
wait_list = iree_hal_fence_semaphore_list(
py::cast<HalFence*>(wait_semaphores)->raw_ptr());
} else {
size_t wait_count = py::len(wait_semaphores);
wait_list = {
wait_count,
/*semaphores=*/
static_cast<iree_hal_semaphore_t**>(
alloca(sizeof(iree_hal_semaphore_t*) * wait_count)),
/*payload_values=*/
static_cast<uint64_t*>(alloca(sizeof(uint64_t) * wait_count)),
};
for (size_t i = 0; i < wait_count; ++i) {
py::tuple pair = wait_semaphores[i];
wait_list.semaphores[i] = py::cast<HalSemaphore*>(pair[0])->raw_ptr();
wait_list.payload_values[i] = py::cast<uint64_t>(pair[1]);
}
}

// Signal list.
if (py::isinstance<HalFence>(signal_semaphores)) {
signal_list = iree_hal_fence_semaphore_list(
py::cast<HalFence*>(signal_semaphores)->raw_ptr());
} else {
size_t signal_count = py::len(signal_semaphores);
signal_list = {
signal_count,
/*semaphores=*/
static_cast<iree_hal_semaphore_t**>(
alloca(sizeof(iree_hal_semaphore_t*) * signal_count)),
/*payload_values=*/
static_cast<uint64_t*>(alloca(sizeof(uint64_t) * signal_count)),
};
for (size_t i = 0; i < signal_count; ++i) {
py::tuple pair = signal_semaphores[i];
signal_list.semaphores[i] = py::cast<HalSemaphore*>(pair[0])->raw_ptr();
signal_list.payload_values[i] = py::cast<uint64_t>(pair[1]);
}
}

// TODO: Accept params for src_offset and target_offset.
iree_device_size_t source_length =
iree_hal_buffer_byte_length(source_buffer.raw_ptr());
if (source_length != iree_hal_buffer_byte_length(target_buffer.raw_ptr())) {
throw std::invalid_argument(
"Source and target buffer length must match and it does not. Please "
"check allocations");
}
CheckApiStatus(iree_hal_device_queue_copy(
raw_ptr(), IREE_HAL_QUEUE_AFFINITY_ANY, wait_list,
signal_list, source_buffer.raw_ptr(), 0,
target_buffer.raw_ptr(), 0, source_length),
"Copying buffer on queue");
}

//------------------------------------------------------------------------------
// HalDriver
//------------------------------------------------------------------------------
Expand Down Expand Up @@ -861,6 +936,9 @@ void SetupHalBindings(nanobind::module_ m) {
.def("queue_execute", &HalDevice::QueueExecute,
py::arg("command_buffers"), py::arg("wait_semaphores"),
py::arg("signal_semaphores"), kHalDeviceQueueExecute)
.def("queue_copy", &HalDevice::QueueCopy, py::arg("source_buffer"),
py::arg("target_buffer"), py::arg("wait_semaphores"),
py::arg("signal_semaphores"), kHalDeviceQueueCopy)
.def("__repr__", [](HalDevice& self) {
auto id_sv = iree_hal_device_id(self.raw_ptr());
return std::string(id_sv.data, id_sv.size);
Expand Down Expand Up @@ -963,6 +1041,9 @@ void SetupHalBindings(nanobind::module_ m) {
py::class_<HalBuffer>(m, "HalBuffer")
.def("fill_zero", &HalBuffer::FillZero, py::arg("byte_offset"),
py::arg("byte_length"))
.def("byte_length", &HalBuffer::byte_length)
.def("memory_type", &HalBuffer::memory_type)
.def("allowed_usage", &HalBuffer::allowed_usage)
.def("create_view", &HalBuffer::CreateView, py::arg("shape"),
py::arg("element_size"), py::keep_alive<0, 1>())
.def("map", HalMappedMemory::CreateFromBuffer, py::keep_alive<0, 1>())
Expand Down Expand Up @@ -994,6 +1075,8 @@ void SetupHalBindings(nanobind::module_ m) {
py::arg("buffer"), py::arg("shape"), py::arg("element_type"));
hal_buffer_view
.def("map", HalMappedMemory::CreateFromBufferView, py::keep_alive<0, 1>())
.def("get_buffer", HalBuffer::CreateFromBufferView,
py::keep_alive<0, 1>())
.def_prop_ro("shape",
[](HalBufferView& self) {
iree_host_size_t rank =
Expand Down
11 changes: 11 additions & 0 deletions runtime/bindings/python/hal.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ class HalDevice : public ApiRefCounted<HalDevice, iree_hal_device_t> {
py::handle signal_semaphores);
void QueueExecute(py::handle command_buffers, py::handle wait_semaphores,
py::handle signal_semaphores);
void QueueCopy(HalBuffer& src_buffer, HalBuffer& dst_buffer,
py::handle wait_semaphores, py::handle signal_semaphores);
};

class HalDriver : public ApiRefCounted<HalDriver, iree_hal_driver_t> {
Expand Down Expand Up @@ -176,6 +178,10 @@ class HalBuffer : public ApiRefCounted<HalBuffer, iree_hal_buffer_t> {
return iree_hal_buffer_byte_length(raw_ptr());
}

int memory_type() const { return iree_hal_buffer_memory_type(raw_ptr()); }

int allowed_usage() const { return iree_hal_buffer_allowed_usage(raw_ptr()); }

void FillZero(iree_device_size_t byte_offset,
iree_device_size_t byte_length) {
CheckApiStatus(
Expand All @@ -197,6 +203,11 @@ class HalBuffer : public ApiRefCounted<HalBuffer, iree_hal_buffer_t> {
return HalBufferView::StealFromRawPtr(bv);
}

static HalBuffer CreateFromBufferView(HalBufferView& bv) {
return HalBuffer::BorrowFromRawPtr(
iree_hal_buffer_view_buffer(bv.raw_ptr()));
}

py::str Repr();
};

Expand Down
49 changes: 48 additions & 1 deletion runtime/bindings/python/iree/runtime/array_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
HalElementType,
MappedMemory,
MemoryType,
HalFence,
)

__all__ = [
Expand Down Expand Up @@ -106,6 +107,20 @@ def to_host(self) -> np.ndarray:
self._transfer_to_host(False)
return self._host_array

def _is_mappable(self) -> bool:
buffer = self._buffer_view.get_buffer()
if (
buffer.memory_type() & int(MemoryType.HOST_VISIBLE)
!= MemoryType.HOST_VISIBLE
):
return False
if (
buffer.allowed_usage() & int(BufferUsage.MAPPING_SCOPED)
!= BufferUsage.MAPPING_SCOPED
):
return False
return True

def _transfer_to_host(self, implicit):
if self._host_array is not None:
return
Expand All @@ -114,7 +129,10 @@ def _transfer_to_host(self, implicit):
"DeviceArray cannot be implicitly transferred to the host: "
"if necessary, do an explicit transfer via .to_host()"
)
self._mapped_memory, self._host_array = self._map_to_host()
if self._is_mappable():
self._mapped_memory, self._host_array = self._map_to_host()
else:
self._host_array = self._copy_to_host()

def _map_to_host(self) -> Tuple[MappedMemory, np.ndarray]:
# TODO: When synchronization is enabled, need to block here.
Expand All @@ -129,6 +147,35 @@ def _map_to_host(self) -> Tuple[MappedMemory, np.ndarray]:
host_array = host_array.astype(self._override_dtype)
return mapped_memory, host_array

def _copy_to_host(self) -> np.ndarray:
# TODO: When synchronization is enabled, need to block here.
source_buffer = self._buffer_view.get_buffer()
host_buffer = self._device.allocator.allocate_buffer(
memory_type=(MemoryType.HOST_LOCAL | MemoryType.DEVICE_VISIBLE),
allowed_usage=(BufferUsage.TRANSFER_TARGET | BufferUsage.MAPPING_SCOPED),
allocation_size=source_buffer.byte_length(),
)
# Copy and wait for buffer to be copied from source buffer.
sem = self._device.create_semaphore(0)
self._device.queue_copy(
source_buffer,
host_buffer,
wait_semaphores=HalFence.create_at(sem, 0),
signal_semaphores=HalFence.create_at(sem, 1),
)
HalFence.create_at(sem, 1).wait()
# Map and reformat buffer as np.array.
raw_dtype = self._get_raw_dtype()
mapped_memory = host_buffer.map()
host_array = mapped_memory.asarray(self._buffer_view.shape, raw_dtype)
# Detect if we need to force an explicit conversion. This happens when
# we were requested to pretend that the array is in a specific dtype,
# even if that is not representable on the device. You guessed it:
# this is to support bools.
if self._override_dtype is not None and self._override_dtype != raw_dtype:
host_array = host_array.astype(self._override_dtype)
return host_array

def _get_raw_dtype(self):
return HalElementType.map_to_dtype(self._buffer_view.element_type)

Expand Down
46 changes: 46 additions & 0 deletions runtime/bindings/python/tests/hal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,52 @@ def testFenceExtend(self):
fence.extend(iree.runtime.HalFence.create_at(sem2, 2))
self.assertEqual(fence.timepoint_count, 2)

def testRoundTripQueueCopy(self):
original_ary = np.zeros([3, 4], dtype=np.int32) + 2
source_bv = self.allocator.allocate_buffer_copy(
memory_type=iree.runtime.MemoryType.DEVICE_LOCAL,
allowed_usage=iree.runtime.BufferUsage.DEFAULT,
device=self.device,
buffer=original_ary,
element_type=iree.runtime.HalElementType.SINT_32,
)
source_buffer = source_bv.get_buffer()
target_buffer = self.allocator.allocate_buffer(
memory_type=iree.runtime.MemoryType.DEVICE_LOCAL,
allowed_usage=iree.runtime.BufferUsage.DEFAULT,
allocation_size=source_buffer.byte_length(),
)
sem = self.device.create_semaphore(0)
self.device.queue_copy(
source_buffer,
target_buffer,
wait_semaphores=iree.runtime.HalFence.create_at(sem, 0),
signal_semaphores=iree.runtime.HalFence.create_at(sem, 1),
)
iree.runtime.HalFence.create_at(sem, 1).wait()
copy_ary = target_buffer.map().asarray(original_ary.shape, original_ary.dtype)
np.testing.assert_array_equal(original_ary, copy_ary)

def testDifferentSizeQueueCopy(self):
source_buffer = self.allocator.allocate_buffer(
memory_type=iree.runtime.MemoryType.DEVICE_LOCAL,
allowed_usage=iree.runtime.BufferUsage.DEFAULT,
allocation_size=12,
)
target_buffer = self.allocator.allocate_buffer(
memory_type=iree.runtime.MemoryType.DEVICE_LOCAL,
allowed_usage=iree.runtime.BufferUsage.DEFAULT,
allocation_size=13,
)
sem = self.device.create_semaphore(0)
with self.assertRaisesRegex(ValueError, "length must match"):
self.device.queue_copy(
source_buffer,
target_buffer,
wait_semaphores=iree.runtime.HalFence.create_at(sem, 0),
signal_semaphores=iree.runtime.HalFence.create_at(sem, 1),
)

def testCommandBufferStartsByDefault(self):
cb = iree.runtime.HalCommandBuffer(self.device)
with self.assertRaisesRegex(RuntimeError, "FAILED_PRECONDITION"):
Expand Down

0 comments on commit add9417

Please sign in to comment.