Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Failing constructor should throw a RemoteInvocationException #86

Merged
merged 3 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/dotnet.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ jobs:
build:

runs-on: ubuntu-latest
timeout-minutes: 5

steps:
- name: Checkout source
Expand Down
22 changes: 22 additions & 0 deletions CoreRemoting.Tests/RpcTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,28 @@ public void NonSerializableError_method_throws_Exception()
}
}

[Fact]
public void Failing_component_constructor_throws_RemoteInvocationException()
{
using var client = new RemotingClient(new ClientConfig()
{
ConnectionTimeout = 3,
InvocationTimeout = 3,
SendTimeout = 3,
Channel = ClientChannel,
MessageEncryption = false,
ServerPort = _serverFixture.Server.Config.NetworkPort,
});

client.Connect();

var proxy = client.CreateProxy<IFailingService>();
var ex = Assert.Throws<RemoteInvocationException>(() => proxy.Hello());

Assert.NotNull(ex);
Assert.Contains("FailingService", ex.Message);
}

[Fact]
public async Task Disposed_client_subscription_doesnt_break_other_clients()
{
Expand Down
4 changes: 4 additions & 0 deletions CoreRemoting.Tests/ServerFixture.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ public ServerFixture()
// Service for enum tests
container.RegisterService<IEnumTestService, EnumTestService>(
lifetime: ServiceLifetime.Singleton);

// Service for session tests
container.RegisterService<IFailingService, FailingService>(
lifetime: ServiceLifetime.SingleCall);
}
};
}
Expand Down
16 changes: 8 additions & 8 deletions CoreRemoting.Tests/SessionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public SessionTests(ServerFixture serverFixture)
_serverFixture = serverFixture;
_serverFixture.Start();
}

[Fact]
public void Client_Connect_should_create_new_session_AND_Disconnect_should_close_session()
{
Expand Down Expand Up @@ -54,7 +54,7 @@ public void Client_Connect_should_create_new_session_AND_Disconnect_should_close

client.Dispose();
});

var clientThread1 = new Thread(() => clientAction(0));
var clientThread2 = new Thread(() => clientAction(1));

Expand Down Expand Up @@ -94,15 +94,15 @@ public void Client_Connect_should_throw_exception_on_invalid_auth_credentials()
AuthenticateFake = credentials => credentials[1].Value == "secret"
}
};

var server = new RemotingServer(serverConfig);
server.Start();

try
{
var clientAction = new Action<string, bool>((password, shouldThrow) =>
{
using var client =
using var client =
new RemotingClient(new ClientConfig()
{
ConnectionTimeout = 0,
Expand All @@ -114,7 +114,7 @@ public void Client_Connect_should_throw_exception_on_invalid_auth_credentials()
new Credential() {Name = "Password", Value = password }
}
});

if (shouldThrow)
Assert.Throws<SecurityException>(() => client.Connect());
else
Expand All @@ -124,7 +124,7 @@ public void Client_Connect_should_throw_exception_on_invalid_auth_credentials()
var clientThread1 = new Thread(() => clientAction("wrong", true));
clientThread1.Start();
clientThread1.Join();

var clientThread2 = new Thread(() => clientAction("secret", false));
clientThread2.Start();
clientThread2.Join();
Expand Down Expand Up @@ -161,15 +161,15 @@ public void RemotingSession_Dispose_should_disconnect_client()
{
waitForDisconnect.Set();
};

client.Connect();
var proxy = client.CreateProxy<ITestService>();

proxy.TestMethod(null);

waitForDisconnect.Wait();
Assert.False(client.IsConnected);

client.Dispose();
}
}
Expand Down
15 changes: 15 additions & 0 deletions CoreRemoting.Tests/Tools/FailingService.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
using System;

namespace CoreRemoting.Tests.Tools;

public class FailingService : IFailingService
{
public FailingService()
{
throw new NotImplementedException();
}

public void Hello()
{
}
}
11 changes: 11 additions & 0 deletions CoreRemoting.Tests/Tools/IFailingService.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
using System;
using System.Data;
using System.Threading.Tasks;
using CoreRemoting.Tests.ExternalTypes;

namespace CoreRemoting.Tests.Tools;

public interface IFailingService
{
void Hello();
}
36 changes: 0 additions & 36 deletions CoreRemoting.sln.DotSettings.user

This file was deleted.

6 changes: 3 additions & 3 deletions CoreRemoting/Channels/Websocket/WebsocketServerChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ public class WebsocketServerChannel : IServerChannel

private IRemotingServer Server { get; set; }

private ConcurrentDictionary<Guid, WebsocketConnection> Connections { get; } =
new ConcurrentDictionary<Guid, WebsocketConnection>();
private ConcurrentDictionary<Guid, WebsocketServerConnection> Connections { get; } =
new ConcurrentDictionary<Guid, WebsocketServerConnection>();

