diff --git a/iris-mpc-gpu/benches/matmul.rs b/iris-mpc-gpu/benches/matmul.rs index 9f8f5032a..3b478392b 100644 --- a/iris-mpc-gpu/benches/matmul.rs +++ b/iris-mpc-gpu/benches/matmul.rs @@ -58,7 +58,14 @@ fn bench_memcpy(c: &mut Criterion) { &streams, &blass, ); - engine.dot_reduce(&query_sums, &db_slices.code_sums_gr, &db_sizes, 0, &streams); + engine.dot_reduce( + &query_sums, + &db_slices.code_sums_gr, + &db_sizes, + 0, + &streams, + 0, + ); device_manager.await_streams(&streams); }); }); diff --git a/iris-mpc-gpu/src/dot/distance_comparator.rs b/iris-mpc-gpu/src/dot/distance_comparator.rs index ba5353ff6..1e9285e87 100644 --- a/iris-mpc-gpu/src/dot/distance_comparator.rs +++ b/iris-mpc-gpu/src/dot/distance_comparator.rs @@ -8,12 +8,12 @@ use crate::{ }; use cudarc::{ driver::{ - result::launch_kernel, CudaFunction, CudaSlice, CudaStream, CudaView, DevicePtr, + result::launch_kernel, sys, CudaFunction, CudaSlice, CudaStream, CudaView, DevicePtr, LaunchAsync, }, nvrtc::compile_ptx, }; -use std::{cmp::min, sync::Arc}; +use std::{cmp::min, ffi::c_void, sync::Arc}; const PTX_SRC: &str = include_str!("kernel.cu"); const OPEN_RESULTS_FUNCTION: &str = "openResults"; @@ -161,26 +161,30 @@ impl DistanceComparator { ); self.device_manager.device(i).bind_to_thread().unwrap(); + let ptr_param = |ptr: *const sys::CUdeviceptr| ptr as *mut c_void; + let usize_param = |val: &usize| val as *const usize as *mut _; + let params = &mut [ - *results1[i].device_ptr() as *mut _, - *results2[i].device_ptr() as *mut _, - *results3[i].device_ptr() as *mut _, - *matches_bitmap[i].device_ptr() as *mut _, - db_sizes[i] as *mut _, - self.query_length as *mut _, - offset as *mut _, - num_elements as *mut _, - real_db_sizes[i] as *mut _, - total_db_sizes[i] as *mut _, - *match_distances_buffers_codes[i].a.device_ptr() as *mut _, - *match_distances_buffers_codes[i].b.device_ptr() as *mut _, - *match_distances_buffers_masks[i].a.device_ptr() as *mut _, - *match_distances_buffers_masks[i].b.device_ptr() as *mut _, - *match_distances_counters[i].device_ptr() as *mut _, - *code_dots[i].a.device_ptr() as *mut _, - *code_dots[i].b.device_ptr() as *mut _, - *mask_dots[i].a.device_ptr() as *mut _, - *mask_dots[i].b.device_ptr() as *mut _, + // Results arrays + ptr_param(results1[i].device_ptr()), + ptr_param(results2[i].device_ptr()), + ptr_param(results3[i].device_ptr()), + ptr_param(matches_bitmap[i].device_ptr()), + usize_param(&db_sizes[i]), + usize_param(&self.query_length), + usize_param(&offset), + usize_param(&num_elements), + usize_param(&real_db_sizes[i]), + usize_param(&total_db_sizes[i]), + ptr_param(match_distances_buffers_codes[i].a.device_ptr()), + ptr_param(match_distances_buffers_codes[i].b.device_ptr()), + ptr_param(match_distances_buffers_masks[i].a.device_ptr()), + ptr_param(match_distances_buffers_masks[i].b.device_ptr()), + ptr_param(match_distances_counters[i].device_ptr()), + ptr_param(code_dots[i].a.device_ptr()), + ptr_param(code_dots[i].b.device_ptr()), + ptr_param(mask_dots[i].a.device_ptr()), + ptr_param(mask_dots[i].b.device_ptr()), ]; unsafe { diff --git a/iris-mpc-gpu/src/dot/kernel.cu b/iris-mpc-gpu/src/dot/kernel.cu index b8d43b90a..be37ef6fc 100644 --- a/iris-mpc-gpu/src/dot/kernel.cu +++ b/iris-mpc-gpu/src/dot/kernel.cu @@ -70,11 +70,11 @@ extern "C" __global__ void openResults(unsigned long long *result1, unsigned lon } // Save the corresponding code and mask dots for later (match distributions) - unsigned int match_distances_counter_idx = atomicAdd(&match_distances_counter[0], 1); - match_distances_buffer_codes_a[match_distances_counter_idx] = code_dots_a[idx]; - match_distances_buffer_codes_b[match_distances_counter_idx] = code_dots_b[idx]; - match_distances_buffer_masks_a[match_distances_counter_idx] = mask_dots_a[idx]; - match_distances_buffer_masks_b[match_distances_counter_idx] = mask_dots_b[idx]; + // unsigned int match_distances_counter_idx = atomicAdd(&match_distances_counter[0], 1); + // match_distances_buffer_codes_a[match_distances_counter_idx] = code_dots_a[idx]; + // match_distances_buffer_codes_b[match_distances_counter_idx] = code_dots_b[idx]; + // match_distances_buffer_masks_a[match_distances_counter_idx] = mask_dots_a[idx]; + // match_distances_buffer_masks_b[match_distances_counter_idx] = mask_dots_b[idx]; unsigned int outputIdx = totalDbLen * (queryIdx / ALL_ROTATIONS) + dbIdx + offset; atomicOr(&output[outputIdx / 64], (1ULL << (outputIdx % 64))); diff --git a/iris-mpc-gpu/src/dot/share_db.rs b/iris-mpc-gpu/src/dot/share_db.rs index 693ddf803..f66de90c7 100644 --- a/iris-mpc-gpu/src/dot/share_db.rs +++ b/iris-mpc-gpu/src/dot/share_db.rs @@ -1213,6 +1213,7 @@ mod tests { &db_sizes, 0, &streams, + 0, ); masks_engine.dot_reduce_and_multiply( &mask_query_sums, @@ -1221,6 +1222,7 @@ mod tests { 0, &streams, 2, + 0, ); device_manager.await_streams(&streams); diff --git a/iris-mpc-gpu/src/threshold_ring/cuda/kernel.cu b/iris-mpc-gpu/src/threshold_ring/cuda/kernel.cu index a5931b087..397ec57fc 100644 --- a/iris-mpc-gpu/src/threshold_ring/cuda/kernel.cu +++ b/iris-mpc-gpu/src/threshold_ring/cuda/kernel.cu @@ -24,8 +24,16 @@ template __device__ void xor_assign_inner(T *lhs, T *rhs) { template __device__ void arithmetic_xor_inner(T *res_a, T *lhs_a, T *lhs_b, T *rhs_a, T *rhs_b, T *r1, T *r2) { - T mul = (*lhs_a * *rhs_a) + (*lhs_b * *rhs_a) + (*lhs_a * *rhs_b) + *r1 - *r2; - *res_a = lhs_a + rhs_a - 2 * mul; + T lhs_a_val = *lhs_a; + T lhs_b_val = *lhs_b; + T rhs_a_val = *rhs_a; + T rhs_b_val = *rhs_b; + T r1_val = *r1; + T r2_val = *r2; + + T mul = (lhs_a_val * rhs_a_val) + (lhs_b_val * rhs_a_val) + + (lhs_a_val * rhs_b_val) + r1_val - r2_val; + *res_a = lhs_a_val + rhs_a_val - 2 * mul; } // Computes the local part of the multiplication (including randomness) @@ -172,14 +180,6 @@ __device__ void u32_transpose_pack_u64(U64 *out_a, U64 *out_b, U32 *in_a, } } -__device__ void lift_mul_sub(U32 *mask, U16 *mask_corr1, U16 *mask_corr2, - U16 *code) { - U32 lifted; - finalize_lift(mask, &lifted, mask_corr1, mask_corr2, code); - *mask *= A; - *mask -= lifted; -} - __device__ void finalize_lift(U32 *mask, U32 *code_lift, U16 *mask_corr1, U16 *mask_corr2, U16 *code) { *mask -= (U32)(*mask_corr1) << 16; @@ -188,6 +188,14 @@ __device__ void finalize_lift(U32 *mask, U32 *code_lift, U16 *mask_corr1, mul_lift_b(code_lift, code); } +__device__ void lift_mul_sub(U32 *mask, U16 *mask_corr1, U16 *mask_corr2, + U16 *code) { + U32 lifted; + finalize_lift(mask, &lifted, mask_corr1, mask_corr2, code); + *mask *= A; + *mask -= lifted; +} + __device__ void lifted_sub(U32 *mask, U32 *code, U32 *output, U32 a) { *output = *mask * a - *code; } @@ -445,8 +453,8 @@ extern "C" __global__ void shared_lifted_sub(U32 *mask_a, U32 *mask_b, U32 a, int id, size_t n) { size_t i = blockIdx.x * blockDim.x + threadIdx.x; if (i < n) { - lifted_sub(&mask_a[i], &code_a[i], output_a[i], a); - lifted_sub(&mask_b[i], &code_b[i], output_n[i], a); + lifted_sub(&mask_a[i], &code_a[i], &output_a[i], a); + lifted_sub(&mask_b[i], &code_b[i], &output_b[i], a); switch (id) { case 0: mask_a[i] -= 1; // Transforms the <= into < @@ -545,9 +553,8 @@ extern "C" __global__ void collapse_u64_helper(U64 *inout_a, U64 *in_b, } } -extern "C" __global__ void collapse_sum_assign(u32 *inout_a, U32 *inout_b, +extern "C" __global__ void collapse_sum_assign(U32 *inout_a, U32 *inout_b, size_t n) { - size_t i = blockIdx.x * blockDim.x + threadIdx.x; if (i == 1) { for (size_t j = 1; j < n; j++) { @@ -560,9 +567,9 @@ extern "C" __global__ void collapse_sum_assign(u32 *inout_a, U32 *inout_b, } } -extern "C" __global__ void collapse_sum(u32 *inout_a, U32 *inout_b, input_a, - input_b, size_t inout_index, size_t n) { - +extern "C" __global__ void collapse_sum(U32 *inout_a, U32 *inout_b, + U32 *input_a, U32 *input_b, + size_t inout_index, size_t n) { size_t i = blockIdx.x * blockDim.x + threadIdx.x; if (i == 1) { for (size_t j = 0; j < n; j++) {