From c2ba1ebe60846f34184ac66b7c6968e6de0e2502 Mon Sep 17 00:00:00 2001 From: philsippl Date: Sun, 5 Jan 2025 16:55:10 +0000 Subject: [PATCH] fix: ignore phantom matchers --- iris-mpc-gpu/src/dot/distance_comparator.rs | 3 ++- iris-mpc-gpu/src/dot/kernel.cu | 14 +++++++++----- iris-mpc-gpu/src/server/actor.rs | 15 +++++++++++++++ 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/iris-mpc-gpu/src/dot/distance_comparator.rs b/iris-mpc-gpu/src/dot/distance_comparator.rs index 1e9285e87..266eae775 100644 --- a/iris-mpc-gpu/src/dot/distance_comparator.rs +++ b/iris-mpc-gpu/src/dot/distance_comparator.rs @@ -144,6 +144,7 @@ impl DistanceComparator { match_distances_counters: &[CudaSlice], code_dots: &[ChunkShareView], mask_dots: &[ChunkShareView], + batch_size: usize, streams: &[CudaStream], ) { for i in 0..self.device_manager.device_count() { @@ -171,7 +172,7 @@ impl DistanceComparator { ptr_param(results3[i].device_ptr()), ptr_param(matches_bitmap[i].device_ptr()), usize_param(&db_sizes[i]), - usize_param(&self.query_length), + usize_param(&(batch_size * ROTATIONS)), usize_param(&offset), usize_param(&num_elements), usize_param(&real_db_sizes[i]), diff --git a/iris-mpc-gpu/src/dot/kernel.cu b/iris-mpc-gpu/src/dot/kernel.cu index be37ef6fc..16084f17d 100644 --- a/iris-mpc-gpu/src/dot/kernel.cu +++ b/iris-mpc-gpu/src/dot/kernel.cu @@ -70,12 +70,16 @@ extern "C" __global__ void openResults(unsigned long long *result1, unsigned lon } // Save the corresponding code and mask dots for later (match distributions) - // unsigned int match_distances_counter_idx = atomicAdd(&match_distances_counter[0], 1); - // match_distances_buffer_codes_a[match_distances_counter_idx] = code_dots_a[idx]; - // match_distances_buffer_codes_b[match_distances_counter_idx] = code_dots_b[idx]; - // match_distances_buffer_masks_a[match_distances_counter_idx] = mask_dots_a[idx]; - // match_distances_buffer_masks_b[match_distances_counter_idx] = mask_dots_b[idx]; + if (match_distances_counter[0] < UINT_MAX) + { + unsigned int match_distances_counter_idx = atomicAdd(&match_distances_counter[0], 1); + match_distances_buffer_codes_a[match_distances_counter_idx] = code_dots_a[idx]; + match_distances_buffer_codes_b[match_distances_counter_idx] = code_dots_b[idx]; + match_distances_buffer_masks_a[match_distances_counter_idx] = mask_dots_a[idx]; + match_distances_buffer_masks_b[match_distances_counter_idx] = mask_dots_b[idx]; + } + // Mark which results are matches with a bit in the output unsigned int outputIdx = totalDbLen * (queryIdx / ALL_ROTATIONS) + dbIdx + offset; atomicOr(&output[outputIdx / 64], (1ULL << (outputIdx % 64))); } diff --git a/iris-mpc-gpu/src/server/actor.rs b/iris-mpc-gpu/src/server/actor.rs index 1ac951817..f8031b67b 100644 --- a/iris-mpc-gpu/src/server/actor.rs +++ b/iris-mpc-gpu/src/server/actor.rs @@ -680,6 +680,7 @@ impl ServerActor { &compact_device_sums_left, &mut events, Eye::Left, + batch_size, ); /////////////////////////////////////////////////////////////////// @@ -727,6 +728,7 @@ impl ServerActor { &compact_device_sums_right, &mut events, Eye::Right, + batch_size, ); /////////////////////////////////////////////////////////////////// @@ -1069,6 +1071,7 @@ impl ServerActor { compact_device_sums: &DeviceCompactSums, events: &mut HashMap<&str, Vec>>, eye_db: Eye, + batch_size: usize, ) { let batch_streams = &self.streams[0]; let batch_cublas = &self.cublas_handles[0]; @@ -1101,6 +1104,15 @@ impl ServerActor { ), }; + // copy counters to host + let counter = self + .device_manager + .device(0) + .dtoh_sync_copy(&match_distances_counters[0]) + .unwrap(); + + tracing::info!("counter: {:?}", counter); + // ---- START BATCH DEDUP ---- tracing::info!(party_id = self.party_id, "Starting batch deduplication"); @@ -1439,6 +1451,7 @@ impl ServerActor { match_distances_counters, &code_dots, &mask_dots, + batch_size, request_streams, ); self.phase2.return_result_buffer(res); @@ -1620,6 +1633,7 @@ fn open( match_distances_counters: &[CudaSlice], code_dots: &[ChunkShareView], mask_dots: &[ChunkShareView], + batch_size: usize, streams: &[CudaStream], ) { let n_devices = x.len(); @@ -1661,6 +1675,7 @@ fn open( match_distances_counters, code_dots, mask_dots, + batch_size, streams, ); }