Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
wojciechsromek committed Aug 20, 2024
1 parent ac46beb commit e620ad4
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 27 deletions.
2 changes: 1 addition & 1 deletion iris-mpc-common/src/helpers/key_pair.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ impl SharesEncryptionKeyPair {
}
}

pub async fn download_private_key_from_asm(
async fn download_private_key_from_asm(
client: &SecretsManagerClient,
env: &str,
node_id: &str,
Expand Down
10 changes: 5 additions & 5 deletions iris-mpc-common/src/helpers/smpc_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ impl SMPCRequest {
let response = match reqwest::get(self.s3_presigned_url.clone()).await {
Ok(response) => response,
Err(e) => {
eprintln!("Failed to send request: {}", e);
tracing::error!("Failed to send request: {}", e);
return Err(SharesDecodingError::RequestError(e));
}
};
Expand All @@ -76,7 +76,7 @@ impl SMPCRequest {
let shares_file: SharesS3Object = match response.json().await {
Ok(file) => file,
Err(e) => {
eprintln!("Failed to parse JSON: {}", e);
tracing::error!("Failed to parse JSON: {}", e);
return Err(SharesDecodingError::RequestError(e));
}
};
Expand All @@ -87,11 +87,11 @@ impl SMPCRequest {
if let Some(value) = shares_file.get(party_id) {
Ok(value.to_string())
} else {
eprintln!("Failed to find field: {}", field_name);
tracing::error!("Failed to find field: {}", field_name);
Err(SharesDecodingError::SecretStringNotFound)
}
} else {
eprintln!("Failed to download file: {}", response.status());
tracing::error!("Failed to download file: {}", response.status());
Err(SharesDecodingError::ResponseContent {
status: response.status(),
url: self.s3_presigned_url.clone(),
Expand All @@ -116,7 +116,7 @@ impl SMPCRequest {
let json_string = String::from_utf8(bytes)
.map_err(SharesDecodingError::DecodedShareParsingToUTF8Error)?;

println!("shares_json_string: {:?}", json_string);
tracing::info!("shares_json_string: {:?}", json_string);
let iris_share: IrisCodesJSON =
serde_json::from_str(&json_string).map_err(SharesDecodingError::SerdeError)?;
iris_share
Expand Down
18 changes: 14 additions & 4 deletions iris-mpc-common/src/helpers/sqs_s3_helper.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::helpers::key_pair::SharesDecodingError;
use aws_config::meta::region::RegionProviderChain;
use aws_sdk_s3::{
error,
error::SdkError,
presigning::PresigningConfig,
primitives::{ByteStream, SdkBody},
Client,
Expand All @@ -20,20 +22,28 @@ pub async fn upload_file_and_generate_presigned_url(
// Create S3 client
let client = Client::new(&config);
let content_bytestream = ByteStream::new(SdkBody::from(contents));

// Create a PutObject request
client
match client
.put_object()
.bucket(bucket)
.key(key)
.body(content_bytestream)
.send()
.await
.expect("Failed to upload file.");
{
Ok(_) => {
tracing::info!("File uploaded successfully.");
}
Err(e) => {
tracing::error!("Error: Failed to upload file: {:?}", e);
}
}

println!("File uploaded successfully.");
tracing::info!("File uploaded successfully.");

// Create a presigned URL for the uploaded file
let presigning_config = match PresigningConfig::expires_in(Duration::from_secs(3600)) {
let presigning_config = match PresigningConfig::expires_in(Duration::from_secs(36000)) {
Ok(config) => config,
Err(e) => return Err(SharesDecodingError::PresigningConfigError(e)),
};
Expand Down
2 changes: 1 addition & 1 deletion iris-mpc-gpu/src/dot/share_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1155,7 +1155,7 @@ mod tests {
let mask = results_masks[0][i] + results_masks[1][i] + results_masks[2][i];

if i == 0 {
println!("Code: {}, Mask: {}", code, mask);
tracing::info!("Code: {}, Mask: {}", code, mask);
}

reconstructed_codes.push(code);
Expand Down
2 changes: 1 addition & 1 deletion iris-mpc-gpu/src/helpers/device_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ impl DeviceManager {
devices.push(CudaDevice::new(i as usize).unwrap());
}

println!("Found {} devices", devices.len());
tracing::info!("Found {} devices", devices.len());

Self { devices }
}
Expand Down
42 changes: 27 additions & 15 deletions iris-mpc/src/bin/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,28 @@ const RESULT_SQS_AWS_REGION: &str = "eu-central-1";
const RNG_SEED_SERVER: u64 = 42;
const DB_SIZE: usize = 8 * 1_000;
const ENROLLMENT_REQUEST_TYPE: &str = "enrollment";
const PUBLIC_KEY_BASE_URL: &str = "https://d24uxaabh702ht.cloudfront.net";
// const PUBLIC_KEY_BASE_URL: &str = "https://d24uxaabh702ht.cloudfront.net";
const SQS_REQUESTS_BUCKET_NAME: &str = "wf-mpc-stage-smpcv2-sns-requests";
const SQS_REQUESTS_BUCKET_REGION: &str = "eu-north-1";
const SNS_TOPIC_REGION: &str = "eu-north-1";

#[derive(Debug, Parser)]
struct Opt {
#[arg(long, env)]
#[arg(long, env, required = true)]
request_topic_arn: String,

#[arg(long, env)]
#[arg(long, env, required = true)]
response_queue_url: String,

#[arg(long, env, required = true)]
requests_bucket_name: String,

#[arg(long, env, required = true)]
public_key_base_url: String,

#[arg(long, env, required = true)]
requests_bucket_region: String,

#[arg(long, env)]
db_index: Option<usize>,

Expand All @@ -57,6 +66,9 @@ async fn main() -> eyre::Result<()> {
tracing_subscriber::fmt::init();

let Opt {
public_key_base_url,
requests_bucket_name,
requests_bucket_region,
request_topic_arn,
response_queue_url,
db_index,
Expand All @@ -75,7 +87,7 @@ async fn main() -> eyre::Result<()> {

for i in 0..3 {
let public_key_string =
download_public_key(PUBLIC_KEY_BASE_URL.to_string(), i.to_string()).await?;
download_public_key(public_key_base_url.to_string(), i.to_string()).await?;
let public_key_bytes = general_purpose::STANDARD
.decode(public_key_string)
.context("Failed to decode public key")?;
Expand All @@ -86,12 +98,8 @@ async fn main() -> eyre::Result<()> {

let n_repeat = n_repeat.unwrap_or(0);

// THIS IS REQUIRED TO USE THE SQS FROM SECONDARY REGION, URL DOES NOT SUFFICE
let region_provider = Region::new(RESULT_SQS_AWS_REGION);
let shared_config = aws_config::from_env().region(region_provider).load().await;

let sqs_config = aws_config::from_env().region(SNS_TOPIC_REGION).load().await;
let client = Client::new(&sqs_config);
let requests_sns_config = aws_config::from_env().region(SNS_TOPIC_REGION).load().await;
let requests_sns_client = Client::new(&requests_sns_config);

let db = IrisDB::new_random_par(DB_SIZE, &mut StdRng::seed_from_u64(RNG_SEED_SERVER));

Expand All @@ -107,11 +115,15 @@ async fn main() -> eyre::Result<()> {
let thread_responses = responses.clone();

let recv_thread = spawn(async move {
let sqs_client = SqsClient::new(&shared_config);
// // THIS IS REQUIRED TO USE THE SQS FROM SECONDARY REGION, URL DOES NOT
// SUFFICE
let region_provider = Region::new(RESULT_SQS_AWS_REGION);
let results_sqs_config = aws_config::from_env().region(region_provider).load().await;
let results_qs_client = SqsClient::new(&results_sqs_config);
let mut counter = 0;
while counter < N_QUERIES * 3 {
// Receive responses
let msg = sqs_client
let msg = results_qs_client
.receive_message()
.max_number_of_messages(1)
.queue_url(response_queue_url.clone())
Expand All @@ -135,7 +147,7 @@ async fn main() -> eyre::Result<()> {
result.signup_id
);

sqs_client
results_qs_client
.delete_message()
.queue_url(response_queue_url.clone())
.receipt_handle(msg.receipt_handle.unwrap())
Expand Down Expand Up @@ -170,7 +182,7 @@ async fn main() -> eyre::Result<()> {
assert_eq!(result.serial_id.unwrap(), expected_result.unwrap());
}

sqs_client
results_qs_client
.delete_message()
.queue_url(response_queue_url.clone())
.receipt_handle(msg.receipt_handle.unwrap())
Expand Down Expand Up @@ -315,7 +327,7 @@ async fn main() -> eyre::Result<()> {
};

// Send all messages in batch
client
requests_sns_client
.publish()
.topic_arn(request_topic_arn.clone())
.message_group_id(ENROLLMENT_REQUEST_TYPE)
Expand Down

0 comments on commit e620ad4

Please sign in to comment.