From de6b4abed8ab18b51c7d6bf5378a4af92b146fb1 Mon Sep 17 00:00:00 2001 From: Philipp Sippl Date: Tue, 17 Sep 2024 11:30:55 +0200 Subject: [PATCH] re-enable streaming no heartbeat (#390) * check deletion index * Revert "Revert "Stream DB entries to memory (#356)" (#384)" This reverts commit 844045df8f5d4b3085acca3656d7dce149dc21e8. * some bounds checks * disable heartbeat * bump * bump more --------- Co-authored-by: Wojciech Sromek <157375010+wojciechsromek@users.noreply.github.com> --- deploy/stage/common-values-iris-mpc.yaml | 2 +- iris-mpc-common/src/galois_engine.rs | 2 - iris-mpc-gpu/benches/matmul.rs | 3 +- iris-mpc-gpu/src/dot/share_db.rs | 319 +++++++++++------------ iris-mpc-gpu/src/server/actor.rs | 210 ++++++++++----- iris-mpc-gpu/tests/e2e.rs | 24 +- iris-mpc-store/src/lib.rs | 14 +- iris-mpc/src/bin/server.rs | 203 ++++++--------- 8 files changed, 390 insertions(+), 387 deletions(-) diff --git a/deploy/stage/common-values-iris-mpc.yaml b/deploy/stage/common-values-iris-mpc.yaml index 6ebdf4ab7..a41955995 100644 --- a/deploy/stage/common-values-iris-mpc.yaml +++ b/deploy/stage/common-values-iris-mpc.yaml @@ -1,4 +1,4 @@ -image: "ghcr.io/worldcoin/iris-mpc:v0.6.4" +image: "ghcr.io/worldcoin/iris-mpc:v0.6.5" environment: stage replicaCount: 1 diff --git a/iris-mpc-common/src/galois_engine.rs b/iris-mpc-common/src/galois_engine.rs index 3f0ae4425..21d266a36 100644 --- a/iris-mpc-common/src/galois_engine.rs +++ b/iris-mpc-common/src/galois_engine.rs @@ -419,8 +419,6 @@ pub mod degree4 { let res = 0.5f64 - (dot_codes as i16) as f64 / (2f64 * dot_masks as f64); - println!("{} {}", dot_codes, dot_masks); - // Without rotations if rot_idx == 15 { assert_float_eq!( diff --git a/iris-mpc-gpu/benches/matmul.rs b/iris-mpc-gpu/benches/matmul.rs index 060e867b5..9f8f5032a 100644 --- a/iris-mpc-gpu/benches/matmul.rs +++ b/iris-mpc-gpu/benches/matmul.rs @@ -38,7 +38,8 @@ fn bench_memcpy(c: &mut Criterion) { let preprocessed_query = preprocess_query(&query); let streams = device_manager.fork_streams(); let blass = device_manager.create_cublas(&streams); - let (db_slices, db_sizes) = engine.load_db(&db, DB_SIZE, DB_SIZE, false); + let mut db_slices = engine.alloc_db(DB_SIZE); + let db_sizes = engine.load_full_db(&mut db_slices, &db); group.throughput(Throughput::Elements((DB_SIZE * QUERY_SIZE / 31) as u64)); group.sample_size(10); diff --git a/iris-mpc-gpu/src/dot/share_db.rs b/iris-mpc-gpu/src/dot/share_db.rs index 467385996..a67d18dc3 100644 --- a/iris-mpc-gpu/src/dot/share_db.rs +++ b/iris-mpc-gpu/src/dot/share_db.rs @@ -18,7 +18,7 @@ use cudarc::{ CudaBlas, }, driver::{ - result::{malloc_async, malloc_managed}, + result::{self, malloc_async, malloc_managed}, sys::{CUdeviceptr, CUmemAttach_flags}, CudaFunction, CudaSlice, CudaStream, CudaView, DevicePtr, DeviceSlice, LaunchAsync, LaunchConfig, @@ -108,28 +108,6 @@ pub fn gemm( } } -fn chunking( - slice: &[T], - n_chunks: usize, - chunk_size: usize, - element_size: usize, - alternating: bool, -) -> Vec> { - if alternating { - let mut result = vec![Vec::new(); n_chunks]; - - for (i, chunk) in slice.chunks(element_size).enumerate() { - result[i % n_chunks].extend_from_slice(chunk); - } - result - } else { - slice - .chunks(chunk_size) - .map(|chunk| chunk.to_vec()) - .collect() - } -} - pub struct SlicedProcessedDatabase { pub code_gr: CudaVec2DSlicerRawPointer, pub code_sums_gr: CudaVec2DSlicerU32, @@ -251,156 +229,143 @@ impl ShareDB { } } - #[allow(clippy::type_complexity)] - pub fn load_db( - &self, - db_entries: &[u16], - db_length: usize, // TODO: should handle different sizes for each device - max_db_length: usize, - alternating_chunks: bool, - ) -> (SlicedProcessedDatabase, Vec) { - let mut a1_host = db_entries - .par_iter() - .map(|&x: &u16| (x >> 8) as i8) - .collect::>(); - let mut a0_host = db_entries.par_iter().map(|&x| x as i8).collect::>(); - - // TODO: maybe use gemm here already to speed up loading (we'll need to correct - // the results as well) - a1_host - .par_iter_mut() - .for_each(|x| *x = (*x as i32 - 128) as i8); - - a0_host - .par_iter_mut() - .for_each(|x| *x = (*x as i32 - 128) as i8); - - let a1_sums: Vec = a1_host - .par_chunks(self.code_length) - .map(|row| row.par_iter().map(|&x| x as u32).sum::()) - .collect(); + pub fn alloc_db(&self, max_db_length: usize) -> SlicedProcessedDatabase { + let max_size = max_db_length / self.device_manager.device_count(); + let (db0_sums, (db1_sums, (db0, db1))) = self + .device_manager + .devices() + .iter() + .map(|device| unsafe { + ( + StreamAwareCudaSlice::from(device.alloc(max_size).unwrap()), + ( + StreamAwareCudaSlice::from(device.alloc(max_size).unwrap()), + ( + malloc_managed( + max_size * self.code_length, + CUmemAttach_flags::CU_MEM_ATTACH_GLOBAL, + ) + .unwrap(), + malloc_managed( + max_size * self.code_length, + CUmemAttach_flags::CU_MEM_ATTACH_GLOBAL, + ) + .unwrap(), + ), + ), + ) + }) + .unzip(); - let a0_sums: Vec = a0_host - .par_chunks(self.code_length) - .map(|row| row.par_iter().map(|&x| x as u32).sum::()) - .collect(); + for dev in self.device_manager.devices() { + dev.synchronize().unwrap(); + } - // Split up db and load to all devices - let chunk_size = db_length / self.device_manager.device_count(); - let max_size = max_db_length / self.device_manager.device_count(); + SlicedProcessedDatabase { + code_gr: CudaVec2DSlicerRawPointer { + limb_0: db0, + limb_1: db1, + }, + code_sums_gr: CudaVec2DSlicerU32 { + limb_0: db0_sums, + limb_1: db1_sums, + }, + } + } - // DB sums - let db1_sums = chunking( - &a1_sums, - self.device_manager.device_count(), - chunk_size, - 1, - alternating_chunks, - ); - let db0_sums = chunking( - &a0_sums, - self.device_manager.device_count(), - chunk_size, - 1, - alternating_chunks, - ); + pub fn load_single_record( + index: usize, + db: &CudaVec2DSlicerRawPointer, + record: &[u16], + n_shards: usize, + code_length: usize, + ) { + assert!(record.len() == code_length); - let db1_sums = db1_sums + let a0_host = record .iter() - .enumerate() - .map(|(idx, chunk)| { - let mut slice = unsafe { self.device_manager.device(idx).alloc(max_size).unwrap() }; - self.device_manager - .htod_copy_into(chunk.to_vec(), &mut slice, idx) - .unwrap(); - StreamAwareCudaSlice::from(slice) - }) + .map(|&x| ((x as i8) as i32 - 128) as i8) .collect::>(); - let db0_sums = db0_sums + + let a1_host = record .iter() - .enumerate() - .map(|(idx, chunk)| { - let mut slice = unsafe { self.device_manager.device(idx).alloc(max_size).unwrap() }; - self.device_manager - .htod_copy_into(chunk.to_vec(), &mut slice, idx) - .unwrap(); - StreamAwareCudaSlice::from(slice) - }) + .map(|&x: &u16| ((x >> 8) as i32 - 128) as i8) .collect::>(); - // DB codes - let db1 = chunking( - &a1_host, - self.device_manager.device_count(), - chunk_size * self.code_length, - self.code_length, - alternating_chunks, - ); - - let db0 = chunking( - &a0_host, - self.device_manager.device_count(), - chunk_size * self.code_length, - self.code_length, - alternating_chunks, - ); + let device_index = index % n_shards; + let device_db_index = index / n_shards; - assert!( - db0.iter() - .zip(db1.iter()) - .all(|(chunk0, chunk1)| chunk0.len() == chunk1.len()), - "db0 and db1 chunks must have the same length" - ); + unsafe { + std::ptr::copy( + a0_host.as_ptr() as *const _, + (db.limb_0[device_index] + (device_db_index * code_length) as u64) as *mut _, + code_length, + ); - let db_lens = db0 - .iter() - .map(|chunk| chunk.len() / self.code_length) - .collect::>(); + std::ptr::copy( + a1_host.as_ptr() as *const _, + (db.limb_1[device_index] + (device_db_index * code_length) as u64) as *mut _, + code_length, + ); + }; + } - let db1 = db1 - .iter() - .map(|chunk| unsafe { - let mem = malloc_managed( - max_size * self.code_length, - CUmemAttach_flags::CU_MEM_ATTACH_GLOBAL, - ) - .unwrap(); + pub fn preprocess_db(&self, db: &mut SlicedProcessedDatabase, db_lens: &[usize]) { + let code_len = self.code_length; + for device_index in 0..self.device_manager.device_count() { + for (limbs, sum_slices) in [ + (&db.code_gr.limb_0, &mut db.code_sums_gr.limb_0), + (&db.code_gr.limb_1, &mut db.code_sums_gr.limb_1), + ] { + let sums = (0..db_lens[device_index]) + .into_par_iter() + .map(|idx| { + let slice: &[i8] = unsafe { + std::slice::from_raw_parts( + (limbs[device_index] + (idx * code_len) as u64) as *const _, + code_len, + ) + }; + slice.iter().map(|&x| x as u32).sum::() + }) + .collect::>(); - std::ptr::copy(chunk.as_ptr() as *const _, mem as *mut _, chunk.len()); - mem - }) - .collect::>(); - let db0 = db0 - .iter() - .map(|chunk| unsafe { - let mem = malloc_managed( - max_size * self.code_length, - CUmemAttach_flags::CU_MEM_ATTACH_GLOBAL, - ) - .unwrap(); + self.device_manager + .device(device_index) + .bind_to_thread() + .unwrap(); + unsafe { + result::memcpy_htod_sync(sum_slices[device_index].cu_device_ptr, &sums) + .unwrap(); + } + } + } + } - std::ptr::copy(chunk.as_ptr() as *const _, mem as *mut _, chunk.len()); - mem - }) - .collect::>(); + #[allow(clippy::type_complexity)] + pub fn load_full_db(&self, db: &mut SlicedProcessedDatabase, db_entries: &[u16]) -> Vec { + assert!(db_entries.len() % self.code_length == 0); - for dev in self.device_manager.devices() { - dev.synchronize().unwrap(); + let code_length = self.code_length; + let n_shards = self.device_manager.device_count(); + db_entries + .par_chunks(self.code_length) + .enumerate() + .for_each(|(idx, chunk)| { + Self::load_single_record(idx, &db.code_gr, chunk, n_shards, code_length); + }); + + // Calculate the number of entries per shard + let mut db_lens = vec![db_entries.len() / self.code_length / n_shards; n_shards]; + for i in 0..db_lens.len() { + if i < (db_entries.len() / self.code_length) % n_shards { + db_lens[i] += 1; + } } - ( - SlicedProcessedDatabase { - code_gr: CudaVec2DSlicerRawPointer { - limb_0: db0, - limb_1: db1, - }, - code_sums_gr: CudaVec2DSlicerU32 { - limb_0: db0_sums, - limb_1: db1_sums, - }, - }, - db_lens, - ) + self.preprocess_db(db, &db_lens); + + db_lens } pub fn query_sums( @@ -767,6 +732,7 @@ mod tests { galois_engine::degree4::{GaloisRingIrisCodeShare, GaloisRingTrimmedMaskCodeShare}, iris_db::db::IrisDB, }; + use itertools::Itertools; use ndarray::Array2; use num_traits::FromPrimitive; use rand::{rngs::StdRng, Rng, SeedableRng}; @@ -800,6 +766,18 @@ mod tests { .collect() } + fn shard_db(db: &[u16], n_shards: usize) -> Vec { + let mut res: Vec> = vec![vec![]; n_shards]; + db.iter() + .chunks(WIDTH) + .into_iter() + .enumerate() + .for_each(|(i, chunk)| { + res[i % n_shards].extend(chunk); + }); + res.into_iter().flatten().collect::>() + } + /// Test to verify the matmul operation for random matrices in the field #[test] #[cfg(feature = "gpu_dependent")] @@ -827,7 +805,8 @@ mod tests { .htod_transfer_query(&preprocessed_query, &streams, QUERY_SIZE, IRIS_CODE_LENGTH) .unwrap(); let query_sums = engine.query_sums(&preprocessed_query, &streams, &blass); - let (db_slices, db_sizes) = engine.load_db(&db, DB_SIZE, DB_SIZE, false); + let mut db_slices = engine.alloc_db(DB_SIZE); + let db_sizes = engine.load_full_db(&mut db_slices, &db); engine.dot( &preprocessed_query, @@ -840,7 +819,7 @@ mod tests { engine.dot_reduce(&query_sums, &db_slices.code_sums_gr, &db_sizes, 0, &streams); device_manager.await_streams(&streams); - let a_nda = random_ndarray::(db.clone(), DB_SIZE, WIDTH); + let a_nda = random_ndarray::(shard_db(&db, n_devices), DB_SIZE, WIDTH); let b_nda = random_ndarray::(query.clone(), QUERY_SIZE, WIDTH); let c_nda = a_nda.dot(&b_nda.t()); @@ -928,7 +907,9 @@ mod tests { .htod_transfer_query(&preprocessed_query, &streams, QUERY_SIZE, IRIS_CODE_LENGTH) .unwrap(); let query_sums = engine.query_sums(&preprocessed_query, &streams, &blass); - let (db_slices, db_sizes) = engine.load_db(&codes_db, DB_SIZE, DB_SIZE, false); + let mut db_slices = engine.alloc_db(DB_SIZE); + let db_sizes = engine.load_full_db(&mut db_slices, &codes_db); + engine.dot( &preprocessed_query, &db_slices.code_gr, @@ -945,7 +926,8 @@ mod tests { for i in 0..DB_SIZE * QUERY_SIZE / n_devices { assert_eq!( (gpu_result[0][i] + gpu_result[1][i] + gpu_result[2][i]), - (db.db[i / (DB_SIZE / n_devices)].mask & db.db[i % (DB_SIZE / n_devices)].mask) + (db.db[i / (DB_SIZE / n_devices)].mask + & db.db[(i % (DB_SIZE / n_devices)) * n_devices].mask) .count_ones() as u16 ); } @@ -1056,10 +1038,10 @@ mod tests { .unwrap(); let code_query_sums = codes_engine.query_sums(&code_query, &streams, &blass); let mask_query_sums = masks_engine.query_sums(&mask_query, &streams, &blass); - let (code_db_slices, db_sizes) = - codes_engine.load_db(&codes_db, DB_SIZE, DB_SIZE, false); - let (mask_db_slices, mask_db_sizes) = - masks_engine.load_db(&masks_db, DB_SIZE, DB_SIZE, false); + let mut code_db_slices = codes_engine.alloc_db(DB_SIZE); + let db_sizes = codes_engine.load_full_db(&mut code_db_slices, &codes_db); + let mut mask_db_slices = masks_engine.alloc_db(DB_SIZE); + let mask_db_sizes = masks_engine.load_full_db(&mut mask_db_slices, &masks_db); assert_eq!(db_sizes, mask_db_sizes); @@ -1111,10 +1093,6 @@ mod tests { let code = results_codes[0][i] + results_codes[1][i] + results_codes[2][i]; let mask = results_masks[0][i] + results_masks[1][i] + results_masks[2][i]; - if i == 0 { - tracing::info!("Code: {}, Mask: {}", code, mask); - } - reconstructed_codes.push(code); reconstructed_masks.push(mask); } @@ -1129,12 +1107,9 @@ mod tests { // Compare against plain reference implementation let reference_dists = db.calculate_distances(&db.db[0]); - println!("Dists: {:?}", dists[0..10].to_vec()); - println!("Ref Dists: {:?}", reference_dists[0..10].to_vec()); - // TODO: check for all devices and the whole query for i in 0..DB_SIZE / n_devices { - assert_float_eq!(dists[i], reference_dists[i], abs <= 1e-6); + assert_float_eq!(dists[i], reference_dists[i * n_devices], abs <= 1e-6); } } } diff --git a/iris-mpc-gpu/src/server/actor.rs b/iris-mpc-gpu/src/server/actor.rs index 42edaaf95..7b3577b8a 100644 --- a/iris-mpc-gpu/src/server/actor.rs +++ b/iris-mpc-gpu/src/server/actor.rs @@ -96,6 +96,7 @@ pub struct ServerActor { current_db_sizes: Vec, query_db_size: Vec, max_batch_size: usize, + max_db_size: usize, } const NON_MATCH_ID: u32 = u32::MAX; @@ -105,23 +106,17 @@ impl ServerActor { pub fn new( party_id: usize, chacha_seeds: ([u32; 8], [u32; 8]), - left_eye_db: IrisCodeDbSlice, - right_eye_db: IrisCodeDbSlice, job_queue_size: usize, - db_size: usize, - db_buffer: usize, + max_db_size: usize, max_batch_size: usize, ) -> eyre::Result<(Self, ServerActorHandle)> { let device_manager = Arc::new(DeviceManager::init()); Self::new_with_device_manager( party_id, chacha_seeds, - left_eye_db, - right_eye_db, device_manager, job_queue_size, - db_size, - db_buffer, + max_db_size, max_batch_size, ) } @@ -129,12 +124,9 @@ impl ServerActor { pub fn new_with_device_manager( party_id: usize, chacha_seeds: ([u32; 8], [u32; 8]), - left_eye_db: IrisCodeDbSlice, - right_eye_db: IrisCodeDbSlice, device_manager: Arc, job_queue_size: usize, - db_size: usize, - db_buffer: usize, + max_db_size: usize, max_batch_size: usize, ) -> eyre::Result<(Self, ServerActorHandle)> { let ids = device_manager.get_ids_from_magic(0); @@ -142,13 +134,10 @@ impl ServerActor { Self::new_with_device_manager_and_comms( party_id, chacha_seeds, - left_eye_db, - right_eye_db, device_manager, comms, job_queue_size, - db_size, - db_buffer, + max_db_size, max_batch_size, ) } @@ -157,48 +146,20 @@ impl ServerActor { pub fn new_with_device_manager_and_comms( party_id: usize, chacha_seeds: ([u32; 8], [u32; 8]), - left_eye_db: IrisCodeDbSlice, - right_eye_db: IrisCodeDbSlice, device_manager: Arc, comms: Vec>, job_queue_size: usize, - db_size: usize, - db_buffer: usize, + max_db_size: usize, max_batch_size: usize, ) -> eyre::Result<(Self, ServerActorHandle)> { - assert!( - [left_eye_db.0.len(), right_eye_db.0.len(),] - .iter() - .all(|&x| x == db_size * IRIS_CODE_LENGTH), - "Internal DB mismatch, left and right iris code db sizes differ, expected {}, left \ - has {}, while right has {}", - db_size * IRIS_CODE_LENGTH, - left_eye_db.0.len(), - right_eye_db.0.len() - ); - - assert!( - [left_eye_db.1.len(), right_eye_db.1.len()] - .iter() - .all(|&x| x == db_size * MASK_CODE_LENGTH), - "Internal DB mismatch, left and right mask code db sizes differ, expected {}, left \ - has {}, while right has {}", - db_size * MASK_CODE_LENGTH, - left_eye_db.1.len(), - right_eye_db.1.len() - ); - let (tx, rx) = mpsc::channel(job_queue_size); let actor = Self::init( party_id, chacha_seeds, - left_eye_db, - right_eye_db, device_manager, comms, rx, - db_size, - db_buffer, + max_db_size, max_batch_size, )?; Ok((actor, ServerActorHandle { job_queue: tx })) @@ -208,13 +169,10 @@ impl ServerActor { fn init( party_id: usize, chacha_seeds: ([u32; 8], [u32; 8]), - left_eye_db: IrisCodeDbSlice, - right_eye_db: IrisCodeDbSlice, device_manager: Arc, comms: Vec>, job_queue: mpsc::Receiver, - db_size: usize, - db_buffer: usize, + max_db_size: usize, max_batch_size: usize, ) -> eyre::Result { let mut kdf_nonce = 0; @@ -255,23 +213,10 @@ impl ServerActor { comms.clone(), ); - // load left and right eye databases to device - let (left_code_db_slices, current_db_sizes) = - codes_engine.load_db(left_eye_db.0, db_size, db_size + db_buffer, true); - let (left_mask_db_slices, left_mask_db_sizes) = - masks_engine.load_db(left_eye_db.1, db_size, db_size + db_buffer, true); - - let (right_code_db_slices, right_db_sizes) = - codes_engine.load_db(right_eye_db.0, db_size, db_size + db_buffer, true); - let (right_mask_db_slices, right_mask_db_sizes) = - masks_engine.load_db(right_eye_db.1, db_size, db_size + db_buffer, true); - - assert!( - [left_mask_db_sizes, right_mask_db_sizes, right_db_sizes] - .iter() - .all(|size| size == ¤t_db_sizes), - "Code and mask db sizes mismatch" - ); + let left_code_db_slices = codes_engine.alloc_db(max_db_size); + let left_mask_db_slices = masks_engine.alloc_db(max_db_size); + let right_code_db_slices = codes_engine.alloc_db(max_db_size); + let right_mask_db_slices = masks_engine.alloc_db(max_db_size); // Engines for inflight queries let batch_codes_engine = ShareDB::init( @@ -343,15 +288,17 @@ impl ServerActor { let results = distance_comparator.prepare_results(); let batch_results = distance_comparator.prepare_results(); - let db_match_list_left = distance_comparator - .prepare_db_match_list((db_size + db_buffer) / device_manager.device_count()); - let db_match_list_right = distance_comparator - .prepare_db_match_list((db_size + db_buffer) / device_manager.device_count()); + let db_match_list_left = + distance_comparator.prepare_db_match_list(max_db_size / device_manager.device_count()); + let db_match_list_right = + distance_comparator.prepare_db_match_list(max_db_size / device_manager.device_count()); let batch_match_list_left = distance_comparator.prepare_db_match_list(n_queries); let batch_match_list_right = distance_comparator.prepare_db_match_list(n_queries); let query_db_size = vec![n_queries; device_manager.device_count()]; + let current_db_sizes = vec![0; device_manager.device_count()]; + for dev in device_manager.devices() { dev.synchronize().unwrap(); } @@ -384,6 +331,7 @@ impl ServerActor { batch_match_list_left, batch_match_list_right, max_batch_size, + max_db_size, }) } @@ -398,6 +346,104 @@ impl ServerActor { tracing::info!("Server Actor finished due to all job queues being closed"); } + pub fn load_full_db( + &mut self, + left: &IrisCodeDbSlice, + right: &IrisCodeDbSlice, + db_size: usize, + ) { + assert!( + [left.0.len(), right.0.len(),] + .iter() + .all(|&x| x == db_size * IRIS_CODE_LENGTH), + "Internal DB mismatch, left and right iris code db sizes differ, expected {}, left \ + has {}, while right has {}", + db_size * IRIS_CODE_LENGTH, + left.0.len(), + right.0.len() + ); + + assert!( + [left.1.len(), right.1.len()] + .iter() + .all(|&x| x == db_size * MASK_CODE_LENGTH), + "Internal DB mismatch, left and right mask code db sizes differ, expected {}, left \ + has {}, while right has {}", + db_size * MASK_CODE_LENGTH, + left.1.len(), + right.1.len() + ); + + let db_lens1 = self + .codes_engine + .load_full_db(&mut self.left_code_db_slices, left.0); + let db_lens2 = self + .masks_engine + .load_full_db(&mut self.left_mask_db_slices, left.1); + let db_lens3 = self + .codes_engine + .load_full_db(&mut self.right_code_db_slices, right.0); + let db_lens4 = self + .masks_engine + .load_full_db(&mut self.right_mask_db_slices, right.1); + + assert_eq!(db_lens1, db_lens2); + assert_eq!(db_lens1, db_lens3); + assert_eq!(db_lens1, db_lens4); + + self.current_db_sizes = db_lens1; + } + + pub fn load_single_record( + &mut self, + index: usize, + left_code: &[u16], + left_mask: &[u16], + right_code: &[u16], + right_mask: &[u16], + ) { + ShareDB::load_single_record( + index, + &self.left_code_db_slices.code_gr, + left_code, + self.device_manager.device_count(), + IRIS_CODE_LENGTH, + ); + ShareDB::load_single_record( + index, + &self.left_mask_db_slices.code_gr, + left_mask, + self.device_manager.device_count(), + MASK_CODE_LENGTH, + ); + ShareDB::load_single_record( + index, + &self.right_code_db_slices.code_gr, + right_code, + self.device_manager.device_count(), + IRIS_CODE_LENGTH, + ); + ShareDB::load_single_record( + index, + &self.right_mask_db_slices.code_gr, + right_mask, + self.device_manager.device_count(), + MASK_CODE_LENGTH, + ); + self.current_db_sizes[index % self.device_manager.device_count()] += 1; + } + + pub fn preprocess_db(&mut self) { + self.codes_engine + .preprocess_db(&mut self.left_code_db_slices, &self.current_db_sizes); + self.masks_engine + .preprocess_db(&mut self.left_mask_db_slices, &self.current_db_sizes); + self.codes_engine + .preprocess_db(&mut self.right_code_db_slices, &self.current_db_sizes); + self.masks_engine + .preprocess_db(&mut self.right_mask_db_slices, &self.current_db_sizes); + } + fn process_batch_query( &mut self, batch: BatchQuery, @@ -438,6 +484,14 @@ impl ServerActor { for deletion_index in batch.deletion_requests_indices.clone() { let device_index = deletion_index % self.device_manager.device_count() as u32; let device_db_index = deletion_index / self.device_manager.device_count() as u32; + if device_db_index as usize >= self.current_db_sizes[device_index as usize] { + tracing::warn!( + "Deletion index {} is out of bounds for device {}", + deletion_index, + device_index + ); + continue; + } self.device_manager .device(device_index as usize) .bind_to_thread() @@ -716,6 +770,19 @@ impl ServerActor { // Write back to in-memory db let previous_total_db_size = self.current_db_sizes.iter().sum::(); + let n_insertions = insertion_list.iter().map(|x| x.len()).sum::(); + + // Check if we actually have space left to write the new entries + if previous_total_db_size + n_insertions > self.max_db_size { + tracing::error!( + "Cannot write new entries, since DB size would be exceeded, current: {}, batch \ + insertions: {}, max: {}", + previous_total_db_size, + n_insertions, + self.max_db_size + ); + eyre::bail!("DB size exceeded"); + } record_stream_time!( &self.device_manager, @@ -743,7 +810,6 @@ impl ServerActor { self.current_db_sizes[i] += 1; } - // DEBUG tracing::debug!( "Updating DB size on device {}: {:?}", i, diff --git a/iris-mpc-gpu/tests/e2e.rs b/iris-mpc-gpu/tests/e2e.rs index d5641279c..b6a677bad 100644 --- a/iris-mpc-gpu/tests/e2e.rs +++ b/iris-mpc-gpu/tests/e2e.rs @@ -110,16 +110,14 @@ async fn e2e_test() -> Result<()> { let actor = match ServerActor::new_with_device_manager_and_comms( 0, chacha_seeds0, - (&db0.0, &db0.1), - (&db0.0, &db0.1), device_manager0, comms0, 8, - DB_SIZE, - DB_BUFFER, + DB_SIZE + DB_BUFFER, MAX_BATCH_SIZE, ) { - Ok((actor, handle)) => { + Ok((mut actor, handle)) => { + actor.load_full_db(&(&db0.0, &db0.1), &(&db0.0, &db0.1), DB_SIZE); tx0.send(Ok(handle)).unwrap(); actor } @@ -137,16 +135,14 @@ async fn e2e_test() -> Result<()> { let actor = match ServerActor::new_with_device_manager_and_comms( 1, chacha_seeds1, - (&db1.0, &db1.1), - (&db1.0, &db1.1), device_manager1, comms1, 8, - DB_SIZE, - DB_BUFFER, + DB_SIZE + DB_BUFFER, MAX_BATCH_SIZE, ) { - Ok((actor, handle)) => { + Ok((mut actor, handle)) => { + actor.load_full_db(&(&db1.0, &db1.1), &(&db1.0, &db1.1), DB_SIZE); tx1.send(Ok(handle)).unwrap(); actor } @@ -164,16 +160,14 @@ async fn e2e_test() -> Result<()> { let actor = match ServerActor::new_with_device_manager_and_comms( 2, chacha_seeds2, - (&db2.0, &db2.1), - (&db2.0, &db2.1), device_manager2, comms2, 8, - DB_SIZE, - DB_BUFFER, + DB_SIZE + DB_BUFFER, MAX_BATCH_SIZE, ) { - Ok((actor, handle)) => { + Ok((mut actor, handle)) => { + actor.load_full_db(&(&db2.0, &db2.1), &(&db2.0, &db2.1), DB_SIZE); tx2.send(Ok(handle)).unwrap(); actor } diff --git a/iris-mpc-store/src/lib.rs b/iris-mpc-store/src/lib.rs index 583e4c0ff..1f1aa93a6 100644 --- a/iris-mpc-store/src/lib.rs +++ b/iris-mpc-store/src/lib.rs @@ -78,7 +78,7 @@ struct StoredState { request_id: String, } -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct Store { pool: PgPool, } @@ -334,7 +334,7 @@ DO UPDATE SET right_code = EXCLUDED.right_code, right_mask = EXCLUDED.right_mask self.rollback(0).await?; } - let mut tx = self.tx().await.unwrap(); + let mut tx = self.tx().await?; for i in 0..db_size { if (i % 1000) == 0 { @@ -367,10 +367,20 @@ DO UPDATE SET right_code = EXCLUDED.right_code, right_mask = EXCLUDED.right_mask right_mask: &mask.coefs, }]) .await?; + + if (i % 1000) == 0 { + tx.commit().await?; + tx = self.tx().await?; + } } tracing::info!("Completed initialization of iris db, committing..."); tx.commit().await?; tracing::info!("Committed"); + + tracing::info!( + "Initialized iris db with {} entries", + self.count_irises().await? + ); Ok(()) } } diff --git a/iris-mpc/src/bin/server.rs b/iris-mpc/src/bin/server.rs index d65c504e5..3d1a4dc98 100644 --- a/iris-mpc/src/bin/server.rs +++ b/iris-mpc/src/bin/server.rs @@ -21,13 +21,12 @@ use iris_mpc_common::{ sync::SyncState, task_monitor::TaskMonitor, }, - IrisCodeDb, IRIS_CODE_LENGTH, MASK_CODE_LENGTH, }; use iris_mpc_gpu::{ helpers::device_manager::DeviceManager, server::{ - get_dummy_shares_for_deletion, heartbeat_nccl::start_heartbeat, sync_nccl, BatchMetadata, - BatchQuery, ServerActor, ServerJobResult, + get_dummy_shares_for_deletion, sync_nccl, BatchMetadata, BatchQuery, ServerActor, + ServerJobResult, }, }; use iris_mpc_store::{Store, StoredIrisRef}; @@ -49,7 +48,7 @@ use tokio::{ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; const REGION: &str = "eu-north-1"; -const DB_BUFFER: usize = 8 * 1_000; +const MAX_DB_SIZE: usize = 8 * 2_000; const RNG_SEED_INIT_DB: u64 = 42; const SQS_POLLING_INTERVAL: Duration = Duration::from_secs(1); const MAX_CONCURRENT_REQUESTS: usize = 32; @@ -448,77 +447,6 @@ async fn initialize_chacha_seeds( Ok(chacha_seeds) } -async fn initialize_iris_dbs( - party_id: usize, - store: &Store, - config: &Config, -) -> eyre::Result<(IrisCodeDb, IrisCodeDb, usize)> { - // Generate or load DB - - tracing::info!("Initialize persistent iris db with randomly generated shares"); - store - .init_db_with_random_shares( - RNG_SEED_INIT_DB, - party_id, - config.init_db_size, - config.clear_db_before_init, - ) - .await - .expect("failed to initialise db"); - - let count_irises = store.count_irises().await?; - tracing::info!("Initialize iris db: Counted {} entries in DB", count_irises); - - let mut left_codes_db: Vec = vec![0u16; count_irises * IRIS_CODE_LENGTH]; - let mut left_masks_db: Vec = vec![0u16; count_irises * MASK_CODE_LENGTH]; - let mut right_codes_db: Vec = vec![0u16; count_irises * IRIS_CODE_LENGTH]; - let mut right_masks_db: Vec = vec![0u16; count_irises * MASK_CODE_LENGTH]; - - let parallelism = config - .database - .as_ref() - .ok_or(eyre!("Missing database config"))? - .load_parallelism; - - tracing::info!( - "Initialize iris db: Loading from DB (parallelism: {})", - parallelism - ); - // Load DB from persistent storage. - let mut store_len = 0; - let mut stream = store.stream_irises_par(parallelism).await; - while let Some(iris) = stream.try_next().await? { - let iris_index = iris.index() - 1; - if iris_index >= count_irises { - return Err(eyre!("Inconsistent iris index {}", iris_index)); - } - - let start_code = iris_index * IRIS_CODE_LENGTH; - let start_mask = iris_index * MASK_CODE_LENGTH; - left_codes_db[start_code..start_code + IRIS_CODE_LENGTH].copy_from_slice(iris.left_code()); - left_masks_db[start_mask..start_mask + MASK_CODE_LENGTH].copy_from_slice(iris.left_mask()); - right_codes_db[start_code..start_code + IRIS_CODE_LENGTH] - .copy_from_slice(iris.right_code()); - right_masks_db[start_mask..start_mask + MASK_CODE_LENGTH] - .copy_from_slice(iris.right_mask()); - - store_len += 1; - if (store_len % 10000) == 0 { - tracing::info!("Initialize iris db: Loaded {} entries from DB", store_len); - } - } - tracing::info!( - "Initialize iris db: Loaded {} entries from DB, done!", - store_len - ); - - Ok(( - (left_codes_db, left_masks_db), - (right_codes_db, right_masks_db), - count_irises, - )) -} - async fn send_results_to_sns( result_events: Vec, sns_client: &SNSClient, @@ -610,9 +538,29 @@ async fn server_main(config: Config) -> eyre::Result<()> { ) .await?; - tracing::info!("Initialize iris db"); - let (mut left_iris_db, mut right_iris_db, store_len) = - initialize_iris_dbs(party_id, &store, &config).await?; + let store_len = store.count_irises().await?; + + // Seed the persistent storage with random shares if configured and db is still + // empty. + if store_len == 0 && config.init_db_size > 0 { + tracing::info!("Initialize persistent iris db with randomly generated shares"); + store + .init_db_with_random_shares( + RNG_SEED_INIT_DB, + party_id, + config.init_db_size, + config.clear_db_before_init, + ) + .await?; + } + + // Fetch again in case we've just initialized the DB + let store_len = store.count_irises().await?; + + if store_len > MAX_DB_SIZE { + tracing::error!("Database size exceeds maximum allowed size: {}", store_len); + eyre::bail!("Database size exceeds maximum allowed size: {}", store_len); + } let my_state = SyncState { db_len: store_len as u64, @@ -622,17 +570,26 @@ async fn server_main(config: Config) -> eyre::Result<()> { tracing::info!("Preparing task monitor"); let mut background_tasks = TaskMonitor::new(); - let (tx, rx) = oneshot::channel(); - let _heartbeat = background_tasks.spawn(start_heartbeat(config.party_id, tx)); + // DEBUG: disable heartbeat + // let (tx, rx) = oneshot::channel(); + // let _heartbeat = background_tasks.spawn(start_heartbeat(config.party_id, + // tx)); - background_tasks.check_tasks(); - tracing::info!("Heartbeat starting..."); - rx.await??; - tracing::info!("Heartbeat started."); + // background_tasks.check_tasks(); + // tracing::info!("Heartbeat starting..."); + // rx.await??; + // tracing::info!("Heartbeat started."); - // a bit convoluted, but we need to create the actor on the thread already, + // Start the actor in separate task. + // A bit convoluted, but we need to create the actor on the thread already, // since it blocks a lot and is `!Send`, we get back the handle via the oneshot // channel + let parallelism = config + .database + .as_ref() + .ok_or(eyre!("Missing database config"))? + .load_parallelism; + let (tx, rx) = oneshot::channel(); background_tasks.spawn_blocking(move || { let device_manager = Arc::new(DeviceManager::init()); @@ -650,41 +607,56 @@ async fn server_main(config: Config) -> eyre::Result<()> { } }; - tracing::info!("DB: check if rollback needed"); if let Some(db_len) = sync_result.must_rollback_storage() { - tracing::warn!( - "Databases are out-of-sync, rolling back (current len: {}, new len: {})", - store_len, - db_len - ); - // Rollback the data that we have already loaded. - let bit_len_code = db_len * IRIS_CODE_LENGTH; - let bit_len_mask = db_len * MASK_CODE_LENGTH; - - // TODO: remove the line below if you removed fake data. - let bit_len_code = bit_len_code + (left_iris_db.0.len() - store_len * IRIS_CODE_LENGTH); - let bit_len_mask = bit_len_mask + (left_iris_db.1.len() - store_len * MASK_CODE_LENGTH); - left_iris_db.0.truncate(bit_len_code); - left_iris_db.1.truncate(bit_len_mask); - right_iris_db.0.truncate(bit_len_code); - right_iris_db.1.truncate(bit_len_mask); + tracing::error!("Databases are out-of-sync: {:?}", sync_result); + if db_len + max_rollback < store_len { + return Err(eyre!( + "Refusing to rollback so much (from {} to {})", + store_len, + db_len, + )); + } + tokio::runtime::Handle::current().block_on(async { store.rollback(db_len).await })?; + tracing::error!("Rolled back to db_len={}", db_len); } tracing::info!("Starting server actor"); match ServerActor::new_with_device_manager_and_comms( config.party_id, chacha_seeds, - (&left_iris_db.0, &left_iris_db.1), - (&right_iris_db.0, &right_iris_db.1), device_manager, comms, 8, - store_len, - DB_BUFFER, + MAX_DB_SIZE, config.max_batch_size, ) { - Ok((actor, handle)) => { - tx.send(Ok((handle, sync_result))).unwrap(); + Ok((mut actor, handle)) => { + tracing::info!( + "Initialize iris db: Loading from DB (parallelism: {})", + parallelism + ); + tokio::runtime::Handle::current().block_on(async { + let mut stream = store.stream_irises_par(parallelism).await; + + while let Some(iris) = stream.try_next().await? { + if iris.index() > store_len { + return Err(eyre!("Inconsistent iris index {}", iris.index())); + } + actor.load_single_record( + iris.index() - 1, + iris.left_code(), + iris.left_mask(), + iris.right_code(), + iris.right_mask(), + ); + } + + actor.preprocess_db(); + + eyre::Ok(()) + })?; + + tx.send(Ok((handle, sync_result, store))).unwrap(); actor.run(); // forever } Err(e) => { @@ -695,20 +667,7 @@ async fn server_main(config: Config) -> eyre::Result<()> { Ok(()) }); - let (mut handle, sync_result) = rx.await??; - - if let Some(db_len) = sync_result.must_rollback_storage() { - tracing::error!("Databases are out-of-sync: {:?}", sync_result); - if db_len + max_rollback < store_len { - return Err(eyre!( - "Refusing to rollback so much (from {} to {})", - store_len, - db_len, - )); - } - store.rollback(db_len).await?; - tracing::error!("Rolled back to db_len={}", db_len); - } + let (mut handle, sync_result, store) = rx.await??; let mut skip_request_ids = sync_result.deleted_request_ids();