Skip to content

Commit

Permalink
Update QUINN to 0.8.5 in bridge
Browse files Browse the repository at this point in the history
  • Loading branch information
TheHellBox committed Nov 9, 2024
1 parent a18bc05 commit b41c36f
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 107 deletions.
8 changes: 4 additions & 4 deletions kissmp-bridge/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,20 @@ bincode = "1.3"
serde = { version = "1.0", features = ["derive"] }
serde_json="1.0"
futures = "0.3.5"
quinn = "0.7.1"
quinn = {version="0.8.5", features = ["tls-rustls"]}
rustls = { version = "0.20.3", features = ["dangerous_configuration"] }
# Held back due to rustls using webpki 0.21
webpki = "0.21"
anyhow = "1.0.32"
reqwest = { version = "0.11", default-features = false, features=["rustls-tls"] }
tiny_http="0.8"
tokio-stream="0.1.5"
rustls = { version = "0.19", features = ["dangerous_configuration"] }
tokio = { version = "1.4", features = ["time", "macros", "sync", "io-util", "net"] }
discord-rpc-client = {version = "0.3", optional = true}
discord-rpc-client = {version = "0.4", optional = true}
percent-encoding = "2.1"
audiopus = "0.2"
rodio = "0.14"
cpal = "0.13"
fon = "0.5.0"
log = "0.4"
indoc = "1.0"
indoc = "1.0"
195 changes: 92 additions & 103 deletions kissmp-bridge/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@ pub mod http_proxy;
pub mod voice_chat;

use futures::stream::FuturesUnordered;
use futures::{StreamExt};
use futures::StreamExt;
use quinn::IdleTimeout;
use rustls::{Certificate, ServerName};
use std::convert::TryFrom;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs};
use std::sync::Arc;
use std::time::SystemTime;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, WriteHalf};
use tokio::net::{TcpListener, TcpStream};
#[macro_use]
Expand All @@ -18,9 +23,7 @@ pub struct DiscordState {
pub server_name: Option<String>,
}

