diff --git a/Cargo.lock b/Cargo.lock index 9330efe36..0cc4bde74 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -679,9 +679,9 @@ dependencies = [ [[package]] name = "aws-smithy-types" -version = "1.2.8" +version = "1.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07c9cdc179e6afbf5d391ab08c85eac817b51c87e1892a5edb5f7bbdc64314b4" +checksum = "4fbd94a32b3a7d55d3806fe27d98d3ad393050439dd05eb53ece36ec5e3d3510" dependencies = [ "base64-simd", "bytes", @@ -868,7 +868,7 @@ dependencies = [ "bitflags 2.6.0", "cexpr", "clang-sys", - "itertools 0.10.5", + "itertools 0.12.1", "lazy_static", "lazycell", "log", @@ -2623,6 +2623,7 @@ name = "iris-mpc" version = "0.1.0" dependencies = [ "aws-config", + "aws-sdk-s3", "aws-sdk-sns", "aws-sdk-sqs", "axum", @@ -2659,6 +2660,7 @@ name = "iris-mpc-common" version = "0.1.0" dependencies = [ "aws-config", + "aws-credential-types", "aws-sdk-kms", "aws-sdk-s3", "aws-sdk-secretsmanager", diff --git a/iris-mpc-common/Cargo.toml b/iris-mpc-common/Cargo.toml index d9a287689..cbec4c6f1 100644 --- a/iris-mpc-common/Cargo.toml +++ b/iris-mpc-common/Cargo.toml @@ -50,6 +50,7 @@ serde-big-array.workspace = true [dev-dependencies] float_eq = "1" +aws-credential-types = "1.2.1" [[bin]] name = "key-manager" diff --git a/iris-mpc-common/src/config/mod.rs b/iris-mpc-common/src/config/mod.rs index b21aea8cc..71ff9d051 100644 --- a/iris-mpc-common/src/config/mod.rs +++ b/iris-mpc-common/src/config/mod.rs @@ -49,6 +49,9 @@ pub struct Config { #[serde(default)] pub public_key_base_url: String, + #[serde(default = "default_shares_bucket_name")] + pub shares_bucket_name: String, + #[serde(default)] pub clear_db_before_init: bool, @@ -106,6 +109,10 @@ fn default_shutdown_last_results_sync_timeout_secs() -> u64 { 10 } +fn default_shares_bucket_name() -> String { + "wf-mpc-prod-smpcv2-sns-requests".to_string() +} + impl Config { pub fn load_config(prefix: &str) -> eyre::Result { let settings = config::Config::builder(); diff --git a/iris-mpc-common/src/helpers/key_pair.rs b/iris-mpc-common/src/helpers/key_pair.rs index c41554539..cebf7e56c 100644 --- a/iris-mpc-common/src/helpers/key_pair.rs +++ b/iris-mpc-common/src/helpers/key_pair.rs @@ -47,6 +47,8 @@ pub enum SharesDecodingError { url: String, message: String, }, + #[error("Received error message from S3 for key {}: {}", .key, .message)] + S3ResponseContent { key: String, message: String }, #[error(transparent)] SerdeError(#[from] serde_json::error::Error), #[error(transparent)] diff --git a/iris-mpc-common/src/helpers/smpc_request.rs b/iris-mpc-common/src/helpers/smpc_request.rs index af97c5b40..04863df66 100644 --- a/iris-mpc-common/src/helpers/smpc_request.rs +++ b/iris-mpc-common/src/helpers/smpc_request.rs @@ -1,5 +1,6 @@ use super::{key_pair::SharesDecodingError, sha256::calculate_sha256}; use crate::helpers::key_pair::SharesEncryptionKeyPairs; +use aws_sdk_s3::Client as S3Client; use aws_sdk_sns::types::MessageAttributeValue; use aws_sdk_sqs::{ error::SdkError, @@ -7,15 +8,10 @@ use aws_sdk_sqs::{ }; use base64::{engine::general_purpose::STANDARD, Engine}; use eyre::Report; -use reqwest::Client; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use serde_json::Value; -use std::{collections::HashMap, sync::LazyLock}; +use std::{collections::HashMap, sync::Arc}; use thiserror::Error; -use tokio_retry::{ - strategy::{jitter, FixedInterval}, - Retry, -}; #[derive(Serialize, Deserialize, Debug)] pub struct SQSMessage { @@ -113,7 +109,7 @@ pub const UNIQUENESS_MESSAGE_TYPE: &str = "uniqueness"; pub struct UniquenessRequest { pub batch_size: Option, pub signup_id: String, - pub s3_presigned_url: String, + pub s3_key: String, pub iris_shares_file_hashes: [String; 3], } @@ -196,51 +192,45 @@ impl SharesS3Object { } } -static S3_HTTP_CLIENT: LazyLock = LazyLock::new(Client::new); - impl UniquenessRequest { pub async fn get_iris_data_by_party_id( &self, party_id: usize, + bucket_name: &String, + s3_client: &Arc, ) -> Result { - // Send a GET request to the presigned URL - let retry_strategy = FixedInterval::from_millis(200).map(jitter).take(5); - let response = Retry::spawn(retry_strategy, || async { - S3_HTTP_CLIENT - .get(self.s3_presigned_url.clone()) - .send() - .await - }) - .await?; - - // Ensure the request was successful - if response.status().is_success() { - // Parse the JSON response into the SharesS3Object struct - let shares_file: SharesS3Object = match response.json().await { - Ok(file) => file, - Err(e) => { - tracing::error!("Failed to parse JSON: {}", e); - return Err(SharesDecodingError::RequestError(e)); + let response = s3_client + .get_object() + .bucket(bucket_name) + .key(self.s3_key.as_str()) + .send() + .await + .map_err(|err| { + tracing::error!("Failed to download file: {}", err); + SharesDecodingError::S3ResponseContent { + key: self.s3_key.clone(), + message: err.to_string(), } - }; - - // Construct the field name dynamically - let field_name = format!("iris_share_{}", party_id); - // Access the field dynamically - if let Some(value) = shares_file.get(party_id) { - Ok(value.to_string()) - } else { - tracing::error!("Failed to find field: {}", field_name); - Err(SharesDecodingError::SecretStringNotFound) + })?; + + let object_body = response.body.collect().await.map_err(|e| { + tracing::error!("Failed to get object body: {}", e); + SharesDecodingError::S3ResponseContent { + key: self.s3_key.clone(), + message: e.to_string(), } - } else { - tracing::error!("Failed to download file: {}", response.status()); - Err(SharesDecodingError::ResponseContent { - status: response.status(), - url: self.s3_presigned_url.clone(), - message: response.text().await.unwrap_or_default(), - }) - } + })?; + + let bytes = object_body.into_bytes(); + + let shares_file: SharesS3Object = serde_json::from_slice(&bytes)?; + + let field_name = format!("iris_share_{}", party_id); + + shares_file.get(party_id).cloned().ok_or_else(|| { + tracing::error!("Failed to find field: {}", field_name); + SharesDecodingError::SecretStringNotFound + }) } pub fn decrypt_iris_share( diff --git a/iris-mpc-common/tests/smpc_request.rs b/iris-mpc-common/tests/smpc_request.rs index 273c65008..1c2e7d5fd 100644 --- a/iris-mpc-common/tests/smpc_request.rs +++ b/iris-mpc-common/tests/smpc_request.rs @@ -1,6 +1,7 @@ mod tests { + use aws_credential_types::{provider::SharedCredentialsProvider, Credentials}; + use aws_sdk_s3::Client as S3Client; use base64::{engine::general_purpose::STANDARD, Engine}; - use http::StatusCode; use iris_mpc_common::helpers::{ key_pair::{SharesDecodingError, SharesEncryptionKeyPairs}, sha256::calculate_sha256, @@ -8,10 +9,8 @@ mod tests { }; use serde_json::json; use sodiumoxide::crypto::{box_::PublicKey, sealedbox}; - use wiremock::{ - matchers::{method, path}, - Mock, MockServer, ResponseTemplate, - }; + use std::sync::Arc; + use wiremock::{matchers::method, Mock, MockServer, ResponseTemplate}; const PREVIOUS_PUBLIC_KEY: &str = "1UY8lKlS7aVj5ZnorSfLIHlG3jg+L4ToVi4K+mLKqFQ="; const PREVIOUS_PRIVATE_KEY: &str = "X26wWfzP5fKMP7QMz0X3eZsEeF4NhJU92jT69wZg6x8="; @@ -45,7 +44,7 @@ mod tests { UniquenessRequest { batch_size: Some(1), signup_id: "signup_mock".to_string(), - s3_presigned_url: "https://example.com/mock".to_string(), + s3_key: "mock".to_string(), iris_shares_file_hashes: hashes, } } @@ -54,7 +53,7 @@ mod tests { UniquenessRequest { batch_size: None, signup_id: "test_signup_id".to_string(), - s3_presigned_url: "https://example.com/package".to_string(), + s3_key: "package".to_string(), iris_shares_file_hashes: [ "hash_0".to_string(), "hash_1".to_string(), @@ -66,26 +65,46 @@ mod tests { #[tokio::test] async fn test_retrieve_iris_shares_from_s3_success() { let mock_server = MockServer::start().await; - - // Simulate a successful response from the presigned URL + let bucket_name = "bobTheBucket"; + let key = "kateTheKey"; let response_body = json!({ "iris_share_0": "share_0_data", "iris_share_1": "share_1_data", "iris_share_2": "share_2_data" }); - let template = ResponseTemplate::new(StatusCode::OK).set_body_json(response_body.clone()); + let data = response_body.to_string(); Mock::given(method("GET")) - .and(path("/test_presign_url")) - .respond_with(template) + .respond_with( + ResponseTemplate::new(200) + .insert_header("Content-Type", "application/octet-stream") + .set_body_raw(data, "application/octet-stream"), + ) .mount(&mock_server) .await; + let credentials = + Credentials::new("test-access-key", "test-secret-key", None, None, "test"); + let credentials_provider = SharedCredentialsProvider::new(credentials); + // Configure the S3Client to point to the mock server + let config = aws_config::from_env() + .region("us-west-2") + .endpoint_url(mock_server.uri()) + .credentials_provider(credentials_provider) + .load() + .await; + let s3_config = aws_sdk_s3::config::Builder::from(&config) + .endpoint_url(mock_server.uri()) + .force_path_style(true) + .build(); + + let s3_client = Arc::new(S3Client::from_conf(s3_config)); + let smpc_request = UniquenessRequest { batch_size: None, signup_id: "test_signup_id".to_string(), - s3_presigned_url: mock_server.uri().clone() + "/test_presign_url", + s3_key: key.to_string(), iris_shares_file_hashes: [ "hash_0".to_string(), "hash_1".to_string(), @@ -93,7 +112,9 @@ mod tests { ], }; - let result = smpc_request.get_iris_data_by_party_id(0).await; + let result = smpc_request + .get_iris_data_by_party_id(0, &bucket_name.to_string(), &s3_client) + .await; assert!(result.is_ok()); assert_eq!(result.unwrap(), "share_0_data".to_string()); diff --git a/iris-mpc/Cargo.toml b/iris-mpc/Cargo.toml index 53454799e..605a831d8 100644 --- a/iris-mpc/Cargo.toml +++ b/iris-mpc/Cargo.toml @@ -11,6 +11,7 @@ repository.workspace = true aws-config.workspace = true aws-sdk-sns.workspace = true aws-sdk-sqs.workspace = true +aws-sdk-s3.workspace = true axum.workspace = true tokio.workspace = true tracing.workspace = true diff --git a/iris-mpc/src/bin/client.rs b/iris-mpc/src/bin/client.rs index 0ab9f77af..cc0cf0529 100644 --- a/iris-mpc/src/bin/client.rs +++ b/iris-mpc/src/bin/client.rs @@ -372,7 +372,7 @@ async fn main() -> eyre::Result<()> { let request_message = UniquenessRequest { batch_size: None, signup_id: request_id.to_string(), - s3_presigned_url: presigned_url, + s3_key: presigned_url, iris_shares_file_hashes, }; diff --git a/iris-mpc/src/bin/server.rs b/iris-mpc/src/bin/server.rs index 842d4f220..26d69a13b 100644 --- a/iris-mpc/src/bin/server.rs +++ b/iris-mpc/src/bin/server.rs @@ -1,5 +1,6 @@ #![allow(clippy::needless_range_loop)] +use aws_sdk_s3::Client as S3Client; use aws_sdk_sns::{types::MessageAttributeValue, Client as SNSClient}; use aws_sdk_sqs::{config::Region, Client}; use axum::{response::IntoResponse, routing::get, Router}; @@ -122,6 +123,7 @@ async fn receive_batch( party_id: usize, client: &Client, sns_client: &SNSClient, + s3_client: &Arc, config: &Config, store: &Store, skip_request_ids: &[String], @@ -275,17 +277,21 @@ async fn receive_batch( batch_query.metadata.push(batch_metadata); let semaphore = Arc::clone(&semaphore); + let s3_client_arc = Arc::clone(s3_client); + let bucket_name = config.shares_bucket_name.clone(); let handle = tokio::spawn(async move { let _ = semaphore.acquire().await?; - let base_64_encoded_message_payload = - match smpc_request.get_iris_data_by_party_id(party_id).await { - Ok(iris_message_share) => iris_message_share, - Err(e) => { - tracing::error!("Failed to get iris shares: {:?}", e); - eyre::bail!("Failed to get iris shares: {:?}", e); - } - }; + let base_64_encoded_message_payload = match smpc_request + .get_iris_data_by_party_id(party_id, &bucket_name, &s3_client_arc) + .await + { + Ok(iris_message_share) => iris_message_share, + Err(e) => { + tracing::error!("Failed to get iris shares: {:?}", e); + eyre::bail!("Failed to get iris shares: {:?}", e); + } + }; let iris_message_share = match smpc_request.decrypt_iris_share( base_64_encoded_message_payload, @@ -664,6 +670,7 @@ async fn server_main(config: Config) -> eyre::Result<()> { let shared_config = aws_config::from_env().region(region_provider).load().await; let sqs_client = Client::new(&shared_config); let sns_client = SNSClient::new(&shared_config); + let s3_client = Arc::new(S3Client::new(&shared_config)); let shares_encryption_key_pair = match SharesEncryptionKeyPairs::from_storage(config.clone()).await { Ok(key_pair) => key_pair, @@ -1265,6 +1272,7 @@ async fn server_main(config: Config) -> eyre::Result<()> { party_id, &sqs_client, &sns_client, + &s3_client, &config, &store, &skip_request_ids, @@ -1318,6 +1326,7 @@ async fn server_main(config: Config) -> eyre::Result<()> { party_id, &sqs_client, &sns_client, + &s3_client, &config, &store, &skip_request_ids,