Skip to content

Commit

Permalink
Support e2e encryption using curve25519 keypair (#115)
Browse files Browse the repository at this point in the history
* Change keypair type in ecdh encryption util

* Use Ed25519Key in ecdh utils to represent other party

* Add encryption props to text serde

* Add encryption props to api

* Fix tests after switching to dh keypair

Co-authored-by: Alexey Tsymbal <alexey@gotiggy.com>
  • Loading branch information
tsmbl and Alexey Tsymbal authored Apr 8, 2022
1 parent 8f5352e commit 6e24559
Show file tree
Hide file tree
Showing 5 changed files with 360 additions and 231 deletions.
43 changes: 21 additions & 22 deletions src/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import { sleep, waitForFinality, Wallet_ } from '../utils';
import { ENCRYPTION_OVERHEAD_BYTES } from '../utils/ecdh-encryption';
import { CyclicByteBuffer } from '../utils/cyclic-bytebuffer';
import ByteBuffer from 'bytebuffer';
import { TextSerdeFactory } from './text-serde';
import { EncryptionProps, TextSerdeFactory } from './text-serde';

// TODO: Switch from types to classes

Expand Down Expand Up @@ -254,9 +254,9 @@ export async function getDialectProgramAddress(

function parseMessages(
{ messages: rawMessagesBuffer, members, encrypted }: RawDialect,
user?: anchor.web3.Keypair,
encryptionProps?: EncryptionProps,
) {
if (encrypted && !user) {
if (encrypted && !encryptionProps) {
return [];
}
const messagesBuffer = new CyclicByteBuffer(
Expand All @@ -270,7 +270,7 @@ function parseMessages(
encrypted,
members,
},
user,
encryptionProps,
);
const allMessages: Message[] = messagesBuffer.items().map(({ buffer }) => {
const byteBuffer = new ByteBuffer(buffer.length).append(buffer).flip();
Expand All @@ -288,29 +288,29 @@ function parseMessages(
return allMessages.reverse();
}

function parseRawDialect(rawDialect: RawDialect, user?: anchor.web3.Keypair) {
function parseRawDialect(
rawDialect: RawDialect,
encryptionProps?: EncryptionProps,
) {
return {
encrypted: rawDialect.encrypted,
members: rawDialect.members,
nextMessageIdx: rawDialect.messages.writeOffset,
lastMessageTimestamp: rawDialect.lastMessageTimestamp * 1000,
messages: parseMessages(rawDialect, user),
messages: parseMessages(rawDialect, encryptionProps),
};
}

export async function getDialect(
program: anchor.Program,
publicKey: PublicKey,
user?: anchor.web3.Keypair | Wallet,
encryptionProps?: EncryptionProps,
): Promise<DialectAccount> {
const rawDialect = (await program.account.dialectAccount.fetch(
publicKey,
)) as RawDialect;
const account = await program.provider.connection.getAccountInfo(publicKey);
const dialect = parseRawDialect(
rawDialect,
user && 'secretKey' in user ? user : undefined,
);
const dialect = parseRawDialect(rawDialect, encryptionProps);
return {
...account,
publicKey: publicKey,
Expand All @@ -321,14 +321,15 @@ export async function getDialect(
export async function getDialects(
program: anchor.Program,
user: anchor.web3.Keypair | Wallet,
encryptionProps?: EncryptionProps,
): Promise<DialectAccount[]> {
const metadata = await getMetadata(program, user.publicKey);
const enabledSubscriptions = metadata.subscriptions.filter(
(it) => it.enabled,
);
return Promise.all(
enabledSubscriptions.map(async ({ pubkey }) =>
getDialect(program, pubkey, user),
getDialect(program, pubkey, encryptionProps),
),
).then((dialects) =>
dialects.sort(
Expand All @@ -341,13 +342,13 @@ export async function getDialects(
export async function getDialectForMembers(
program: anchor.Program,
members: Member[],
user?: anchor.web3.Keypair,
encryptionProps?: EncryptionProps,
): Promise<DialectAccount> {
const sortedMembers = members.sort((a, b) =>
a.publicKey.toBuffer().compare(b.publicKey.toBuffer()),
);
const [publicKey] = await getDialectProgramAddress(program, sortedMembers);
return await getDialect(program, publicKey, user);
return await getDialect(program, publicKey, encryptionProps);
}

export async function findDialects(
Expand Down Expand Up @@ -395,7 +396,8 @@ export async function createDialect(
program: anchor.Program,
owner: anchor.web3.Keypair | Wallet,
members: Member[],
encrypted = true,
encrypted = false,
encryptionProps?: EncryptionProps,
): Promise<DialectAccount> {
const sortedMembers = members.sort((a, b) =>
a.publicKey.toBuffer().compare(b.publicKey.toBuffer()),
Expand Down Expand Up @@ -425,11 +427,7 @@ export async function createDialect(
},
);
await waitForFinality(program, tx);
return await getDialectForMembers(
program,
members,
'secretKey' in owner ? owner : undefined,
);
return await getDialectForMembers(program, members, encryptionProps);
}

export async function deleteDialect(
Expand Down Expand Up @@ -470,6 +468,7 @@ export async function sendMessage(
{ dialect, publicKey }: DialectAccount,
sender: anchor.web3.Keypair | Wallet,
text: string,
encryptionProps?: EncryptionProps,
): Promise<Message> {
const [dialectPublicKey, nonce] = await getDialectProgramAddress(
program,
Expand All @@ -480,7 +479,7 @@ export async function sendMessage(
encrypted: dialect.encrypted,
members: dialect.members,
},
sender && 'secretKey' in sender ? sender : undefined,
encryptionProps,
);
const serializedText = textSerde.serialize(text);
await program.rpc.sendMessage(
Expand All @@ -498,7 +497,7 @@ export async function sendMessage(
signers: sender && 'secretKey' in sender ? [sender] : [],
},
);
const d = await getDialect(program, publicKey, sender);
const d = await getDialect(program, publicKey, encryptionProps);
return d.dialect.messages[d.dialect.nextMessageIdx - 1]; // TODO: Support ring
}

Expand Down
45 changes: 27 additions & 18 deletions src/api/text-serde.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,14 @@ import {
generateRandomNonceWithPrefix,
NONCE_SIZE_BYTES,
} from '../utils/nonce-generator';
import { ecdhDecrypt, ecdhEncrypt } from '../utils/ecdh-encryption';
import {
Curve25519KeyPair,
ecdhDecrypt,
ecdhEncrypt,
Ed25519Key,
} from '../utils/ecdh-encryption';
import * as anchor from '@project-serum/anchor';
import { PublicKey } from '@solana/web3.js';

export interface TextSerde {
serialize(text: string): Uint8Array;
Expand All @@ -17,37 +23,35 @@ export class EncryptedTextSerde implements TextSerde {
new UnencryptedTextSerde();

constructor(
private readonly user: anchor.web3.Keypair,
private readonly encryptionProps: EncryptionProps,
private readonly members: Member[],
) {}

deserialize(bytes: Uint8Array): string {
const encryptionNonce = bytes.slice(0, NONCE_SIZE_BYTES);
const encryptedText = bytes.slice(NONCE_SIZE_BYTES, bytes.length);
const otherMember = this.findOtherMember(this.user.publicKey);
const otherMember = this.findOtherMember(
new PublicKey(this.encryptionProps.ed25519PublicKey),
);
const encodedText = ecdhDecrypt(
encryptedText,
{
secretKey: this.user.secretKey,
publicKey: this.user.publicKey.toBytes(),
},
otherMember.publicKey.toBuffer(),
this.encryptionProps.diffieHellmanKeyPair,
otherMember.publicKey.toBytes(),
encryptionNonce,
);
return this.unencryptedTextSerde.deserialize(encodedText);
}

serialize(text: string): Uint8Array {
const senderMemberIdx = this.findMemberIdx(this.user.publicKey);
const publicKey = new PublicKey(this.encryptionProps.ed25519PublicKey);
const senderMemberIdx = this.findMemberIdx(publicKey);
const textBytes = this.unencryptedTextSerde.serialize(text);
const otherMember = this.findOtherMember(this.user.publicKey);
const otherMember = this.findOtherMember(publicKey);
const encryptionNonce = generateRandomNonceWithPrefix(senderMemberIdx);
const encryptedText = ecdhEncrypt(
textBytes,
{
secretKey: this.user.secretKey,
publicKey: this.user.publicKey.toBytes(),
},
this.encryptionProps.diffieHellmanKeyPair,

otherMember.publicKey.toBytes(),
encryptionNonce,
);
Expand Down Expand Up @@ -88,17 +92,22 @@ export type DialectAttributes = {
members: Member[];
};

export interface EncryptionProps {
diffieHellmanKeyPair: Curve25519KeyPair;
ed25519PublicKey: Ed25519Key;
}

export class TextSerdeFactory {
static create(
{ encrypted, members }: DialectAttributes,
user?: anchor.web3.Keypair,
encryptionProps?: EncryptionProps,
): TextSerde {
if (!encrypted) {
return new UnencryptedTextSerde();
}
if (encrypted && user) {
return new EncryptedTextSerde(user, members);
if (encrypted && encryptionProps) {
return new EncryptedTextSerde(encryptionProps, members);
}
throw new Error('Cannot proceed with encrypted dialect w/o user identity');
throw new Error('Cannot proceed without encryptionProps');
}
}
50 changes: 35 additions & 15 deletions src/utils/ecdh-encryption.spec.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,29 @@
import { expect } from 'chai';
import {
Curve25519KeyPair,
ecdhDecrypt,
ecdhEncrypt,
ENCRYPTION_OVERHEAD_BYTES,
generateEd25519KeyPair,
} from './ecdh-encryption';
import { randomBytes } from 'tweetnacl';
import { NONCE_SIZE_BYTES } from './nonce-generator';
import { Keypair } from '@solana/web3.js';
import ed2curve from 'ed2curve';

function generateKeypair() {
const { publicKey, secretKey } = new Keypair();
const curve25519: Curve25519KeyPair = ed2curve.convertKeyPair({
publicKey: publicKey.toBytes(),
secretKey,
})!;
return {
ed25519: {
publicKey: publicKey.toBytes(),
secretKey,
},
curve25519,
};
}

describe('ECDH encryptor/decryptor test', async () => {
/*
Expand All @@ -25,10 +42,13 @@ describe('ECDH encryptor/decryptor test', async () => {
const sizesComparison = messageSizes.map((size) => {
const unencrypted = randomBytes(size);
const nonce = randomBytes(NONCE_SIZE_BYTES);
const keyPair1 = generateKeypair();
const keyPair2 = generateKeypair();

const encrypted = ecdhEncrypt(
unencrypted,
generateEd25519KeyPair(),
generateEd25519KeyPair().publicKey,
keyPair1.curve25519,
keyPair2.ed25519.publicKey,
nonce,
);
return {
Expand All @@ -47,19 +67,19 @@ describe('ECDH encryptor/decryptor test', async () => {
// given
const unencrypted = randomBytes(10);
const nonce = randomBytes(NONCE_SIZE_BYTES);
const party1KeyPair = generateEd25519KeyPair();
const party2KeyPair = generateEd25519KeyPair();
const party1KeyPair = generateKeypair();
const party2KeyPair = generateKeypair();
const encrypted = ecdhEncrypt(
unencrypted,
party1KeyPair,
party2KeyPair.publicKey,
party1KeyPair.curve25519,
party2KeyPair.ed25519.publicKey,
nonce,
);
// when
const decrypted = ecdhDecrypt(
encrypted,
party1KeyPair,
party2KeyPair.publicKey,
party1KeyPair.curve25519,
party2KeyPair.ed25519.publicKey,
nonce,
);
// then
Expand All @@ -71,19 +91,19 @@ describe('ECDH encryptor/decryptor test', async () => {
// given
const unencrypted = randomBytes(10);
const nonce = randomBytes(NONCE_SIZE_BYTES);
const party1KeyPair = generateEd25519KeyPair();
const party2KeyPair = generateEd25519KeyPair();
const party1KeyPair = generateKeypair();
const party2KeyPair = generateKeypair();
const encrypted = ecdhEncrypt(
unencrypted,
party1KeyPair,
party2KeyPair.publicKey,
party1KeyPair.curve25519,
party2KeyPair.ed25519.publicKey,
nonce,
);
// when
const decrypted = ecdhDecrypt(
encrypted,
party2KeyPair,
party1KeyPair.publicKey,
party2KeyPair.curve25519,
party1KeyPair.ed25519.publicKey,
nonce,
);
// then
Expand Down
Loading

0 comments on commit 6e24559

Please sign in to comment.