Skip to content

Commit

Permalink
Add retrieving s3 shares via getObject (#756)
Browse files Browse the repository at this point in the history
  • Loading branch information
danielle-tfh authored Dec 5, 2024
1 parent 4b3375a commit ef1ae02
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 71 deletions.
8 changes: 5 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions iris-mpc-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ serde-big-array.workspace = true

[dev-dependencies]
float_eq = "1"
aws-credential-types = "1.2.1"

[[bin]]
name = "key-manager"
Expand Down
7 changes: 7 additions & 0 deletions iris-mpc-common/src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ pub struct Config {
#[serde(default)]
pub public_key_base_url: String,

#[serde(default = "default_shares_bucket_name")]
pub shares_bucket_name: String,

#[serde(default)]
pub clear_db_before_init: bool,

Expand Down Expand Up @@ -106,6 +109,10 @@ fn default_shutdown_last_results_sync_timeout_secs() -> u64 {
10
}

fn default_shares_bucket_name() -> String {
"wf-mpc-prod-smpcv2-sns-requests".to_string()
}

impl Config {
pub fn load_config(prefix: &str) -> eyre::Result<Config> {
let settings = config::Config::builder();
Expand Down
2 changes: 2 additions & 0 deletions iris-mpc-common/src/helpers/key_pair.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ pub enum SharesDecodingError {
url: String,
message: String,
},
#[error("Received error message from S3 for key {}: {}", .key, .message)]
S3ResponseContent { key: String, message: String },
#[error(transparent)]
SerdeError(#[from] serde_json::error::Error),
#[error(transparent)]
Expand Down
80 changes: 35 additions & 45 deletions iris-mpc-common/src/helpers/smpc_request.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
use super::{key_pair::SharesDecodingError, sha256::calculate_sha256};
use crate::helpers::key_pair::SharesEncryptionKeyPairs;
use aws_sdk_s3::Client as S3Client;
use aws_sdk_sns::types::MessageAttributeValue;
use aws_sdk_sqs::{
error::SdkError,
operation::{delete_message::DeleteMessageError, receive_message::ReceiveMessageError},
};
use base64::{engine::general_purpose::STANDARD, Engine};
use eyre::Report;
use reqwest::Client;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde_json::Value;
use std::{collections::HashMap, sync::LazyLock};
use std::{collections::HashMap, sync::Arc};
use thiserror::Error;
use tokio_retry::{
strategy::{jitter, FixedInterval},
Retry,
};

#[derive(Serialize, Deserialize, Debug)]
pub struct SQSMessage {
Expand Down Expand Up @@ -113,7 +109,7 @@ pub const UNIQUENESS_MESSAGE_TYPE: &str = "uniqueness";
pub struct UniquenessRequest {
pub batch_size: Option<usize>,
pub signup_id: String,
pub s3_presigned_url: String,
pub s3_key: String,
pub iris_shares_file_hashes: [String; 3],
}

Expand Down Expand Up @@ -196,51 +192,45 @@ impl SharesS3Object {
}
}

static S3_HTTP_CLIENT: LazyLock<Client> = LazyLock::new(Client::new);

impl UniquenessRequest {
pub async fn get_iris_data_by_party_id(
&self,
party_id: usize,
bucket_name: &String,
s3_client: &Arc<S3Client>,
) -> Result<String, SharesDecodingError> {
// Send a GET request to the presigned URL
let retry_strategy = FixedInterval::from_millis(200).map(jitter).take(5);
let response = Retry::spawn(retry_strategy, || async {
S3_HTTP_CLIENT
.get(self.s3_presigned_url.clone())
.send()
.await
})
.await?;

// Ensure the request was successful
if response.status().is_success() {
// Parse the JSON response into the SharesS3Object struct
let shares_file: SharesS3Object = match response.json().await {
Ok(file) => file,
Err(e) => {
tracing::error!("Failed to parse JSON: {}", e);
return Err(SharesDecodingError::RequestError(e));
let response = s3_client
.get_object()
.bucket(bucket_name)
.key(self.s3_key.as_str())
.send()
.await
.map_err(|err| {
tracing::error!("Failed to download file: {}", err);
SharesDecodingError::S3ResponseContent {
key: self.s3_key.clone(),
message: err.to_string(),
}
};

// Construct the field name dynamically
let field_name = format!("iris_share_{}", party_id);
// Access the field dynamically
if let Some(value) = shares_file.get(party_id) {
Ok(value.to_string())
} else {
tracing::error!("Failed to find field: {}", field_name);
Err(SharesDecodingError::SecretStringNotFound)
})?;

let object_body = response.body.collect().await.map_err(|e| {
tracing::error!("Failed to get object body: {}", e);
SharesDecodingError::S3ResponseContent {
key: self.s3_key.clone(),
message: e.to_string(),
}
} else {
tracing::error!("Failed to download file: {}", response.status());
Err(SharesDecodingError::ResponseContent {
status: response.status(),
url: self.s3_presigned_url.clone(),
message: response.text().await.unwrap_or_default(),
})
}
})?;

let bytes = object_body.into_bytes();

let shares_file: SharesS3Object = serde_json::from_slice(&bytes)?;

let field_name = format!("iris_share_{}", party_id);

shares_file.get(party_id).cloned().ok_or_else(|| {
tracing::error!("Failed to find field: {}", field_name);
SharesDecodingError::SecretStringNotFound
})
}

pub fn decrypt_iris_share(
Expand Down
49 changes: 35 additions & 14 deletions iris-mpc-common/tests/smpc_request.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
mod tests {
use aws_credential_types::{provider::SharedCredentialsProvider, Credentials};
use aws_sdk_s3::Client as S3Client;
use base64::{engine::general_purpose::STANDARD, Engine};
use http::StatusCode;
use iris_mpc_common::helpers::{
key_pair::{SharesDecodingError, SharesEncryptionKeyPairs},
sha256::calculate_sha256,
smpc_request::{IrisCodesJSON, UniquenessRequest},
};
use serde_json::json;
use sodiumoxide::crypto::{box_::PublicKey, sealedbox};
use wiremock::{
matchers::{method, path},
Mock, MockServer, ResponseTemplate,
};
use std::sync::Arc;
use wiremock::{matchers::method, Mock, MockServer, ResponseTemplate};

const PREVIOUS_PUBLIC_KEY: &str = "1UY8lKlS7aVj5ZnorSfLIHlG3jg+L4ToVi4K+mLKqFQ=";
const PREVIOUS_PRIVATE_KEY: &str = "X26wWfzP5fKMP7QMz0X3eZsEeF4NhJU92jT69wZg6x8=";
Expand Down Expand Up @@ -45,7 +44,7 @@ mod tests {
UniquenessRequest {
batch_size: Some(1),
signup_id: "signup_mock".to_string(),
s3_presigned_url: "https://example.com/mock".to_string(),
s3_key: "mock".to_string(),
iris_shares_file_hashes: hashes,
}
}
Expand All @@ -54,7 +53,7 @@ mod tests {
UniquenessRequest {
batch_size: None,
signup_id: "test_signup_id".to_string(),
s3_presigned_url: "https://example.com/package".to_string(),
s3_key: "package".to_string(),
iris_shares_file_hashes: [
"hash_0".to_string(),
"hash_1".to_string(),
Expand All @@ -66,34 +65,56 @@ mod tests {
#[tokio::test]
async fn test_retrieve_iris_shares_from_s3_success() {
let mock_server = MockServer::start().await;

// Simulate a successful response from the presigned URL
let bucket_name = "bobTheBucket";
let key = "kateTheKey";
let response_body = json!({
"iris_share_0": "share_0_data",
"iris_share_1": "share_1_data",
"iris_share_2": "share_2_data"
});

let template = ResponseTemplate::new(StatusCode::OK).set_body_json(response_body.clone());
let data = response_body.to_string();

Mock::given(method("GET"))
.and(path("/test_presign_url"))
.respond_with(template)
.respond_with(
ResponseTemplate::new(200)
.insert_header("Content-Type", "application/octet-stream")
.set_body_raw(data, "application/octet-stream"),
)
.mount(&mock_server)
.await;

let credentials =
Credentials::new("test-access-key", "test-secret-key", None, None, "test");
let credentials_provider = SharedCredentialsProvider::new(credentials);
// Configure the S3Client to point to the mock server
let config = aws_config::from_env()
.region("us-west-2")
.endpoint_url(mock_server.uri())
.credentials_provider(credentials_provider)
.load()
.await;
let s3_config = aws_sdk_s3::config::Builder::from(&config)
.endpoint_url(mock_server.uri())
.force_path_style(true)
.build();

let s3_client = Arc::new(S3Client::from_conf(s3_config));

let smpc_request = UniquenessRequest {
batch_size: None,
signup_id: "test_signup_id".to_string(),
s3_presigned_url: mock_server.uri().clone() + "/test_presign_url",
s3_key: key.to_string(),
iris_shares_file_hashes: [
"hash_0".to_string(),
"hash_1".to_string(),
"hash_2".to_string(),
],
};

let result = smpc_request.get_iris_data_by_party_id(0).await;
let result = smpc_request
.get_iris_data_by_party_id(0, &bucket_name.to_string(), &s3_client)
.await;

assert!(result.is_ok());
assert_eq!(result.unwrap(), "share_0_data".to_string());
Expand Down
1 change: 1 addition & 0 deletions iris-mpc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ repository.workspace = true
aws-config.workspace = true
aws-sdk-sns.workspace = true
aws-sdk-sqs.workspace = true
aws-sdk-s3.workspace = true
axum.workspace = true
tokio.workspace = true
tracing.workspace = true
Expand Down
2 changes: 1 addition & 1 deletion iris-mpc/src/bin/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ async fn main() -> eyre::Result<()> {
let request_message = UniquenessRequest {
batch_size: None,
signup_id: request_id.to_string(),
s3_presigned_url: presigned_url,
s3_key: presigned_url,
iris_shares_file_hashes,
};

Expand Down
25 changes: 17 additions & 8 deletions iris-mpc/src/bin/server.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#![allow(clippy::needless_range_loop)]

use aws_sdk_s3::Client as S3Client;
use aws_sdk_sns::{types::MessageAttributeValue, Client as SNSClient};
use aws_sdk_sqs::{config::Region, Client};
use axum::{response::IntoResponse, routing::get, Router};
Expand Down Expand Up @@ -122,6 +123,7 @@ async fn receive_batch(
party_id: usize,
client: &Client,
sns_client: &SNSClient,
s3_client: &Arc<S3Client>,
config: &Config,
store: &Store,
skip_request_ids: &[String],
Expand Down Expand Up @@ -275,17 +277,21 @@ async fn receive_batch(
batch_query.metadata.push(batch_metadata);

let semaphore = Arc::clone(&semaphore);
let s3_client_arc = Arc::clone(s3_client);
let bucket_name = config.shares_bucket_name.clone();
let handle = tokio::spawn(async move {
let _ = semaphore.acquire().await?;

let base_64_encoded_message_payload =
match smpc_request.get_iris_data_by_party_id(party_id).await {
Ok(iris_message_share) => iris_message_share,
Err(e) => {
tracing::error!("Failed to get iris shares: {:?}", e);
eyre::bail!("Failed to get iris shares: {:?}", e);
}
};
let base_64_encoded_message_payload = match smpc_request
.get_iris_data_by_party_id(party_id, &bucket_name, &s3_client_arc)
.await
{
Ok(iris_message_share) => iris_message_share,
Err(e) => {
tracing::error!("Failed to get iris shares: {:?}", e);
eyre::bail!("Failed to get iris shares: {:?}", e);
}
};

let iris_message_share = match smpc_request.decrypt_iris_share(
base_64_encoded_message_payload,
Expand Down Expand Up @@ -664,6 +670,7 @@ async fn server_main(config: Config) -> eyre::Result<()> {
let shared_config = aws_config::from_env().region(region_provider).load().await;
let sqs_client = Client::new(&shared_config);
let sns_client = SNSClient::new(&shared_config);
let s3_client = Arc::new(S3Client::new(&shared_config));
let shares_encryption_key_pair =
match SharesEncryptionKeyPairs::from_storage(config.clone()).await {
Ok(key_pair) => key_pair,
Expand Down Expand Up @@ -1265,6 +1272,7 @@ async fn server_main(config: Config) -> eyre::Result<()> {
party_id,
&sqs_client,
&sns_client,
&s3_client,
&config,
&store,
&skip_request_ids,
Expand Down Expand Up @@ -1318,6 +1326,7 @@ async fn server_main(config: Config) -> eyre::Result<()> {
party_id,
&sqs_client,
&sns_client,
&s3_client,
&config,
&store,
&skip_request_ids,
Expand Down

0 comments on commit ef1ae02

Please sign in to comment.