From fb8b97a45548a55a80003f1ca88e253ce0bc44a5 Mon Sep 17 00:00:00 2001 From: yallie Date: Wed, 4 Dec 2024 02:27:30 +0300 Subject: [PATCH] RemotingClient should handle undeserializable remote exceptions. --- CoreRemoting.Tests/Tools/FailingService.cs | 3 -- .../Websocket/WebsocketServerChannel.cs | 6 +-- ...ection.cs => WebsocketServerConnection.cs} | 6 +-- CoreRemoting/RemotingClient.cs | 54 +++++++++++++++---- 4 files changed, 50 insertions(+), 19 deletions(-) rename CoreRemoting/Channels/Websocket/{WebsocketConnection.cs => WebsocketServerConnection.cs} (94%) diff --git a/CoreRemoting.Tests/Tools/FailingService.cs b/CoreRemoting.Tests/Tools/FailingService.cs index dc9bb7e..08e68a2 100644 --- a/CoreRemoting.Tests/Tools/FailingService.cs +++ b/CoreRemoting.Tests/Tools/FailingService.cs @@ -1,7 +1,4 @@ using System; -using System.Data; -using System.Threading.Tasks; -using CoreRemoting.Tests.ExternalTypes; namespace CoreRemoting.Tests.Tools; diff --git a/CoreRemoting/Channels/Websocket/WebsocketServerChannel.cs b/CoreRemoting/Channels/Websocket/WebsocketServerChannel.cs index 2ceaa18..8ffafe7 100644 --- a/CoreRemoting/Channels/Websocket/WebsocketServerChannel.cs +++ b/CoreRemoting/Channels/Websocket/WebsocketServerChannel.cs @@ -14,8 +14,8 @@ public class WebsocketServerChannel : IServerChannel private IRemotingServer Server { get; set; } - private ConcurrentDictionary Connections { get; } = - new ConcurrentDictionary(); + private ConcurrentDictionary Connections { get; } = + new ConcurrentDictionary(); /// public bool IsListening => HttpListener.IsListening; @@ -57,7 +57,7 @@ private async Task ReceiveConnection() // accept websocket request and start a new session var websocketContext = await context.AcceptWebSocketAsync(null); var websocket = websocketContext.WebSocket; - var connection = new WebsocketConnection(websocketContext, websocket, Server); + var connection = new WebsocketServerConnection(websocketContext, websocket, Server); // handle incoming websocket messages var sessionId = connection.StartListening(); diff --git a/CoreRemoting/Channels/Websocket/WebsocketConnection.cs b/CoreRemoting/Channels/Websocket/WebsocketServerConnection.cs similarity index 94% rename from CoreRemoting/Channels/Websocket/WebsocketConnection.cs rename to CoreRemoting/Channels/Websocket/WebsocketServerConnection.cs index 4143ffc..4e97716 100644 --- a/CoreRemoting/Channels/Websocket/WebsocketConnection.cs +++ b/CoreRemoting/Channels/Websocket/WebsocketServerConnection.cs @@ -9,15 +9,15 @@ namespace CoreRemoting.Channels.Websocket; /// /// Websocket connection. /// -public class WebsocketConnection : IRawMessageTransport +public class WebsocketServerConnection : IRawMessageTransport { // note: LOH threshold is ~85 kilobytes private const int BufferSize = 16 * 1024; /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// - public WebsocketConnection(HttpListenerWebSocketContext websocketContext, WebSocket websocket, IRemotingServer remotingServer) + public WebsocketServerConnection(HttpListenerWebSocketContext websocketContext, WebSocket websocket, IRemotingServer remotingServer) { WebSocketContext = websocketContext ?? throw new ArgumentNullException(nameof(websocketContext)); WebSocket = websocket ?? throw new ArgumentNullException(nameof(websocket)); diff --git a/CoreRemoting/RemotingClient.cs b/CoreRemoting/RemotingClient.cs index 37c427c..808462f 100644 --- a/CoreRemoting/RemotingClient.cs +++ b/CoreRemoting/RemotingClient.cs @@ -434,7 +434,7 @@ private void Authenticate() /// Raw message data private void OnMessage(byte[] rawMessage) { - var message = Serializer.Deserialize(rawMessage); + var message = TryDeserialize(rawMessage); switch (message.MessageType.ToLower()) { @@ -456,6 +456,29 @@ private void OnMessage(byte[] rawMessage) case "session_closed": Disconnect(quiet: true); break; + default: + // TODO: how do we handle invalid wire messages received by the client? + // A wire message could have been tampered with and couldn't be deserialized + break; + } + } + + private WireMessage TryDeserialize(byte[] rawMessage) + { + try + { + return Serializer.Deserialize(rawMessage); + } + catch // TODO: dispatch message deserialization exception? + { + return new WireMessage + { + Data = rawMessage, + Error = true, + Iv = Array.Empty(), + MessageType = "invalid", + UniqueCallKey = Array.Empty(), + }; } } @@ -585,16 +608,27 @@ private void ProcessRpcResultMessage(WireMessage message) if (message.Error) { - var remoteException = - Serializer.Deserialize( - MessageEncryptionManager.GetDecryptedMessageData( - message: message, - serializer: Serializer, - sharedSecret: sharedSecret, - sendersPublicKeyBlob: _serverPublicKeyBlob, - sendersPublicKeySize: _keyPair?.KeySize ?? 0)); + try + { + var remoteException = + Serializer.Deserialize( + MessageEncryptionManager.GetDecryptedMessageData( + message: message, + serializer: Serializer, + sharedSecret: sharedSecret, + sendersPublicKeyBlob: _serverPublicKeyBlob, + sendersPublicKeySize: _keyPair?.KeySize ?? 0)); + + clientRpcContext.RemoteException = remoteException; + } + catch (Exception deserializationException) + { + var remoteException = new RemoteInvocationException( + "Remote exception couldn't be deserialized", + deserializationException); - clientRpcContext.RemoteException = remoteException; + clientRpcContext.RemoteException = remoteException; + } } else {