Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
philsippl committed Jan 5, 2025
1 parent 70bce2e commit 0f26aa5
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 44 deletions.
9 changes: 8 additions & 1 deletion iris-mpc-gpu/benches/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,14 @@ fn bench_memcpy(c: &mut Criterion) {
&streams,
&blass,
);
engine.dot_reduce(&query_sums, &db_slices.code_sums_gr, &db_sizes, 0, &streams);
engine.dot_reduce(
&query_sums,
&db_slices.code_sums_gr,
&db_sizes,
0,
&streams,
0,
);
device_manager.await_streams(&streams);
});
});
Expand Down
46 changes: 25 additions & 21 deletions iris-mpc-gpu/src/dot/distance_comparator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ use crate::{
};
use cudarc::{
driver::{
result::launch_kernel, CudaFunction, CudaSlice, CudaStream, CudaView, DevicePtr,
result::launch_kernel, sys, CudaFunction, CudaSlice, CudaStream, CudaView, DevicePtr,
LaunchAsync,
},
nvrtc::compile_ptx,
};
use std::{cmp::min, sync::Arc};
use std::{cmp::min, ffi::c_void, sync::Arc};

const PTX_SRC: &str = include_str!("kernel.cu");
const OPEN_RESULTS_FUNCTION: &str = "openResults";
Expand Down Expand Up @@ -161,26 +161,30 @@ impl DistanceComparator {
);
self.device_manager.device(i).bind_to_thread().unwrap();

let ptr_param = |ptr: *const sys::CUdeviceptr| ptr as *mut c_void;
let usize_param = |val: &usize| val as *const usize as *mut _;

let params = &mut [
*results1[i].device_ptr() as *mut _,
*results2[i].device_ptr() as *mut _,
*results3[i].device_ptr() as *mut _,
*matches_bitmap[i].device_ptr() as *mut _,
db_sizes[i] as *mut _,
self.query_length as *mut _,
offset as *mut _,
num_elements as *mut _,
real_db_sizes[i] as *mut _,
total_db_sizes[i] as *mut _,
*match_distances_buffers_codes[i].a.device_ptr() as *mut _,
*match_distances_buffers_codes[i].b.device_ptr() as *mut _,
*match_distances_buffers_masks[i].a.device_ptr() as *mut _,
*match_distances_buffers_masks[i].b.device_ptr() as *mut _,
*match_distances_counters[i].device_ptr() as *mut _,
*code_dots[i].a.device_ptr() as *mut _,
*code_dots[i].b.device_ptr() as *mut _,
*mask_dots[i].a.device_ptr() as *mut _,
*mask_dots[i].b.device_ptr() as *mut _,
// Results arrays
ptr_param(results1[i].device_ptr()),
ptr_param(results2[i].device_ptr()),
ptr_param(results3[i].device_ptr()),
ptr_param(matches_bitmap[i].device_ptr()),
usize_param(&db_sizes[i]),
usize_param(&self.query_length),
usize_param(&offset),
usize_param(&num_elements),
usize_param(&real_db_sizes[i]),
usize_param(&total_db_sizes[i]),
ptr_param(match_distances_buffers_codes[i].a.device_ptr()),
ptr_param(match_distances_buffers_codes[i].b.device_ptr()),
ptr_param(match_distances_buffers_masks[i].a.device_ptr()),
ptr_param(match_distances_buffers_masks[i].b.device_ptr()),
ptr_param(match_distances_counters[i].device_ptr()),
ptr_param(code_dots[i].a.device_ptr()),
ptr_param(code_dots[i].b.device_ptr()),
ptr_param(mask_dots[i].a.device_ptr()),
ptr_param(mask_dots[i].b.device_ptr()),
];

