Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
philsippl committed Aug 17, 2024
1 parent 0563af3 commit 5c46395
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 94 deletions.
64 changes: 32 additions & 32 deletions iris-mpc-gpu/src/dot/distance_comparator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,52 +121,52 @@ impl DistanceComparator {
}
}

pub fn merge_db_results(
pub fn join_db_matches(
&self,
matches_bitmap_left: &[CudaSlice<u64>],
matches_bitmap_right: &[CudaSlice<u64>],
final_results: &[CudaSlice<u32>],
db_sizes: &[usize],
streams: &[CudaStream],
) {
for i in 0..self.device_manager.device_count() {
let num_elements = (db_sizes[i] * self.query_length / ROTATIONS).div_ceil(64);
let threads_per_block = 256;
let blocks_per_grid = num_elements.div_ceil(threads_per_block);
let cfg = LaunchConfig {
block_dim: (threads_per_block as u32, 1, 1),
grid_dim: (blocks_per_grid as u32, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
self.merge_db_kernels[i]
.clone()
.launch_on_stream(
&streams[i],
cfg,
(
&matches_bitmap_left[i],
&matches_bitmap_right[i],
&final_results[i],
(self.query_length / ROTATIONS) as u64,
db_sizes[i] as u64,
num_elements as u64,
),
)
.unwrap();
}
}
self.join_matches(
matches_bitmap_left,
matches_bitmap_right,
final_results,
db_sizes,
streams,
&self.merge_db_kernels,
);
}

pub fn merge_batch_results(
pub fn join_batch_matches(
&self,
matches_bitmap_left: &[CudaSlice<u64>],
matches_bitmap_right: &[CudaSlice<u64>],
final_results: &[CudaSlice<u32>],
streams: &[CudaStream],
) {
self.join_matches(
matches_bitmap_left,
matches_bitmap_right,
final_results,
&vec![self.query_length; self.device_manager.device_count()],
streams,
&self.merge_batch_kernels,
);
}

fn join_matches(
&self,
matches_bitmap_left: &[CudaSlice<u64>],
matches_bitmap_right: &[CudaSlice<u64>],
final_results: &[CudaSlice<u32>],
db_sizes: &[usize],
streams: &[CudaStream],
kernels: &[CudaFunction],
) {
for i in 0..self.device_manager.device_count() {
let num_elements = (self.query_length * self.query_length / ROTATIONS).div_ceil(64);
let num_elements = (db_sizes[i] * self.query_length / ROTATIONS).div_ceil(64);
let threads_per_block = 256;
let blocks_per_grid = num_elements.div_ceil(threads_per_block);
let cfg = LaunchConfig {
Expand All @@ -175,7 +175,7 @@ impl DistanceComparator {
shared_mem_bytes: 0,
};
unsafe {
self.merge_batch_kernels[i]
kernels[i]
.clone()
.launch_on_stream(
&streams[i],
Expand All @@ -185,7 +185,7 @@ impl DistanceComparator {
&matches_bitmap_right[i],
&final_results[i],
(self.query_length / ROTATIONS) as u64,
self.query_length as u64,
db_sizes[i] as u64,
num_elements as u64,
),
)
Expand Down
8 changes: 4 additions & 4 deletions iris-mpc-gpu/src/dot/kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,20 @@ extern "C" __global__ void matmul_correct_and_reduce(int *c, unsigned short *out
}
}

extern "C" __global__ void openResults(unsigned long long *result1, unsigned long long *result2, unsigned long long *result3, unsigned long long *output, size_t dbLength, size_t queryLength, size_t offset, size_t numElements, size_t realDbLen, size_t totalDbLen)
extern "C" __global__ void openResults(unsigned long long *result1, unsigned long long *result2, unsigned long long *result3, unsigned long long *output, size_t chunkLength, size_t queryLength, size_t offset, size_t numElements, size_t realChunkLen, size_t totalDbLen)
{
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numElements)
{
unsigned long long result = result1[idx] ^ result2[idx] ^ result3[idx];
for (int i = 0; i < 64; i++)
{
unsigned int queryIdx = (idx * 64 + i) / dbLength;
unsigned int dbIdx = (idx * 64 + i) % dbLength;
unsigned int queryIdx = (idx * 64 + i) / chunkLength;
unsigned int dbIdx = (idx * 64 + i) % chunkLength;
bool match = (result & (1ULL << i));

// Check if we are out of bounds for the query or db
if (queryIdx >= queryLength || dbIdx >= realDbLen || !match)
if (queryIdx >= queryLength || dbIdx >= realChunkLen || !match)
{
continue;
}
Expand Down
69 changes: 11 additions & 58 deletions iris-mpc-gpu/src/server/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,6 @@ pub struct ServerActor {

const NON_MATCH_ID: u32 = u32::MAX;

const RESULTS_INIT_HOST: [u32; MAX_BATCH_SIZE * ROTATIONS] =
[NON_MATCH_ID; MAX_BATCH_SIZE * ROTATIONS];
const FINAL_RESULTS_INIT_HOST: [u32; MAX_BATCH_SIZE] = [NON_MATCH_ID; MAX_BATCH_SIZE];

impl ServerActor {
#[allow(clippy::too_many_arguments)]
pub fn new(
Expand Down Expand Up @@ -491,15 +487,15 @@ impl ServerActor {

// Merge results and fetch matching indices
// Format: host_results[device_index][query_index]
self.distance_comparator.merge_db_results(
self.distance_comparator.join_db_matches(
&self.db_match_list_left,
&self.db_match_list_right,
&self.final_results,
&self.current_db_sizes,
&self.streams[0],
);

self.distance_comparator.merge_batch_results(
self.distance_comparator.join_batch_matches(
&self.batch_match_list_left,
&self.batch_match_list_right,
&self.final_results,
Expand All @@ -525,8 +521,6 @@ impl ServerActor {
.distance_comparator
.fetch_final_results(&self.final_results);

println!("host_results: {:?}", host_results[0]);

// Truncate the results to the batch size
host_results.iter_mut().for_each(|x| x.truncate(batch_size));

Expand Down Expand Up @@ -664,7 +658,7 @@ impl ServerActor {
&self.batch_match_list_left,
&self.batch_match_list_right,
] {
reset_results_bitmap(self.device_manager.devices(), dst, &self.streams[0]);
reset_slice(self.device_manager.devices(), dst, 0, &self.streams[0]);
}

// ---- END RESULT PROCESSING ----
Expand Down Expand Up @@ -984,44 +978,10 @@ impl ServerActor {
self.device_manager.await_streams(&self.streams[1]);
tracing::debug!(party_id = self.party_id, "db search finished");

// ---- START RESULT PROCESSING ----

self.device_manager.await_streams(&self.streams[0]);

// Fetch the final results (blocking)
// let mut host_results = self
// .distance_comparator
// .fetch_final_results(&self.final_results);

// // Truncate the results to the batch size
// host_results.iter_mut().for_each(|x| x.truncate(batch_size));

// // Evaluate the results across devices
// // Format: merged_results[query_index]
// let merged_results = get_merged_results(&host_results,
// self.device_manager.device_count());

// Reset the results buffers for reuse
reset_results(
self.device_manager.devices(),
&self.results,
&RESULTS_INIT_HOST,
&self.streams[0],
);
reset_results(
self.device_manager.devices(),
&self.batch_results,
&RESULTS_INIT_HOST,
&self.streams[0],
);
reset_results(
self.device_manager.devices(),
&self.final_results,
&FINAL_RESULTS_INIT_HOST,
&self.streams[0],
);

// Ok(merged_results)
for dst in &[&self.results, &self.batch_results, &self.final_results] {
reset_slice(self.device_manager.devices(), dst, 0xff, &self.streams[0]);
}
}
}

Expand Down Expand Up @@ -1129,7 +1089,7 @@ fn get_merged_results(host_results: &[Vec<u32>], n_devices: usize) -> Vec<u32> {
results.push(match_entry);

// DEBUG
println!(
tracing::debug!(
"Query {}: match={} [index: {}]",
j,
match_entry != NON_MATCH_ID,
Expand All @@ -1154,25 +1114,18 @@ fn distribute_insertions(results: &[usize], db_sizes: &[usize]) -> Vec<Vec<usize
ret
}

fn reset_results(
fn reset_slice<T>(
devs: &[Arc<CudaDevice>],
dst: &[CudaSlice<u32>],
src: &[u32],
dst: &[CudaSlice<T>],
value: u8,
streams: &[CudaStream],
) {
for i in 0..devs.len() {
devs[i].bind_to_thread().unwrap();
unsafe { result::memcpy_htod_async(*dst[i].device_ptr(), src, streams[i].stream) }.unwrap();
}
}

fn reset_results_bitmap(devs: &[Arc<CudaDevice>], dst: &[CudaSlice<u64>], streams: &[CudaStream]) {
for i in 0..devs.len() {
devs[i].bind_to_thread().unwrap();
unsafe {
result::memset_d8_async(
*dst[i].device_ptr(),
0,
value,
dst[i].num_bytes(),
streams[i].stream,
)
Expand Down

0 comments on commit 5c46395

Please sign in to comment.