Skip to content

Commit

Permalink
wip: keep results around
Browse files Browse the repository at this point in the history
  • Loading branch information
philsippl committed Jan 3, 2025
1 parent dfc60d7 commit 64a3fe4
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 98 deletions.
17 changes: 17 additions & 0 deletions iris-mpc-gpu/src/dot/distance_comparator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,4 +347,21 @@ impl DistanceComparator {
})
.collect::<Vec<_>>()
}

pub fn prepare_match_distances_buffer(&self, max_size: usize) -> Vec<CudaSlice<u16>> {
(0..self.device_manager.device_count())
.map(|i| {
self.device_manager
.device(i)
.alloc_zeros(max_size / self.device_manager.device_count())
.unwrap()
})
.collect::<Vec<_>>()
}

pub fn prepare_match_distances_counter(&self) -> Vec<CudaSlice<u32>> {
(0..self.device_manager.device_count())
.map(|i| self.device_manager.device(i).alloc_zeros(1).unwrap())
.collect::<Vec<_>>()
}
}
165 changes: 111 additions & 54 deletions iris-mpc-gpu/src/dot/share_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ pub struct ShareDB {
comms: Vec<Arc<NcclComm>>,
ones: Vec<CudaSlice<u8>>,
intermediate_results: Vec<CudaSlice<i32>>,
pub results: Vec<CudaSlice<u8>>,
pub results_peer: Vec<CudaSlice<u8>>,
pub results: Vec<Vec<CudaSlice<u8>>>,
pub results_peer: Vec<Vec<CudaSlice<u8>>>,
code_length: usize,
}

