diff --git a/rust/theoros/src/handlers/rest/get_calldata.rs b/rust/theoros/src/handlers/rest/get_calldata.rs index 61927856..64435089 100644 --- a/rust/theoros/src/handlers/rest/get_calldata.rs +++ b/rust/theoros/src/handlers/rest/get_calldata.rs @@ -1,66 +1,89 @@ use std::str::FromStr; use alloy::hex; -use axum::extract::{Query, State}; -use axum::Json; +use axum::{ + extract::{Query, State}, + Json, +}; use serde::{Deserialize, Serialize}; use utoipa::{IntoParams, ToResponse, ToSchema}; -use crate::configs::evm_config::EvmChainName; -use crate::errors::GetCalldataError; -use crate::extractors::PathExtractor; -use crate::types::calldata::{AsCalldata, Calldata}; -use crate::AppState; +use crate::{ + configs::evm_config::EvmChainName, + errors::GetCalldataError, + types::calldata::{AsCalldata, Calldata}, + AppState, +}; -#[derive(Default, Deserialize, IntoParams, ToSchema)] -pub struct GetCalldataQuery {} +#[derive(Deserialize, IntoParams, ToSchema)] +pub struct GetCalldataQuery { + pub chain: String, + #[serde(deserialize_with = "deserialize_feed_ids")] + pub feed_ids: Vec, +} #[derive(Debug, Serialize, Deserialize, ToResponse, ToSchema)] pub struct GetCalldataResponse { - pub calldata: Calldata, + pub feed_id: String, pub encoded_calldata: String, } #[utoipa::path( get, - path = "/v1/calldata/{chain_name}/{feed_id}", + path = "/v1/calldata", + params( + GetCalldataQuery + ), responses( ( status = 200, - description = "Constructs the calldata used to update the feed id specified", + description = "Constructs the calldata used to update the specified feed IDs", body = [GetCalldataResponse] ), ( status = 404, - description = "Unknown Feed Id", - body = [GetCalldataError] + description = "Unknown Feed ID", + body = GetCalldataError ) ), - params( - GetCalldataQuery - ), )] pub async fn get_calldata( State(state): State, - PathExtractor(path_args): PathExtractor<(String, String)>, - Query(_params): Query, -) -> Result, GetCalldataError> { + Query(params): Query, +) -> Result>, GetCalldataError> { let started_at = std::time::Instant::now(); - let (raw_chain_name, feed_id) = path_args; - let chain_name = EvmChainName::from_str(&raw_chain_name) - .map_err(|_| GetCalldataError::ChainNotSupported(raw_chain_name.clone()))?; + + let chain_name = + EvmChainName::from_str(¶ms.chain).map_err(|_| GetCalldataError::ChainNotSupported(params.chain.clone()))?; let stored_feed_ids = state.storage.feed_ids(); - if !stored_feed_ids.contains(&feed_id).await { - return Err(GetCalldataError::FeedNotFound(feed_id)); - }; - let calldata = Calldata::build_from(&state, chain_name, feed_id) - .await - .map_err(|e| GetCalldataError::CalldataError(e.to_string()))?; + // Check if all requested feed IDs are supported. + if let Some(missing_id) = stored_feed_ids.contains_vec(¶ms.feed_ids).await { + return Err(GetCalldataError::FeedNotFound(missing_id)); + } + + // Build calldata for each feed ID. + let mut responses = Vec::with_capacity(params.feed_ids.len()); + for feed_id in ¶ms.feed_ids { + let calldata = Calldata::build_from(&state, chain_name, feed_id.clone()) + .await + .map_err(|e| GetCalldataError::CalldataError(e.to_string()))?; + + let response = + GetCalldataResponse { feed_id: feed_id.clone(), encoded_calldata: hex::encode(calldata.as_bytes()) }; + responses.push(response); + } - let response = - GetCalldataResponse { calldata: calldata.clone(), encoded_calldata: hex::encode(calldata.as_bytes()) }; tracing::info!("🌐 get_calldata - {:?}", started_at.elapsed()); - Ok(Json(response)) + Ok(Json(responses)) +} + +/// Deserialize a list of feed ids "A, B, C" into a Vec = [A, B, C]. +fn deserialize_feed_ids<'de, D>(deserializer: D) -> Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + let s: String = String::deserialize(deserializer)?; + Ok(s.split(',').map(|s| s.trim().to_string()).collect()) } diff --git a/rust/theoros/src/handlers/websocket/subscribe_to_calldata.rs b/rust/theoros/src/handlers/websocket/subscribe_to_calldata.rs index 138708af..b7031294 100644 --- a/rust/theoros/src/handlers/websocket/subscribe_to_calldata.rs +++ b/rust/theoros/src/handlers/websocket/subscribe_to_calldata.rs @@ -5,7 +5,7 @@ use std::{ }; use alloy::hex; -use anyhow::{anyhow, Result}; +use anyhow::Result; use axum::{ extract::{ ws::{Message, WebSocket, WebSocketUpgrade}, @@ -21,14 +21,16 @@ use serde::{Deserialize, Serialize}; use tokio::sync::broadcast::Receiver; use utoipa::ToSchema; -use crate::constants::{MAX_CLIENT_MESSAGE_SIZE, PING_INTERVAL_DURATION}; -use crate::types::calldata::AsCalldata; -use crate::types::hyperlane::NewUpdatesAvailableEvent; -use crate::AppState; -use crate::{configs::evm_config::EvmChainName, types::calldata::Calldata}; +use crate::{ + configs::evm_config::EvmChainName, + constants::{MAX_CLIENT_MESSAGE_SIZE, PING_INTERVAL_DURATION}, + types::{ + calldata::{AsCalldata, Calldata}, + hyperlane::NewUpdatesAvailableEvent, + }, + AppState, +}; -// TODO: add config for the client -/// Configuration for a specific data feed. #[derive(Clone)] pub struct DataFeedClientConfig {} @@ -45,9 +47,7 @@ enum ClientMessage { pub struct RpcDataFeed { pub feed_id: String, /// The calldata binary represented as a hex string. - #[serde(skip_serializing_if = "Option::is_none")] - #[schema(value_type = Option)] - pub encoded_calldata: Option, + pub encoded_calldata: String, } #[derive(Serialize, Debug, Clone)] @@ -56,7 +56,7 @@ enum ServerMessage { #[serde(rename = "response")] Response(ServerResponseMessage), #[serde(rename = "data_feed_update")] - DataFeedUpdate { data_feed: RpcDataFeed }, + DataFeedUpdate { data_feeds: Vec }, } #[derive(Serialize, Debug, Clone)] @@ -68,6 +68,10 @@ enum ServerResponseMessage { Err { error: String }, } +/// WebSocket route handler. +/// +/// Upgrades the HTTP connection to a WebSocket connection and spawns a new +/// subscriber to handle incoming and outgoing messages. pub async fn ws_route_handler( ws: WebSocketUpgrade, AxumState(state): AxumState, @@ -76,6 +80,7 @@ pub async fn ws_route_handler( ws.max_message_size(MAX_CLIENT_MESSAGE_SIZE).on_upgrade(move |socket| websocket_handler(socket, state)) } +/// Handles the WebSocket connection for a single client. #[tracing::instrument(skip(stream, state))] async fn websocket_handler(stream: WebSocket, state: AppState) { let ws_state = state.ws.clone(); @@ -90,6 +95,10 @@ async fn websocket_handler(stream: WebSocket, state: AppState) { pub type SubscriberId = usize; +/// Represents a client connected via WebSocket. +/// +/// Manages subscriptions to data feeds, handles incoming client messages, +/// and sends updates to the client. pub struct Subscriber { id: SubscriberId, closed: bool, @@ -104,6 +113,7 @@ pub struct Subscriber { } impl Subscriber { + /// Creates a new `Subscriber` instance. pub fn new( id: SubscriberId, state: Arc, @@ -125,33 +135,42 @@ impl Subscriber { } } + /// Runs the subscriber event loop, handling messages and updates. #[tracing::instrument(skip(self))] pub async fn run(&mut self) { while !self.closed { if let Err(e) = self.handle_next().await { - tracing::error!(subscriber = self.id, error = ?e, "Error Handling Subscriber Message."); + tracing::error!( + subscriber = self.id, + error = ?e, + "Error handling subscriber message." + ); break; } } } + /// Handles the next event, whether it's an incoming message, a data feed update, or a ping. async fn handle_next(&mut self) -> Result<()> { tokio::select! { - maybe_update_feeds_event = self.feeds_receiver.recv() => { - match maybe_update_feeds_event { + maybe_update = self.feeds_receiver.recv() => { + match maybe_update { Ok(_) => self.handle_data_feeds_update().await, - Err(e) => Err(anyhow!("Failed to receive update from store: {:?}", e)), + Err(e) => anyhow::bail!("Failed to receive update from store: {:?}", e), } }, - maybe_message_or_err = self.receiver.next() => { - self.handle_client_message( - maybe_message_or_err.ok_or(anyhow!("Client channel is closed"))?? - ).await - }, - _ = self.ping_interval.tick() => { - if !self.responded_to_ping { - return Err(anyhow!("Subscriber did not respond to ping. Closing connection.")); + maybe_message = self.receiver.next() => { + match maybe_message { + Some(Ok(message)) => self.handle_client_message(message).await, + Some(Err(e)) => anyhow::bail!("WebSocket error: {:?}", e), + None => { + self.closed = true; + Ok(()) + } } + }, + _ = self.ping_interval.tick() => { + anyhow::ensure!(self.responded_to_ping, "Subscriber did not respond to ping. Closing connection."); self.responded_to_ping = false; self.sender.send(Message::Ping(vec![])).await?; Ok(()) @@ -159,110 +178,108 @@ impl Subscriber { } } + /// Handles data feed updates by sending new data to the client for all subscribed feeds. async fn handle_data_feeds_update(&mut self) -> Result<()> { - if self.active_chain.is_none() { + if self.active_chain.is_none() || self.data_feeds_with_config.is_empty() { return Ok(()); } - tracing::debug!(subscriber = self.id, "Handling Data Feeds Update."); - // Retrieve the updates for subscribed feed ids at the given slot - let feed_ids = self.data_feeds_with_config.keys().cloned().collect::>(); - // TODO: add support for multiple feeds - let feed_id = feed_ids.first().unwrap(); - let calldata = - Calldata::build_from(self.state.as_ref(), self.active_chain.unwrap(), feed_id.to_owned()).await?; + tracing::debug!(subscriber = self.id, "Handling data feeds update."); + + // Retrieve the list of subscribed feed IDs. + let feed_ids: Vec = self.data_feeds_with_config.keys().cloned().collect(); + + let mut data_feeds = Vec::with_capacity(feed_ids.len()); + // Build calldata for each subscribed feed and collect them. + for feed_id in feed_ids { + match Calldata::build_from(self.state.as_ref(), self.active_chain.unwrap(), feed_id.clone()).await { + Ok(calldata) => { + data_feeds.push(RpcDataFeed { + feed_id: feed_id.clone(), + encoded_calldata: hex::encode(calldata.as_bytes()), + }); + } + Err(e) => { + tracing::error!("Error building calldata for {}: {}", feed_id, e); + } + } + } + + // Send a single update containing all data feeds. + if !data_feeds.is_empty() { + let update = ServerMessage::DataFeedUpdate { data_feeds }; + let message = serde_json::to_string(&update)?; + self.sender.send(Message::Text(message)).await?; + } - let message = serde_json::to_string(&ServerMessage::DataFeedUpdate { - data_feed: RpcDataFeed { - feed_id: feed_id.clone(), - encoded_calldata: Some(hex::encode(calldata.as_bytes())), - }, - })?; - self.sender.send(message.into()).await?; Ok(()) } + /// Processes messages received from the client. #[tracing::instrument(skip(self, message))] async fn handle_client_message(&mut self, message: Message) -> Result<()> { - let maybe_client_message = match message { + match message { Message::Close(_) => { - // Closing the connection. We don't remove it from the subscribers - // list, instead when the Subscriber struct is dropped the channel - // to subscribers list will be closed and it will eventually get - // removed. tracing::trace!(id = self.id, "πŸ“¨ [CLOSE]"); - - // Send the close message to gracefully shut down the connection - // Otherwise the client might get an abnormal Websocket closure - // error. self.sender.close().await?; self.closed = true; - return Ok(()); + Ok(()) } - Message::Text(text) => serde_json::from_str::(&text), - Message::Binary(data) => serde_json::from_slice::(&data), - Message::Ping(_) => { - // Axum will send Pong automatically - return Ok(()); + Message::Text(text) => self.process_client_message(&text).await, + Message::Binary(data) => { + let text = String::from_utf8(data)?; + self.process_client_message(&text).await } + Message::Ping(_) => Ok(()), // Axum handles PONG responses automatically. Message::Pong(_) => { self.responded_to_ping = true; - return Ok(()); + Ok(()) } - }; + } + } - match maybe_client_message { + /// Parses and processes a client message in text format. + async fn process_client_message(&mut self, text: &str) -> Result<()> { + let client_message: ClientMessage = match serde_json::from_str(text) { + Ok(msg) => msg, Err(e) => { - tracing::error!("πŸ˜Άβ€πŸŒ«οΈ Client disconnected/error occurred. Closing the channel."); - self.sender - .send( - serde_json::to_string(&ServerMessage::Response(ServerResponseMessage::Err { - error: e.to_string(), - }))? - .into(), - ) - .await?; + tracing::error!("Invalid client message format: {}", e); + let message = ServerMessage::Response(ServerResponseMessage::Err { error: e.to_string() }); + self.sender.send(Message::Text(serde_json::to_string(&message)?)).await?; return Ok(()); } + }; - Ok(ClientMessage::Subscribe { ids: feed_ids, chain_name }) => { + match client_message { + ClientMessage::Subscribe { ids, chain_name } => { let stored_feed_ids = self.state.storage.feed_ids(); - // If there is a single feed id that is not found, we don't subscribe to any of the - // asked feed ids and return an error to be more explicit and clear. - match stored_feed_ids.contains_vec(&feed_ids).await { - // TODO: return multiple missing ids - Some(missing_id) => { - self.sender - .send( - serde_json::to_string(&ServerMessage::Response(ServerResponseMessage::Err { - error: format!("Can't subscribe: at least one of the requested feed ids is not supported ({:?})", missing_id), - }))? - .into(), - ) - .await?; - return Ok(()); - } - None => { - for feed_id in feed_ids { - self.data_feeds_with_config.insert(feed_id, DataFeedClientConfig {}); - // TODO: Assert that the chain is supported by theoros - self.active_chain = Some(chain_name); - } - } + // Check if all requested feed IDs are supported. + if let Some(missing_id) = stored_feed_ids.contains_vec(&ids).await { + let message = ServerResponseMessage::Err { + error: format!("Can't subscribe: feed ID not supported ({:?})", missing_id), + }; + self.sender.send(Message::Text(serde_json::to_string(&ServerMessage::Response(message))?)).await?; + return Ok(()); + } + + // Subscribe to the requested feed IDs. + self.active_chain = Some(chain_name); + for feed_id in ids { + self.data_feeds_with_config.insert(feed_id, DataFeedClientConfig {}); } } - Ok(ClientMessage::Unsubscribe { ids: feed_ids }) => { - for feed_id in feed_ids { + ClientMessage::Unsubscribe { ids } => { + for feed_id in ids { self.data_feeds_with_config.remove(&feed_id); } } } + // Acknowledge the successful processing of the client message. self.sender - .send(serde_json::to_string(&ServerMessage::Response(ServerResponseMessage::Success))?.into()) + .send(Message::Text(serde_json::to_string(&ServerMessage::Response(ServerResponseMessage::Success))?)) .await?; - Ok(()) } } diff --git a/rust/theoros/src/services/api/router.rs b/rust/theoros/src/services/api/router.rs index 67e54182..9a5d9c90 100644 --- a/rust/theoros/src/services/api/router.rs +++ b/rust/theoros/src/services/api/router.rs @@ -37,11 +37,11 @@ async fn handler_404() -> impl IntoResponse { } fn ws_route(state: AppState) -> Router { - Router::new().route("/ws", get(ws_route_handler)).with_state(state) + Router::new().route("/ws/calldata", get(ws_route_handler)).with_state(state) } fn calldata_routes(state: AppState) -> Router { - Router::new().route("/calldata/:chain_name/:feed_id", get(get_calldata)).with_state(state) + Router::new().route("/calldata", get(get_calldata)).with_state(state) } fn data_feeds_routes(state: AppState) -> Router { diff --git a/rust/theoros/src/storage/checkpoints.rs b/rust/theoros/src/storage/checkpoints.rs index 6089cd53..f2092605 100644 --- a/rust/theoros/src/storage/checkpoints.rs +++ b/rust/theoros/src/storage/checkpoints.rs @@ -58,7 +58,7 @@ impl SignedCheckpointsStorage { pub async fn get(&self, validators: &[Felt], searched_nonce: u32) -> Vec<(Felt, SignedCheckpointWithMessageId)> { let lock = self.0.read().await; - let mut checkpoints = Vec::new(); + let mut checkpoints = Vec::with_capacity(lock.len()); // Iterate over the map with tuple key (validator, message_id) for ((validator, nonce), checkpoint) in lock.iter() { // Only include if validator is in the provided list and message_id matches diff --git a/rust/theoros/src/storage/feed_id.rs b/rust/theoros/src/storage/feed_id.rs index a1fe8733..6195470a 100644 --- a/rust/theoros/src/storage/feed_id.rs +++ b/rust/theoros/src/storage/feed_id.rs @@ -21,12 +21,6 @@ impl FeedIdsStorage { lock.remove(feed_id); } - /// Checks if the storage contains the given feed ID. - pub async fn contains(&self, feed_id: &str) -> bool { - let lock = self.0.read().await; - lock.contains(feed_id) - } - /// Checks if all feed IDs in the given vector are present in the storage. /// Returns None if all IDs are present, or Some(id) with the first missing ID. pub async fn contains_vec(&self, feed_ids: &[String]) -> Option {