From fad34a74acbf94d359e3db1cc3fdcddab5550b0c Mon Sep 17 00:00:00 2001 From: Carlo Mazzaferro Date: Mon, 16 Sep 2024 13:07:04 +0200 Subject: [PATCH] =?UTF-8?q?Revert=20"Revert=20"fix:=20also=20use=2064-bit?= =?UTF-8?q?=20counter=20in=20final=20state=20addition"=E2=80=A6=20(#378)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Revert "Revert "fix: also use 64-bit counter in final state addition" (#374)" This reverts commit 1647bc0526e764171ad240c6f8d8f520146fe5ad. --- iris-mpc-gpu/benches/chacha.rs | 34 +-- iris-mpc-gpu/src/rng/aes.cu | 479 ------------------------------ iris-mpc-gpu/src/rng/aes.rs | 106 ------- iris-mpc-gpu/src/rng/chacha.cu | 19 +- iris-mpc-gpu/src/rng/field_fix.cu | 34 --- iris-mpc-gpu/src/rng/mod.rs | 1 - 6 files changed, 12 insertions(+), 661 deletions(-) delete mode 100644 iris-mpc-gpu/src/rng/aes.cu delete mode 100644 iris-mpc-gpu/src/rng/aes.rs delete mode 100644 iris-mpc-gpu/src/rng/field_fix.cu diff --git a/iris-mpc-gpu/benches/chacha.rs b/iris-mpc-gpu/benches/chacha.rs index dfd9b563e..97c914050 100644 --- a/iris-mpc-gpu/benches/chacha.rs +++ b/iris-mpc-gpu/benches/chacha.rs @@ -1,6 +1,6 @@ use criterion::{criterion_group, criterion_main, Criterion}; use cudarc::driver::CudaDevice; -use iris_mpc_gpu::rng::{aes::AesCudaRng, chacha::ChaChaCudaRng}; +use iris_mpc_gpu::rng::chacha::ChaChaCudaRng; pub fn criterion_benchmark_chacha12_runner(c: &mut Criterion, buf_size: usize) { let mut group = c.benchmark_group(format!( @@ -28,30 +28,6 @@ pub fn criterion_benchmark_chacha12_runner(c: &mut Criterion, buf_size: usize) { group.finish(); } -pub fn criterion_benchmark_aes_runner(c: &mut Criterion, buf_size: usize) { - let mut group = c.benchmark_group(format!( - "AES (buf_size = {}MB)", - buf_size * 4 / (1024 * 1024) - )); - - group.throughput(criterion::Throughput::Bytes( - (buf_size * std::mem::size_of::()) as u64, - )); - let mut chacha = AesCudaRng::init(buf_size); - group.bench_function("with copy to host", move |b| { - b.iter(|| { - chacha.fill_rng(); - }) - }); - let mut chacha = AesCudaRng::init(buf_size); - group.bench_function("without copy to host", move |b| { - b.iter(|| { - chacha.fill_rng_no_host_copy(); - }) - }); - group.finish(); -} - pub fn criterion_benchmark_chacha12(c: &mut Criterion) { for log_buf_size in 20..=30 { let buf_size = (1usize << log_buf_size) / 4; @@ -59,15 +35,9 @@ pub fn criterion_benchmark_chacha12(c: &mut Criterion) { } } -pub fn criterion_benchmark_aes(c: &mut Criterion) { - for log_buf_size in 20..=30 { - let buf_size = (1usize << log_buf_size) / 4; - criterion_benchmark_aes_runner(c, buf_size); - } -} criterion_group!( name = rng_benches; config = Criterion::default(); - targets = criterion_benchmark_chacha12, criterion_benchmark_aes + targets = criterion_benchmark_chacha12 ); criterion_main!(rng_benches); diff --git a/iris-mpc-gpu/src/rng/aes.cu b/iris-mpc-gpu/src/rng/aes.cu deleted file mode 100644 index af7b923ff..000000000 --- a/iris-mpc-gpu/src/rng/aes.cu +++ /dev/null @@ -1,479 +0,0 @@ -/* - * Parts of this file are copied from https://github.com/pytorch/csprng/, - * with the following license: - * - * Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. - * - * This source code is licensed under the BSD-style license found in the - * PYTORCH_LICENSE file in the current directory of this source tree. - */ - -// #include - -#define uint8_t unsigned char -#define int64_t long long -#define size_t unsigned long long -#define uint32_t unsigned int - -namespace aes -{ - -// This AES implementation is based on -// https://github.com/kokke/tiny-AES-c/blob/master/aes.c -// authored by kokke and et al. and distributed under public domain license. -// -// This is free and unencumbered software released into the public domain. -// -// Anyone is free to copy, modify, publish, use, compile, sell, or -// distribute this software, either in source code form or as a compiled -// binary, for any purpose, commercial or non-commercial, and by any -// means. -// -// In jurisdictions that recognize copyright laws, the author or authors -// of this software dedicate any and all copyright interest in the -// software to the public domain. We make this dedication for the benefit -// of the public at large and to the detriment of our heirs and -// successors. We intend this dedication to be an overt act of -// relinquishment in perpetuity of all present and future rights to this -// software under copyright law. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. -// IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR -// OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, -// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR -// OTHER DEALINGS IN THE SOFTWARE. -// -// For more information, please refer to -// -// Adapted for CUDA by Pavel Belevich - -/*****************************************************************************/ -/* Defines: */ -/*****************************************************************************/ -// The number of columns comprising a state in AES. This is a constant in AES. Value=4 -#define Nb 4 - -#if defined(AES256) && (AES256 == 1) -#define Nk 8 -#define Nr 14 -#elif defined(AES192) && (AES192 == 1) -#define Nk 6 -#define Nr 12 -#else -#define Nk 4 // The number of 32 bit words in a key. -#define Nr 10 // The number of rounds in AES Cipher. -#endif - - constexpr size_t block_t_size = 16; - - typedef uint8_t state_t[4][4]; - - // The lookup-tables are marked const so they can be placed in read-only storage instead of RAM - // The numbers below can be computed dynamically trading ROM for RAM - - // This can be useful in (embedded) bootloader applications, where ROM is often limited. - __constant__ const uint8_t sbox[256] = { - // 0 1 2 3 4 5 6 7 8 9 A B C D E F - 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76, - 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, - 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15, - 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75, - 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84, - 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf, - 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8, - 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, - 0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73, - 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb, - 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, - 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08, - 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a, - 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, - 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf, - 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16}; - - __constant__ const uint8_t rsbox[256] = { - 0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38, 0xbf, 0x40, 0xa3, 0x9e, 0x81, 0xf3, 0xd7, 0xfb, - 0x7c, 0xe3, 0x39, 0x82, 0x9b, 0x2f, 0xff, 0x87, 0x34, 0x8e, 0x43, 0x44, 0xc4, 0xde, 0xe9, 0xcb, - 0x54, 0x7b, 0x94, 0x32, 0xa6, 0xc2, 0x23, 0x3d, 0xee, 0x4c, 0x95, 0x0b, 0x42, 0xfa, 0xc3, 0x4e, - 0x08, 0x2e, 0xa1, 0x66, 0x28, 0xd9, 0x24, 0xb2, 0x76, 0x5b, 0xa2, 0x49, 0x6d, 0x8b, 0xd1, 0x25, - 0x72, 0xf8, 0xf6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xd4, 0xa4, 0x5c, 0xcc, 0x5d, 0x65, 0xb6, 0x92, - 0x6c, 0x70, 0x48, 0x50, 0xfd, 0xed, 0xb9, 0xda, 0x5e, 0x15, 0x46, 0x57, 0xa7, 0x8d, 0x9d, 0x84, - 0x90, 0xd8, 0xab, 0x00, 0x8c, 0xbc, 0xd3, 0x0a, 0xf7, 0xe4, 0x58, 0x05, 0xb8, 0xb3, 0x45, 0x06, - 0xd0, 0x2c, 0x1e, 0x8f, 0xca, 0x3f, 0x0f, 0x02, 0xc1, 0xaf, 0xbd, 0x03, 0x01, 0x13, 0x8a, 0x6b, - 0x3a, 0x91, 0x11, 0x41, 0x4f, 0x67, 0xdc, 0xea, 0x97, 0xf2, 0xcf, 0xce, 0xf0, 0xb4, 0xe6, 0x73, - 0x96, 0xac, 0x74, 0x22, 0xe7, 0xad, 0x35, 0x85, 0xe2, 0xf9, 0x37, 0xe8, 0x1c, 0x75, 0xdf, 0x6e, - 0x47, 0xf1, 0x1a, 0x71, 0x1d, 0x29, 0xc5, 0x89, 0x6f, 0xb7, 0x62, 0x0e, 0xaa, 0x18, 0xbe, 0x1b, - 0xfc, 0x56, 0x3e, 0x4b, 0xc6, 0xd2, 0x79, 0x20, 0x9a, 0xdb, 0xc0, 0xfe, 0x78, 0xcd, 0x5a, 0xf4, - 0x1f, 0xdd, 0xa8, 0x33, 0x88, 0x07, 0xc7, 0x31, 0xb1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xec, 0x5f, - 0x60, 0x51, 0x7f, 0xa9, 0x19, 0xb5, 0x4a, 0x0d, 0x2d, 0xe5, 0x7a, 0x9f, 0x93, 0xc9, 0x9c, 0xef, - 0xa0, 0xe0, 0x3b, 0x4d, 0xae, 0x2a, 0xf5, 0xb0, 0xc8, 0xeb, 0xbb, 0x3c, 0x83, 0x53, 0x99, 0x61, - 0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, 0x26, 0xe1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0c, 0x7d}; - - // The round constant word array, Rcon[i], contains the values given by - // x to the power (i-1) being powers of x (x is denoted as {02}) in the field GF(2^8) - __constant__ const uint8_t Rcon[11] = { - 0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36}; - -#define getSBoxValue(num) (sbox[(num)]) - -#define getSBoxInvert(num) (rsbox[(num)]) - - // This function produces Nb(Nr+1) round keys. The round keys are used in each round to decrypt the states. - __device__ void KeyExpansion(uint8_t *RoundKey, const uint8_t *Key) - { - unsigned int i, j, k; - uint8_t tempa[4]; // Used for the column/row operations - - // The first round key is the key itself. - for (i = 0; i < Nk; ++i) - { - RoundKey[(i * 4) + 0] = Key[(i * 4) + 0]; - RoundKey[(i * 4) + 1] = Key[(i * 4) + 1]; - RoundKey[(i * 4) + 2] = Key[(i * 4) + 2]; - RoundKey[(i * 4) + 3] = Key[(i * 4) + 3]; - } - - // All other round keys are found from the previous round keys. - for (i = Nk; i < Nb * (Nr + 1); ++i) - { - { - k = (i - 1) * 4; - tempa[0] = RoundKey[k + 0]; - tempa[1] = RoundKey[k + 1]; - tempa[2] = RoundKey[k + 2]; - tempa[3] = RoundKey[k + 3]; - } - - if (i % Nk == 0) - { - // This function shifts the 4 bytes in a word to the left once. - // [a0,a1,a2,a3] becomes [a1,a2,a3,a0] - - // Function RotWord() - { - const uint8_t u8tmp = tempa[0]; - tempa[0] = tempa[1]; - tempa[1] = tempa[2]; - tempa[2] = tempa[3]; - tempa[3] = u8tmp; - } - - // SubWord() is a function that takes a four-byte input word and - // applies the S-box to each of the four bytes to produce an output word. - - // Function Subword() - { - tempa[0] = getSBoxValue(tempa[0]); - tempa[1] = getSBoxValue(tempa[1]); - tempa[2] = getSBoxValue(tempa[2]); - tempa[3] = getSBoxValue(tempa[3]); - } - - tempa[0] = tempa[0] ^ Rcon[i / Nk]; - } -#if defined(AES256) && (AES256 == 1) - if (i % Nk == 4) - { - // Function Subword() - { - tempa[0] = getSBoxValue(tempa[0]); - tempa[1] = getSBoxValue(tempa[1]); - tempa[2] = getSBoxValue(tempa[2]); - tempa[3] = getSBoxValue(tempa[3]); - } - } -#endif - j = i * 4; - k = (i - Nk) * 4; - RoundKey[j + 0] = RoundKey[k + 0] ^ tempa[0]; - RoundKey[j + 1] = RoundKey[k + 1] ^ tempa[1]; - RoundKey[j + 2] = RoundKey[k + 2] ^ tempa[2]; - RoundKey[j + 3] = RoundKey[k + 3] ^ tempa[3]; - } - } - - // This function adds the round key to state. - // The round key is added to the state by an XOR function. - __device__ void AddRoundKey(uint8_t round, state_t *state, const uint8_t *RoundKey) - { - uint8_t i, j; - for (i = 0; i < 4; ++i) - { - for (j = 0; j < 4; ++j) - { - (*state)[i][j] ^= RoundKey[(round * Nb * 4) + (i * Nb) + j]; - } - } - } - - // The SubBytes Function Substitutes the values in the - // state matrix with values in an S-box. - __device__ void SubBytes(state_t *state) - { - uint8_t i, j; - for (i = 0; i < 4; ++i) - { - for (j = 0; j < 4; ++j) - { - (*state)[j][i] = getSBoxValue((*state)[j][i]); - } - } - } - - // The ShiftRows() function shifts the rows in the state to the left. - // Each row is shifted with different offset. - // Offset = Row number. So the first row is not shifted. - __device__ void ShiftRows(state_t *state) - { - uint8_t temp; - - // Rotate first row 1 columns to left - temp = (*state)[0][1]; - (*state)[0][1] = (*state)[1][1]; - (*state)[1][1] = (*state)[2][1]; - (*state)[2][1] = (*state)[3][1]; - (*state)[3][1] = temp; - - // Rotate second row 2 columns to left - temp = (*state)[0][2]; - (*state)[0][2] = (*state)[2][2]; - (*state)[2][2] = temp; - - temp = (*state)[1][2]; - (*state)[1][2] = (*state)[3][2]; - (*state)[3][2] = temp; - - // Rotate third row 3 columns to left - temp = (*state)[0][3]; - (*state)[0][3] = (*state)[3][3]; - (*state)[3][3] = (*state)[2][3]; - (*state)[2][3] = (*state)[1][3]; - (*state)[1][3] = temp; - } - - __device__ uint8_t xtime(uint8_t x) - { - return ((x << 1) ^ (((x >> 7) & 1) * 0x1b)); - } - - // MixColumns function mixes the columns of the state matrix - __device__ void MixColumns(state_t *state) - { - uint8_t i; - uint8_t Tmp, Tm, t; - for (i = 0; i < 4; ++i) - { - t = (*state)[i][0]; - Tmp = (*state)[i][0] ^ (*state)[i][1] ^ (*state)[i][2] ^ (*state)[i][3]; - Tm = (*state)[i][0] ^ (*state)[i][1]; - Tm = xtime(Tm); - (*state)[i][0] ^= Tm ^ Tmp; - Tm = (*state)[i][1] ^ (*state)[i][2]; - Tm = xtime(Tm); - (*state)[i][1] ^= Tm ^ Tmp; - Tm = (*state)[i][2] ^ (*state)[i][3]; - Tm = xtime(Tm); - (*state)[i][2] ^= Tm ^ Tmp; - Tm = (*state)[i][3] ^ t; - Tm = xtime(Tm); - (*state)[i][3] ^= Tm ^ Tmp; - } - } - - __device__ uint8_t Multiply(uint8_t x, uint8_t y) - { - return (((y & 1) * x) ^ - ((y >> 1 & 1) * xtime(x)) ^ - ((y >> 2 & 1) * xtime(xtime(x))) ^ - ((y >> 3 & 1) * xtime(xtime(xtime(x)))) ^ - ((y >> 4 & 1) * xtime(xtime(xtime(xtime(x)))))); /* this last call to xtime() can be omitted */ - } - - // MixColumns function mixes the columns of the state matrix. - // The method used to multiply may be difficult to understand for the inexperienced. - // Please use the references to gain more information. - __device__ void InvMixColumns(state_t *state) - { - int i; - uint8_t a, b, c, d; - for (i = 0; i < 4; ++i) - { - a = (*state)[i][0]; - b = (*state)[i][1]; - c = (*state)[i][2]; - d = (*state)[i][3]; - - (*state)[i][0] = Multiply(a, 0x0e) ^ Multiply(b, 0x0b) ^ Multiply(c, 0x0d) ^ Multiply(d, 0x09); - (*state)[i][1] = Multiply(a, 0x09) ^ Multiply(b, 0x0e) ^ Multiply(c, 0x0b) ^ Multiply(d, 0x0d); - (*state)[i][2] = Multiply(a, 0x0d) ^ Multiply(b, 0x09) ^ Multiply(c, 0x0e) ^ Multiply(d, 0x0b); - (*state)[i][3] = Multiply(a, 0x0b) ^ Multiply(b, 0x0d) ^ Multiply(c, 0x09) ^ Multiply(d, 0x0e); - } - } - - // The SubBytes Function Substitutes the values in the - // state matrix with values in an S-box. - __device__ void InvSubBytes(state_t *state) - { - uint8_t i, j; - for (i = 0; i < 4; ++i) - { - for (j = 0; j < 4; ++j) - { - (*state)[j][i] = getSBoxInvert((*state)[j][i]); - } - } - } - - __device__ void InvShiftRows(state_t *state) - { - uint8_t temp; - - // Rotate first row 1 columns to right - temp = (*state)[3][1]; - (*state)[3][1] = (*state)[2][1]; - (*state)[2][1] = (*state)[1][1]; - (*state)[1][1] = (*state)[0][1]; - (*state)[0][1] = temp; - - // Rotate second row 2 columns to right - temp = (*state)[0][2]; - (*state)[0][2] = (*state)[2][2]; - (*state)[2][2] = temp; - - temp = (*state)[1][2]; - (*state)[1][2] = (*state)[3][2]; - (*state)[3][2] = temp; - - // Rotate third row 3 columns to right - temp = (*state)[0][3]; - (*state)[0][3] = (*state)[1][3]; - (*state)[1][3] = (*state)[2][3]; - (*state)[2][3] = (*state)[3][3]; - (*state)[3][3] = temp; - } - - __device__ void encrypt(uint8_t *state, const uint8_t *key) - { - uint8_t RoundKey[176]; - KeyExpansion(RoundKey, key); - - uint8_t round = 0; - - // Add the First round key to the state before starting the rounds. - AddRoundKey(0, (state_t *)state, RoundKey); - - // There will be Nr rounds. - // The first Nr-1 rounds are identical. - // These Nr rounds are executed in the loop below. - // Last one without MixColumns() - for (round = 1;; ++round) - { - SubBytes((state_t *)state); - ShiftRows((state_t *)state); - if (round == Nr) - { - break; - } - MixColumns((state_t *)state); - AddRoundKey(round, (state_t *)state, RoundKey); - } - // Add round key to last round - AddRoundKey(Nr, (state_t *)state, RoundKey); - } - - __device__ void decrypt(uint8_t *state, const uint8_t *key) - { - uint8_t RoundKey[176]; - KeyExpansion(RoundKey, key); - - uint8_t round = 0; - - // Add the First round key to the state before starting the rounds. - AddRoundKey(Nr, (state_t *)state, RoundKey); - - // There will be Nr rounds. - // The first Nr-1 rounds are identical. - // These Nr rounds are executed in the loop below. - // Last one without InvMixColumn() - for (round = (Nr - 1);; --round) - { - InvShiftRows((state_t *)state); - InvSubBytes((state_t *)state); - AddRoundKey(round, (state_t *)state, RoundKey); - if (round == 0) - { - break; - } - InvMixColumns((state_t *)state); - } - } - -} - -template -__device__ static void copy_input_to_block(int64_t idx, uint8_t *block, int block_size, - void *input_ptr, int64_t input_numel, int input_type_size, input_index_calc_t input_index_calc) -{ - for (auto i = 0; i < block_size / input_type_size; ++i) - { - const auto linear_index = idx * (block_size / input_type_size) + i; - if (linear_index < input_numel) - { - memcpy( - block + i * input_type_size, - &(reinterpret_cast(input_ptr)[input_index_calc(linear_index)]), - input_type_size); - } - } -} - -template -__device__ static void copy_block_to_output(int64_t idx, uint8_t *block, int output_elem_per_block, - void *output_ptr, int64_t output_numel, int output_type_size, output_index_calc_t output_index_calc) -{ - for (auto i = 0; i < output_elem_per_block; ++i) - { - const auto linear_index = idx * output_elem_per_block + i; - if (linear_index < output_numel) - { - memcpy( - &(reinterpret_cast(output_ptr)[output_index_calc(linear_index)]), - block + i * output_type_size, - output_type_size); - } - } -} - -template -__device__ static void block_cipher_kernel_cuda( - const uint8_t *key_bytes, - int output_elem_per_block, void *output_ptr, int64_t output_numel, - int output_type_size, output_index_calc_t output_index_calc) -{ - const auto idx = blockIdx.x * blockDim.x + threadIdx.x; - uint8_t block[aes::block_t_size]; - memset(&block, 0, aes::block_t_size); // is it ok to use zeros as padding? - // if (input_ptr != nullptr) - // { - // copy_input_to_block(idx, block, aes::block_t_size, input_ptr, input_numel, input_type_size, input_index_calc); - // } - - uint8_t idx_block[aes::block_t_size]; - memset(&idx_block, 0, aes::block_t_size); - *(reinterpret_cast(idx_block)) = idx; - aes::encrypt(idx_block, key_bytes); - for (size_t i = 0; i < aes::block_t_size; i++) - { - block[i] ^= idx_block[i]; - } - copy_block_to_output(idx, block, output_elem_per_block, output_ptr, output_numel, output_type_size, output_index_calc); -} - -extern "C" __global__ void aes_128_rng(const uint8_t *key_bytes, - int output_elem_per_block, void *output_ptr, int64_t output_numel, - int output_type_size) -{ - block_cipher_kernel_cuda(key_bytes, output_elem_per_block, output_ptr, output_numel, output_type_size, [](uint32_t idx) - { return idx; }); -} diff --git a/iris-mpc-gpu/src/rng/aes.rs b/iris-mpc-gpu/src/rng/aes.rs deleted file mode 100644 index 5c6bedfd9..000000000 --- a/iris-mpc-gpu/src/rng/aes.rs +++ /dev/null @@ -1,106 +0,0 @@ -use cudarc::{ - driver::{CudaDevice, CudaFunction, CudaSlice, LaunchAsync, LaunchConfig}, - nvrtc::compile_ptx, -}; -use std::{mem, sync::Arc}; - -pub struct AesCudaRng { - buf_size: usize, - devs: Vec>, - kernels: Vec, - rng_chunks: Vec>, - output_buf: Vec, -} - -const AES_PTX_SRC: &str = include_str!("aes.cu"); -const AES_FUNCTION_NAME: &str = "aes_128_rng"; - -impl AesCudaRng { - // buf size in u8 - pub fn init(buf_size: usize) -> Self { - let n_devices = CudaDevice::count().unwrap() as usize; - let mut devs = Vec::new(); - let mut kernels = Vec::new(); - let ptx = compile_ptx(AES_PTX_SRC).unwrap(); - - for i in 0..n_devices { - // This call to CudaDevice::new is only used in context of a benchmark - not - // used in the server binary - let dev = CudaDevice::new(i).unwrap(); - dev.load_ptx(ptx.clone(), AES_FUNCTION_NAME, &[AES_FUNCTION_NAME]) - .unwrap(); - let function = dev.get_func(AES_FUNCTION_NAME, AES_FUNCTION_NAME).unwrap(); - - devs.push(dev); - kernels.push(function); - } - - assert!(buf_size % 16 == 0, "buf_size must be a multiple of 16 atm"); - - let buf = vec![0u8; buf_size]; - let rng_chunks = vec![devs[0].htod_sync_copy(&buf).unwrap()]; // just do on device 0 for now - - Self { - buf_size, - devs, - kernels, - rng_chunks, - output_buf: buf, - } - } - - pub fn fill_rng(&mut self) { - self.fill_rng_no_host_copy(); - self.devs[0] - .dtoh_sync_copy_into(&self.rng_chunks[0], &mut self.output_buf[..]) - .unwrap(); - } - pub fn fill_rng_no_host_copy(&mut self) { - let num_kernel_calls = self.buf_size / 16; - let threads_per_block = 256; - let blocks_per_grid = (num_kernel_calls + threads_per_block - 1) / threads_per_block; - let cfg = LaunchConfig { - block_dim: (threads_per_block as u32, 1, 1), - grid_dim: (blocks_per_grid as u32, 1, 1), - shared_mem_bytes: 0, - }; - let key_bytes = [0u8; 16]; - let key_slice = self.devs[0].htod_sync_copy(&key_bytes[..]).unwrap(); - unsafe { - self.kernels[0] - .clone() - .launch( - cfg, - ( - &key_slice, - 16, - &self.rng_chunks[0], - self.buf_size, - mem::size_of::(), - ), - ) - .unwrap(); - } - } - - pub fn data(&self) -> &[u8] { - &self.output_buf - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - #[cfg(feature = "gpu_dependent")] - fn test_aes_rng() { - let mut rng = AesCudaRng::init(1024 * 1024); - rng.fill_rng(); - - let zeros = rng.data().iter().filter(|x| x == &&0).count(); - let expected = 1024 * 1024 / 256; - assert!(1.1 * expected as f64 > zeros as f64); - assert!(1.1 * zeros as f64 > expected as f64); - } -} diff --git a/iris-mpc-gpu/src/rng/chacha.cu b/iris-mpc-gpu/src/rng/chacha.cu index faac89368..41722ea97 100644 --- a/iris-mpc-gpu/src/rng/chacha.cu +++ b/iris-mpc-gpu/src/rng/chacha.cu @@ -106,10 +106,10 @@ extern "C" __global__ void chacha12(uint32_t *d_ciphertext, uint32_t *d_state, int idx = blockIdx.x * blockDim.x + threadIdx.x; // the 64-bit counter part is in state[12] and 13, we add our local counter = // idx here may not overflow, caller has to ensure that - uint64_t counter = state[12] | (state[13] << 32); + uint64_t counter = (uint64_t)(state[12]) | ((uint64_t)(state[13]) << 32); counter += idx; - thread_state[12] = counter & 0xFFFFFFFF; - thread_state[13] = counter >> 32; + thread_state[12] = (uint32_t)(counter & 0xFFFFFFFF); + thread_state[13] = (uint32_t)(counter >> 32); // 6 double rounds (8 quarter rounds) for (int i = 0; i < 6; i++) { QUARTERROUND(thread_state, 0, 4, 8, 12); @@ -136,8 +136,8 @@ extern "C" __global__ void chacha12(uint32_t *d_ciphertext, uint32_t *d_state, thread_state[9] += state[9]; thread_state[10] += state[10]; thread_state[11] += state[11]; - thread_state[12] += state[12] + idx; - thread_state[13] += state[13]; + thread_state[12] += (uint32_t)(counter & 0xFFFFFFFF); + thread_state[13] += (uint32_t)(counter >> 32); thread_state[14] += state[14]; thread_state[15] += state[15]; @@ -185,10 +185,10 @@ extern "C" __global__ void chacha12_xor(uint32_t *d_ciphertext, int idx = blockIdx.x * blockDim.x + threadIdx.x; // the 64-bit counter part is in state[12] and 13, we add our local counter = // idx here may not overflow, caller has to ensure that - uint64_t counter = state[12] | (state[13] << 32); + uint64_t counter = (uint64_t)(state[12]) | ((uint64_t)(state[13]) << 32); counter += idx; - thread_state[12] = counter & 0xFFFFFFFF; - thread_state[13] = counter >> 32; + thread_state[12] = (uint32_t)(counter & 0xFFFFFFFF); + thread_state[13] = (uint32_t)(counter >> 32); // 6 double rounds (8 quarter rounds) for (int i = 0; i < 6; i++) { QUARTERROUND(thread_state, 0, 4, 8, 12); @@ -215,7 +215,8 @@ extern "C" __global__ void chacha12_xor(uint32_t *d_ciphertext, thread_state[9] += state[9]; thread_state[10] += state[10]; thread_state[11] += state[11]; - thread_state[12] += state[12] + idx; + thread_state[12] += (uint32_t)(counter & 0xFFFFFFFF); + thread_state[13] += (uint32_t)(counter >> 32); thread_state[13] += state[13]; thread_state[14] += state[14]; thread_state[15] += state[15]; diff --git a/iris-mpc-gpu/src/rng/field_fix.cu b/iris-mpc-gpu/src/rng/field_fix.cu deleted file mode 100644 index 3ddb26006..000000000 --- a/iris-mpc-gpu/src/rng/field_fix.cu +++ /dev/null @@ -1,34 +0,0 @@ -#define uint16_t unsigned short -#define uint32_t unsigned int -#define uint64_t unsigned long long - -#define P 65519 - -/** - * the chacha12_block function - */ -extern "C" __global__ void fix_fe(uint32_t *d_ciphertext, uint32_t valid_size) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - - if (idx * 1000 >= valid_size) { - return; - } - // each thread looks at 1000 elements - // and has 24 elements to copy from the fix array - - uint16_t *elements = (uint16_t *)d_ciphertext; - uint16_t *fix = &elements[valid_size]; - uint16_t *my_chunk = &elements[idx * 1000]; - uint16_t *my_fix = &fix[idx * 24]; - - int fix_idx = 0; - - for (int i = 0; i < 1000; i++) { - while (my_chunk[i] >= P) { - assert(fix_idx < 24); // should be bound with prob 2^-128 so this is fine - // to remove I guess - my_chunk[i] = my_fix[fix_idx]; - fix_idx++; - } - } -} diff --git a/iris-mpc-gpu/src/rng/mod.rs b/iris-mpc-gpu/src/rng/mod.rs index 4f32cc86b..abf67057f 100644 --- a/iris-mpc-gpu/src/rng/mod.rs +++ b/iris-mpc-gpu/src/rng/mod.rs @@ -1,3 +1,2 @@ -pub mod aes; pub mod chacha; pub mod chacha_corr;