From c5300f4bd6a359df1b4d6a704e67e65c7a7b9143 Mon Sep 17 00:00:00 2001 From: Kyle Barron Date: Mon, 30 Dec 2024 07:12:55 -0800 Subject: [PATCH] Python: Return boxes as arrow from RTree (#89) * Return boxes as arrow from RTree * update comment --- benches/rtree.rs | 6 +-- python/python/geoindex_rs/rtree.pyi | 7 +++ python/src/lib.rs | 2 + python/src/rtree/builder.rs | 67 ++++++++++++++++------------- python/tests/test_rtree.py | 24 ++++++++++- 5 files changed, 72 insertions(+), 34 deletions(-) diff --git a/benches/rtree.rs b/benches/rtree.rs index 5fa52bb..e8c8061 100644 --- a/benches/rtree.rs +++ b/benches/rtree.rs @@ -5,7 +5,7 @@ use geo_index::rtree::util::f64_box_to_f32; use geo_index::rtree::{RTree, RTreeBuilder, RTreeIndex}; use geo_index::IndexableNum; use rstar::primitives::{GeomWithData, Rectangle}; -use rstar::{RTree, AABB}; +use rstar::AABB; use std::fs::read; fn load_data() -> Vec { @@ -52,8 +52,8 @@ fn construct_rtree_f32_with_cast(boxes_buf: &[f64]) -> RTree { fn construct_rstar( rect_vec: Vec, usize>>, -) -> RTree, usize>> { - RTree::bulk_load(rect_vec) +) -> rstar::RTree, usize>> { + rstar::RTree::bulk_load(rect_vec) } pub fn criterion_benchmark(c: &mut Criterion) { diff --git a/python/python/geoindex_rs/rtree.pyi b/python/python/geoindex_rs/rtree.pyi index 950578f..709aab5 100644 --- a/python/python/geoindex_rs/rtree.pyi +++ b/python/python/geoindex_rs/rtree.pyi @@ -59,3 +59,10 @@ class RTree(Buffer): def num_levels(self) -> int: ... @property def num_bytes(self) -> int: ... + def boxes_at_level(self, level: int) -> Array: + """ + + This is shared as a zero-copy view from Rust. Note that it will keep the entire + index memory alive until the returned array is garbage collected. + + """ diff --git a/python/src/lib.rs b/python/src/lib.rs index 88f93dd..af835a3 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -1,3 +1,5 @@ +#![deny(clippy::undocumented_unsafe_blocks)] + mod coord_type; mod kdtree; mod rtree; diff --git a/python/src/rtree/builder.rs b/python/src/rtree/builder.rs index e144fb0..29f8f77 100644 --- a/python/src/rtree/builder.rs +++ b/python/src/rtree/builder.rs @@ -1,18 +1,21 @@ use arrow_array::builder::UInt32Builder; use arrow_array::cast::AsArray; use arrow_array::types::{Float32Type, Float64Type}; +use arrow_array::{ArrayRef, ArrowPrimitiveType, PrimitiveArray}; +use arrow_buffer::alloc::Allocation; +use arrow_buffer::{ArrowNativeType, Buffer, ScalarBuffer}; use arrow_cast::cast; use arrow_schema::DataType; use geo_index::rtree::sort::{HilbertSort, STRSort}; use geo_index::rtree::util::f64_box_to_f32; use geo_index::rtree::{RTree, RTreeBuilder, RTreeIndex, DEFAULT_RTREE_NODE_SIZE}; -use numpy::{PyArray1, PyArrayMethods}; use pyo3::exceptions::{PyIndexError, PyValueError}; use pyo3::ffi; use pyo3::prelude::*; use pyo3::pybacked::PyBackedStr; use pyo3_arrow::PyArray; use std::os::raw::c_int; +use std::ptr::NonNull; use std::sync::Arc; use crate::coord_type::CoordType; @@ -312,17 +315,17 @@ impl PyRTreeBuilder { .take() .ok_or(PyValueError::new_err("Cannot call finish multiple times."))?; let out = match (inner, method) { - (PyRTreeBuilderInner::Float32(tree), RTreeMethod::Hilbert) => { - PyRTree(PyRTreeInner::Float32(tree.finish::())) - } + (PyRTreeBuilderInner::Float32(tree), RTreeMethod::Hilbert) => PyRTree( + PyRTreeInner::Float32(Arc::new(tree.finish::())), + ), (PyRTreeBuilderInner::Float32(tree), RTreeMethod::STR) => { - PyRTree(PyRTreeInner::Float32(tree.finish::())) - } - (PyRTreeBuilderInner::Float64(tree), RTreeMethod::Hilbert) => { - PyRTree(PyRTreeInner::Float64(tree.finish::())) + PyRTree(PyRTreeInner::Float32(Arc::new(tree.finish::()))) } + (PyRTreeBuilderInner::Float64(tree), RTreeMethod::Hilbert) => PyRTree( + PyRTreeInner::Float64(Arc::new(tree.finish::())), + ), (PyRTreeBuilderInner::Float64(tree), RTreeMethod::STR) => { - PyRTree(PyRTreeInner::Float64(tree.finish::())) + PyRTree(PyRTreeInner::Float64(Arc::new(tree.finish::()))) } }; Ok(out) @@ -330,8 +333,8 @@ impl PyRTreeBuilder { } enum PyRTreeInner { - Float32(RTree), - Float64(RTree), + Float32(Arc>), + Float64(Arc>), } #[pyclass(name = "RTree")] @@ -367,16 +370,13 @@ impl PyRTreeInner { } fn num_bytes(&self) -> usize { - match self { - Self::Float32(index) => index.as_ref().len(), - Self::Float64(index) => index.as_ref().len(), - } + self.buffer().len() } fn buffer(&self) -> &[u8] { match self { - Self::Float32(index) => index.as_ref(), - Self::Float64(index) => index.as_ref(), + Self::Float32(index) => index.as_ref().as_ref(), + Self::Float64(index) => index.as_ref().as_ref(), } } @@ -386,23 +386,15 @@ impl PyRTreeInner { let boxes = index .boxes_at_level(level) .map_err(|err| PyIndexError::new_err(err.to_string()))?; - let array = PyArray1::from_slice(py, boxes); - Ok(array - .reshape([boxes.len() / 4, 4])? - .into_pyobject(py)? - .into_any() - .unbind()) + PyArray::from_array_ref(boxes_at_level::(boxes, index.clone())) + .to_arro3(py) } Self::Float64(index) => { let boxes = index .boxes_at_level(level) .map_err(|err| PyIndexError::new_err(err.to_string()))?; - let array = PyArray1::from_slice(py, boxes); - Ok(array - .reshape([boxes.len() / 4, 4])? - .into_pyobject(py)? - .into_any() - .unbind()) + PyArray::from_array_ref(boxes_at_level::(boxes, index.clone())) + .to_arro3(py) } } } @@ -481,3 +473,20 @@ impl PyRTree { self.0.boxes_at_level(py, level) } } + +fn boxes_at_level( + boxes: &[T::Native], + owner: Arc, +) -> ArrayRef { + let ptr = NonNull::new(boxes.as_ptr() as *mut _).unwrap(); + let len = boxes.len(); + let bytes_len = len * T::Native::get_byte_width(); + + // Safety: + // ptr is a non-null pointer owned by the RTree, which is passed in as the Allocation + let buffer = unsafe { Buffer::from_custom_allocation(ptr, bytes_len, owner) }; + Arc::new(PrimitiveArray::::new( + ScalarBuffer::new(buffer, 0, len), + None, + )) +} diff --git a/python/tests/test_rtree.py b/python/tests/test_rtree.py index 604e4ea..17d15c6 100644 --- a/python/tests/test_rtree.py +++ b/python/tests/test_rtree.py @@ -2,16 +2,36 @@ from geoindex_rs import rtree -def test_rtree(): +def create_index(): builder = rtree.RTreeBuilder(5) min_x = np.arange(5) min_y = np.arange(5) max_x = np.arange(5, 10) max_y = np.arange(5, 10) builder.add(min_x, min_y, max_x, max_y) - tree = builder.finish() + return builder.finish() + +def test_search(): + tree = create_index() result = rtree.search(tree, 0.5, 0.5, 1.5, 1.5) assert len(result) == 2 assert result[0].as_py() == 0 assert result[1].as_py() == 1 + + +def test_rtree(): + builder = rtree.RTreeBuilder(5) + min_x = np.arange(5) + min_y = np.arange(5) + max_x = np.arange(5, 10) + max_y = np.arange(5, 10) + builder.add(min_x, min_y, max_x, max_y) + tree = builder.finish() + + boxes = tree.boxes_at_level(0) + np_arr = np.asarray(boxes).reshape(-1, 4) + assert np.all(min_x == np_arr[:, 0]) + assert np.all(min_y == np_arr[:, 1]) + assert np.all(max_x == np_arr[:, 2]) + assert np.all(max_y == np_arr[:, 3])