diff --git a/libsignal-protocol-dotnet-tests/SessionBuilderTest.cs b/libsignal-protocol-dotnet-tests/SessionBuilderTest.cs index 4ab868e..fbad9ff 100644 --- a/libsignal-protocol-dotnet-tests/SessionBuilderTest.cs +++ b/libsignal-protocol-dotnet-tests/SessionBuilderTest.cs @@ -1,4 +1,4 @@ -/** +/** * Copyright (C) 2016 langboost * * This program is free software: you can redistribute it and/or modify @@ -25,6 +25,7 @@ using System; using System.Collections.Generic; using System.Text; +using System.Threading.Tasks; namespace libsignal_test { @@ -46,10 +47,11 @@ public BobDecryptionCallback(SignalProtocolStore bobStore, String originalMessag this.originalMessage = originalMessage; } - public void handlePlaintext(byte[] plaintext) + public Task handlePlaintext(byte[] plaintext, uint sessionVersion) { Assert.AreEqual(originalMessage, Encoding.UTF8.GetString(plaintext)); Assert.IsFalse(bobStore.ContainsSession(ALICE_ADDRESS)); + return Task.CompletedTask; } } @@ -87,7 +89,7 @@ public void testBasicPreKeyV3() bobStore.StoreSignedPreKey(22, new SignedPreKeyRecord(22, DateUtil.currentTimeMillis(), bobSignedPreKeyPair, bobSignedPreKeySignature)); SessionCipher bobSessionCipher = new SessionCipher(bobStore, ALICE_ADDRESS); - byte[] plaintext = bobSessionCipher.decrypt(incomingMessage, new BobDecryptionCallback(bobStore, originalMessage)); + byte[] plaintext = bobSessionCipher.decrypt(incomingMessage, new BobDecryptionCallback(bobStore, originalMessage)).Result; Assert.IsTrue(bobStore.ContainsSession(ALICE_ADDRESS)); Assert.AreEqual((uint)3, bobStore.LoadSession(ALICE_ADDRESS).getSessionState().getSessionVersion()); diff --git a/libsignal-protocol-dotnet/DecryptionCallback.cs b/libsignal-protocol-dotnet/DecryptionCallback.cs index b72b20d..8c3fce0 100644 --- a/libsignal-protocol-dotnet/DecryptionCallback.cs +++ b/libsignal-protocol-dotnet/DecryptionCallback.cs @@ -1,24 +1,25 @@ -/** - * Copyright (C) 2016 smndtrl, langboost - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ - +using libsignal.state; +using System.Threading.Tasks; +/** +* Copyright (C) 2016 smndtrl, langboost +* +* This program is free software: you can redistribute it and/or modify +* it under the terms of the GNU General Public License as published by +* the Free Software Foundation, either version 3 of the License, or +* (at your option) any later version. +* +* This program is distributed in the hope that it will be useful, +* but WITHOUT ANY WARRANTY; without even the implied warranty of +* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +* GNU General Public License for more details. +* +* You should have received a copy of the GNU General Public License +* along with this program. If not, see . +*/ namespace libsignal { public interface DecryptionCallback { - void handlePlaintext(byte[] plaintext); + Task handlePlaintext(byte[] plaintext, uint sessionVersion); } } diff --git a/libsignal-protocol-dotnet/SessionCipher.cs b/libsignal-protocol-dotnet/SessionCipher.cs index e269093..8a66849 100644 --- a/libsignal-protocol-dotnet/SessionCipher.cs +++ b/libsignal-protocol-dotnet/SessionCipher.cs @@ -7,6 +7,7 @@ using libsignal.state; using libsignal.util; using Strilanc.Value; +using System.Threading.Tasks; namespace libsignal { @@ -117,7 +118,9 @@ public CiphertextMessage encrypt(byte[] paddedMessage) /// when the of the sender is untrusted. public byte[] decrypt(PreKeySignalMessage ciphertext) { - return decrypt(ciphertext, new NullDecryptionCallback()); + var tsk = (decrypt(ciphertext, new NullDecryptionCallback())); + tsk.Wait(); + return tsk.Result; } /// @@ -137,7 +140,7 @@ public byte[] decrypt(PreKeySignalMessage ciphertext) /// /// when the message is formatted incorrectly. /// when the of the sender is untrusted. - public byte[] decrypt(PreKeySignalMessage ciphertext, DecryptionCallback callback) + public Task decrypt(PreKeySignalMessage ciphertext, DecryptionCallback callback) { lock (SESSION_LOCK) { @@ -147,7 +150,7 @@ public byte[] decrypt(PreKeySignalMessage ciphertext, DecryptionCallback callbac identityKeyStore.SaveIdentity(remoteAddress, sessionRecord.getSessionState().getRemoteIdentityKey()); - callback.handlePlaintext(plaintext); + callback.handlePlaintext(plaintext, sessionRecord.getSessionState().getSessionVersion()).Wait(); sessionStore.StoreSession(remoteAddress, sessionRecord); @@ -156,7 +159,7 @@ public byte[] decrypt(PreKeySignalMessage ciphertext, DecryptionCallback callbac preKeyStore.RemovePreKey(unsignedPreKeyId.ForceGetValue()); } - return plaintext; + return Task.FromResult(plaintext); } } @@ -172,7 +175,9 @@ public byte[] decrypt(PreKeySignalMessage ciphertext, DecryptionCallback callbac /// if there is no established session for this contact. public byte[] decrypt(SignalMessage ciphertext) { - return decrypt(ciphertext, new NullDecryptionCallback()); + var tsk = decrypt(ciphertext, new NullDecryptionCallback()); + tsk.Wait(); + return tsk.Result; } /// @@ -189,7 +194,7 @@ public byte[] decrypt(SignalMessage ciphertext) /// if the input is a message formatted by a protocol version that is /// no longer supported. /// if there is no established session for this contact. - public byte[] decrypt(SignalMessage ciphertext, DecryptionCallback callback) + public Task decrypt(SignalMessage ciphertext, DecryptionCallback callback) { lock (SESSION_LOCK) { @@ -207,11 +212,11 @@ public byte[] decrypt(SignalMessage ciphertext, DecryptionCallback callback) throw new UntrustedIdentityException(remoteAddress.Name, sessionRecord.getSessionState().getRemoteIdentityKey()); } - callback.handlePlaintext(plaintext); + callback.handlePlaintext(plaintext, sessionRecord.getSessionState().getSessionVersion()).Wait();//no async in a lock sessionStore.StoreSession(remoteAddress, sessionRecord); - return plaintext; + return Task.FromResult(plaintext); } } @@ -391,7 +396,7 @@ private byte[] getPlaintext(MessageKeys messageKeys, byte[] cipherText) private class NullDecryptionCallback : DecryptionCallback { - public void handlePlaintext(byte[] plaintext) { } + public Task handlePlaintext(byte[] plaintext, uint sessionVersion) => Task.CompletedTask; } } } diff --git a/libsignal-protocol-dotnet/groups/GroupCipher.cs b/libsignal-protocol-dotnet/groups/GroupCipher.cs index 4b5d693..4b39221 100644 --- a/libsignal-protocol-dotnet/groups/GroupCipher.cs +++ b/libsignal-protocol-dotnet/groups/GroupCipher.cs @@ -19,7 +19,9 @@ using libsignal.groups.ratchet; using libsignal.groups.state; using libsignal.protocol; +using libsignal.state; using libsignal.util; +using System.Threading.Tasks; namespace libsignal.groups { @@ -127,7 +129,7 @@ public byte[] decrypt(byte[] senderKeyMessageBytes, DecryptionCallback callback) byte[] plaintext = getPlainText(senderKey.getIv(), senderKey.getCipherKey(), senderKeyMessage.getCipherText()); - callback.handlePlaintext(plaintext); + callback.handlePlaintext(plaintext, 931).Wait();//931 random as we don't have the real version senderKeyStore.storeSenderKey(senderKeyId, record); @@ -210,7 +212,8 @@ private byte[] getCipherText(byte[] iv, byte[] key, byte[] plaintext) private class NullDecryptionCallback : DecryptionCallback { - public void handlePlaintext(byte[] plaintext) { } - } + public Task handlePlaintext(byte[] plaintext, uint sessionVersion) => Task.CompletedTask; + + } } }