From fabc34acf8f92a28e6dff9ef1147abe53055f71e Mon Sep 17 00:00:00 2001 From: Philipp Sippl Date: Mon, 19 Aug 2024 13:53:09 +0200 Subject: [PATCH 1/3] Fix: use correct db size (#252) use correct size --- iris-mpc/src/bin/server.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/iris-mpc/src/bin/server.rs b/iris-mpc/src/bin/server.rs index cc639f7b1..5aa9b74a6 100644 --- a/iris-mpc/src/bin/server.rs +++ b/iris-mpc/src/bin/server.rs @@ -435,6 +435,11 @@ async fn server_main(config: Config) -> eyre::Result<()> { }; if let Some(db_len) = sync_result.must_rollback_storage() { + tracing::warn!( + "Databases are out-of-sync, rolling back (current len: {}, new len: {})", + store_len, + db_len + ); // Rollback the data that we have already loaded. let bit_len = db_len * IRIS_CODE_LENGTH; // TODO: remove the line below if you removed fake data. @@ -454,7 +459,7 @@ async fn server_main(config: Config) -> eyre::Result<()> { device_manager, comms, 8, - DB_SIZE, + store_len + DB_SIZE, // TODO: remove DB_SIZE you removed fake data. DB_BUFFER, ) { Ok((actor, handle)) => { From 3f062b83f62199681a2f7028abea7aee85da934b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 19 Aug 2024 13:59:32 +0200 Subject: [PATCH 2/3] build(deps): bump aws-sdk-secretsmanager from 1.40.0 to 1.41.0 (#229) Bumps [aws-sdk-secretsmanager](https://github.com/awslabs/aws-sdk-rust) from 1.40.0 to 1.41.0. - [Release notes](https://github.com/awslabs/aws-sdk-rust/releases) - [Commits](https://github.com/awslabs/aws-sdk-rust/commits) --- updated-dependencies: - dependency-name: aws-sdk-secretsmanager dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 13 +++++++------ Cargo.toml | 2 +- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bf34ca459..2f2b1cdee 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -192,9 +192,9 @@ dependencies = [ [[package]] name = "aws-runtime" -version = "1.3.1" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87c5f920ffd1e0526ec9e70e50bf444db50b204395a0fa7016bbf9e31ea1698f" +checksum = "f42c2d4218de4dcd890a109461e2f799a1a2ba3bcd2cde9af88360f5df9266c6" dependencies = [ "aws-credential-types", "aws-sigv4", @@ -208,6 +208,7 @@ dependencies = [ "fastrand 2.1.0", "http 0.2.12", "http-body 0.4.6", + "once_cell", "percent-encoding", "pin-project-lite", "tracing", @@ -273,9 +274,9 @@ dependencies = [ [[package]] name = "aws-sdk-secretsmanager" -version = "1.40.0" +version = "1.41.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "66a0cc1d41792d2d383746c154f48521715c50f5d59e9cdf36ef763de3c2345f" +checksum = "ebe053ffc4ffe9e15de3c1354a06a64b5fda92f3f6f1013fa4a3276694085d8b" dependencies = [ "aws-credential-types", "aws-runtime", @@ -547,9 +548,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime-api" -version = "1.7.1" +version = "1.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30819352ed0a04ecf6a2f3477e344d2d1ba33d43e0f09ad9047c12e0d923616f" +checksum = "e086682a53d3aa241192aa110fa8dfce98f2f5ac2ead0de84d41582c7e8fdb96" dependencies = [ "aws-smithy-async", "aws-smithy-types", diff --git a/Cargo.toml b/Cargo.toml index d7222ccf8..89e329432 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ aws-sdk-kms = { version = "1.37.0" } aws-sdk-sns = { version = "1.37.0" } aws-sdk-sqs = { version = "1.36.0" } aws-sdk-s3 = { version = "1.42.0" } -aws-sdk-secretsmanager = { version = "1.40.0" } +aws-sdk-secretsmanager = { version = "1.41.0" } axum = "0.7" clap = { version = "4", features = ["derive", "env"] } base64 = "0.22.1" From c8caa4b3340695765f6c6177a169156557d04e47 Mon Sep 17 00:00:00 2001 From: Philipp Sippl Date: Mon, 19 Aug 2024 16:21:59 +0200 Subject: [PATCH 3/3] Join results of both eyes (#249) * wip * wip * wip * reset memory * clippy * cleanup * use same matching policy in batch --- iris-mpc-gpu/src/dot/distance_comparator.rs | 105 +++++++--- iris-mpc-gpu/src/dot/kernel.cu | 89 +++++--- iris-mpc-gpu/src/server/actor.rs | 215 +++++++++++--------- 3 files changed, 252 insertions(+), 157 deletions(-) diff --git a/iris-mpc-gpu/src/dot/distance_comparator.rs b/iris-mpc-gpu/src/dot/distance_comparator.rs index e365d9149..206e665b3 100644 --- a/iris-mpc-gpu/src/dot/distance_comparator.rs +++ b/iris-mpc-gpu/src/dot/distance_comparator.rs @@ -8,12 +8,14 @@ use std::sync::Arc; const PTX_SRC: &str = include_str!("kernel.cu"); const OPEN_RESULTS_FUNCTION: &str = "openResults"; -const MERGE_RESULTS_FUNCTION: &str = "mergeResults"; +const MERGE_DB_RESULTS_FUNCTION: &str = "mergeDbResults"; +const MERGE_BATCH_RESULTS_FUNCTION: &str = "mergeBatchResults"; pub struct DistanceComparator { pub device_manager: Arc, pub open_kernels: Vec, - pub merge_kernels: Vec, + pub merge_db_kernels: Vec, + pub merge_batch_kernels: Vec, pub query_length: usize, pub opened_results: Vec>, pub final_results: Vec>, @@ -25,7 +27,8 @@ impl DistanceComparator { pub fn init(query_length: usize, device_manager: Arc) -> Self { let ptx = compile_ptx(PTX_SRC).unwrap(); let mut open_kernels = Vec::new(); - let mut merge_kernels = Vec::new(); + let mut merge_db_kernels = Vec::new(); + let mut merge_batch_kernels = Vec::new(); let mut opened_results = vec![]; let mut final_results = vec![]; @@ -39,24 +42,29 @@ impl DistanceComparator { device .load_ptx(ptx.clone(), "", &[ OPEN_RESULTS_FUNCTION, - MERGE_RESULTS_FUNCTION, + MERGE_DB_RESULTS_FUNCTION, + MERGE_BATCH_RESULTS_FUNCTION, ]) .unwrap(); let open_results_function = device.get_func("", OPEN_RESULTS_FUNCTION).unwrap(); - let merge_results_function = device.get_func("", MERGE_RESULTS_FUNCTION).unwrap(); + let merge_db_results_function = device.get_func("", MERGE_DB_RESULTS_FUNCTION).unwrap(); + let merge_batch_results_function = + device.get_func("", MERGE_BATCH_RESULTS_FUNCTION).unwrap(); opened_results.push(device.htod_copy(results_init_host.clone()).unwrap()); final_results.push(device.htod_copy(final_results_init_host.clone()).unwrap()); open_kernels.push(open_results_function); - merge_kernels.push(merge_results_function); + merge_db_kernels.push(merge_db_results_function); + merge_batch_kernels.push(merge_batch_results_function); } Self { device_manager, open_kernels, - merge_kernels, + merge_db_kernels, + merge_batch_kernels, query_length, opened_results, final_results, @@ -71,10 +79,11 @@ impl DistanceComparator { results1: &[CudaView], results2: &[CudaView], results3: &[CudaView], - results_ptrs: &[CudaSlice], + matches_bitmap: &[CudaSlice], db_sizes: &[usize], real_db_sizes: &[usize], offset: usize, + total_db_sizes: &[usize], streams: &[CudaStream], ) { for i in 0..self.device_manager.device_count() { @@ -98,12 +107,13 @@ impl DistanceComparator { &results1[i], &results2[i], &results3[i], - &results_ptrs[i], + &matches_bitmap[i], db_sizes[i], self.query_length, offset, num_elements, real_db_sizes[i], + total_db_sizes[i], ), ) .unwrap(); @@ -111,34 +121,72 @@ impl DistanceComparator { } } - pub fn merge_results( + pub fn join_db_matches( &self, - match_results_self: &[CudaSlice], - match_results: &[CudaSlice], + matches_bitmap_left: &[CudaSlice], + matches_bitmap_right: &[CudaSlice], + final_results: &[CudaSlice], + db_sizes: &[usize], + streams: &[CudaStream], + ) { + self.join_matches( + matches_bitmap_left, + matches_bitmap_right, + final_results, + db_sizes, + streams, + &self.merge_db_kernels, + ); + } + + pub fn join_batch_matches( + &self, + matches_bitmap_left: &[CudaSlice], + matches_bitmap_right: &[CudaSlice], final_results: &[CudaSlice], streams: &[CudaStream], ) { - let num_elements = self.query_length / ROTATIONS; - let threads_per_block = 256; - let blocks_per_grid = num_elements.div_ceil(threads_per_block); - let cfg = LaunchConfig { - block_dim: (threads_per_block as u32, 1, 1), - grid_dim: (blocks_per_grid as u32, 1, 1), - shared_mem_bytes: 0, - }; + self.join_matches( + matches_bitmap_left, + matches_bitmap_right, + final_results, + &vec![self.query_length; self.device_manager.device_count()], + streams, + &self.merge_batch_kernels, + ); + } + fn join_matches( + &self, + matches_bitmap_left: &[CudaSlice], + matches_bitmap_right: &[CudaSlice], + final_results: &[CudaSlice], + db_sizes: &[usize], + streams: &[CudaStream], + kernels: &[CudaFunction], + ) { for i in 0..self.device_manager.device_count() { + let num_elements = (db_sizes[i] * self.query_length / ROTATIONS).div_ceil(64); + let threads_per_block = 256; + let blocks_per_grid = num_elements.div_ceil(threads_per_block); + let cfg = LaunchConfig { + block_dim: (threads_per_block as u32, 1, 1), + grid_dim: (blocks_per_grid as u32, 1, 1), + shared_mem_bytes: 0, + }; unsafe { - self.merge_kernels[i] + kernels[i] .clone() .launch_on_stream( &streams[i], cfg, ( - &match_results_self[i], - &match_results[i], + &matches_bitmap_left[i], + &matches_bitmap_right[i], &final_results[i], (self.query_length / ROTATIONS) as u64, + db_sizes[i] as u64, + num_elements as u64, ), ) .unwrap(); @@ -180,4 +228,15 @@ impl DistanceComparator { }) .collect::>() } + + pub fn prepare_db_match_list(&self, db_size: usize) -> Vec> { + (0..self.device_manager.device_count()) + .map(|i| { + self.device_manager + .device(i) + .alloc_zeros(db_size * self.query_length / ROTATIONS / 64) + .unwrap() + }) + .collect::>() + } } diff --git a/iris-mpc-gpu/src/dot/kernel.cu b/iris-mpc-gpu/src/dot/kernel.cu index c471578a6..e11e648c1 100644 --- a/iris-mpc-gpu/src/dot/kernel.cu +++ b/iris-mpc-gpu/src/dot/kernel.cu @@ -1,11 +1,14 @@ #define UINT_MAX 0xffffffff #define ROTATIONS 15 +#define ALL_ROTATIONS (2 * ROTATIONS + 1) #define IRIS_CODE_LENGTH 12800 #define U8 unsigned char -extern "C" __global__ void xor_assign_u8(U8 *lhs, U8 *rhs, int n) { +extern "C" __global__ void xor_assign_u8(U8 *lhs, U8 *rhs, int n) +{ int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < n) { + if (i < n) + { lhs[i] ^= rhs[i]; } } @@ -23,7 +26,7 @@ extern "C" __global__ void matmul_correct_and_reduce(int *c, unsigned short *out } } -extern "C" __global__ void openResults(unsigned long long *result1, unsigned long long *result2, unsigned long long *result3, unsigned int *output, size_t dbLength, size_t queryLength, size_t offset, size_t numElements, size_t realDbLen) +extern "C" __global__ void openResults(unsigned long long *result1, unsigned long long *result2, unsigned long long *result3, unsigned long long *output, size_t chunkLength, size_t queryLength, size_t offset, size_t numElements, size_t realChunkLen, size_t totalDbLen) { size_t idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < numElements) @@ -31,53 +34,73 @@ extern "C" __global__ void openResults(unsigned long long *result1, unsigned lon unsigned long long result = result1[idx] ^ result2[idx] ^ result3[idx]; for (int i = 0; i < 64; i++) { - unsigned int queryIdx = (idx * 64 + i) / dbLength; - unsigned int dbIdx = (idx * 64 + i) % dbLength; + unsigned int queryIdx = (idx * 64 + i) / chunkLength; + unsigned int dbIdx = (idx * 64 + i) % chunkLength; bool match = (result & (1ULL << i)); // Check if we are out of bounds for the query or db - if (queryIdx >= queryLength || dbIdx >= realDbLen) { + if (queryIdx >= queryLength || dbIdx >= realChunkLen || !match) + { continue; } - // return db element with smallest index - if (match) - atomicMin(&output[queryIdx], dbIdx + offset); + unsigned int outputIdx = totalDbLen * (queryIdx / ALL_ROTATIONS) + dbIdx + offset; + atomicOr(&output[outputIdx / 64], (1ULL << (outputIdx % 64))); } } } -extern "C" __global__ void mergeResults(unsigned int *matchResultsSelf, unsigned int *matchResults, unsigned int *finalResults, size_t queryLength) +extern "C" __global__ void mergeDbResults(unsigned long long *matchResultsLeft, unsigned long long *matchResultsRight, unsigned int *finalResults, size_t queryLength, size_t dbLength, size_t numElements) { size_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < queryLength) + if (idx < numElements) { - bool match = false; - - // Check if there is a match in the db - for (int r = 0; r <= ROTATIONS * 2; r++) + for (int i = 0; i < 64; i++) { - int oldIdx = idx * (2 * ROTATIONS + 1) + r; - if (matchResults[oldIdx] != UINT_MAX) - { - finalResults[idx] = matchResults[oldIdx]; - match = true; - } - } + unsigned int queryIdx = (idx * 64 + i) / dbLength; + unsigned int dbIdx = (idx * 64 + i) % dbLength; + bool matchLeft = (matchResultsLeft[idx] & (1ULL << i)); + bool matchRight = (matchResultsRight[idx] & (1ULL << i)); - // If there is a match in the db, we return the db index - if (match) - return; + // Check bounds + if (queryIdx >= queryLength || dbIdx >= dbLength) + continue; - // Check if there is a match in the query itelf - // We only need to check a single query, since we don't want to rotate double - int oldIdx = idx * (2 * ROTATIONS + 1) + ROTATIONS; - if (matchResultsSelf[oldIdx] != UINT_MAX && oldIdx != matchResultsSelf[oldIdx]) - { - finalResults[idx] = UINT_MAX - 1; // Set to UINT_MAX - 1 to indicate that the match is in the query itself - return; + // Current *AND* policy: only match if both eyes match + if (matchLeft && matchRight) + atomicMin(&finalResults[queryIdx], dbIdx); } + } +} - finalResults[idx] = UINT_MAX; +extern "C" __global__ void mergeBatchResults(unsigned long long *matchResultsSelfLeft, unsigned long long *matchResultsSelfRight, unsigned int *finalResults, size_t queryLength, size_t dbLength, size_t numElements) +{ + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < numElements) + { + for (int i = 0; i < 64; i++) + { + unsigned int queryIdx = (idx * 64 + i) / dbLength; + unsigned int dbIdx = (idx * 64 + i) % dbLength; + + // Check bounds + if (queryIdx >= queryLength || dbIdx >= dbLength) + continue; + + // Query is already considering rotations, ignore rotated db entries + if ((dbIdx - ROTATIONS) % ALL_ROTATIONS != 0) + continue; + + // Only consider results above the diagonal + if (queryIdx <= dbIdx / ALL_ROTATIONS) + continue; + + bool matchLeft = (matchResultsSelfLeft[idx] & (1ULL << i)); + bool matchRight = (matchResultsSelfRight[idx] & (1ULL << i)); + + // Current *AND* policy: only match if both eyes match + if (matchLeft && matchRight) + atomicMin(&finalResults[queryIdx], UINT_MAX - 1); + } } } diff --git a/iris-mpc-gpu/src/server/actor.rs b/iris-mpc-gpu/src/server/actor.rs index ca1d42728..f153820da 100644 --- a/iris-mpc-gpu/src/server/actor.rs +++ b/iris-mpc-gpu/src/server/actor.rs @@ -17,7 +17,7 @@ use cudarc::{ driver::{ result::{self, event::elapsed}, sys::CUevent, - CudaDevice, CudaSlice, CudaStream, DevicePtr, + CudaDevice, CudaSlice, CudaStream, DevicePtr, DeviceSlice, }, nccl::Comm, }; @@ -69,37 +69,37 @@ impl ServerActorHandle { const DB_CHUNK_SIZE: usize = 512; const QUERIES: usize = ROTATIONS * MAX_BATCH_SIZE; pub struct ServerActor { - job_queue: mpsc::Receiver, - device_manager: Arc, - party_id: usize, + job_queue: mpsc::Receiver, + device_manager: Arc, + party_id: usize, // engines - codes_engine: ShareDB, - masks_engine: ShareDB, - batch_codes_engine: ShareDB, - batch_masks_engine: ShareDB, - phase2: Circuits, - phase2_batch: Circuits, - distance_comparator: DistanceComparator, + codes_engine: ShareDB, + masks_engine: ShareDB, + batch_codes_engine: ShareDB, + batch_masks_engine: ShareDB, + phase2: Circuits, + phase2_batch: Circuits, + distance_comparator: DistanceComparator, // DB slices - left_code_db_slices: SlicedProcessedDatabase, - left_mask_db_slices: SlicedProcessedDatabase, - right_code_db_slices: SlicedProcessedDatabase, - right_mask_db_slices: SlicedProcessedDatabase, - streams: Vec>, - cublas_handles: Vec>, - results: Vec>, - batch_results: Vec>, - final_results: Vec>, - current_db_sizes: Vec, - query_db_size: Vec, + left_code_db_slices: SlicedProcessedDatabase, + left_mask_db_slices: SlicedProcessedDatabase, + right_code_db_slices: SlicedProcessedDatabase, + right_mask_db_slices: SlicedProcessedDatabase, + streams: Vec>, + cublas_handles: Vec>, + results: Vec>, + batch_results: Vec>, + final_results: Vec>, + db_match_list_left: Vec>, + db_match_list_right: Vec>, + batch_match_list_left: Vec>, + batch_match_list_right: Vec>, + current_db_sizes: Vec, + query_db_size: Vec, } const NON_MATCH_ID: u32 = u32::MAX; -const RESULTS_INIT_HOST: [u32; MAX_BATCH_SIZE * ROTATIONS] = - [NON_MATCH_ID; MAX_BATCH_SIZE * ROTATIONS]; -const FINAL_RESULTS_INIT_HOST: [u32; MAX_BATCH_SIZE] = [NON_MATCH_ID; MAX_BATCH_SIZE]; - impl ServerActor { #[allow(clippy::too_many_arguments)] pub fn new( @@ -319,6 +319,13 @@ impl ServerActor { let results = distance_comparator.prepare_results(); let batch_results = distance_comparator.prepare_results(); + let db_match_list_left = distance_comparator + .prepare_db_match_list((db_size + db_buffer) / device_manager.device_count()); + let db_match_list_right = distance_comparator + .prepare_db_match_list((db_size + db_buffer) / device_manager.device_count()); + let batch_match_list_left = distance_comparator.prepare_db_match_list(QUERIES); + let batch_match_list_right = distance_comparator.prepare_db_match_list(QUERIES); + let query_db_size = vec![QUERIES; device_manager.device_count()]; for dev in device_manager.devices() { @@ -347,6 +354,10 @@ impl ServerActor { final_results, current_db_sizes, query_db_size, + db_match_list_left, + db_match_list_right, + batch_match_list_left, + batch_match_list_right, }) } @@ -421,13 +432,12 @@ impl ServerActor { &self.cublas_handles[0], )?; - let merged_results_left = self.compare_query_against_db_and_self( + self.compare_query_against_db_and_self( &compact_device_queries_left, &compact_device_sums_left, &mut events, - batch_size, Eye::Left, - )?; + ); /////////////////////////////////////////////////////////////////// // COMPARE RIGHT EYE QUERIES @@ -464,31 +474,35 @@ impl ServerActor { &self.cublas_handles[0], )?; - let merged_results_right = self.compare_query_against_db_and_self( + self.compare_query_against_db_and_self( &compact_device_queries_right, &compact_device_sums_right, &mut events, - batch_size, Eye::Right, - )?; + ); /////////////////////////////////////////////////////////////////// // MERGE LEFT & RIGHT results /////////////////////////////////////////////////////////////////// - let mut merged_results = merged_results_left - .into_iter() - .zip(merged_results_right) - .map(|(left, right)| { - // If both eyes are matches with the same ID, return the ID - // This also covers the case where both are non-matches, since we return - // NON_MATCH in that case as well - if left == right { - left - } else { - NON_MATCH_ID - } - }) - .collect::>(); + + // Merge results and fetch matching indices + // Format: host_results[device_index][query_index] + self.distance_comparator.join_db_matches( + &self.db_match_list_left, + &self.db_match_list_right, + &self.final_results, + &self.current_db_sizes, + &self.streams[0], + ); + + self.distance_comparator.join_batch_matches( + &self.batch_match_list_left, + &self.batch_match_list_right, + &self.final_results, + &self.streams[0], + ); + + self.device_manager.await_streams(&self.streams[0]); // Iterate over a list of tracing payloads, and create logs with mappings to // payloads Log at least a "start" event using a log with trace.id @@ -501,6 +515,20 @@ impl ServerActor { "Protocol finished", ); } + + // Fetch the final results (blocking) + let mut host_results = self + .distance_comparator + .fetch_final_results(&self.final_results); + + // Truncate the results to the batch size + host_results.iter_mut().for_each(|x| x.truncate(batch_size)); + + // Evaluate the results across devices + // Format: merged_results[query_index] + let mut merged_results = + get_merged_results(&host_results, self.device_manager.device_count()); + // List the indices of the queries that did not match. let insertion_list = merged_results .iter() @@ -623,6 +651,16 @@ impl ServerActor { self.device_manager.await_streams(&self.streams[0]); self.device_manager.await_streams(&self.streams[1]); + // Reset the results buffers for reuse + for dst in &[ + &self.db_match_list_left, + &self.db_match_list_right, + &self.batch_match_list_left, + &self.batch_match_list_right, + ] { + reset_slice(self.device_manager.devices(), dst, 0, &self.streams[0]); + } + // ---- END RESULT PROCESSING ---- log_timers(events); @@ -639,9 +677,8 @@ impl ServerActor { compact_device_queries: &DeviceCompactQuery, compact_device_sums: &DeviceCompactSums, events: &mut HashMap<&str, Vec>>, - batch_size: usize, eye_db: Eye, - ) -> eyre::Result> { + ) { let batch_streams = &self.streams[0]; let batch_cublas = &self.cublas_handles[0]; @@ -650,6 +687,12 @@ impl ServerActor { Eye::Left => (&self.left_code_db_slices, &self.left_mask_db_slices), Eye::Right => (&self.right_code_db_slices, &self.right_mask_db_slices), }; + + let (db_match_bitmap, batch_match_bitmap) = match eye_db { + Eye::Left => (&self.db_match_list_left, &self.batch_match_list_left), + Eye::Right => (&self.db_match_list_right, &self.batch_match_list_right), + }; + // Transfer queries to device // ---- START BATCH DEDUP ---- @@ -712,11 +755,12 @@ impl ServerActor { &mut self.phase2_batch, &res, &self.distance_comparator, - &self.batch_results, + batch_match_bitmap, chunk_size, &db_sizes_batch, &db_sizes_batch, 0, + &db_sizes_batch, batch_streams, ); self.phase2_batch.return_result_buffer(res); @@ -882,11 +926,12 @@ impl ServerActor { &mut self.phase2, &res, &self.distance_comparator, - &self.results, + db_match_bitmap, max_chunk_size * QUERIES / 64, &dot_chunk_size, &chunk_size, offset, + &self.current_db_sizes, request_streams, ); self.phase2.return_result_buffer(res); @@ -933,52 +978,10 @@ impl ServerActor { self.device_manager.await_streams(&self.streams[1]); tracing::debug!(party_id = self.party_id, "db search finished"); - // ---- START RESULT PROCESSING ---- - - // Merge results and fetch matching indices - // Format: host_results[device_index][query_index] - self.distance_comparator.merge_results( - &self.batch_results, - &self.results, - &self.final_results, - &self.streams[0], - ); - - self.device_manager.await_streams(&self.streams[0]); - - // Fetch the final results (blocking) - let mut host_results = self - .distance_comparator - .fetch_final_results(&self.final_results); - - // Truncate the results to the batch size - host_results.iter_mut().for_each(|x| x.truncate(batch_size)); - - // Evaluate the results across devices - // Format: merged_results[query_index] - let merged_results = get_merged_results(&host_results, self.device_manager.device_count()); - // Reset the results buffers for reuse - reset_results( - self.device_manager.devices(), - &self.results, - &RESULTS_INIT_HOST, - &self.streams[0], - ); - reset_results( - self.device_manager.devices(), - &self.batch_results, - &RESULTS_INIT_HOST, - &self.streams[0], - ); - reset_results( - self.device_manager.devices(), - &self.final_results, - &FINAL_RESULTS_INIT_HOST, - &self.streams[0], - ); - - Ok(merged_results) + for dst in &[&self.results, &self.batch_results, &self.final_results] { + reset_slice(self.device_manager.devices(), dst, 0xff, &self.streams[0]); + } } } @@ -1027,11 +1030,12 @@ fn open( party: &mut Circuits, x: &[ChunkShare], distance_comparator: &DistanceComparator, - results_ptrs: &[CudaSlice], + matches_bitmap: &[CudaSlice], chunk_size: usize, db_sizes: &[usize], real_db_sizes: &[usize], offset: usize, + total_db_sizes: &[usize], streams: &[CudaStream], ) { let n_devices = x.len(); @@ -1062,10 +1066,11 @@ fn open( &a, &b, &c, - results_ptrs, + matches_bitmap, db_sizes, real_db_sizes, offset, + total_db_sizes, streams, ); } @@ -1109,15 +1114,23 @@ fn distribute_insertions(results: &[usize], db_sizes: &[usize]) -> Vec( devs: &[Arc], - dst: &[CudaSlice], - src: &[u32], + dst: &[CudaSlice], + value: u8, streams: &[CudaStream], ) { for i in 0..devs.len() { devs[i].bind_to_thread().unwrap(); - unsafe { result::memcpy_htod_async(*dst[i].device_ptr(), src, streams[i].stream) }.unwrap(); + unsafe { + result::memset_d8_async( + *dst[i].device_ptr(), + value, + dst[i].num_bytes(), + streams[i].stream, + ) + .unwrap(); + }; } }