unsafe {
Expand Down
10 changes: 5 additions & 5 deletions iris-mpc-gpu/src/dot/kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ extern "C" __global__ void openResults(unsigned long long *result1, unsigned lon
}

// Save the corresponding code and mask dots for later (match distributions)
unsigned int match_distances_counter_idx = atomicAdd(&match_distances_counter[0], 1);
match_distances_buffer_codes_a[match_distances_counter_idx] = code_dots_a[idx];
match_distances_buffer_codes_b[match_distances_counter_idx] = code_dots_b[idx];
match_distances_buffer_masks_a[match_distances_counter_idx] = mask_dots_a[idx];
match_distances_buffer_masks_b[match_distances_counter_idx] = mask_dots_b[idx];
// unsigned int match_distances_counter_idx = atomicAdd(&match_distances_counter[0], 1);
// match_distances_buffer_codes_a[match_distances_counter_idx] = code_dots_a[idx];
// match_distances_buffer_codes_b[match_distances_counter_idx] = code_dots_b[idx];
// match_distances_buffer_masks_a[match_distances_counter_idx] = mask_dots_a[idx];
// match_distances_buffer_masks_b[match_distances_counter_idx] = mask_dots_b[idx];

unsigned int outputIdx = totalDbLen * (queryIdx / ALL_ROTATIONS) + dbIdx + offset;
atomicOr(&output[outputIdx / 64], (1ULL << (outputIdx % 64)));
Expand Down
2 changes: 2 additions & 0 deletions iris-mpc-gpu/src/dot/share_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1213,6 +1213,7 @@ mod tests {
&db_sizes,
0,
&streams,
0,
);
masks_engine.dot_reduce_and_multiply(
&mask_query_sums,
Expand All @@ -1221,6 +1222,7 @@ mod tests {
0,
&streams,
2,
0,
);

device_manager.await_streams(&streams);
Expand Down
41 changes: 24 additions & 17 deletions iris-mpc-gpu/src/threshold_ring/cuda/kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,16 @@ template <typename T> __device__ void xor_assign_inner(T *lhs, T *rhs) {
template <typename T>
__device__ void arithmetic_xor_inner(T *res_a, T *lhs_a, T *lhs_b, T *rhs_a,
T *rhs_b, T *r1, T *r2) {
T mul = (*lhs_a * *rhs_a) + (*lhs_b * *rhs_a) + (*lhs_a * *rhs_b) + *r1 - *r2;
*res_a = lhs_a + rhs_a - 2 * mul;
T lhs_a_val = *lhs_a;
T lhs_b_val = *lhs_b;
T rhs_a_val = *rhs_a;
T rhs_b_val = *rhs_b;
T r1_val = *r1;
T r2_val = *r2;

T mul = (lhs_a_val * rhs_a_val) + (lhs_b_val * rhs_a_val) +
(lhs_a_val * rhs_b_val) + r1_val - r2_val;
*res_a = lhs_a_val + rhs_a_val - 2 * mul;
}

// Computes the local part of the multiplication (including randomness)
Expand Down Expand Up @@ -172,14 +180,6 @@ __device__ void u32_transpose_pack_u64(U64 *out_a, U64 *out_b, U32 *in_a,
}
}

__device__ void lift_mul_sub(U32 *mask, U16 *mask_corr1, U16 *mask_corr2,
U16 *code) {
U32 lifted;
finalize_lift(mask, &lifted, mask_corr1, mask_corr2, code);
*mask *= A;
*mask -= lifted;
}

__device__ void finalize_lift(U32 *mask, U32 *code_lift, U16 *mask_corr1,
U16 *mask_corr2, U16 *code) {
*mask -= (U32)(*mask_corr1) << 16;
Expand All @@ -188,6 +188,14 @@ __device__ void finalize_lift(U32 *mask, U32 *code_lift, U16 *mask_corr1,
mul_lift_b(code_lift, code);
}

__device__ void lift_mul_sub(U32 *mask, U16 *mask_corr1, U16 *mask_corr2,
U16 *code) {
U32 lifted;
finalize_lift(mask, &lifted, mask_corr1, mask_corr2, code);
*mask *= A;
*mask -= lifted;
}

__device__ void lifted_sub(U32 *mask, U32 *code, U32 *output, U32 a) {
*output = *mask * a - *code;
}
Expand Down Expand Up @@ -445,8 +453,8 @@ extern "C" __global__ void shared_lifted_sub(U32 *mask_a, U32 *mask_b,
U32 a, int id, size_t n) {
size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) {
lifted_sub(&mask_a[i], &code_a[i], output_a[i], a);
lifted_sub(&mask_b[i], &code_b[i], output_n[i], a);
lifted_sub(&mask_a[i], &code_a[i], &output_a[i], a);
lifted_sub(&mask_b[i], &code_b[i], &output_b[i], a);
switch (id) {
case 0:
mask_a[i] -= 1; // Transforms the <= into <
Expand Down Expand Up @@ -545,9 +553,8 @@ extern "C" __global__ void collapse_u64_helper(U64 *inout_a, U64 *in_b,
}
}

extern "C" __global__ void collapse_sum_assign(u32 *inout_a, U32 *inout_b,
extern "C" __global__ void collapse_sum_assign(U32 *inout_a, U32 *inout_b,
size_t n) {

size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i == 1) {
for (size_t j = 1; j < n; j++) {
Expand All @@ -560,9 +567,9 @@ extern "C" __global__ void collapse_sum_assign(u32 *inout_a, U32 *inout_b,
}
}

extern "C" __global__ void collapse_sum(u32 *inout_a, U32 *inout_b, input_a,
input_b, size_t inout_index, size_t n) {

extern "C" __global__ void collapse_sum(U32 *inout_a, U32 *inout_b,
U32 *input_a, U32 *input_b,
size_t inout_index, size_t n) {
size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i == 1) {
for (size_t j = 0; j < n; j++) {
Expand Down

0 comments on commit 0f26aa5

Please sign in to comment.