From 5f21f6ccd67894cb60fea784e28826b0cd701788 Mon Sep 17 00:00:00 2001 From: Mitch Capper Date: Sun, 14 Feb 2021 01:50:16 -0800 Subject: [PATCH] Squashed commit of the following: Improved callback support added Task support and passing SessionRecord This modifies the existing callback support to work with Task based callbacks. It also passes the session version (when available) to the callback (as not yet committed for prekey), needed for protocol version check code that already exists. --- .../SessionBuilderTest.cs | 8 ++-- .../DecryptionCallback.cs | 37 ++++++++++--------- libsignal-protocol-dotnet/SessionCipher.cs | 23 +++++++----- .../groups/GroupCipher.cs | 9 +++-- 4 files changed, 44 insertions(+), 33 deletions(-) diff --git a/libsignal-protocol-dotnet-tests/SessionBuilderTest.cs b/libsignal-protocol-dotnet-tests/SessionBuilderTest.cs index 4ab868e..fbad9ff 100644 --- a/libsignal-protocol-dotnet-tests/SessionBuilderTest.cs +++ b/libsignal-protocol-dotnet-tests/SessionBuilderTest.cs @@ -1,4 +1,4 @@ -/** +/** * Copyright (C) 2016 langboost * * This program is free software: you can redistribute it and/or modify @@ -25,6 +25,7 @@ using System; using System.Collections.Generic; using System.Text; +using System.Threading.Tasks; namespace libsignal_test { @@ -46,10 +47,11 @@ public BobDecryptionCallback(SignalProtocolStore bobStore, String originalMessag this.originalMessage = originalMessage; } - public void handlePlaintext(byte[] plaintext) + public Task handlePlaintext(byte[] plaintext, uint sessionVersion) { Assert.AreEqual(originalMessage, Encoding.UTF8.GetString(plaintext)); Assert.IsFalse(bobStore.ContainsSession(ALICE_ADDRESS)); + return Task.CompletedTask; } } @@ -87,7 +89,7 @@ public void testBasicPreKeyV3() bobStore.StoreSignedPreKey(22, new SignedPreKeyRecord(22, DateUtil.currentTimeMillis(), bobSignedPreKeyPair, bobSignedPreKeySignature)); SessionCipher bobSessionCipher = new SessionCipher(bobStore, ALICE_ADDRESS); - byte[] plaintext = bobSessionCipher.decrypt(incomingMessage, new BobDecryptionCallback(bobStore, originalMessage)); + byte[] plaintext = bobSessionCipher.decrypt(incomingMessage, new BobDecryptionCallback(bobStore, originalMessage)).Result; Assert.IsTrue(bobStore.ContainsSession(ALICE_ADDRESS)); Assert.AreEqual((uint)3, bobStore.LoadSession(ALICE_ADDRESS).getSessionState().getSessionVersion()); diff --git a/libsignal-protocol-dotnet/DecryptionCallback.cs b/libsignal-protocol-dotnet/DecryptionCallback.cs index b72b20d..8c3fce0 100644 --- a/libsignal-protocol-dotnet/DecryptionCallback.cs +++ b/libsignal-protocol-dotnet/DecryptionCallback.cs @@ -1,24 +1,25 @@ -/** - * Copyright (C) 2016 smndtrl, langboost - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ - +using libsignal.state; +using System.Threading.Tasks; +/** +* Copyright (C) 2016 smndtrl, langboost +* +* This program is free software: you can redistribute it and/or modify +* it under the terms of the GNU General Public License as published by +* the Free Software Foundation, either version 3 of the License, or +* (at your option) any later version. +* +* This program is distributed in the hope that it will be useful, +* but WITHOUT ANY WARRANTY; without even the implied warranty of +* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +* GNU General Public License for more details. +* +* You should have received a copy of the GNU General Public License +* along with this program. If not, see . +*/ namespace libsignal { public interface DecryptionCallback { - void handlePlaintext(byte[] plaintext); + Task handlePlaintext(byte[] plaintext, uint sessionVersion); } } diff --git a/libsignal-protocol-dotnet/SessionCipher.cs b/libsignal-protocol-dotnet/SessionCipher.cs index e269093..8a66849 100644 --- a/libsignal-protocol-dotnet/SessionCipher.cs +++ b/libsignal-protocol-dotnet/SessionCipher.cs @@ -7,6 +7,7 @@ using libsignal.state; using libsignal.util; using Strilanc.Value; +using System.Threading.Tasks; namespace libsignal { @@ -117,7 +118,9 @@ public CiphertextMessage encrypt(byte[] paddedMessage) /// when the of the sender is untrusted. public byte[] decrypt(PreKeySignalMessage ciphertext) { - return decrypt(ciphertext, new NullDecryptionCallback()); + var tsk = (decrypt(ciphertext, new NullDecryptionCallback())); + tsk.Wait(); + return tsk.Result; } /// @@ -137,7 +140,7 @@ public byte[] decrypt(PreKeySignalMessage ciphertext) /// /// when the message is formatted incorrectly. /// when the of the sender is untrusted. - public byte[] decrypt(PreKeySignalMessage ciphertext, DecryptionCallback callback) + public Task decrypt(PreKeySignalMessage ciphertext, DecryptionCallback callback) { lock (SESSION_LOCK) { @@ -147,7 +150,7 @@ public byte[] decrypt(PreKeySignalMessage ciphertext, DecryptionCallback callbac identityKeyStore.SaveIdentity(remoteAddress, sessionRecord.getSessionState().getRemoteIdentityKey()); - callback.handlePlaintext(plaintext); + callback.handlePlaintext(plaintext, sessionRecord.getSessionState().getSessionVersion()).Wait(); sessionStore.StoreSession(remoteAddress, sessionRecord); @@ -156,7 +159,7 @@ public byte[] decrypt(PreKeySignalMessage ciphertext, DecryptionCallback callbac preKeyStore.RemovePreKey(unsignedPreKeyId.ForceGetValue()); } - return plaintext; + return Task.FromResult(plaintext); } } @@ -172,7 +175,9 @@ public byte[] decrypt(PreKeySignalMessage ciphertext, DecryptionCallback callbac /// if there is no established session for this contact. public byte[] decrypt(SignalMessage ciphertext) { - return decrypt(ciphertext, new NullDecryptionCallback()); + var tsk = decrypt(ciphertext, new NullDecryptionCallback()); + tsk.Wait(); + return tsk.Result; } /// @@ -189,7 +194,7 @@ public byte[] decrypt(SignalMessage ciphertext) /// if the input is a message formatted by a protocol version that is /// no longer supported. /// if there is no established session for this contact. - public byte[] decrypt(SignalMessage ciphertext, DecryptionCallback callback) + public Task decrypt(SignalMessage ciphertext, DecryptionCallback callback) { lock (SESSION_LOCK) { @@ -207,11 +212,11 @@ public byte[] decrypt(SignalMessage ciphertext, DecryptionCallback callback) throw new UntrustedIdentityException(remoteAddress.Name, sessionRecord.getSessionState().getRemoteIdentityKey()); } - callback.handlePlaintext(plaintext); + callback.handlePlaintext(plaintext, sessionRecord.getSessionState().getSessionVersion()).Wait();//no async in a lock sessionStore.StoreSession(remoteAddress, sessionRecord); - return plaintext; + return Task.FromResult(plaintext); } } @@ -391,7 +396,7 @@ private byte[] getPlaintext(MessageKeys messageKeys, byte[] cipherText) private class NullDecryptionCallback : DecryptionCallback { - public void handlePlaintext(byte[] plaintext) { } + public Task handlePlaintext(byte[] plaintext, uint sessionVersion) => Task.CompletedTask; } } } diff --git a/libsignal-protocol-dotnet/groups/GroupCipher.cs b/libsignal-protocol-dotnet/groups/GroupCipher.cs index 4b5d693..4b39221 100644 --- a/libsignal-protocol-dotnet/groups/GroupCipher.cs +++ b/libsignal-protocol-dotnet/groups/GroupCipher.cs @@ -19,7 +19,9 @@ using libsignal.groups.ratchet; using libsignal.groups.state; using libsignal.protocol; +using libsignal.state; using libsignal.util; +using System.Threading.Tasks; namespace libsignal.groups { @@ -127,7 +129,7 @@ public byte[] decrypt(byte[] senderKeyMessageBytes, DecryptionCallback callback) byte[] plaintext = getPlainText(senderKey.getIv(), senderKey.getCipherKey(), senderKeyMessage.getCipherText()); - callback.handlePlaintext(plaintext); + callback.handlePlaintext(plaintext, 931).Wait();//931 random as we don't have the real version senderKeyStore.storeSenderKey(senderKeyId, record); @@ -210,7 +212,8 @@ private byte[] getCipherText(byte[] iv, byte[] key, byte[] plaintext) private class NullDecryptionCallback : DecryptionCallback { - public void handlePlaintext(byte[] plaintext) { } - } + public Task handlePlaintext(byte[] plaintext, uint sessionVersion) => Task.CompletedTask; + + } } }