Skip to content

Commit

Permalink
page-lock in spawn_blocking (#893)
Browse files Browse the repository at this point in the history
  • Loading branch information
philsippl authored Jan 10, 2025
1 parent 42d3706 commit 1978beb
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 40 deletions.
29 changes: 28 additions & 1 deletion iris-mpc-gpu/src/helpers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use crate::threshold_ring::protocol::ChunkShare;
use cudarc::driver::{
result::{self, memcpy_dtoh_async, memcpy_htod_async, stream},
sys::{lib, CUdeviceptr, CUstream, CUstream_st},
sys::{lib, CUdeviceptr, CUstream, CUstream_st, CU_MEMHOSTALLOC_PORTABLE},
CudaDevice, CudaSlice, CudaStream, DevicePtr, DevicePtrMut, DeviceRepr, DriverError,
LaunchConfig,
};
use device_manager::DeviceManager;
use query_processor::CudaVec2DSlicerRawPointer;
use std::sync::Arc;

pub mod comm;
Expand Down Expand Up @@ -167,3 +169,28 @@ pub fn htod_on_stream_sync<T: DeviceRepr>(
};
Ok(buf)
}

pub fn register_host_memory(
device_manager: Arc<DeviceManager>,
db: &CudaVec2DSlicerRawPointer,
max_db_length: usize,
code_length: usize,
) {
let max_size = max_db_length / device_manager.device_count();
for (device_index, device) in device_manager.devices().iter().enumerate() {
device.bind_to_thread().unwrap();
unsafe {
let _ = cudarc::driver::sys::lib().cuMemHostRegister_v2(
db.limb_0[device_index] as *mut _,
max_size * code_length,
CU_MEMHOSTALLOC_PORTABLE,
);

let _ = cudarc::driver::sys::lib().cuMemHostRegister_v2(
db.limb_1[device_index] as *mut _,
max_size * code_length,
CU_MEMHOSTALLOC_PORTABLE,
);
}
}
}
1 change: 1 addition & 0 deletions iris-mpc-gpu/src/helpers/query_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ impl<T> Drop for StreamAwareCudaSlice<T> {

/// Holds the raw memory pointers for the 2D slices.
/// Memory is not freed when the struct is dropped, but must be freed manually.
#[derive(Clone)]
pub struct CudaVec2DSlicerRawPointer {
pub limb_0: Vec<u64>,
pub limb_1: Vec<u64>,
Expand Down
72 changes: 36 additions & 36 deletions iris-mpc-gpu/src/server/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,44 +77,44 @@ const KDF_SALT: &str = "111a1a93518f670e9bb0c2c68888e2beb9406d4c4ed571dc77b801e6
const SUPERMATCH_THRESHOLD: usize = 4_000;

pub struct ServerActor {
job_queue: mpsc::Receiver<ServerJob>,
device_manager: Arc<DeviceManager>,
party_id: usize,
job_queue: mpsc::Receiver<ServerJob>,
pub device_manager: Arc<DeviceManager>,
party_id: usize,
// engines
codes_engine: ShareDB,
masks_engine: ShareDB,
batch_codes_engine: ShareDB,
batch_masks_engine: ShareDB,
phase2: Circuits,
phase2_batch: Circuits,
distance_comparator: DistanceComparator,
comms: Vec<Arc<NcclComm>>,
codes_engine: ShareDB,
masks_engine: ShareDB,
batch_codes_engine: ShareDB,
batch_masks_engine: ShareDB,
phase2: Circuits,
phase2_batch: Circuits,
distance_comparator: DistanceComparator,
comms: Vec<Arc<NcclComm>>,
// DB slices
left_code_db_slices: SlicedProcessedDatabase,
left_mask_db_slices: SlicedProcessedDatabase,
right_code_db_slices: SlicedProcessedDatabase,
right_mask_db_slices: SlicedProcessedDatabase,
streams: Vec<Vec<CudaStream>>,
cublas_handles: Vec<Vec<CudaBlas>>,
results: Vec<CudaSlice<u32>>,
batch_results: Vec<CudaSlice<u32>>,
final_results: Vec<CudaSlice<u32>>,
db_match_list_left: Vec<CudaSlice<u64>>,
db_match_list_right: Vec<CudaSlice<u64>>,
batch_match_list_left: Vec<CudaSlice<u64>>,
batch_match_list_right: Vec<CudaSlice<u64>>,
current_db_sizes: Vec<usize>,
query_db_size: Vec<usize>,
max_batch_size: usize,
max_db_size: usize,
return_partial_results: bool,
disable_persistence: bool,
enable_debug_timing: bool,
code_chunk_buffers: Vec<DBChunkBuffers>,
mask_chunk_buffers: Vec<DBChunkBuffers>,
dot_events: Vec<Vec<CUevent>>,
exchange_events: Vec<Vec<CUevent>>,
phase2_events: Vec<Vec<CUevent>>,
pub left_code_db_slices: SlicedProcessedDatabase,
pub left_mask_db_slices: SlicedProcessedDatabase,
pub right_code_db_slices: SlicedProcessedDatabase,
pub right_mask_db_slices: SlicedProcessedDatabase,
streams: Vec<Vec<CudaStream>>,
cublas_handles: Vec<Vec<CudaBlas>>,
results: Vec<CudaSlice<u32>>,
batch_results: Vec<CudaSlice<u32>>,
final_results: Vec<CudaSlice<u32>>,
db_match_list_left: Vec<CudaSlice<u64>>,
db_match_list_right: Vec<CudaSlice<u64>>,
batch_match_list_left: Vec<CudaSlice<u64>>,
batch_match_list_right: Vec<CudaSlice<u64>>,
current_db_sizes: Vec<usize>,
query_db_size: Vec<usize>,
max_batch_size: usize,
max_db_size: usize,
return_partial_results: bool,
disable_persistence: bool,
enable_debug_timing: bool,
code_chunk_buffers: Vec<DBChunkBuffers>,
mask_chunk_buffers: Vec<DBChunkBuffers>,
dot_events: Vec<Vec<CUevent>>,
exchange_events: Vec<Vec<CUevent>>,
phase2_events: Vec<Vec<CUevent>>,
}

const NON_MATCH_ID: u32 = u32::MAX;
Expand Down
34 changes: 31 additions & 3 deletions iris-mpc/src/bin/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ use iris_mpc_common::{
sync::SyncState,
task_monitor::TaskMonitor,
},
IRIS_CODE_LENGTH, MASK_CODE_LENGTH,
};
use iris_mpc_gpu::{
helpers::device_manager::DeviceManager,
helpers::{device_manager::DeviceManager, register_host_memory},
server::{
get_dummy_shares_for_deletion, sync_nccl, BatchMetadata, BatchQuery,
BatchQueryEntriesPreprocessed, ServerActor, ServerJobResult,
Expand Down Expand Up @@ -1045,6 +1046,33 @@ async fn server_main(config: Config) -> eyre::Result<()> {
}
};

tracing::info!("Page-lock host memory");
let left_codes = actor.left_code_db_slices.code_gr.clone();
let right_codes = actor.right_code_db_slices.code_gr.clone();
let left_masks = actor.left_mask_db_slices.code_gr.clone();
let right_masks = actor.right_mask_db_slices.code_gr.clone();
let device_manager_clone = actor.device_manager.clone();

let page_lock_handle = spawn_blocking(move || {
for db in [&left_codes, &right_codes] {
register_host_memory(
device_manager_clone.clone(),
db,
config.max_db_size,
IRIS_CODE_LENGTH,
);
}

for db in [&left_masks, &right_masks] {
register_host_memory(
device_manager_clone.clone(),
db,
config.max_db_size,
MASK_CODE_LENGTH,
);
}
});

let now = Instant::now();
let mut now_load_summary = Instant::now();
let mut time_waiting_for_stream = time::Duration::from_secs(0);
Expand Down Expand Up @@ -1139,8 +1167,8 @@ async fn server_main(config: Config) -> eyre::Result<()> {
tracing::info!("Preprocessing db");
actor.preprocess_db();

tracing::info!("Page-lock host memory");
actor.register_host_memory();
tracing::info!("Waiting for page-lock to finish");
page_lock_handle.await?;

tracing::info!(
"Loaded {} records from db into memory [DB sizes: {:?}]",
Expand Down

0 comments on commit 1978beb

Please sign in to comment.