Skip to content

Commit

Permalink
Python: Return boxes as arrow from RTree (#89)
Browse files Browse the repository at this point in the history
* Return boxes as arrow from RTree

* update comment
  • Loading branch information
kylebarron authored Dec 30, 2024
1 parent e5d7f4e commit c5300f4
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 34 deletions.
6 changes: 3 additions & 3 deletions benches/rtree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use geo_index::rtree::util::f64_box_to_f32;
use geo_index::rtree::{RTree, RTreeBuilder, RTreeIndex};
use geo_index::IndexableNum;
use rstar::primitives::{GeomWithData, Rectangle};
use rstar::{RTree, AABB};
use rstar::AABB;
use std::fs::read;

fn load_data() -> Vec<f64> {
Expand Down Expand Up @@ -52,8 +52,8 @@ fn construct_rtree_f32_with_cast(boxes_buf: &[f64]) -> RTree<f32> {

fn construct_rstar(
rect_vec: Vec<GeomWithData<Rectangle<(f64, f64)>, usize>>,
) -> RTree<GeomWithData<Rectangle<(f64, f64)>, usize>> {
RTree::bulk_load(rect_vec)
) -> rstar::RTree<GeomWithData<Rectangle<(f64, f64)>, usize>> {
rstar::RTree::bulk_load(rect_vec)
}

pub fn criterion_benchmark(c: &mut Criterion) {
Expand Down
7 changes: 7 additions & 0 deletions python/python/geoindex_rs/rtree.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,10 @@ class RTree(Buffer):
def num_levels(self) -> int: ...
@property
def num_bytes(self) -> int: ...
def boxes_at_level(self, level: int) -> Array:
"""
This is shared as a zero-copy view from Rust. Note that it will keep the entire
index memory alive until the returned array is garbage collected.
"""
2 changes: 2 additions & 0 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#![deny(clippy::undocumented_unsafe_blocks)]

mod coord_type;
mod kdtree;
mod rtree;
Expand Down
67 changes: 38 additions & 29 deletions python/src/rtree/builder.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
use arrow_array::builder::UInt32Builder;
use arrow_array::cast::AsArray;
use arrow_array::types::{Float32Type, Float64Type};
use arrow_array::{ArrayRef, ArrowPrimitiveType, PrimitiveArray};
use arrow_buffer::alloc::Allocation;
use arrow_buffer::{ArrowNativeType, Buffer, ScalarBuffer};
use arrow_cast::cast;
use arrow_schema::DataType;
use geo_index::rtree::sort::{HilbertSort, STRSort};
use geo_index::rtree::util::f64_box_to_f32;
use geo_index::rtree::{RTree, RTreeBuilder, RTreeIndex, DEFAULT_RTREE_NODE_SIZE};
use numpy::{PyArray1, PyArrayMethods};
use pyo3::exceptions::{PyIndexError, PyValueError};
use pyo3::ffi;
use pyo3::prelude::*;
use pyo3::pybacked::PyBackedStr;
use pyo3_arrow::PyArray;
use std::os::raw::c_int;
use std::ptr::NonNull;
use std::sync::Arc;

use crate::coord_type::CoordType;
Expand Down Expand Up @@ -312,26 +315,26 @@ impl PyRTreeBuilder {
.take()
.ok_or(PyValueError::new_err("Cannot call finish multiple times."))?;
let out = match (inner, method) {
(PyRTreeBuilderInner::Float32(tree), RTreeMethod::Hilbert) => {
PyRTree(PyRTreeInner::Float32(tree.finish::<HilbertSort>()))
}
(PyRTreeBuilderInner::Float32(tree), RTreeMethod::Hilbert) => PyRTree(
PyRTreeInner::Float32(Arc::new(tree.finish::<HilbertSort>())),
),
(PyRTreeBuilderInner::Float32(tree), RTreeMethod::STR) => {
PyRTree(PyRTreeInner::Float32(tree.finish::<STRSort>()))
}
(PyRTreeBuilderInner::Float64(tree), RTreeMethod::Hilbert) => {
PyRTree(PyRTreeInner::Float64(tree.finish::<HilbertSort>()))
PyRTree(PyRTreeInner::Float32(Arc::new(tree.finish::<STRSort>())))
}
(PyRTreeBuilderInner::Float64(tree), RTreeMethod::Hilbert) => PyRTree(
PyRTreeInner::Float64(Arc::new(tree.finish::<HilbertSort>())),
),
(PyRTreeBuilderInner::Float64(tree), RTreeMethod::STR) => {
PyRTree(PyRTreeInner::Float64(tree.finish::<STRSort>()))
PyRTree(PyRTreeInner::Float64(Arc::new(tree.finish::<STRSort>())))
}
};
Ok(out)
}
}

enum PyRTreeInner {
Float32(RTree<f32>),
Float64(RTree<f64>),
Float32(Arc<RTree<f32>>),
Float64(Arc<RTree<f64>>),
}

#[pyclass(name = "RTree")]
Expand Down Expand Up @@ -367,16 +370,13 @@ impl PyRTreeInner {
}

fn num_bytes(&self) -> usize {
match self {
Self::Float32(index) => index.as_ref().len(),
Self::Float64(index) => index.as_ref().len(),
}
self.buffer().len()
}

fn buffer(&self) -> &[u8] {
match self {
Self::Float32(index) => index.as_ref(),
Self::Float64(index) => index.as_ref(),
Self::Float32(index) => index.as_ref().as_ref(),
Self::Float64(index) => index.as_ref().as_ref(),
}
}

Expand All @@ -386,23 +386,15 @@ impl PyRTreeInner {
let boxes = index
.boxes_at_level(level)
.map_err(|err| PyIndexError::new_err(err.to_string()))?;
let array = PyArray1::from_slice(py, boxes);
Ok(array
.reshape([boxes.len() / 4, 4])?
.into_pyobject(py)?
.into_any()
.unbind())
PyArray::from_array_ref(boxes_at_level::<Float32Type>(boxes, index.clone()))
.to_arro3(py)
}
Self::Float64(index) => {
let boxes = index
.boxes_at_level(level)
.map_err(|err| PyIndexError::new_err(err.to_string()))?;
let array = PyArray1::from_slice(py, boxes);
Ok(array
.reshape([boxes.len() / 4, 4])?
.into_pyobject(py)?
.into_any()
.unbind())
PyArray::from_array_ref(boxes_at_level::<Float64Type>(boxes, index.clone()))
.to_arro3(py)
}
}
}
Expand Down Expand Up @@ -481,3 +473,20 @@ impl PyRTree {
self.0.boxes_at_level(py, level)
}
}

