Skip to content

Commit

Permalink
Feat/Python HNSW Bindings (#672)
Browse files Browse the repository at this point in the history
Implement Python bindings for HNSW graph search functionality over plaintext iris codes using the PyO3 library and Maturin build layer.  Provides basic Python bindings for iris code, vector store, and graph store data structures, data serialization of iris codes using base64 encoding compatible with Open IRIS, and serialization of vector and graph stores to and from file.

Serialization of the vector store is implemented using the NDJSON ("Newline Delimited JSON") file format, which allows the use of a single large database file of iris code test data from which entries can be streamed rather than read into memory as a single block.

Build and deployment instructions for the new Python bindings can be found in the `README.md` file of the new `iris-mpc-py` crate.  Usage details are also found in `README.md`, and an example Python script exercising the functionality is available in `examples-py/test_integration.py`.

---------

Co-authored-by: Bryan Gillespie <bryan@inversed.tech>
  • Loading branch information
bgillesp and Bryan Gillespie authored Nov 25, 2024
1 parent 925b179 commit a0a27f8
Show file tree
Hide file tree
Showing 25 changed files with 935 additions and 20 deletions.
97 changes: 96 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ members = [
"iris-mpc-common",
"iris-mpc-upgrade",
"iris-mpc-store",
"iris-mpc-py",
]
resolver = "2"

Expand All @@ -28,11 +29,14 @@ bytemuck = { version = "1.17", features = ["derive"] }
dotenvy = "0.15"
eyre = "0.6"
futures = "0.3.30"
hawk-pack = { git = "https://github.com/Inversed-Tech/hawk-pack.git", rev = "29e888ed" }
hex = "0.4.3"
itertools = "0.13"
num-traits = "0.2"
serde = { version = "1.0", features = ["derive"] }
serde-big-array = "0.5.1"
serde_json = "1"
bincode = "1.3.3"
sqlx = { version = "0.8", features = ["runtime-tokio-native-tls", "postgres"] }
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.15", features = ["env-filter"] }
Expand Down
4 changes: 2 additions & 2 deletions iris-mpc-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ wiremock = "0.6.1"
digest = "0.10.7"
ring = "0.17.8"
data-encoding = "2.6.0"
bincode = "1.3.3"
serde-big-array = "0.5.1"
bincode.workspace = true
serde-big-array.workspace = true

[dev-dependencies]
float_eq = "1"
Expand Down
8 changes: 5 additions & 3 deletions iris-mpc-common/src/iris_db/iris.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ use rand::{
distributions::{Bernoulli, Distribution},
Rng,
};
use serde::{Deserialize, Serialize};
use serde_big_array::BigArray;

pub const MATCH_THRESHOLD_RATIO: f64 = 0.375;

#[repr(transparent)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct IrisCodeArray(pub [u64; Self::IRIS_CODE_SIZE_U64]);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct IrisCodeArray(#[serde(with = "BigArray")] pub [u64; Self::IRIS_CODE_SIZE_U64]);
impl Default for IrisCodeArray {
fn default() -> Self {
Self::ZERO
Expand Down Expand Up @@ -141,7 +143,7 @@ impl std::ops::BitXor for IrisCodeArray {
}
}

#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct IrisCode {
pub code: IrisCodeArray,
pub mask: IrisCodeArray,
Expand Down
6 changes: 3 additions & 3 deletions iris-mpc-cpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ async-channel = "2.3.1"
async-stream = "0.2"
async-trait = "~0.1"
backoff = {version="0.4.0", features = ["tokio"]}
bincode = "1.3.3"
bincode.workspace = true
bytes = "1.7"
bytemuck.workspace = true
dashmap = "6.1.0"
eyre.workspace = true
futures.workspace = true
hawk-pack = { git = "https://github.com/Inversed-Tech/hawk-pack.git", rev = "4e6de24" }
hawk-pack.workspace = true
iris-mpc-common = { path = "../iris-mpc-common" }
itertools.workspace = true
num-traits.workspace = true
Expand Down Expand Up @@ -47,4 +47,4 @@ name = "hnsw"
harness = false

[[example]]
name = "hnsw-ex"
name = "hnsw-ex"
25 changes: 14 additions & 11 deletions iris-mpc-cpu/src/hawkers/plaintext_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,10 @@ use iris_mpc_common::iris_db::{
iris::{IrisCode, MATCH_THRESHOLD_RATIO},
};
use rand::{CryptoRng, RngCore, SeedableRng};
use serde::{Deserialize, Serialize};
use std::ops::{Index, IndexMut};

#[derive(Default, Debug, Clone)]
pub struct PlaintextStore {
pub points: Vec<PlaintextPoint>,
}

#[derive(Default, Debug, Clone)]
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
pub struct PlaintextIris(pub IrisCode);

impl PlaintextIris {
Expand Down Expand Up @@ -47,17 +43,19 @@ impl PlaintextIris {
}
}

#[derive(Clone, Default, Debug)]
// TODO refactor away is_persistent flag; should probably be stored in a
// separate buffer instead whenever working with non-persistent iris codes
#[derive(Clone, Default, Debug, Serialize, Deserialize)]
pub struct PlaintextPoint {
/// Whatever encoding of a vector.
data: PlaintextIris,
pub data: PlaintextIris,
/// Distinguish between queries that are pending, and those that were
/// ultimately accepted into the vector store.
is_persistent: bool,
pub is_persistent: bool,
}

#[derive(Copy, Debug, Clone, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize)]
pub struct PointId(u32);
#[derive(Copy, Default, Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct PointId(pub u32);

impl<T> Index<PointId> for Vec<T> {
type Output = T;
Expand Down Expand Up @@ -85,6 +83,11 @@ impl From<u32> for PointId {
}
}

#[derive(Default, Debug, Clone, Serialize, Deserialize)]
pub struct PlaintextStore {
pub points: Vec<PlaintextPoint>,
}

impl PlaintextStore {
pub fn prepare_query(&mut self, raw_query: IrisCode) -> <Self as VectorStore>::QueryRef {
self.points.push(PlaintextPoint {
Expand Down
1 change: 1 addition & 0 deletions iris-mpc-cpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ pub(crate) mod network;
#[rustfmt::skip]
pub(crate) mod proto_generated;
pub mod protocol;
pub mod py_bindings;
pub(crate) mod shares;
Loading

0 comments on commit a0a27f8

Please sign in to comment.