diff --git a/src/manager.rs b/src/manager.rs index aa2576a..8fbaa07 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -19,16 +19,15 @@ use crate::lsps1::service::{LSPS1ServiceConfig, LSPS1ServiceHandler}; use crate::lsps2::client::{LSPS2ClientConfig, LSPS2ClientHandler}; use crate::lsps2::msgs::LSPS2Message; use crate::lsps2::service::{LSPS2ServiceConfig, LSPS2ServiceHandler}; -use crate::prelude::{HashMap, ToString, Vec}; +use crate::prelude::{HashMap, HashSet, ToString, Vec}; use crate::sync::{Arc, Mutex, RwLock}; use lightning::chain::{self, BestBlock, Confirm, Filter, Listen}; use lightning::ln::channelmanager::{AChannelManager, ChainParameters}; use lightning::ln::features::{InitFeatures, NodeFeatures}; -use lightning::ln::msgs::{ErrorAction, ErrorMessage, LightningError}; +use lightning::ln::msgs::{ErrorAction, LightningError}; use lightning::ln::peer_handler::CustomMessageHandler; use lightning::ln::wire::CustomMessageReader; -use lightning::ln::ChannelId; use lightning::sign::EntropySource; use lightning::util::logger::Level; use lightning::util::ser::Readable; @@ -94,6 +93,8 @@ where pending_messages: Arc, pending_events: Arc, request_id_to_method_map: Mutex>, + // We ignore peers if they send us bogus data. + ignored_peers: RwLock>, lsps0_client_handler: LSPS0ClientHandler, lsps0_service_handler: Option, #[cfg(lsps1)] @@ -126,6 +127,7 @@ where where { let pending_messages = Arc::new(MessageQueue::new()); let pending_events = Arc::new(EventQueue::new()); + let ignored_peers = RwLock::new(HashSet::new()); let lsps0_client_handler = LSPS0ClientHandler::new( entropy_source.clone(), @@ -192,6 +194,7 @@ where { pending_messages, pending_events, request_id_to_method_map: Mutex::new(HashMap::new()), + ignored_peers, lsps0_client_handler, lsps0_service_handler, #[cfg(lsps1)] @@ -480,41 +483,62 @@ where fn handle_custom_message( &self, msg: Self::CustomMessage, sender_node_id: &PublicKey, ) -> Result<(), lightning::ln::msgs::LightningError> { + { + if self.ignored_peers.read().unwrap().contains(&sender_node_id) { + let err = format!("Ignoring message from peer {}.", sender_node_id); + return Err(LightningError { + err, + action: ErrorAction::IgnoreAndLog(Level::Trace), + }); + } + } + let message = { - let mut request_id_to_method_map = self.request_id_to_method_map.lock().unwrap(); - LSPSMessage::from_str_with_id_map(&msg.payload, &mut request_id_to_method_map).map_err( - |_| { - let error = ResponseError { - code: JSONRPC_INVALID_MESSAGE_ERROR_CODE, - message: JSONRPC_INVALID_MESSAGE_ERROR_MESSAGE.to_string(), - data: None, - }; - - self.pending_messages.enqueue(sender_node_id, LSPSMessage::Invalid(error)); - let err = format!("Failed to deserialize invalid LSPS message."); - let err_msg = - Some(ErrorMessage { channel_id: ChannelId([0; 32]), data: err.clone() }); - LightningError { err, action: ErrorAction::DisconnectPeer { msg: err_msg } } - }, - )? + { + let mut request_id_to_method_map = self.request_id_to_method_map.lock().unwrap(); + LSPSMessage::from_str_with_id_map(&msg.payload, &mut request_id_to_method_map) + } + .map_err(|_| { + let error = ResponseError { + code: JSONRPC_INVALID_MESSAGE_ERROR_CODE, + message: JSONRPC_INVALID_MESSAGE_ERROR_MESSAGE.to_string(), + data: None, + }; + + self.pending_messages.enqueue(sender_node_id, LSPSMessage::Invalid(error)); + self.ignored_peers.write().unwrap().insert(*sender_node_id); + let err = format!( + "Failed to deserialize invalid LSPS message. Ignoring peer {} from now on.", + sender_node_id + ); + LightningError { err, action: ErrorAction::IgnoreAndLog(Level::Info) } + })? }; self.handle_lsps_message(message, sender_node_id) } fn get_and_clear_pending_msg(&self) -> Vec<(PublicKey, Self::CustomMessage)> { - let mut request_id_to_method_map = self.request_id_to_method_map.lock().unwrap(); - self.pending_messages - .get_and_clear_pending_msgs() + let pending_messages = self.pending_messages.get_and_clear_pending_msgs(); + + let mut request_ids_and_methods = pending_messages .iter() - .map(|(public_key, lsps_message)| { - if let Some((request_id, method)) = lsps_message.get_request_id_and_method() { - request_id_to_method_map.insert(request_id, method); - } - ( - *public_key, - RawLSPSMessage { payload: serde_json::to_string(&lsps_message).unwrap() }, - ) + .filter_map(|(_, msg)| msg.get_request_id_and_method()) + .peekable(); + + if request_ids_and_methods.peek().is_some() { + let mut request_id_to_method_map_lock = self.request_id_to_method_map.lock().unwrap(); + for (request_id, method) in request_ids_and_methods { + request_id_to_method_map_lock.insert(request_id, method); + } + } + + pending_messages + .into_iter() + .filter_map(|(public_key, msg)| { + serde_json::to_string(&msg) + .ok() + .and_then(|payload| Some((public_key, RawLSPSMessage { payload }))) }) .collect() }