/// <inheritdoc/>
public bool IsListening => HttpListener.IsListening;
Expand Down Expand Up @@ -57,7 +57,7 @@ private async Task ReceiveConnection()
// accept websocket request and start a new session
var websocketContext = await context.AcceptWebSocketAsync(null);
var websocket = websocketContext.WebSocket;
var connection = new WebsocketConnection(websocketContext, websocket, Server);
var connection = new WebsocketServerConnection(websocketContext, websocket, Server);

// handle incoming websocket messages
var sessionId = connection.StartListening();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ namespace CoreRemoting.Channels.Websocket;
/// <summary>
/// Websocket connection.
/// </summary>
public class WebsocketConnection : IRawMessageTransport
public class WebsocketServerConnection : IRawMessageTransport
{
// note: LOH threshold is ~85 kilobytes
private const int BufferSize = 16 * 1024;

/// <summary>
/// Initializes a new instance of the <see cref="WebsocketConnection"/> class.
/// Initializes a new instance of the <see cref="WebsocketServerConnection"/> class.
/// </summary>
public WebsocketConnection(HttpListenerWebSocketContext websocketContext, WebSocket websocket, IRemotingServer remotingServer)
public WebsocketServerConnection(HttpListenerWebSocketContext websocketContext, WebSocket websocket, IRemotingServer remotingServer)
{
WebSocketContext = websocketContext ?? throw new ArgumentNullException(nameof(websocketContext));
WebSocket = websocket ?? throw new ArgumentNullException(nameof(websocket));
Expand Down
54 changes: 44 additions & 10 deletions CoreRemoting/RemotingClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ private void Authenticate()
/// <param name="rawMessage">Raw message data</param>
private void OnMessage(byte[] rawMessage)
{
var message = Serializer.Deserialize<WireMessage>(rawMessage);
var message = TryDeserialize(rawMessage);

switch (message.MessageType.ToLower())
{
Expand All @@ -456,6 +456,29 @@ private void OnMessage(byte[] rawMessage)
case "session_closed":
Disconnect(quiet: true);
break;
default:
// TODO: how do we handle invalid wire messages received by the client?
// A wire message could have been tampered with and couldn't be deserialized
break;
}
}

private WireMessage TryDeserialize(byte[] rawMessage)
{
try
{
return Serializer.Deserialize<WireMessage>(rawMessage);
}
catch // TODO: dispatch message deserialization exception?
{
return new WireMessage
{
Data = rawMessage,
Error = true,
Iv = Array.Empty<byte>(),
MessageType = "invalid",
UniqueCallKey = Array.Empty<byte>(),
};
}
}

Expand Down Expand Up @@ -585,16 +608,27 @@ private void ProcessRpcResultMessage(WireMessage message)

if (message.Error)
{
var remoteException =
Serializer.Deserialize<RemoteInvocationException>(
MessageEncryptionManager.GetDecryptedMessageData(
message: message,
serializer: Serializer,
sharedSecret: sharedSecret,
sendersPublicKeyBlob: _serverPublicKeyBlob,
sendersPublicKeySize: _keyPair?.KeySize ?? 0));
try
{
var remoteException =
Serializer.Deserialize<RemoteInvocationException>(
MessageEncryptionManager.GetDecryptedMessageData(
message: message,
serializer: Serializer,
sharedSecret: sharedSecret,
sendersPublicKeyBlob: _serverPublicKeyBlob,
sendersPublicKeySize: _keyPair?.KeySize ?? 0));

clientRpcContext.RemoteException = remoteException;
}
catch (Exception deserializationException)
{
var remoteException = new RemoteInvocationException(
"Remote exception couldn't be deserialized",
deserializationException);

clientRpcContext.RemoteException = remoteException;
clientRpcContext.RemoteException = remoteException;
}
}
else
{
Expand Down
2 changes: 1 addition & 1 deletion CoreRemoting/RemotingSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ private void ProcessRpcMessage(WireMessage request)
serverRpcContext.Exception =
new RemoteInvocationException(
message: ex.Message,
innerEx: ex.GetType().IsSerializable ? ex : null);
innerEx: ex.ToSerializable());

if (oneWay)
return;
Expand Down
6 changes: 5 additions & 1 deletion CoreRemoting/Serialization/ExceptionExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using Castle.MicroKernel.ComponentActivator;
using System;
using System.Linq;

namespace CoreRemoting.Serialization;
Expand All @@ -20,6 +21,9 @@ public static class ExceptionExtensions
agg.InnerException.IsSerializable() &&
agg.GetType().IsSerializable,

// pesky exception that looks like serializable but really isn't
ComponentActivatorException cax => false,

_ => ex.GetType().IsSerializable &&
ex.InnerException.IsSerializable()
};
Expand Down
Loading