diff --git a/Cargo.lock b/Cargo.lock index dbb040d3f..3e2f11542 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2196,7 +2196,7 @@ dependencies = [ [[package]] name = "hawk-pack" version = "0.1.0" -source = "git+https://github.com/Inversed-Tech/hawk-pack.git?rev=4e6de24#4e6de24f7422923f8cccd8571ef03407e8dbbb99" +source = "git+https://github.com/Inversed-Tech/hawk-pack.git?rev=29e888ed#29e888edfe19cd69e5925fa676ca07d1f64214da" dependencies = [ "aes-prng 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)", "criterion", @@ -2787,6 +2787,17 @@ dependencies = [ "uuid", ] +[[package]] +name = "iris-mpc-py" +version = "0.1.0" +dependencies = [ + "hawk-pack", + "iris-mpc-common", + "iris-mpc-cpu", + "pyo3", + "rand", +] + [[package]] name = "iris-mpc-store" version = "0.1.0" @@ -3109,6 +3120,15 @@ version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + [[package]] name = "metrics" version = "0.22.3" @@ -3987,6 +4007,69 @@ dependencies = [ "prost", ] +[[package]] +name = "pyo3" +version = "0.22.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f402062616ab18202ae8319da13fa4279883a2b8a9d9f83f20dbade813ce1884" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "once_cell", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.22.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b14b5775b5ff446dd1056212d778012cbe8a0fbffd368029fd9e25b514479c38" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.22.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ab5bcf04a2cdcbb50c7d6105de943f543f9ed92af55818fd17b660390fc8636" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.22.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fd24d897903a9e6d80b968368a34e1525aeb719d568dba8b3d4bfa5dc67d453" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn 2.0.85", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.22.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36c011a03ba1e50152b4b394b479826cad97e7a21eb52df179cd91ac411cbfbe" +dependencies = [ + "heck 0.5.0", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn 2.0.85", +] + [[package]] name = "quanta" version = "0.12.3" @@ -5243,6 +5326,12 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" +[[package]] +name = "target-lexicon" +version = "0.12.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" + [[package]] name = "telemetry-batteries" version = "0.1.0" @@ -5859,6 +5948,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" +[[package]] +name = "unindent" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" + [[package]] name = "untrusted" version = "0.9.0" diff --git a/Cargo.toml b/Cargo.toml index 7416fa873..843cb4908 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ members = [ "iris-mpc-common", "iris-mpc-upgrade", "iris-mpc-store", + "iris-mpc-py", ] resolver = "2" @@ -28,11 +29,14 @@ bytemuck = { version = "1.17", features = ["derive"] } dotenvy = "0.15" eyre = "0.6" futures = "0.3.30" +hawk-pack = { git = "https://github.com/Inversed-Tech/hawk-pack.git", rev = "29e888ed" } hex = "0.4.3" itertools = "0.13" num-traits = "0.2" serde = { version = "1.0", features = ["derive"] } +serde-big-array = "0.5.1" serde_json = "1" +bincode = "1.3.3" sqlx = { version = "0.8", features = ["runtime-tokio-native-tls", "postgres"] } tracing = "0.1.40" tracing-subscriber = { version = "0.3.15", features = ["env-filter"] } diff --git a/iris-mpc-common/Cargo.toml b/iris-mpc-common/Cargo.toml index a658dba54..d9a287689 100644 --- a/iris-mpc-common/Cargo.toml +++ b/iris-mpc-common/Cargo.toml @@ -45,8 +45,8 @@ wiremock = "0.6.1" digest = "0.10.7" ring = "0.17.8" data-encoding = "2.6.0" -bincode = "1.3.3" -serde-big-array = "0.5.1" +bincode.workspace = true +serde-big-array.workspace = true [dev-dependencies] float_eq = "1" diff --git a/iris-mpc-common/src/iris_db/iris.rs b/iris-mpc-common/src/iris_db/iris.rs index b8acc9e88..1176d5a0b 100644 --- a/iris-mpc-common/src/iris_db/iris.rs +++ b/iris-mpc-common/src/iris_db/iris.rs @@ -4,12 +4,14 @@ use rand::{ distributions::{Bernoulli, Distribution}, Rng, }; +use serde::{Deserialize, Serialize}; +use serde_big_array::BigArray; pub const MATCH_THRESHOLD_RATIO: f64 = 0.375; #[repr(transparent)] -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct IrisCodeArray(pub [u64; Self::IRIS_CODE_SIZE_U64]); +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub struct IrisCodeArray(#[serde(with = "BigArray")] pub [u64; Self::IRIS_CODE_SIZE_U64]); impl Default for IrisCodeArray { fn default() -> Self { Self::ZERO @@ -141,7 +143,7 @@ impl std::ops::BitXor for IrisCodeArray { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct IrisCode { pub code: IrisCodeArray, pub mask: IrisCodeArray, diff --git a/iris-mpc-cpu/Cargo.toml b/iris-mpc-cpu/Cargo.toml index 99bfe2a4b..0b67d1a13 100644 --- a/iris-mpc-cpu/Cargo.toml +++ b/iris-mpc-cpu/Cargo.toml @@ -13,13 +13,13 @@ async-channel = "2.3.1" async-stream = "0.2" async-trait = "~0.1" backoff = {version="0.4.0", features = ["tokio"]} -bincode = "1.3.3" +bincode.workspace = true bytes = "1.7" bytemuck.workspace = true dashmap = "6.1.0" eyre.workspace = true futures.workspace = true -hawk-pack = { git = "https://github.com/Inversed-Tech/hawk-pack.git", rev = "4e6de24" } +hawk-pack.workspace = true iris-mpc-common = { path = "../iris-mpc-common" } itertools.workspace = true num-traits.workspace = true @@ -47,4 +47,4 @@ name = "hnsw" harness = false [[example]] -name = "hnsw-ex" \ No newline at end of file +name = "hnsw-ex" diff --git a/iris-mpc-cpu/src/hawkers/plaintext_store.rs b/iris-mpc-cpu/src/hawkers/plaintext_store.rs index 6c69e6355..2d0ebd062 100644 --- a/iris-mpc-cpu/src/hawkers/plaintext_store.rs +++ b/iris-mpc-cpu/src/hawkers/plaintext_store.rs @@ -5,14 +5,10 @@ use iris_mpc_common::iris_db::{ iris::{IrisCode, MATCH_THRESHOLD_RATIO}, }; use rand::{CryptoRng, RngCore, SeedableRng}; +use serde::{Deserialize, Serialize}; use std::ops::{Index, IndexMut}; -#[derive(Default, Debug, Clone)] -pub struct PlaintextStore { - pub points: Vec, -} - -#[derive(Default, Debug, Clone)] +#[derive(Default, Debug, Clone, Serialize, Deserialize)] pub struct PlaintextIris(pub IrisCode); impl PlaintextIris { @@ -47,17 +43,19 @@ impl PlaintextIris { } } -#[derive(Clone, Default, Debug)] +// TODO refactor away is_persistent flag; should probably be stored in a +// separate buffer instead whenever working with non-persistent iris codes +#[derive(Clone, Default, Debug, Serialize, Deserialize)] pub struct PlaintextPoint { /// Whatever encoding of a vector. - data: PlaintextIris, + pub data: PlaintextIris, /// Distinguish between queries that are pending, and those that were /// ultimately accepted into the vector store. - is_persistent: bool, + pub is_persistent: bool, } -#[derive(Copy, Debug, Clone, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize)] -pub struct PointId(u32); +#[derive(Copy, Default, Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct PointId(pub u32); impl Index for Vec { type Output = T; @@ -85,6 +83,11 @@ impl From for PointId { } } +#[derive(Default, Debug, Clone, Serialize, Deserialize)] +pub struct PlaintextStore { + pub points: Vec, +} + impl PlaintextStore { pub fn prepare_query(&mut self, raw_query: IrisCode) -> ::QueryRef { self.points.push(PlaintextPoint { diff --git a/iris-mpc-cpu/src/lib.rs b/iris-mpc-cpu/src/lib.rs index 1a74801f0..bf4a96011 100644 --- a/iris-mpc-cpu/src/lib.rs +++ b/iris-mpc-cpu/src/lib.rs @@ -5,4 +5,5 @@ pub(crate) mod network; #[rustfmt::skip] pub(crate) mod proto_generated; pub mod protocol; +pub mod py_bindings; pub(crate) mod shares; diff --git a/iris-mpc-cpu/src/py_bindings/hnsw.rs b/iris-mpc-cpu/src/py_bindings/hnsw.rs new file mode 100644 index 000000000..e57c85ff0 --- /dev/null +++ b/iris-mpc-cpu/src/py_bindings/hnsw.rs @@ -0,0 +1,126 @@ +use super::plaintext_store::Base64IrisCode; +use crate::hawkers::plaintext_store::{PlaintextStore, PointId}; +use hawk_pack::{graph_store::GraphMem, hnsw_db::HawkSearcher, VectorStore}; +use iris_mpc_common::iris_db::iris::IrisCode; +use rand::rngs::ThreadRng; +use serde_json::{self, Deserializer}; +use std::{fs::File, io::BufReader}; + +pub fn search( + query: IrisCode, + searcher: &HawkSearcher, + vector: &mut PlaintextStore, + graph: &mut GraphMem, +) -> (PointId, f64) { + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); + + rt.block_on(async move { + let query = vector.prepare_query(query); + let neighbors = searcher.search_to_insert(vector, graph, &query).await; + let (nearest, (dist_num, dist_denom)) = neighbors[0].get_nearest().unwrap(); + (*nearest, (*dist_num as f64) / (*dist_denom as f64)) + }) +} + +// TODO could instead take iterator of IrisCodes to make more flexible +pub fn insert( + iris: IrisCode, + searcher: &HawkSearcher, + vector: &mut PlaintextStore, + graph: &mut GraphMem, +) -> PointId { + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); + + rt.block_on(async move { + let mut rng = ThreadRng::default(); + + let query = vector.prepare_query(iris); + let neighbors = searcher.search_to_insert(vector, graph, &query).await; + let inserted = vector.insert(&query).await; + searcher + .insert_from_search_results(vector, graph, &mut rng, inserted, neighbors) + .await; + inserted + }) +} + +pub fn insert_uniform_random( + searcher: &HawkSearcher, + vector: &mut PlaintextStore, + graph: &mut GraphMem, +) -> PointId { + let mut rng = ThreadRng::default(); + let raw_query = IrisCode::random_rng(&mut rng); + + insert(raw_query, searcher, vector, graph) +} + +pub fn fill_uniform_random( + num: usize, + searcher: &HawkSearcher, + vector: &mut PlaintextStore, + graph: &mut GraphMem, +) { + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); + + rt.block_on(async move { + let mut rng = ThreadRng::default(); + + for idx in 0..num { + let raw_query = IrisCode::random_rng(&mut rng); + let query = vector.prepare_query(raw_query.clone()); + let neighbors = searcher.search_to_insert(vector, graph, &query).await; + let inserted = vector.insert(&query).await; + searcher + .insert_from_search_results(vector, graph, &mut rng, inserted, neighbors) + .await; + if idx % 100 == 99 { + println!("{}", idx + 1); + } + } + }) +} + +pub fn fill_from_ndjson_file( + filename: &str, + limit: Option, + searcher: &HawkSearcher, + vector: &mut PlaintextStore, + graph: &mut GraphMem, +) { + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); + + rt.block_on(async move { + let mut rng = ThreadRng::default(); + + let file = File::open(filename).unwrap(); + let reader = BufReader::new(file); + + // Create an iterator over deserialized objects + let stream = Deserializer::from_reader(reader).into_iter::(); + let stream = super::limited_iterator(stream, limit); + + // Iterate over each deserialized object + for json_pt in stream { + let raw_query = (&json_pt.unwrap()).into(); + let query = vector.prepare_query(raw_query); + let neighbors = searcher.search_to_insert(vector, graph, &query).await; + let inserted = vector.insert(&query).await; + searcher + .insert_from_search_results(vector, graph, &mut rng, inserted, neighbors) + .await; + } + }) +} diff --git a/iris-mpc-cpu/src/py_bindings/io.rs b/iris-mpc-cpu/src/py_bindings/io.rs new file mode 100644 index 000000000..77f2c5b6f --- /dev/null +++ b/iris-mpc-cpu/src/py_bindings/io.rs @@ -0,0 +1,36 @@ +use bincode; +use eyre::Result; +use serde::{de::DeserializeOwned, Serialize}; +use serde_json; +use std::{ + fs::File, + io::{BufReader, BufWriter}, +}; + +pub fn write_bin(data: &T, filename: &str) -> Result<()> { + let file = File::create(filename)?; + let writer = BufWriter::new(file); + bincode::serialize_into(writer, data)?; + Ok(()) +} + +pub fn read_bin(filename: &str) -> Result { + let file = File::open(filename)?; + let reader = BufReader::new(file); + let data: T = bincode::deserialize_from(reader)?; + Ok(data) +} + +pub fn write_json(data: &T, filename: &str) -> Result<()> { + let file = File::create(filename)?; + let writer = BufWriter::new(file); + serde_json::to_writer(writer, &data)?; + Ok(()) +} + +pub fn read_json(filename: &str) -> Result { + let file = File::open(filename)?; + let reader = BufReader::new(file); + let data: T = serde_json::from_reader(reader)?; + Ok(data) +} diff --git a/iris-mpc-cpu/src/py_bindings/mod.rs b/iris-mpc-cpu/src/py_bindings/mod.rs new file mode 100644 index 000000000..b655e05f2 --- /dev/null +++ b/iris-mpc-cpu/src/py_bindings/mod.rs @@ -0,0 +1,13 @@ +pub mod hnsw; +pub mod io; +pub mod plaintext_store; + +pub fn limited_iterator(iter: I, limit: Option) -> Box> +where + I: Iterator + 'static, +{ + match limit { + Some(num) => Box::new(iter.take(num)), + None => Box::new(iter), + } +} diff --git a/iris-mpc-cpu/src/py_bindings/plaintext_store.rs b/iris-mpc-cpu/src/py_bindings/plaintext_store.rs new file mode 100644 index 000000000..7340454e8 --- /dev/null +++ b/iris-mpc-cpu/src/py_bindings/plaintext_store.rs @@ -0,0 +1,79 @@ +use crate::hawkers::plaintext_store::{PlaintextIris, PlaintextPoint, PlaintextStore}; +use iris_mpc_common::iris_db::iris::{IrisCode, IrisCodeArray}; +use serde::{Deserialize, Serialize}; +use std::{ + fs::File, + io::{self, BufReader, BufWriter, Write}, +}; + +/// Iris code representation using base64 encoding compatible with Open IRIS +#[derive(Serialize, Deserialize)] +pub struct Base64IrisCode { + iris_codes: String, + mask_codes: String, +} + +impl From<&IrisCode> for Base64IrisCode { + fn from(value: &IrisCode) -> Self { + Self { + iris_codes: value.code.to_base64().unwrap(), + mask_codes: value.mask.to_base64().unwrap(), + } + } +} + +impl From<&Base64IrisCode> for IrisCode { + fn from(value: &Base64IrisCode) -> Self { + Self { + code: IrisCodeArray::from_base64(&value.iris_codes).unwrap(), + mask: IrisCodeArray::from_base64(&value.mask_codes).unwrap(), + } + } +} + +pub fn from_ndjson_file(filename: &str, len: Option) -> io::Result { + let file = File::open(filename)?; + let reader = BufReader::new(file); + + // Create an iterator over deserialized objects + let stream = serde_json::Deserializer::from_reader(reader).into_iter::(); + let stream = super::limited_iterator(stream, len); + + // Iterate over each deserialized object + let mut vector = PlaintextStore::default(); + for json_pt in stream { + let json_pt = json_pt?; + vector.points.push(PlaintextPoint { + data: PlaintextIris((&json_pt).into()), + is_persistent: true, + }); + } + + if let Some(num) = len { + if vector.points.len() != num { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "File {} contains too few entries; number read: {}", + filename, + vector.points.len() + ), + )); + } + } + + Ok(vector) +} + +pub fn to_ndjson_file(vector: &PlaintextStore, filename: &str) -> std::io::Result<()> { + // Serialize the objects to the file + let file = File::create(filename)?; + let mut writer = BufWriter::new(file); + for pt in &vector.points { + let json_pt: Base64IrisCode = (&pt.data.0).into(); + serde_json::to_writer(&mut writer, &json_pt)?; + writer.write_all(b"\n")?; // Write a newline after each JSON object + } + writer.flush()?; + Ok(()) +} diff --git a/iris-mpc-py/.gitignore b/iris-mpc-py/.gitignore new file mode 100644 index 000000000..c8f044299 --- /dev/null +++ b/iris-mpc-py/.gitignore @@ -0,0 +1,72 @@ +/target + +# Byte-compiled / optimized / DLL files +__pycache__/ +.pytest_cache/ +*.py[cod] + +# C extensions +*.so + +# Distribution / packaging +.Python +.venv/ +env/ +bin/ +build/ +develop-eggs/ +dist/ +eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +include/ +man/ +venv/ +*.egg-info/ +.installed.cfg +*.egg + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt +pip-selfcheck.json + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.cache +nosetests.xml +coverage.xml + +# Translations +*.mo + +# Mr Developer +.mr.developer.cfg +.project +.pydevproject + +# Rope +.ropeproject + +# Django stuff: +*.log +*.pot + +.DS_Store + +# Sphinx documentation +docs/_build/ + +# PyCharm +.idea/ + +# VSCode +.vscode/ + +# Pyenv +.python-version diff --git a/iris-mpc-py/Cargo.toml b/iris-mpc-py/Cargo.toml new file mode 100644 index 000000000..d3d325935 --- /dev/null +++ b/iris-mpc-py/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "iris-mpc-py" +version = "0.1.0" +publish = false + +edition.workspace = true +license.workspace = true +repository.workspace = true + +[lib] +name = "iris_mpc_py" +crate-type = ["cdylib"] + +[dependencies] +iris-mpc-common = { path = "../iris-mpc-common" } +iris-mpc-cpu = { path = "../iris-mpc-cpu" } +hawk-pack.workspace = true +pyo3 = { version = "0.22.0", features = ["extension-module"] } +rand.workspace = true diff --git a/iris-mpc-py/README.md b/iris-mpc-py/README.md new file mode 100644 index 000000000..e80736956 --- /dev/null +++ b/iris-mpc-py/README.md @@ -0,0 +1,93 @@ +# Python Bindings + +This package provides Python bindings for some functionalities in the `iris-mpc` workspace, currently focused on execution of the HNSW k-nearest neighbors graph search algorithm over plaintext iris codes for testing and data analysis. For compatibility, compilation of this crate is disabled from the workspace root, but enabled from within the crate subdirectory via the Cargo default feature flag `enable`. + +## Installation + +Installation of Python bindings from the PyO3 library code can be accomplished using the Maturin Python package as follows: + +- Install Maturin in the target Python environment, e.g. the venv used for data analysis, using `pip install maturin` + +- Optionally install `patchelf` library with `pip install patchelf` for support for patching wheel files that link other shared libraries + +- Build and install current bindings as a module in the current Python environment by navigating to the `iris-mpc-py` directory and running `maturin develop --release` + +- Build a wheel file suitable for installation using `pip install` by instead running `maturin build --release`; the `.whl` file is specific to the building architecture and Python version, and can be found in `iris_mpc/target/wheels` directory + +See the [Maturin User Guide Tutorial](https://www.maturin.rs/tutorial#build-and-install-the-module-with-maturin-develop) for additional details. + +## Usage + +Once successfully installed, the native rust module `iris_mpc_py` can be imported in your Python environment as usual with `import iris_mpc_py`. Example usage: + +```python +from iris_mpc_py import PyHawkSearcher, PyPlaintextStore, PyGraphStore, PyIrisCode + +hnsw = PyHawkSearcher.new_uniform(32, 32) # M, ef +vector = PyPlaintextStore() +graph = PyGraphStore() + +hnsw.fill_uniform_random(1000, vector, graph) + +iris = PyIrisCode.uniform_random() +iris_id = hnsw.insert(iris, vector, graph) +print("Inserted iris id:", iris_id) + +nearest_id, nearest_dist = hnsw.search(iris, vector, graph) +print("Nearest iris id:", nearest_id) # should be iris_id +print("Nearest iris distance:", nearest_dist) # should be 0.0 +``` + +To write the HNSW vector and graph indices to file and read them back: + +```python +hnsw.write_to_json("searcher.json") +vector.write_to_ndjson("vector.ndjson") +graph.write_to_bin("graph.dat") + +hnsw2 = PyHawkSearcher.read_from_json("searcher.json") +vector2 = PyPlaintextStore.read_from_ndjson("vector.ndjson") +graph2 = PyGraphStore.read_from_bin("graph.dat") +``` + +As an efficiency feature, the data from the vector store is read in a streamed fashion. This means that for a large database of iris codes, the first `num` can be read from file without loading the entire database into memory. This can be used in two ways; first, a vector store can be initialized from the large databse file for use with a previously generated HNSW index: + +```python +# Serialized HNSW graph constructed from the first 10k entries of database file +vector = PyPlaintextStore.read_from_ndjson("large_vector_database.ndjson", 10000) +graph = PyGraphStore.read_from_bin("graph.dat") +``` + +Second, to construct an HNSW index dynamically from streamed database entries: + +```python +hnsw = PyHawkSearcher.new_uniform(32, 32) +vector = PyPlaintextStore() +graph = PyGraphStore() +hnsw.fill_from_ndjson_file("large_vector_database.ndjson", vector, graph, 10000) +``` + +To generate a vector database directly for use in this way: + +```python +# Generate 100k uniform random iris codes +vector_init = PyPlaintextStore() +for i in range(1,100000): + vector_init.insert(PyIrisCode.uniform_random()) +vector_init.write_to_ndjson("vector.ndjson") +``` + +Basic interoperability with Open IRIS iris templates is provided by way of a common base64 encoding scheme, provided by the `iris.io.dataclasses.IrisTemplate` methods `serialize` and `deserialize`. These methods use a base64 encoding of iris code and mask code arrays represented as a Python `dict` with base64-encoded fields `iris_codes`, `mask_codes`, and a version string `iris_code_version` to check for compatibility. The `PyIrisCode` class interacts with this representation as follows: + +```python +serialized_iris_code = { + "iris_codes": "...", + "mask_codes": "...", + "iris_code_version": "1.0", +} + +iris = PyIrisCode.from_open_iris_template_dict(serialized_iris_code) +reserialized_iris_code = iris.to_open_iris_template_dict("1.0") +``` + +Note that the `to_open_iris_template_dict` method takes an optional argument which fills the `iris_code_version` field of the resulting Python `dict` since the `PyIrisCode` object does not preserve this data. diff --git a/iris-mpc-py/examples-py/test_integration.py b/iris-mpc-py/examples-py/test_integration.py new file mode 100644 index 000000000..945069af4 --- /dev/null +++ b/iris-mpc-py/examples-py/test_integration.py @@ -0,0 +1,37 @@ +from iris_mpc_py import PyIrisCode, PyPlaintextStore, PyGraphStore, PyHawkSearcher + +print("Generating 100k uniform random iris codes...") +vector_init = PyPlaintextStore() +iris0 = PyIrisCode.uniform_random() +iris_id = vector_init.insert(iris0) +for i in range(1,100000): + vector_init.insert(PyIrisCode.uniform_random()) + +# write vector store to file +print("Writing vector store to file...") +vector_init.write_to_ndjson("vector.ndjson") + +print("Generating HNSW graphs for 10k imported iris codes...") +hnsw = PyHawkSearcher.new_uniform(32, 32) +vector1 = PyPlaintextStore() +graph1 = PyGraphStore() +hnsw.fill_from_ndjson_file("vector.ndjson", vector1, graph1, 10000) + +print("Imported length:", vector1.len()) + +retrieved_iris = vector1.get(iris_id) +print("Retrieved iris0 base64 == original iris0 base64:", iris0.code.to_base64() == retrieved_iris.code.to_base64() and iris0.mask.to_base64() == retrieved_iris.mask.to_base64()) + +query = PyIrisCode.uniform_random() +print("Search for random query iris code:", hnsw.search(query, vector1, graph1)) + +# write graph store to file +print("Writing graph store to file...") +graph1.write_to_bin("graph1.dat") + +# read HNSW graphs from disk +print("Reading vector and graph stores from file...") +vector2 = PyPlaintextStore.read_from_ndjson("vector.ndjson", 10000) +graph2 = PyGraphStore.read_from_bin("graph1.dat") + +print("Search for random query iris code:", hnsw.search(query, vector2, graph2)) diff --git a/iris-mpc-py/pyproject.toml b/iris-mpc-py/pyproject.toml new file mode 100644 index 000000000..8b731d0c3 --- /dev/null +++ b/iris-mpc-py/pyproject.toml @@ -0,0 +1,16 @@ +[build-system] +requires = ["maturin>=1.7,<2.0"] +build-backend = "maturin" + +[project] +name = "iris-mpc-py" +requires-python = ">=3.8" +classifiers = [ + "Programming Language :: Rust", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dynamic = ["version"] +[tool.maturin] +features = ["pyo3/extension-module"] +module-name = "iris_mpc_py" \ No newline at end of file diff --git a/iris-mpc-py/src/lib.rs b/iris-mpc-py/src/lib.rs new file mode 100644 index 000000000..d8301516c --- /dev/null +++ b/iris-mpc-py/src/lib.rs @@ -0,0 +1 @@ +pub mod py_hnsw; diff --git a/iris-mpc-py/src/py_hnsw/mod.rs b/iris-mpc-py/src/py_hnsw/mod.rs new file mode 100644 index 000000000..d5fe0536c --- /dev/null +++ b/iris-mpc-py/src/py_hnsw/mod.rs @@ -0,0 +1,2 @@ +pub mod pyclasses; +pub mod pymodule; diff --git a/iris-mpc-py/src/py_hnsw/pyclasses/graph_store.rs b/iris-mpc-py/src/py_hnsw/pyclasses/graph_store.rs new file mode 100644 index 000000000..fc6768f3d --- /dev/null +++ b/iris-mpc-py/src/py_hnsw/pyclasses/graph_store.rs @@ -0,0 +1,27 @@ +use hawk_pack::graph_store::GraphMem; +use iris_mpc_cpu::{hawkers::plaintext_store::PlaintextStore, py_bindings}; +use pyo3::{exceptions::PyIOError, prelude::*}; + +#[pyclass] +#[derive(Clone, Default)] +pub struct PyGraphStore(pub GraphMem); + +#[pymethods] +impl PyGraphStore { + #[new] + pub fn new() -> Self { + Self::default() + } + + #[staticmethod] + pub fn read_from_bin(filename: String) -> PyResult { + let result = py_bindings::io::read_bin(&filename) + .map_err(|_| PyIOError::new_err("Unable to read from file"))?; + Ok(Self(result)) + } + + pub fn write_to_bin(&self, filename: String) -> PyResult<()> { + py_bindings::io::write_bin(&self.0, &filename) + .map_err(|_| PyIOError::new_err("Unable to write to file")) + } +} diff --git a/iris-mpc-py/src/py_hnsw/pyclasses/hawk_searcher.rs b/iris-mpc-py/src/py_hnsw/pyclasses/hawk_searcher.rs new file mode 100644 index 000000000..05fb346ee --- /dev/null +++ b/iris-mpc-py/src/py_hnsw/pyclasses/hawk_searcher.rs @@ -0,0 +1,98 @@ +use super::{graph_store::PyGraphStore, iris_code::PyIrisCode, plaintext_store::PyPlaintextStore}; +use hawk_pack::hnsw_db::{HawkSearcher, Params}; +use iris_mpc_cpu::py_bindings; +use pyo3::{exceptions::PyIOError, prelude::*}; + +#[pyclass] +#[derive(Clone, Default)] +pub struct PyHawkSearcher(pub HawkSearcher); + +#[pymethods] +#[allow(non_snake_case)] +impl PyHawkSearcher { + #[new] + pub fn new(M: usize, ef_constr: usize, ef_search: usize) -> Self { + Self::new_standard(ef_constr, ef_search, M) + } + + #[staticmethod] + pub fn new_standard(M: usize, ef_constr: usize, ef_search: usize) -> Self { + let params = Params::new_standard(ef_constr, ef_search, M); + Self(HawkSearcher { params }) + } + + #[staticmethod] + pub fn new_uniform(M: usize, ef: usize) -> Self { + let params = Params::new_uniform(ef, M); + Self(HawkSearcher { params }) + } + + pub fn insert( + &self, + iris: PyIrisCode, + vector: &mut PyPlaintextStore, + graph: &mut PyGraphStore, + ) -> u32 { + let id = py_bindings::hnsw::insert(iris.0, &self.0, &mut vector.0, &mut graph.0); + id.0 + } + + pub fn insert_uniform_random( + &self, + vector: &mut PyPlaintextStore, + graph: &mut PyGraphStore, + ) -> u32 { + let id = py_bindings::hnsw::insert_uniform_random(&self.0, &mut vector.0, &mut graph.0); + id.0 + } + + pub fn fill_uniform_random( + &self, + num: usize, + vector: &mut PyPlaintextStore, + graph: &mut PyGraphStore, + ) { + py_bindings::hnsw::fill_uniform_random(num, &self.0, &mut vector.0, &mut graph.0); + } + + #[pyo3(signature = (filename, vector, graph, limit=None))] + pub fn fill_from_ndjson_file( + &self, + filename: String, + vector: &mut PyPlaintextStore, + graph: &mut PyGraphStore, + limit: Option, + ) { + py_bindings::hnsw::fill_from_ndjson_file( + &filename, + limit, + &self.0, + &mut vector.0, + &mut graph.0, + ); + } + + /// Search HNSW index and return nearest ID and its distance from query + pub fn search( + &mut self, + query: &PyIrisCode, + vector: &mut PyPlaintextStore, + graph: &mut PyGraphStore, + ) -> (u32, f64) { + let (id, dist) = + py_bindings::hnsw::search(query.0.clone(), &self.0, &mut vector.0, &mut graph.0); + (id.0, dist) + } + + #[staticmethod] + pub fn read_from_json(filename: String) -> PyResult { + let result = py_bindings::io::read_json(&filename) + .map_err(|_| PyIOError::new_err("Unable to read from file"))?; + Ok(Self(result)) + } + + pub fn write_to_json(&self, filename: String) -> PyResult<()> { + py_bindings::io::write_json(&self.0, &filename) + .map_err(|_| PyIOError::new_err("Unable to write to file")) + } +} diff --git a/iris-mpc-py/src/py_hnsw/pyclasses/iris_code.rs b/iris-mpc-py/src/py_hnsw/pyclasses/iris_code.rs new file mode 100644 index 000000000..c004344ee --- /dev/null +++ b/iris-mpc-py/src/py_hnsw/pyclasses/iris_code.rs @@ -0,0 +1,73 @@ +use super::iris_code_array::PyIrisCodeArray; +use iris_mpc_common::iris_db::iris::IrisCode; +use pyo3::{prelude::*, types::PyDict}; +use rand::rngs::ThreadRng; + +#[pyclass] +#[derive(Clone, Default)] +pub struct PyIrisCode(pub IrisCode); + +#[pymethods] +impl PyIrisCode { + #[new] + pub fn new(code: &PyIrisCodeArray, mask: &PyIrisCodeArray) -> Self { + Self(IrisCode { + code: code.0, + mask: mask.0, + }) + } + + #[getter] + pub fn code(&self) -> PyIrisCodeArray { + PyIrisCodeArray(self.0.code) + } + + #[getter] + pub fn mask(&self) -> PyIrisCodeArray { + PyIrisCodeArray(self.0.mask) + } + + #[staticmethod] + pub fn uniform_random() -> Self { + let mut rng = ThreadRng::default(); + Self(IrisCode::random_rng(&mut rng)) + } + + #[pyo3(signature = (version=None))] + pub fn to_open_iris_template_dict<'py>( + &self, + py: Python<'py>, + version: Option, + ) -> PyResult> { + let dict = PyDict::new_bound(py); + + dict.set_item("iris_codes", self.0.code.to_base64().unwrap())?; + dict.set_item("mask_codes", self.0.mask.to_base64().unwrap())?; + dict.set_item("iris_code_version", version)?; + + Ok(dict) + } + + #[staticmethod] + pub fn from_open_iris_template_dict(dict_obj: &Bound) -> PyResult { + // Extract base64-encoded iris code arrays + let iris_codes_str: String = dict_obj.get_item("iris_codes")?.unwrap().extract()?; + let mask_codes_str: String = dict_obj.get_item("mask_codes")?.unwrap().extract()?; + + // Convert the base64 strings into PyIrisCodeArrays + let code = PyIrisCodeArray::from_base64(iris_codes_str); + let mask = PyIrisCodeArray::from_base64(mask_codes_str); + + // Construct and return PyIrisCode + Ok(Self(IrisCode { + code: code.0, + mask: mask.0, + })) + } +} + +impl From for PyIrisCode { + fn from(value: IrisCode) -> Self { + Self(value) + } +} diff --git a/iris-mpc-py/src/py_hnsw/pyclasses/iris_code_array.rs b/iris-mpc-py/src/py_hnsw/pyclasses/iris_code_array.rs new file mode 100644 index 000000000..7d12fe3e7 --- /dev/null +++ b/iris-mpc-py/src/py_hnsw/pyclasses/iris_code_array.rs @@ -0,0 +1,46 @@ +use iris_mpc_common::iris_db::iris::IrisCodeArray; +use pyo3::prelude::*; +use rand::rngs::ThreadRng; + +#[pyclass] +#[derive(Clone, Default)] +pub struct PyIrisCodeArray(pub IrisCodeArray); + +#[pymethods] +impl PyIrisCodeArray { + #[new] + pub fn new(input: String) -> Self { + Self::from_base64(input) + } + + pub fn to_base64(&self) -> String { + self.0.to_base64().unwrap() + } + + #[staticmethod] + pub fn from_base64(input: String) -> Self { + Self(IrisCodeArray::from_base64(&input).unwrap()) + } + + #[staticmethod] + pub fn zeros() -> Self { + Self(IrisCodeArray::ZERO) + } + + #[staticmethod] + pub fn ones() -> Self { + Self(IrisCodeArray::ONES) + } + + #[staticmethod] + pub fn uniform_random() -> Self { + let mut rng = ThreadRng::default(); + Self(IrisCodeArray::random_rng(&mut rng)) + } +} + +impl From for PyIrisCodeArray { + fn from(value: IrisCodeArray) -> Self { + Self(value) + } +} diff --git a/iris-mpc-py/src/py_hnsw/pyclasses/mod.rs b/iris-mpc-py/src/py_hnsw/pyclasses/mod.rs new file mode 100644 index 000000000..eea66d959 --- /dev/null +++ b/iris-mpc-py/src/py_hnsw/pyclasses/mod.rs @@ -0,0 +1,5 @@ +pub mod graph_store; +pub mod hawk_searcher; +pub mod iris_code; +pub mod iris_code_array; +pub mod plaintext_store; diff --git a/iris-mpc-py/src/py_hnsw/pyclasses/plaintext_store.rs b/iris-mpc-py/src/py_hnsw/pyclasses/plaintext_store.rs new file mode 100644 index 000000000..f1d3fed19 --- /dev/null +++ b/iris-mpc-py/src/py_hnsw/pyclasses/plaintext_store.rs @@ -0,0 +1,52 @@ +use super::iris_code::PyIrisCode; +use iris_mpc_cpu::{ + hawkers::plaintext_store::{PlaintextIris, PlaintextPoint, PlaintextStore}, + py_bindings, +}; +use pyo3::{exceptions::PyIOError, prelude::*}; + +#[pyclass] +#[derive(Clone, Default)] +pub struct PyPlaintextStore(pub PlaintextStore); + +#[pymethods] +impl PyPlaintextStore { + #[new] + pub fn new() -> Self { + Self::default() + } + + pub fn get(&self, id: u32) -> PyIrisCode { + self.0.points[id as usize].data.0.clone().into() + } + + pub fn insert(&mut self, iris: PyIrisCode) -> u32 { + let new_id = self.0.points.len() as u32; + self.0.points.push(PlaintextPoint { + data: PlaintextIris(iris.0), + is_persistent: true, + }); + new_id + } + + pub fn len(&self) -> usize { + self.0.points.len() + } + + pub fn is_empty(&self) -> bool { + self.0.points.is_empty() + } + + #[staticmethod] + #[pyo3(signature = (filename, len=None))] + pub fn read_from_ndjson(filename: String, len: Option) -> PyResult { + let result = py_bindings::plaintext_store::from_ndjson_file(&filename, len) + .map_err(|_| PyIOError::new_err("Unable to read from file"))?; + Ok(Self(result)) + } + + pub fn write_to_ndjson(&self, filename: String) -> PyResult<()> { + py_bindings::plaintext_store::to_ndjson_file(&self.0, &filename) + .map_err(|_| PyIOError::new_err("Unable to write to file")) + } +} diff --git a/iris-mpc-py/src/py_hnsw/pymodule.rs b/iris-mpc-py/src/py_hnsw/pymodule.rs new file mode 100644 index 000000000..b0ceae8e3 --- /dev/null +++ b/iris-mpc-py/src/py_hnsw/pymodule.rs @@ -0,0 +1,15 @@ +use super::pyclasses::{ + graph_store::PyGraphStore, hawk_searcher::PyHawkSearcher, iris_code::PyIrisCode, + iris_code_array::PyIrisCodeArray, plaintext_store::PyPlaintextStore, +}; +use pyo3::prelude::*; + +#[pymodule] +fn iris_mpc_py(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + Ok(()) +}