From 8b53a53e7e452ac89b6a0e8779f5f714d4c0eb46 Mon Sep 17 00:00:00 2001 From: Paul Quinn Date: Sat, 28 Dec 2024 10:00:01 -0800 Subject: [PATCH] refactoring of relay-client --- endpoints/Cargo.toml | 2 +- endpoints/src/backend.rs | 2 + endpoints/src/endpoints.rs | 5 +- relay-client/Cargo.toml | 11 +- relay-client/src/bin/manual-test.rs | 580 +++++++++++++--------------- relay-client/src/client.rs | 203 +++++----- relay-client/src/lib.rs | 23 +- 7 files changed, 418 insertions(+), 408 deletions(-) diff --git a/endpoints/Cargo.toml b/endpoints/Cargo.toml index 1dccff95..3cea53c3 100644 --- a/endpoints/Cargo.toml +++ b/endpoints/Cargo.toml @@ -9,7 +9,7 @@ publish = false edition = "2021" license = "MIT OR (Apache-2.0 WITH LLVM-exception)" repository = "https://github.com/worldcoin/orb-software" -rust-version = "1.77.0" +rust-version = "1.82.0" [dependencies] hex = "0.4.3" diff --git a/endpoints/src/backend.rs b/endpoints/src/backend.rs index 683e80f4..4309f549 100644 --- a/endpoints/src/backend.rs +++ b/endpoints/src/backend.rs @@ -7,6 +7,7 @@ pub enum Backend { Prod, Staging, Analysis, + Local, } impl Backend { @@ -66,6 +67,7 @@ impl FromStr for Backend { "prod" | "production" => Ok(Self::Prod), "stage" | "staging" | "dev" | "development" => Ok(Self::Staging), "analysis" | "analysis.ml" | "analysis-ml" => Ok(Self::Analysis), + "local" => Ok(Self::Local), _ => Err(BackendParseErr), } } diff --git a/endpoints/src/endpoints.rs b/endpoints/src/endpoints.rs index 2575bade..8f628fa0 100644 --- a/endpoints/src/endpoints.rs +++ b/endpoints/src/endpoints.rs @@ -10,6 +10,7 @@ pub struct Endpoints { pub ai_volume: Url, pub auth: Url, pub ping: Url, + pub relay: Url, } impl Endpoints { @@ -20,7 +21,7 @@ impl Endpoints { pub fn new(backend: Backend, orb_id: &OrbId) -> Self { let subdomain = match backend { Backend::Prod => "orb", - Backend::Staging => "stage.orb", + Backend::Staging | Backend::Local => "stage.orb", Backend::Analysis => "analysis.ml", }; @@ -47,6 +48,8 @@ impl Endpoints { orb_id, "", ), + relay: Url::parse(&format!("https://relay.{subdomain}.worldcoin.org/")) + .expect("urls with validated orb ids should always parse"), } } } diff --git a/relay-client/Cargo.toml b/relay-client/Cargo.toml index 79dd083c..40440a4d 100644 --- a/relay-client/Cargo.toml +++ b/relay-client/Cargo.toml @@ -1,12 +1,19 @@ [package] name = "orb-relay-client" version = "0.1.0" -edition.workspace = true publish = false +# orb-core can't consume crates that use workspace inheritance :( +edition = "2021" +license = "MIT OR (Apache-2.0 WITH LLVM-exception)" +repository = "https://github.com/worldcoin/orb-software" +rust-version = "1.82.0" + [dependencies] -clap = { version = "4", features = ["derive"] } +clap = { version = "4", features = ["derive", "env"] } +derive_more = { workspace = true, features = ["deref", "from", "into"] } eyre.workspace = true +orb-endpoints.workspace = true orb-relay-messages.workspace = true orb-security-utils = { workspace = true, features = ["reqwest"] } rand = "0.8" diff --git a/relay-client/src/bin/manual-test.rs b/relay-client/src/bin/manual-test.rs index df88f47f..1cf40cb9 100644 --- a/relay-client/src/bin/manual-test.rs +++ b/relay-client/src/bin/manual-test.rs @@ -1,44 +1,14 @@ use clap::Parser; -use eyre::{Ok, Result}; -use orb_relay_client::{client::Client, debug_any, PayloadMatcher}; +use eyre::Result; +use orb_endpoints::{Backend, Endpoints, OrbId}; +use orb_relay_client::client::Client; use orb_relay_messages::{common, self_serve}; use rand::{distributions::Alphanumeric, Rng}; use std::{ - env, - sync::LazyLock, + str::FromStr, time::{Duration, Instant, SystemTime, UNIX_EPOCH}, }; - -static BACKEND_URL: LazyLock = LazyLock::new(|| { - let backend = - env::var("RELAY_TOOL_BACKEND").unwrap_or_else(|_| "stage".to_string()); - match backend.as_str() { - "stage" => "https://relay.stage.orb.worldcoin.org", - "prod" => "https://relay.orb.worldcoin.org", - "local" => "http://127.0.0.1:8443", - _ => panic!("Invalid backend option"), - } - .to_string() -}); -static APP_KEY: LazyLock = LazyLock::new(|| { - env::var("RELAY_TOOL_APP_KEY") - .unwrap_or_else(|_| "OTk3b3RGNTFYMnlYZ0dYODJlNkVZSTZqWlZnOHJUeDI=".to_string()) -}); -static ORB_KEY: LazyLock = LazyLock::new(|| { - env::var("RELAY_TOOL_ORB_KEY") - .unwrap_or_else(|_| "NWZxTTZQRlBwMm15ODhxUjRCS283ZERFMTlzek1ZOTU=".to_string()) -}); - -static ORB_ID: LazyLock = LazyLock::new(|| { - env::var("RELAY_TOOL_ORB_ID").unwrap_or_else(|_| "b222b1a3".to_string()) -}); -static SESSION_ID: LazyLock = LazyLock::new(|| { - env::var("RELAY_TOOL_SESSION_ID") - .unwrap_or_else(|_| "6943c6d9-48bf-4f29-9b60-48c63222e3ea".to_string()) -}); -static RELAY_NAMESPACE: LazyLock = LazyLock::new(|| { - env::var("RELAY_TOOL_RELAY_NAMESPACE").unwrap_or_else(|_| String::new()) -}); +use tracing::{error, info}; #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] @@ -47,12 +17,48 @@ struct Args { #[clap(short = 'c', long = "consume-only")] consume_only: bool, /// Run only the stage_producer_app function - #[clap(short = 'p', long = "produce-only")] + #[clap(short = 'p', long = "produce-only", conflicts_with = "consume_only")] produce_only: bool, #[clap(short = 's', long = "start-orb-signup")] start_orb_signup: bool, #[clap(short = 'w', long = "slow-tests")] slow_tests: bool, + + #[clap(long, env = "RELAY_TOOL_ORB_ID", default_value = "b222b1a3")] + orb_id: String, + #[clap( + long, + env = "RELAY_TOOL_SESSION_ID", + default_value = "6943c6d9-48bf-4f29-9b60-48c63222e3ea" + )] + session_id: String, + #[clap(long, env = "RELAY_TOOL_BACKEND", default_value = "staging")] + backend: String, + #[clap( + long, + env = "RELAY_TOOL_APP_KEY", + default_value = "OTk3b3RGNTFYMnlYZ0dYODJlNkVZSTZqWlZnOHJUeDI=" + )] + app_key: String, + #[clap( + long, + env = "RELAY_TOOL_ORB_KEY", + default_value = "NWZxTTZQRlBwMm15ODhxUjRCS283ZERFMTlzek1ZOTU=" + )] + orb_key: String, + #[clap(long, env = "RELAY_TOOL_RELAY_NAMESPACE", default_value = "relay-tool")] + relay_namespace: String, +} + +fn backend_url(args: &Args) -> String { + let backend = Backend::from_str(args.backend.as_str()).unwrap_or(Backend::Staging); + if backend == Backend::Local { + "http://127.0.0.1:8443".to_string() + } else { + let endpoints = + Endpoints::new(backend, &OrbId::from_str(&args.orb_id.as_str()).unwrap()); + endpoints.relay.to_string() + } } #[tokio::main] @@ -62,52 +68,53 @@ async fn main() -> Result<()> { let args = Args::parse(); if args.consume_only { - stage_consumer_app().await?; + stage_consumer_app(&args).await?; } else if args.start_orb_signup { - stage_producer_from_app_start_orb_signup().await?; + stage_producer_from_app_start_orb_signup(&args).await?; } else if args.produce_only { - stage_producer_orb().await?; + stage_producer_orb(&args).await?; } else { - app_to_orb().await?; - orb_to_app().await?; - orb_to_app_with_state_request().await?; - orb_to_app_blocking_send().await?; + app_to_orb(&args).await?; + orb_to_app(&args).await?; + orb_to_app_with_state_request(&args).await?; + orb_to_app_blocking_send(&args).await?; if args.slow_tests { - orb_to_app_with_clients_created_later_and_delay().await?; + orb_to_app_with_clients_created_later_and_delay(&args).await?; } } Ok(()) } -async fn app_to_orb() -> Result<()> { - tracing::info!("== Running App to Orb =="); +async fn app_to_orb(args: &Args) -> Result<()> { + info!("== Running App to Orb =="); let (orb_id, session_id) = get_ids(); let mut app_client = Client::new_as_app( - BACKEND_URL.to_string(), - APP_KEY.to_string(), + backend_url(&args), + args.app_key.to_string(), session_id.to_string(), orb_id.to_string(), + args.relay_namespace.to_string(), ); let now = Instant::now(); app_client.connect().await?; - tracing::info!("Time took to app_connect: {}ms", now.elapsed().as_millis()); + info!("Time took to app_connect: {}ms", now.elapsed().as_millis()); let mut orb_client = Client::new_as_orb( - BACKEND_URL.to_string(), - ORB_KEY.to_string(), + backend_url(&args), + args.orb_key.to_string(), orb_id.to_string(), session_id.to_string(), - RELAY_NAMESPACE.to_string(), + args.relay_namespace.to_string(), ); let now = Instant::now(); orb_client.connect().await?; - tracing::info!("Time took to orb_connect: {}ms", now.elapsed().as_millis()); + info!("Time took to orb_connect: {}ms", now.elapsed().as_millis()); + info!("Sending AnnounceOrbId"); let now = Instant::now(); let time_now = time_now()?; - tracing::info!("Sending time now: {}", time_now); app_client .send(common::v1::AnnounceOrbId { orb_id: time_now.clone(), @@ -115,39 +122,32 @@ async fn app_to_orb() -> Result<()> { hardware_type: common::v1::announce_orb_id::HardwareType::Diamond.into(), }) .await?; - tracing::info!( + info!( "Time took to send a message from the app: {}ms", now.elapsed().as_millis() ); let now = Instant::now(); - 'ext: loop { - #[expect(clippy::never_loop)] - for msg in orb_client.get_buffered_messages().await { - tracing::info!( - "Received message: from: {:?}, to: {:?}, seq: {:?}, payload: {:?}", - msg.src, - msg.dst, - msg.seq, - debug_any(&msg.payload) + match orb_client + .wait_for_msg::(Duration::from_millis(1000)) + .await + { + Ok(msg) => { + assert!( + msg.orb_id == time_now, + "Received orb_id is not the same as sent orb_id" ); - if let Some(common::v1::AnnounceOrbId { orb_id, .. }) = - common::v1::AnnounceOrbId::matches(msg.payload.as_ref().unwrap()) - { - assert!( - orb_id == time_now, - "Received orb_id is not the same as sent orb_id" - ); - break 'ext; - } - unreachable!("Received unexpected message: {msg:?}"); + } + Err(e) => { + error!("Failed to receive AnnounceOrbId: {:?}", e); } } - tracing::info!( + info!( "Time took to receive a message: {}ms", now.elapsed().as_millis() ); + info!("Sending SignupEnded"); let now = Instant::now(); app_client .send(self_serve::orb::v1::SignupEnded { @@ -155,32 +155,24 @@ async fn app_to_orb() -> Result<()> { failure_feedback: [].to_vec(), }) .await?; - tracing::info!( + info!( "Time took to send a second message: {}ms", now.elapsed().as_millis() ); let now = Instant::now(); - 'ext: loop { - #[expect(clippy::never_loop)] - for msg in orb_client.get_buffered_messages().await { - tracing::info!( - "Received message: from: {:?}, to: {:?}, seq: {:?}, payload: {:?}", - msg.src, - msg.dst, - msg.seq, - debug_any(&msg.payload) - ); - if let Some(self_serve::orb::v1::SignupEnded { success, .. }) = - self_serve::orb::v1::SignupEnded::matches(msg.payload.as_ref().unwrap()) - { - assert!(success, "Received: success is not true"); - break 'ext; - } - unreachable!("Received unexpected message: {msg:?}"); + match orb_client + .wait_for_msg::(Duration::from_millis(1000)) + .await + { + Ok(msg) => { + assert!(msg.success, "Received: success is not true"); + } + Err(e) => { + error!("Failed to receive SignupEnded: {:?}", e); } } - tracing::info!( + info!( "Time took to receive a second message: {}ms", now.elapsed().as_millis() ); @@ -195,34 +187,35 @@ async fn app_to_orb() -> Result<()> { Ok(()) } -async fn orb_to_app() -> Result<()> { - tracing::info!("== Running Orb to App =="); +async fn orb_to_app(args: &Args) -> Result<()> { + info!("== Running Orb to App =="); let (orb_id, session_id) = get_ids(); let mut app_client = Client::new_as_app( - BACKEND_URL.to_string(), - APP_KEY.to_string(), + backend_url(&args), + args.app_key.to_string(), session_id.to_string(), orb_id.to_string(), + args.relay_namespace.to_string(), ); let now = Instant::now(); app_client.connect().await?; - tracing::info!("Time took to app_connect: {}ms", now.elapsed().as_millis()); + info!("Time took to app_connect: {}ms", now.elapsed().as_millis()); let mut orb_client = Client::new_as_orb( - BACKEND_URL.to_string(), - ORB_KEY.to_string(), + backend_url(&args), + args.orb_key.to_string(), orb_id.to_string(), session_id.to_string(), - RELAY_NAMESPACE.to_string(), + args.relay_namespace.to_string(), ); let now = Instant::now(); orb_client.connect().await?; - tracing::info!("Time took to orb_connect: {}ms", now.elapsed().as_millis()); + info!("Time took to orb_connect: {}ms", now.elapsed().as_millis()); + info!("Sending AnnounceOrbId"); let now = Instant::now(); let time_now = time_now()?; - tracing::info!("Sending time now: {}", time_now); orb_client .send(common::v1::AnnounceOrbId { orb_id: time_now.clone(), @@ -230,39 +223,32 @@ async fn orb_to_app() -> Result<()> { hardware_type: common::v1::announce_orb_id::HardwareType::Diamond.into(), }) .await?; - tracing::info!( + info!( "Time took to send a message from the app: {}ms", now.elapsed().as_millis() ); let now = Instant::now(); - 'ext: loop { - #[expect(clippy::never_loop)] - for msg in app_client.get_buffered_messages().await { - tracing::info!( - "Received message: from: {:?}, to: {:?}, seq: {:?}, payload: {:?}", - msg.src, - msg.dst, - msg.seq, - debug_any(&msg.payload) + match app_client + .wait_for_msg::(Duration::from_millis(1000)) + .await + { + Ok(msg) => { + assert!( + msg.orb_id == time_now, + "Received orb_id is not the same as sent orb_id" ); - if let Some(common::v1::AnnounceOrbId { orb_id, .. }) = - common::v1::AnnounceOrbId::matches(msg.payload.as_ref().unwrap()) - { - assert!( - orb_id == time_now, - "Received orb_id is not the same as sent orb_id" - ); - break 'ext; - } - unreachable!("Received unexpected message: {msg:?}"); + } + Err(e) => { + error!("Failed to receive AnnounceOrbId: {:?}", e); } } - tracing::info!( + info!( "Time took to receive a message: {}ms", now.elapsed().as_millis() ); + info!("Sending SignupEnded"); let now = Instant::now(); orb_client .send(self_serve::orb::v1::SignupEnded { @@ -270,32 +256,24 @@ async fn orb_to_app() -> Result<()> { failure_feedback: Vec::new(), }) .await?; - tracing::info!( + info!( "Time took to send a second message: {}ms", now.elapsed().as_millis() ); let now = Instant::now(); - 'ext: loop { - #[expect(clippy::never_loop)] - for msg in app_client.get_buffered_messages().await { - tracing::info!( - "Received message: from: {:?}, to: {:?}, seq: {:?}, payload: {:?}", - msg.src, - msg.dst, - msg.seq, - debug_any(&msg.payload) - ); - if let Some(self_serve::orb::v1::SignupEnded { success, .. }) = - self_serve::orb::v1::SignupEnded::matches(msg.payload.as_ref().unwrap()) - { - assert!(success, "Received: success is not true"); - break 'ext; - } - unreachable!("Received unexpected message: {msg:?}"); + match app_client + .wait_for_msg::(Duration::from_millis(1000)) + .await + { + Ok(msg) => { + assert!(msg.success, "Received: success is not true"); + } + Err(e) => { + error!("Failed to receive SignupEnded: {:?}", e); } } - tracing::info!( + info!( "Time took to receive a second message: {}ms", now.elapsed().as_millis() ); @@ -310,117 +288,118 @@ async fn orb_to_app() -> Result<()> { Ok(()) } -async fn orb_to_app_with_state_request() -> Result<()> { - tracing::info!("== Running Orb to App with state request =="); +async fn orb_to_app_with_state_request(args: &Args) -> Result<()> { + info!("== Running Orb to App with state request =="); let (orb_id, session_id) = get_ids(); let mut app_client = Client::new_as_app( - BACKEND_URL.to_string(), - APP_KEY.to_string(), + backend_url(&args), + args.app_key.to_string(), session_id.to_string(), orb_id.to_string(), + args.relay_namespace.to_string(), ); let now = Instant::now(); app_client.connect().await?; - tracing::info!("Time took to app_connect: {}ms", now.elapsed().as_millis()); + info!("Time took to app_connect: {}ms", now.elapsed().as_millis()); let mut orb_client = Client::new_as_orb( - BACKEND_URL.to_string(), - ORB_KEY.to_string(), + backend_url(&args), + args.orb_key.to_string(), orb_id.to_string(), session_id.to_string(), - RELAY_NAMESPACE.to_string(), + args.relay_namespace.to_string(), ); let now = Instant::now(); orb_client.connect().await?; - tracing::info!("Time took to orb_connect: {}ms", now.elapsed().as_millis()); + info!("Time took to orb_connect: {}ms", now.elapsed().as_millis()); + info!("Sending RequestState"); let now = Instant::now(); app_client .send(self_serve::app::v1::RequestState {}) .await?; - tracing::info!( + info!( "Time took to send RequestState from the app: {}ms", now.elapsed().as_millis() ); let now = Instant::now(); - 'ext: loop { - #[expect(clippy::never_loop)] - for msg in app_client.get_buffered_messages().await { - tracing::info!( - "Received message: from: {:?}, to: {:?}, seq: {:?}, payload: {:?}", - msg.src, - msg.dst, - msg.seq, - debug_any(&msg.payload) - ); - break 'ext; + match app_client + .wait_for_payload(Duration::from_millis(1000)) + .await + { + Ok(_) => { + info!("Received RelayPayload"); + } + Err(e) => { + error!("Failed to receive RelayPayload: {:?}", e); } } - tracing::info!( + info!( "Time took to receive a message: {}ms", now.elapsed().as_millis() ); + info!("Sending AnnounceOrbId"); let now = Instant::now(); let time_now = time_now()?; - tracing::info!("Sending time now: {}", time_now); orb_client .send(common::v1::AnnounceOrbId { - orb_id: time_now, + orb_id: time_now.clone(), mode_type: common::v1::announce_orb_id::ModeType::SelfServe.into(), hardware_type: common::v1::announce_orb_id::HardwareType::Diamond.into(), }) .await?; - tracing::info!( + info!( "Time took to send a message from the app: {}ms", now.elapsed().as_millis() ); let now = Instant::now(); - 'ext: loop { - #[expect(clippy::never_loop)] - for msg in app_client.get_buffered_messages().await { - tracing::info!( - "Received message: from: {:?}, to: {:?}, seq: {:?}, payload: {:?}", - msg.src, - msg.dst, - msg.seq, - debug_any(&msg.payload) + match app_client + .wait_for_msg::(Duration::from_millis(1000)) + .await + { + Ok(msg) => { + info!("Received AnnounceOrbId: {:?}", msg); + assert!( + msg.orb_id == time_now, + "Received orb_id is not the same as sent orb_id" ); - break 'ext; + } + Err(e) => { + error!("Failed to receive AnnounceOrbId: {:?}", e); } } - tracing::info!( + info!( "Time took to receive a message: {}ms", now.elapsed().as_millis() ); + info!("Sending RequestState"); let now = Instant::now(); app_client .send(self_serve::app::v1::RequestState {}) .await?; - tracing::info!( + info!( "Time took to send RequestState from the app: {}ms", now.elapsed().as_millis() ); let now = Instant::now(); - 'ext: loop { - #[expect(clippy::never_loop)] - for msg in app_client.get_buffered_messages().await { - tracing::info!( - "Received message: from: {:?}, to: {:?}, seq: {:?}, payload: {:?}", - msg.src, - msg.dst, - msg.seq, - debug_any(&msg.payload) - ); - break 'ext; + match app_client + .wait_for_payload(Duration::from_millis(1000)) + .await + { + Ok(_) => { + info!("Received RelayPayload"); + } + Err(e) => { + error!("Failed to receive RelayPayload: {:?}", e); } } - tracing::info!( + info!( "Time took to receive a message: {}ms", now.elapsed().as_millis() ); @@ -435,34 +414,35 @@ async fn orb_to_app_with_state_request() -> Result<()> { Ok(()) } -async fn orb_to_app_blocking_send() -> Result<()> { - tracing::info!("== Running Orb to App blocking send =="); +async fn orb_to_app_blocking_send(args: &Args) -> Result<()> { + info!("== Running Orb to App blocking send =="); let (orb_id, session_id) = get_ids(); let mut app_client = Client::new_as_app( - BACKEND_URL.to_string(), - APP_KEY.to_string(), + backend_url(&args), + args.app_key.to_string(), session_id.to_string(), orb_id.to_string(), + args.relay_namespace.to_string(), ); let now = Instant::now(); app_client.connect().await?; - tracing::info!("Time took to app_connect: {}ms", now.elapsed().as_millis()); + info!("Time took to app_connect: {}ms", now.elapsed().as_millis()); let mut orb_client = Client::new_as_orb( - BACKEND_URL.to_string(), - ORB_KEY.to_string(), + backend_url(&args), + args.orb_key.to_string(), orb_id.to_string(), session_id.to_string(), - RELAY_NAMESPACE.to_string(), + args.relay_namespace.to_string(), ); let now = Instant::now(); orb_client.connect().await?; - tracing::info!("Time took to orb_connect: {}ms", now.elapsed().as_millis()); + info!("Time took to orb_connect: {}ms", now.elapsed().as_millis()); let now = Instant::now(); let time_now = time_now()?; - tracing::info!("Sending time now: {}", time_now); + info!("Sending AnnounceOrbId"); orb_client .send_blocking( common::v1::AnnounceOrbId { @@ -474,35 +454,28 @@ async fn orb_to_app_blocking_send() -> Result<()> { Duration::from_secs(5), ) .await?; - tracing::info!( + info!( "Time took to send a message from the app: {}ms", now.elapsed().as_millis() ); let now = Instant::now(); - 'ext: loop { - #[expect(clippy::never_loop)] - for msg in app_client.get_buffered_messages().await { - tracing::info!( - "Received message: from: {:?}, to: {:?}, seq: {:?}, payload: {:?}", - msg.src, - msg.dst, - msg.seq, - debug_any(&msg.payload) + match app_client + .wait_for_msg::(Duration::from_millis(1000)) + .await + { + Ok(msg) => { + info!("Received AnnounceOrbId: {:?}", msg); + assert!( + msg.orb_id == time_now, + "Received orb_id is not the same as sent orb_id" ); - if let Some(common::v1::AnnounceOrbId { orb_id, .. }) = - common::v1::AnnounceOrbId::matches(msg.payload.as_ref().unwrap()) - { - assert!( - orb_id == time_now, - "Received orb_id is not the same as sent orb_id" - ); - break 'ext; - } - unreachable!("Received unexpected message: {msg:?}"); + } + Err(e) => { + error!("Failed to receive AnnounceOrbId: {:?}", e); } } - tracing::info!( + info!( "Time took to receive a message: {}ms", now.elapsed().as_millis() ); @@ -517,32 +490,25 @@ async fn orb_to_app_blocking_send() -> Result<()> { Duration::from_secs(5), ) .await?; - tracing::info!( + info!( "Time took to send a second message: {}ms", now.elapsed().as_millis() ); let now = Instant::now(); - 'ext: loop { - #[expect(clippy::never_loop)] - for msg in app_client.get_buffered_messages().await { - tracing::info!( - "Received message: from: {:?}, to: {:?}, seq: {:?}, payload: {:?}", - msg.src, - msg.dst, - msg.seq, - debug_any(&msg.payload) - ); - if let Some(self_serve::orb::v1::SignupEnded { success, .. }) = - self_serve::orb::v1::SignupEnded::matches(msg.payload.as_ref().unwrap()) - { - assert!(success, "Received: success is not true"); - break 'ext; - } - unreachable!("Received unexpected message: {msg:?}"); + match app_client + .wait_for_msg::(Duration::from_millis(1000)) + .await + { + Ok(msg) => { + info!("Received SignupEnded: {:?}", msg); + assert!(msg.success, "Received: success is not true"); + } + Err(e) => { + error!("Failed to receive SignupEnded: {:?}", e); } } - tracing::info!( + info!( "Time took to receive a second message: {}ms", now.elapsed().as_millis() ); @@ -557,23 +523,23 @@ async fn orb_to_app_blocking_send() -> Result<()> { Ok(()) } -async fn orb_to_app_with_clients_created_later_and_delay() -> Result<()> { +async fn orb_to_app_with_clients_created_later_and_delay(args: &Args) -> Result<()> { let (orb_id, session_id) = get_ids(); let mut orb_client = Client::new_as_orb( - BACKEND_URL.to_string(), - ORB_KEY.to_string(), + backend_url(&args), + args.orb_key.to_string(), orb_id.to_string(), session_id.to_string(), - RELAY_NAMESPACE.to_string(), + args.relay_namespace.to_string(), ); let now = Instant::now(); orb_client.connect().await?; - tracing::info!("Time took to orb_connect: {}ms", now.elapsed().as_millis()); + info!("Time took to orb_connect: {}ms", now.elapsed().as_millis()); + info!("Sending AnnounceOrbId"); let now = Instant::now(); let time_now = time_now()?; - tracing::info!("Sending time now: {}", time_now); orb_client .send(common::v1::AnnounceOrbId { orb_id: time_now, @@ -581,39 +547,38 @@ async fn orb_to_app_with_clients_created_later_and_delay() -> Result<()> { hardware_type: common::v1::announce_orb_id::HardwareType::Diamond.into(), }) .await?; - tracing::info!( + info!( "Time took to send a message from the app: {}ms", now.elapsed().as_millis() ); - tracing::info!("Waiting for 60 seconds..."); + info!("Waiting for 60 seconds..."); tokio::time::sleep(tokio::time::Duration::from_secs(60)).await; let mut app_client = Client::new_as_app( - BACKEND_URL.to_string(), - APP_KEY.to_string(), + backend_url(&args), + args.app_key.to_string(), session_id.to_string(), orb_id.to_string(), + args.relay_namespace.to_string(), ); let now = Instant::now(); app_client.connect().await?; - tracing::info!("Time took to app_connect: {}ms", now.elapsed().as_millis()); + info!("Time took to app_connect: {}ms", now.elapsed().as_millis()); let now = Instant::now(); - 'ext: loop { - #[expect(clippy::never_loop)] - for msg in app_client.get_buffered_messages().await { - tracing::info!( - "Received message: from: {:?}, to: {:?}, seq: {:?}, payload: {:?}", - msg.src, - msg.dst, - msg.seq, - debug_any(&msg.payload) - ); - break 'ext; + match app_client + .wait_for_payload(Duration::from_millis(1000)) + .await + { + Ok(_) => { + info!("Received AnnounceOrbId"); + } + Err(e) => { + error!("Failed to receive AnnounceOrbId: {:?}", e); } } - tracing::info!( + info!( "Time took to receive a message: {}ms", now.elapsed().as_millis() ); @@ -640,7 +605,7 @@ fn get_ids() -> (String, String) { .take(10) .map(char::from) .collect(); - tracing::info!("Orb ID: {orb_id}, Session ID: {session_id}"); + info!("Orb ID: {orb_id}, Session ID: {session_id}"); (orb_id, session_id) } @@ -651,47 +616,49 @@ fn time_now() -> Result { .to_string()) } -async fn stage_consumer_app() -> Result<()> { +async fn stage_consumer_app(args: &Args) -> Result<()> { let mut app_client = Client::new_as_app( - BACKEND_URL.to_string(), - APP_KEY.to_string(), - SESSION_ID.to_string(), - ORB_ID.to_string(), + backend_url(&args), + args.app_key.to_string(), + args.session_id.to_string(), + args.orb_id.to_string(), + args.relay_namespace.to_string(), ); let now = Instant::now(); app_client.connect().await?; - tracing::info!("Time took to connect: {}ms", now.elapsed().as_millis()); + info!("Time took to connect: {}ms", now.elapsed().as_millis()); loop { #[expect(clippy::never_loop)] - for msg in app_client.get_buffered_messages().await { - tracing::info!( - "Received message: from: {:?}, to: {:?}, seq: {:?}, payload: {:?}", - msg.src, - msg.dst, - msg.seq, - debug_any(&msg.payload) - ); + match app_client + .wait_for_payload(Duration::from_millis(1000)) + .await + { + Ok(_) => { + info!("Received RelayPayload"); + } + Err(e) => { + error!("Failed to receive RelayPayload: {:?}", e); + } } - tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; } } -async fn stage_producer_orb() -> Result<()> { +async fn stage_producer_orb(args: &Args) -> Result<()> { let mut orb_client = Client::new_as_orb( - BACKEND_URL.to_string(), - ORB_KEY.to_string(), - ORB_ID.to_string(), - SESSION_ID.to_string(), - RELAY_NAMESPACE.to_string(), + backend_url(&args), + args.orb_key.to_string(), + args.orb_id.to_string(), + args.session_id.to_string(), + args.relay_namespace.to_string(), ); let now = Instant::now(); orb_client.connect().await?; - tracing::info!("Time took to orb_connect: {}ms", now.elapsed().as_millis()); + info!("Time took to orb_connect: {}ms", now.elapsed().as_millis()); loop { + info!("Sending AnnounceOrbId"); let time_now = time_now()?; - tracing::info!("Sending time now: {}", time_now); orb_client .send(common::v1::AnnounceOrbId { orb_id: time_now, @@ -704,18 +671,19 @@ async fn stage_producer_orb() -> Result<()> { } } -async fn stage_producer_from_app_start_orb_signup() -> Result<()> { +async fn stage_producer_from_app_start_orb_signup(args: &Args) -> Result<()> { let mut app_client = Client::new_as_app( - BACKEND_URL.to_string(), - APP_KEY.to_string(), - SESSION_ID.to_string(), - ORB_ID.to_string(), + backend_url(&args), + args.app_key.to_string(), + args.session_id.to_string(), + args.orb_id.to_string(), + args.relay_namespace.to_string(), ); let now = Instant::now(); app_client.connect().await?; - tracing::info!("Time took to orb_connect: {}ms", now.elapsed().as_millis()); + info!("Time took to orb_connect: {}ms", now.elapsed().as_millis()); - tracing::info!("Sending StartCapture now"); + info!("Sending StartCapture"); app_client .send(self_serve::app::v1::StartCapture {}) .await?; @@ -723,14 +691,16 @@ async fn stage_producer_from_app_start_orb_signup() -> Result<()> { loop { #[expect(clippy::never_loop)] - for msg in app_client.get_buffered_messages().await { - tracing::info!( - "Received message: from: {:?}, to: {:?}, seq: {:?}, payload: {:?}", - msg.src, - msg.dst, - msg.seq, - debug_any(&msg.payload) - ); + match app_client + .wait_for_payload(Duration::from_millis(1000)) + .await + { + Ok(_) => { + info!("Received RelayPayload"); + } + Err(e) => { + error!("Failed to receive RelayPayload: {:?}", e); + } } } } diff --git a/relay-client/src/client.rs b/relay-client/src/client.rs index ef0339c9..9d98df71 100644 --- a/relay-client/src/client.rs +++ b/relay-client/src/client.rs @@ -21,20 +21,32 @@ use orb_security_utils::reqwest::{ GTS_ROOT_R1_CERT, GTS_ROOT_R2_CERT, GTS_ROOT_R3_CERT, GTS_ROOT_R4_CERT, SFS_ROOT_G2_CERT, }; -use std::{ - any::type_name, - collections::{BTreeMap, VecDeque}, - sync::Arc, -}; +use std::collections::BTreeMap; use tokio::{ sync::{ mpsc::{self, Sender}, - oneshot, Mutex, + oneshot, }, time::{self, Duration}, }; use tokio_stream::{wrappers::ReceiverStream, StreamExt}; use tokio_util::sync::CancellationToken; +use tracing::{debug, error, info, warn}; + +#[derive( + Debug, + Eq, + PartialEq, + Hash, + Ord, + PartialOrd, + Clone, + Copy, + derive_more::Deref, + derive_more::From, + derive_more::Into, +)] +struct AckNum(u64); #[derive(Debug, Clone)] pub struct TokenAuth { @@ -96,7 +108,7 @@ enum OutgoingMessage { /// Client state pub struct Client { - message_buffer: Arc>>, + incoming_rx: Option>, outgoing_tx: Option>, command_tx: Option>, shutdown_token: Option, @@ -137,7 +149,7 @@ impl Client { mode: Mode, ) -> Self { Self { - message_buffer: Arc::new(Mutex::new(VecDeque::new())), + incoming_rx: None, outgoing_tx: None, command_tx: None, shutdown_token: None, @@ -186,13 +198,14 @@ impl Client { token: String, session_id: String, orb_id: String, + namespace: String, ) -> Self { Self::new( url, Auth::Token(TokenAuth { token }), session_id, orb_id, - String::new(), // namespace + namespace, Mode::App, ) } @@ -207,6 +220,7 @@ impl Client { proof: String, session_id: String, orb_id: String, + namespace: String, ) -> Self { Self::new( url, @@ -218,33 +232,11 @@ impl Client { }), session_id, orb_id, - String::new(), // namespace + namespace, Mode::App, ) } - async fn check_for_msg(&self) -> Option { - for msg in self.get_buffered_messages().await { - if let Some(payload) = &msg.payload { - if let Some(specific_payload) = T::matches(payload) { - return Some(specific_payload); - } - tracing::warn!( - "While waiting for payload of type {:?}, we got: {:?}", - type_name::(), - debug_any(&msg.payload) - ); - } - } - None - } - - /// Get buffered messages - pub async fn get_buffered_messages(&self) -> VecDeque { - let mut buffer = self.message_buffer.lock().await; - std::mem::take(&mut *buffer) - } - /// Connect to the Orb-Relay server pub async fn connect(&mut self) -> Result<()> { let shutdown_token = CancellationToken::new(); @@ -252,7 +244,9 @@ impl Client { let (connection_established_tx, connection_established_rx) = oneshot::channel(); - let message_buffer = Arc::clone(&self.message_buffer); + let (incoming_tx, incoming_rx) = mpsc::channel(self.config.max_buffer_size); + self.incoming_rx = Some(incoming_rx); + // TODO: Make the buffer size configurable let (outgoing_tx, mut outgoing_rx) = mpsc::channel(32); self.outgoing_tx = Some(outgoing_tx); @@ -264,24 +258,23 @@ impl Client { let config = self.config.clone(); let no_state = self.no_state(); - tracing::info!( + info!( "Connecting with: src_id: {}, dst_id: {}", - config.src_id, - config.dst_id + config.src_id, config.dst_id ); tokio::spawn(async move { let mut agent = PollerAgent { config: &config, pending_messages: Default::default(), last_message: no_state, - seq: 0, + seq: AckNum(0), }; let mut connection_established_tx = Some(connection_established_tx); loop { if let Err(e) = agent .main_loop( - &message_buffer, + &incoming_tx, shutdown_token.clone(), &mut outgoing_rx, &mut command_rx, @@ -289,18 +282,15 @@ impl Client { ) .await { - tracing::error!("Connection error: {e}"); + error!("Connection error: {e}"); } if shutdown_token.is_cancelled() { - tracing::info!("Connection shutdown"); + info!("Connection shutdown"); break; } - tracing::info!( - "Reconnecting in {}s ...", - config.reconnect_delay.as_secs() - ); + info!("Reconnecting in {}s ...", config.reconnect_delay.as_secs()); tokio::time::sleep(config.reconnect_delay).await; } shutdown_completed_tx.send(()).ok(); @@ -315,23 +305,56 @@ impl Client { Ok(()) } + pub async fn wait_for_payload(&mut self, wait: Duration) -> Result { + let timeout_future = Box::pin(tokio::time::sleep(wait)); + + tokio::select! { + _ = timeout_future => { + return Err(eyre::eyre!( + "Timeout waiting for payload" + )); + } + message = self.incoming_rx.as_mut().expect("Client not connected").recv() => { + if let Some(payload) = message { + return Ok(payload); + } + } + } + + Err(eyre::eyre!("No valid payload received")) + } + /// Wait for a specific message type pub async fn wait_for_msg( - &self, + &mut self, wait: Duration, ) -> Result { - let start_time = tokio::time::Instant::now(); - loop { - if let Some(payload) = self.check_for_msg::().await { - return Ok(payload); + match self.wait_for_payload(wait).await { + Ok(payload) => { + info!( + "Received message: from: {:?}, to: {:?}, seq: {:?}, payload: {:?}", + payload.src, + payload.dst, + payload.seq, + debug_any(&payload.payload) + ); + if let Some(specific_payload) = + T::matches(&payload.payload.as_ref().unwrap()) + { + return Ok(specific_payload); + } else { + return Err(eyre::eyre!( + "Payload does not match expected type {:?}", + std::any::type_name::() + )); + } } - if start_time.elapsed() >= wait { + Err(_) => { return Err(eyre::eyre!( "Timeout waiting for payload of type {:?}", std::any::type_name::() )); } - tokio::time::sleep(Duration::from_millis(100)).await; } } @@ -369,7 +392,7 @@ impl Client { .ok_or_eyre("client not connected")? .send(msg) .await - .inspect_err(|e| tracing::error!("Failed to send payload: {e}")) + .inspect_err(|e| error!("Failed to send payload: {e}")) .wrap_err("Failed to send payload") } @@ -414,7 +437,7 @@ impl Client { ) { // Let's wait for all acks to be received if self.has_pending_messages().await.map_or(false, |n| n > 0) { - tracing::info!( + info!( "Giving {}ms for pending messages to be acked", wait_for_pending_messages.as_millis() ); @@ -422,7 +445,7 @@ impl Client { } // If there are still pending messages, we retry to send them if self.has_pending_messages().await.map_or(false, |n| n > 0) { - tracing::info!("There are still pending messages, replaying..."); + info!("There are still pending messages, replaying..."); if let Ok(()) = self.replay_pending_messages().await { tokio::time::sleep(wait_for_pending_messages).await; } @@ -433,15 +456,15 @@ impl Client { if let Some(shutdown_completed) = self.shutdown_completed.take() { match tokio::time::timeout(wait_for_shutdown, shutdown_completed).await { - Ok(_) => tracing::info!("Shutdown completed successfully."), - Err(_) => tracing::warn!("Timed out waiting for shutdown to complete."), + Ok(_) => info!("Shutdown completed successfully."), + Err(_) => warn!("Timed out waiting for shutdown to complete."), } } } /// Shutdown the client pub fn shutdown(&mut self) { - tracing::info!("Shutting down requested"); + info!("Shutting down requested"); if let Some(token) = self.shutdown_token.take() { token.cancel(); } @@ -459,7 +482,9 @@ impl Client { let start_time = tokio::time::Instant::now(); let mut spam_time = tokio::time::Instant::now(); loop { - if let Some(payload) = self.check_for_msg::().await { + if let Ok(payload) = + self.wait_for_msg::(Duration::from_millis(100)).await + { return Ok(payload); } @@ -474,7 +499,6 @@ impl Client { std::any::type_name::() )); } - tokio::time::sleep(Duration::from_millis(100)).await; } } } @@ -487,9 +511,10 @@ impl Drop for Client { struct PollerAgent<'a> { config: &'a Config, - pending_messages: BTreeMap>)>, + pending_messages: + BTreeMap>)>, last_message: RelayConnectRequest, - seq: u64, + seq: AckNum, } impl<'a> PollerAgent<'a> { @@ -498,7 +523,7 @@ impl<'a> PollerAgent<'a> { // different sources. async fn main_loop( &mut self, - message_buffer: &Arc>>, + incoming_tx: &mpsc::Sender, shutdown_token: CancellationToken, outgoing_rx: &mut mpsc::Receiver, command_rx: &mut mpsc::Receiver, @@ -520,9 +545,9 @@ impl<'a> PollerAgent<'a> { loop { tokio::select! { () = shutdown_token.cancelled() => { - tracing::info!("Shutting down connection"); + info!("Shutting down connection"); if !self.pending_messages.is_empty() { - tracing::warn!("Pending messages {}: {:?}", self.pending_messages.len(), self.pending_messages); + warn!("Pending messages {}: {:?}", self.pending_messages.len(), self.pending_messages); } return Ok(()); } @@ -543,44 +568,41 @@ impl<'a> PollerAgent<'a> { .await .wrap_err("Failed to send outgoing message")?; } else if src.id != self.config.dst_id { - tracing::error!( + error!( "Skipping received message from unexpected source: {:?}: {payload:?}", src.id ); } else { - self.handle_message( - RelayPayload { src: Some(src), dst, seq, payload: Some(payload) }, - message_buffer, - ) - .await?; + let payload = RelayPayload { src: Some(src), dst, seq, payload: Some(payload) }; + incoming_tx.send(payload).await.wrap_err("Failed to handle incoming message")?; } } Some(Ok(RelayConnectResponse { msg: Some(relay_connect_response::Msg::Ack(ack)) })) => { - if let Some((_, Some(ack_tx))) = self.pending_messages.remove(&ack.seq) { + if let Some((_, Some(ack_tx))) = self.pending_messages.remove(&AckNum(ack.seq)) { if ack_tx.send(()).is_err() { // The receiver has been dropped, possibly due to a timeout. That means we // need to increase the timeout at send_blocking(). - tracing::warn!( + warn!( "Failed to send ack back to send_blocking(): receiver dropped" ); } } } Some(Err(e)) => { - tracing::error!("Error receiving message from tonic stream: {e:?}"); + error!("Error receiving message from tonic stream: {e:?}"); return Err(e.into()); } None => { - tracing::info!("Stream ended"); + info!("Stream ended"); return Ok(()); } _ => { - tracing::error!("Received unexpected message: {message:?}"); + error!("Received unexpected message: {message:?}"); } } } Some(outgoing_message) = outgoing_rx.recv() => { - self.seq = self.seq.wrapping_add(1); + self.seq = AckNum(self.seq.wrapping_add(1)); let (payload, maybe_ack_tx) = match outgoing_message { OutgoingMessage::Normal(payload) => (payload, None), OutgoingMessage::Blocking(payload, ack_tx) => (payload, Some(ack_tx)), @@ -592,11 +614,11 @@ impl<'a> PollerAgent<'a> { let relay_message = RelayPayload { src: Some(Entity { id: self.config.src_id.clone(), entity_type: src_t, namespace: self.config.namespace.clone() }), dst: Some(Entity { id: self.config.dst_id.clone(), entity_type: dst_t, namespace: self.config.namespace.clone() }), - seq: self.seq, + seq: self.seq.into(), payload: Some(payload), }; - tracing::debug!("Sending message: from: {:?}, to: {:?}, seq: {:?}, payload: {:?}", + debug!("Sending message: from: {:?}, to: {:?}, seq: {:?}, payload: {:?}", relay_message.src, relay_message.dst, relay_message.seq, debug_any(&relay_message.payload)); self.pending_messages.insert(self.seq, (relay_message.clone().into(), maybe_ack_tx)); @@ -612,15 +634,15 @@ impl<'a> PollerAgent<'a> { let _ = reply_tx.send(self.pending_messages.len()); } Command::Reconnect => { - tracing::info!("Reconnecting..."); + info!("Reconnecting..."); return Ok(()); } } } _ = interval.tick() => { - self.seq = self.seq.wrapping_add(1); + self.seq = AckNum(self.seq.wrapping_add(1)); sender_tx - .send(Heartbeat { seq: self.seq }.into()) + .send(Heartbeat { seq: self.seq.into() }.into()) .await .wrap_err("Failed to send heartbeat")?; }, @@ -633,7 +655,7 @@ impl<'a> PollerAgent<'a> { sender_tx: &Sender, ) -> Result<()> { if !self.pending_messages.is_empty() { - tracing::warn!("Replaying pending messages: {:?}", self.pending_messages); + warn!("Replaying pending messages: {:?}", self.pending_messages); for (_key, (msg, sender)) in self.pending_messages.iter_mut() { sender_tx .send(msg.clone()) @@ -732,7 +754,7 @@ impl<'a> PollerAgent<'a> { }) = message { return if success { - tracing::info!("Successful connection"); + info!("Successful connection"); Ok(()) } else { Err(eyre::eyre!("Failed to establish connection: {error:?}")) @@ -743,19 +765,4 @@ impl<'a> PollerAgent<'a> { "Connection stream ended before receiving ConnectResponse" )) } - - async fn handle_message( - &self, - payload: RelayPayload, - message_buffer: &Arc>>, - ) -> Result<()> { - let mut buffer = message_buffer.lock().await; - if buffer.len() >= self.config.max_buffer_size { - // Remove the oldest message to maintain the buffer size - let msg: Vec = buffer.drain(0..1).collect(); - tracing::warn!("Buffer is full, removing oldest message: {msg:?}"); - } - buffer.push_back(payload); - Ok(()) - } } diff --git a/relay-client/src/lib.rs b/relay-client/src/lib.rs index 45af3150..9c2b2136 100644 --- a/relay-client/src/lib.rs +++ b/relay-client/src/lib.rs @@ -1,5 +1,7 @@ //! Orb-Relay crate -use orb_relay_messages::{common, prost::Name, prost_types::Any, self_serve}; +use orb_relay_messages::{ + common, orb_commands, prost::Name, prost_types::Any, self_serve, +}; pub mod client; @@ -66,6 +68,14 @@ impl PayloadMatcher for self_serve::orb::v1::SignupEnded { } } +impl PayloadMatcher for orb_commands::v1::OrbCommand { + type Output = orb_commands::v1::OrbCommand; + + fn matches(payload: &Any) -> Option { + unpack_any::(payload) + } +} + pub trait IntoPayload { fn into_payload(self) -> Any; } @@ -153,6 +163,15 @@ impl IntoPayload for common::v1::NoState { } } +impl IntoPayload for orb_commands::v1::OrbCommand { + fn into_payload(self) -> Any { + Any::from_msg(&orb_commands::v1::OrbCommand { + commands: self.commands, + }) + .unwrap() + } +} + /// Debug any message pub fn debug_any(any: &Option) -> String { let Some(any) = any else { @@ -164,6 +183,8 @@ pub fn debug_any(any: &Option) -> String { format!("{:?}", w) } else if let Some(w) = unpack_any::(any) { format!("{:?}", w) + } else if let Some(w) = unpack_any::(any) { + format!("{:?}", w) } else { "Error".to_string() }