Expand Down Expand Up @@ -183,25 +183,32 @@ impl ShareDB {
// TODO: depending on the batch size, intermediate_results can get quite big, we
// can perform the gemm in chunks to limit this
let mut intermediate_results = vec![];
let mut results = vec![];
let mut results_peer = vec![];
let mut results = vec![vec![]; 2];
let mut results_peer = vec![vec![]; 2];
let results_len = (max_db_length * query_length).div_ceil(64) * 64;

for idx in 0..n_devices {
unsafe {
intermediate_results.push(device_manager.device(idx).alloc(results_len).unwrap());
results.push(
device_manager
.device(idx)
.alloc(results_len * std::mem::size_of::<u16>())
.unwrap(),
);
results_peer.push(
device_manager
.device(idx)
.alloc(results_len * std::mem::size_of::<u16>())
.unwrap(),
);
}
}

for i in 0..2 {
for idx in 0..n_devices {
unsafe {
results_peer[i].push(
device_manager
.device(idx)
.alloc(results_len * std::mem::size_of::<u16>())
.unwrap(),
);
results[i].push(
device_manager
.device(idx)
.alloc(results_len * std::mem::size_of::<u16>())
.unwrap(),
);
}
}
}

Expand Down Expand Up @@ -589,6 +596,7 @@ impl ShareDB {
offset: usize,
streams: &[CudaStream],
multiplier: u16,
results_idx: usize,
) {
for idx in 0..self.device_manager.device_count() {
assert!(
Expand All @@ -611,7 +619,7 @@ impl ShareDB {
cfg,
(
&self.intermediate_results[idx],
&mut self.results[idx],
&mut self.results[results_idx][idx],
*db_sums.limb_0[idx].device_ptr(),
*db_sums.limb_1[idx].device_ptr(),
*query_sums.limb_0[idx].device_ptr(),
Expand All @@ -636,8 +644,17 @@ impl ShareDB {
chunk_sizes: &[usize],
offset: usize,
streams: &[CudaStream],
results_idx: usize,
) {
self.dot_reduce_and_multiply(query_sums, db_sums, chunk_sizes, offset, streams, 1);
self.dot_reduce_and_multiply(
query_sums,
db_sums,
chunk_sizes,
offset,
streams,
1,
results_idx,
);
}

fn single_xor_assign_u8(
Expand Down Expand Up @@ -700,26 +717,33 @@ impl ShareDB {
len: usize,
idx: usize,
streams: &[CudaStream],
results_idx: usize,
) -> CudaSlice<u32> {
assert_eq!(len & 3, 0);
let mut rand = unsafe {
self.device_manager
.device(idx)
.alloc::<u32>(len >> 2)
.unwrap()
.unwrap() // TODO: fix, make this async
};
let mut rand_u8 = self.fill_my_rng_into_u8(&mut rand, idx, streams);
self.single_xor_assign_u8(
&mut rand_u8,
&self.results[idx].slice(..),
&self.results[results_idx][idx].slice(..),
idx,
len,
streams,
);
rand
}

fn otp_decrypt_rng_result(&mut self, len: usize, idx: usize, streams: &[CudaStream]) {
fn otp_decrypt_rng_result(
&mut self,
len: usize,
idx: usize,
streams: &[CudaStream],
results_idx: usize,
) {
assert_eq!(len & 3, 0);
let mut rand = unsafe {
self.device_manager
Expand All @@ -729,22 +753,27 @@ impl ShareDB {
};
let rand_u8 = self.fill_their_rng_into_u8(&mut rand, idx, streams);
self.single_xor_assign_u8(
&mut self.results_peer[idx].slice(..),
&mut self.results_peer[results_idx][idx].slice(..),
&rand_u8,
idx,
len,
streams,
);
}

pub fn reshare_results(&mut self, db_sizes: &[usize], streams: &[CudaStream]) {
pub fn reshare_results(
&mut self,
db_sizes: &[usize],
streams: &[CudaStream],
results_idx: usize,
) {
let next_peer = (self.peer_id + 1) % 3;
let prev_peer = (self.peer_id + 2) % 3;

let send_bufs = (0..self.device_manager.device_count())
.map(|idx| {
let len = db_sizes[idx] * self.query_length * 2;
self.otp_encrypt_rng_result(len, idx, streams)
self.otp_encrypt_rng_result(len, idx, streams, results_idx)
})
.collect_vec();

Expand All @@ -759,22 +788,28 @@ impl ShareDB {
.send_view(&send_view, next_peer, &streams[idx])
.unwrap();

let mut recv_view = self.results_peer[idx].slice(..len);
let mut recv_view = self.results_peer[results_idx][idx].slice(..len);
self.comms[idx]
.receive_view(&mut recv_view, prev_peer, &streams[idx])
.unwrap();
}
nccl::group_end().unwrap();
for idx in 0..self.device_manager.device_count() {
let len = db_sizes[idx] * self.query_length * 2;
self.otp_decrypt_rng_result(len, idx, streams);
self.otp_decrypt_rng_result(len, idx, streams, results_idx);
}
}

pub fn fetch_results(&self, results: &mut [u16], db_sizes: &[usize], device_id: usize) {
pub fn fetch_results(
&self,
results: &mut [u16],
db_sizes: &[usize],
device_id: usize,
results_idx: usize,
) {
unsafe {
let res_trans =
self.results[device_id].transmute(db_sizes[device_id] * self.query_length);
let res_trans = self.results[results_idx][device_id]
.transmute(db_sizes[device_id] * self.query_length);

self.device_manager
.device(device_id)
Expand All @@ -783,25 +818,33 @@ impl ShareDB {
}
}

pub fn result_chunk_shares<'a>(&'a self, db_sizes: &[usize]) -> Vec<ChunkShareView<'a, u16>> {
izip!(db_sizes, self.results.iter(), self.results_peer.iter())
.map(|(&len, xa, xb)| {
// SAFETY: All bit patterns are valid u16 values
let xa_view = unsafe {
xa.transmute(len * self.query_length)
.expect("len is correct")
};
// SAFETY: All bit patterns are valid u16 values
let xb_view = unsafe {
xb.transmute(len * self.query_length)
.expect("len is correct")
};
ChunkShareView {
a: xa_view,
b: xb_view,
}
})
.collect()
pub fn result_chunk_shares<'a>(
&'a self,
db_sizes: &[usize],
results_idx: usize,
) -> Vec<ChunkShareView<'a, u16>> {
izip!(
db_sizes,
self.results[results_idx].iter(),
self.results_peer[results_idx].iter()
)
.map(|(&len, xa, xb)| {
// SAFETY: All bit patterns are valid u16 values
let xa_view = unsafe {
xa.transmute(len * self.query_length)
.expect("len is correct")
};
// SAFETY: All bit patterns are valid u16 values
let xb_view = unsafe {
xb.transmute(len * self.query_length)
.expect("len is correct")
};
ChunkShareView {
a: xa_view,
b: xb_view,
}
})
.collect()
}
}

Expand Down Expand Up @@ -902,7 +945,14 @@ mod tests {
&streams,
&blass,
);
engine.dot_reduce(&query_sums, &db_slices.code_sums_gr, &db_sizes, 0, &streams);
engine.dot_reduce(
&query_sums,
&db_slices.code_sums_gr,
&db_sizes,
0,
&streams,
0,
);
device_manager.await_streams(&streams);

let a_nda = random_ndarray::<u16>(shard_db(&db, n_devices), DB_SIZE, WIDTH);
Expand All @@ -917,7 +967,7 @@ mod tests {
}

for device_idx in 0..n_devices {
engine.fetch_results(&mut gpu_result, &db_sizes, device_idx);
engine.fetch_results(&mut gpu_result, &db_sizes, device_idx, 0);
let selected_elements: Vec<u16> = vec_column_major
.chunks(DB_SIZE)
.flat_map(|chunk| {
Expand Down Expand Up @@ -1004,9 +1054,16 @@ mod tests {
&streams,
&blass,
);
engine.dot_reduce(&query_sums, &db_slices.code_sums_gr, &db_sizes, 0, &streams);
engine.dot_reduce(
&query_sums,
&db_slices.code_sums_gr,
&db_sizes,
0,
&streams,
0,
);
device_manager.await_streams(&streams);
engine.fetch_results(&mut gpu_result[i], &db_sizes, 0);
engine.fetch_results(&mut gpu_result[i], &db_sizes, 0, 0);
}

for i in 0..DB_SIZE * QUERY_SIZE / n_devices {
Expand Down Expand Up @@ -1168,8 +1225,8 @@ mod tests {
device_manager.await_streams(&streams);

// TODO: fetch results also for other devices
codes_engine.fetch_results(&mut results_codes[party_id], &db_sizes, 0);
masks_engine.fetch_results(&mut results_masks[party_id], &db_sizes, 0);
codes_engine.fetch_results(&mut results_codes[party_id], &db_sizes, 0, 0);
masks_engine.fetch_results(&mut results_masks[party_id], &db_sizes, 0, 0);
}

// Reconstruct the results
Expand Down
5 changes: 5 additions & 0 deletions iris-mpc-gpu/src/helpers/query_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ impl DeviceCompactSums {
db_sizes,
offset,
streams,
0,
);
mask_engine.dot_reduce_and_multiply(
&self.mask_query,
Expand All @@ -257,6 +258,7 @@ impl DeviceCompactSums {
offset,
streams,
2,
0,
);
}

Expand All @@ -270,13 +272,15 @@ impl DeviceCompactSums {
database_sizes: &[usize],
offset: usize,
streams: &[CudaStream],
results_idx: usize,
) {
code_engine.dot_reduce(
&self.code_query,
&sliced_code_db.code_sums_gr,
database_sizes,
offset,
streams,
results_idx,
);
mask_engine.dot_reduce_and_multiply(
&self.mask_query,
Expand All @@ -285,6 +289,7 @@ impl DeviceCompactSums {
offset,
streams,
2,
results_idx,
);
}
}
Loading

0 comments on commit 64a3fe4

Please sign in to comment.