fn boxes_at_level<T: ArrowPrimitiveType>(
boxes: &[T::Native],
owner: Arc<dyn Allocation>,
) -> ArrayRef {
let ptr = NonNull::new(boxes.as_ptr() as *mut _).unwrap();
let len = boxes.len();
let bytes_len = len * T::Native::get_byte_width();

// Safety:
// ptr is a non-null pointer owned by the RTree, which is passed in as the Allocation
let buffer = unsafe { Buffer::from_custom_allocation(ptr, bytes_len, owner) };
Arc::new(PrimitiveArray::<T>::new(
ScalarBuffer::new(buffer, 0, len),
None,
))
}
24 changes: 22 additions & 2 deletions python/tests/test_rtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,36 @@
from geoindex_rs import rtree


def test_rtree():
def create_index():
builder = rtree.RTreeBuilder(5)
min_x = np.arange(5)
min_y = np.arange(5)
max_x = np.arange(5, 10)
max_y = np.arange(5, 10)
builder.add(min_x, min_y, max_x, max_y)
tree = builder.finish()
return builder.finish()


def test_search():
tree = create_index()
result = rtree.search(tree, 0.5, 0.5, 1.5, 1.5)
assert len(result) == 2
assert result[0].as_py() == 0
assert result[1].as_py() == 1


def test_rtree():
builder = rtree.RTreeBuilder(5)
min_x = np.arange(5)
min_y = np.arange(5)
max_x = np.arange(5, 10)
max_y = np.arange(5, 10)
builder.add(min_x, min_y, max_x, max_y)
tree = builder.finish()

boxes = tree.boxes_at_level(0)
np_arr = np.asarray(boxes).reshape(-1, 4)
assert np.all(min_x == np_arr[:, 0])
assert np.all(min_y == np_arr[:, 1])
assert np.all(max_x == np_arr[:, 2])
assert np.all(max_y == np_arr[:, 3])

0 comments on commit c5300f4

Please sign in to comment.