async fn read_pascal_bytes<R: AsyncRead + Unpin>(
stream: &mut R
) -> Result<Vec<u8>, anyhow::Error> {
async fn read_pascal_bytes<R: AsyncRead + Unpin>(stream: &mut R) -> Result<Vec<u8>, anyhow::Error> {
let mut buffer = [0; 4];
stream.read_exact(&mut buffer).await?;
let len = u32::from_le_bytes(buffer) as usize;
Expand All @@ -31,8 +34,8 @@ async fn read_pascal_bytes<R: AsyncRead + Unpin>(

async fn write_pascal_bytes<W: AsyncWrite + Unpin>(
stream: &mut W,
bytes: &mut Vec<u8>
) -> Result<(), anyhow::Error>{
bytes: &mut Vec<u8>,
) -> Result<(), anyhow::Error> {
let len = bytes.len() as u32;
let mut data = Vec::with_capacity(len as usize + 4);
data.append(&mut len.to_le_bytes().to_vec());
Expand All @@ -56,28 +59,25 @@ async fn main() {
let listener = TcpListener::bind(bind_addr).await.unwrap();
info!("Bridge is running!");
while let Ok((mut client_stream, _)) = listener.accept().await {

info!("Attempting to connect to a server...");

let addr = {
let address_string =
String::from_utf8(
read_pascal_bytes(&mut client_stream).await.unwrap()
).unwrap();
let address_string =
String::from_utf8(read_pascal_bytes(&mut client_stream).await.unwrap()).unwrap();

let mut socket_addrs = match address_string.to_socket_addrs() {
Ok(socket_addrs) => socket_addrs,
Err(e) => {
error!("Failed to parse address: {}", e);
continue;
},
};
Ok(socket_addrs) => socket_addrs,
Err(e) => {
error!("Failed to parse address: {}", e);
continue;
}
};
match socket_addrs.next() {
Some(addr) => addr,
None => {
error!("Could not find address: {}", address_string);
continue;
},
}
}
};

Expand All @@ -89,56 +89,42 @@ async fn main() {
async fn connect_to_server(
addr: SocketAddr,
client_stream: TcpStream,
discord_tx: std::sync::mpsc::Sender<DiscordState>
discord_tx: std::sync::mpsc::Sender<DiscordState>,
) -> () {
let endpoint = {
let mut client_cfg = quinn::ClientConfig::default();

let rustls_config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_custom_certificate_verifier(Arc::new(AcceptAnyCertificate))
.with_no_client_auth();
let mut client_cfg = quinn::ClientConfig::new(Arc::new(rustls_config));

let mut transport = quinn::TransportConfig::default();
transport
.max_idle_timeout(Some(SERVER_IDLE_TIMEOUT))
.unwrap();
transport.max_idle_timeout(Some(IdleTimeout::try_from(SERVER_IDLE_TIMEOUT).unwrap()));
client_cfg.transport = std::sync::Arc::new(transport);

let tls_cfg = std::sync::Arc::get_mut(&mut client_cfg.crypto).unwrap();
tls_cfg
.dangerous()
.set_certificate_verifier(std::sync::Arc::new(AcceptAnyCertificate));

let mut endpoint = quinn::Endpoint::builder();
endpoint.default_client_config(client_cfg);

let bind_from = match addr {
SocketAddr::V4(_) => IpAddr::from(Ipv4Addr::UNSPECIFIED),
SocketAddr::V6(_) => IpAddr::from(Ipv6Addr::UNSPECIFIED),
};

let mut endpoint = quinn::Endpoint::client(addr).unwrap();
endpoint.set_default_client_config(client_cfg);
endpoint
.bind(&SocketAddr::new(bind_from, 0))
.unwrap().0
};

let server_connection = match endpoint.connect(&addr, "kissmp").unwrap().await {
let server_connection = match endpoint.connect(addr, "kissmp").unwrap().await {
Ok(c) => c,
Err(e) => {
error!("Failed to connect to the server: {}", e);
return;
},
}
};

let (client_stream_reader, mut client_stream_writer) =
tokio::io::split(client_stream);
let (client_stream_reader, mut client_stream_writer) = tokio::io::split(client_stream);

let _ = client_stream_writer.write_all(CONNECTED_BYTE).await;

let (client_event_sender, client_event_receiver) =
tokio::sync::mpsc::unbounded_channel::<(bool, shared::ClientCommand)>();
let (server_commands_sender, server_commands_receiver) =
tokio::sync::mpsc::channel::<shared::ServerCommand>(256);
let (vc_recording_sender, vc_recording_receiver) =
std::sync::mpsc::channel();
let (vc_playback_sender, vc_playback_receiver) =
std::sync::mpsc::channel();
let (vc_recording_sender, vc_recording_receiver) = std::sync::mpsc::channel();
let (vc_playback_sender, vc_playback_receiver) = std::sync::mpsc::channel();

// TODO: Use a struct that can hold either a JoinHandle or a bare future so
// additional tasks that do not depend on using tokio::spawn can be added.
Expand All @@ -148,18 +134,25 @@ async fn connect_to_server(
Ok(handle) => {
non_critical_tasks.push(handle);
debug!("Playback OK")
},
Err(e) => {error!("Failed to set up voice chat playback: {}", e)},
}
Err(e) => {
error!("Failed to set up voice chat playback: {}", e)
}
};

match voice_chat::try_create_vc_recording_task(client_event_sender.clone(), vc_recording_receiver) {
match voice_chat::try_create_vc_recording_task(
client_event_sender.clone(),
vc_recording_receiver,
) {
Ok(handle) => {
non_critical_tasks.push(handle);
debug!("Recording OK")
},
Err(e) => {error!("Failed to set up voice chat recording: {}", e)},
}
Err(e) => {
error!("Failed to set up voice chat recording: {}", e)
}
};

tokio::spawn(async move {
debug!("Starting tasks");
match tokio::try_join!(
Expand All @@ -168,28 +161,25 @@ async fn connect_to_server(
match result {
Err(e) => warn!("Non-critical task failed: {}", e),
Ok(Err(e)) => warn!("Non-critical task died with exception: {}", e),
_ => ()
_ => (),
}
}
Ok(())
},
client_outgoing(
server_commands_receiver,
client_stream_writer),
client_outgoing(server_commands_receiver, client_stream_writer),
client_incoming(
server_connection.connection.clone(),
vc_playback_sender.clone(),
client_stream_reader,
vc_recording_sender,
client_event_sender),
server_outgoing(
server_connection.connection.clone(),
client_event_receiver),
client_event_sender
),
server_outgoing(server_connection.connection.clone(), client_event_receiver),
server_incoming(
server_commands_sender,
vc_playback_sender,
server_connection),

server_connection
),
) {
Ok(_) => debug!("Tasks completed successfully"),
Err(e) => warn!("Tasks ended due to exception: {}", e),
Expand All @@ -211,7 +201,9 @@ fn server_command_to_client_bytes(command: shared::ServerCommand) -> Vec<u8> {
result.append(&mut data.clone());
result
}
shared::ServerCommand::VoiceChatPacket(_, _, _) => panic!("Voice packets have to handled by the bridge itself."),
shared::ServerCommand::VoiceChatPacket(_, _, _) => {
panic!("Voice packets have to handled by the bridge itself.")
}
_ => {
let json = serde_json::to_string(&command).unwrap();
//println!("{:?}", json);
Expand All @@ -228,10 +220,12 @@ type AHResult = Result<(), anyhow::Error>;

async fn client_outgoing(
mut server_commands_receiver: tokio::sync::mpsc::Receiver<shared::ServerCommand>,
mut client_stream_writer: WriteHalf<TcpStream>
mut client_stream_writer: WriteHalf<TcpStream>,
) -> AHResult {
while let Some(server_command) = server_commands_receiver.recv().await {
client_stream_writer.write_all(server_command_to_client_bytes(server_command).as_ref()).await?;
client_stream_writer
.write_all(server_command_to_client_bytes(server_command).as_ref())
.await?;
}
debug!("Server outgoing closed");
Ok(())
Expand All @@ -240,52 +234,47 @@ async fn client_outgoing(
async fn server_incoming(
server_commands_sender: tokio::sync::mpsc::Sender<shared::ServerCommand>,
vc_playback_sender: std::sync::mpsc::Sender<voice_chat::VoiceChatPlaybackEvent>,
server_connection: quinn::generic::NewConnection<quinn::crypto::rustls::TlsSession>
server_connection: quinn::NewConnection,
) -> AHResult {
let mut reliable_commands = server_connection
.uni_streams
.map(|stream| async {
Ok::<_, anyhow::Error>(read_pascal_bytes(&mut stream?).await?)
});
.map(|stream| async { Ok::<_, anyhow::Error>(read_pascal_bytes(&mut stream?).await?) });

let mut unreliable_commands = server_connection
.datagrams
.map(|data| async {
Ok::<_, anyhow::Error>(data?.to_vec())
})
.map(|data| async { Ok::<_, anyhow::Error>(data?.to_vec()) })
.buffer_unordered(1024);

loop {
let command_bytes =
tokio::select! {
Some(reliable_command) = reliable_commands.next() => {
reliable_command.await?
},
Some(unreliable_command) = unreliable_commands.next() => {
unreliable_command?
},
else => break
};
let command_bytes = tokio::select! {
Some(reliable_command) = reliable_commands.next() => {
reliable_command.await?
},
Some(unreliable_command) = unreliable_commands.next() => {
unreliable_command?
},
else => break
};
let command = bincode::deserialize::<shared::ServerCommand>(command_bytes.as_ref())?;
match command {
shared::ServerCommand::VoiceChatPacket(client, pos, data) =>{
shared::ServerCommand::VoiceChatPacket(client, pos, data) => {
let _ = vc_playback_sender.send(voice_chat::VoiceChatPlaybackEvent::Packet(
client, pos, data,
));
},
_ => server_commands_sender.send(command).await?
}
_ => server_commands_sender.send(command).await?,
};
};
}
debug!("Server incoming closed");
Ok(())
}

async fn client_incoming(
server_stream: quinn::generic::Connection<quinn::crypto::rustls::TlsSession>,
server_stream: quinn::Connection,
vc_playback_sender: std::sync::mpsc::Sender<voice_chat::VoiceChatPlaybackEvent>,
mut client_stream_reader: tokio::io::ReadHalf<TcpStream>,
vc_recording_sender: std::sync::mpsc::Sender<voice_chat::VoiceChatRecordingEvent>,
client_event_sender: tokio::sync::mpsc::UnboundedSender<(bool, shared::ClientCommand)>
client_event_sender: tokio::sync::mpsc::UnboundedSender<(bool, shared::ClientCommand)>,
) -> AHResult {
let mut buffer = [0; 1];
while let Ok(_) = client_stream_reader.read_exact(&mut buffer).await {
Expand All @@ -300,9 +289,7 @@ async fn client_incoming(
match decoded {
shared::ClientCommand::SpatialUpdate(left_ear, right_ear) => {
let _ = vc_playback_sender.send(
voice_chat::VoiceChatPlaybackEvent::PositionUpdate(
left_ear, right_ear,
),
voice_chat::VoiceChatPlaybackEvent::PositionUpdate(left_ear, right_ear),
);
}
shared::ClientCommand::StartTalking => {
Expand All @@ -325,9 +312,9 @@ async fn client_incoming(
}

async fn server_outgoing(
server_stream: quinn::generic::Connection<quinn::crypto::rustls::TlsSession>,
mut client_event_receiver: tokio::sync::mpsc::UnboundedReceiver<(bool, shared::ClientCommand)>
) -> AHResult {
server_stream: quinn::Connection,
mut client_event_receiver: tokio::sync::mpsc::UnboundedReceiver<(bool, shared::ClientCommand)>,
) -> AHResult {
while let Some((reliable, client_command)) = client_event_receiver.recv().await {
let mut data = bincode::serialize::<shared::ClientCommand>(&client_command)?;
if !reliable {
Expand All @@ -342,14 +329,16 @@ async fn server_outgoing(

struct AcceptAnyCertificate;

impl rustls::ServerCertVerifier for AcceptAnyCertificate {
impl rustls::client::ServerCertVerifier for AcceptAnyCertificate {
fn verify_server_cert(
&self,
_roots: &rustls::RootCertStore,
_presented_certs: &[rustls::Certificate],
_dns_name: webpki::DNSNameRef,
_ocsp_response: &[u8],
) -> Result<rustls::ServerCertVerified, rustls::TLSError> {
Ok(rustls::ServerCertVerified::assertion())
_end_entity: &Certificate,
_: &[Certificate],
_: &ServerName,
scts: &mut dyn Iterator<Item = &[u8]>,
ocsp_response: &[u8],
now: SystemTime,
) -> Result<rustls::client::ServerCertVerified, rustls::TLSError> {
Ok(rustls::client::ServerCertVerified::assertion())
}
}

0 comments on commit b41c36f

Please sign in to comment.