diff --git a/iris-mpc-gpu/src/dot/distance_comparator.rs b/iris-mpc-gpu/src/dot/distance_comparator.rs index 455a9c70a..eb5a0c7c6 100644 --- a/iris-mpc-gpu/src/dot/distance_comparator.rs +++ b/iris-mpc-gpu/src/dot/distance_comparator.rs @@ -347,4 +347,21 @@ impl DistanceComparator { }) .collect::>() } + + pub fn prepare_match_distances_buffer(&self, max_size: usize) -> Vec> { + (0..self.device_manager.device_count()) + .map(|i| { + self.device_manager + .device(i) + .alloc_zeros(max_size / self.device_manager.device_count()) + .unwrap() + }) + .collect::>() + } + + pub fn prepare_match_distances_counter(&self) -> Vec> { + (0..self.device_manager.device_count()) + .map(|i| self.device_manager.device(i).alloc_zeros(1).unwrap()) + .collect::>() + } } diff --git a/iris-mpc-gpu/src/dot/share_db.rs b/iris-mpc-gpu/src/dot/share_db.rs index b6dcd210a..5842a4d91 100644 --- a/iris-mpc-gpu/src/dot/share_db.rs +++ b/iris-mpc-gpu/src/dot/share_db.rs @@ -132,8 +132,8 @@ pub struct ShareDB { comms: Vec>, ones: Vec>, intermediate_results: Vec>, - pub results: Vec>, - pub results_peer: Vec>, + pub results: Vec>>, + pub results_peer: Vec>>, code_length: usize, } @@ -183,25 +183,32 @@ impl ShareDB { // TODO: depending on the batch size, intermediate_results can get quite big, we // can perform the gemm in chunks to limit this let mut intermediate_results = vec![]; - let mut results = vec![]; - let mut results_peer = vec![]; + let mut results = vec![vec![]; 2]; + let mut results_peer = vec![vec![]; 2]; let results_len = (max_db_length * query_length).div_ceil(64) * 64; for idx in 0..n_devices { unsafe { intermediate_results.push(device_manager.device(idx).alloc(results_len).unwrap()); - results.push( - device_manager - .device(idx) - .alloc(results_len * std::mem::size_of::()) - .unwrap(), - ); - results_peer.push( - device_manager - .device(idx) - .alloc(results_len * std::mem::size_of::()) - .unwrap(), - ); + } + } + + for i in 0..2 { + for idx in 0..n_devices { + unsafe { + results_peer[i].push( + device_manager + .device(idx) + .alloc(results_len * std::mem::size_of::()) + .unwrap(), + ); + results[i].push( + device_manager + .device(idx) + .alloc(results_len * std::mem::size_of::()) + .unwrap(), + ); + } } } @@ -589,6 +596,7 @@ impl ShareDB { offset: usize, streams: &[CudaStream], multiplier: u16, + results_idx: usize, ) { for idx in 0..self.device_manager.device_count() { assert!( @@ -611,7 +619,7 @@ impl ShareDB { cfg, ( &self.intermediate_results[idx], - &mut self.results[idx], + &mut self.results[results_idx][idx], *db_sums.limb_0[idx].device_ptr(), *db_sums.limb_1[idx].device_ptr(), *query_sums.limb_0[idx].device_ptr(), @@ -636,8 +644,17 @@ impl ShareDB { chunk_sizes: &[usize], offset: usize, streams: &[CudaStream], + results_idx: usize, ) { - self.dot_reduce_and_multiply(query_sums, db_sums, chunk_sizes, offset, streams, 1); + self.dot_reduce_and_multiply( + query_sums, + db_sums, + chunk_sizes, + offset, + streams, + 1, + results_idx, + ); } fn single_xor_assign_u8( @@ -700,18 +717,19 @@ impl ShareDB { len: usize, idx: usize, streams: &[CudaStream], + results_idx: usize, ) -> CudaSlice { assert_eq!(len & 3, 0); let mut rand = unsafe { self.device_manager .device(idx) .alloc::(len >> 2) - .unwrap() + .unwrap() // TODO: fix, make this async }; let mut rand_u8 = self.fill_my_rng_into_u8(&mut rand, idx, streams); self.single_xor_assign_u8( &mut rand_u8, - &self.results[idx].slice(..), + &self.results[results_idx][idx].slice(..), idx, len, streams, @@ -719,7 +737,13 @@ impl ShareDB { rand } - fn otp_decrypt_rng_result(&mut self, len: usize, idx: usize, streams: &[CudaStream]) { + fn otp_decrypt_rng_result( + &mut self, + len: usize, + idx: usize, + streams: &[CudaStream], + results_idx: usize, + ) { assert_eq!(len & 3, 0); let mut rand = unsafe { self.device_manager @@ -729,7 +753,7 @@ impl ShareDB { }; let rand_u8 = self.fill_their_rng_into_u8(&mut rand, idx, streams); self.single_xor_assign_u8( - &mut self.results_peer[idx].slice(..), + &mut self.results_peer[results_idx][idx].slice(..), &rand_u8, idx, len, @@ -737,14 +761,19 @@ impl ShareDB { ); } - pub fn reshare_results(&mut self, db_sizes: &[usize], streams: &[CudaStream]) { + pub fn reshare_results( + &mut self, + db_sizes: &[usize], + streams: &[CudaStream], + results_idx: usize, + ) { let next_peer = (self.peer_id + 1) % 3; let prev_peer = (self.peer_id + 2) % 3; let send_bufs = (0..self.device_manager.device_count()) .map(|idx| { let len = db_sizes[idx] * self.query_length * 2; - self.otp_encrypt_rng_result(len, idx, streams) + self.otp_encrypt_rng_result(len, idx, streams, results_idx) }) .collect_vec(); @@ -759,7 +788,7 @@ impl ShareDB { .send_view(&send_view, next_peer, &streams[idx]) .unwrap(); - let mut recv_view = self.results_peer[idx].slice(..len); + let mut recv_view = self.results_peer[results_idx][idx].slice(..len); self.comms[idx] .receive_view(&mut recv_view, prev_peer, &streams[idx]) .unwrap(); @@ -767,14 +796,20 @@ impl ShareDB { nccl::group_end().unwrap(); for idx in 0..self.device_manager.device_count() { let len = db_sizes[idx] * self.query_length * 2; - self.otp_decrypt_rng_result(len, idx, streams); + self.otp_decrypt_rng_result(len, idx, streams, results_idx); } } - pub fn fetch_results(&self, results: &mut [u16], db_sizes: &[usize], device_id: usize) { + pub fn fetch_results( + &self, + results: &mut [u16], + db_sizes: &[usize], + device_id: usize, + results_idx: usize, + ) { unsafe { - let res_trans = - self.results[device_id].transmute(db_sizes[device_id] * self.query_length); + let res_trans = self.results[results_idx][device_id] + .transmute(db_sizes[device_id] * self.query_length); self.device_manager .device(device_id) @@ -783,25 +818,33 @@ impl ShareDB { } } - pub fn result_chunk_shares<'a>(&'a self, db_sizes: &[usize]) -> Vec> { - izip!(db_sizes, self.results.iter(), self.results_peer.iter()) - .map(|(&len, xa, xb)| { - // SAFETY: All bit patterns are valid u16 values - let xa_view = unsafe { - xa.transmute(len * self.query_length) - .expect("len is correct") - }; - // SAFETY: All bit patterns are valid u16 values - let xb_view = unsafe { - xb.transmute(len * self.query_length) - .expect("len is correct") - }; - ChunkShareView { - a: xa_view, - b: xb_view, - } - }) - .collect() + pub fn result_chunk_shares<'a>( + &'a self, + db_sizes: &[usize], + results_idx: usize, + ) -> Vec> { + izip!( + db_sizes, + self.results[results_idx].iter(), + self.results_peer[results_idx].iter() + ) + .map(|(&len, xa, xb)| { + // SAFETY: All bit patterns are valid u16 values + let xa_view = unsafe { + xa.transmute(len * self.query_length) + .expect("len is correct") + }; + // SAFETY: All bit patterns are valid u16 values + let xb_view = unsafe { + xb.transmute(len * self.query_length) + .expect("len is correct") + }; + ChunkShareView { + a: xa_view, + b: xb_view, + } + }) + .collect() } } @@ -902,7 +945,14 @@ mod tests { &streams, &blass, ); - engine.dot_reduce(&query_sums, &db_slices.code_sums_gr, &db_sizes, 0, &streams); + engine.dot_reduce( + &query_sums, + &db_slices.code_sums_gr, + &db_sizes, + 0, + &streams, + 0, + ); device_manager.await_streams(&streams); let a_nda = random_ndarray::(shard_db(&db, n_devices), DB_SIZE, WIDTH); @@ -917,7 +967,7 @@ mod tests { } for device_idx in 0..n_devices { - engine.fetch_results(&mut gpu_result, &db_sizes, device_idx); + engine.fetch_results(&mut gpu_result, &db_sizes, device_idx, 0); let selected_elements: Vec = vec_column_major .chunks(DB_SIZE) .flat_map(|chunk| { @@ -1004,9 +1054,16 @@ mod tests { &streams, &blass, ); - engine.dot_reduce(&query_sums, &db_slices.code_sums_gr, &db_sizes, 0, &streams); + engine.dot_reduce( + &query_sums, + &db_slices.code_sums_gr, + &db_sizes, + 0, + &streams, + 0, + ); device_manager.await_streams(&streams); - engine.fetch_results(&mut gpu_result[i], &db_sizes, 0); + engine.fetch_results(&mut gpu_result[i], &db_sizes, 0, 0); } for i in 0..DB_SIZE * QUERY_SIZE / n_devices { @@ -1168,8 +1225,8 @@ mod tests { device_manager.await_streams(&streams); // TODO: fetch results also for other devices - codes_engine.fetch_results(&mut results_codes[party_id], &db_sizes, 0); - masks_engine.fetch_results(&mut results_masks[party_id], &db_sizes, 0); + codes_engine.fetch_results(&mut results_codes[party_id], &db_sizes, 0, 0); + masks_engine.fetch_results(&mut results_masks[party_id], &db_sizes, 0, 0); } // Reconstruct the results diff --git a/iris-mpc-gpu/src/helpers/query_processor.rs b/iris-mpc-gpu/src/helpers/query_processor.rs index de27e4e5d..133b80347 100644 --- a/iris-mpc-gpu/src/helpers/query_processor.rs +++ b/iris-mpc-gpu/src/helpers/query_processor.rs @@ -249,6 +249,7 @@ impl DeviceCompactSums { db_sizes, offset, streams, + 0, ); mask_engine.dot_reduce_and_multiply( &self.mask_query, @@ -257,6 +258,7 @@ impl DeviceCompactSums { offset, streams, 2, + 0, ); } @@ -270,6 +272,7 @@ impl DeviceCompactSums { database_sizes: &[usize], offset: usize, streams: &[CudaStream], + results_idx: usize, ) { code_engine.dot_reduce( &self.code_query, @@ -277,6 +280,7 @@ impl DeviceCompactSums { database_sizes, offset, streams, + results_idx, ); mask_engine.dot_reduce_and_multiply( &self.mask_query, @@ -285,6 +289,7 @@ impl DeviceCompactSums { offset, streams, 2, + results_idx, ); } } diff --git a/iris-mpc-gpu/src/server/actor.rs b/iris-mpc-gpu/src/server/actor.rs index 6711c0884..9071ecdcb 100644 --- a/iris-mpc-gpu/src/server/actor.rs +++ b/iris-mpc-gpu/src/server/actor.rs @@ -77,44 +77,48 @@ const KDF_SALT: &str = "111a1a93518f670e9bb0c2c68888e2beb9406d4c4ed571dc77b801e6 const SUPERMATCH_THRESHOLD: usize = 4_000; pub struct ServerActor { - job_queue: mpsc::Receiver, - device_manager: Arc, - party_id: usize, + job_queue: mpsc::Receiver, + device_manager: Arc, + party_id: usize, // engines - codes_engine: ShareDB, - masks_engine: ShareDB, - batch_codes_engine: ShareDB, - batch_masks_engine: ShareDB, - phase2: Circuits, - phase2_batch: Circuits, - distance_comparator: DistanceComparator, - comms: Vec>, + codes_engine: ShareDB, + masks_engine: ShareDB, + batch_codes_engine: ShareDB, + batch_masks_engine: ShareDB, + phase2: Circuits, + phase2_batch: Circuits, + distance_comparator: DistanceComparator, + comms: Vec>, // DB slices - left_code_db_slices: SlicedProcessedDatabase, - left_mask_db_slices: SlicedProcessedDatabase, - right_code_db_slices: SlicedProcessedDatabase, - right_mask_db_slices: SlicedProcessedDatabase, - streams: Vec>, - cublas_handles: Vec>, - results: Vec>, - batch_results: Vec>, - final_results: Vec>, - db_match_list_left: Vec>, - db_match_list_right: Vec>, - batch_match_list_left: Vec>, + left_code_db_slices: SlicedProcessedDatabase, + left_mask_db_slices: SlicedProcessedDatabase, + right_code_db_slices: SlicedProcessedDatabase, + right_mask_db_slices: SlicedProcessedDatabase, + streams: Vec>, + cublas_handles: Vec>, + results: Vec>, + batch_results: Vec>, + final_results: Vec>, + db_match_list_left: Vec>, + db_match_list_right: Vec>, + batch_match_list_left: Vec>, batch_match_list_right: Vec>, - current_db_sizes: Vec, - query_db_size: Vec, - max_batch_size: usize, - max_db_size: usize, + current_db_sizes: Vec, + query_db_size: Vec, + max_batch_size: usize, + max_db_size: usize, return_partial_results: bool, - disable_persistence: bool, - enable_debug_timing: bool, - code_chunk_buffers: Vec, - mask_chunk_buffers: Vec, - dot_events: Vec>, - exchange_events: Vec>, - phase2_events: Vec>, + disable_persistence: bool, + enable_debug_timing: bool, + code_chunk_buffers: Vec, + mask_chunk_buffers: Vec, + dot_events: Vec>, + exchange_events: Vec>, + phase2_events: Vec>, + match_distances_buffer_left: Vec>, + match_distances_buffer_right: Vec>, + match_distances_counter_left: Vec>, + match_distances_counter_right: Vec>, } const NON_MATCH_ID: u32 = u32::MAX; @@ -350,6 +354,14 @@ impl ServerActor { let exchange_events = vec![device_manager.create_events(); 2]; let phase2_events = vec![device_manager.create_events(); 2]; + // Buffers and counters for match distribution + let match_distances_buffer_left = + distance_comparator.prepare_match_distances_buffer(1_000_000); // TODO + let match_distances_buffer_right = + distance_comparator.prepare_match_distances_buffer(1_000_000); // TODO + let match_distances_counter_left = distance_comparator.prepare_match_distances_counter(); + let match_distances_counter_right = distance_comparator.prepare_match_distances_counter(); + for dev in device_manager.devices() { dev.synchronize().unwrap(); } @@ -391,6 +403,10 @@ impl ServerActor { dot_events, exchange_events, phase2_events, + match_distances_buffer_left, + match_distances_buffer_right, + match_distances_counter_left, + match_distances_counter_right, }) } @@ -1102,18 +1118,22 @@ impl ServerActor { { tracing::info!(party_id = self.party_id, "batch_reshare start"); self.batch_codes_engine - .reshare_results(&self.query_db_size, batch_streams); + .reshare_results(&self.query_db_size, batch_streams, 0); tracing::info!(party_id = self.party_id, "batch_reshare masks start"); self.batch_masks_engine - .reshare_results(&self.query_db_size, batch_streams); + .reshare_results(&self.query_db_size, batch_streams, 0); tracing::info!(party_id = self.party_id, "batch_reshare end"); } ); let db_sizes_batch = vec![self.max_batch_size * ROTATIONS; self.device_manager.device_count()]; - let code_dots_batch = self.batch_codes_engine.result_chunk_shares(&db_sizes_batch); - let mask_dots_batch = self.batch_masks_engine.result_chunk_shares(&db_sizes_batch); + let code_dots_batch = self + .batch_codes_engine + .result_chunk_shares(&db_sizes_batch, 0); + let mask_dots_batch = self + .batch_masks_engine + .result_chunk_shares(&db_sizes_batch, 0); record_stream_time!( &self.device_manager, @@ -1297,6 +1317,7 @@ impl ServerActor { &dot_chunk_size, offset, request_streams, + db_chunk_idx % 2, ); } ); @@ -1311,10 +1332,16 @@ impl ServerActor { "db_reshare", self.enable_debug_timing, { - self.codes_engine - .reshare_results(&dot_chunk_size, request_streams); - self.masks_engine - .reshare_results(&dot_chunk_size, request_streams); + self.codes_engine.reshare_results( + &dot_chunk_size, + request_streams, + db_chunk_idx % 2, + ); + self.masks_engine.reshare_results( + &dot_chunk_size, + request_streams, + db_chunk_idx % 2, + ); } ); @@ -1326,8 +1353,12 @@ impl ServerActor { // ---- START PHASE 2 ---- let max_chunk_size = dot_chunk_size.iter().max().copied().unwrap(); let phase_2_chunk_sizes = vec![max_chunk_size; self.device_manager.device_count()]; - let code_dots = self.codes_engine.result_chunk_shares(&phase_2_chunk_sizes); - let mask_dots = self.masks_engine.result_chunk_shares(&phase_2_chunk_sizes); + let code_dots = self + .codes_engine + .result_chunk_shares(&phase_2_chunk_sizes, db_chunk_idx % 2); + let mask_dots = self + .masks_engine + .result_chunk_shares(&phase_2_chunk_sizes, db_chunk_idx % 2); { assert_eq!( (max_chunk_size * self.max_batch_size * ROTATIONS) % 64,