Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved callback support added Task support and passing SessionRecord #5

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions libsignal-protocol-dotnet-tests/SessionBuilderTest.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/**
/**
* Copyright (C) 2016 langboost
*
* This program is free software: you can redistribute it and/or modify
Expand All @@ -25,6 +25,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using System.Threading.Tasks;

namespace libsignal_test
{
Expand All @@ -46,10 +47,11 @@ public BobDecryptionCallback(SignalProtocolStore bobStore, String originalMessag
this.originalMessage = originalMessage;
}

public void handlePlaintext(byte[] plaintext)
public Task handlePlaintext(byte[] plaintext, uint sessionVersion)
{
Assert.AreEqual(originalMessage, Encoding.UTF8.GetString(plaintext));
Assert.IsFalse(bobStore.ContainsSession(ALICE_ADDRESS));
return Task.CompletedTask;
}
}

Expand Down Expand Up @@ -87,7 +89,7 @@ public void testBasicPreKeyV3()
bobStore.StoreSignedPreKey(22, new SignedPreKeyRecord(22, DateUtil.currentTimeMillis(), bobSignedPreKeyPair, bobSignedPreKeySignature));

SessionCipher bobSessionCipher = new SessionCipher(bobStore, ALICE_ADDRESS);
byte[] plaintext = bobSessionCipher.decrypt(incomingMessage, new BobDecryptionCallback(bobStore, originalMessage));
byte[] plaintext = bobSessionCipher.decrypt(incomingMessage, new BobDecryptionCallback(bobStore, originalMessage)).Result;

Assert.IsTrue(bobStore.ContainsSession(ALICE_ADDRESS));
Assert.AreEqual((uint)3, bobStore.LoadSession(ALICE_ADDRESS).getSessionState().getSessionVersion());
Expand Down
37 changes: 19 additions & 18 deletions libsignal-protocol-dotnet/DecryptionCallback.cs
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
/**
* Copyright (C) 2016 smndtrl, langboost
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

using libsignal.state;
using System.Threading.Tasks;
/**
* Copyright (C) 2016 smndtrl, langboost
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
namespace libsignal
{
public interface DecryptionCallback
{
void handlePlaintext(byte[] plaintext);
Task handlePlaintext(byte[] plaintext, uint sessionVersion);
}
}
23 changes: 14 additions & 9 deletions libsignal-protocol-dotnet/SessionCipher.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using libsignal.state;
using libsignal.util;
using Strilanc.Value;
using System.Threading.Tasks;

namespace libsignal
{
Expand Down Expand Up @@ -117,7 +118,9 @@ public CiphertextMessage encrypt(byte[] paddedMessage)
/// <exception cref="UntrustedIdentityException">when the <see cref="IdentityKey"/> of the sender is untrusted.</exception>
public byte[] decrypt(PreKeySignalMessage ciphertext)
{
return decrypt(ciphertext, new NullDecryptionCallback());
var tsk = (decrypt(ciphertext, new NullDecryptionCallback()));
tsk.Wait();
return tsk.Result;
}

/// <summary>
Expand All @@ -137,7 +140,7 @@ public byte[] decrypt(PreKeySignalMessage ciphertext)
///
/// <exception cref="InvalidKeyException">when the message is formatted incorrectly.</exception>
/// <exception cref="UntrustedIdentityException">when the <see cref="IdentityKey"/> of the sender is untrusted.</exception>
public byte[] decrypt(PreKeySignalMessage ciphertext, DecryptionCallback callback)
public Task<byte[]> decrypt(PreKeySignalMessage ciphertext, DecryptionCallback callback)
{
lock (SESSION_LOCK)
{
Expand All @@ -147,7 +150,7 @@ public byte[] decrypt(PreKeySignalMessage ciphertext, DecryptionCallback callbac

identityKeyStore.SaveIdentity(remoteAddress, sessionRecord.getSessionState().getRemoteIdentityKey());

callback.handlePlaintext(plaintext);
callback.handlePlaintext(plaintext, sessionRecord.getSessionState().getSessionVersion()).Wait();

sessionStore.StoreSession(remoteAddress, sessionRecord);

Expand All @@ -156,7 +159,7 @@ public byte[] decrypt(PreKeySignalMessage ciphertext, DecryptionCallback callbac
preKeyStore.RemovePreKey(unsignedPreKeyId.ForceGetValue());
}

return plaintext;
return Task.FromResult(plaintext);
}
}

Expand All @@ -172,7 +175,9 @@ public byte[] decrypt(PreKeySignalMessage ciphertext, DecryptionCallback callbac
/// <exception cref="NoSessionException">if there is no established session for this contact.</exception>
public byte[] decrypt(SignalMessage ciphertext)
{
return decrypt(ciphertext, new NullDecryptionCallback());
var tsk = decrypt(ciphertext, new NullDecryptionCallback());
tsk.Wait();
return tsk.Result;
}

/// <summary>
Expand All @@ -189,7 +194,7 @@ public byte[] decrypt(SignalMessage ciphertext)
/// <exception cref="LegacyMessageException">if the input is a message formatted by a protocol version that is
/// no longer supported.</exception>
/// <exception cref="NoSessionException">if there is no established session for this contact.</exception>
public byte[] decrypt(SignalMessage ciphertext, DecryptionCallback callback)
public Task<byte[]> decrypt(SignalMessage ciphertext, DecryptionCallback callback)
{
lock (SESSION_LOCK)
{
Expand All @@ -207,11 +212,11 @@ public byte[] decrypt(SignalMessage ciphertext, DecryptionCallback callback)
throw new UntrustedIdentityException(remoteAddress.Name, sessionRecord.getSessionState().getRemoteIdentityKey());
}

callback.handlePlaintext(plaintext);
callback.handlePlaintext(plaintext, sessionRecord.getSessionState().getSessionVersion()).Wait();//no async in a lock

sessionStore.StoreSession(remoteAddress, sessionRecord);

return plaintext;
return Task.FromResult(plaintext);
}
}

Expand Down Expand Up @@ -391,7 +396,7 @@ private byte[] getPlaintext(MessageKeys messageKeys, byte[] cipherText)
private class NullDecryptionCallback : DecryptionCallback
{

public void handlePlaintext(byte[] plaintext) { }
public Task handlePlaintext(byte[] plaintext, uint sessionVersion) => Task.CompletedTask;
}
}
}
9 changes: 6 additions & 3 deletions libsignal-protocol-dotnet/groups/GroupCipher.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
using libsignal.groups.ratchet;
using libsignal.groups.state;
using libsignal.protocol;
using libsignal.state;
using libsignal.util;
using System.Threading.Tasks;

namespace libsignal.groups
{
Expand Down Expand Up @@ -127,7 +129,7 @@ public byte[] decrypt(byte[] senderKeyMessageBytes, DecryptionCallback callback)

byte[] plaintext = getPlainText(senderKey.getIv(), senderKey.getCipherKey(), senderKeyMessage.getCipherText());

callback.handlePlaintext(plaintext);
callback.handlePlaintext(plaintext, 931).Wait();//931 random as we don't have the real version

senderKeyStore.storeSenderKey(senderKeyId, record);

Expand Down Expand Up @@ -210,7 +212,8 @@ private byte[] getCipherText(byte[] iv, byte[] key, byte[] plaintext)

private class NullDecryptionCallback : DecryptionCallback
{
public void handlePlaintext(byte[] plaintext) { }
}
public Task handlePlaintext(byte[] plaintext, uint sessionVersion) => Task.CompletedTask;

}
}
}