diff --git a/Cargo.toml b/Cargo.toml index b149e0e..d0297ea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,8 +33,8 @@ criterion = "0.3" bench = false [[bench]] -name = "pub_sub" -harness = false +name = "pub_sub" +harness = false bench = false # Don't actually benchmark this, until we fix it [[bench]] diff --git a/src/codec/framed.rs b/src/codec/framed.rs new file mode 100644 index 0000000..64c650f --- /dev/null +++ b/src/codec/framed.rs @@ -0,0 +1,34 @@ +//! General types and traits to facilitate compatibility across async runtimes + +use crate::codec::ZmqCodec; + +// We use dynamic dispatch to avoid complicated generics and simplify things +type Inner = futures_codec::Framed, ZmqCodec>; + +// Enables us to have multiple bounds on the dyn trait in `InnerFramed` +pub trait Frameable: futures::AsyncWrite + futures::AsyncRead + Unpin + Send {} +impl Frameable for T where T: futures::AsyncWrite + futures::AsyncRead + Unpin + Send {} + +/// Equivalent to [`futures_codec::Framed`] or +/// [`tokio_util::codec::Framed`] +pub(crate) struct FramedIo(Inner); +impl FramedIo { + pub fn new(frameable: Box) -> Self { + let inner = futures_codec::Framed::new(frameable, ZmqCodec::new()); + Self(inner) + } +} + +impl std::ops::Deref for FramedIo { + type Target = Inner; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::ops::DerefMut for FramedIo { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} diff --git a/src/codec/mod.rs b/src/codec/mod.rs index 6e19a5a..fd9f8e0 100644 --- a/src/codec/mod.rs +++ b/src/codec/mod.rs @@ -1,11 +1,13 @@ mod command; mod error; +mod framed; mod greeting; mod mechanism; mod zmq_codec; pub(crate) use command::{ZmqCommand, ZmqCommandName}; pub(crate) use error::{CodecError, CodecResult}; +pub(crate) use framed::FramedIo; pub(crate) use greeting::ZmqGreeting; pub(crate) use zmq_codec::ZmqCodec; diff --git a/src/dealer_router.rs b/src/dealer_router.rs index ac7e789..9afc2b9 100644 --- a/src/dealer_router.rs +++ b/src/dealer_router.rs @@ -2,21 +2,21 @@ use async_trait::async_trait; use dashmap::DashMap; use futures::channel::{mpsc, oneshot}; use futures::lock::Mutex; +use futures::stream::{FuturesUnordered, StreamExt}; use futures::SinkExt; -use futures_codec::Framed; +use std::collections::HashMap; use std::convert::TryInto; use std::sync::Arc; -use tokio::net::TcpStream; +use crate::codec::FramedIo; use crate::codec::*; use crate::endpoint::{Endpoint, TryIntoEndpoint}; use crate::error::*; use crate::message::*; -use crate::util::*; -use crate::{util, MultiPeer, Socket, SocketBackend}; +use crate::transport; +use crate::util::{self, Peer, PeerIdentity}; +use crate::{MultiPeer, Socket, SocketBackend}; use crate::{SocketType, ZmqResult}; -use futures::stream::{FuturesUnordered, StreamExt}; -use std::collections::HashMap; struct RouterSocketBackend { pub(crate) peers: Arc>, @@ -97,8 +97,12 @@ impl Socket for RouterSocket { async fn bind(&mut self, endpoint: impl TryIntoEndpoint + 'async_trait) -> ZmqResult { let endpoint = endpoint.try_into()?; - let (endpoint, stop_handle) = - util::start_accepting_connections(endpoint, self.backend.clone()).await?; + let Endpoint::Tcp(host, port) = endpoint; + + let cloned_backend = self.backend.clone(); + let cback = move |result| util::peer_connected(result, cloned_backend.clone()); + let (endpoint, stop_handle) = transport::tcp::begin_accept(host, port, cback).await?; + self.binds.insert(endpoint.clone(), stop_handle); Ok(endpoint) } @@ -159,7 +163,7 @@ impl RouterSocket { } pub struct DealerSocket { - pub(crate) _inner: Framed, + pub(crate) _inner: FramedIo, } impl DealerSocket { diff --git a/src/error.rs b/src/error.rs index e45a600..16464a4 100644 --- a/src/error.rs +++ b/src/error.rs @@ -3,6 +3,8 @@ use crate::ZmqMessage; use thiserror::Error; +pub type ZmqResult = Result; + #[derive(Error, Debug)] pub enum ZmqError { #[error(transparent)] diff --git a/src/lib.rs b/src/lib.rs index de15f82..6da910a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,14 +1,4 @@ #![recursion_limit = "1024"] -#[macro_use] -extern crate enum_primitive_derive; -use num_traits::ToPrimitive; - -use async_trait::async_trait; -use futures::channel::{mpsc, oneshot}; -use std::convert::TryFrom; -use std::fmt::{Debug, Display}; - -use futures_codec::Framed; mod codec; mod dealer_router; @@ -20,21 +10,30 @@ mod r#pub; mod rep; mod req; mod sub; +mod transport; pub mod util; -use crate::codec::*; pub use crate::dealer_router::*; pub use crate::endpoint::{Endpoint, Host, Transport, TryIntoEndpoint}; -pub use crate::error::ZmqError; +pub use crate::error::{ZmqError, ZmqResult}; pub use crate::r#pub::*; pub use crate::rep::*; pub use crate::req::*; pub use crate::sub::*; -use crate::util::*; pub use message::*; -use std::collections::HashMap; -pub type ZmqResult = Result; +use crate::codec::*; +use crate::util::*; + +#[macro_use] +extern crate enum_primitive_derive; + +use async_trait::async_trait; +use futures::channel::{mpsc, oneshot}; +use num_traits::ToPrimitive; +use std::collections::HashMap; +use std::convert::TryFrom; +use std::fmt::{Debug, Display}; #[derive(Clone, Copy, Debug, PartialEq, Primitive)] pub enum SocketType { diff --git a/src/pub.rs b/src/pub.rs index d1b6bdd..8cb7911 100644 --- a/src/pub.rs +++ b/src/pub.rs @@ -1,8 +1,10 @@ use crate::codec::*; use crate::endpoint::{Endpoint, TryIntoEndpoint}; use crate::message::*; +use crate::transport; use crate::util::*; use crate::{util, MultiPeer, NonBlockingSend, Socket, SocketBackend, SocketType, ZmqResult}; + use async_trait::async_trait; use dashmap::DashMap; use futures::channel::{mpsc, oneshot}; @@ -145,8 +147,12 @@ impl Socket for PubSocket { async fn bind(&mut self, endpoint: impl TryIntoEndpoint + 'async_trait) -> ZmqResult { let endpoint = endpoint.try_into()?; - let (endpoint, stop_handle) = - util::start_accepting_connections(endpoint, self.backend.clone()).await?; + let Endpoint::Tcp(host, port) = endpoint; + + let cloned_backend = self.backend.clone(); + let cback = move |result| util::peer_connected(result, cloned_backend.clone()); + let (endpoint, stop_handle) = transport::tcp::begin_accept(host, port, cback).await?; + self.binds.insert(endpoint.clone(), stop_handle); Ok(endpoint) } @@ -155,8 +161,8 @@ impl Socket for PubSocket { let endpoint = endpoint.try_into()?; let Endpoint::Tcp(host, port) = endpoint; - let raw_socket = tokio::net::TcpStream::connect((host.to_string().as_str(), port)).await?; - util::peer_connected(raw_socket, self.backend.clone()).await; + let connect_result = transport::tcp::connect(host, port).await; + util::peer_connected(connect_result, self.backend.clone()).await; Ok(()) } diff --git a/src/rep.rs b/src/rep.rs index 4585961..163e89a 100644 --- a/src/rep.rs +++ b/src/rep.rs @@ -2,15 +2,17 @@ use crate::codec::*; use crate::endpoint::{Endpoint, TryIntoEndpoint}; use crate::error::*; use crate::fair_queue::FairQueue; +use crate::transport; use crate::util::FairQueueProcessor; use crate::*; use crate::{util, SocketType, ZmqResult}; + use async_trait::async_trait; use dashmap::DashMap; -use futures_util::sink::SinkExt; +use futures::SinkExt; +use futures::StreamExt; use std::collections::HashMap; use std::sync::Arc; -use tokio::stream::StreamExt; struct RepPeer { pub(crate) _identity: PeerIdentity, @@ -66,8 +68,12 @@ impl Socket for RepSocket { async fn bind(&mut self, endpoint: impl TryIntoEndpoint + 'async_trait) -> ZmqResult { let endpoint = endpoint.try_into()?; - let (endpoint, stop_handle) = - util::start_accepting_connections(endpoint, self.backend.clone()).await?; + let Endpoint::Tcp(host, port) = endpoint; + + let cloned_backend = self.backend.clone(); + let cback = move |result| util::peer_connected(result, cloned_backend.clone()); + let (endpoint, stop_handle) = transport::tcp::begin_accept(host, port, cback).await?; + self.binds.insert(endpoint.clone(), stop_handle); Ok(endpoint) } @@ -76,8 +82,8 @@ impl Socket for RepSocket { let endpoint = endpoint.try_into()?; let Endpoint::Tcp(host, port) = endpoint; - let raw_socket = tokio::net::TcpStream::connect((host.to_string().as_str(), port)).await?; - util::peer_connected(raw_socket, self.backend.clone()).await; + let connect_result = transport::tcp::connect(host, port).await; + util::peer_connected(connect_result, self.backend.clone()).await; Ok(()) } diff --git a/src/req.rs b/src/req.rs index f5867e7..3614465 100644 --- a/src/req.rs +++ b/src/req.rs @@ -1,18 +1,19 @@ use crate::codec::*; use crate::endpoint::{Endpoint, TryIntoEndpoint}; use crate::error::*; +use crate::transport; use crate::util::{self, Peer, PeerIdentity}; use crate::*; use crate::{SocketType, ZmqResult}; + use async_trait::async_trait; use crossbeam::queue::SegQueue; use dashmap::DashMap; use futures::channel::{mpsc, oneshot}; use futures::lock::Mutex; -use futures_util::sink::SinkExt; +use futures::{SinkExt, StreamExt}; use std::collections::HashMap; use std::sync::Arc; -use tokio::stream::StreamExt; struct ReqSocketBackend { pub(crate) peers: DashMap, @@ -124,8 +125,12 @@ impl Socket for ReqSocket { async fn bind(&mut self, endpoint: impl TryIntoEndpoint + 'async_trait) -> ZmqResult { let endpoint = endpoint.try_into()?; - let (endpoint, stop_handle) = - util::start_accepting_connections(endpoint, self.backend.clone()).await?; + let Endpoint::Tcp(host, port) = endpoint; + + let cloned_backend = self.backend.clone(); + let cback = move |result| util::peer_connected(result, cloned_backend.clone()); + let (endpoint, stop_handle) = transport::tcp::begin_accept(host, port, cback).await?; + self.binds.insert(endpoint.clone(), stop_handle); Ok(endpoint) } @@ -134,8 +139,8 @@ impl Socket for ReqSocket { let endpoint = endpoint.try_into()?; let Endpoint::Tcp(host, port) = endpoint; - let raw_socket = tokio::net::TcpStream::connect((host.to_string().as_str(), port)).await?; - util::peer_connected(raw_socket, self.backend.clone()).await; + let connect_result = transport::tcp::connect(host, port).await; + util::peer_connected(connect_result, self.backend.clone()).await; Ok(()) } diff --git a/src/sub.rs b/src/sub.rs index f614183..1a5b557 100644 --- a/src/sub.rs +++ b/src/sub.rs @@ -2,15 +2,15 @@ use crate::codec::*; use crate::endpoint::{Endpoint, TryIntoEndpoint}; use crate::fair_queue::FairQueue; use crate::message::*; +use crate::transport; use crate::util::*; use crate::{util, BlockingRecv, MultiPeer, Socket, SocketBackend, SocketType, ZmqResult}; + use async_trait::async_trait; use bytes::{BufMut, BytesMut}; use dashmap::DashMap; use futures::channel::{mpsc, oneshot}; -use futures::SinkExt; -use futures::StreamExt; - +use futures::{SinkExt, StreamExt}; use std::collections::HashMap; use std::sync::Arc; @@ -147,8 +147,12 @@ impl Socket for SubSocket { async fn bind(&mut self, endpoint: impl TryIntoEndpoint + 'async_trait) -> ZmqResult { let endpoint = endpoint.try_into()?; - let (endpoint, stop_handle) = - util::start_accepting_connections(endpoint, self.backend.clone()).await?; + let Endpoint::Tcp(host, port) = endpoint; + + let cloned_backend = self.backend.clone(); + let cback = move |result| util::peer_connected(result, cloned_backend.clone()); + let (endpoint, stop_handle) = transport::tcp::begin_accept(host, port, cback).await?; + self.binds.insert(endpoint.clone(), stop_handle); Ok(endpoint) } @@ -157,8 +161,8 @@ impl Socket for SubSocket { let endpoint = endpoint.try_into()?; let Endpoint::Tcp(host, port) = endpoint; - let raw_socket = tokio::net::TcpStream::connect((host.to_string().as_str(), port)).await?; - util::peer_connected(raw_socket, self.backend.clone()).await; + let connect_result = transport::tcp::connect(host, port).await; + util::peer_connected(connect_result, self.backend.clone()).await; Ok(()) } diff --git a/src/transport/mod.rs b/src/transport/mod.rs new file mode 100644 index 0000000..fcb722b --- /dev/null +++ b/src/transport/mod.rs @@ -0,0 +1 @@ +pub mod tcp; diff --git a/src/transport/tcp/mod.rs b/src/transport/tcp/mod.rs new file mode 100644 index 0000000..b033c16 --- /dev/null +++ b/src/transport/tcp/mod.rs @@ -0,0 +1,30 @@ +// TODO: Conditionally compile things +mod tokio; + +use self::tokio as tk; +use crate::codec::FramedIo; +use crate::endpoint::{Endpoint, Host, Port}; +use crate::ZmqResult; + +use std::net::SocketAddr; + +pub(crate) async fn connect(host: Host, port: Port) -> ZmqResult<(FramedIo, SocketAddr)> { + tk::connect(host, port).await +} + +/// Spawns an async task that listens for tcp connections at the provided +/// address. +/// +/// `cback` will be invoked when a tcp connection is accepted. If the result was +/// `Ok`, we get a tuple containing the framed raw socket, along with the ip +/// address of the remote connection accepted. +pub(crate) async fn begin_accept( + host: Host, + port: Port, + cback: impl Fn(ZmqResult<(FramedIo, SocketAddr)>) -> T + Send + 'static, +) -> ZmqResult<(Endpoint, futures::channel::oneshot::Sender)> +where + T: std::future::Future + Send + 'static, +{ + tk::begin_accept(host, port, cback).await +} diff --git a/src/transport/tcp/tokio.rs b/src/transport/tcp/tokio.rs new file mode 100644 index 0000000..0e8ed81 --- /dev/null +++ b/src/transport/tcp/tokio.rs @@ -0,0 +1,57 @@ +//! Tokio-specific utilities for the functionality in [`super::compat`] + +use crate::codec::FramedIo; +use crate::endpoint::{Endpoint, Host, Port}; +use crate::ZmqResult; + +use futures::{select, FutureExt}; +use std::net::SocketAddr; +use tokio_util::compat::Tokio02AsyncReadCompatExt; + +pub(crate) async fn connect(host: Host, port: Port) -> ZmqResult<(FramedIo, SocketAddr)> { + let raw_socket = tokio::net::TcpStream::connect((host.to_string().as_str(), port)).await?; + let remote_addr = raw_socket.peer_addr()?; + let boxed_sock = Box::new(raw_socket.compat()); + Ok((FramedIo::new(boxed_sock), remote_addr)) +} + +pub(crate) async fn begin_accept( + mut host: Host, + port: Port, + cback: impl Fn(ZmqResult<(FramedIo, SocketAddr)>) -> T + Send + 'static, +) -> ZmqResult<(Endpoint, futures::channel::oneshot::Sender)> +where + T: std::future::Future + Send + 'static, +{ + let mut listener = tokio::net::TcpListener::bind((host.to_string().as_str(), port)).await?; + let resolved_addr = listener.local_addr()?; + let (stop_handle, stop_callback) = futures::channel::oneshot::channel::(); + tokio::spawn(async move { + let mut stop_callback = stop_callback.fuse(); + loop { + select! { + incoming = listener.accept().fuse() => { + let maybe_accepted: Result<_, _> = incoming.map(|(raw_sock, remote_addr)| { + let raw_sock = FramedIo::new(Box::new(raw_sock.compat())); + (raw_sock, remote_addr) + }).map_err(|err| err.into()); + tokio::spawn(cback(maybe_accepted.into())); + }, + _ = stop_callback => { + break + } + } + } + }); + debug_assert_ne!(resolved_addr.port(), 0); + let port = resolved_addr.port(); + let resolved_host: Host = resolved_addr.ip().into(); + if let Host::Ipv4(ip) = host { + debug_assert_eq!(ip, resolved_addr.ip()); + host = resolved_host; + } else if let Host::Ipv6(ip) = host { + debug_assert_eq!(ip, resolved_addr.ip()); + host = resolved_host; + } + Ok((Endpoint::Tcp(host, port), stop_handle)) +} diff --git a/src/util.rs b/src/util.rs index 5861ad3..e9e9fc8 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,18 +1,15 @@ use crate::codec::CodecResult; -use crate::endpoint::Endpoint; +use crate::codec::FramedIo; use crate::fair_queue::FairQueue; use crate::*; use bytes::Bytes; use futures::lock::Mutex; use futures::stream::StreamExt; -use futures::{select, SinkExt}; -use futures_util::future::FutureExt; +use futures::SinkExt; use std::convert::{TryFrom, TryInto}; +use std::net::SocketAddr; use std::sync::Arc; -use tokio::net::TcpStream; -use tokio_util::compat::Compat; -use tokio_util::compat::Tokio02AsyncReadCompatExt; use uuid::Uuid; #[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Clone)] @@ -91,14 +88,12 @@ pub fn sockets_compatible(one: SocketType, another: SocketType) -> bool { COMPATIBILITY_MATRIX[row_index * 11 + col_index] != 0 } -pub(crate) async fn greet_exchange( - socket: &mut Framed, ZmqCodec>, -) -> ZmqResult<()> { - socket +pub(crate) async fn greet_exchange(raw_socket: &mut FramedIo) -> ZmqResult<()> { + raw_socket .send(Message::Greeting(ZmqGreeting::default())) .await?; - let greeting: Option> = socket.next().await; + let greeting: Option> = raw_socket.next().await; match greeting { Some(Ok(Message::Greeting(greet))) => match greet.version { @@ -110,13 +105,13 @@ pub(crate) async fn greet_exchange( } pub(crate) async fn ready_exchange( - socket: &mut Framed, ZmqCodec>, + raw_socket: &mut FramedIo, socket_type: SocketType, ) -> ZmqResult { let ready = ZmqCommand::ready(socket_type); - socket.send(Message::Command(ready)).await?; + raw_socket.send(Message::Command(ready)).await?; - let ready_repl: Option> = socket.next().await; + let ready_repl: Option> = raw_socket.next().await; match ready_repl { Some(Ok(Message::Command(command))) => match command.name { ZmqCommandName::READY => { @@ -148,9 +143,11 @@ pub(crate) async fn ready_exchange( } } -pub(crate) async fn peer_connected(socket: tokio::net::TcpStream, backend: Arc) { - let mut raw_socket = Framed::new(socket.compat(), ZmqCodec::new()); - +pub(crate) async fn peer_connected( + accept_result: ZmqResult<(FramedIo, SocketAddr)>, + backend: Arc, +) { + let (mut raw_socket, _remote_addr) = accept_result.expect("Failed to accept"); greet_exchange(&mut raw_socket) .await .expect("Failed to exchange greetings"); @@ -199,48 +196,6 @@ pub(crate) async fn peer_connected(socket: tokio::net::TcpStream, backend: Arc, -) -> ZmqResult<(Endpoint, futures::channel::oneshot::Sender)> { - let Endpoint::Tcp(mut host, port) = endpoint; - - let mut listener = tokio::net::TcpListener::bind((host.to_string().as_str(), port)).await?; - let resolved_addr = listener.local_addr()?; - let (stop_handle, stop_callback) = futures::channel::oneshot::channel::(); - tokio::spawn(async move { - let mut stop_callback = stop_callback.fuse(); - loop { - select! { - incoming = listener.accept().fuse() => { - let (socket, _) = incoming.expect("Failed to accept connection"); - tokio::spawn(peer_connected(socket, backend.clone())); - }, - _ = stop_callback => { - break - } - } - } - }); - debug_assert_ne!(resolved_addr.port(), 0); - let port = resolved_addr.port(); - let resolved_host: Host = resolved_addr.ip().into(); - if let Host::Ipv4(ip) = host { - debug_assert_eq!(ip, resolved_addr.ip()); - host = resolved_host; - } else if let Host::Ipv6(ip) = host { - debug_assert_eq!(ip, resolved_addr.ip()); - host = resolved_host; - } - Ok((Endpoint::Tcp(host, port), stop_handle)) -} - pub(crate) struct FairQueueProcessor { pub(crate) fair_queue_stream: FairQueue, PeerIdentity>, pub(crate) socket_incoming_queue: mpsc::Sender<(PeerIdentity, Message)>,