From 8f58f5de38b6b60251f5e97ba9de81be358f2846 Mon Sep 17 00:00:00 2001 From: yallie Date: Sat, 30 Nov 2024 06:24:32 +0300 Subject: [PATCH 1/2] Added sketch channel based on QUIC network protocol (requires .NET 9.0). --- .../CertificateHelper.cs | 43 ++++ .../CoreRemoting.Channels.Quic.csproj | 33 +++ .../QuicClientChannel.cs | 223 ++++++++++++++++++ .../QuicServerChannel.cs | 116 +++++++++ .../QuicServerConnection.cs | 140 +++++++++++ 5 files changed, 555 insertions(+) create mode 100644 CoreRemoting.Channels.Quic/CertificateHelper.cs create mode 100644 CoreRemoting.Channels.Quic/CoreRemoting.Channels.Quic.csproj create mode 100644 CoreRemoting.Channels.Quic/QuicClientChannel.cs create mode 100644 CoreRemoting.Channels.Quic/QuicServerChannel.cs create mode 100644 CoreRemoting.Channels.Quic/QuicServerConnection.cs diff --git a/CoreRemoting.Channels.Quic/CertificateHelper.cs b/CoreRemoting.Channels.Quic/CertificateHelper.cs new file mode 100644 index 0000000..a651ed9 --- /dev/null +++ b/CoreRemoting.Channels.Quic/CertificateHelper.cs @@ -0,0 +1,43 @@ +using System; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; + +namespace CoreRemoting.Channels.Quic; + +internal class CertificateHelper +{ + public static X509Certificate2 LoadFromPfx(string pfxFilePath, string pfxPassword) => + X509CertificateLoader.LoadPkcs12FromFile(pfxFilePath, pfxPassword); + + public static X509Certificate2 GenerateSelfSigned(string hostName = "localhost") + { + // generate a new certificate + var now = DateTimeOffset.UtcNow; + SubjectAlternativeNameBuilder sanBuilder = new(); + sanBuilder.AddDnsName(hostName); + + using var ec = ECDsa.Create(ECCurve.NamedCurves.nistP256); + CertificateRequest req = new($"CN={hostName}", ec, HashAlgorithmName.SHA256); + + // Adds purpose + req.CertificateExtensions.Add(new X509EnhancedKeyUsageExtension(new OidCollection + { + new("1.3.6.1.5.5.7.3.1") // serverAuth + }, + false)); + + // Adds usage + req.CertificateExtensions.Add(new X509KeyUsageExtension(X509KeyUsageFlags.DigitalSignature, false)); + + // Adds subject alternate names + req.CertificateExtensions.Add(sanBuilder.Build()); + + // Sign + using var crt = req.CreateSelfSigned(now, now.AddDays(14)); // 14 days is the max duration of a certificate for this type + + var password = Guid.NewGuid().ToString(); + var pfx = crt.Export(X509ContentType.Pfx, password); + var cert = X509CertificateLoader.LoadPkcs12(pfx, password); + return cert; + } +} diff --git a/CoreRemoting.Channels.Quic/CoreRemoting.Channels.Quic.csproj b/CoreRemoting.Channels.Quic/CoreRemoting.Channels.Quic.csproj new file mode 100644 index 0000000..78540f2 --- /dev/null +++ b/CoreRemoting.Channels.Quic/CoreRemoting.Channels.Quic.csproj @@ -0,0 +1,33 @@ + + + + net9.0.0 + CoreRemoting.Channels.Quic + CoreRemoting.Channels.Quic + 1.2.1 + Alexey Yakovlev + Quic channels for CoreRemoting + 2024 Alexey Yakovlev + https://github.com/theRainbird/CoreRemoting + + true + CoreRemoting.Channels.Quic + https://github.com/theRainbird/CoreRemoting.git + git + 1.2.1 + 10 + + + + 1701;1702;CA1416 + + + + 1701;1702;1416 + + + + + + + diff --git a/CoreRemoting.Channels.Quic/QuicClientChannel.cs b/CoreRemoting.Channels.Quic/QuicClientChannel.cs new file mode 100644 index 0000000..6dc5520 --- /dev/null +++ b/CoreRemoting.Channels.Quic/QuicClientChannel.cs @@ -0,0 +1,223 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Net; +using System.Net.Quic; +using System.Net.Security; +using System.Text; +using System.Threading.Tasks; + +namespace CoreRemoting.Channels.Quic; + +/// +/// Client side QUIC channel implementation based on System.Net.Quic. +/// +public class QuicClientChannel : IClientChannel, IRawMessageTransport +{ + internal const int MaxMessageSize = 1024 * 1024 * 128; + internal const string ProtocolName = nameof(CoreRemoting); + + /// + /// Gets or sets the URL this channel is connected to. + /// + public string Url { get; private set; } + + private Uri Uri { get; set; } + + private IRemotingClient Client { get; set; } + + private QuicClientConnectionOptions Options { get; set; } + + private QuicConnection Connection { get; set; } + + private QuicStream ClientStream { get; set; } + + private BinaryReader ClientReader { get; set; } + + private BinaryWriter ClientWriter { get; set; } + + /// + public bool IsConnected { get; private set; } + + /// + public IRawMessageTransport RawMessageTransport => this; + + /// + public NetworkException LastException { get; set; } + + /// + /// Event: fires when the channel is connected. + /// + public event Action Connected; + + /// + public event Action Disconnected; + + /// + public event Action ReceiveMessage; + + /// + public event Action ErrorOccured; + + /// + public void Init(IRemotingClient client) + { + Client = client ?? throw new ArgumentNullException(nameof(client)); + if (!QuicConnection.IsSupported) + throw new NotSupportedException("QUIC is not supported."); + + Url = + "quic://" + + client.Config.ServerHostName + ":" + + Convert.ToString(client.Config.ServerPort) + + "/rpc"; + + Uri = new Uri(Url); + + // prepare QUIC client connection options + Options = new QuicClientConnectionOptions + { + RemoteEndPoint = new IPEndPoint(IPAddress.Loopback, Uri.Port), //new DnsEndPoint(Uri.Host, Uri.Port), + DefaultStreamErrorCode = 0x0A, + DefaultCloseErrorCode = 0x0B, + MaxInboundUnidirectionalStreams = 10, + MaxInboundBidirectionalStreams = 100, + ClientAuthenticationOptions = new SslClientAuthenticationOptions() + { + // accept self-signed certificates generated on-the-fly + RemoteCertificateValidationCallback = (sender, certificate, chain, errors) => true, + ApplicationProtocols = new List() + { + new SslApplicationProtocol(ProtocolName) + } + } + }; + } + + /// + public void Connect() + { + ConnectTask = ConnectTask ?? Task.Factory.StartNew(async () => + { + // connect and open duplex stream + Connection = await QuicConnection.ConnectAsync(Options); + ClientStream = await Connection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); + ClientReader = new BinaryReader(ClientStream, Encoding.UTF8, leaveOpen: true); + ClientWriter = new BinaryWriter(ClientStream, Encoding.UTF8, leaveOpen: true); + + // prepare handshake message + var handshakeMessage = Array.Empty(); + if (Client.MessageEncryption) + { + handshakeMessage = Client.PublicKey; + } + + // send handshake message + SendMessage(handshakeMessage); + IsConnected = true; + Connected?.Invoke(); + + // start listening for incoming messages + _ = Task.Factory.StartNew(() => StartListening()); + }); + + ConnectTask.ConfigureAwait(false) + .GetAwaiter() + .GetResult(); + } + + private Task ConnectTask { get; set; } + + private void StartListening() + { + try + { + while (IsConnected) + { + var messageSize = ClientReader.Read7BitEncodedInt(); + var message = ClientReader.ReadBytes(Math.Min(messageSize, MaxMessageSize)); + if (message.Length > 0) + { + ReceiveMessage(message); + } + } + } + catch (Exception ex) + { + LastException = ex as NetworkException ?? + new NetworkException(ex.Message, ex); + + ErrorOccured?.Invoke(ex.Message, ex); + Disconnected?.Invoke(); + } + finally + { + Disconnect(); + } + } + + /// + public bool SendMessage(byte[] rawMessage) + { + try + { + if (rawMessage.Length > MaxMessageSize) + throw new InvalidOperationException("Message is too large. Max size: " + + MaxMessageSize + ", actual size: " + rawMessage.Length); + + // message length + message body + ClientWriter.Write7BitEncodedInt(rawMessage.Length); + ClientWriter.Write(rawMessage, 0, rawMessage.Length); + return true; + } + catch (Exception ex) + { + LastException = ex as NetworkException ?? + new NetworkException(ex.Message, ex); + + ErrorOccured?.Invoke(ex.Message, ex); + return false; + } + } + + private Task DisconnectTask { get; set; } + + /// + public void Disconnect() + { + DisconnectTask = DisconnectTask ?? Task.Factory.StartNew(async () => + { + await Connection.CloseAsync(0x0C); + IsConnected = false; + Disconnected?.Invoke(); + }); + } + + /// + public void Dispose() + { + if (Connection == null) + return; + + if (IsConnected) + Disconnect(); + + var task = DisconnectTask; + if (task != null) + task.ConfigureAwait(false) + .GetAwaiter() + .GetResult(); + + Connection.DisposeAsync() + .ConfigureAwait(false) + .GetAwaiter() + .GetResult(); + Connection = null; + + // clean up readers/writers + ClientReader.Dispose(); + ClientReader = null; + ClientWriter.Dispose(); + ClientWriter = null; + } +} diff --git a/CoreRemoting.Channels.Quic/QuicServerChannel.cs b/CoreRemoting.Channels.Quic/QuicServerChannel.cs new file mode 100644 index 0000000..f7c0bc3 --- /dev/null +++ b/CoreRemoting.Channels.Quic/QuicServerChannel.cs @@ -0,0 +1,116 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Net; +using System.Net.Quic; +using System.Net.Security; +using System.Threading.Tasks; + +namespace CoreRemoting.Channels.Quic; + +/// +/// Server side QUIC channel implementation based on System.Net.Quic. +/// +public class QuicServerChannel : IServerChannel +{ + private IRemotingServer Server { get; set; } + + private ConcurrentDictionary Connections { get; } = + new ConcurrentDictionary(); + + /// + public bool IsListening { get; private set; } + + private QuicServerConnectionOptions Options { get; set; } + + private QuicListener Listener { get; set; } + + private IPEndPoint ListenEndPoint { get; set; } + + /// + public void Init(IRemotingServer server) + { + Server = server ?? throw new ArgumentNullException(nameof(server)); + if (!QuicListener.IsSupported) + throw new NotSupportedException("QUIC is not supported."); + + var url = "quic://" + + Server.Config.HostName + ":" + + Server.Config.NetworkPort + "/rpc"; + + // validate URL and create listener endpoint + var uri = new Uri(url); + var certificate = CertificateHelper.GenerateSelfSigned(uri.DnsSafeHost); + ListenEndPoint = new IPEndPoint(IPAddress.Loopback, uri.Port); // TODO: Loopback + + Options = new QuicServerConnectionOptions() + { + DefaultStreamErrorCode = 0x0A, + DefaultCloseErrorCode = 0x0B, + ServerAuthenticationOptions = new SslServerAuthenticationOptions + { + ServerCertificate = certificate, + ApplicationProtocols = new List() + { + new SslApplicationProtocol(QuicClientChannel.ProtocolName) + }, + } + }; + } + + /// + public void StartListening() + { + _ = Task.Factory.StartNew(async () => + { + // start the listener + Listener = await QuicListener.ListenAsync(new QuicListenerOptions() + { + ListenEndPoint = ListenEndPoint, + ConnectionOptionsCallback = (_, _, _) => ValueTask.FromResult(Options), + ApplicationProtocols = new List() + { + new SslApplicationProtocol(QuicClientChannel.ProtocolName) + }, + }); + + // accept incoming connections + IsListening = true; + while (IsListening) + { + try + { + var connection = await Listener.AcceptConnectionAsync(); + var stream = await connection.AcceptInboundStreamAsync(); + var session = new QuicServerConnection(connection, stream, Server); + var sessionId = session.StartListening(); + Connections[sessionId] = session; + } + catch + { + IsListening = false; // TODO: not sure?? + } + } + }); + } + + /// + public void StopListening() + { + if (Listener != null && IsListening) + { + IsListening = false; + Listener.DisposeAsync() + .ConfigureAwait(false) + .GetAwaiter() + .GetResult(); + } + } + + /// + public void Dispose() + { + StopListening(); + Listener = null; + } +} \ No newline at end of file diff --git a/CoreRemoting.Channels.Quic/QuicServerConnection.cs b/CoreRemoting.Channels.Quic/QuicServerConnection.cs new file mode 100644 index 0000000..10b5eb8 --- /dev/null +++ b/CoreRemoting.Channels.Quic/QuicServerConnection.cs @@ -0,0 +1,140 @@ +using System; +using System.IO; +using System.Net.Quic; +using System.Text; +using System.Threading.Tasks; + +namespace CoreRemoting.Channels.Quic; + +/// +/// Quic server-side connection. +/// +public class QuicServerConnection : IRawMessageTransport +{ + private const int MaxMessageSize = QuicClientChannel.MaxMessageSize; + + /// + /// Initializes a new instance of the class. + /// + public QuicServerConnection(QuicConnection connection, QuicStream stream, IRemotingServer remotingServer) + { + Connection = connection ?? throw new ArgumentNullException(nameof(connection)); + ClientStream = stream ?? throw new ArgumentNullException(nameof(stream)); + RemotingServer = remotingServer ?? throw new ArgumentNullException(nameof(remotingServer)); + ClientReader = new BinaryReader(stream, Encoding.UTF8, true); + ClientWriter = new BinaryWriter(stream, Encoding.UTF8, true); + } + + private QuicConnection Connection { get; set; } + + private QuicStream ClientStream { get; set; } + + private BinaryReader ClientReader { get; set; } + + private BinaryWriter ClientWriter { get; set; } + + private IRemotingServer RemotingServer { get; set; } + + private RemotingSession Session { get; set; } + + /// + public NetworkException LastException { get; set; } + + /// + public event Action ReceiveMessage; + + /// + public event Action ErrorOccured; + + /// + /// Event: fires when a web socket is disconnected. + /// + public event Action Disconnected; + + /// + public bool SendMessage(byte[] rawMessage) + { + try + { + if (rawMessage.Length > MaxMessageSize) + throw new InvalidOperationException("Message is too large. Max size: " + + MaxMessageSize + ", actual size: " + rawMessage.Length); + + // message length + message body + ClientWriter.Write7BitEncodedInt(rawMessage.Length); + ClientWriter.Write(rawMessage, 0, rawMessage.Length); + return true; + } + catch (Exception ex) + { + LastException = ex as NetworkException ?? + new NetworkException(ex.Message, ex); + + ErrorOccured?.Invoke(ex.Message, ex); + return false; + } + } + + /// + /// Starts listening to the incoming messages. + /// + public Guid StartListening() + { + var sessionId = CreateRemotingSession(); + _ = Task.Factory.StartNew(() => ReadIncomingMessages()); + return sessionId; + } + + /// + /// Creates for the incoming QUIC connection. + /// + private Guid CreateRemotingSession() + { + // read handshake message + var clientPublicKey = ReadIncomingMessage(); + + // disable message encryption if handshake is empty + if (clientPublicKey != null && clientPublicKey.Length == 0) + clientPublicKey = null; + + Session = RemotingServer.SessionRepository.CreateSession( + clientPublicKey, RemotingServer, this); + + return Session.SessionId; + } + + private byte[] ReadIncomingMessage() + { + var messageSize = ClientReader.Read7BitEncodedInt(); + return ClientReader.ReadBytes(Math.Min(messageSize, MaxMessageSize)); + } + + private void ReadIncomingMessages() + { + try + { + while (true) + { + var message = ReadIncomingMessage(); + if (message != null && message.Length > 0) + { + // flush received QUIC message + ReceiveMessage(message); + } + } + } + catch (Exception ex) + { + LastException = ex as NetworkException ?? + new NetworkException(ex.Message, ex); + + ErrorOccured?.Invoke(ex.Message, LastException); + Disconnected?.Invoke(); + } + finally + { + Connection?.DisposeAsync().ConfigureAwait(false).GetAwaiter().GetResult(); + Connection = null; + } + } +} From 6ad033015a9b390b35d670b05731cabf9bd3e14d Mon Sep 17 00:00:00 2001 From: yallie Date: Sat, 30 Nov 2024 17:37:51 +0300 Subject: [PATCH 2/2] Added a unit test for messages larger than 2 megabytes. --- CoreRemoting.Tests/RpcTests.cs | 49 ++++++++++++++++++++++++ CoreRemoting.Tests/Tools/ITestService.cs | 2 + CoreRemoting.Tests/Tools/TestService.cs | 19 +++++++++ 3 files changed, 70 insertions(+) diff --git a/CoreRemoting.Tests/RpcTests.cs b/CoreRemoting.Tests/RpcTests.cs index b3dddfc..12af61d 100644 --- a/CoreRemoting.Tests/RpcTests.cs +++ b/CoreRemoting.Tests/RpcTests.cs @@ -608,5 +608,54 @@ public void DataTable_roundtrip_works_issue60() var dt2 = proxy.TestDt(dt, 1); Assert.NotNull(dt2); } + + [Fact] + public void Large_messages_are_sent_and_received() + { + // max payload size, in bytes + var maxSize = 2 * 1024 * 1024 + 1; + + using var client = new RemotingClient(new ClientConfig() + { + ConnectionTimeout = 0, + InvocationTimeout = 0, + SendTimeout = 0, + Channel = ClientChannel, + MessageEncryption = false, + ServerPort = _serverFixture.Server.Config.NetworkPort, + }); + + client.Connect(); + var proxy = client.CreateProxy(); + + // shouldn't throw exceptions + Roundtrip("Payload", maxSize); + Roundtrip(new byte[] { 1, 2, 3, 4, 5 }, maxSize); + Roundtrip(new int[] { 12345, 67890 }, maxSize); + + void Roundtrip(T payload, int maxSize) where T : class + { + var lastSize = 0; + try + { + while (true) + { + // a -> aa -> aaaa ... + var dup = proxy.Duplicate(payload); + if (dup.size >= maxSize) + break; + + // save the size for error reporting + lastSize = dup.size; + payload = dup.duplicate; + } + } + catch (Exception ex) + { + throw new InvalidOperationException($"Failed to handle " + + $"payload larger than {lastSize}: {ex.Message}", ex); + } + } + } } } \ No newline at end of file diff --git a/CoreRemoting.Tests/Tools/ITestService.cs b/CoreRemoting.Tests/Tools/ITestService.cs index 7977a95..7685f04 100644 --- a/CoreRemoting.Tests/Tools/ITestService.cs +++ b/CoreRemoting.Tests/Tools/ITestService.cs @@ -38,5 +38,7 @@ public interface ITestService : IBaseService void NonSerializableError(string text, params object[] data); DataTable TestDt(DataTable dt, long num); + + (T duplicate, int size) Duplicate(T sample) where T : class; } } \ No newline at end of file diff --git a/CoreRemoting.Tests/Tools/TestService.cs b/CoreRemoting.Tests/Tools/TestService.cs index 1cc271d..6bdf221 100644 --- a/CoreRemoting.Tests/Tools/TestService.cs +++ b/CoreRemoting.Tests/Tools/TestService.cs @@ -99,5 +99,24 @@ public DataTable TestDt(DataTable dt, long num) dt.Rows.Clear(); return dt; } + + public (T, int) Duplicate(T sample) where T : class + { + return sample switch + { + byte[] arr => (Dup(arr) as T, arr.Length * 2), + int[] iarr => (Dup(iarr) as T, iarr.Length * 2 * sizeof(int)), + string str => ((str + str) as T, str.Length * 2 * sizeof(char)), + _ => throw new ArgumentOutOfRangeException(), + }; + + TItem[] Dup(TItem[] arr) + { + var length = arr.Length; + Array.Resize(ref arr, length * 2); + Array.Copy(arr, 0, arr, length, length); + return arr; + } + } } } \ No newline at end of file