Skip to content

Commit

Permalink
Fix/35 (#37)
Browse files Browse the repository at this point in the history
* Allow multiple event listeners on same ws connection
* fix some logs
  • Loading branch information
jordy25519 authored Feb 26, 2024
1 parent f39ec12 commit 8b36c4c
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 31 deletions.
2 changes: 1 addition & 1 deletion src/controller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ impl AppState {
tx_sig: &str,
sub_account_id: Option<u16>,
) -> GatewayResult<TxEventsResponse> {
let signature = Signature::from_str(&tx_sig).map_err(|err| {
let signature = Signature::from_str(tx_sig).map_err(|err| {
warn!(target: LOG_TARGET, "failed to parse transaction signature: {err:?}");
ControllerError::BadRequest(format!("failed to parse transaction signature: {err:?}"))
})?;
Expand Down
2 changes: 0 additions & 2 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
//! - gateway request/responses
//! - wrappers for presenting drift program types with less implementation detail
//!
use std::io::Empty;

use drift_sdk::{
constants::{ProgramData, BASE_PRECISION, PRICE_PRECISION},
dlob::{self, L2Level, L2Orderbook},
Expand Down
71 changes: 43 additions & 28 deletions src/websocket.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Websocket server
use std::ops::Neg;
use std::{collections::HashMap, ops::Neg, sync::Arc};

use drift_sdk::{
async_utils::retry_policy::{self},
Expand All @@ -14,9 +14,9 @@ use log::{info, warn};
use rust_decimal::Decimal;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde_json::json;
use solana_sdk::account::Account;
use tokio::{
net::{TcpListener, TcpStream},
sync::Mutex,
task::JoinHandle,
};
use tokio_tungstenite::{accept_async, tungstenite::Message};
Expand Down Expand Up @@ -58,21 +58,20 @@ async fn accept_connection(
) {
let addr = stream.peer_addr().expect("peer address");
let ws_stream = accept_async(stream).await.expect("Ws handshake");
info!("accepted Ws connection: {}", addr);
info!(target: LOG_TARGET, "accepted Ws connection: {}", addr);

let (mut ws_out, mut ws_in) = ws_stream.split();
let (message_tx, mut message_rx) = tokio::sync::mpsc::channel::<Message>(32);
let mut stream_handle: Option<JoinHandle<()>> = None;
let subscriptions = Arc::new(Mutex::new(HashMap::<u8, JoinHandle<()>>::default()));

// writes messages to the connection
tokio::spawn(async move {
while let Some(msg) = message_rx.recv().await {
if msg.is_close() {
let _ = ws_out.close().await;
break;
} else {
ws_out.send(msg).await.expect("sent");
}
ws_out.send(msg).await.expect("sent");
}
});

Expand All @@ -84,23 +83,34 @@ async fn accept_connection(
match request.method {
Method::Subscribe => {
// TODO: support subscriptions for individual channels and/or markets
if stream_handle.is_some() {
// no double subs
return;
let mut subscription_map = subscriptions.lock().await;
if subscription_map.contains_key(&request.sub_account_id) {
info!(target: LOG_TARGET, "subscription already exists for: {}", request.sub_account_id);
message_tx
.send(Message::text(
json!({
"error": "bad request",
"reason": "subscription already exists",
})
.to_string(),
))
.await
.unwrap();
continue;
}
info!(target: LOG_TARGET, "subscribing to events for: {}", request.sub_account_id);

let sub_account_address =
wallet.sub_account(request.sub_account_id as u16);
let mut event_stream = EventSubscriber::subscribe(
PubsubClient::new(ws_endpoint.as_str())
.await
.expect("ws connect"),
sub_account_address,
retry_policy::forever(5),
);

let join_handle = tokio::spawn({
let sub_account_address =
wallet.sub_account(request.sub_account_id as u16);
let mut event_stream = EventSubscriber::subscribe(
PubsubClient::new(ws_endpoint.as_str())
.await
.expect("ws connect"),
sub_account_address,
retry_policy::forever(5),
);
let subscription_map = Arc::clone(&subscriptions);
let sub_account_id = request.sub_account_id;
let message_tx = message_tx.clone();
async move {
Expand All @@ -113,7 +123,7 @@ async fn accept_connection(
if data.is_none() {
continue;
}
message_tx
if message_tx
.send(Message::text(
serde_json::to_string(&WsEvent {
data,
Expand All @@ -123,34 +133,39 @@ async fn accept_connection(
.expect("serializes"),
))
.await
.expect("capacity");
.is_err()
{
break;
}
}
warn!(target: LOG_TARGET, "event stream finished: {sub_account_id:?}, sending close");
let _ = message_tx.send(Message::Close(None)).await;
subscription_map.lock().await.remove(&sub_account_id);
}
});

stream_handle = Some(join_handle);
subscription_map.insert(request.sub_account_id, join_handle);
}
Method::Unsubscribe => {
info!(target: LOG_TARGET, "unsubscribing: {}", request.sub_account_id);
info!(target: LOG_TARGET, "unsubscribing events of: {}", request.sub_account_id);
// TODO: support ending by channel, this ends all channels
if let Some(task) = stream_handle.take() {
let mut subscription_map = subscriptions.lock().await;
if let Some(task) = subscription_map.remove(&request.sub_account_id) {
task.abort();
}
}
}
}
Err(err) => {
message_tx
.try_send(Message::text(
.send(Message::text(
json!({
"error": "bad request",
"reason": err.to_string(),
})
.to_string(),
))
.expect("capacity");
.await
.unwrap();
}
},
Message::Close(frame) => {
Expand All @@ -161,7 +176,7 @@ async fn accept_connection(
_ => (),
}
}
info!("closing Ws connection: {}", addr);
info!(target: LOG_TARGET, "closing Ws connection: {}", addr);
}

#[derive(Deserialize, Debug)]
Expand Down

0 comments on commit 8b36c4c

Please sign in to comment.