diff --git a/.github/workflows/build-all-targets.yml b/.github/workflows/build-all-targets.yml index 368ec8d38..ad47edf97 100644 --- a/.github/workflows/build-all-targets.yml +++ b/.github/workflows/build-all-targets.yml @@ -4,7 +4,7 @@ on: push: concurrency: - group: '${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}' + group: "${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}" cancel-in-progress: true jobs: @@ -13,6 +13,8 @@ jobs: steps: - name: Checkout code uses: actions/checkout@v4 + - name: Install Dependencies + run: sudo apt install protobuf-compiler - name: Cache build products uses: Swatinem/rust-cache@v2.7.3 with: @@ -31,6 +33,8 @@ jobs: steps: - name: Checkout code uses: actions/checkout@v4 + - name: Install Dependencies + run: sudo apt install protobuf-compiler - name: Cache build products uses: Swatinem/rust-cache@v2.7.3 with: diff --git a/.github/workflows/doc.yml b/.github/workflows/doc.yml index b776b29cd..6a0d0ac49 100644 --- a/.github/workflows/doc.yml +++ b/.github/workflows/doc.yml @@ -4,7 +4,7 @@ on: push: concurrency: - group: '${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}' + group: "${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}" cancel-in-progress: true jobs: @@ -13,6 +13,8 @@ jobs: steps: - name: Checkout code uses: actions/checkout@v4 + - name: Install Dependencies + run: sudo apt install protobuf-compiler - name: Show errors inline uses: r7kamura/rust-problem-matchers@v1 - name: Install Rust nightly diff --git a/.github/workflows/lint-clippy.yaml b/.github/workflows/lint-clippy.yaml index 760c79520..d5620c21c 100644 --- a/.github/workflows/lint-clippy.yaml +++ b/.github/workflows/lint-clippy.yaml @@ -4,7 +4,7 @@ on: push: concurrency: - group: '${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}' + group: "${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}" cancel-in-progress: true permissions: @@ -21,6 +21,8 @@ jobs: steps: - name: Checkout code uses: actions/checkout@v4 + - name: Install Dependencies + run: sudo apt install protobuf-compiler - name: Install Rust nightly run: rustup toolchain install nightly-2024-07-10 - name: Set Rust nightly as default diff --git a/.github/workflows/run-unit-tests.yaml b/.github/workflows/run-unit-tests.yaml index 382511b2a..bf26e176e 100644 --- a/.github/workflows/run-unit-tests.yaml +++ b/.github/workflows/run-unit-tests.yaml @@ -5,10 +5,10 @@ on: branches: - main pull_request: - types: [ opened, synchronize ] + types: [opened, synchronize] concurrency: - group: '${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}' + group: "${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}" cancel-in-progress: true jobs: @@ -28,6 +28,8 @@ jobs: steps: - name: Checkout code uses: actions/checkout@v4 + - name: Install Dependencies + run: sudo apt install protobuf-compiler - name: Cache build products uses: Swatinem/rust-cache@v2.7.3 with: diff --git a/.github/workflows/temp-branch-build-and-push.yaml b/.github/workflows/temp-branch-build-and-push.yaml new file mode 100644 index 000000000..87b38a746 --- /dev/null +++ b/.github/workflows/temp-branch-build-and-push.yaml @@ -0,0 +1,47 @@ +name: Branch - Build and push docker image + +on: + push: + branches: + - "ps/potential-phantom-match" + +concurrency: + group: '${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}' + cancel-in-progress: true + +env: + REGISTRY: ghcr.io + IMAGE_NAME: ${{ github.repository }} + +jobs: + docker: + runs-on: + labels: ubuntu-22.04-64core + permissions: + packages: write + contents: read + attestations: write + id-token: write + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + - name: Log in to the Container registry + uses: docker/login-action@v3 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.repository_owner }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Build and Push + uses: docker/build-push-action@v6 + with: + context: . + push: true + tags: | + ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ github.sha }} + platforms: linux/amd64 + cache-from: type=gha + cache-to: type=gha,mode=max diff --git a/.github/workflows/test-gpu.yaml b/.github/workflows/test-gpu.yaml index a0f9af714..633cf7010 100644 --- a/.github/workflows/test-gpu.yaml +++ b/.github/workflows/test-gpu.yaml @@ -30,8 +30,8 @@ jobs: sudo ln -sf /usr/bin/gcc-11 /usr/bin/gcc gcc --version - - name: Install OpenSSL && pkg-config - run: sudo apt-get update && sudo apt-get install -y pkg-config libssl-dev + - name: Install OpenSSL && pkg-config && protobuf-compiler + run: sudo apt-get update && sudo apt-get install -y pkg-config libssl-dev protobuf-compiler - name: Install CUDA and NCCL dependencies if: steps.cache-cuda-nccl.outputs.cache-hit != 'true' diff --git a/Cargo.lock b/Cargo.lock index fb3c06e6f..89face5ee 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "addr2line" @@ -42,7 +42,7 @@ dependencies = [ [[package]] name = "aes-prng" version = "0.2.1" -source = "git+https://github.com/tf-encrypted/aes-prng.git?branch=dragos/display#ebe79b1173ab6698c69d18dc464294f1893b44bb" +source = "git+https://github.com/tf-encrypted/aes-prng.git?branch=dragos%2Fdisplay#ebe79b1173ab6698c69d18dc464294f1893b44bb" dependencies = [ "aes", "byteorder", @@ -147,6 +147,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "anyhow" +version = "1.0.93" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c95c10ba0b00a02636238b814946408b1322d5ac4760326e6fb8ec956d85775" + [[package]] name = "arraydeque" version = "0.5.1" @@ -175,6 +181,28 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "async-stream" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.85", +] + [[package]] name = "async-trait" version = "0.1.83" @@ -209,9 +237,9 @@ checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] name = "aws-config" -version = "1.5.9" +version = "1.5.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d6448cfb224dd6a9b9ac734f58622dd0d4751f3589f3b777345745f46b2eb14" +checksum = "9b49afaa341e8dd8577e1a2200468f98956d6eda50bcf4a53246cc00174ba924" dependencies = [ "aws-credential-types", "aws-runtime", @@ -220,7 +248,7 @@ dependencies = [ "aws-sdk-sts", "aws-smithy-async", "aws-smithy-http", - "aws-smithy-json", + "aws-smithy-json 0.60.7", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", @@ -278,9 +306,9 @@ dependencies = [ [[package]] name = "aws-runtime" -version = "1.4.3" +version = "1.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a10d5c055aa540164d9561a0e2e74ad30f0dcf7393c3a92f6733ddf9c5762468" +checksum = "b5ac934720fbb46206292d2c75b57e67acfc56fe7dfd34fb9a02334af08409ea" dependencies = [ "aws-credential-types", "aws-sigv4", @@ -312,7 +340,7 @@ dependencies = [ "aws-runtime", "aws-smithy-async", "aws-smithy-http", - "aws-smithy-json", + "aws-smithy-json 0.60.7", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", @@ -326,9 +354,9 @@ dependencies = [ [[package]] name = "aws-sdk-s3" -version = "1.58.0" +version = "1.65.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0656a79cf5e6ab0d4bb2465cd750a7a2fd7ea26c062183ed94225f5782e22365" +checksum = "d3ba2c5c0f2618937ce3d4a5ad574b86775576fa24006bcb3128c6e2cbf3c34e" dependencies = [ "aws-credential-types", "aws-runtime", @@ -337,7 +365,7 @@ dependencies = [ "aws-smithy-checksums", "aws-smithy-eventstream", "aws-smithy-http", - "aws-smithy-json", + "aws-smithy-json 0.61.1", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", @@ -368,7 +396,7 @@ dependencies = [ "aws-runtime", "aws-smithy-async", "aws-smithy-http", - "aws-smithy-json", + "aws-smithy-json 0.60.7", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", @@ -391,7 +419,7 @@ dependencies = [ "aws-runtime", "aws-smithy-async", "aws-smithy-http", - "aws-smithy-json", + "aws-smithy-json 0.60.7", "aws-smithy-query", "aws-smithy-runtime", "aws-smithy-runtime-api", @@ -414,7 +442,7 @@ dependencies = [ "aws-runtime", "aws-smithy-async", "aws-smithy-http", - "aws-smithy-json", + "aws-smithy-json 0.60.7", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", @@ -428,15 +456,15 @@ dependencies = [ [[package]] name = "aws-sdk-sso" -version = "1.47.0" +version = "1.50.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8776850becacbd3a82a4737a9375ddb5c6832a51379f24443a98e61513f852c" +checksum = "05ca43a4ef210894f93096039ef1d6fa4ad3edfabb3be92b80908b9f2e4b4eab" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", "aws-smithy-http", - "aws-smithy-json", + "aws-smithy-json 0.61.1", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", @@ -450,15 +478,15 @@ dependencies = [ [[package]] name = "aws-sdk-ssooidc" -version = "1.48.0" +version = "1.51.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0007b5b8004547133319b6c4e87193eee2a0bcb3e4c18c75d09febe9dab7b383" +checksum = "abaf490c2e48eed0bb8e2da2fb08405647bd7f253996e0f93b981958ea0f73b0" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", "aws-smithy-http", - "aws-smithy-json", + "aws-smithy-json 0.61.1", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", @@ -472,15 +500,15 @@ dependencies = [ [[package]] name = "aws-sdk-sts" -version = "1.47.0" +version = "1.51.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fffaa356e7f1c725908b75136d53207fa714e348f365671df14e95a60530ad3" +checksum = "b68fde0d69c8bfdc1060ea7da21df3e39f6014da316783336deff0a9ec28f4bf" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", "aws-smithy-http", - "aws-smithy-json", + "aws-smithy-json 0.61.1", "aws-smithy-query", "aws-smithy-runtime", "aws-smithy-runtime-api", @@ -495,9 +523,9 @@ dependencies = [ [[package]] name = "aws-sigv4" -version = "1.2.5" +version = "1.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5619742a0d8f253be760bfbb8e8e8368c69e3587e4637af5754e488a611499b1" +checksum = "7d3820e0c08d0737872ff3c7c1f21ebbb6693d832312d6152bf18ef50a5471c2" dependencies = [ "aws-credential-types", "aws-smithy-eventstream", @@ -595,6 +623,15 @@ dependencies = [ "aws-smithy-types", ] +[[package]] +name = "aws-smithy-json" +version = "0.61.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee4e69cc50921eb913c6b662f8d909131bb3e6ad6cb6090d3a39b66fc5c52095" +dependencies = [ + "aws-smithy-types", +] + [[package]] name = "aws-smithy-query" version = "0.60.7" @@ -607,9 +644,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.7.3" +version = "1.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be28bd063fa91fd871d131fc8b68d7cd4c5fa0869bea68daca50dcb1cbd76be2" +checksum = "9f20685047ca9d6f17b994a07f629c813f08b5bce65523e47124879e60103d45" dependencies = [ "aws-smithy-async", "aws-smithy-http", @@ -634,9 +671,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime-api" -version = "1.7.2" +version = "1.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e086682a53d3aa241192aa110fa8dfce98f2f5ac2ead0de84d41582c7e8fdb96" +checksum = "92165296a47a812b267b4f41032ff8069ab7ff783696d217f0994a0d7ab585cd" dependencies = [ "aws-smithy-async", "aws-smithy-types", @@ -651,9 +688,9 @@ dependencies = [ [[package]] name = "aws-smithy-types" -version = "1.2.8" +version = "1.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07c9cdc179e6afbf5d391ab08c85eac817b51c87e1892a5edb5f7bbdc64314b4" +checksum = "4fbd94a32b3a7d55d3806fe27d98d3ad393050439dd05eb53ece36ec5e3d3510" dependencies = [ "base64-simd", "bytes", @@ -726,7 +763,7 @@ dependencies = [ "serde_urlencoded", "sync_wrapper 1.0.1", "tokio", - "tower", + "tower 0.5.1", "tower-layer", "tower-service", "tracing", @@ -753,6 +790,20 @@ dependencies = [ "tracing", ] +[[package]] +name = "backoff" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1" +dependencies = [ + "futures-core", + "getrandom", + "instant", + "pin-project-lite", + "rand", + "tokio", +] + [[package]] name = "backtrace" version = "0.3.71" @@ -886,7 +937,7 @@ dependencies = [ "base64 0.13.1", "bitvec", "hex", - "indexmap", + "indexmap 2.6.0", "js-sys", "once_cell", "rand", @@ -1399,6 +1450,27 @@ dependencies = [ "typenum", ] +[[package]] +name = "csv" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acdc4883a9c96732e4733212c01447ebd805833b7275a73ca3ee080fd77afdaf" +dependencies = [ + "csv-core", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "csv-core" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5efa2b3d7902f4b634a20cae3c9c4e6209dc4779feb6863329607560143efa70" +dependencies = [ + "memchr", +] + [[package]] name = "cudarc" version = "0.12.1" @@ -1755,6 +1827,12 @@ dependencies = [ "subtle", ] +[[package]] +name = "fixedbitset" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" + [[package]] name = "fixedbitset" version = "0.5.7" @@ -1869,7 +1947,7 @@ version = "7.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9b724496da7c26fcce66458526ce68fc2ecf4aaaa994281cf322ded5755520c" dependencies = [ - "fixedbitset", + "fixedbitset 0.5.7", "futures-buffered", "futures-core", "futures-lite", @@ -2043,7 +2121,7 @@ dependencies = [ "futures-sink", "futures-util", "http 0.2.12", - "indexmap", + "indexmap 2.6.0", "slab", "tokio", "tokio-util", @@ -2062,7 +2140,7 @@ dependencies = [ "futures-core", "futures-sink", "http 1.1.0", - "indexmap", + "indexmap 2.6.0", "slab", "tokio", "tokio-util", @@ -2079,6 +2157,12 @@ dependencies = [ "crunchy", ] +[[package]] +name = "hashbrown" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" + [[package]] name = "hashbrown" version = "0.14.5" @@ -2121,7 +2205,7 @@ dependencies = [ [[package]] name = "hawk-pack" version = "0.1.0" -source = "git+https://github.com/Inversed-Tech/hawk-pack.git?rev=4e6de24#4e6de24f7422923f8cccd8571ef03407e8dbbb99" +source = "git+https://github.com/Inversed-Tech/hawk-pack.git?rev=ba995e09#ba995e096116a564a0cc8fc43a6b75b2515c34ff" dependencies = [ "aes-prng 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)", "criterion", @@ -2130,6 +2214,7 @@ dependencies = [ "futures", "rand", "rand_core", + "rand_distr", "serde", "serde_json", "sqlx", @@ -2364,7 +2449,7 @@ dependencies = [ "hyper 1.5.0", "hyper-util", "log", - "rustls 0.23.16", + "rustls 0.23.18", "rustls-native-certs 0.8.0", "rustls-pki-types", "tokio", @@ -2372,6 +2457,19 @@ dependencies = [ "tower-service", ] +[[package]] +name = "hyper-timeout" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b90d566bffbce6a75bd8b09a05aa8c2cb1fabb6cb348f8840c9e4c90a0d83b0" +dependencies = [ + "hyper 1.5.0", + "hyper-util", + "pin-project-lite", + "tokio", + "tower-service", +] + [[package]] name = "hyper-tls" version = "0.5.0" @@ -2476,6 +2574,16 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ce23b50ad8242c51a442f3ff322d56b02f08852c77e4c0b4d3fd684abc89c683" +[[package]] +name = "indexmap" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" +dependencies = [ + "autocfg", + "hashbrown 0.12.3", +] + [[package]] name = "indexmap" version = "2.6.0" @@ -2546,10 +2654,12 @@ name = "iris-mpc" version = "0.1.0" dependencies = [ "aws-config", + "aws-sdk-s3", "aws-sdk-sns", "aws-sdk-sqs", "axum", "base64 0.22.1", + "bincode", "bytemuck", "clap", "criterion", @@ -2564,6 +2674,7 @@ dependencies = [ "ndarray", "rand", "reqwest 0.12.9", + "serde", "serde_json", "sha2", "sodiumoxide", @@ -2580,6 +2691,7 @@ name = "iris-mpc-common" version = "0.1.0" dependencies = [ "aws-config", + "aws-credential-types", "aws-sdk-kms", "aws-sdk-s3", "aws-sdk-secretsmanager", @@ -2624,12 +2736,15 @@ dependencies = [ name = "iris-mpc-cpu" version = "0.1.0" dependencies = [ - "aes-prng 0.2.1 (git+https://github.com/tf-encrypted/aes-prng.git?branch=dragos/display)", + "aes-prng 0.2.1 (git+https://github.com/tf-encrypted/aes-prng.git?branch=dragos%2Fdisplay)", "async-channel", + "async-stream", "async-trait", + "backoff", "bincode", "bytemuck", "bytes", + "clap", "criterion", "dashmap", "eyre", @@ -2638,13 +2753,20 @@ dependencies = [ "iris-mpc-common", "itertools 0.13.0", "num-traits", + "prost", "rand", "rstest", "serde", + "serde_json", "static_assertions", "tokio", + "tokio-stream", + "tonic", + "tonic-build", "tracing", + "tracing-subscriber", "tracing-test", + "uuid", ] [[package]] @@ -2663,6 +2785,7 @@ dependencies = [ "hex", "iris-mpc-common", "itertools 0.13.0", + "memmap2", "metrics 0.22.3", "metrics-exporter-statsd 0.7.0", "ndarray", @@ -2681,17 +2804,34 @@ dependencies = [ "uuid", ] +[[package]] +name = "iris-mpc-py" +version = "0.1.0" +dependencies = [ + "hawk-pack", + "iris-mpc-common", + "iris-mpc-cpu", + "pyo3", + "rand", +] + [[package]] name = "iris-mpc-store" version = "0.1.0" dependencies = [ + "async-trait", + "aws-sdk-s3", "bytemuck", + "bytes", + "csv", "dotenvy", "eyre", "futures", + "hex", "iris-mpc-common", "itertools 0.13.0", "rand", + "rayon", "serde", "serde_json", "sqlx", @@ -2710,19 +2850,25 @@ dependencies = [ "float_eq", "futures", "futures-concurrency", + "hkdf", "indicatif", "iris-mpc-common", "iris-mpc-store", "itertools 0.13.0", "mpc", + "prost", "rand", "rand_chacha", "rcgen", "serde", "serde-big-array", + "sha2", "sqlx", + "thiserror", "tokio", "tokio-native-tls", + "tonic", + "tonic-build", "tracing", "tracing-subscriber", ] @@ -2997,6 +3143,24 @@ version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" +[[package]] +name = "memmap2" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd3f7eed9d3848f8b98834af67102b720745c4ec028fcd0aa0239277e7de374f" +dependencies = [ + "libc", +] + +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + [[package]] name = "metrics" version = "0.22.3" @@ -3028,7 +3192,7 @@ dependencies = [ "hyper 1.5.0", "hyper-rustls 0.27.3", "hyper-util", - "indexmap", + "indexmap 2.6.0", "ipnet", "metrics 0.23.0", "metrics-util", @@ -3204,6 +3368,12 @@ dependencies = [ "url", ] +[[package]] +name = "multimap" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03" + [[package]] name = "native-tls" version = "0.2.12" @@ -3422,7 +3592,7 @@ dependencies = [ "ahash", "futures-core", "http 1.1.0", - "indexmap", + "indexmap 2.6.0", "itertools 0.11.0", "itoa", "once_cell", @@ -3651,6 +3821,16 @@ dependencies = [ "sha2", ] +[[package]] +name = "petgraph" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" +dependencies = [ + "fixedbitset 0.4.2", + "indexmap 2.6.0", +] + [[package]] name = "pin-project" version = "1.1.7" @@ -3806,6 +3986,122 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "prost" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b0487d90e047de87f984913713b85c601c05609aad5b0df4b4573fbf69aa13f" +dependencies = [ + "bytes", + "prost-derive", +] + +[[package]] +name = "prost-build" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c1318b19085f08681016926435853bbf7858f9c082d0999b80550ff5d9abe15" +dependencies = [ + "bytes", + "heck 0.5.0", + "itertools 0.13.0", + "log", + "multimap", + "once_cell", + "petgraph", + "prettyplease", + "prost", + "prost-types", + "regex", + "syn 2.0.85", + "tempfile", +] + +[[package]] +name = "prost-derive" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9552f850d5f0964a4e4d0bf306459ac29323ddfbae05e35a7c0d35cb0803cc5" +dependencies = [ + "anyhow", + "itertools 0.13.0", + "proc-macro2", + "quote", + "syn 2.0.85", +] + +[[package]] +name = "prost-types" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4759aa0d3a6232fb8dbdb97b61de2c20047c68aca932c7ed76da9d788508d670" +dependencies = [ + "prost", +] + +[[package]] +name = "pyo3" +version = "0.22.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f402062616ab18202ae8319da13fa4279883a2b8a9d9f83f20dbade813ce1884" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "once_cell", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.22.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b14b5775b5ff446dd1056212d778012cbe8a0fbffd368029fd9e25b514479c38" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.22.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ab5bcf04a2cdcbb50c7d6105de943f543f9ed92af55818fd17b660390fc8636" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.22.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fd24d897903a9e6d80b968368a34e1525aeb719d568dba8b3d4bfa5dc67d453" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn 2.0.85", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.22.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36c011a03ba1e50152b4b394b479826cad97e7a21eb52df179cd91ac411cbfbe" +dependencies = [ + "heck 0.5.0", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn 2.0.85", +] + [[package]] name = "quanta" version = "0.12.3" @@ -3872,6 +4168,16 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rand_distr" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand", +] + [[package]] name = "raw-cpuid" version = "11.2.0" @@ -4266,13 +4572,14 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.16" +version = "0.23.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eee87ff5d9b36712a58574e12e9f0ea80f915a5b0ac518d322b24a465617925e" +checksum = "9c9cc1d47e243d655ace55ed38201c19ae02c148ae56412ab8750e8f0166ab7f" dependencies = [ "aws-lc-rs", "log", "once_cell", + "ring", "rustls-pki-types", "rustls-webpki 0.102.8", "subtle", @@ -4504,7 +4811,7 @@ version = "1.0.132" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" dependencies = [ - "indexmap", + "indexmap 2.6.0", "itoa", "memchr", "ryu", @@ -4771,7 +5078,7 @@ dependencies = [ "hashbrown 0.14.5", "hashlink 0.9.1", "hex", - "indexmap", + "indexmap 2.6.0", "log", "memchr", "native-tls", @@ -5061,6 +5368,12 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" +[[package]] +name = "target-lexicon" +version = "0.12.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" + [[package]] name = "telemetry-batteries" version = "0.1.0" @@ -5263,7 +5576,7 @@ version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" dependencies = [ - "rustls 0.23.16", + "rustls 0.23.18", "rustls-pki-types", "tokio", ] @@ -5320,13 +5633,80 @@ version = "0.22.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" dependencies = [ - "indexmap", + "indexmap 2.6.0", "serde", "serde_spanned", "toml_datetime", "winnow", ] +[[package]] +name = "tonic" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877c5b330756d856ffcc4553ab34a5684481ade925ecc54bcd1bf02b1d0d4d52" +dependencies = [ + "async-stream", + "async-trait", + "axum", + "base64 0.22.1", + "bytes", + "h2 0.4.6", + "http 1.1.0", + "http-body 1.0.1", + "http-body-util", + "hyper 1.5.0", + "hyper-timeout", + "hyper-util", + "percent-encoding", + "pin-project", + "prost", + "rustls-native-certs 0.8.0", + "rustls-pemfile 2.2.0", + "socket2 0.5.7", + "tokio", + "tokio-rustls 0.26.0", + "tokio-stream", + "tower 0.4.13", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tonic-build" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9557ce109ea773b399c9b9e5dca39294110b74f1f342cb347a80d1fce8c26a11" +dependencies = [ + "prettyplease", + "proc-macro2", + "prost-build", + "prost-types", + "quote", + "syn 2.0.85", +] + +[[package]] +name = "tower" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" +dependencies = [ + "futures-core", + "futures-util", + "indexmap 1.9.3", + "pin-project", + "pin-project-lite", + "rand", + "slab", + "tokio", + "tokio-util", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "tower" version = "0.5.1" @@ -5610,6 +5990,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" +[[package]] +name = "unindent" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" + [[package]] name = "untrusted" version = "0.9.0" diff --git a/Cargo.toml b/Cargo.toml index 7416fa873..765c0c9a0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ members = [ "iris-mpc-common", "iris-mpc-upgrade", "iris-mpc-store", + "iris-mpc-py", ] resolver = "2" @@ -15,24 +16,31 @@ license = "MIT OR (Apache-2.0 WITH LLVM-exception)" repository = "https://github.com/worldcoin/iris-mpc" [workspace.dependencies] -aws-config = { version = "1.5.4", features = ["behavior-version-latest"] } +aws-config = { version = "1.5.10", features = ["behavior-version-latest"] } aws-sdk-kms = { version = "1.44.0" } aws-sdk-sns = { version = "1.44.0" } aws-sdk-sqs = { version = "1.36.0" } -aws-sdk-s3 = { version = "1.50.0" } +aws-sdk-s3 = { version = "1.65.0" } aws-sdk-secretsmanager = { version = "1.47.0" } +async-trait = "0.1.83" axum = "0.7" clap = { version = "4", features = ["derive", "env"] } +csv = "1.3.1" base64 = "0.22.1" +bytes = "1.5" bytemuck = { version = "1.17", features = ["derive"] } dotenvy = "0.15" eyre = "0.6" futures = "0.3.30" +hawk-pack = { git = "https://github.com/Inversed-Tech/hawk-pack.git", rev = "ba995e09" } hex = "0.4.3" itertools = "0.13" num-traits = "0.2" +memmap2 = "0.9.5" serde = { version = "1.0", features = ["derive"] } +serde-big-array = "0.5.1" serde_json = "1" +bincode = "1.3.3" sqlx = { version = "0.8", features = ["runtime-tokio-native-tls", "postgres"] } tracing = "0.1.40" tracing-subscriber = { version = "0.3.15", features = ["env-filter"] } diff --git a/Dockerfile b/Dockerfile index 4ea9b9ef3..147996daf 100644 --- a/Dockerfile +++ b/Dockerfile @@ -12,6 +12,7 @@ RUN apt-get update && apt-get install -y \ devscripts \ debhelper \ ca-certificates \ + protobuf-compiler \ wget RUN curl https://sh.rustup.rs -sSf | sh -s -- -y @@ -27,7 +28,7 @@ RUN cargo install cargo-build-deps \ FROM --platform=linux/amd64 build-image as build-app WORKDIR /src/gpu-iris-mpc COPY . . -RUN cargo build --release --target x86_64-unknown-linux-gnu --bin nccl --bin server --bin client --bin key-manager --bin upgrade-server --bin upgrade-client --bin upgrade-checker +RUN cargo build --release --target x86_64-unknown-linux-gnu --bin nccl --bin server --bin client --bin key-manager --bin upgrade-server --bin upgrade-client --bin upgrade-checker --bin reshare-server --bin reshare-client FROM --platform=linux/amd64 ghcr.io/worldcoin/iris-mpc-base:cuda12_2-nccl2_22_3_1 ENV DEBIAN_FRONTEND=noninteractive @@ -40,6 +41,8 @@ COPY --from=build-app /src/gpu-iris-mpc/target/x86_64-unknown-linux-gnu/release/ COPY --from=build-app /src/gpu-iris-mpc/target/x86_64-unknown-linux-gnu/release/upgrade-server /bin/upgrade-server COPY --from=build-app /src/gpu-iris-mpc/target/x86_64-unknown-linux-gnu/release/upgrade-client /bin/upgrade-client COPY --from=build-app /src/gpu-iris-mpc/target/x86_64-unknown-linux-gnu/release/upgrade-checker /bin/upgrade-checker +COPY --from=build-app /src/gpu-iris-mpc/target/x86_64-unknown-linux-gnu/release/reshare-server /bin/reshare-server +COPY --from=build-app /src/gpu-iris-mpc/target/x86_64-unknown-linux-gnu/release/reshare-client /bin/reshare-client USER 65534 ENTRYPOINT ["/bin/server"] diff --git a/Dockerfile.base b/Dockerfile.base index c82f08130..5ddd4c845 100644 --- a/Dockerfile.base +++ b/Dockerfile.base @@ -1,6 +1,6 @@ FROM --platform=linux/amd64 ubuntu:22.04 as build-image -RUN apt-get update && apt-get install -y pkg-config wget libssl-dev ca-certificates \ +RUN apt-get update && apt-get install -y pkg-config wget libssl-dev ca-certificates protobuf-compiler \ && rm -rf /var/lib/apt/lists/* RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb \ && dpkg -i cuda-keyring_1.1-1_all.deb \ @@ -41,8 +41,8 @@ RUN cd /tmp \ && git checkout ${AWS_OFI_NCCL_VERSION} \ && ./autogen.sh \ && ./configure --prefix=/opt/aws-ofi-nccl/install \ - --with-libfabric=/opt/amazon/efa/ \ - --with-cuda=/usr/local/cuda \ - --with-nccl=/tmp/nccl/build \ - --with-mpi=/opt/amazon/openmpi/ \ + --with-libfabric=/opt/amazon/efa/ \ + --with-cuda=/usr/local/cuda \ + --with-nccl=/tmp/nccl/build \ + --with-mpi=/opt/amazon/openmpi/ \ && make && make install diff --git a/Dockerfile.debug b/Dockerfile.debug index 83758e93c..2e1a1bd76 100644 --- a/Dockerfile.debug +++ b/Dockerfile.debug @@ -14,6 +14,7 @@ RUN apt-get update && apt-get install -y \ devscripts \ debhelper \ ca-certificates \ + protobuf-compiler \ wget RUN curl https://sh.rustup.rs -sSf | sh -s -- -y diff --git a/Dockerfile.nccl b/Dockerfile.nccl index edf79ecc3..d6c5a3db7 100644 --- a/Dockerfile.nccl +++ b/Dockerfile.nccl @@ -6,6 +6,7 @@ RUN apt-get update && apt-get install -y \ build-essential \ libssl-dev \ ca-certificates \ + protobuf-compiler \ wget RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb \ diff --git a/Dockerfile.nocuda b/Dockerfile.nocuda index 00adf1498..13326ab73 100644 --- a/Dockerfile.nocuda +++ b/Dockerfile.nocuda @@ -12,6 +12,7 @@ RUN apt-get update && apt-get install -y \ devscripts \ debhelper \ ca-certificates \ + protobuf-compiler \ wget RUN curl https://sh.rustup.rs -sSf | sh -s -- -y @@ -28,12 +29,12 @@ FROM --platform=linux/amd64 build-image as build-app WORKDIR /src/gpu-iris-mpc COPY . . -RUN cargo build --release --target x86_64-unknown-linux-gnu --bin seed-v1-dbs --bin upgrade-server --bin upgrade-client --bin upgrade-checker +RUN cargo build --release --target x86_64-unknown-linux-gnu --bin seed-v1-dbs --bin upgrade-server --bin upgrade-client --bin upgrade-checker --bin reshare-server --bin key-manager FROM --platform=linux/amd64 ubuntu:22.04 ENV DEBIAN_FRONTEND=noninteractive -RUN apt-get update && apt-get install -y ca-certificates +RUN apt-get update && apt-get install -y ca-certificates awscli COPY certs /usr/local/share/ca-certificates/ RUN update-ca-certificates @@ -41,6 +42,8 @@ COPY --from=build-app /src/gpu-iris-mpc/target/x86_64-unknown-linux-gnu/release/ COPY --from=build-app /src/gpu-iris-mpc/target/x86_64-unknown-linux-gnu/release/upgrade-server /bin/upgrade-server COPY --from=build-app /src/gpu-iris-mpc/target/x86_64-unknown-linux-gnu/release/upgrade-client /bin/upgrade-client COPY --from=build-app /src/gpu-iris-mpc/target/x86_64-unknown-linux-gnu/release/upgrade-checker /bin/upgrade-checker +COPY --from=build-app /src/gpu-iris-mpc/target/x86_64-unknown-linux-gnu/release/reshare-server /bin/reshare-server +COPY --from=build-app /src/gpu-iris-mpc/target/x86_64-unknown-linux-gnu/release/key-manager /bin/key-manager USER 65534 ENTRYPOINT ["/bin/upgrade-server"] diff --git a/Dockerfile.shares-encoding b/Dockerfile.shares-encoding index 719a5d2e7..ee7a39d4f 100644 --- a/Dockerfile.shares-encoding +++ b/Dockerfile.shares-encoding @@ -12,6 +12,7 @@ RUN apt-get update && apt-get install -y \ devscripts \ debhelper \ ca-certificates \ + protobuf-compiler \ wget RUN curl https://sh.rustup.rs -sSf | sh -s -- -y diff --git a/deny.toml b/deny.toml index 26dd67235..82f47c613 100644 --- a/deny.toml +++ b/deny.toml @@ -5,10 +5,12 @@ all-features = true [advisories] version = 2 ignore = [ - { id = "RUSTSEC-2021-0137", reason = "we will switch to alkali eventually" }, - # https://github.com/mehcode/config-rs/issues/563 - { id = "RUSTSEC-2024-0384", reason = "waiting for `web-time` crate to remove the dependency" }, - { id = "RUSTSEC-2024-0388", reason = "waiting for `mongodb` crate to remove the deprecated dependency" }, + { id = "RUSTSEC-2021-0137", reason = "we will switch to alkali eventually" }, + # https://github.com/mehcode/config-rs/issues/563 + { id = "RUSTSEC-2024-0384", reason = "waiting for `web-time` crate to remove the dependency" }, + { id = "RUSTSEC-2024-0388", reason = "waiting for `mongodb` crate to remove the deprecated dependency" }, + { id = "RUSTSEC-2024-0402", reason = "wating for `index-map` crate to remove the dependency" }, + { id = "RUSTSEC-2024-0421", reason = "waiting for `mongodb` crate to remove the deprecated dependency" }, ] [sources] @@ -35,6 +37,7 @@ allow = [ "MIT", "MPL-2.0", # Although this is copyleft, it is scoped to modifying the original files "OpenSSL", + "Unicode-3.0", "Unicode-DFS-2016", "Unlicense", "Zlib", diff --git a/deploy/e2e/iris-mpc-0.yaml.tpl b/deploy/e2e/iris-mpc-0.yaml.tpl new file mode 100644 index 000000000..72e38feb1 --- /dev/null +++ b/deploy/e2e/iris-mpc-0.yaml.tpl @@ -0,0 +1,221 @@ +iris-mpc-0: + fullnameOverride: "iris-mpc-0" + image: "ghcr.io/worldcoin/iris-mpc:$IRIS_MPC_IMAGE_TAG" + + environment: $ENV + replicaCount: 1 + + strategy: + type: Recreate + + datadog: + enabled: false + + ports: + - containerPort: 3000 + name: health + protocol: TCP + + livenessProbe: + httpGet: + path: /health + port: health + + readinessProbe: + periodSeconds: 30 + httpGet: + path: /ready + port: health + + startupProbe: + initialDelaySeconds: 60 + failureThreshold: 40 + periodSeconds: 30 + httpGet: + path: /ready + port: health + + podSecurityContext: + runAsNonRoot: false + seccompProfile: + type: RuntimeDefault + + resources: + limits: + cpu: 31 + memory: 60Gi + nvidia.com/gpu: 1 + + requests: + cpu: 30 + memory: 55Gi + nvidia.com/gpu: 1 + + imagePullSecrets: + - name: github-secret + + nodeSelector: + kubernetes.io/arch: amd64 + + hostNetwork: false + + tolerations: + - key: "gpuGroup" + operator: "Equal" + value: "dedicated" + effect: "NoSchedule" + + keelPolling: + # -- Specifies whether keel should poll for container updates + enabled: true + + libsDir: + enabled: true + path: "/libs" + size: 2Gi + files: + - path: "/usr/local/cuda-12.2/targets/x86_64-linux/lib/libcublasLt.so.12.2.5.6" + file: "libcublasLt.so.12.2.5.6" + - path: "/usr/local/cuda-12.2/targets/x86_64-linux/lib/libcublas.so.12.2.5.6" + file: "libcublas.so.12.2.5.6" + + preStop: + # preStop.sleepPeriod specifies the time spent in Terminating state before SIGTERM is sent + sleepPeriod: 10 + + # terminationGracePeriodSeconds specifies the grace time between SIGTERM and SIGKILL + terminationGracePeriodSeconds: 180 # 3x SMPC__PROCESSING_TIMEOUT_SECS + + env: + - name: RUST_LOG + value: "info" + + - name: AWS_REGION + value: "$AWS_REGION" + + - name: AWS_ENDPOINT_URL + value: "http://localstack:4566" + + - name: RUST_BACKTRACE + value: "full" + + - name: NCCL_SOCKET_IFNAME + value: "eth0" + + - name: NCCL_COMM_ID + value: "iris-mpc-0.svc.cluster.local:4000" + + - name: SMPC__ENVIRONMENT + value: "$ENV" + + - name: SMPC__AWS__REGION + value: "$AWS_REGION" + + - name: SMPC__SERVICE__SERVICE_NAME + value: "smpcv2-server-$ENV" + + - name: SMPC__DATABASE__URL + valueFrom: + secretKeyRef: + key: DATABASE_AURORA_URL + name: application + + - name: SMPC__DATABASE__MIGRATE + value: "true" + + - name: SMPC__DATABASE__CREATE + value: "true" + + - name: SMPC__DATABASE__LOAD_PARALLELISM + value: "8" + + - name: SMPC__REQUESTS_QUEUE_URL + value: "arn:aws:sns:eu-central-1:000000000000:iris-mpc-input" + + - name: SMPC__RESULTS_TOPIC_ARN + value: "arn:aws:sns:eu-central-1:000000000000:iris-mpc-results" + + - name: SMPC__PROCESSING_TIMEOUT_SECS + value: "60" + + - name: SMPC__PATH + value: "/data/" + + - name: SMPC__KMS_KEY_ARNS + value: '["arn:aws:kms:$AWS_REGION:000000000000:key/00000000-0000-0000-0000-000000000000","arn:aws:kms:$AWS_REGION:000000000000:key/00000000-0000-0000-0000-000000000001","arn:aws:kms:$AWS_REGION:000000000000:key/00000000-0000-0000-0000-000000000002"]' + + - name: SMPC__PARTY_ID + value: "0" + + - name: SMPC__PUBLIC_KEY_BASE_URL + value: "http://wf-$ENV-public-keys.s3.localhost.localstack.cloud:4566" + + - name: SMPC__ENABLE_S3_IMPORTER + value: "false" + + - name: SMPC__SHARES_BUCKET_NAME + value: "wf-smpcv2-stage-sns-requests" + + - name: SMPC__CLEAR_DB_BEFORE_INIT + value: "true" + + - name: SMPC__INIT_DB_SIZE + value: "80000" + + - name: SMPC__MAX_DB_SIZE + value: "110000" + + - name: SMPC__MAX_BATCH_SIZE + value: "64" + + - name: SMPC__SERVICE__METRICS__HOST + valueFrom: + fieldRef: + fieldPath: status.hostIP + + - name: SMPC__SERVICE__METRICS__PORT + value: "8125" + + - name: SMPC__SERVICE__METRICS__QUEUE_SIZE + value: "5000" + + - name: SMPC__SERVICE__METRICS__BUFFER_SIZE + value: "256" + + - name: SMPC__SERVICE__METRICS__PREFIX + value: "smpcv2-$ENV-0" + + - name: SMPC__RETURN_PARTIAL_RESULTS + value: "true" + + - name: SMPC__NODE_HOSTNAMES + value: '["iris-mpc-0.svc.cluster.local","iris-mpc-1.svc.cluster.local","iris-mpc-2.svc.cluster.local"]' + + - name: SMPC__IMAGE_NAME + value: "ghcr.io/worldcoin/iris-mpc:$IRIS_MPC_IMAGE_TAG" + + initContainer: + enabled: true + image: "ghcr.io/worldcoin/iris-mpc:2694d8cbb37c278ed84951ef9aac3af47b21f146" # no-cuda image + name: "iris-mpc-0-copy-cuda-libs" + env: + - name: AWS_REGION + value: "$AWS_REGION" + - name: PARTY_ID + value: "1" + - name: MY_NODE_IP + valueFrom: + fieldRef: + fieldPath: status.hostIP + configMap: + name: "iris-mpc-0-init" + init.sh: | + #!/usr/bin/env bash + set -e + + cd /libs + + aws s3 cp s3://wf-smpcv2-stage-libs/libcublas.so.12.2.5.6 . + aws s3 cp s3://wf-smpcv2-stage-libs/libcublasLt.so.12.2.5.6 . + + key-manager --node-id 0 --env $ENV --endpoint-url "http://localstack:4566" rotate --public-key-bucket-name wf-$ENV-stage-public-keys --region $AWS_REGION diff --git a/deploy/e2e/iris-mpc-1.yaml.tpl b/deploy/e2e/iris-mpc-1.yaml.tpl new file mode 100644 index 000000000..15b3cd127 --- /dev/null +++ b/deploy/e2e/iris-mpc-1.yaml.tpl @@ -0,0 +1,221 @@ +iris-mpc-1: + fullnameOverride: "iris-mpc-1" + image: "ghcr.io/worldcoin/iris-mpc:$IRIS_MPC_IMAGE_TAG" + + environment: $ENV + replicaCount: 1 + + strategy: + type: Recreate + + datadog: + enabled: false + + ports: + - containerPort: 3000 + name: health + protocol: TCP + + livenessProbe: + httpGet: + path: /health + port: health + + readinessProbe: + periodSeconds: 30 + httpGet: + path: /ready + port: health + + startupProbe: + initialDelaySeconds: 60 + failureThreshold: 40 + periodSeconds: 30 + httpGet: + path: /ready + port: health + + podSecurityContext: + runAsNonRoot: false + seccompProfile: + type: RuntimeDefault + + resources: + limits: + cpu: 31 + memory: 60Gi + nvidia.com/gpu: 1 + + requests: + cpu: 30 + memory: 55Gi + nvidia.com/gpu: 1 + + imagePullSecrets: + - name: github-secret + + nodeSelector: + kubernetes.io/arch: amd64 + + hostNetwork: false + + tolerations: + - key: "gpuGroup" + operator: "Equal" + value: "dedicated" + effect: "NoSchedule" + + keelPolling: + # -- Specifies whether keel should poll for container updates + enabled: true + + libsDir: + enabled: true + path: "/libs" + size: 2Gi + files: + - path: "/usr/local/cuda-12.2/targets/x86_64-linux/lib/libcublasLt.so.12.2.5.6" + file: "libcublasLt.so.12.2.5.6" + - path: "/usr/local/cuda-12.2/targets/x86_64-linux/lib/libcublas.so.12.2.5.6" + file: "libcublas.so.12.2.5.6" + + preStop: + # preStop.sleepPeriod specifies the time spent in Terminating state before SIGTERM is sent + sleepPeriod: 10 + + # terminationGracePeriodSeconds specifies the grace time between SIGTERM and SIGKILL + terminationGracePeriodSeconds: 180 # 3x SMPC__PROCESSING_TIMEOUT_SECS + + env: + - name: RUST_LOG + value: "info" + + - name: AWS_REGION + value: "$AWS_REGION" + + - name: AWS_ENDPOINT_URL + value: "http://localstack:4566" + + - name: RUST_BACKTRACE + value: "full" + + - name: NCCL_SOCKET_IFNAME + value: "eth0" + + - name: NCCL_COMM_ID + value: "iris-mpc-1.svc.cluster.local:4000" + + - name: SMPC__ENVIRONMENT + value: "$ENV" + + - name: SMPC__AWS__REGION + value: "$AWS_REGION" + + - name: SMPC__SERVICE__SERVICE_NAME + value: "smpcv2-server-$ENV" + + - name: SMPC__DATABASE__URL + valueFrom: + secretKeyRef: + key: DATABASE_AURORA_URL + name: application + + - name: SMPC__DATABASE__MIGRATE + value: "true" + + - name: SMPC__DATABASE__CREATE + value: "true" + + - name: SMPC__DATABASE__LOAD_PARALLELISM + value: "8" + + - name: SMPC__REQUESTS_QUEUE_URL + value: "arn:aws:sns:eu-central-1:000000000000:iris-mpc-input" + + - name: SMPC__RESULTS_TOPIC_ARN + value: "arn:aws:sns:eu-central-1:000000000000:iris-mpc-results" + + - name: SMPC__PROCESSING_TIMEOUT_SECS + value: "60" + + - name: SMPC__PATH + value: "/data/" + + - name: SMPC__KMS_KEY_ARNS + value: '["arn:aws:kms:$AWS_REGION:000000000000:key/00000000-0000-0000-0000-000000000000","arn:aws:kms:$AWS_REGION:000000000000:key/00000000-0000-0000-0000-000000000001","arn:aws:kms:$AWS_REGION:000000000000:key/00000000-0000-0000-0000-000000000002"]' + + - name: SMPC__PARTY_ID + value: "1" + + - name: SMPC__PUBLIC_KEY_BASE_URL + value: "http://wf-$ENV-public-keys.s3.localhost.localstack.cloud:4566" + + - name: SMPC__ENABLE_S3_IMPORTER + value: "false" + + - name: SMPC__SHARES_BUCKET_NAME + value: "wf-smpcv2-stage-sns-requests" + + - name: SMPC__CLEAR_DB_BEFORE_INIT + value: "true" + + - name: SMPC__INIT_DB_SIZE + value: "80000" + + - name: SMPC__MAX_DB_SIZE + value: "110000" + + - name: SMPC__MAX_BATCH_SIZE + value: "64" + + - name: SMPC__SERVICE__METRICS__HOST + valueFrom: + fieldRef: + fieldPath: status.hostIP + + - name: SMPC__SERVICE__METRICS__PORT + value: "8125" + + - name: SMPC__SERVICE__METRICS__QUEUE_SIZE + value: "5000" + + - name: SMPC__SERVICE__METRICS__BUFFER_SIZE + value: "256" + + - name: SMPC__SERVICE__METRICS__PREFIX + value: "smpcv2-$ENV-1" + + - name: SMPC__RETURN_PARTIAL_RESULTS + value: "true" + + - name: SMPC__NODE_HOSTNAMES + value: '["iris-mpc-0.svc.cluster.local","iris-mpc-1.svc.cluster.local","iris-mpc-2.svc.cluster.local"]' + + - name: SMPC__IMAGE_NAME + value: "ghcr.io/worldcoin/iris-mpc:$IRIS_MPC_IMAGE_TAG" + + initContainer: + enabled: true + image: "ghcr.io/worldcoin/iris-mpc:2694d8cbb37c278ed84951ef9aac3af47b21f146" # no-cuda image + name: "iris-mpc-1-copy-cuda-libs" + env: + - name: AWS_REGION + value: "$AWS_REGION" + - name: PARTY_ID + value: "2" + - name: MY_NODE_IP + valueFrom: + fieldRef: + fieldPath: status.hostIP + configMap: + name: "iris-mpc-1-init" + init.sh: | + #!/usr/bin/env bash + set -e + + cd /libs + + aws s3 cp s3://wf-smpcv2-stage-libs/libcublas.so.12.2.5.6 . + aws s3 cp s3://wf-smpcv2-stage-libs/libcublasLt.so.12.2.5.6 . + + key-manager --node-id 1 --env $ENV --region $AWS_REGION --endpoint-url "http://localstack:4566" rotate --public-key-bucket-name wf-$ENV-public-keys diff --git a/deploy/e2e/iris-mpc-2.yaml.tpl b/deploy/e2e/iris-mpc-2.yaml.tpl new file mode 100644 index 000000000..485734c90 --- /dev/null +++ b/deploy/e2e/iris-mpc-2.yaml.tpl @@ -0,0 +1,221 @@ +iris-mpc-2: + fullnameOverride: "iris-mpc-2" + image: "ghcr.io/worldcoin/iris-mpc:$IRIS_MPC_IMAGE_TAG" + + environment: $ENV + replicaCount: 1 + + strategy: + type: Recreate + + datadog: + enabled: false + + ports: + - containerPort: 3000 + name: health + protocol: TCP + + livenessProbe: + httpGet: + path: /health + port: health + + readinessProbe: + periodSeconds: 30 + httpGet: + path: /ready + port: health + + startupProbe: + initialDelaySeconds: 60 + failureThreshold: 40 + periodSeconds: 30 + httpGet: + path: /ready + port: health + + podSecurityContext: + runAsNonRoot: false + seccompProfile: + type: RuntimeDefault + + resources: + limits: + cpu: 31 + memory: 60Gi + nvidia.com/gpu: 1 + + requests: + cpu: 30 + memory: 55Gi + nvidia.com/gpu: 1 + + imagePullSecrets: + - name: github-secret + + nodeSelector: + kubernetes.io/arch: amd64 + + hostNetwork: false + + tolerations: + - key: "gpuGroup" + operator: "Equal" + value: "dedicated" + effect: "NoSchedule" + + keelPolling: + # -- Specifies whether keel should poll for container updates + enabled: true + + libsDir: + enabled: true + path: "/libs" + size: 2Gi + files: + - path: "/usr/local/cuda-12.2/targets/x86_64-linux/lib/libcublasLt.so.12.2.5.6" + file: "libcublasLt.so.12.2.5.6" + - path: "/usr/local/cuda-12.2/targets/x86_64-linux/lib/libcublas.so.12.2.5.6" + file: "libcublas.so.12.2.5.6" + + preStop: + # preStop.sleepPeriod specifies the time spent in Terminating state before SIGTERM is sent + sleepPeriod: 10 + + # terminationGracePeriodSeconds specifies the grace time between SIGTERM and SIGKILL + terminationGracePeriodSeconds: 180 # 3x SMPC__PROCESSING_TIMEOUT_SECS + + env: + - name: RUST_LOG + value: "info" + + - name: AWS_REGION + value: "$AWS_REGION" + + - name: AWS_ENDPOINT_URL + value: "http://localstack:4566" + + - name: RUST_BACKTRACE + value: "full" + + - name: NCCL_SOCKET_IFNAME + value: "eth0" + + - name: NCCL_COMM_ID + value: "iris-mpc-2.svc.cluster.local:4000" + + - name: SMPC__ENVIRONMENT + value: "$ENV" + + - name: SMPC__AWS__REGION + value: "$AWS_REGION" + + - name: SMPC__SERVICE__SERVICE_NAME + value: "smpcv2-server-$ENV" + + - name: SMPC__DATABASE__URL + valueFrom: + secretKeyRef: + key: DATABASE_AURORA_URL + name: application + + - name: SMPC__DATABASE__MIGRATE + value: "true" + + - name: SMPC__DATABASE__CREATE + value: "true" + + - name: SMPC__DATABASE__LOAD_PARALLELISM + value: "8" + + - name: SMPC__REQUESTS_QUEUE_URL + value: "arn:aws:sns:eu-central-1:000000000000:iris-mpc-input" + + - name: SMPC__RESULTS_TOPIC_ARN + value: "arn:aws:sns:eu-central-1:000000000000:iris-mpc-results" + + - name: SMPC__PROCESSING_TIMEOUT_SECS + value: "60" + + - name: SMPC__PATH + value: "/data/" + + - name: SMPC__KMS_KEY_ARNS + value: '["arn:aws:kms:$AWS_REGION:000000000000:key/00000000-0000-0000-0000-000000000000","arn:aws:kms:$AWS_REGION:000000000000:key/00000000-0000-0000-0000-000000000001","arn:aws:kms:$AWS_REGION:000000000000:key/00000000-0000-0000-0000-000000000002"]' + + - name: SMPC__PARTY_ID + value: "2" + + - name: SMPC__PUBLIC_KEY_BASE_URL + value: "http://wf-$ENV-public-keys.s3.localhost.localstack.cloud:4566" + + - name: SMPC__ENABLE_S3_IMPORTER + value: "false" + + - name: SMPC__SHARES_BUCKET_NAME + value: "wf-smpcv2-stage-sns-requests" + + - name: SMPC__CLEAR_DB_BEFORE_INIT + value: "true" + + - name: SMPC__INIT_DB_SIZE + value: "80000" + + - name: SMPC__MAX_DB_SIZE + value: "110000" + + - name: SMPC__MAX_BATCH_SIZE + value: "64" + + - name: SMPC__SERVICE__METRICS__HOST + valueFrom: + fieldRef: + fieldPath: status.hostIP + + - name: SMPC__SERVICE__METRICS__PORT + value: "8125" + + - name: SMPC__SERVICE__METRICS__QUEUE_SIZE + value: "5000" + + - name: SMPC__SERVICE__METRICS__BUFFER_SIZE + value: "256" + + - name: SMPC__SERVICE__METRICS__PREFIX + value: "smpcv2-$ENV-2" + + - name: SMPC__RETURN_PARTIAL_RESULTS + value: "true" + + - name: SMPC__NODE_HOSTNAMES + value: '["iris-mpc-0.svc.cluster.local","iris-mpc-1.svc.cluster.local","iris-mpc-2.svc.cluster.local"]' + + - name: SMPC__IMAGE_NAME + value: "ghcr.io/worldcoin/iris-mpc:$IRIS_MPC_IMAGE_TAG" + + initContainer: + enabled: true + image: "ghcr.io/worldcoin/iris-mpc:2694d8cbb37c278ed84951ef9aac3af47b21f146" # no-cuda image + name: "iris-mpc-2-copy-cuda-libs" + env: + - name: AWS_REGION + value: "$AWS_REGION" + - name: PARTY_ID + value: "3" + - name: MY_NODE_IP + valueFrom: + fieldRef: + fieldPath: status.hostIP + configMap: + name: "iris-mpc-2-init" + init.sh: | + #!/usr/bin/env bash + set -e + + cd /libs + + aws s3 cp s3://wf-smpcv2-stage-libs/libcublas.so.12.2.5.6 . + aws s3 cp s3://wf-smpcv2-stage-libs/libcublasLt.so.12.2.5.6 . + + key-manager --node-id 2 --env $ENV --region $AWS_REGION --endpoint-url "http://localstack:4566" rotate --public-key-bucket-name wf-$ENV-public-keys diff --git a/deploy/prod/common-values-iris-mpc.yaml b/deploy/prod/common-values-iris-mpc.yaml index 139b7c6b2..858e3d872 100644 --- a/deploy/prod/common-values-iris-mpc.yaml +++ b/deploy/prod/common-values-iris-mpc.yaml @@ -1,4 +1,4 @@ -image: "ghcr.io/worldcoin/iris-mpc:v0.9.10" +image: "ghcr.io/worldcoin/iris-mpc:v0.13.9" environment: prod replicaCount: 1 @@ -18,25 +18,21 @@ ports: # protocol: TCP livenessProbe: - initialDelaySeconds: 300 httpGet: path: /health port: health readinessProbe: - initialDelaySeconds: 300 - periodSeconds: 30 - failureThreshold: 10 httpGet: - path: /health + path: /ready port: health startupProbe: initialDelaySeconds: 900 - failureThreshold: 20 + failureThreshold: 50 periodSeconds: 30 httpGet: - path: /health + path: /ready port: health resources: diff --git a/deploy/prod/common-values-upgrade-server-left.yaml b/deploy/prod/common-values-upgrade-server-left.yaml index a15c07562..ceddfc97a 100644 --- a/deploy/prod/common-values-upgrade-server-left.yaml +++ b/deploy/prod/common-values-upgrade-server-left.yaml @@ -1,4 +1,4 @@ -image: "ghcr.io/worldcoin/iris-mpc:v0.8.25" +image: "ghcr.io/worldcoin/iris-mpc:v0.12.1" environment: prod replicaCount: 1 diff --git a/deploy/prod/common-values-upgrade-server-right.yaml b/deploy/prod/common-values-upgrade-server-right.yaml index 5f8f28507..f34c939ad 100644 --- a/deploy/prod/common-values-upgrade-server-right.yaml +++ b/deploy/prod/common-values-upgrade-server-right.yaml @@ -1,4 +1,4 @@ -image: "ghcr.io/worldcoin/iris-mpc:v0.8.25" +image: "ghcr.io/worldcoin/iris-mpc:v0.12.1" environment: prod replicaCount: 1 diff --git a/deploy/prod/smpcv2-0-prod/values-iris-mpc.yaml b/deploy/prod/smpcv2-0-prod/values-iris-mpc.yaml index 963bd3944..a22402550 100644 --- a/deploy/prod/smpcv2-0-prod/values-iris-mpc.yaml +++ b/deploy/prod/smpcv2-0-prod/values-iris-mpc.yaml @@ -2,9 +2,6 @@ env: - name: RUST_LOG value: "info" - - name: NCCL_DEBUG - value: "INFO" - - name: NCCL_SOCKET_IFNAME value: "eth" @@ -39,7 +36,7 @@ env: value: "true" - name: SMPC__DATABASE__LOAD_PARALLELISM - value: "80" + value: "8" - name: SMPC__AWS__REGION value: "eu-north-1" @@ -57,13 +54,13 @@ env: name: application - name: SMPC__PROCESSING_TIMEOUT_SECS - value: "120" + value: "300" - name: SMPC__HEARTBEAT_INTERVAL_SECS value: "2" - name: SMPC__HEARTBEAT_INITIAL_RETRIES - value: "1000" + value: "65" - name: SMPC__PATH value: "/data/" @@ -80,6 +77,18 @@ env: - name: SMPC__PUBLIC_KEY_BASE_URL value: "https://pki-smpc.worldcoin.org" + - name: SMPC__ENABLE_S3_IMPORTER + value: "true" + + - name: SMPC__DB_CHUNKS_BUCKET_NAME + value: "iris-mpc-db-exporter-store-node-0-prod-eu-north-1" + + - name: SMPC__DB_CHUNKS_FOLDER_NAME + value: "binary_output_16k" + + - name: SMPC__DATABASE__LOAD_PARALLELISM + value: "64" + - name: SMPC__CLEAR_DB_BEFORE_INIT value: "true" @@ -121,6 +130,9 @@ env: - name: SMPC__NODE_HOSTNAMES value: '["iris-mpc-node.1.smpcv2.worldcoin.org","iris-mpc-node.2.smpcv2.worldcoin.org","iris-mpc-node.3.smpcv2.worldcoin.org"]' + - name: SMPC__IMAGE_NAME + value: $(IMAGE_NAME) + initContainer: enabled: true image: "amazon/aws-cli:2.17.62" diff --git a/deploy/prod/smpcv2-0-prod/values-upgrade-server-left.yaml b/deploy/prod/smpcv2-0-prod/values-upgrade-server-left.yaml index 884a3f4a4..23a679a81 100644 --- a/deploy/prod/smpcv2-0-prod/values-upgrade-server-left.yaml +++ b/deploy/prod/smpcv2-0-prod/values-upgrade-server-left.yaml @@ -9,6 +9,8 @@ args: - "left" - "--environment" - "$(ENVIRONMENT)" + - "--healthcheck-port" + - "3000" initContainer: enabled: true diff --git a/deploy/prod/smpcv2-0-prod/values-upgrade-server-right.yaml b/deploy/prod/smpcv2-0-prod/values-upgrade-server-right.yaml index f988173e6..35ad13c73 100644 --- a/deploy/prod/smpcv2-0-prod/values-upgrade-server-right.yaml +++ b/deploy/prod/smpcv2-0-prod/values-upgrade-server-right.yaml @@ -9,6 +9,8 @@ args: - "right" - "--environment" - "$(ENVIRONMENT)" + - "--healthcheck-port" + - "3000" initContainer: enabled: true diff --git a/deploy/prod/smpcv2-1-prod/values-iris-mpc.yaml b/deploy/prod/smpcv2-1-prod/values-iris-mpc.yaml index 6968b3c99..e4aae1f37 100644 --- a/deploy/prod/smpcv2-1-prod/values-iris-mpc.yaml +++ b/deploy/prod/smpcv2-1-prod/values-iris-mpc.yaml @@ -2,9 +2,6 @@ env: - name: RUST_LOG value: "info" - - name: NCCL_DEBUG - value: "INFO" - - name: RUST_BACKTRACE value: "1" @@ -39,7 +36,7 @@ env: value: "true" - name: SMPC__DATABASE__LOAD_PARALLELISM - value: "80" + value: "8" - name: SMPC__AWS__REGION value: "eu-north-1" @@ -57,13 +54,13 @@ env: name: application - name: SMPC__PROCESSING_TIMEOUT_SECS - value: "120" + value: "300" - name: SMPC__HEARTBEAT_INTERVAL_SECS value: "2" - name: SMPC__HEARTBEAT_INITIAL_RETRIES - value: "1000" + value: "65" - name: SMPC__PATH value: "/data/" @@ -80,6 +77,18 @@ env: - name: SMPC__PUBLIC_KEY_BASE_URL value: "https://pki-smpc.worldcoin.org" + - name: SMPC__ENABLE_S3_IMPORTER + value: "true" + + - name: SMPC__DB_CHUNKS_BUCKET_NAME + value: "iris-mpc-db-exporter-store-node-1-prod-eu-north-1" + + - name: SMPC__DB_CHUNKS_FOLDER_NAME + value: "binary_output_16k" + + - name: SMPC__DATABASE__LOAD_PARALLELISM + value: "64" + - name: SMPC__CLEAR_DB_BEFORE_INIT value: "true" @@ -121,6 +130,9 @@ env: - name: SMPC__NODE_HOSTNAMES value: '["iris-mpc-node.1.smpcv2.worldcoin.org","iris-mpc-node.2.smpcv2.worldcoin.org","iris-mpc-node.3.smpcv2.worldcoin.org"]' + - name: SMPC__IMAGE_NAME + value: $(IMAGE_NAME) + initContainer: enabled: true image: "amazon/aws-cli:2.17.62" diff --git a/deploy/prod/smpcv2-1-prod/values-upgrade-server-left.yaml b/deploy/prod/smpcv2-1-prod/values-upgrade-server-left.yaml index d140f92fc..b64075865 100644 --- a/deploy/prod/smpcv2-1-prod/values-upgrade-server-left.yaml +++ b/deploy/prod/smpcv2-1-prod/values-upgrade-server-left.yaml @@ -9,6 +9,8 @@ args: - "left" - "--environment" - "$(ENVIRONMENT)" + - "--healthcheck-port" + - "3000" initContainer: enabled: true diff --git a/deploy/prod/smpcv2-1-prod/values-upgrade-server-right.yaml b/deploy/prod/smpcv2-1-prod/values-upgrade-server-right.yaml index 81a947071..0700f4770 100644 --- a/deploy/prod/smpcv2-1-prod/values-upgrade-server-right.yaml +++ b/deploy/prod/smpcv2-1-prod/values-upgrade-server-right.yaml @@ -9,6 +9,8 @@ args: - "right" - "--environment" - "$(ENVIRONMENT)" + - "--healthcheck-port" + - "3000" initContainer: enabled: true diff --git a/deploy/prod/smpcv2-2-prod/values-iris-mpc.yaml b/deploy/prod/smpcv2-2-prod/values-iris-mpc.yaml index 1620d41d1..6f35d9131 100644 --- a/deploy/prod/smpcv2-2-prod/values-iris-mpc.yaml +++ b/deploy/prod/smpcv2-2-prod/values-iris-mpc.yaml @@ -2,9 +2,6 @@ env: - name: RUST_LOG value: "info" - - name: NCCL_DEBUG - value: "INFO" - - name: RUST_BACKTRACE value: "1" @@ -39,7 +36,7 @@ env: value: "true" - name: SMPC__DATABASE__LOAD_PARALLELISM - value: "80" + value: "8" - name: SMPC__AWS__REGION value: "eu-north-1" @@ -57,13 +54,13 @@ env: name: application - name: SMPC__PROCESSING_TIMEOUT_SECS - value: "120" + value: "300" - name: SMPC__HEARTBEAT_INTERVAL_SECS value: "2" - name: SMPC__HEARTBEAT_INITIAL_RETRIES - value: "1000" + value: "65" - name: SMPC__PATH value: "/data/" @@ -80,6 +77,18 @@ env: - name: SMPC__PUBLIC_KEY_BASE_URL value: "https://pki-smpc.worldcoin.org" + - name: SMPC__ENABLE_S3_IMPORTER + value: "true" + + - name: SMPC__DB_CHUNKS_BUCKET_NAME + value: "iris-mpc-db-exporter-store-node-2-prod-eu-north-1" + + - name: SMPC__DB_CHUNKS_FOLDER_NAME + value: "binary_output_16k" + + - name: SMPC__DATABASE__LOAD_PARALLELISM + value: "64" + - name: SMPC__CLEAR_DB_BEFORE_INIT value: "true" @@ -121,6 +130,9 @@ env: - name: SMPC__NODE_HOSTNAMES value: '["iris-mpc-node.1.smpcv2.worldcoin.org","iris-mpc-node.2.smpcv2.worldcoin.org","iris-mpc-node.3.smpcv2.worldcoin.org"]' + - name: SMPC__IMAGE_NAME + value: $(IMAGE_NAME) + initContainer: enabled: true image: "amazon/aws-cli:2.17.62" diff --git a/deploy/prod/smpcv2-2-prod/values-upgrade-server-left.yaml b/deploy/prod/smpcv2-2-prod/values-upgrade-server-left.yaml index cb4fd532c..3b41b5db5 100644 --- a/deploy/prod/smpcv2-2-prod/values-upgrade-server-left.yaml +++ b/deploy/prod/smpcv2-2-prod/values-upgrade-server-left.yaml @@ -9,6 +9,8 @@ args: - "left" - "--environment" - "$(ENVIRONMENT)" + - "--healthcheck-port" + - "3000" initContainer: enabled: true diff --git a/deploy/prod/smpcv2-2-prod/values-upgrade-server-right.yaml b/deploy/prod/smpcv2-2-prod/values-upgrade-server-right.yaml index e4d486815..a531fc0fc 100644 --- a/deploy/prod/smpcv2-2-prod/values-upgrade-server-right.yaml +++ b/deploy/prod/smpcv2-2-prod/values-upgrade-server-right.yaml @@ -9,6 +9,8 @@ args: - "right" - "--environment" - "$(ENVIRONMENT)" + - "--healthcheck-port" + - "3000" initContainer: enabled: true diff --git a/deploy/stage/common-values-iris-mpc.yaml b/deploy/stage/common-values-iris-mpc.yaml index 5b02b4bf2..b8da0ebb0 100644 --- a/deploy/stage/common-values-iris-mpc.yaml +++ b/deploy/stage/common-values-iris-mpc.yaml @@ -1,4 +1,4 @@ -image: "ghcr.io/worldcoin/iris-mpc:v0.9.10" +image: "ghcr.io/worldcoin/iris-mpc:v0.13.9" environment: stage replicaCount: 1 @@ -15,23 +15,22 @@ ports: protocol: TCP livenessProbe: - initialDelaySeconds: 300 httpGet: path: /health port: health readinessProbe: - initialDelaySeconds: 300 periodSeconds: 30 - failureThreshold: 10 httpGet: - path: /health + path: /ready port: health startupProbe: - initialDelaySeconds: 300 + initialDelaySeconds: 60 + failureThreshold: 60 + periodSeconds: 30 httpGet: - path: /health + path: /ready port: health resources: @@ -83,4 +82,4 @@ preStop: sleepPeriod: 10 # terminationGracePeriodSeconds specifies the grace time between SIGTERM and SIGKILL -terminationGracePeriodSeconds: 180 # 3x SMPC__PROCESSING_TIMEOUT_SECS +terminationGracePeriodSeconds: 20 diff --git a/deploy/stage/common-values-reshare-server.yaml b/deploy/stage/common-values-reshare-server.yaml new file mode 100644 index 000000000..fb6e6257b --- /dev/null +++ b/deploy/stage/common-values-reshare-server.yaml @@ -0,0 +1,141 @@ +image: "ghcr.io/worldcoin/iris-mpc:v0.10.4" + +environment: stage +replicaCount: 0 + +strategy: + type: Recreate + +datadog: + enabled: true + +# Nginx exposes the only port required here +ports: + - containerPort: 3001 + name: health + protocol: TCP + +startupProbe: + httpGet: + path: /health + port: health + +livenessProbe: + httpGet: + path: /health + port: health + +readinessProbe: + periodSeconds: 30 + failureThreshold: 10 + httpGet: + path: /health + port: health + +resources: + limits: + cpu: 4 + memory: 16Gi + requests: + cpu: 4 + memory: 16Gi + +imagePullSecrets: + - name: github-secret + +nodeSelector: + kubernetes.io/arch: amd64 + beta.kubernetes.io/instance-type: t3.2xlarge + +podSecurityContext: + runAsUser: 405 + runAsGroup: 405 + +serviceAccount: + create: true + +command: [ "/bin/reshare-server" ] + +env: + - name: SMPC__DATABASE__URL + valueFrom: + secretKeyRef: + key: DATABASE_AURORA_URL + name: application + - name: RUST_LOG + value: info + - name: ENVIRONMENT + value: stage + +service: + enabled: false + +nginxSidecar: + enabled: true + port: 6443 + secrets: + enabled: true + volumeMount: + - name: mounted-secret-name + mountPath: /etc/nginx/cert + volume: + - name: mounted-secret-name + secret: + secretName: application + items: + - key: certificate.crt + path: certificate.crt + - key: key.pem + path: key.pem + optional: false + config: + nginx.conf: | + worker_processes auto; + + error_log /dev/stderr notice; + pid /tmp/nginx.pid; + + events { + worker_connections 1024; + } + + http { + proxy_temp_path /tmp/proxy_temp; + client_body_temp_path /tmp/client_temp; + fastcgi_temp_path /tmp/fastcgi_temp; + uwsgi_temp_path /tmp/uwsgi_temp; + scgi_temp_path /tmp/scgi_temp; + + log_format basic '$remote_addr [$time_local] ' + '$status $bytes_sent'; + + server { + listen 6443 ssl; + http2 on; + + ssl_certificate /etc/nginx/cert/certificate.crt; + ssl_certificate_key /etc/nginx/cert/key.pem; + + ssl_protocols TLSv1.3; + ssl_ciphers HIGH:!aNULL:!MD5; + + # Enable session resumption to improve performance + ssl_session_cache shared:SSL:10m; + ssl_session_timeout 1h; + + location / { + # Forward gRPC traffic to the gRPC server on port 7000 + grpc_pass grpc://127.0.0.1:7000; + error_page 502 = /error502grpc; # Custom error page for GRPC backend issues + } + + # Custom error page + location = /error502grpc { + internal; + default_type text/plain; + return 502 "Bad Gateway: gRPC server unreachable."; + } + + access_log /dev/stdout basic; + } + } diff --git a/deploy/stage/common-values-upgrade-server-left.yaml b/deploy/stage/common-values-upgrade-server-left.yaml index e5f577ac5..46ea80cda 100644 --- a/deploy/stage/common-values-upgrade-server-left.yaml +++ b/deploy/stage/common-values-upgrade-server-left.yaml @@ -1,4 +1,4 @@ -image: "ghcr.io/worldcoin/iris-mpc:v0.8.25" +image: "ghcr.io/worldcoin/iris-mpc:v0.10.2" environment: stage replicaCount: 1 diff --git a/deploy/stage/common-values-upgrade-server-right.yaml b/deploy/stage/common-values-upgrade-server-right.yaml index e5f577ac5..46ea80cda 100644 --- a/deploy/stage/common-values-upgrade-server-right.yaml +++ b/deploy/stage/common-values-upgrade-server-right.yaml @@ -1,4 +1,4 @@ -image: "ghcr.io/worldcoin/iris-mpc:v0.8.25" +image: "ghcr.io/worldcoin/iris-mpc:v0.10.2" environment: stage replicaCount: 1 diff --git a/deploy/stage/smpcv2-0-stage/values-iris-mpc.yaml b/deploy/stage/smpcv2-0-stage/values-iris-mpc.yaml index ef3c3066c..111e913cc 100644 --- a/deploy/stage/smpcv2-0-stage/values-iris-mpc.yaml +++ b/deploy/stage/smpcv2-0-stage/values-iris-mpc.yaml @@ -65,11 +65,26 @@ env: - name: SMPC__PUBLIC_KEY_BASE_URL value: "https://pki-smpcv2-stage.worldcoin.org" + - name: SMPC__SHARES_BUCKET_NAME + value: "wf-smpcv2-stage-sns-requests" + + - name: SMPC__ENABLE_S3_IMPORTER + value: "true" + + - name: SMPC__DB_CHUNKS_BUCKET_NAME + value: "iris-mpc-db-exporter-store-node-0-stage-eu-north-1" + + - name: SMPC__DB_CHUNKS_FOLDER_NAME + value: "binary_output_16k" + + - name: SMPC__LOAD_CHUNKS_PARALLELISM + value: "32" + - name: SMPC__CLEAR_DB_BEFORE_INIT value: "true" - name: SMPC__INIT_DB_SIZE - value: "0" + value: "800000" - name: SMPC__MAX_DB_SIZE value: "1000000" @@ -90,7 +105,7 @@ env: - name: SMPC__SERVICE__METRICS__BUFFER_SIZE value: "256" - + - name: SMPC__SERVICE__METRICS__PREFIX value: "smpcv2-0" @@ -100,6 +115,9 @@ env: - name: SMPC__NODE_HOSTNAMES value: '["iris-mpc-node.1.stage.smpcv2.worldcoin.dev","iris-mpc-node.2.stage.smpcv2.worldcoin.dev","iris-mpc-node.3.stage.smpcv2.worldcoin.dev"]' + - name: SMPC__IMAGE_NAME + value: $(IMAGE_NAME) + initContainer: enabled: true image: "amazon/aws-cli:2.17.62" diff --git a/deploy/stage/smpcv2-0-stage/values-upgrade-server-left.yaml b/deploy/stage/smpcv2-0-stage/values-upgrade-server-left.yaml index 81b44cf68..3af78a4f1 100644 --- a/deploy/stage/smpcv2-0-stage/values-upgrade-server-left.yaml +++ b/deploy/stage/smpcv2-0-stage/values-upgrade-server-left.yaml @@ -9,6 +9,8 @@ args: - "left" - "--environment" - "$(ENVIRONMENT)" + - "--healthcheck-port" + - "3000" initContainer: enabled: true diff --git a/deploy/stage/smpcv2-0-stage/values-upgrade-server-right.yaml b/deploy/stage/smpcv2-0-stage/values-upgrade-server-right.yaml index 45f690e25..56b176fb7 100644 --- a/deploy/stage/smpcv2-0-stage/values-upgrade-server-right.yaml +++ b/deploy/stage/smpcv2-0-stage/values-upgrade-server-right.yaml @@ -9,6 +9,8 @@ args: - "right" - "--environment" - "$(ENVIRONMENT)" + - "--healthcheck-port" + - "3000" initContainer: enabled: true diff --git a/deploy/stage/smpcv2-1-stage/values-iris-mpc.yaml b/deploy/stage/smpcv2-1-stage/values-iris-mpc.yaml index fed0f2b09..f568b95e2 100644 --- a/deploy/stage/smpcv2-1-stage/values-iris-mpc.yaml +++ b/deploy/stage/smpcv2-1-stage/values-iris-mpc.yaml @@ -65,11 +65,26 @@ env: - name: SMPC__PUBLIC_KEY_BASE_URL value: "https://pki-smpcv2-stage.worldcoin.org" + - name: SMPC__SHARES_BUCKET_NAME + value: "wf-smpcv2-stage-sns-requests" + + - name: SMPC__ENABLE_S3_IMPORTER + value: "true" + + - name: SMPC__DB_CHUNKS_BUCKET_NAME + value: "iris-mpc-db-exporter-store-node-1-stage-eu-north-1" + + - name: SMPC__DB_CHUNKS_FOLDER_NAME + value: "binary_output_16k" + + - name: SMPC__LOAD_CHUNKS_PARALLELISM + value: "32" + - name: SMPC__CLEAR_DB_BEFORE_INIT value: "true" - name: SMPC__INIT_DB_SIZE - value: "0" + value: "800000" - name: SMPC__MAX_DB_SIZE value: "1000000" @@ -90,7 +105,7 @@ env: - name: SMPC__SERVICE__METRICS__BUFFER_SIZE value: "256" - + - name: SMPC__SERVICE__METRICS__PREFIX value: "smpcv2-1" @@ -100,6 +115,9 @@ env: - name: SMPC__NODE_HOSTNAMES value: '["iris-mpc-node.1.stage.smpcv2.worldcoin.dev","iris-mpc-node.2.stage.smpcv2.worldcoin.dev","iris-mpc-node.3.stage.smpcv2.worldcoin.dev"]' + - name: SMPC__IMAGE_NAME + value: $(IMAGE_NAME) + initContainer: enabled: true image: "amazon/aws-cli:2.17.62" diff --git a/deploy/stage/smpcv2-1-stage/values-reshare-server.yaml b/deploy/stage/smpcv2-1-stage/values-reshare-server.yaml new file mode 100644 index 000000000..a9bab80bb --- /dev/null +++ b/deploy/stage/smpcv2-1-stage/values-reshare-server.yaml @@ -0,0 +1,61 @@ +args: + - "--bind-addr" + - "0.0.0.0:7000" + - "--db-url" + - "$(SMPC__DATABASE__URL)" + - "--party-id" + - "1" + - "--environment" + - "$(ENVIRONMENT)" + - "--sender1-party-id" + - "0" + - "--sender2-party-id" + - "2" + - "--batch-size" + - "100" + - "--max-buffer-size" + - "10" + - "--healthcheck-port" + - "3001" + +initContainer: + enabled: true + image: "amazon/aws-cli:2.17.62" + name: "reshare-proto-dns-records-updater" + env: + - name: PARTY_ID + value: "2" + - name: MY_POD_IP + valueFrom: + fieldRef: + fieldPath: status.podIP + configMap: + init.sh: | + #!/usr/bin/env bash + + # Set up environment variables + HOSTED_ZONE_ID=$(aws route53 list-hosted-zones-by-name --dns-name "$PARTY_ID".stage.smpcv2.worldcoin.dev --query "HostedZones[].Id" --output text) + + # Generate the JSON content in memory + BATCH_JSON=$(cat <>> key-manager --node-id 2 --env prod rotate --public-key-bucket-name wf-env-stage-public-keys +``` + +This will: + +1. Update the public key in the bucket `wf-env-stage-public-keys` for node 2. +2. Generate a new private key and store aws secrets manager under the secret name: `prod/iris-mpc/ecdh-private-key-2` + +This key will be immediately valid, though the previous key will retain a validity of 24 hours (dictated by the cloudfront caching behavior, +and by application logic that checks against AWSCURRENT and AWSPREVIOUS version of the secret). + + diff --git a/iris-mpc-common/src/bin/key_manager.rs b/iris-mpc-common/src/bin/key_manager.rs index a11398756..bc8346f1b 100644 --- a/iris-mpc-common/src/bin/key_manager.rs +++ b/iris-mpc-common/src/bin/key_manager.rs @@ -15,9 +15,7 @@ use sodiumoxide::crypto::box_::{curve25519xsalsa20poly1305, PublicKey, SecretKey const PUBLIC_KEY_S3_BUCKET_NAME: &str = "wf-smpcv2-stage-public-keys"; const PUBLIC_KEY_S3_KEY_NAME_PREFIX: &str = "public-key"; -const REGION: &str = "eu-north-1"; -/// A fictional versioning CLI #[derive(Debug, Parser)] // requires `derive` feature #[command(name = "key-manager")] #[command(about = "Key manager CLI", long_about = None)] @@ -32,6 +30,12 @@ struct KeyManagerCli { #[arg(short, long, env, default_value = "stage")] env: String, + + #[arg(short, long, env, default_value = "eu-north-1")] + region: String, + + #[arg(short, long, env, default_value = None)] + endpoint_url: Option, } #[derive(Debug, Subcommand)] @@ -67,8 +71,9 @@ async fn main() -> eyre::Result<()> { tracing_subscriber::fmt::init(); let args = KeyManagerCli::parse(); + let region = args.region; - let region_provider = S3Region::new(REGION); + let region_provider = S3Region::new(region.clone()); let shared_config = aws_config::from_env().region(region_provider).load().await; let bucket_key_name = format!("{}-{}", PUBLIC_KEY_S3_KEY_NAME_PREFIX, args.node_id); @@ -86,6 +91,7 @@ async fn main() -> eyre::Result<()> { &private_key_secret_id, dry_run, public_key_bucket_name, + args.endpoint_url, ) .await?; } @@ -101,6 +107,8 @@ async fn main() -> eyre::Result<()> { b64_pub_key, &bucket_key_name, public_key_bucket_name, + region.clone(), + args.endpoint_url, ) .await?; } @@ -108,6 +116,7 @@ async fn main() -> eyre::Result<()> { Ok(()) } +#[allow(clippy::too_many_arguments)] async fn validate_keys( sdk_config: &SdkConfig, secret_id: &str, @@ -115,8 +124,16 @@ async fn validate_keys( b64_pub_key: Option, bucket_key_name: &str, public_key_bucket_name: Option, + region: String, + endpoint_url: Option, ) -> eyre::Result<()> { - let sm_client = SecretsManagerClient::new(sdk_config); + let mut sm_config_builder = aws_sdk_secretsmanager::config::Builder::from(sdk_config); + + if let Some(endpoint_url) = endpoint_url.as_ref() { + sm_config_builder = sm_config_builder.endpoint_url(endpoint_url); + } + + let sm_client = SecretsManagerClient::from_conf(sm_config_builder.build()); let bucket_name = if let Some(bucket_name) = public_key_bucket_name { bucket_name @@ -133,7 +150,7 @@ async fn validate_keys( } else { // Otherwise, get the latest one from S3 using HTTPS let user_pubkey_string = - download_key_from_s3(bucket_name.as_str(), bucket_key_name).await?; + download_key_from_s3(bucket_name.as_str(), bucket_key_name, region.clone()).await?; let user_pubkey = STANDARD.decode(user_pubkey_string.as_bytes()).unwrap(); match PublicKey::from_slice(&user_pubkey) { Some(key) => key, @@ -156,6 +173,7 @@ async fn rotate_keys( private_key_secret_id: &str, dry_run: Option, public_key_bucket_name: Option, + endpoint_url: Option, ) -> eyre::Result<()> { let mut rng = thread_rng(); @@ -169,8 +187,17 @@ async fn rotate_keys( rng.fill(&mut seedbuf); let pk_seed = Seed(seedbuf); - let s3_client = S3Client::new(sdk_config); - let sm_client = SecretsManagerClient::new(sdk_config); + let mut s3_config_builder = aws_sdk_s3::config::Builder::from(sdk_config); + let mut sm_config_builder = aws_sdk_secretsmanager::config::Builder::from(sdk_config); + + if let Some(endpoint_url) = endpoint_url.as_ref() { + s3_config_builder = s3_config_builder.endpoint_url(endpoint_url); + s3_config_builder = s3_config_builder.force_path_style(true); + sm_config_builder = sm_config_builder.endpoint_url(endpoint_url); + } + + let s3_client = S3Client::from_conf(s3_config_builder.build()); + let sm_client = SecretsManagerClient::from_conf(sm_config_builder.build()); let (public_key, private_key) = generate_key_pairs(pk_seed); let pub_key_str = STANDARD.encode(public_key); @@ -231,9 +258,13 @@ async fn rotate_keys( Ok(()) } -async fn download_key_from_s3(bucket: &str, key: &str) -> Result { +async fn download_key_from_s3( + bucket: &str, + key: &str, + region: String, +) -> Result { print!("Downloading key from S3 bucket: {} key: {}", bucket, key); - let s3_url = format!("https://{}.s3.{}.amazonaws.com/{}", bucket, REGION, key); + let s3_url = format!("https://{}.s3.{}.amazonaws.com/{}", bucket, region, key); let client = Client::new(); let response = client.get(&s3_url).send().await?.text().await?; Ok(response) diff --git a/iris-mpc-common/src/config/mod.rs b/iris-mpc-common/src/config/mod.rs index d98a3ec1f..971425b05 100644 --- a/iris-mpc-common/src/config/mod.rs +++ b/iris-mpc-common/src/config/mod.rs @@ -49,6 +49,9 @@ pub struct Config { #[serde(default)] pub public_key_base_url: String, + #[serde(default = "default_shares_bucket_name")] + pub shares_bucket_name: String, + #[serde(default)] pub clear_db_before_init: bool, @@ -76,11 +79,39 @@ pub struct Config { #[serde(default)] pub disable_persistence: bool, + #[serde(default)] + pub enable_debug_timing: bool, + #[serde(default, deserialize_with = "deserialize_yaml_json_string")] pub node_hostnames: Vec, #[serde(default = "default_shutdown_last_results_sync_timeout_secs")] pub shutdown_last_results_sync_timeout_secs: u64, + + #[serde(default)] + pub image_name: String, + + #[serde(default)] + pub enable_s3_importer: bool, + + #[serde(default)] + pub db_chunks_bucket_name: String, + + #[serde(default = "default_load_chunks_parallelism")] + pub load_chunks_parallelism: usize, + + /// Defines the safety overlap to load the DB records >last_modified_at in + /// seconds This is to ensure we don't miss any records that were + /// updated during the DB export to S3 + #[serde(default = "default_db_load_safety_overlap_seconds")] + pub db_load_safety_overlap_seconds: i64, + + #[serde(default)] + pub db_chunks_folder_name: String, +} + +fn default_load_chunks_parallelism() -> usize { + 32 } fn default_processing_timeout_secs() -> u64 { @@ -92,7 +123,7 @@ fn default_max_batch_size() -> usize { } fn default_heartbeat_interval_secs() -> u64 { - 30 + 2 } fn default_heartbeat_initial_retries() -> u64 { @@ -103,6 +134,14 @@ fn default_shutdown_last_results_sync_timeout_secs() -> u64 { 10 } +fn default_shares_bucket_name() -> String { + "wf-mpc-prod-smpcv2-sns-requests".to_string() +} + +fn default_db_load_safety_overlap_seconds() -> i64 { + 60 +} + impl Config { pub fn load_config(prefix: &str) -> eyre::Result { let settings = config::Config::builder(); diff --git a/iris-mpc-common/src/galois.rs b/iris-mpc-common/src/galois.rs index 52304ad70..a5b23a981 100644 --- a/iris-mpc-common/src/galois.rs +++ b/iris-mpc-common/src/galois.rs @@ -1,323 +1,3 @@ -pub mod degree2 { - use crate::id::PartyID; - use rand::{CryptoRng, Rng}; - /// An element of the Galois ring `$\mathbb{Z}_{2^{16}}[x]/(x^2 - x - 1)$`. - #[derive(Copy, Clone, Debug, PartialEq, Eq)] - pub struct GaloisRingElement { - pub coefs: [u16; 2], - } - - impl GaloisRingElement { - pub const ZERO: GaloisRingElement = GaloisRingElement { coefs: [0, 0] }; - pub const ONE: GaloisRingElement = GaloisRingElement { coefs: [1, 0] }; - pub const EXCEPTIONAL_SEQUENCE: [GaloisRingElement; 4] = [ - GaloisRingElement::ZERO, - GaloisRingElement::ONE, - GaloisRingElement { coefs: [0, 1] }, - GaloisRingElement { coefs: [1, 1] }, - ]; - - pub fn random(rng: &mut (impl Rng + CryptoRng)) -> Self { - GaloisRingElement { coefs: rng.gen() } - } - - pub fn inverse(&self) -> Self { - // hard-coded inverses for some elements we need - // too lazy to implement the general case in rust - // and we do not need the general case, since this is only used for the lagrange - // polys, which can be pre-computed anyway - - if *self == GaloisRingElement::ZERO { - panic!("Division by zero"); - } - - if *self == GaloisRingElement::ONE { - return GaloisRingElement::ONE; - } - - if *self == -GaloisRingElement::ONE { - return -GaloisRingElement::ONE; - } - if *self == (GaloisRingElement { coefs: [0, 1] }) { - return GaloisRingElement { coefs: [65535, 1] }; - } - if *self == (GaloisRingElement { coefs: [0, 65535] }) { - return GaloisRingElement { coefs: [1, 65535] }; - } - if *self == (GaloisRingElement { coefs: [1, 1] }) { - return GaloisRingElement { coefs: [2, 65535] }; - } - if *self == (GaloisRingElement { coefs: [1, 65535] }) { - return GaloisRingElement { coefs: [0, 65535] }; - } - if *self == (GaloisRingElement { coefs: [65535, 1] }) { - return GaloisRingElement { coefs: [0, 1] }; - } - - panic!("No inverse for {:?} in LUT", self); - } - } - - impl std::ops::Add for GaloisRingElement { - type Output = Self; - fn add(self, rhs: Self) -> Self::Output { - self.add(&rhs) - } - } - impl std::ops::Add<&GaloisRingElement> for GaloisRingElement { - type Output = Self; - fn add(mut self, rhs: &Self) -> Self::Output { - for i in 0..2 { - self.coefs[i] = self.coefs[i].wrapping_add(rhs.coefs[i]); - } - self - } - } - - impl std::ops::Sub for GaloisRingElement { - type Output = Self; - fn sub(self, rhs: Self) -> Self::Output { - self.sub(&rhs) - } - } - impl std::ops::Sub<&GaloisRingElement> for GaloisRingElement { - type Output = Self; - fn sub(mut self, rhs: &Self) -> Self::Output { - for i in 0..2 { - self.coefs[i] = self.coefs[i].wrapping_sub(rhs.coefs[i]); - } - self - } - } - - impl std::ops::Neg for GaloisRingElement { - type Output = Self; - - fn neg(self) -> Self::Output { - GaloisRingElement { - coefs: [self.coefs[0].wrapping_neg(), self.coefs[1].wrapping_neg()], - } - } - } - - impl std::ops::Mul for GaloisRingElement { - type Output = Self; - fn mul(self, rhs: Self) -> Self::Output { - self.mul(&rhs) - } - } - impl std::ops::Mul<&GaloisRingElement> for GaloisRingElement { - type Output = Self; - fn mul(self, rhs: &Self) -> Self::Output { - GaloisRingElement { - coefs: [ - (self.coefs[0].wrapping_mul(rhs.coefs[0])) - .wrapping_add(self.coefs[1].wrapping_mul(rhs.coefs[1])), - (self.coefs[0].wrapping_mul(rhs.coefs[1])) - .wrapping_add(self.coefs[1].wrapping_mul(rhs.coefs[0])) - .wrapping_add(self.coefs[1].wrapping_mul(rhs.coefs[1])), - ], - } - } - } - - #[derive(Debug, Clone, Copy, PartialEq, Eq)] - pub struct ShamirGaloisRingShare { - pub id: usize, - pub y: GaloisRingElement, - } - impl std::ops::Add for ShamirGaloisRingShare { - type Output = Self; - fn add(self, rhs: Self) -> Self::Output { - assert_eq!(self.id, rhs.id, "ids must be euqal"); - ShamirGaloisRingShare { - id: self.id, - y: self.y + rhs.y, - } - } - } - impl std::ops::Mul for ShamirGaloisRingShare { - type Output = Self; - fn mul(self, rhs: Self) -> Self::Output { - assert_eq!(self.id, rhs.id, "ids must be euqal"); - ShamirGaloisRingShare { - id: self.id, - y: self.y * rhs.y, - } - } - } - impl std::ops::Sub for ShamirGaloisRingShare { - type Output = Self; - fn sub(self, rhs: Self) -> Self::Output { - assert_eq!(self.id, rhs.id, "ids must be euqal"); - ShamirGaloisRingShare { - id: self.id, - y: self.y - rhs.y, - } - } - } - - impl ShamirGaloisRingShare { - pub fn encode_3( - input: &GaloisRingElement, - rng: &mut R, - ) -> [ShamirGaloisRingShare; 3] { - let coefs = [*input, GaloisRingElement::random(rng)]; - (1..=3) - .map(|i| { - let element = GaloisRingElement::EXCEPTIONAL_SEQUENCE[i]; - let share = coefs[0] + coefs[1] * element; - ShamirGaloisRingShare { id: i, y: share } - }) - .collect::>() - .as_slice() - .try_into() - .unwrap() - } - - pub fn encode_3_mat( - input: &[u16; 2], - rng: &mut R, - ) -> [ShamirGaloisRingShare; 3] { - let invec = [input[0], input[1], rng.gen(), rng.gen()]; - let share1 = ShamirGaloisRingShare { - id: 1, - y: GaloisRingElement { - coefs: [ - invec[0].wrapping_add(invec[2]), - invec[1].wrapping_add(invec[3]), - ], - }, - }; - let share2 = ShamirGaloisRingShare { - id: 2, - y: GaloisRingElement { - coefs: [ - invec[0].wrapping_add(invec[3]), - invec[1].wrapping_add(invec[2]).wrapping_add(invec[3]), - ], - }, - }; - let share3 = ShamirGaloisRingShare { - id: 3, - y: GaloisRingElement { - coefs: [ - share2.y.coefs[0].wrapping_add(invec[2]), - share2.y.coefs[1].wrapping_add(invec[3]), - ], - }, - }; - [share1, share2, share3] - } - - pub fn deg_1_lagrange_polys_at_zero( - my_id: PartyID, - other_id: PartyID, - ) -> GaloisRingElement { - let mut res = GaloisRingElement::ONE; - let i = usize::from(my_id) + 1; - let j = usize::from(other_id) + 1; - res = res * (-GaloisRingElement::EXCEPTIONAL_SEQUENCE[j]); - res = res - * (GaloisRingElement::EXCEPTIONAL_SEQUENCE[i] - - GaloisRingElement::EXCEPTIONAL_SEQUENCE[j]) - .inverse(); - res - } - - pub fn deg_2_lagrange_polys_at_zero() -> [GaloisRingElement; 3] { - let mut res = [GaloisRingElement::ONE; 3]; - for i in 1..=3 { - for j in 1..=3 { - if j != i { - res[i - 1] = res[i - 1] * (-GaloisRingElement::EXCEPTIONAL_SEQUENCE[j]); - res[i - 1] = res[i - 1] - * (GaloisRingElement::EXCEPTIONAL_SEQUENCE[i] - - GaloisRingElement::EXCEPTIONAL_SEQUENCE[j]) - .inverse(); - } - } - } - res - } - - pub fn reconstruct_deg_2_shares(shares: &[ShamirGaloisRingShare; 3]) -> GaloisRingElement { - let lagrange_polys_at_zero = Self::deg_2_lagrange_polys_at_zero(); - shares - .iter() - .map(|s| s.y * lagrange_polys_at_zero[s.id - 1]) - .reduce(|a, b| a + b) - .unwrap() - } - } - - #[cfg(test)] - mod tests { - use super::{GaloisRingElement, ShamirGaloisRingShare}; - use rand::thread_rng; - - #[test] - fn inverses() { - for g_e in [ - GaloisRingElement::ONE, - -GaloisRingElement::ONE, - GaloisRingElement::EXCEPTIONAL_SEQUENCE[2], - GaloisRingElement::EXCEPTIONAL_SEQUENCE[3], - GaloisRingElement::EXCEPTIONAL_SEQUENCE[1] - - GaloisRingElement::EXCEPTIONAL_SEQUENCE[2], - GaloisRingElement::EXCEPTIONAL_SEQUENCE[1] - - GaloisRingElement::EXCEPTIONAL_SEQUENCE[3], - GaloisRingElement::EXCEPTIONAL_SEQUENCE[2] - - GaloisRingElement::EXCEPTIONAL_SEQUENCE[1], - GaloisRingElement::EXCEPTIONAL_SEQUENCE[2] - - GaloisRingElement::EXCEPTIONAL_SEQUENCE[3], - GaloisRingElement::EXCEPTIONAL_SEQUENCE[3] - - GaloisRingElement::EXCEPTIONAL_SEQUENCE[1], - GaloisRingElement::EXCEPTIONAL_SEQUENCE[3] - - GaloisRingElement::EXCEPTIONAL_SEQUENCE[2], - ] { - assert_eq!(g_e.inverse() * g_e, GaloisRingElement::ONE); - } - } - #[test] - fn sharing() { - let input1 = GaloisRingElement::random(&mut rand::thread_rng()); - let input2 = GaloisRingElement::random(&mut rand::thread_rng()); - - let shares1 = ShamirGaloisRingShare::encode_3(&input1, &mut thread_rng()); - let shares2 = ShamirGaloisRingShare::encode_3(&input2, &mut thread_rng()); - let shares_mul = [ - shares1[0] * shares2[0], - shares1[1] * shares2[1], - shares1[2] * shares2[2], - ]; - - let reconstructed = ShamirGaloisRingShare::reconstruct_deg_2_shares(&shares_mul); - let expected = input1 * input2; - - assert_eq!(reconstructed, expected); - } - #[test] - fn sharing_mat() { - let input1 = GaloisRingElement::random(&mut rand::thread_rng()); - let input2 = GaloisRingElement::random(&mut rand::thread_rng()); - - let shares1 = ShamirGaloisRingShare::encode_3_mat(&input1.coefs, &mut thread_rng()); - let shares2 = ShamirGaloisRingShare::encode_3_mat(&input2.coefs, &mut thread_rng()); - let shares_mul = [ - shares1[0] * shares2[0], - shares1[1] * shares2[1], - shares1[2] * shares2[2], - ]; - - let reconstructed = ShamirGaloisRingShare::reconstruct_deg_2_shares(&shares_mul); - let expected = input1 * input2; - - assert_eq!(reconstructed, expected); - } - } -} - pub mod degree4 { use crate::id::PartyID; use basis::{Basis, Monomial}; @@ -354,12 +34,23 @@ pub mod degree4 { coefs: [1, 0, 0, 0], basis: PhantomData, }; - pub const EXCEPTIONAL_SEQUENCE: [GaloisRingElement; 5] = [ - GaloisRingElement::ZERO, - GaloisRingElement::ONE, + pub const EXCEPTIONAL_SEQUENCE: [GaloisRingElement; 16] = [ + GaloisRingElement::from_coefs([0, 0, 0, 0]), + GaloisRingElement::from_coefs([1, 0, 0, 0]), GaloisRingElement::from_coefs([0, 1, 0, 0]), GaloisRingElement::from_coefs([1, 1, 0, 0]), GaloisRingElement::from_coefs([0, 0, 1, 0]), + GaloisRingElement::from_coefs([1, 0, 1, 0]), + GaloisRingElement::from_coefs([0, 1, 1, 0]), + GaloisRingElement::from_coefs([1, 1, 1, 0]), + GaloisRingElement::from_coefs([0, 0, 0, 1]), + GaloisRingElement::from_coefs([1, 0, 0, 1]), + GaloisRingElement::from_coefs([0, 1, 0, 1]), + GaloisRingElement::from_coefs([1, 1, 0, 1]), + GaloisRingElement::from_coefs([0, 0, 1, 1]), + GaloisRingElement::from_coefs([1, 0, 1, 1]), + GaloisRingElement::from_coefs([0, 1, 1, 1]), + GaloisRingElement::from_coefs([1, 1, 1, 1]), ]; pub fn encode1(x: &[u16]) -> Option> { if x.len() % 4 != 0 { @@ -388,40 +79,47 @@ pub mod degree4 { ) } + /// Inverse of the element, if it exists + /// + /// # Panics + /// + /// This function panics if the element has no inverse pub fn inverse(&self) -> Self { // hard-coded inverses for some elements we need // too lazy to implement the general case in rust // and we do not need the general case, since this is only used for the lagrange // polys, which can be pre-computed anyway - if *self == GaloisRingElement::ZERO { - panic!("Division by zero"); + if self.coefs.iter().all(|x| x % 2 == 0) { + panic!("Element has no inverse"); } - if *self == GaloisRingElement::ONE { - return GaloisRingElement::ONE; - } + // inversion by exponentition by (p^r -1) * p^(m-1) - 1, with p = 2, r = 4, m = + // 16 + const P: u32 = 2; + const R: u32 = 4; + const M: u32 = 16; + const EXP: u32 = (P.pow(R) - 1) * P.pow(M - 1) - 1; - if *self == -GaloisRingElement::ONE { - return -GaloisRingElement::ONE; - } - if *self == GaloisRingElement::from_coefs([0, 1, 0, 0]) { - return GaloisRingElement::from_coefs([65535, 0, 0, 1]); - } - if *self == GaloisRingElement::from_coefs([0, 65535, 0, 0]) { - return GaloisRingElement::from_coefs([1, 0, 0, 65535]); - } - if *self == GaloisRingElement::from_coefs([1, 1, 0, 0]) { - return GaloisRingElement::from_coefs([2, 65535, 1, 65535]); - } - if *self == GaloisRingElement::from_coefs([1, 65535, 0, 0]) { - return GaloisRingElement::from_coefs([0, 65535, 65535, 65535]); + self.pow(EXP) + } + + /// Basic exponentiation by squaring, not constant time + pub fn pow(&self, mut exp: u32) -> Self { + if exp == 0 { + return Self::ONE; } - if *self == GaloisRingElement::from_coefs([65535, 1, 0, 0]) { - return GaloisRingElement::from_coefs([0, 1, 1, 1]); + let mut x = *self; + let mut y = Self::ONE; + while exp > 1 { + if exp % 2 == 1 { + y = x * y; + exp -= 1; + } + x = x * x; + exp /= 2; } - - panic!("No inverse for {:?} in LUT", self); + x * y } #[allow(non_snake_case)] @@ -719,6 +417,28 @@ pub mod degree4 { res } + // zero-indexed party ids here, party i will map to i+1 in the exceptional + // sequence + pub fn deg_1_lagrange_poly_at_v( + my_id: usize, + other_id: usize, + v: usize, + ) -> GaloisRingElement { + assert!(my_id < 15); + assert!(other_id < 15); + assert!(v < 15); + let i = my_id + 1; + let j = other_id + 1; + let v = v + 1; + let mut res = GaloisRingElement::EXCEPTIONAL_SEQUENCE[v] + - GaloisRingElement::EXCEPTIONAL_SEQUENCE[j]; + res = res + * (GaloisRingElement::EXCEPTIONAL_SEQUENCE[i] + - GaloisRingElement::EXCEPTIONAL_SEQUENCE[j]) + .inverse(); + res + } + pub fn deg_2_lagrange_polys_at_zero() -> [GaloisRingElement; 3] { let mut res = [GaloisRingElement::ONE; 3]; for i in 1..=3 { @@ -753,25 +473,25 @@ pub mod degree4 { use crate::galois::degree4::basis; #[test] - fn inverses() { - for g_e in [ - GaloisRingElement::ONE, - -GaloisRingElement::ONE, - GaloisRingElement::EXCEPTIONAL_SEQUENCE[2], - GaloisRingElement::EXCEPTIONAL_SEQUENCE[3], - GaloisRingElement::EXCEPTIONAL_SEQUENCE[1] - - GaloisRingElement::EXCEPTIONAL_SEQUENCE[2], - GaloisRingElement::EXCEPTIONAL_SEQUENCE[1] - - GaloisRingElement::EXCEPTIONAL_SEQUENCE[3], - GaloisRingElement::EXCEPTIONAL_SEQUENCE[2] - - GaloisRingElement::EXCEPTIONAL_SEQUENCE[1], - GaloisRingElement::EXCEPTIONAL_SEQUENCE[2] - - GaloisRingElement::EXCEPTIONAL_SEQUENCE[3], - GaloisRingElement::EXCEPTIONAL_SEQUENCE[3] - - GaloisRingElement::EXCEPTIONAL_SEQUENCE[1], - GaloisRingElement::EXCEPTIONAL_SEQUENCE[3] - - GaloisRingElement::EXCEPTIONAL_SEQUENCE[2], - ] { + fn exceptional_sequence_is_pairwise_diff_invertible() { + for i in 0..GaloisRingElement::EXCEPTIONAL_SEQUENCE.len() { + for j in 0..GaloisRingElement::EXCEPTIONAL_SEQUENCE.len() { + if i != j { + let diff = GaloisRingElement::EXCEPTIONAL_SEQUENCE[i] + - GaloisRingElement::EXCEPTIONAL_SEQUENCE[j]; + assert_eq!(diff.inverse() * diff, GaloisRingElement::ONE); + } + } + } + } + + #[test] + fn random_inverses() { + for _ in 0..100 { + let mut g_e = GaloisRingElement::random(&mut rand::thread_rng()); + // make it have an inverse + g_e.coefs.iter_mut().for_each(|x| *x |= 1); + assert_eq!(g_e.inverse() * g_e, GaloisRingElement::ONE); } } diff --git a/iris-mpc-common/src/galois_engine.rs b/iris-mpc-common/src/galois_engine.rs index c50927192..fa6e0814b 100644 --- a/iris-mpc-common/src/galois_engine.rs +++ b/iris-mpc-common/src/galois_engine.rs @@ -3,7 +3,7 @@ pub type CompactGaloisRingShares = Vec>; pub mod degree4 { use crate::{ galois::degree4::{basis, GaloisRingElement, ShamirGaloisRingShare}, - iris_db::iris::IrisCodeArray, + iris_db::iris::{IrisCode, IrisCodeArray}, IRIS_CODE_LENGTH, MASK_CODE_LENGTH, }; use base64::{prelude::BASE64_STANDARD, Engine}; @@ -44,9 +44,10 @@ pub mod degree4 { .for_each(|chunk| chunk.rotate_left(by * 4)); } - #[derive(Debug, Clone, PartialEq, Eq, Hash)] + #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct GaloisRingTrimmedMaskCodeShare { pub id: usize, + #[serde(with = "BigArray")] pub coefs: [u16; MASK_CODE_LENGTH], } @@ -316,6 +317,36 @@ pub mod degree4 { } } + pub struct FullGaloisRingIrisCodeShare { + pub code: GaloisRingIrisCodeShare, + pub mask: GaloisRingTrimmedMaskCodeShare, + } + + impl FullGaloisRingIrisCodeShare { + pub fn encode_iris_code( + iris: &IrisCode, + rng: &mut (impl Rng + CryptoRng), + ) -> [FullGaloisRingIrisCodeShare; 3] { + let [code0, code1, code2] = + GaloisRingIrisCodeShare::encode_iris_code(&iris.code, &iris.mask, rng); + let [mask0, mask1, mask2] = GaloisRingIrisCodeShare::encode_mask_code(&iris.mask, rng); + [ + FullGaloisRingIrisCodeShare { + code: code0, + mask: mask0.into(), + }, + FullGaloisRingIrisCodeShare { + code: code1, + mask: mask1.into(), + }, + FullGaloisRingIrisCodeShare { + code: code2, + mask: mask2.into(), + }, + ] + } + } + #[cfg(test)] mod tests { use crate::{ diff --git a/iris-mpc-common/src/helpers/key_pair.rs b/iris-mpc-common/src/helpers/key_pair.rs index c41554539..bcbda891b 100644 --- a/iris-mpc-common/src/helpers/key_pair.rs +++ b/iris-mpc-common/src/helpers/key_pair.rs @@ -47,6 +47,8 @@ pub enum SharesDecodingError { url: String, message: String, }, + #[error("Received error message from S3 for key {}: {}", .key, .message)] + S3ResponseContent { key: String, message: String }, #[error(transparent)] SerdeError(#[from] serde_json::error::Error), #[error(transparent)] @@ -81,7 +83,13 @@ impl Drop for SharesEncryptionKeyPairs { impl SharesEncryptionKeyPairs { pub async fn from_storage(config: Config) -> Result { - let region_provider = Region::new(REGION); + // use the configured region, fallback to the hardcoded value + let region = config + .aws + .and_then(|aws| aws.region) + .unwrap_or_else(|| REGION.to_owned()); + tracing::info!("Using region: {} for key pair download", region); + let region_provider = Region::new(region); let shared_config = aws_config::from_env().region(region_provider).load().await; let client = SecretsManagerClient::new(&shared_config); @@ -192,6 +200,10 @@ async fn download_private_key_from_asm( version_stage: &str, ) -> Result { let private_key_secret_id: String = format!("{}/iris-mpc/ecdh-private-key-{}", env, node_id); + tracing::info!( + "Downloading private key from Secrets Manager: {}", + private_key_secret_id + ); match client .get_secret_value() .secret_id(private_key_secret_id) diff --git a/iris-mpc-common/src/helpers/mod.rs b/iris-mpc-common/src/helpers/mod.rs index d330d08d6..8731fd9a3 100644 --- a/iris-mpc-common/src/helpers/mod.rs +++ b/iris-mpc-common/src/helpers/mod.rs @@ -5,6 +5,7 @@ pub mod kms_dh; pub mod sha256; pub mod shutdown_handler; pub mod smpc_request; +pub mod smpc_response; pub mod sqs_s3_helper; pub mod sync; pub mod task_monitor; diff --git a/iris-mpc-common/src/helpers/smpc_request.rs b/iris-mpc-common/src/helpers/smpc_request.rs index 53c7b1f72..04863df66 100644 --- a/iris-mpc-common/src/helpers/smpc_request.rs +++ b/iris-mpc-common/src/helpers/smpc_request.rs @@ -1,5 +1,6 @@ use super::{key_pair::SharesDecodingError, sha256::calculate_sha256}; use crate::helpers::key_pair::SharesEncryptionKeyPairs; +use aws_sdk_s3::Client as S3Client; use aws_sdk_sns::types::MessageAttributeValue; use aws_sdk_sqs::{ error::SdkError, @@ -7,15 +8,10 @@ use aws_sdk_sqs::{ }; use base64::{engine::general_purpose::STANDARD, Engine}; use eyre::Report; -use reqwest::Client; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use serde_json::Value; -use std::{collections::HashMap, sync::LazyLock}; +use std::{collections::HashMap, sync::Arc}; use thiserror::Error; -use tokio_retry::{ - strategy::{jitter, FixedInterval}, - Retry, -}; #[derive(Serialize, Deserialize, Debug)] pub struct SQSMessage { @@ -105,7 +101,6 @@ where map.serialize(serializer) } -pub const SMPC_MESSAGE_TYPE_ATTRIBUTE: &str = "message_type"; pub const IDENTITY_DELETION_MESSAGE_TYPE: &str = "identity_deletion"; pub const CIRCUIT_BREAKER_MESSAGE_TYPE: &str = "circuit_breaker"; pub const UNIQUENESS_MESSAGE_TYPE: &str = "uniqueness"; @@ -114,7 +109,7 @@ pub const UNIQUENESS_MESSAGE_TYPE: &str = "uniqueness"; pub struct UniquenessRequest { pub batch_size: Option, pub signup_id: String, - pub s3_presigned_url: String, + pub s3_key: String, pub iris_shares_file_hashes: [String; 3], } @@ -197,51 +192,45 @@ impl SharesS3Object { } } -static S3_HTTP_CLIENT: LazyLock = LazyLock::new(Client::new); - impl UniquenessRequest { pub async fn get_iris_data_by_party_id( &self, party_id: usize, + bucket_name: &String, + s3_client: &Arc, ) -> Result { - // Send a GET request to the presigned URL - let retry_strategy = FixedInterval::from_millis(200).map(jitter).take(5); - let response = Retry::spawn(retry_strategy, || async { - S3_HTTP_CLIENT - .get(self.s3_presigned_url.clone()) - .send() - .await - }) - .await?; - - // Ensure the request was successful - if response.status().is_success() { - // Parse the JSON response into the SharesS3Object struct - let shares_file: SharesS3Object = match response.json().await { - Ok(file) => file, - Err(e) => { - tracing::error!("Failed to parse JSON: {}", e); - return Err(SharesDecodingError::RequestError(e)); + let response = s3_client + .get_object() + .bucket(bucket_name) + .key(self.s3_key.as_str()) + .send() + .await + .map_err(|err| { + tracing::error!("Failed to download file: {}", err); + SharesDecodingError::S3ResponseContent { + key: self.s3_key.clone(), + message: err.to_string(), } - }; - - // Construct the field name dynamically - let field_name = format!("iris_share_{}", party_id); - // Access the field dynamically - if let Some(value) = shares_file.get(party_id) { - Ok(value.to_string()) - } else { - tracing::error!("Failed to find field: {}", field_name); - Err(SharesDecodingError::SecretStringNotFound) + })?; + + let object_body = response.body.collect().await.map_err(|e| { + tracing::error!("Failed to get object body: {}", e); + SharesDecodingError::S3ResponseContent { + key: self.s3_key.clone(), + message: e.to_string(), } - } else { - tracing::error!("Failed to download file: {}", response.status()); - Err(SharesDecodingError::ResponseContent { - status: response.status(), - url: self.s3_presigned_url.clone(), - message: response.text().await.unwrap_or_default(), - }) - } + })?; + + let bytes = object_body.into_bytes(); + + let shares_file: SharesS3Object = serde_json::from_slice(&bytes)?; + + let field_name = format!("iris_share_{}", party_id); + + shares_file.get(party_id).cloned().ok_or_else(|| { + tracing::error!("Failed to find field: {}", field_name); + SharesDecodingError::SecretStringNotFound + }) } pub fn decrypt_iris_share( @@ -299,70 +288,3 @@ impl UniquenessRequest { Ok(self.iris_shares_file_hashes[party_id] == calculate_sha256(stringified_share)) } } - -#[derive(Debug, Serialize, Deserialize, Clone)] -pub struct UniquenessResult { - pub node_id: usize, - pub serial_id: Option, - pub is_match: bool, - pub signup_id: String, - pub matched_serial_ids: Option>, - pub matched_serial_ids_left: Option>, - pub matched_serial_ids_right: Option>, - pub matched_batch_request_ids: Option>, -} - -impl UniquenessResult { - #[allow(clippy::too_many_arguments)] - pub fn new( - node_id: usize, - serial_id: Option, - is_match: bool, - signup_id: String, - matched_serial_ids: Option>, - matched_serial_ids_left: Option>, - matched_serial_ids_right: Option>, - matched_batch_request_ids: Option>, - ) -> Self { - Self { - node_id, - serial_id, - is_match, - signup_id, - matched_serial_ids, - matched_serial_ids_left, - matched_serial_ids_right, - matched_batch_request_ids, - } - } -} - -#[derive(Debug, Serialize, Deserialize, Clone)] -pub struct IdentityDeletionResult { - pub node_id: usize, - pub serial_id: u32, - pub success: bool, -} - -impl IdentityDeletionResult { - pub fn new(node_id: usize, serial_id: u32, success: bool) -> Self { - Self { - node_id, - serial_id, - success, - } - } -} - -pub fn create_message_type_attribute_map( - message_type: &str, -) -> HashMap { - let mut message_attributes_map = HashMap::new(); - let message_type_value = MessageAttributeValue::builder() - .data_type("String") - .string_value(message_type) - .build() - .unwrap(); - message_attributes_map.insert(SMPC_MESSAGE_TYPE_ATTRIBUTE.to_string(), message_type_value); - message_attributes_map -} diff --git a/iris-mpc-common/src/helpers/smpc_response.rs b/iris-mpc-common/src/helpers/smpc_response.rs new file mode 100644 index 000000000..492ecddc3 --- /dev/null +++ b/iris-mpc-common/src/helpers/smpc_response.rs @@ -0,0 +1,78 @@ +use aws_sdk_sns::types::MessageAttributeValue; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +pub const SMPC_MESSAGE_TYPE_ATTRIBUTE: &str = "message_type"; +// Error Reasons +pub const ERROR_FAILED_TO_PROCESS_IRIS_SHARES: &str = "failed_to_process_iris_shares"; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct UniquenessResult { + pub node_id: usize, + pub serial_id: Option, + pub is_match: bool, + pub signup_id: String, + pub matched_serial_ids: Option>, + pub matched_serial_ids_left: Option>, + pub matched_serial_ids_right: Option>, + pub matched_batch_request_ids: Option>, + pub error: Option, + pub error_reason: Option, +} + +impl UniquenessResult { + #[allow(clippy::too_many_arguments)] + pub fn new( + node_id: usize, + serial_id: Option, + is_match: bool, + signup_id: String, + matched_serial_ids: Option>, + matched_serial_ids_left: Option>, + matched_serial_ids_right: Option>, + matched_batch_request_ids: Option>, + ) -> Self { + Self { + node_id, + serial_id, + is_match, + signup_id, + matched_serial_ids, + matched_serial_ids_left, + matched_serial_ids_right, + matched_batch_request_ids, + error: None, + error_reason: None, + } + } +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct IdentityDeletionResult { + pub node_id: usize, + pub serial_id: u32, + pub success: bool, +} + +impl IdentityDeletionResult { + pub fn new(node_id: usize, serial_id: u32, success: bool) -> Self { + Self { + node_id, + serial_id, + success, + } + } +} + +pub fn create_message_type_attribute_map( + message_type: &str, +) -> HashMap { + let mut message_attributes_map = HashMap::new(); + let message_type_value = MessageAttributeValue::builder() + .data_type("String") + .string_value(message_type) + .build() + .unwrap(); + message_attributes_map.insert(SMPC_MESSAGE_TYPE_ATTRIBUTE.to_string(), message_type_value); + message_attributes_map +} diff --git a/iris-mpc-common/src/iris_db/iris.rs b/iris-mpc-common/src/iris_db/iris.rs index b8acc9e88..1176d5a0b 100644 --- a/iris-mpc-common/src/iris_db/iris.rs +++ b/iris-mpc-common/src/iris_db/iris.rs @@ -4,12 +4,14 @@ use rand::{ distributions::{Bernoulli, Distribution}, Rng, }; +use serde::{Deserialize, Serialize}; +use serde_big_array::BigArray; pub const MATCH_THRESHOLD_RATIO: f64 = 0.375; #[repr(transparent)] -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct IrisCodeArray(pub [u64; Self::IRIS_CODE_SIZE_U64]); +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub struct IrisCodeArray(#[serde(with = "BigArray")] pub [u64; Self::IRIS_CODE_SIZE_U64]); impl Default for IrisCodeArray { fn default() -> Self { Self::ZERO @@ -141,7 +143,7 @@ impl std::ops::BitXor for IrisCodeArray { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct IrisCode { pub code: IrisCodeArray, pub mask: IrisCodeArray, diff --git a/iris-mpc-common/tests/smpc_request.rs b/iris-mpc-common/tests/smpc_request.rs index 273c65008..1c2e7d5fd 100644 --- a/iris-mpc-common/tests/smpc_request.rs +++ b/iris-mpc-common/tests/smpc_request.rs @@ -1,6 +1,7 @@ mod tests { + use aws_credential_types::{provider::SharedCredentialsProvider, Credentials}; + use aws_sdk_s3::Client as S3Client; use base64::{engine::general_purpose::STANDARD, Engine}; - use http::StatusCode; use iris_mpc_common::helpers::{ key_pair::{SharesDecodingError, SharesEncryptionKeyPairs}, sha256::calculate_sha256, @@ -8,10 +9,8 @@ mod tests { }; use serde_json::json; use sodiumoxide::crypto::{box_::PublicKey, sealedbox}; - use wiremock::{ - matchers::{method, path}, - Mock, MockServer, ResponseTemplate, - }; + use std::sync::Arc; + use wiremock::{matchers::method, Mock, MockServer, ResponseTemplate}; const PREVIOUS_PUBLIC_KEY: &str = "1UY8lKlS7aVj5ZnorSfLIHlG3jg+L4ToVi4K+mLKqFQ="; const PREVIOUS_PRIVATE_KEY: &str = "X26wWfzP5fKMP7QMz0X3eZsEeF4NhJU92jT69wZg6x8="; @@ -45,7 +44,7 @@ mod tests { UniquenessRequest { batch_size: Some(1), signup_id: "signup_mock".to_string(), - s3_presigned_url: "https://example.com/mock".to_string(), + s3_key: "mock".to_string(), iris_shares_file_hashes: hashes, } } @@ -54,7 +53,7 @@ mod tests { UniquenessRequest { batch_size: None, signup_id: "test_signup_id".to_string(), - s3_presigned_url: "https://example.com/package".to_string(), + s3_key: "package".to_string(), iris_shares_file_hashes: [ "hash_0".to_string(), "hash_1".to_string(), @@ -66,26 +65,46 @@ mod tests { #[tokio::test] async fn test_retrieve_iris_shares_from_s3_success() { let mock_server = MockServer::start().await; - - // Simulate a successful response from the presigned URL + let bucket_name = "bobTheBucket"; + let key = "kateTheKey"; let response_body = json!({ "iris_share_0": "share_0_data", "iris_share_1": "share_1_data", "iris_share_2": "share_2_data" }); - let template = ResponseTemplate::new(StatusCode::OK).set_body_json(response_body.clone()); + let data = response_body.to_string(); Mock::given(method("GET")) - .and(path("/test_presign_url")) - .respond_with(template) + .respond_with( + ResponseTemplate::new(200) + .insert_header("Content-Type", "application/octet-stream") + .set_body_raw(data, "application/octet-stream"), + ) .mount(&mock_server) .await; + let credentials = + Credentials::new("test-access-key", "test-secret-key", None, None, "test"); + let credentials_provider = SharedCredentialsProvider::new(credentials); + // Configure the S3Client to point to the mock server + let config = aws_config::from_env() + .region("us-west-2") + .endpoint_url(mock_server.uri()) + .credentials_provider(credentials_provider) + .load() + .await; + let s3_config = aws_sdk_s3::config::Builder::from(&config) + .endpoint_url(mock_server.uri()) + .force_path_style(true) + .build(); + + let s3_client = Arc::new(S3Client::from_conf(s3_config)); + let smpc_request = UniquenessRequest { batch_size: None, signup_id: "test_signup_id".to_string(), - s3_presigned_url: mock_server.uri().clone() + "/test_presign_url", + s3_key: key.to_string(), iris_shares_file_hashes: [ "hash_0".to_string(), "hash_1".to_string(), @@ -93,7 +112,9 @@ mod tests { ], }; - let result = smpc_request.get_iris_data_by_party_id(0).await; + let result = smpc_request + .get_iris_data_by_party_id(0, &bucket_name.to_string(), &s3_client) + .await; assert!(result.is_ok()); assert_eq!(result.unwrap(), "share_0_data".to_string()); diff --git a/iris-mpc-cpu/.gitignore b/iris-mpc-cpu/.gitignore new file mode 100644 index 000000000..249cda967 --- /dev/null +++ b/iris-mpc-cpu/.gitignore @@ -0,0 +1 @@ +/data \ No newline at end of file diff --git a/iris-mpc-cpu/Cargo.toml b/iris-mpc-cpu/Cargo.toml index 726d3298c..6940012ac 100644 --- a/iris-mpc-cpu/Cargo.toml +++ b/iris-mpc-cpu/Cargo.toml @@ -10,31 +10,51 @@ repository.workspace = true [dependencies] aes-prng = { git = "https://github.com/tf-encrypted/aes-prng.git", branch = "dragos/display"} async-channel = "2.3.1" +async-stream = "0.3.6" async-trait = "~0.1" -bincode = "1.3.3" +backoff = {version="0.4.0", features = ["tokio"]} +bincode.workspace = true bytes = "1.7" bytemuck.workspace = true +clap.workspace = true dashmap = "6.1.0" eyre.workspace = true futures.workspace = true -hawk-pack = { git = "https://github.com/Inversed-Tech/hawk-pack.git", rev = "4e6de24" } +hawk-pack.workspace = true iris-mpc-common = { path = "../iris-mpc-common" } itertools.workspace = true num-traits.workspace = true +prost = "0.13" rand.workspace = true rstest = "0.23.0" serde.workspace = true +serde_json.workspace = true static_assertions.workspace = true tokio.workspace = true +tokio-stream = "0.1" +tonic = "0.12.3" tracing.workspace = true +tracing-subscriber.workspace = true tracing-test = "0.2.5" +uuid.workspace = true [dev-dependencies] criterion = { version = "0.5.1", features = ["async_tokio"] } +[build-dependencies] +tonic-build = "0.12.3" + [[bench]] name = "hnsw" harness = false [[example]] -name = "hnsw-ex" \ No newline at end of file +name = "hnsw-ex" + +[[bin]] +name = "local_hnsw" +path = "bin/local_hnsw.rs" + +[[bin]] +name = "generate_benchmark_data" +path = "bin/generate_benchmark_data.rs" \ No newline at end of file diff --git a/iris-mpc-cpu/benches/hnsw.rs b/iris-mpc-cpu/benches/hnsw.rs index c4fcce26d..62be58814 100644 --- a/iris-mpc-cpu/benches/hnsw.rs +++ b/iris-mpc-cpu/benches/hnsw.rs @@ -1,12 +1,14 @@ use aes_prng::AesRng; use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, SamplingMode}; -use hawk_pack::{graph_store::GraphMem, hnsw_db::HawkSearcher, VectorStore}; +use hawk_pack::{graph_store::GraphMem, HawkSearcher}; use iris_mpc_common::iris_db::{db::IrisDB, iris::IrisCode}; use iris_mpc_cpu::{ database_generators::{create_random_sharing, generate_galois_iris_shares}, execution::local::LocalRuntime, - hawkers::{galois_store::gr_create_ready_made_hawk_searcher, plaintext_store::PlaintextStore}, - protocol::ops::{cross_compare, galois_ring_pairwise_distance, galois_ring_to_rep3}, + hawkers::{galois_store::LocalNetAby3NgStoreProtocol, plaintext_store::PlaintextStore}, + protocol::ops::{ + batch_signed_lift_vec, cross_compare, galois_ring_pairwise_distance, galois_ring_to_rep3, + }, }; use rand::SeedableRng; use tokio::task::JoinSet; @@ -31,18 +33,8 @@ fn bench_plaintext_hnsw(c: &mut Criterion) { for _ in 0..database_size { let raw_query = IrisCode::random_rng(&mut rng); let query = vector.prepare_query(raw_query.clone()); - let neighbors = searcher - .search_to_insert(&mut vector, &mut graph, &query) - .await; - let inserted = vector.insert(&query).await; searcher - .insert_from_search_results( - &mut vector, - &mut graph, - &mut rng, - inserted, - neighbors, - ) + .insert(&mut vector, &mut graph, &query, &mut rng) .await; } (vector, graph) @@ -56,17 +48,8 @@ fn bench_plaintext_hnsw(c: &mut Criterion) { let mut rng = AesRng::seed_from_u64(0_u64); let on_the_fly_query = IrisDB::new_random_rng(1, &mut rng).db[0].clone(); let query = db_vectors.prepare_query(on_the_fly_query); - let neighbors = searcher - .search_to_insert(&mut db_vectors, &mut graph, &query) - .await; searcher - .insert_from_search_results( - &mut db_vectors, - &mut graph, - &mut rng, - query, - neighbors, - ) + .insert(&mut db_vectors, &mut graph, &query, &mut rng) .await; }, criterion::BatchSize::SmallInput, @@ -89,8 +72,7 @@ fn bench_hnsw_primitives(c: &mut Criterion) { let t1 = create_random_sharing(&mut rng, 10_u16); let t2 = create_random_sharing(&mut rng, 10_u16); - let runtime = LocalRuntime::replicated_test_config(); - let ready_sessions = runtime.create_player_sessions().await.unwrap(); + let runtime = LocalRuntime::mock_setup_with_grpc().await.unwrap(); let mut jobs = JoinSet::new(); for (index, player) in runtime.identities.iter().enumerate() { @@ -98,11 +80,25 @@ fn bench_hnsw_primitives(c: &mut Criterion) { let d2i = d2[index].clone(); let t1i = t1[index].clone(); let t2i = t2[index].clone(); - let mut player_session = ready_sessions.get(player).unwrap().clone(); + let mut player_session = runtime.sessions.get(player).unwrap().clone(); jobs.spawn(async move { - cross_compare(&mut player_session, d1i, t1i, d2i, t2i) - .await - .unwrap() + let ds_and_ts = batch_signed_lift_vec(&mut player_session, vec![ + d1i.clone(), + d2i.clone(), + t1i.clone(), + t2i.clone(), + ]) + .await + .unwrap(); + cross_compare( + &mut player_session, + ds_and_ts[0].clone(), + ds_and_ts[1].clone(), + ds_and_ts[2].clone(), + ds_and_ts[3].clone(), + ) + .await + .unwrap() }); } let _outputs = black_box(jobs.join_all().await); @@ -117,8 +113,7 @@ fn bench_gr_primitives(c: &mut Criterion) { .build() .unwrap(); b.to_async(&rt).iter(|| async move { - let runtime = LocalRuntime::replicated_test_config(); - let ready_sessions = runtime.create_player_sessions().await.unwrap(); + let runtime = LocalRuntime::mock_setup_with_grpc().await.unwrap(); let mut rng = AesRng::seed_from_u64(0); let iris_db = IrisDB::new_random_rng(4, &mut rng).db; @@ -135,7 +130,7 @@ fn bench_gr_primitives(c: &mut Criterion) { let x2 = x2[index].clone(); let mut y2 = y2[index].clone(); - let mut player_session = ready_sessions.get(player).unwrap().clone(); + let mut player_session = runtime.sessions.get(player).unwrap().clone(); jobs.spawn(async move { y1.code.preprocess_iris_code_query_share(); y1.mask.preprocess_mask_code_query_share(); @@ -148,6 +143,9 @@ fn bench_gr_primitives(c: &mut Criterion) { let ds_and_ts = galois_ring_to_rep3(&mut player_session, ds_and_ts) .await .unwrap(); + let ds_and_ts = batch_signed_lift_vec(&mut player_session, ds_and_ts) + .await + .unwrap(); cross_compare( &mut player_session, ds_and_ts[0].clone(), @@ -164,47 +162,67 @@ fn bench_gr_primitives(c: &mut Criterion) { }); } +/// To run this benchmark, you need to generate the data first by running the +/// following commands: +/// +/// cargo run --release --bin generate_benchmark_data fn bench_gr_ready_made_hnsw(c: &mut Criterion) { let mut group = c.benchmark_group("gr_ready_made_hnsw"); group.sample_size(10); - for database_size in [1, 10, 100, 1000, 10000, 100000] { + for database_size in [1, 10, 100, 1000, 10_000, 100_000] { let rt = tokio::runtime::Builder::new_multi_thread() .enable_all() .build() .unwrap(); - let (_, secret_searcher) = rt.block_on(async move { + let secret_searcher = rt.block_on(async move { let mut rng = AesRng::seed_from_u64(0_u64); - gr_create_ready_made_hawk_searcher(&mut rng, database_size) - .await - .unwrap() + LocalNetAby3NgStoreProtocol::lazy_setup_from_files_with_grpc( + "./data/store.ndjson", + &format!("./data/graph_{}.dat", database_size), + &mut rng, + database_size, + false, + ) + .await }); + if let Err(e) = secret_searcher { + eprintln!("bench_gr_ready_made_hnsw failed. {e:?}"); + rt.shutdown_timeout(std::time::Duration::from_secs(5)); + return; + } + let (_, secret_searcher) = secret_searcher.unwrap(); + group.bench_function( BenchmarkId::new("gr-big-hnsw-insertions", database_size), |b| { b.to_async(&rt).iter_batched( || secret_searcher.clone(), - |(mut db_vectors, mut db_graph)| async move { + |vectors_graphs| async move { let searcher = HawkSearcher::default(); let mut rng = AesRng::seed_from_u64(0_u64); let on_the_fly_query = IrisDB::new_random_rng(1, &mut rng).db[0].clone(); let raw_query = generate_galois_iris_shares(&mut rng, on_the_fly_query); - let query = db_vectors.prepare_query(raw_query); - let neighbors = searcher - .search_to_insert(&mut db_vectors, &mut db_graph, &query) - .await; - searcher - .insert_from_search_results( - &mut db_vectors, - &mut db_graph, - &mut rng, - query, - neighbors, - ) - .await; + let mut jobs = JoinSet::new(); + + for (vector_store, graph_store) in vectors_graphs.into_iter() { + let mut vector_store = vector_store; + let mut graph_store = graph_store; + + let player_index = vector_store.get_owner_index(); + let query = vector_store.prepare_query(raw_query[player_index].clone()); + let searcher = searcher.clone(); + let mut rng = rng.clone(); + jobs.spawn(async move { + searcher + .insert(&mut vector_store, &mut graph_store, &query, &mut rng) + .await; + }); + } + jobs.join_all().await; }, criterion::BatchSize::SmallInput, ) @@ -216,17 +234,27 @@ fn bench_gr_ready_made_hnsw(c: &mut Criterion) { |b| { b.to_async(&rt).iter_batched( || secret_searcher.clone(), - |(mut db_vectors, mut db_graph)| async move { + |vectors_graphs| async move { let searcher = HawkSearcher::default(); let mut rng = AesRng::seed_from_u64(0_u64); let on_the_fly_query = IrisDB::new_random_rng(1, &mut rng).db[0].clone(); let raw_query = generate_galois_iris_shares(&mut rng, on_the_fly_query); - let query = db_vectors.prepare_query(raw_query); - let neighbors = searcher - .search_to_insert(&mut db_vectors, &mut db_graph, &query) - .await; - searcher.is_match(&mut db_vectors, &neighbors).await; + let mut jobs = JoinSet::new(); + for (vector_store, graph_store) in vectors_graphs.into_iter() { + let mut vector_store = vector_store; + let mut graph_store = graph_store; + let player_index = vector_store.get_owner_index(); + let query = vector_store.prepare_query(raw_query[player_index].clone()); + let searcher = searcher.clone(); + jobs.spawn(async move { + let neighbors = searcher + .search(&mut vector_store, &mut graph_store, &query, 1) + .await; + searcher.is_match(&mut vector_store, &[neighbors]).await; + }); + } + jobs.join_all().await; }, criterion::BatchSize::SmallInput, ) diff --git a/iris-mpc-cpu/bin/generate_benchmark_data.rs b/iris-mpc-cpu/bin/generate_benchmark_data.rs new file mode 100644 index 000000000..f89c285b6 --- /dev/null +++ b/iris-mpc-cpu/bin/generate_benchmark_data.rs @@ -0,0 +1,26 @@ +use aes_prng::AesRng; +use iris_mpc_cpu::{ + hawkers::plaintext_store::PlaintextStore, + py_bindings::{io::write_bin, plaintext_store::to_ndjson_file}, +}; +use rand::SeedableRng; +use std::error::Error; + +#[tokio::main] +async fn main() -> Result<(), Box> { + // create a folder ./iris-mpc-cpu/data if it is non-existent + let crate_root = env!("CARGO_MANIFEST_DIR"); + std::fs::create_dir_all(format!("{crate_root}/data"))?; + let mut rng = AesRng::seed_from_u64(0_u64); + println!("Generating plaintext store with 100_000 irises"); + let mut store = PlaintextStore::create_random_store(&mut rng, 100_000).await?; + println!("Writing store to file"); + to_ndjson_file(&store, &format!("{crate_root}/data/store.ndjson"))?; + + for graph_size in [1, 10, 100, 1000, 10_000, 100_000] { + println!("Generating graph with {} vertices", graph_size); + let graph = store.create_graph(&mut rng, graph_size).await?; + write_bin(&graph, &format!("{crate_root}/data/graph_{graph_size}.dat"))?; + } + Ok(()) +} diff --git a/iris-mpc-cpu/bin/local_hnsw.rs b/iris-mpc-cpu/bin/local_hnsw.rs new file mode 100644 index 000000000..20eca212d --- /dev/null +++ b/iris-mpc-cpu/bin/local_hnsw.rs @@ -0,0 +1,24 @@ +use aes_prng::AesRng; +use clap::Parser; +use iris_mpc_cpu::hawkers::galois_store::LocalNetAby3NgStoreProtocol; +use rand::SeedableRng; +use std::error::Error; + +#[derive(Parser)] +struct Args { + #[clap(short = 'n', default_value = "1000")] + database_size: usize, +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let args = Args::parse(); + let database_size = args.database_size; + + println!("Starting Local HNSW with {} vectors", database_size); + let mut rng = AesRng::seed_from_u64(0_u64); + + LocalNetAby3NgStoreProtocol::shared_random_setup_with_grpc(&mut rng, database_size).await?; + + Ok(()) +} diff --git a/iris-mpc-cpu/build.rs b/iris-mpc-cpu/build.rs new file mode 100644 index 000000000..cf3860392 --- /dev/null +++ b/iris-mpc-cpu/build.rs @@ -0,0 +1,6 @@ +fn main() { + tonic_build::configure() + .out_dir("src/proto_generated") + .compile_protos(&["src/proto/party_node.proto"], &["src/proto"]) + .unwrap_or_else(|e| panic!("Failed to compile protos {:?}", e)); +} diff --git a/iris-mpc-cpu/examples/hnsw-ex.rs b/iris-mpc-cpu/examples/hnsw-ex.rs index 041f0118a..71c925028 100644 --- a/iris-mpc-cpu/examples/hnsw-ex.rs +++ b/iris-mpc-cpu/examples/hnsw-ex.rs @@ -1,5 +1,5 @@ use aes_prng::AesRng; -use hawk_pack::{graph_store::GraphMem, hnsw_db::HawkSearcher, VectorStore}; +use hawk_pack::{graph_store::GraphMem, HawkSearcher}; use iris_mpc_common::iris_db::iris::IrisCode; use iris_mpc_cpu::hawkers::plaintext_store::PlaintextStore; use rand::SeedableRng; @@ -21,12 +21,8 @@ fn main() { for idx in 0..DATABASE_SIZE { let raw_query = IrisCode::random_rng(&mut rng); let query = vector.prepare_query(raw_query.clone()); - let neighbors = searcher - .search_to_insert(&mut vector, &mut graph, &query) - .await; - let inserted = vector.insert(&query).await; searcher - .insert_from_search_results(&mut vector, &mut graph, &mut rng, inserted, neighbors) + .insert(&mut vector, &mut graph, &query, &mut rng) .await; if idx % 100 == 99 { println!("{}", idx + 1); diff --git a/iris-mpc-cpu/src/database_generators.rs b/iris-mpc-cpu/src/database_generators.rs index 3d38209ed..243981322 100644 --- a/iris-mpc-cpu/src/database_generators.rs +++ b/iris-mpc-cpu/src/database_generators.rs @@ -4,11 +4,12 @@ use iris_mpc_common::{ iris_db::iris::IrisCode, }; use rand::{CryptoRng, Rng, RngCore}; +use serde::{Deserialize, Serialize}; type ShareRing = u16; type ShareRingPlain = RingElement; -#[derive(PartialEq, Eq, Debug, Clone)] +#[derive(PartialEq, Eq, Debug, Clone, Serialize, Deserialize, Hash)] pub struct GaloisRingSharedIris { pub code: GaloisRingIrisCodeShare, pub mask: GaloisRingTrimmedMaskCodeShare, diff --git a/iris-mpc-cpu/src/execution/local.rs b/iris-mpc-cpu/src/execution/local.rs index cc400076b..0f425beae 100644 --- a/iris-mpc-cpu/src/execution/local.rs +++ b/iris-mpc-cpu/src/execution/local.rs @@ -3,81 +3,172 @@ use crate::{ player::*, session::{BootSession, Session, SessionHandles, SessionId}, }, - network::local::LocalNetworkingStore, - protocol::{ - ops::setup_replicated_prf, - prf::{Prf, PrfSeed}, - }, + network::{grpc::setup_local_grpc_networking, local::LocalNetworkingStore, NetworkType}, + protocol::{ops::setup_replicated_prf, prf::PrfSeed}, +}; +use std::{ + collections::{HashMap, HashSet}, + sync::{Arc, LazyLock}, }; -use std::{collections::HashMap, sync::Arc}; -use tokio::task::JoinSet; +use tokio::{sync::Mutex, task::JoinSet}; + +pub fn generate_local_identities() -> Vec { + vec![ + Identity::from("alice"), + Identity::from("bob"), + Identity::from("charlie"), + ] +} + +static USED_PORTS: LazyLock>> = LazyLock::new(|| Mutex::new(HashSet::new())); + +pub async fn get_free_local_addresses(num_ports: usize) -> eyre::Result> { + let mut addresses = vec![]; + let mut listeners = vec![]; + while addresses.len() < num_ports { + let listener = std::net::TcpListener::bind("127.0.0.1:0")?; + let port = listener.local_addr()?.port(); + if USED_PORTS.lock().await.insert(port) { + addresses.push(format!("127.0.0.1:{port}")); + listeners.push(listener); + } else { + tracing::warn!("Port {port} already in use, retrying"); + } + } + tracing::info!("Found free addresses: {addresses:?}"); + Ok(addresses) +} #[derive(Debug, Clone)] pub struct LocalRuntime { pub identities: Vec, pub role_assignments: RoleAssignment, - pub prf_setups: Option>, pub seeds: Vec, + // only one session per player is created + pub sessions: HashMap, } impl LocalRuntime { - pub fn replicated_test_config() -> Self { + pub async fn mock_setup(network_t: NetworkType) -> eyre::Result { let num_parties = 3; - let identities: Vec = vec!["alice".into(), "bob".into(), "charlie".into()]; + let identities = generate_local_identities(); let mut seeds = Vec::new(); for i in 0..num_parties { let mut seed = [0_u8; 16]; seed[0] = i; seeds.push(seed); } - LocalRuntime::new(identities, seeds) + LocalRuntime::new_with_network_type(identities, seeds, network_t).await } - pub fn new(identities: Vec, seeds: Vec) -> Self { + + pub async fn mock_setup_with_channel() -> eyre::Result { + Self::mock_setup(NetworkType::LocalChannel).await + } + + pub async fn mock_setup_with_grpc() -> eyre::Result { + Self::mock_setup(NetworkType::GrpcChannel).await + } + + pub async fn new_with_network_type( + identities: Vec, + seeds: Vec, + network_type: NetworkType, + ) -> eyre::Result { let role_assignments: RoleAssignment = identities .iter() .enumerate() .map(|(index, id)| (Role::new(index), id.clone())) .collect(); - LocalRuntime { - identities, - role_assignments, - prf_setups: None, - seeds, - } - } - - pub async fn create_player_sessions(&self) -> eyre::Result> { - let network = LocalNetworkingStore::from_host_ids(&self.identities); - let sess_id = SessionId::from(0_u128); - let boot_sessions: Vec = (0..self.seeds.len()) - .map(|i| { - let identity = self.identities[i].clone(); - BootSession { - session_id: sess_id, - role_assignments: Arc::new(self.role_assignments.clone()), - networking: Arc::new(network.get_local_network(identity.clone())), - own_identity: identity, + let sess_id = SessionId::from(0_u64); + let boot_sessions = match network_type { + NetworkType::LocalChannel => { + let network = LocalNetworkingStore::from_host_ids(&identities); + let boot_sessions: Vec = (0..seeds.len()) + .map(|i| { + let identity = identities[i].clone(); + BootSession { + session_id: sess_id, + role_assignments: Arc::new(role_assignments.clone()), + networking: Arc::new(network.get_local_network(identity.clone())), + own_identity: identity, + } + }) + .collect(); + boot_sessions + } + NetworkType::GrpcChannel => { + let networks = setup_local_grpc_networking(identities.clone()).await?; + let mut jobs = JoinSet::new(); + for player in networks.iter() { + let player = player.clone(); + jobs.spawn(async move { + player.create_session(sess_id).await.unwrap(); + }); } - }) - .collect(); + jobs.join_all().await; + let boot_sessions: Vec = (0..seeds.len()) + .map(|i| { + let identity = identities[i].clone(); + BootSession { + session_id: sess_id, + role_assignments: Arc::new(role_assignments.clone()), + networking: Arc::new(networks[i].clone()), + own_identity: identity, + } + }) + .collect(); + boot_sessions + } + }; let mut jobs = JoinSet::new(); for (player_id, boot_session) in boot_sessions.iter().enumerate() { - let player_seed = self.seeds[player_id]; + let player_seed = seeds[player_id]; let sess = boot_session.clone(); jobs.spawn(async move { let prf = setup_replicated_prf(&sess, player_seed).await.unwrap(); (sess, prf) }); } - let mut complete_sessions = HashMap::new(); + let mut sessions = HashMap::new(); while let Some(t) = jobs.join_next().await { let (boot_session, prf) = t.unwrap(); - complete_sessions.insert(boot_session.own_identity(), Session { + sessions.insert(boot_session.own_identity(), Session { boot_session, setup: prf, }); } - Ok(complete_sessions) + Ok(LocalRuntime { + identities, + role_assignments, + seeds, + sessions, + }) + } + + pub async fn new(identities: Vec, seeds: Vec) -> eyre::Result { + Self::new_with_network_type(identities, seeds, NetworkType::LocalChannel).await + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_get_free_local_addresses() { + let mut jobs = JoinSet::new(); + let num_ports = 3; + + for _ in 0..100 { + jobs.spawn(async move { + let mut addresses = get_free_local_addresses(num_ports).await.unwrap(); + assert_eq!(addresses.len(), num_ports); + addresses.sort(); + addresses.dedup(); + assert_eq!(addresses.len(), num_ports); + }); + } + jobs.join_all().await; } } diff --git a/iris-mpc-cpu/src/execution/player.rs b/iris-mpc-cpu/src/execution/player.rs index 49690755a..94364f853 100644 --- a/iris-mpc-cpu/src/execution/player.rs +++ b/iris-mpc-cpu/src/execution/player.rs @@ -40,7 +40,7 @@ impl Role { } /// Retrieve index of Role (zero indexed) - pub fn zero_based(&self) -> usize { + pub fn index(&self) -> usize { self.0 as usize } diff --git a/iris-mpc-cpu/src/execution/session.rs b/iris-mpc-cpu/src/execution/session.rs index 8d7f05dd6..fba4403ea 100644 --- a/iris-mpc-cpu/src/execution/session.rs +++ b/iris-mpc-cpu/src/execution/session.rs @@ -5,20 +5,20 @@ use crate::{ }; use eyre::eyre; use serde::{Deserialize, Serialize}; -use std::{collections::HashMap, sync::Arc}; +use std::{collections::HashMap, fmt::Debug, sync::Arc}; #[derive(Serialize, Deserialize, Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub struct SessionId(pub u128); +pub struct SessionId(pub u64); -impl From for SessionId { - fn from(id: u128) -> Self { +impl From for SessionId { + fn from(id: u64) -> Self { SessionId(id) } } pub type NetworkingImpl = Arc; -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct Session { pub boot_session: BootSession, pub setup: Prf, @@ -32,6 +32,17 @@ pub struct BootSession { pub own_identity: Identity, } +impl Debug for BootSession { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // TODO: incorporate networking into debug output + f.debug_struct("BootSession") + .field("session_id", &self.session_id) + .field("role_assignments", &self.role_assignments) + .field("own_identity", &self.own_identity) + .finish() + } +} + pub trait SessionHandles { fn session_id(&self) -> SessionId; fn own_role(&self) -> eyre::Result; diff --git a/iris-mpc-cpu/src/hawkers/galois_store.rs b/iris-mpc-cpu/src/hawkers/galois_store.rs index 4fd542753..c960801f8 100644 --- a/iris-mpc-cpu/src/hawkers/galois_store.rs +++ b/iris-mpc-cpu/src/hawkers/galois_store.rs @@ -1,25 +1,68 @@ use super::plaintext_store::PlaintextStore; use crate::{ database_generators::{generate_galois_iris_shares, GaloisRingSharedIris}, - execution::{local::LocalRuntime, player::Identity, session::Session}, + execution::{ + local::{generate_local_identities, LocalRuntime}, + player::Identity, + session::Session, + }, hawkers::plaintext_store::PointId, + network::NetworkType, protocol::ops::{ - cross_compare, galois_ring_pairwise_distance, galois_ring_to_rep3, is_dot_zero, + batch_signed_lift_vec, compare_threshold_and_open, cross_compare, + galois_ring_pairwise_distance, galois_ring_to_rep3, + }, + py_bindings::{io::read_bin, plaintext_store::from_ndjson_file}, + shares::{ + ring_impl::RingElement, + share::{DistanceShare, Share}, }, - shares::{int_ring::IntRing2k, share::Share}, }; use aes_prng::AesRng; use hawk_pack::{ + data_structures::queue::FurthestQueue, graph_store::{graph_mem::Layer, GraphMem}, - hnsw_db::{FurthestQueue, HawkSearcher}, - GraphStore, VectorStore, + GraphStore, HawkSearcher, VectorStore, }; -use iris_mpc_common::iris_db::{db::IrisDB, iris::IrisCode}; +use iris_mpc_common::iris_db::db::IrisDB; use rand::{CryptoRng, RngCore, SeedableRng}; use serde::{Deserialize, Serialize}; -use std::collections::HashMap; +use std::{collections::HashMap, fmt::Debug, sync::Arc, vec}; use tokio::task::JoinSet; +#[derive(Copy, Default, Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct VectorId { + id: PointId, +} + +impl From for VectorId { + fn from(id: PointId) -> Self { + VectorId { id } + } +} + +impl From<&PointId> for VectorId { + fn from(id: &PointId) -> Self { + VectorId { id: *id } + } +} + +impl From for VectorId { + fn from(id: usize) -> Self { + VectorId { id: id.into() } + } +} + +type GaloisRingPoint = GaloisRingSharedIris; + +#[derive(Clone, Serialize, Deserialize, Hash, Eq, PartialEq, Debug)] +pub struct Query { + pub query: GaloisRingPoint, + pub processed_query: GaloisRingPoint, +} + +type QueryRef = Arc; + #[derive(Default, Clone)] pub struct Aby3NgStorePlayer { points: Vec, @@ -31,33 +74,34 @@ impl std::fmt::Debug for Aby3NgStorePlayer { } } -#[derive(Eq, PartialEq, Clone, Debug)] -struct GaloisRingPoint { - /// Whatever encoding of a vector. - data: GaloisRingSharedIris, -} - impl Aby3NgStorePlayer { pub fn new_with_shared_db(data: Vec) -> Self { - let points: Vec = data - .into_iter() - .map(|d| GaloisRingPoint { data: d }) - .collect(); - Aby3NgStorePlayer { points } + Aby3NgStorePlayer { points: data } } - pub fn prepare_query(&mut self, raw_query: GaloisRingSharedIris) -> PointId { - self.points.push(GaloisRingPoint { data: raw_query }); + pub fn prepare_query(&mut self, raw_query: GaloisRingSharedIris) -> QueryRef { + let mut preprocessed_query = raw_query.clone(); + preprocessed_query.code.preprocess_iris_code_query_share(); + preprocessed_query.mask.preprocess_mask_code_query_share(); - let point_id = self.points.len() - 1; - point_id.into() + Arc::new(Query { + query: raw_query, + processed_query: preprocessed_query, + }) + } + + pub fn get_vector(&self, vector: &VectorId) -> &GaloisRingPoint { + &self.points[vector.id] } } impl Aby3NgStorePlayer { - fn insert(&mut self, query: &PointId) -> PointId { - // The query is now accepted in the store. It keeps the same ID. - *query + fn insert(&mut self, query: &QueryRef) -> VectorId { + // The query is now accepted in the store. + self.points.push(query.query.clone()); + + let new_id = self.points.len() - 1; + VectorId { id: new_id.into() } } } @@ -68,112 +112,132 @@ pub fn setup_local_player_preloaded_db( Ok(aby3_store) } -pub fn setup_local_aby3_players_with_preloaded_db( +pub async fn setup_local_aby3_players_with_preloaded_db( rng: &mut R, - database: Vec, -) -> eyre::Result { - let mut p0 = Vec::new(); - let mut p1 = Vec::new(); - let mut p2 = Vec::new(); - - for iris in database { - let all_shares = generate_galois_iris_shares(rng, iris); - p0.push(all_shares[0].clone()); - p1.push(all_shares[1].clone()); - p2.push(all_shares[2].clone()); - } - - let player_0 = setup_local_player_preloaded_db(p0)?; - let player_1 = setup_local_player_preloaded_db(p1)?; - let player_2 = setup_local_player_preloaded_db(p2)?; - let players = HashMap::from([ - (Identity::from("alice"), player_0), - (Identity::from("bob"), player_1), - (Identity::from("charlie"), player_2), - ]); - let runtime = LocalRuntime::replicated_test_config(); - Ok(LocalNetAby3NgStoreProtocol { runtime, players }) + plain_store: &PlaintextStore, + network_t: NetworkType, +) -> eyre::Result> { + let identities = generate_local_identities(); + + let mut shared_irises = vec![vec![]; identities.len()]; + + for iris in plain_store.points.iter() { + let all_shares = generate_galois_iris_shares(rng, iris.data.0.clone()); + for (i, shares) in all_shares.iter().enumerate() { + shared_irises[i].push(shares.clone()); + } + } + + let storages: Vec = shared_irises + .into_iter() + .map(|player_irises| setup_local_player_preloaded_db(player_irises).unwrap()) + .collect(); + let runtime = LocalRuntime::mock_setup(network_t).await?; + + let local_stores = identities + .into_iter() + .zip(storages.into_iter()) + .map(|(identity, storage)| LocalNetAby3NgStoreProtocol { + runtime: runtime.clone(), + storage, + owner: identity, + }) + .collect(); + + Ok(local_stores) } #[derive(Debug, Clone)] pub struct LocalNetAby3NgStoreProtocol { - pub players: HashMap, + pub owner: Identity, + pub storage: Aby3NgStorePlayer, pub runtime: LocalRuntime, } -pub fn setup_local_store_aby3_players() -> eyre::Result { - let player_0 = Aby3NgStorePlayer::default(); - let player_1 = Aby3NgStorePlayer::default(); - let player_2 = Aby3NgStorePlayer::default(); - let runtime = LocalRuntime::replicated_test_config(); - let players = HashMap::from([ - (Identity::from("alice"), player_0), - (Identity::from("bob"), player_1), - (Identity::from("charlie"), player_2), - ]); - Ok(LocalNetAby3NgStoreProtocol { runtime, players }) -} - impl LocalNetAby3NgStoreProtocol { - pub fn prepare_query(&mut self, code: Vec) -> PointId { - assert_eq!(code.len(), 3); - assert_eq!(self.players.len(), 3); - let pid0 = self - .players - .get_mut(&Identity::from("alice")) - .unwrap() - .prepare_query(code[0].clone()); - let pid1 = self - .players - .get_mut(&Identity::from("bob")) - .unwrap() - .prepare_query(code[1].clone()); - let pid2 = self - .players - .get_mut(&Identity::from("charlie")) + pub fn get_owner_session(&self) -> Session { + self.runtime.sessions.get(&self.owner).unwrap().clone() + } + + pub fn get_owner_index(&self) -> usize { + self.runtime + .role_assignments + .iter() + .find_map(|(role, id)| { + if id.clone() == self.owner { + Some(role.clone()) + } else { + None + } + }) .unwrap() - .prepare_query(code[2].clone()); - assert_eq!(pid0, pid1); - assert_eq!(pid1, pid2); - pid0 + .index() } } -#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)] -#[serde(bound = "")] -pub struct DistanceShare { - code_dot: Share, - mask_dot: Share, - player: Identity, +pub async fn setup_local_store_aby3_players( + network_t: NetworkType, +) -> eyre::Result> { + let runtime = LocalRuntime::mock_setup(network_t).await?; + let players = generate_local_identities(); + let local_stores = players + .into_iter() + .map(|identity| LocalNetAby3NgStoreProtocol { + runtime: runtime.clone(), + storage: Aby3NgStorePlayer::default(), + owner: identity, + }) + .collect(); + Ok(local_stores) } -async fn eval_pairwise_distances( - mut pairs: Vec<(GaloisRingSharedIris, GaloisRingSharedIris)>, - player_session: &mut Session, -) -> Vec> { - pairs.iter_mut().for_each(|(_x, y)| { - y.code.preprocess_iris_code_query_share(); - y.mask.preprocess_mask_code_query_share(); - }); - let ds_and_ts = galois_ring_pairwise_distance(player_session, &pairs) - .await - .unwrap(); - galois_ring_to_rep3(player_session, ds_and_ts) - .await - .unwrap() +impl LocalNetAby3NgStoreProtocol { + pub fn prepare_query(&mut self, code: GaloisRingSharedIris) -> QueryRef { + self.storage.prepare_query(code) + } + + pub async fn lift_distances( + &mut self, + distances: Vec>, + ) -> eyre::Result>> { + if distances.is_empty() { + return Ok(vec![]); + } + let mut player_session = self.get_owner_session(); + let distances = batch_signed_lift_vec(&mut player_session, distances).await?; + Ok(distances + .chunks(2) + .map(|dot_products| { + DistanceShare::new(dot_products[0].clone(), dot_products[1].clone()) + }) + .collect::>()) + } + + /// Assumes that the first iris of each pair is preprocessed. + async fn eval_pairwise_distances( + &mut self, + pairs: Vec<(GaloisRingSharedIris, GaloisRingSharedIris)>, + ) -> Vec> { + if pairs.is_empty() { + return vec![]; + } + let mut player_session = self.get_owner_session(); + let ds_and_ts = galois_ring_pairwise_distance(&mut player_session, &pairs) + .await + .unwrap(); + galois_ring_to_rep3(&mut player_session, ds_and_ts) + .await + .unwrap() + } } impl VectorStore for LocalNetAby3NgStoreProtocol { - type QueryRef = PointId; // Vector ID, pending insertion. - type VectorRef = PointId; // Vector ID, inserted. - type DistanceRef = Vec>; // Distance represented as shares. + type QueryRef = QueryRef; // Point ID, pending insertion. + type VectorRef = VectorId; // Point ID, inserted. + type DistanceRef = DistanceShare; // Distance represented as shares. async fn insert(&mut self, query: &Self::QueryRef) -> Self::VectorRef { - // The query is now accepted in the store. It keeps the same ID. - for (_id, storage) in self.players.iter_mut() { - storage.insert(query); - } - *query + self.storage.insert(query) } async fn eval_distance( @@ -181,24 +245,10 @@ impl VectorStore for LocalNetAby3NgStoreProtocol { query: &Self::QueryRef, vector: &Self::VectorRef, ) -> Self::DistanceRef { - let ready_sessions = self.runtime.create_player_sessions().await.unwrap(); - let mut jobs = JoinSet::new(); - for player in self.runtime.identities.clone() { - let mut player_session = ready_sessions.get(&player).unwrap().clone(); - let storage = self.players.get(&player).unwrap(); - let query_point = storage.points[*query].clone(); - let vector_point = storage.points[*vector].clone(); - let pairs = vec![(query_point.data, vector_point.data)]; - jobs.spawn(async move { - let ds_and_ts = eval_pairwise_distances(pairs, &mut player_session).await; - DistanceShare { - code_dot: ds_and_ts[0].clone(), - mask_dot: ds_and_ts[1].clone(), - player: player.clone(), - } - }); - } - jobs.join_all().await + let vector_point = self.storage.get_vector(vector); + let pairs = vec![(query.processed_query.clone(), vector_point.clone())]; + let dist = self.eval_pairwise_distances(pairs).await; + self.lift_distances(dist).await.unwrap()[0].clone() } async fn eval_distance_batch( @@ -206,68 +256,25 @@ impl VectorStore for LocalNetAby3NgStoreProtocol { query: &Self::QueryRef, vectors: &[Self::VectorRef], ) -> Vec { - let ready_sessions = self.runtime.create_player_sessions().await.unwrap(); - let mut jobs = JoinSet::new(); - for player in self.runtime.identities.clone() { - let mut player_session = ready_sessions.get(&player).unwrap().clone(); - let storage = self.players.get(&player).unwrap(); - let query_point = storage.points[*query].clone(); - let pairs = vectors - .iter() - .map(|vector_id| { - let vector_point = storage.points[*vector_id].clone(); - (query_point.data.clone(), vector_point.data) - }) - .collect::>(); - jobs.spawn(async move { - let ds_and_ts = eval_pairwise_distances(pairs, &mut player_session).await; - ds_and_ts - .chunks(2) - .map(|dot_products| DistanceShare { - code_dot: dot_products[0].clone(), - mask_dot: dot_products[1].clone(), - player: player.clone(), - }) - .collect::>() - }); + if vectors.is_empty() { + return vec![]; } - // Now we have a vector of 3 vectors of DistanceShares, we need to transpose it - // to a vector of DistanceRef - let mut all_shares = jobs - .join_all() - .await - .into_iter() - .map(|player_shares| player_shares.into_iter()) - .collect::>(); - (0..vectors.len()) - .map(|_| { - all_shares - .iter_mut() - .map(|player_shares| player_shares.next().unwrap()) - .collect::() + let pairs = vectors + .iter() + .map(|vector_id| { + let vector_point = self.storage.get_vector(vector_id); + (query.processed_query.clone(), vector_point.clone()) }) - .collect::>() + .collect::>(); + let dist = self.eval_pairwise_distances(pairs).await; + self.lift_distances(dist).await.unwrap() } async fn is_match(&mut self, distance: &Self::DistanceRef) -> bool { - let ready_sessions = self.runtime.create_player_sessions().await.unwrap(); - let mut jobs = JoinSet::new(); - for distance_share in distance.iter() { - let mut player_session = ready_sessions.get(&distance_share.player).unwrap().clone(); - let code_dot = distance_share.code_dot.clone(); - let mask_dot = distance_share.mask_dot.clone(); - jobs.spawn(async move { - is_dot_zero(&mut player_session, code_dot, mask_dot) - .await - .unwrap() - }); - } - let r0 = jobs.join_next().await.unwrap().unwrap(); - let r1 = jobs.join_next().await.unwrap().unwrap(); - let r2 = jobs.join_next().await.unwrap().unwrap(); - assert_eq!(r0, r1); - assert_eq!(r1, r2); - r0 + let mut player_session = self.get_owner_session(); + compare_threshold_and_open(&mut player_session, distance.clone()) + .await + .unwrap() } async fn less_than( @@ -275,43 +282,58 @@ impl VectorStore for LocalNetAby3NgStoreProtocol { distance1: &Self::DistanceRef, distance2: &Self::DistanceRef, ) -> bool { - let ready_sessions = self.runtime.create_player_sessions().await.unwrap(); - let mut jobs = JoinSet::new(); - for share1 in distance1.iter() { - for share2 in distance2.iter() { - if share1.player == share2.player { - let mut player_session = ready_sessions.get(&share1.player).unwrap().clone(); - let code_dot1 = share1.code_dot.clone(); - let mask_dot1 = share1.mask_dot.clone(); - let code_dot2 = share2.code_dot.clone(); - let mask_dot2 = share2.mask_dot.clone(); - jobs.spawn(async move { - cross_compare( - &mut player_session, - code_dot1, - mask_dot1, - code_dot2, - mask_dot2, - ) - .await - .unwrap() - }); - } - } - } - let res = jobs.join_all().await; - assert_eq!(res[0], res[1]); - assert_eq!(res[0], res[2]); - res[0] + let mut player_session = self.get_owner_session(); + let code_dot1 = distance1.code_dot.clone(); + let mask_dot1 = distance1.mask_dot.clone(); + let code_dot2 = distance2.code_dot.clone(); + let mask_dot2 = distance2.mask_dot.clone(); + cross_compare( + &mut player_session, + code_dot1, + mask_dot1, + code_dot2, + mask_dot2, + ) + .await + .unwrap() } } impl LocalNetAby3NgStoreProtocol { + pub fn get_trivial_share(&self, distance: u16) -> Share { + let player = self.get_owner_index(); + let distance_elem = RingElement(distance as u32); + let zero_elem = RingElement(0_u32); + + match player { + 0 => Share::new(distance_elem, zero_elem), + 1 => Share::new(zero_elem, distance_elem), + 2 => Share::new(zero_elem, zero_elem), + _ => panic!("Invalid player index"), + } + } + + async fn eval_distance_vectors( + &mut self, + vector1: &::VectorRef, + vector2: &::VectorRef, + ) -> ::DistanceRef { + let point1 = self.storage.get_vector(vector1); + let mut point2 = self.storage.get_vector(vector2).clone(); + point2.code.preprocess_iris_code_query_share(); + point2.mask.preprocess_mask_code_query_share(); + let pairs = vec![(point1.clone(), point2.clone())]; + let dist = self.eval_pairwise_distances(pairs).await; + self.lift_distances(dist).await.unwrap()[0].clone() + } + async fn graph_from_plain( &mut self, - graph_store: GraphMem, + graph_store: &GraphMem, + recompute_distances: bool, ) -> GraphMem { let ep = graph_store.get_entry_point().await; + let new_ep = ep.map(|(vector_ref, layer_count)| (VectorId { id: vector_ref }, layer_count)); let layers = graph_store.get_layers(); @@ -320,108 +342,254 @@ impl LocalNetAby3NgStoreProtocol { let links = layer.get_links_map(); let mut shared_links = HashMap::new(); for (source_v, queue) in links { + let source_v = source_v.into(); let mut shared_queue = vec![]; - for (target_v, _) in queue.as_vec_ref() { - // recompute distances of graph edges from scratch - let shared_distance = self.eval_distance(source_v, target_v).await; - shared_queue.push((*target_v, shared_distance)); + for (target_v, dist) in queue.as_vec_ref() { + let target_v = target_v.into(); + let distance = if recompute_distances { + // recompute distances of graph edges from scratch + self.eval_distance_vectors(&source_v, &target_v).await + } else { + DistanceShare::new( + self.get_trivial_share(dist.0), + self.get_trivial_share(dist.1), + ) + }; + shared_queue.push((target_v, distance.clone())); } - shared_links.insert(*source_v, FurthestQueue::from_ascending_vec(shared_queue)); + shared_links.insert( + source_v, + FurthestQueue::from_ascending_vec(shared_queue.clone()), + ); } shared_layers.push(Layer::from_links(shared_links)); } - - GraphMem::from_precomputed(ep, shared_layers) + GraphMem::from_precomputed(new_ep, shared_layers) } } -pub async fn gr_create_ready_made_hawk_searcher( - rng: &mut R, - database_size: usize, -) -> eyre::Result<( - (PlaintextStore, GraphMem), - ( - LocalNetAby3NgStoreProtocol, - GraphMem, - ), -)> { - // makes sure the searcher produces same graph structure by having the same rng - let mut rng_searcher1 = AesRng::from_rng(rng.clone())?; - let cleartext_database = IrisDB::new_random_rng(database_size, rng).db; - - let mut plaintext_vector_store = PlaintextStore::default(); - let mut plaintext_graph_store = GraphMem::new(); - let searcher = HawkSearcher::default(); - - for raw_query in cleartext_database.iter() { - let query = plaintext_vector_store.prepare_query(raw_query.clone()); - let neighbors = searcher - .search_to_insert( - &mut plaintext_vector_store, - &mut plaintext_graph_store, - &query, - ) - .await; - let inserted = plaintext_vector_store.insert(&query).await; - searcher - .insert_from_search_results( - &mut plaintext_vector_store, - &mut plaintext_graph_store, - &mut rng_searcher1, - inserted, - neighbors, - ) - .await; - } - - let mut protocol_store = setup_local_aby3_players_with_preloaded_db(rng, cleartext_database)?; - let protocol_graph = protocol_store - .graph_from_plain(plaintext_graph_store.clone()) - .await; - - let plaintext = (plaintext_vector_store, plaintext_graph_store); - let secret = (protocol_store, protocol_graph); - Ok((plaintext, secret)) -} +impl LocalNetAby3NgStoreProtocol { + /// Generates 3 pairs of vector stores and graphs from a plaintext + /// vector store and graph read from disk, which are returned as well. + /// The network type is specified by the user. + /// A recompute flag is used to determine whether to recompute the distances + /// from stored shares. If recompute is set to false, the distances are + /// naively converted from plaintext. + pub async fn lazy_setup_from_files( + plainstore_file: &str, + plaingraph_file: &str, + rng: &mut R, + database_size: usize, + network_t: NetworkType, + recompute_distances: bool, + ) -> eyre::Result<( + (PlaintextStore, GraphMem), + Vec<(Self, GraphMem)>, + )> { + if database_size > 100_000 { + return Err(eyre::eyre!("Database size too large, max. 100,000")); + } + let generation_comment = "Please, generate benchmark data with cargo run --release --bin \ + generate_benchmark_data."; + let plaintext_vector_store = from_ndjson_file(plainstore_file, Some(database_size)) + .map_err(|e| eyre::eyre!("Cannot find store: {e}. {generation_comment}"))?; + let plaintext_graph_store: GraphMem = read_bin(plaingraph_file) + .map_err(|e| eyre::eyre!("Cannot find graph: {e}. {generation_comment}"))?; -pub async fn ng_create_from_scratch_hawk_searcher( - rng: &mut R, - database_size: usize, -) -> eyre::Result<( - LocalNetAby3NgStoreProtocol, - GraphMem, -)> { - let mut rng_searcher = AesRng::from_rng(rng.clone())?; - let cleartext_database = IrisDB::new_random_rng(database_size, rng).db; - let shared_irises: Vec<_> = (0..database_size) - .map(|id| generate_galois_iris_shares(rng, cleartext_database[id].clone())) - .collect(); + let protocol_stores = + setup_local_aby3_players_with_preloaded_db(rng, &plaintext_vector_store, network_t) + .await?; - let searcher = HawkSearcher::default(); - let mut aby3_store_protocol = setup_local_store_aby3_players().unwrap(); - let mut graph_store = GraphMem::new(); - - let queries = (0..database_size) - .map(|id| aby3_store_protocol.prepare_query(shared_irises[id].clone())) - .collect::>(); - - // insert queries - for query in queries.iter() { - let neighbors = searcher - .search_to_insert(&mut aby3_store_protocol, &mut graph_store, query) - .await; - searcher - .insert_from_search_results( - &mut aby3_store_protocol, - &mut graph_store, - &mut rng_searcher, - *query, - neighbors, - ) - .await; - } - - Ok((aby3_store_protocol, graph_store)) + let mut jobs = JoinSet::new(); + for store in protocol_stores.iter() { + let mut store = store.clone(); + let plaintext_graph_store = plaintext_graph_store.clone(); + jobs.spawn(async move { + ( + store.clone(), + store + .graph_from_plain(&plaintext_graph_store, recompute_distances) + .await, + ) + }); + } + let mut secret_shared_stores = jobs.join_all().await; + secret_shared_stores.sort_by_key(|(store, _)| store.get_owner_index()); + let plaintext = (plaintext_vector_store, plaintext_graph_store); + Ok((plaintext, secret_shared_stores)) + } + + /// Generates 3 pairs of vector stores and graphs from a plaintext + /// vector store and graph read from disk, which are returned as well. + /// Networking is based on gRPC. + pub async fn lazy_setup_from_files_with_grpc( + plainstore_file: &str, + plaingraph_file: &str, + rng: &mut R, + database_size: usize, + recompute_distances: bool, + ) -> eyre::Result<( + (PlaintextStore, GraphMem), + Vec<(Self, GraphMem)>, + )> { + Self::lazy_setup_from_files( + plainstore_file, + plaingraph_file, + rng, + database_size, + NetworkType::GrpcChannel, + recompute_distances, + ) + .await + } + + /// Generates 3 pairs of vector stores and graphs from a random plaintext + /// vector store and graph, which are returned as well. + /// The network type is specified by the user. + /// A recompute flag is used to determine whether to recompute the distances + /// from stored shares. If recompute is set to false, the distances are + /// naively converted from plaintext. + pub async fn lazy_random_setup( + rng: &mut R, + database_size: usize, + network_t: NetworkType, + recompute_distances: bool, + ) -> eyre::Result<( + (PlaintextStore, GraphMem), + Vec<(Self, GraphMem)>, + )> { + let (plaintext_vector_store, plaintext_graph_store) = + PlaintextStore::create_random(rng, database_size).await?; + + let protocol_stores = + setup_local_aby3_players_with_preloaded_db(rng, &plaintext_vector_store, network_t) + .await?; + + let mut jobs = JoinSet::new(); + for store in protocol_stores.iter() { + let mut store = store.clone(); + let plaintext_graph_store = plaintext_graph_store.clone(); + jobs.spawn(async move { + ( + store.clone(), + store + .graph_from_plain(&plaintext_graph_store, recompute_distances) + .await, + ) + }); + } + let mut secret_shared_stores = jobs.join_all().await; + secret_shared_stores.sort_by_key(|(store, _)| store.get_owner_index()); + let plaintext = (plaintext_vector_store, plaintext_graph_store); + Ok((plaintext, secret_shared_stores)) + } + + /// Generates 3 pairs of vector stores and graphs from a random plaintext + /// vector store and graph, which are returned as well. Networking is + /// based on local async_channel. + pub async fn lazy_random_setup_with_local_channel( + rng: &mut R, + database_size: usize, + recompute_distances: bool, + ) -> eyre::Result<( + (PlaintextStore, GraphMem), + Vec<( + LocalNetAby3NgStoreProtocol, + GraphMem, + )>, + )> { + Self::lazy_random_setup( + rng, + database_size, + NetworkType::LocalChannel, + recompute_distances, + ) + .await + } + + /// Generates 3 pairs of vector stores and graphs from a random plaintext + /// vector store and graph, which are returned as well. Networking is + /// based on gRPC. + pub async fn lazy_random_setup_with_grpc( + rng: &mut R, + database_size: usize, + recompute_distances: bool, + ) -> eyre::Result<( + (PlaintextStore, GraphMem), + Vec<( + LocalNetAby3NgStoreProtocol, + GraphMem, + )>, + )> { + Self::lazy_random_setup( + rng, + database_size, + NetworkType::GrpcChannel, + recompute_distances, + ) + .await + } + + /// Generates 3 pairs of vector stores and graphs corresponding to each + /// local player. + pub async fn shared_random_setup( + rng: &mut R, + database_size: usize, + network_t: NetworkType, + ) -> eyre::Result)>> { + let rng_searcher = AesRng::from_rng(rng.clone())?; + let cleartext_database = IrisDB::new_random_rng(database_size, rng).db; + let shared_irises: Vec<_> = (0..database_size) + .map(|id| generate_galois_iris_shares(rng, cleartext_database[id].clone())) + .collect(); + + let mut local_stores = setup_local_store_aby3_players(network_t).await?; + + let mut jobs = JoinSet::new(); + for store in local_stores.iter_mut() { + let mut store = store.clone(); + let role = store.get_owner_index(); + let mut rng_searcher = rng_searcher.clone(); + let queries = (0..database_size) + .map(|id| store.prepare_query(shared_irises[id][role].clone())) + .collect::>(); + jobs.spawn(async move { + let mut graph_store = GraphMem::new(); + let searcher = HawkSearcher::default(); + // insert queries + for query in queries.iter() { + searcher + .insert(&mut store, &mut graph_store, query, &mut rng_searcher) + .await; + } + (store, graph_store) + }); + } + let mut result = jobs.join_all().await; + // preserve order of players + result.sort_by(|(store1, _), (store2, _)| { + store1.get_owner_index().cmp(&store2.get_owner_index()) + }); + Ok(result) + } + + /// Generates 3 pairs of vector stores and graphs corresponding to each + /// local player. Networking is based on local async_channel. + pub async fn shared_random_setup_with_local_channel( + rng: &mut R, + database_size: usize, + ) -> eyre::Result)>> { + Self::shared_random_setup(rng, database_size, NetworkType::LocalChannel).await + } + + /// Generates 3 pairs of vector stores and graphs corresponding to each + /// local player. Networking is based on gRPC. + pub async fn shared_random_setup_with_grpc( + rng: &mut R, + database_size: usize, + ) -> eyre::Result)>> { + Self::shared_random_setup(rng, database_size, NetworkType::GrpcChannel).await + } } #[cfg(test)] @@ -429,8 +597,7 @@ mod tests { use super::*; use crate::database_generators::generate_galois_iris_shares; use aes_prng::AesRng; - use hawk_pack::{graph_store::GraphMem, hnsw_db::HawkSearcher}; - use iris_mpc_common::iris_db::db::IrisDB; + use hawk_pack::{graph_store::GraphMem, HawkSearcher}; use itertools::Itertools; use rand::SeedableRng; use tracing_test::traced_test; @@ -440,47 +607,56 @@ mod tests { let mut rng = AesRng::seed_from_u64(0_u64); let database_size = 10; let cleartext_database = IrisDB::new_random_rng(database_size, &mut rng).db; + let shared_irises: Vec<_> = cleartext_database + .iter() + .map(|iris| generate_galois_iris_shares(&mut rng, iris.clone())) + .collect(); - let mut aby3_store = setup_local_store_aby3_players().unwrap(); - let mut aby3_graph = GraphMem::new(); - let db = HawkSearcher::default(); - - let queries = (0..database_size) - .map(|id| { - aby3_store.prepare_query(generate_galois_iris_shares( - &mut rng, - cleartext_database[id].clone(), - )) - }) - .collect::>(); + let mut stores = setup_local_store_aby3_players(NetworkType::LocalChannel) + .await + .unwrap(); - // insert queries - for query in queries.iter() { - let neighbors = db - .search_to_insert(&mut aby3_store, &mut aby3_graph, query) - .await; - db.insert_from_search_results( - &mut aby3_store, - &mut aby3_graph, - &mut rng, - *query, - neighbors, - ) - .await; + let mut jobs = JoinSet::new(); + for store in stores.iter_mut() { + let player_index = store.get_owner_index(); + let queries = (0..database_size) + .map(|id| store.prepare_query(shared_irises[id][player_index].clone())) + .collect::>(); + let mut store = store.clone(); + let mut rng = rng.clone(); + jobs.spawn(async move { + let mut aby3_graph = GraphMem::new(); + let db = HawkSearcher::default(); + + let mut inserted = vec![]; + // insert queries + for query in queries.iter() { + let inserted_vector = db + .insert(&mut store, &mut aby3_graph, query, &mut rng) + .await; + inserted.push(inserted_vector) + } + tracing::debug!("FINISHED INSERTING"); + // Search for the same codes and find matches. + let mut matching_results = vec![]; + for v in inserted.into_iter() { + let query = store.prepare_query(store.storage.get_vector(&v).clone()); + let neighbors = db.search(&mut store, &mut aby3_graph, &query, 1).await; + tracing::debug!("Finished checking query"); + matching_results.push(db.is_match(&mut store, &[neighbors]).await) + } + matching_results + }); } - println!("FINISHED INSERTING"); - // Search for the same codes and find matches. - for (index, query) in queries.iter().enumerate() { - let neighbors = db - .search_to_insert(&mut aby3_store, &mut aby3_graph, query) - .await; - // assert_eq!(false, true); - tracing::debug!("Finished query"); - assert!( - db.is_match(&mut aby3_store, &neighbors).await, - "failed at index {:?}", - index - ); + let matching_results = jobs.join_all().await; + for (party_id, party_results) in matching_results.iter().enumerate() { + for (index, result) in party_results.iter().enumerate() { + assert!( + *result, + "Failed at index {:?} for party {:?}", + index, party_id + ); + } } } @@ -489,85 +665,70 @@ mod tests { async fn test_gr_premade_hnsw() { let mut rng = AesRng::seed_from_u64(0_u64); let database_size = 10; - let (mut cleartext_data, mut secret_data) = - gr_create_ready_made_hawk_searcher(&mut rng, database_size) - .await - .unwrap(); + let network_t = NetworkType::LocalChannel; + let (mut cleartext_data, secret_data) = LocalNetAby3NgStoreProtocol::lazy_random_setup( + &mut rng, + database_size, + network_t.clone(), + true, + ) + .await + .unwrap(); let mut rng = AesRng::seed_from_u64(0_u64); - let (mut vector_store, mut graph_store) = - ng_create_from_scratch_hawk_searcher(&mut rng, database_size) + let vector_graph_stores = + LocalNetAby3NgStoreProtocol::shared_random_setup(&mut rng, database_size, network_t) .await .unwrap(); - assert_eq!( - vector_store - .players - .get(&Identity::from("alice")) - .unwrap() - .points, - secret_data - .0 - .players - .get(&Identity::from("alice")) - .unwrap() - .points - ); - assert_eq!( - vector_store - .players - .get(&Identity::from("bob")) - .unwrap() - .points, - secret_data - .0 - .players - .get(&Identity::from("bob")) - .unwrap() - .points - ); - assert_eq!( - vector_store - .players - .get(&Identity::from("charlie")) - .unwrap() - .points, - secret_data - .0 - .players - .get(&Identity::from("charlie")) - .unwrap() - .points - ); + for ((v_from_scratch, _), (premade_v, _)) in + vector_graph_stores.iter().zip(secret_data.iter()) + { + assert_eq!(v_from_scratch.storage.points, premade_v.storage.points); + } let hawk_searcher = HawkSearcher::default(); for i in 0..database_size { let cleartext_neighbors = hawk_searcher - .search_to_insert(&mut cleartext_data.0, &mut cleartext_data.1, &i.into()) + .search(&mut cleartext_data.0, &mut cleartext_data.1, &i.into(), 1) .await; assert!( hawk_searcher - .is_match(&mut cleartext_data.0, &cleartext_neighbors) + .is_match(&mut cleartext_data.0, &[cleartext_neighbors]) .await, ); - let secret_neighbors = hawk_searcher - .search_to_insert(&mut secret_data.0, &mut secret_data.1, &i.into()) - .await; - assert!( - hawk_searcher - .is_match(&mut secret_data.0, &secret_neighbors) - .await - ); + let mut jobs = JoinSet::new(); + for (v, g) in vector_graph_stores.iter() { + let hawk_searcher = hawk_searcher.clone(); + let mut v = v.clone(); + let mut g = g.clone(); + let q = v.prepare_query(v.storage.get_vector(&i.into()).clone()); + jobs.spawn(async move { + let secret_neighbors = hawk_searcher.search(&mut v, &mut g, &q, 1).await; + + hawk_searcher.is_match(&mut v, &[secret_neighbors]).await + }); + } + let scratch_results = jobs.join_all().await; + + let mut jobs = JoinSet::new(); + for (v, g) in secret_data.iter() { + let hawk_searcher = hawk_searcher.clone(); + let mut v = v.clone(); + let mut g = g.clone(); + jobs.spawn(async move { + let query = v.prepare_query(v.storage.get_vector(&i.into()).clone()); + let secret_neighbors = hawk_searcher.search(&mut v, &mut g, &query, 1).await; + + hawk_searcher.is_match(&mut v, &[secret_neighbors]).await + }); + } + let premade_results = jobs.join_all().await; - let scratch_secret_neighbors = hawk_searcher - .search_to_insert(&mut vector_store, &mut graph_store, &i.into()) - .await; - assert!( - hawk_searcher - .is_match(&mut vector_store, &scratch_secret_neighbors) - .await, - ); + for (premade_res, scratch_res) in scratch_results.iter().zip(premade_results.iter()) { + assert!(*premade_res && *scratch_res); + } } } @@ -577,21 +738,13 @@ mod tests { let mut rng = AesRng::seed_from_u64(0_u64); let db_dim = 4; let cleartext_database = IrisDB::new_random_rng(db_dim, &mut rng).db; - - let mut aby3_store_protocol = setup_local_store_aby3_players().unwrap(); - - let aby3_preps: Vec<_> = (0..db_dim) - .map(|id| { - aby3_store_protocol.prepare_query(generate_galois_iris_shares( - &mut rng, - cleartext_database[id].clone(), - )) - }) + let shared_irises: Vec<_> = cleartext_database + .iter() + .map(|iris| generate_galois_iris_shares(&mut rng, iris.clone())) .collect(); - let mut aby3_inserts = Vec::new(); - for p in aby3_preps.iter() { - aby3_inserts.push(aby3_store_protocol.insert(p).await); - } + let mut local_stores = setup_local_store_aby3_players(NetworkType::LocalChannel) + .await + .unwrap(); // Now do the work for the plaintext store let mut plaintext_store = PlaintextStore::default(); let plaintext_preps: Vec<_> = (0..db_dim) @@ -601,31 +754,76 @@ mod tests { for p in plaintext_preps.iter() { plaintext_inserts.push(plaintext_store.insert(p).await); } + + // pairs of indices to compare let it1 = (0..db_dim).combinations(2); let it2 = (0..db_dim).combinations(2); - for comb1 in it1 { + + let mut plain_results = HashMap::new(); + for comb1 in it1.clone() { for comb2 in it2.clone() { - let dist1_aby3 = aby3_store_protocol - .eval_distance(&aby3_inserts[comb1[0]], &aby3_inserts[comb1[1]]) - .await; - let dist2_aby3 = aby3_store_protocol - .eval_distance(&aby3_inserts[comb2[0]], &aby3_inserts[comb2[1]]) - .await; + // compute distances in plaintext let dist1_plain = plaintext_store .eval_distance(&plaintext_inserts[comb1[0]], &plaintext_inserts[comb1[1]]) .await; let dist2_plain = plaintext_store .eval_distance(&plaintext_inserts[comb2[0]], &plaintext_inserts[comb2[1]]) .await; - assert_eq!( - aby3_store_protocol - .less_than(&dist1_aby3, &dist2_aby3) - .await, - plaintext_store.less_than(&dist1_plain, &dist2_plain).await, - "Failed at combo: {:?}, {:?}", - comb1, - comb2 - ) + let bit = plaintext_store.less_than(&dist1_plain, &dist2_plain).await; + plain_results.insert((comb1.clone(), comb2.clone()), bit); + } + } + + let mut aby3_inserts = vec![]; + for store in local_stores.iter_mut() { + let player_index = store.get_owner_index(); + let player_preps: Vec<_> = (0..db_dim) + .map(|id| store.prepare_query(shared_irises[id][player_index].clone())) + .collect(); + let mut player_inserts = vec![]; + for p in player_preps.iter() { + player_inserts.push(store.insert(p).await); + } + aby3_inserts.push(player_inserts); + } + + for comb1 in it1 { + for comb2 in it2.clone() { + let mut jobs = JoinSet::new(); + for store in local_stores.iter() { + let player_index = store.get_owner_index(); + let player_inserts = aby3_inserts[player_index].clone(); + let mut store = store.clone(); + let index10 = comb1[0]; + let index11 = comb1[1]; + let index20 = comb2[0]; + let index21 = comb2[1]; + jobs.spawn(async move { + let dist1_aby3 = store + .eval_distance_vectors( + &player_inserts[index10], + &player_inserts[index11], + ) + .await; + let dist2_aby3 = store + .eval_distance_vectors( + &player_inserts[index20], + &player_inserts[index21], + ) + .await; + store.less_than(&dist1_aby3, &dist2_aby3).await + }); + } + let res = jobs.join_all().await; + for bit in res { + assert_eq!( + bit, + plain_results[&(comb1.clone(), comb2.clone())], + "Failed at combo: {:?}, {:?}", + comb1, + comb2 + ) + } } } } @@ -636,19 +834,30 @@ mod tests { let mut rng = AesRng::seed_from_u64(0_u64); let database_size = 2; let searcher = HawkSearcher::default(); - let (mut vector, mut graph) = ng_create_from_scratch_hawk_searcher(&mut rng, database_size) - .await - .unwrap(); + let mut vectors_and_graphs = LocalNetAby3NgStoreProtocol::shared_random_setup( + &mut rng, + database_size, + NetworkType::LocalChannel, + ) + .await + .unwrap(); for i in 0..database_size { - let secret_neighbors = searcher - .search_to_insert(&mut vector, &mut graph, &i.into()) - .await; - assert!( - searcher.is_match(&mut vector, &secret_neighbors).await, - "Failed at index {:?}", - i - ); + let mut jobs = JoinSet::new(); + for (store, graph) in vectors_and_graphs.iter_mut() { + let mut store = store.clone(); + let mut graph = graph.clone(); + let searcher = searcher.clone(); + let q = store.prepare_query(store.storage.get_vector(&i.into()).clone()); + jobs.spawn(async move { + let secret_neighbors = searcher.search(&mut store, &mut graph, &q, 1).await; + searcher.is_match(&mut store, &[secret_neighbors]).await + }); + } + let res = jobs.join_all().await; + for (party_index, r) in res.iter().enumerate() { + assert!(r, "Failed at index {:?} by party {:?}", i, party_index); + } } } } diff --git a/iris-mpc-cpu/src/hawkers/plaintext_store.rs b/iris-mpc-cpu/src/hawkers/plaintext_store.rs index 7111425b5..80a462a82 100644 --- a/iris-mpc-cpu/src/hawkers/plaintext_store.rs +++ b/iris-mpc-cpu/src/hawkers/plaintext_store.rs @@ -1,13 +1,14 @@ -use hawk_pack::VectorStore; -use iris_mpc_common::iris_db::iris::{IrisCode, MATCH_THRESHOLD_RATIO}; +use aes_prng::AesRng; +use hawk_pack::{graph_store::GraphMem, HawkSearcher, VectorStore}; +use iris_mpc_common::iris_db::{ + db::IrisDB, + iris::{IrisCode, MATCH_THRESHOLD_RATIO}, +}; +use rand::{CryptoRng, RngCore, SeedableRng}; +use serde::{Deserialize, Serialize}; use std::ops::{Index, IndexMut}; -#[derive(Default, Debug, Clone)] -pub struct PlaintextStore { - pub points: Vec, -} - -#[derive(Default, Debug, Clone)] +#[derive(Default, Debug, Clone, Serialize, Deserialize)] pub struct PlaintextIris(pub IrisCode); impl PlaintextIris { @@ -42,17 +43,19 @@ impl PlaintextIris { } } -#[derive(Clone, Default, Debug)] +// TODO refactor away is_persistent flag; should probably be stored in a +// separate buffer instead whenever working with non-persistent iris codes +#[derive(Clone, Default, Debug, Serialize, Deserialize)] pub struct PlaintextPoint { /// Whatever encoding of a vector. - data: PlaintextIris, + pub data: PlaintextIris, /// Distinguish between queries that are pending, and those that were /// ultimately accepted into the vector store. - is_persistent: bool, + pub is_persistent: bool, } -#[derive(Copy, Debug, Clone, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize)] -pub struct PointId(u32); +#[derive(Copy, Default, Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct PointId(pub u32); impl Index for Vec { type Output = T; @@ -80,6 +83,11 @@ impl From for PointId { } } +#[derive(Default, Debug, Clone, Serialize, Deserialize)] +pub struct PlaintextStore { + pub points: Vec, +} + impl PlaintextStore { pub fn prepare_query(&mut self, raw_query: IrisCode) -> ::QueryRef { self.points.push(PlaintextPoint { @@ -129,12 +137,80 @@ impl VectorStore for PlaintextStore { } } +impl PlaintextStore { + pub async fn create_random( + rng: &mut R, + database_size: usize, + ) -> eyre::Result<(Self, GraphMem)> { + // makes sure the searcher produces same graph structure by having the same rng + let mut rng_searcher1 = AesRng::from_rng(rng.clone())?; + let cleartext_database = IrisDB::new_random_rng(database_size, rng).db; + + let mut plaintext_vector_store = PlaintextStore::default(); + let mut plaintext_graph_store = GraphMem::new(); + let searcher = HawkSearcher::default(); + + for raw_query in cleartext_database.iter() { + let query = plaintext_vector_store.prepare_query(raw_query.clone()); + searcher + .insert( + &mut plaintext_vector_store, + &mut plaintext_graph_store, + &query, + &mut rng_searcher1, + ) + .await; + } + + Ok((plaintext_vector_store, plaintext_graph_store)) + } + + pub async fn create_random_store( + rng: &mut R, + database_size: usize, + ) -> eyre::Result { + let cleartext_database = IrisDB::new_random_rng(database_size, rng).db; + + let mut plaintext_vector_store = PlaintextStore::default(); + + for raw_query in cleartext_database.iter() { + let query = plaintext_vector_store.prepare_query(raw_query.clone()); + let _ = plaintext_vector_store.insert(&query).await; + } + + Ok(plaintext_vector_store) + } + + pub async fn create_graph( + &mut self, + rng: &mut R, + graph_size: usize, + ) -> eyre::Result> { + let mut rng_searcher1 = AesRng::from_rng(rng.clone())?; + + let mut plaintext_graph_store = GraphMem::new(); + let searcher = HawkSearcher::default(); + + for i in 0..graph_size { + searcher + .insert( + self, + &mut plaintext_graph_store, + &i.into(), + &mut rng_searcher1, + ) + .await; + } + + Ok(plaintext_graph_store) + } +} + #[cfg(test)] mod tests { use super::*; - use crate::hawkers::galois_store::gr_create_ready_made_hawk_searcher; use aes_prng::AesRng; - use hawk_pack::hnsw_db::HawkSearcher; + use hawk_pack::HawkSearcher; use iris_mpc_common::iris_db::db::IrisDB; use rand::SeedableRng; use tracing_test::traced_test; @@ -217,17 +293,17 @@ mod tests { let mut rng = AesRng::seed_from_u64(0_u64); let database_size = 1; let searcher = HawkSearcher::default(); - let ((mut ptxt_vector, mut ptxt_graph), _) = - gr_create_ready_made_hawk_searcher(&mut rng, database_size) + let (mut ptxt_vector, mut ptxt_graph) = + PlaintextStore::create_random(&mut rng, database_size) .await .unwrap(); for i in 0..database_size { let cleartext_neighbors = searcher - .search_to_insert(&mut ptxt_vector, &mut ptxt_graph, &i.into()) + .search(&mut ptxt_vector, &mut ptxt_graph, &i.into(), 1) .await; assert!( searcher - .is_match(&mut ptxt_vector, &cleartext_neighbors) + .is_match(&mut ptxt_vector, &[cleartext_neighbors]) .await, ); } diff --git a/iris-mpc-cpu/src/lib.rs b/iris-mpc-cpu/src/lib.rs index fb378ddd0..bf4a96011 100644 --- a/iris-mpc-cpu/src/lib.rs +++ b/iris-mpc-cpu/src/lib.rs @@ -2,5 +2,8 @@ pub mod database_generators; pub mod execution; pub mod hawkers; pub(crate) mod network; +#[rustfmt::skip] +pub(crate) mod proto_generated; pub mod protocol; +pub mod py_bindings; pub(crate) mod shares; diff --git a/iris-mpc-cpu/src/network/grpc.rs b/iris-mpc-cpu/src/network/grpc.rs new file mode 100644 index 000000000..bd3185532 --- /dev/null +++ b/iris-mpc-cpu/src/network/grpc.rs @@ -0,0 +1,600 @@ +use super::Networking; +use crate::{ + execution::{local::get_free_local_addresses, player::Identity}, + network::SessionId, + proto_generated::party_node::{ + party_node_client::PartyNodeClient, + party_node_server::{PartyNode, PartyNodeServer}, + SendRequest, SendResponse, + }, +}; +use backoff::{future::retry, ExponentialBackoff}; +use dashmap::DashMap; +use eyre::eyre; +use std::{str::FromStr, sync::Arc, time::Duration}; +use tokio::{ + sync::{ + mpsc::{self, UnboundedSender}, + Mutex, + }, + time::timeout, +}; +use tokio_stream::StreamExt; +use tonic::{ + async_trait, + metadata::AsciiMetadataValue, + transport::{Channel, Server}, + Request, Response, Status, Streaming, +}; + +type TonicResult = Result; + +fn err_to_status(e: eyre::Error) -> Status { + Status::internal(e.to_string()) +} + +struct MessageQueueStore { + queues: DashMap>>, +} + +impl MessageQueueStore { + fn new() -> Self { + MessageQueueStore { + queues: DashMap::new(), + } + } + + fn insert(&self, sender_id: Identity, stream: Streaming) -> eyre::Result<()> { + if self.queues.contains_key(&sender_id) { + return Err(eyre!("Player {:?} already has a message queue", sender_id)); + } + self.queues.insert(sender_id, Mutex::new(stream)); + Ok(()) + } + + async fn pop(&self, sender_id: &Identity) -> eyre::Result> { + let queue = self.queues.get(sender_id).ok_or(eyre!(format!( + "RECEIVE: Sender {sender_id:?} hasn't been found in the message queues" + )))?; + + let mut queue = queue.lock().await; + + let msg = queue.next().await.ok_or(eyre!("No message received"))??; + + Ok(msg.data) + } +} + +struct OutgoingStreams { + streams: DashMap<(SessionId, Identity), Arc>>, +} + +impl OutgoingStreams { + fn new() -> Self { + OutgoingStreams { + streams: DashMap::new(), + } + } + + fn add_session_stream( + &self, + session_id: SessionId, + receiver_id: Identity, + stream: UnboundedSender, + ) { + self.streams + .insert((session_id, receiver_id), Arc::new(stream)); + } + + fn get_stream( + &self, + session_id: SessionId, + receiver_id: Identity, + ) -> eyre::Result>> { + self.streams + .get(&(session_id, receiver_id.clone())) + .ok_or(eyre!( + "Streams for session {session_id:?} and receiver {receiver_id:?} not found" + )) + .map(|s| s.value().clone()) + } +} + +#[derive(Default, Clone)] +pub struct GrpcConfig { + pub timeout_duration: Duration, +} + +// WARNING: this implementation assumes that messages for a specific player +// within one session are sent in order and consecutively. Don't send messages +// to the same player in parallel within the same session. Use batching instead. +#[derive(Clone)] +pub struct GrpcNetworking { + party_id: Identity, + // other party id -> client to call that party + clients: Arc>>, + // other party id -> outgoing streams to send messages to that party in different sessions + outgoing_streams: Arc, + // session id -> incoming message streams + message_queues: Arc>, + + pub config: GrpcConfig, +} + +impl GrpcNetworking { + pub fn new(party_id: Identity, config: GrpcConfig) -> Self { + GrpcNetworking { + party_id, + clients: Arc::new(DashMap::new()), + outgoing_streams: Arc::new(OutgoingStreams::new()), + message_queues: Arc::new(DashMap::new()), + config, + } + } + + pub async fn connect_to_party(&self, party_id: Identity, address: &str) -> eyre::Result<()> { + let client = PartyNodeClient::connect(address.to_string()).await?; + self.clients.insert(party_id.clone(), client); + Ok(()) + } + + pub async fn create_session(&self, session_id: SessionId) -> eyre::Result<()> { + if self.message_queues.contains_key(&session_id) { + return Err(eyre!( + "Player {:?} has already created session {session_id:?}", + self.party_id + )); + } + + for mut client in self.clients.iter_mut() { + let (tx, rx) = mpsc::unbounded_channel(); + self.outgoing_streams + .add_session_stream(session_id, client.key().clone(), tx); + let receiving_stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx); + let mut request = Request::new(receiving_stream); + request.metadata_mut().insert( + "sender_id", + AsciiMetadataValue::from_str(&self.party_id.0).unwrap(), + ); + request.metadata_mut().insert( + "session_id", + AsciiMetadataValue::from_str(&session_id.0.to_string()).unwrap(), + ); + let _response = client.value_mut().send_message(request).await?; + } + Ok(()) + } +} + +// Server implementation +#[async_trait] +impl PartyNode for GrpcNetworking { + async fn send_message( + &self, + request: Request>, + ) -> TonicResult> { + let sender_id: Identity = request + .metadata() + .get("sender_id") + .ok_or(Status::unauthenticated("Sender ID not found"))? + .to_str() + .map_err(|_| Status::unauthenticated("Sender ID is not a string"))? + .to_string() + .into(); + if sender_id == self.party_id { + return Err(Status::unauthenticated(format!( + "Sender ID coincides with receiver ID: {:?}", + sender_id + ))); + } + let session_id: u64 = request + .metadata() + .get("session_id") + .ok_or(Status::not_found("Session ID not found"))? + .to_str() + .map_err(|_| Status::not_found("Session ID malformed"))? + .parse() + .map_err(|_| Status::invalid_argument("Session ID is not a u64 number"))?; + let session_id = SessionId::from(session_id); + + let incoming_stream = request.into_inner(); + + tracing::trace!( + "Player {:?}. Creating session {:?} for player {:?}", + self.party_id, + session_id, + sender_id + ); + let message_queue = self + .message_queues + .entry(session_id) + .or_insert(MessageQueueStore::new()); + + message_queue + .insert(sender_id, incoming_stream) + .map_err(err_to_status)?; + + Ok(Response::new(SendResponse {})) + } +} + +// Client implementation +#[async_trait] +impl Networking for GrpcNetworking { + async fn send( + &self, + value: Vec, + receiver: &Identity, + session_id: &SessionId, + ) -> eyre::Result<()> { + let backoff = ExponentialBackoff { + max_elapsed_time: Some(std::time::Duration::from_secs(2)), + max_interval: std::time::Duration::from_secs(1), + multiplier: 1.1, + ..Default::default() + }; + let outgoing_stream = self + .outgoing_streams + .get_stream(*session_id, receiver.clone())?; + + // Send message via the outgoing stream + let request = SendRequest { data: value }; + retry(backoff, || async { + tracing::trace!( + "INIT: Sending message {:?} from {:?} to {:?} in session {:?}", + request.data, + self.party_id, + receiver, + session_id + ); + outgoing_stream + .send(request.clone()) + .map_err(|e| eyre!(e.to_string()))?; + tracing::trace!( + "SUCCESS: Sending message {:?} from {:?} to {:?} in session {:?}", + request.data, + self.party_id, + receiver, + session_id + ); + Ok(()) + }) + .await + } + + async fn receive(&self, sender: &Identity, session_id: &SessionId) -> eyre::Result> { + // Just retrieve the first message from the corresponding queue + let queue = self.message_queues.get(session_id).ok_or(eyre!(format!( + "Session {session_id:?} hasn't been added to message queues" + )))?; + + tracing::trace!( + "Player {:?} is receiving message from {:?} in session {:?}", + self.party_id, + sender, + session_id + ); + + match timeout(self.config.timeout_duration, queue.pop(sender)).await { + Ok(res) => res, + Err(_) => Err(eyre!( + "Timeout while waiting for message from {sender:?} in session {session_id:?}" + )), + } + } +} + +pub async fn setup_local_grpc_networking( + parties: Vec, +) -> eyre::Result> { + let config = GrpcConfig { + timeout_duration: Duration::from_secs(1), + }; + + let players = parties + .iter() + .map(|party| GrpcNetworking::new(party.clone(), config.clone())) + .collect::>(); + + let addresses = get_free_local_addresses(players.len()).await?; + + let players_addresses = players + .iter() + .cloned() + .zip(addresses.iter().cloned()) + .collect::>(); + + // Initialize servers + for (player, addr) in &players_addresses { + let player = player.clone(); + let socket = addr.parse().unwrap(); + tokio::spawn(async move { + Server::builder() + .add_service(PartyNodeServer::new(player)) + .serve(socket) + .await + .unwrap(); + }); + } + + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + + // Connect to each other + for (player, addr) in &players_addresses { + for (other_player, other_addr) in &players_addresses.clone() { + if addr != other_addr { + let other_addr = format!("http://{}", other_addr); + player + .connect_to_party(other_player.party_id.clone(), &other_addr) + .await + .unwrap(); + } + } + } + + Ok(players) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + execution::{local::generate_local_identities, player::Role}, + hawkers::galois_store::LocalNetAby3NgStoreProtocol, + }; + use aes_prng::AesRng; + use hawk_pack::HawkSearcher; + use rand::SeedableRng; + use tokio::task::JoinSet; + use tracing_test::traced_test; + + async fn create_session_helper( + session_id: SessionId, + players: &[GrpcNetworking], + ) -> eyre::Result<()> { + let mut jobs = JoinSet::new(); + for player in players.iter() { + let player = player.clone(); + jobs.spawn(async move { + player.create_session(session_id).await.unwrap(); + }); + } + jobs.join_all().await; + Ok(()) + } + + #[tokio::test(flavor = "multi_thread")] + #[traced_test] + async fn test_grpc_comms_correct() -> eyre::Result<()> { + let identities = generate_local_identities(); + let players = setup_local_grpc_networking(identities.clone()).await?; + + let mut jobs = JoinSet::new(); + + // Simple session with one message sent from one party to another + { + let players = players.clone(); + + let session_id = SessionId::from(0); + + jobs.spawn(async move { + create_session_helper(session_id, &players).await.unwrap(); + + let alice = players[0].clone(); + let bob = players[1].clone(); + + // Send a message from the first party to the second party + let message = b"Hey, Bob. I'm Alice. Do you copy?".to_vec(); + let message_copy = message.clone(); + + let task1 = tokio::spawn(async move { + alice + .send(message.clone(), &"bob".into(), &session_id) + .await + .unwrap(); + }); + let task2 = tokio::spawn(async move { + let received_message = bob.receive(&"alice".into(), &session_id).await.unwrap(); + assert_eq!(message_copy, received_message); + }); + let _ = tokio::try_join!(task1, task2).unwrap(); + }); + } + + // Each party sending and receiving messages to each other + { + jobs.spawn(async move { + let session_id = SessionId::from(1); + + create_session_helper(session_id, &players).await.unwrap(); + + let mut tasks = JoinSet::new(); + // Send messages + for (player_id, player) in players.iter().enumerate() { + let role = Role::new(player_id); + let next = role.next(3).index(); + let prev = role.prev(3).index(); + + let player = player.clone(); + let next_id = identities[next].clone(); + let prev_id = identities[prev].clone(); + + tasks.spawn(async move { + // Sending + let msg_to_next = + format!("From player {} to player {} with love", player_id, next) + .into_bytes(); + let msg_to_prev = + format!("From player {} to player {} with love", player_id, prev) + .into_bytes(); + player + .send(msg_to_next.clone(), &next_id, &session_id) + .await + .unwrap(); + player + .send(msg_to_prev.clone(), &prev_id, &session_id) + .await + .unwrap(); + + // Receiving + let received_msg_from_prev = + player.receive(&prev_id, &session_id).await.unwrap(); + let expected_msg_from_prev = + format!("From player {} to player {} with love", prev, player_id) + .into_bytes(); + assert_eq!(received_msg_from_prev, expected_msg_from_prev); + let received_msg_from_next = + player.receive(&next_id, &session_id).await.unwrap(); + let expected_msg_from_next = + format!("From player {} to player {} with love", next, player_id) + .into_bytes(); + assert_eq!(received_msg_from_next, expected_msg_from_next); + }); + } + tasks.join_all().await; + }); + } + + jobs.join_all().await; + + Ok(()) + } + + #[tokio::test(flavor = "multi_thread")] + #[traced_test] + async fn test_grpc_comms_fail() -> eyre::Result<()> { + let parties = generate_local_identities(); + + let players = setup_local_grpc_networking(parties.clone()).await?; + + let mut jobs = JoinSet::new(); + + // Send to a non-existing party + { + let players = players.clone(); + jobs.spawn(async move { + let session_id = SessionId::from(0); + create_session_helper(session_id, &players).await.unwrap(); + + let alice = players[0].clone(); + let message = b"Hey, Eve. I'm Alice. Do you copy?".to_vec(); + let res = alice + .send(message.clone(), &Identity::from("eve"), &session_id) + .await; + assert!(res.is_err()); + }); + } + + // Receive from a wrong party + { + let players = players.clone(); + jobs.spawn(async move { + let session_id = SessionId::from(1); + create_session_helper(session_id, &players).await.unwrap(); + + let alice = players[0].clone(); + + let res = alice.receive(&Identity::from("eve"), &session_id).await; + assert!(res.is_err()); + }); + } + + // Send to itself + { + let players = players.clone(); + jobs.spawn(async move { + let session_id = SessionId::from(2); + create_session_helper(session_id, &players).await.unwrap(); + + let alice = players[0].clone(); + + let message = b"Hey, Alice. I'm Alice. Do you copy?".to_vec(); + let res = alice + .send(message.clone(), &Identity::from("alice"), &session_id) + .await; + assert!(res.is_err()); + }); + } + + // Add the same session + { + let players = players.clone(); + jobs.spawn(async move { + let session_id = SessionId::from(3); + create_session_helper(session_id, &players).await.unwrap(); + + let alice = players[0].clone(); + + let res = alice.create_session(session_id).await; + + assert!(res.is_err()); + }); + } + + // Send and retrieve from a non-existing session + { + let alice = players[0].clone(); + jobs.spawn(async move { + let session_id = SessionId::from(50); + + let message = b"Hey, Bob. I'm Alice. Do you copy?".to_vec(); + let res = alice + .send(message.clone(), &Identity::from("bob"), &session_id) + .await; + assert!(res.is_err()); + let res = alice.receive(&Identity::from("bob"), &session_id).await; + assert!(res.is_err()); + }); + } + + // Receive from a party that didn't send a message + { + let alice = players[0].clone(); + let players = players.clone(); + jobs.spawn(async move { + let session_id = SessionId::from(4); + create_session_helper(session_id, &players).await.unwrap(); + + let res = alice.receive(&Identity::from("bob"), &session_id).await; + assert!(res.is_err()); + }); + } + + jobs.join_all().await; + + Ok(()) + } + + #[tokio::test] + #[traced_test] + async fn test_hnsw_local() { + let mut rng = AesRng::seed_from_u64(0_u64); + let database_size = 2; + let searcher = HawkSearcher::default(); + let mut vectors_and_graphs = LocalNetAby3NgStoreProtocol::shared_random_setup( + &mut rng, + database_size, + crate::network::NetworkType::GrpcChannel, + ) + .await + .unwrap(); + + for i in 0..database_size { + let mut jobs = JoinSet::new(); + for (store, graph) in vectors_and_graphs.iter_mut() { + let mut store = store.clone(); + let mut graph = graph.clone(); + let searcher = searcher.clone(); + let q = store.prepare_query(store.storage.get_vector(&i.into()).clone()); + jobs.spawn(async move { + let secret_neighbors = searcher.search(&mut store, &mut graph, &q, 1).await; + searcher.is_match(&mut store, &[secret_neighbors]).await + }); + } + let res = jobs.join_all().await; + for (party_index, r) in res.iter().enumerate() { + assert!(r, "Failed at index {:?} by party {:?}", i, party_index); + } + } + } +} diff --git a/iris-mpc-cpu/src/network/local.rs b/iris-mpc-cpu/src/network/local.rs index 22ad2fcf5..91269795f 100644 --- a/iris-mpc-cpu/src/network/local.rs +++ b/iris-mpc-cpu/src/network/local.rs @@ -51,6 +51,7 @@ impl LocalNetworkingStore { } } +#[derive(Debug)] pub struct LocalNetworking { p2p_channels: P2PChannels, pub owner: Identity, @@ -114,7 +115,7 @@ mod tests { let bob = networking_store.get_local_network("bob".into()); let task1 = tokio::spawn(async move { - let recv = bob.receive(&"alice".into(), &1_u128.into()).await; + let recv = bob.receive(&"alice".into(), &1_u64.into()).await; assert_eq!( NetworkValue::from_network(recv).unwrap(), NetworkValue::Ring16(Wrapping::(777)) @@ -123,7 +124,7 @@ mod tests { let task2 = tokio::spawn(async move { let value = NetworkValue::Ring16(Wrapping::(777)); alice - .send(value.to_network(), &"bob".into(), &1_u128.into()) + .send(value.to_network(), &"bob".into(), &1_u64.into()) .await }); diff --git a/iris-mpc-cpu/src/network/mod.rs b/iris-mpc-cpu/src/network/mod.rs index d0261b0ec..edd362833 100644 --- a/iris-mpc-cpu/src/network/mod.rs +++ b/iris-mpc-cpu/src/network/mod.rs @@ -14,5 +14,12 @@ pub trait Networking { async fn receive(&self, sender: &Identity, session_id: &SessionId) -> eyre::Result>; } +#[derive(Clone)] +pub enum NetworkType { + LocalChannel, + GrpcChannel, +} + +pub mod grpc; pub mod local; pub mod value; diff --git a/iris-mpc-cpu/src/network/value.rs b/iris-mpc-cpu/src/network/value.rs index 93198a691..43be3b76b 100644 --- a/iris-mpc-cpu/src/network/value.rs +++ b/iris-mpc-cpu/src/network/value.rs @@ -23,7 +23,15 @@ impl NetworkValue { } pub fn from_network(serialized: eyre::Result>) -> eyre::Result { - bincode::deserialize::(&serialized?).map_err(|_e| eyre!("failed to parse value")) + bincode::deserialize::(&serialized?).map_err(|_e| eyre!("Failed to parse value")) + } + + pub fn vec_to_network(values: &Vec) -> Vec { + bincode::serialize(&values).unwrap() + } + + pub fn vec_from_network(serialized: eyre::Result>) -> eyre::Result> { + bincode::deserialize::>(&serialized?).map_err(|_e| eyre!("Failed to parse value")) } } @@ -39,8 +47,27 @@ impl TryFrom for Vec> { match value { NetworkValue::VecRing16(x) => Ok(x), _ => Err(eyre!( - "could not convert Network Value into Vec>" + "Could not convert Network Value into Vec>" )), } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_from_vec() -> eyre::Result<()> { + let values = (0..2).map(RingElement).collect::>(); + let network_values = values + .iter() + .map(|v| NetworkValue::RingElement16(*v)) + .collect::>(); + let serialized = NetworkValue::vec_to_network(&network_values); + let result_vec = NetworkValue::vec_from_network(Ok(serialized))?; + assert_eq!(network_values, result_vec); + + Ok(()) + } +} diff --git a/iris-mpc-cpu/src/proto/party_node.proto b/iris-mpc-cpu/src/proto/party_node.proto new file mode 100644 index 000000000..515e95520 --- /dev/null +++ b/iris-mpc-cpu/src/proto/party_node.proto @@ -0,0 +1,13 @@ +syntax = "proto3"; + +package party_node; + +service PartyNode { + rpc SendMessage (stream SendRequest) returns (SendResponse); +} + +message SendRequest { + bytes data = 1; +} + +message SendResponse {} \ No newline at end of file diff --git a/iris-mpc-cpu/src/proto_generated/mod.rs b/iris-mpc-cpu/src/proto_generated/mod.rs new file mode 100644 index 000000000..8c4bf931d --- /dev/null +++ b/iris-mpc-cpu/src/proto_generated/mod.rs @@ -0,0 +1 @@ +pub mod party_node; diff --git a/iris-mpc-cpu/src/proto_generated/party_node.rs b/iris-mpc-cpu/src/proto_generated/party_node.rs new file mode 100644 index 000000000..c2f51267a --- /dev/null +++ b/iris-mpc-cpu/src/proto_generated/party_node.rs @@ -0,0 +1,299 @@ +// This file is @generated by prost-build. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct SendRequest { + #[prost(bytes = "vec", tag = "1")] + pub data: ::prost::alloc::vec::Vec, +} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct SendResponse {} +/// Generated client implementations. +pub mod party_node_client { + #![allow( + unused_variables, + dead_code, + missing_docs, + clippy::wildcard_imports, + clippy::let_unit_value, + )] + use tonic::codegen::*; + use tonic::codegen::http::Uri; + #[derive(Debug, Clone)] + pub struct PartyNodeClient { + inner: tonic::client::Grpc, + } + impl PartyNodeClient { + /// Attempt to create a new client by connecting to a given endpoint. + pub async fn connect(dst: D) -> Result + where + D: TryInto, + D::Error: Into, + { + let conn = tonic::transport::Endpoint::new(dst)?.connect().await?; + Ok(Self::new(conn)) + } + } + impl PartyNodeClient + where + T: tonic::client::GrpcService, + T::Error: Into, + T::ResponseBody: Body + std::marker::Send + 'static, + ::Error: Into + std::marker::Send, + { + pub fn new(inner: T) -> Self { + let inner = tonic::client::Grpc::new(inner); + Self { inner } + } + pub fn with_origin(inner: T, origin: Uri) -> Self { + let inner = tonic::client::Grpc::with_origin(inner, origin); + Self { inner } + } + pub fn with_interceptor( + inner: T, + interceptor: F, + ) -> PartyNodeClient> + where + F: tonic::service::Interceptor, + T::ResponseBody: Default, + T: tonic::codegen::Service< + http::Request, + Response = http::Response< + >::ResponseBody, + >, + >, + , + >>::Error: Into + std::marker::Send + std::marker::Sync, + { + PartyNodeClient::new(InterceptedService::new(inner, interceptor)) + } + /// Compress requests with the given encoding. + /// + /// This requires the server to support it otherwise it might respond with an + /// error. + #[must_use] + pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.send_compressed(encoding); + self + } + /// Enable decompressing responses. + #[must_use] + pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.accept_compressed(encoding); + self + } + /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_decoding_message_size(limit); + self + } + /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_encoding_message_size(limit); + self + } + pub async fn send_message( + &mut self, + request: impl tonic::IntoStreamingRequest, + ) -> std::result::Result, tonic::Status> { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::unknown( + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/party_node.PartyNode/SendMessage", + ); + let mut req = request.into_streaming_request(); + req.extensions_mut() + .insert(GrpcMethod::new("party_node.PartyNode", "SendMessage")); + self.inner.client_streaming(req, path, codec).await + } + } +} +/// Generated server implementations. +pub mod party_node_server { + #![allow( + unused_variables, + dead_code, + missing_docs, + clippy::wildcard_imports, + clippy::let_unit_value, + )] + use tonic::codegen::*; + /// Generated trait containing gRPC methods that should be implemented for use with PartyNodeServer. + #[async_trait] + pub trait PartyNode: std::marker::Send + std::marker::Sync + 'static { + async fn send_message( + &self, + request: tonic::Request>, + ) -> std::result::Result, tonic::Status>; + } + #[derive(Debug)] + pub struct PartyNodeServer { + inner: Arc, + accept_compression_encodings: EnabledCompressionEncodings, + send_compression_encodings: EnabledCompressionEncodings, + max_decoding_message_size: Option, + max_encoding_message_size: Option, + } + impl PartyNodeServer { + pub fn new(inner: T) -> Self { + Self::from_arc(Arc::new(inner)) + } + pub fn from_arc(inner: Arc) -> Self { + Self { + inner, + accept_compression_encodings: Default::default(), + send_compression_encodings: Default::default(), + max_decoding_message_size: None, + max_encoding_message_size: None, + } + } + pub fn with_interceptor( + inner: T, + interceptor: F, + ) -> InterceptedService + where + F: tonic::service::Interceptor, + { + InterceptedService::new(Self::new(inner), interceptor) + } + /// Enable decompressing requests with the given encoding. + #[must_use] + pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.accept_compression_encodings.enable(encoding); + self + } + /// Compress responses with the given encoding, if the client supports it. + #[must_use] + pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.send_compression_encodings.enable(encoding); + self + } + /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.max_decoding_message_size = Some(limit); + self + } + /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.max_encoding_message_size = Some(limit); + self + } + } + impl tonic::codegen::Service> for PartyNodeServer + where + T: PartyNode, + B: Body + std::marker::Send + 'static, + B::Error: Into + std::marker::Send + 'static, + { + type Response = http::Response; + type Error = std::convert::Infallible; + type Future = BoxFuture; + fn poll_ready( + &mut self, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + fn call(&mut self, req: http::Request) -> Self::Future { + match req.uri().path() { + "/party_node.PartyNode/SendMessage" => { + #[allow(non_camel_case_types)] + struct SendMessageSvc(pub Arc); + impl< + T: PartyNode, + > tonic::server::ClientStreamingService + for SendMessageSvc { + type Response = super::SendResponse; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request>, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::send_message(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let method = SendMessageSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.client_streaming(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + _ => { + Box::pin(async move { + let mut response = http::Response::new(empty_body()); + let headers = response.headers_mut(); + headers + .insert( + tonic::Status::GRPC_STATUS, + (tonic::Code::Unimplemented as i32).into(), + ); + headers + .insert( + http::header::CONTENT_TYPE, + tonic::metadata::GRPC_CONTENT_TYPE, + ); + Ok(response) + }) + } + } + } + } + impl Clone for PartyNodeServer { + fn clone(&self) -> Self { + let inner = self.inner.clone(); + Self { + inner, + accept_compression_encodings: self.accept_compression_encodings, + send_compression_encodings: self.send_compression_encodings, + max_decoding_message_size: self.max_decoding_message_size, + max_encoding_message_size: self.max_encoding_message_size, + } + } + } + /// Generated gRPC service name + pub const SERVICE_NAME: &str = "party_node.PartyNode"; + impl tonic::server::NamedService for PartyNodeServer { + const NAME: &'static str = SERVICE_NAME; + } +} diff --git a/iris-mpc-cpu/src/protocol/binary.rs b/iris-mpc-cpu/src/protocol/binary.rs index 5059a828f..49ddccd88 100644 --- a/iris-mpc-cpu/src/protocol/binary.rs +++ b/iris-mpc-cpu/src/protocol/binary.rs @@ -10,6 +10,7 @@ use crate::{ }, }; use eyre::{eyre, Error}; +use itertools::Itertools; use num_traits::{One, Zero}; use rand::{distributions::Standard, prelude::Distribution, Rng}; use std::ops::SubAssign; @@ -29,7 +30,7 @@ pub(crate) fn a2b_pre( let mut x2 = Share::zero(); let mut x3 = Share::zero(); - match session.own_role()?.zero_based() { + match session.own_role()?.index() { 0 => { x1.a = a; x3.b = b; @@ -142,16 +143,26 @@ pub(crate) async fn transposed_pack_and( session: &mut Session, x1: Vec>, x2: Vec>, -) -> Result>, Error> -where - Standard: Distribution, -{ - // TODO(Dragos) this could probably be parallelized even more. - let mut res = Vec::with_capacity(x1.len()); - for (x1, x2) in x1.iter().zip(x2.iter()) { - let shares_a = and_many_send(session, x1.as_slice(), x2.as_slice()).await?; - let shares_b = and_many_receive(session).await?; - res.push(VecShare::from_ab(shares_a, shares_b)) +) -> Result>, Error> { + if x1.len() != x2.len() { + return Err(eyre!("Inputs have different length")); + } + let chunk_sizes = x1.iter().map(VecShare::len).collect::>(); + let chunk_sizes2 = x2.iter().map(VecShare::len).collect::>(); + if chunk_sizes != chunk_sizes2 { + return Err(eyre!("VecShare lengths are not equal")); + } + + let x1 = VecShare::flatten(x1); + let x2 = VecShare::flatten(x2); + let mut shares_a = and_many_send(session, x1.as_slice(), x2.as_slice()).await?; + let mut shares_b = and_many_receive(session).await?; + + let mut res = Vec::with_capacity(chunk_sizes.len()); + for l in chunk_sizes { + let a = shares_a.drain(..l).collect(); + let b = shares_b.drain(..l).collect(); + res.push(VecShare::from_ab(a, b)); } Ok(res) } @@ -282,15 +293,21 @@ async fn bit_inject_ot_2round_receiver( let sid = session.session_id(); let (m0, m1, wc) = tokio::spawn(async move { - let reply_m0 = network.receive(&next_id, &sid).await; - let m0 = match NetworkValue::from_network(reply_m0) { - Ok(NetworkValue::VecRing16(val)) => Ok(val), + let reply_m0_and_m1 = network.receive(&next_id, &sid).await; + let m0_and_m1 = NetworkValue::vec_from_network(reply_m0_and_m1).unwrap(); + assert!( + m0_and_m1.len() == 2, + "Deserialized vec in bit inject is wrong length" + ); + let (m0, m1) = m0_and_m1.into_iter().collect_tuple().unwrap(); + + let m0 = match m0 { + NetworkValue::VecRing16(val) => Ok(val), _ => Err(eyre!("Could not deserialize properly in bit inject")), }; - let reply_m1 = network.receive(&next_id, &sid).await; - let m1 = match NetworkValue::from_network(reply_m1) { - Ok(NetworkValue::VecRing16(val)) => Ok(val), + let m1 = match m1 { + NetworkValue::VecRing16(val) => Ok(val), _ => Err(eyre!("Could not deserialize properly in bit inject")), }; @@ -365,26 +382,27 @@ async fn bit_inject_ot_2round_sender( let prev_id = session.prev_identity()?; let sid = session.session_id(); // TODO(Dragos) Note this can be compressed in a single round. + let m0_and_m1: Vec = [m0, m1] + .into_iter() + .map(NetworkValue::VecRing16) + .collect::>(); // Reshare to Helper tokio::spawn(async move { let _ = network - .send(NetworkValue::VecRing16(m0).to_network(), &prev_id, &sid) - .await; - let _ = network - .send(NetworkValue::VecRing16(m1).to_network(), &prev_id, &sid) + .send(NetworkValue::vec_to_network(&m0_and_m1), &prev_id, &sid) .await; }) .await?; Ok(shares) } -// TODO this is inbalanced, so a real implementation should actually rotate +// TODO this is unbalanced, so a real implementation should actually rotate // parties around pub(crate) async fn bit_inject_ot_2round( session: &mut Session, input: VecShare, ) -> Result, Error> { - let res = match session.own_role()?.zero_based() { + let res = match session.own_role()?.index() { 0 => { // OT Helper bit_inject_ot_2round_helper(session, input).await? diff --git a/iris-mpc-cpu/src/protocol/ops.rs b/iris-mpc-cpu/src/protocol/ops.rs index fefb66ad1..510bf7fad 100644 --- a/iris-mpc-cpu/src/protocol/ops.rs +++ b/iris-mpc-cpu/src/protocol/ops.rs @@ -1,13 +1,18 @@ -use super::binary::single_extract_msb_u32; +use super::binary::{mul_lift_2k, single_extract_msb_u32}; use crate::{ database_generators::GaloisRingSharedIris, execution::session::{BootSession, Session, SessionHandles}, network::value::NetworkValue::{self}, protocol::{ - binary::{lift, mul_lift_2k, open_bin}, + binary::{lift, open_bin}, prf::{Prf, PrfSeed}, }, - shares::{bit::Bit, ring_impl::RingElement, share::Share, vecshare::VecShare}, + shares::{ + bit::Bit, + ring_impl::RingElement, + share::{DistanceShare, Share}, + vecshare::VecShare, + }, }; use eyre::eyre; @@ -15,7 +20,6 @@ pub(crate) const MATCH_THRESHOLD_RATIO: f64 = iris_mpc_common::iris_db::iris::MA pub(crate) const B_BITS: u64 = 16; pub(crate) const B: u64 = 1 << B_BITS; pub(crate) const A: u64 = ((1. - 2. * MATCH_THRESHOLD_RATIO) * B as f64) as u64; -pub(crate) const A_BITS: u32 = u64::BITS - A.leading_zeros(); /// Setup the PRF seeds in the replicated protocol. /// Each party sends to the next party a random seed. @@ -46,30 +50,47 @@ pub async fn setup_replicated_prf(session: &BootSession, my_seed: PrfSeed) -> ey Ok(Prf::new(my_seed, other_seed)) } -/// Takes as input two code and mask dot products between two Irises: i, j. -/// i.e. code_dot = and mask_dot = -/// Then lifts the two dot products to the larger ring (Z_{2^32}), multiplies -/// with some predefined constants B = 2^16 -/// A = ((1. - 2. * MATCH_THRESHOLD_RATIO) * B as f64) -/// and then compares mask_dot * A < code_dot * B. +/// Compares the distance between two iris pairs to a threshold. +/// +/// - Takes as input two code and mask dot products between two Irises: i, j. +/// i.e. code_dot = and mask_dot = . +/// - Lifts the two dot products to the ring Z_{2^32}. +/// - Multiplies with predefined threshold constants B = 2^16 and A = ((1. - 2. +/// * MATCH_THRESHOLD_RATIO) * B as f64). +/// - Compares mask_dot * A < code_dot * B. pub async fn compare_threshold( + session: &mut Session, + code_dot: Share, + mask_dot: Share, +) -> eyre::Result> { + let mut x = mask_dot * A as u32; + let y = code_dot * B as u32; + x -= y; + + single_extract_msb_u32::<32>(session, x).await +} + +/// The same as compare_threshold, but the input shares are 16-bit and lifted to +/// 32-bit before threshold comparison. +/// +/// See compare_threshold for more details. +pub async fn lift_and_compare_threshold( session: &mut Session, code_dot: Share, mask_dot: Share, ) -> eyre::Result> { - debug_assert!(A_BITS as u64 <= B_BITS); - let y = mul_lift_2k::(&code_dot); let mut x = lift::<{ B_BITS as usize }>(session, VecShare::new_vec(vec![mask_dot])).await?; - debug_assert_eq!(x.len(), 1); - let mut x = x.pop().expect("Enough elements present"); + let mut x = x.pop().expect("Expected a single element in the VecShare"); x *= A as u32; x -= y; single_extract_msb_u32::<32>(session, x).await } -pub(crate) async fn batch_signed_lift( +/// Lifts a share of a vector (VecShare) of 16-bit values to a share of a vector +/// (VecShare) of 32-bit values. +pub async fn batch_signed_lift( session: &mut Session, mut pre_lift: VecShare, ) -> eyre::Result> { @@ -87,40 +108,26 @@ pub(crate) async fn batch_signed_lift( Ok(lifted_values) } -/// Computes [D1 * T2; D2 * T1] via lifting -pub(crate) async fn cross_mul_via_lift( +/// Wrapper over batch_signed_lift that lifts a vector (Vec) of 16-bit shares to +/// a vector (Vec) of 32-bit shares. +pub async fn batch_signed_lift_vec( session: &mut Session, - d1: Share, - t1: Share, - d2: Share, - t2: Share, -) -> eyre::Result<(Share, Share)> { - let mut pre_lift = VecShare::::with_capacity(4); - // Do preprocessing to lift all values - pre_lift.push(d1); - pre_lift.push(t2); - pre_lift.push(d2); - pre_lift.push(t1); - - let lifted_values = batch_signed_lift(session, pre_lift).await?; - - // Compute d1 * t2; t2 * d1 - let mut exchanged_shares_a = Vec::with_capacity(2); - let pairs = [ - ( - lifted_values.shares[0].clone(), - lifted_values.shares[1].clone(), - ), - ( - lifted_values.shares[2].clone(), - lifted_values.shares[3].clone(), - ), - ]; - for pair in pairs.iter() { - let (x, y) = pair; - let res = session.prf_as_mut().gen_zero_share() + x * y; - exchanged_shares_a.push(res); - } + pre_lift: Vec>, +) -> eyre::Result>> { + let pre_lift = VecShare::new_vec(pre_lift); + Ok(batch_signed_lift(session, pre_lift).await?.inner()) +} + +/// Computes D2 * T1 - T2 * D1 +/// Assumes that the input shares are originally 16-bit and lifted to u32. +pub(crate) async fn cross_mul( + session: &mut Session, + d1: Share, + t1: Share, + d2: Share, + t2: Share, +) -> eyre::Result> { + let res_a = session.prf_as_mut().gen_zero_share() + &d2 * &t1 - &t2 * &d1; let network = session.network(); let next_role = session.identity(&session.own_role()?.next(3))?; @@ -128,7 +135,7 @@ pub(crate) async fn cross_mul_via_lift( network .send( - NetworkValue::VecRing32(exchanged_shares_a.clone()).to_network(), + NetworkValue::RingElement32(res_a).to_network(), next_role, &session.session_id(), ) @@ -136,43 +143,31 @@ pub(crate) async fn cross_mul_via_lift( let serialized_reply = network.receive(prev_role, &session.session_id()).await; let res_b = match NetworkValue::from_network(serialized_reply) { - Ok(NetworkValue::VecRing32(element)) => element, - _ => return Err(eyre!("Could not deserialize VecRing16")), + Ok(NetworkValue::RingElement32(element)) => element, + _ => return Err(eyre!("Could not deserialize RingElement32")), }; - if exchanged_shares_a.len() != res_b.len() { - return Err(eyre!( - "Expected a VecRing32 with length {:?} but received with length: {:?}", - exchanged_shares_a.len(), - res_b.len() - )); - } - - // vec![D1 * T2; T2 * D1] - let mut res = Vec::with_capacity(2); - for (a_share, b_share) in exchanged_shares_a.into_iter().zip(res_b) { - res.push(Share::new(a_share, b_share)); - } - Ok((res[0].clone(), res[1].clone())) + Ok(Share::new(res_a, res_b)) } -/// Computes (d2*t1 - d1*t2) > 0 by first lifting the values in a batch -/// from Z_{2^16} to a bigger ring Z_{2^32} +/// Computes (d2*t1 - d1*t2) > 0. /// Does the multiplication in Z_{2^32} and computes the MSB, to check the /// comparison result. /// d1, t1 are replicated shares that come from an iris code/mask dot product, /// ie: d1 = dot(c_x, c_y); t1 = dot(m_x, m_y). d2, t2 are replicated shares /// that come from an iris code and mask dot product, ie: /// d2 = dot(c_u, c_w), t2 = dot(m_u, m_w) +/// +/// Input values are assumed to be 16-bit shares that have been lifted to +/// 32-bit. pub async fn cross_compare( session: &mut Session, - d1: Share, - t1: Share, - d2: Share, - t2: Share, + d1: Share, + t1: Share, + d2: Share, + t2: Share, ) -> eyre::Result { - let (d1t2, d2t1) = cross_mul_via_lift(session, d1, t1, d2, t2).await?; - let diff = d2t1 - d1t2; + let diff = cross_mul(session, d1, t1, d2, t2).await?; // Compute bit <- MSB(D2 * T1 - D1 * T2) let bit = single_extract_msb_u32::<32>(session, diff).await?; // Open bit @@ -247,10 +242,10 @@ pub async fn galois_ring_to_rep3( /// Checks whether first Iris entry in the pair matches the Iris in the second /// entry. This is done in the following manner: -/// Compute the dot product between the two Irises. -/// Convert the partial shamir share result to a replicated sharing and then -/// Compare the distance using the MATCH_THRESHOLD_RATIO from the -/// `compare_threshold` function. +/// - Compute the dot product between the two Irises. +/// - Convert the partial Shamir share result to a replicated sharing and then +/// - Compare the distance using the MATCH_THRESHOLD_RATIO from the +/// `lift_and_compare_threshold` function. pub async fn galois_ring_is_match( session: &mut Session, pairs: &[(GaloisRingSharedIris, GaloisRingSharedIris)], @@ -259,18 +254,17 @@ pub async fn galois_ring_is_match( let additive_dots = galois_ring_pairwise_distance(session, pairs).await?; let rep_dots = galois_ring_to_rep3(session, additive_dots).await?; // compute dots[0] - dots[1] - let bit = compare_threshold(session, rep_dots[0].clone(), rep_dots[1].clone()).await?; + let bit = lift_and_compare_threshold(session, rep_dots[0].clone(), rep_dots[1].clone()).await?; let opened = open_bin(session, bit).await?; Ok(opened.convert()) } -/// Checks that the given dot product is zero. -pub async fn is_dot_zero( +/// Compares the given distance to a threshold and reveal the result. +pub async fn compare_threshold_and_open( session: &mut Session, - code_dot: Share, - mask_dot: Share, + distance: DistanceShare, ) -> eyre::Result { - let bit = compare_threshold(session, code_dot, mask_dot).await?; + let bit = compare_threshold(session, distance.code_dot, distance.mask_dot).await?; let opened = open_bin(session, bit).await?; Ok(opened.convert()) } @@ -280,7 +274,10 @@ mod tests { use super::*; use crate::{ database_generators::generate_galois_iris_shares, - execution::{local::LocalRuntime, player::Identity}, + execution::{ + local::{generate_local_identities, LocalRuntime}, + player::Identity, + }, hawkers::plaintext_store::PlaintextIris, protocol::ops::NetworkValue::RingElement32, shares::{int_ring::IntRing2k, ring_impl::RingElement}, @@ -352,15 +349,16 @@ mod tests { #[tokio::test] async fn test_async_prf_setup() { let num_parties = 3; - let identities: Vec = vec!["alice".into(), "bob".into(), "charlie".into()]; + let identities = generate_local_identities(); let mut seeds = Vec::new(); for i in 0..num_parties { let mut seed = [0_u8; 16]; seed[0] = i; seeds.push(seed); } - let local = LocalRuntime::new(identities.clone(), seeds.clone()); - let mut ready_sessions = local.create_player_sessions().await.unwrap(); + let mut runtime = LocalRuntime::new(identities.clone(), seeds.clone()) + .await + .unwrap(); // check whether parties have sent/received the correct seeds. // P0: [seed_0, seed_2] @@ -368,7 +366,8 @@ mod tests { // P2: [seed_2, seed_1] // This is done by calling next() on the PRFs and see whether they match with // the ones created from scratch. - let prf0 = ready_sessions + let prf0 = runtime + .sessions .get_mut(&"alice".into()) .unwrap() .prf_as_mut(); @@ -381,7 +380,11 @@ mod tests { Prf::new(seeds[0], seeds[2]).get_prev_prf().next_u64() ); - let prf1 = ready_sessions.get_mut(&"bob".into()).unwrap().prf_as_mut(); + let prf1 = runtime + .sessions + .get_mut(&"bob".into()) + .unwrap() + .prf_as_mut(); assert_eq!( prf1.get_my_prf().next_u64(), Prf::new(seeds[1], seeds[0]).get_my_prf().next_u64() @@ -391,7 +394,8 @@ mod tests { Prf::new(seeds[1], seeds[0]).get_prev_prf().next_u64() ); - let prf2 = ready_sessions + let prf2 = runtime + .sessions .get_mut(&"charlie".into()) .unwrap() .prf_as_mut(); @@ -464,15 +468,19 @@ mod tests { seed[0] = i; seeds.push(seed); } - let local = LocalRuntime::new(identities.clone(), seeds.clone()); - let ready_sessions = local.create_player_sessions().await.unwrap(); + let runtime = LocalRuntime::new(identities.clone(), seeds.clone()) + .await + .unwrap(); let mut jobs = JoinSet::new(); for player in identities.iter() { - let mut player_session = ready_sessions.get(player).unwrap().clone(); + let mut player_session = runtime.sessions.get(player).unwrap().clone(); let four_shares = four_share_map.get(player).unwrap().clone(); jobs.spawn(async move { - let out_shared = cross_mul_via_lift( + let four_shares = batch_signed_lift_vec(&mut player_session, four_shares) + .await + .unwrap(); + let out_shared = cross_mul( &mut player_session, four_shares[0].clone(), four_shares[1].clone(), @@ -481,16 +489,13 @@ mod tests { ) .await .unwrap(); - ( - open_single(&player_session, out_shared.0).await.unwrap(), - open_single(&player_session, out_shared.1).await.unwrap(), - ) + + open_single(&player_session, out_shared).await.unwrap() }); } // check first party output is equal to the expected result. let t = jobs.join_next().await.unwrap().unwrap(); - assert_eq!(t.0, RingElement(4)); - assert_eq!(t.1, RingElement(6)); + assert_eq!(t, RingElement(2)); } async fn open_additive(session: &Session, x: Vec>) -> eyre::Result> { @@ -537,8 +542,7 @@ mod tests { #[case(1)] #[case(2)] async fn test_galois_ring_to_rep3(#[case] seed: u64) { - let runtime = LocalRuntime::replicated_test_config(); - let ready_sessions = runtime.create_player_sessions().await.unwrap(); + let runtime = LocalRuntime::mock_setup_with_channel().await.unwrap(); let mut rng = AesRng::seed_from_u64(seed); let iris_db = IrisDB::new_random_rng(2, &mut rng).db; @@ -548,7 +552,7 @@ mod tests { let mut jobs = JoinSet::new(); for (index, player) in runtime.identities.iter().cloned().enumerate() { - let mut player_session = ready_sessions.get(&player).unwrap().clone(); + let mut player_session = runtime.sessions.get(&player).unwrap().clone(); let mut own_shares = vec![(first_entry[index].clone(), second_entry[index].clone())]; own_shares.iter_mut().for_each(|(_x, y)| { y.code.preprocess_iris_code_query_share(); diff --git a/iris-mpc-cpu/src/py_bindings/hnsw.rs b/iris-mpc-cpu/src/py_bindings/hnsw.rs new file mode 100644 index 000000000..471de784a --- /dev/null +++ b/iris-mpc-cpu/src/py_bindings/hnsw.rs @@ -0,0 +1,113 @@ +use super::plaintext_store::Base64IrisCode; +use crate::hawkers::plaintext_store::{PlaintextStore, PointId}; +use hawk_pack::{graph_store::GraphMem, HawkSearcher}; +use iris_mpc_common::iris_db::iris::IrisCode; +use rand::rngs::ThreadRng; +use serde_json::{self, Deserializer}; +use std::{fs::File, io::BufReader}; + +pub fn search( + query: IrisCode, + searcher: &HawkSearcher, + vector: &mut PlaintextStore, + graph: &mut GraphMem, +) -> (PointId, f64) { + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); + + rt.block_on(async move { + let query = vector.prepare_query(query); + let neighbors = searcher.search(vector, graph, &query, 1).await; + let (nearest, (dist_num, dist_denom)) = neighbors.get_nearest().unwrap(); + (*nearest, (*dist_num as f64) / (*dist_denom as f64)) + }) +} + +// TODO could instead take iterator of IrisCodes to make more flexible +pub fn insert( + iris: IrisCode, + searcher: &HawkSearcher, + vector: &mut PlaintextStore, + graph: &mut GraphMem, +) -> PointId { + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); + + rt.block_on(async move { + let mut rng = ThreadRng::default(); + + let query = vector.prepare_query(iris); + searcher.insert(vector, graph, &query, &mut rng).await + }) +} + +pub fn insert_uniform_random( + searcher: &HawkSearcher, + vector: &mut PlaintextStore, + graph: &mut GraphMem, +) -> PointId { + let mut rng = ThreadRng::default(); + let raw_query = IrisCode::random_rng(&mut rng); + + insert(raw_query, searcher, vector, graph) +} + +pub fn fill_uniform_random( + num: usize, + searcher: &HawkSearcher, + vector: &mut PlaintextStore, + graph: &mut GraphMem, +) { + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); + + rt.block_on(async move { + let mut rng = ThreadRng::default(); + + for idx in 0..num { + let raw_query = IrisCode::random_rng(&mut rng); + let query = vector.prepare_query(raw_query.clone()); + searcher.insert(vector, graph, &query, &mut rng).await; + if idx % 100 == 99 { + println!("{}", idx + 1); + } + } + }) +} + +pub fn fill_from_ndjson_file( + filename: &str, + limit: Option, + searcher: &HawkSearcher, + vector: &mut PlaintextStore, + graph: &mut GraphMem, +) { + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); + + rt.block_on(async move { + let mut rng = ThreadRng::default(); + + let file = File::open(filename).unwrap(); + let reader = BufReader::new(file); + + // Create an iterator over deserialized objects + let stream = Deserializer::from_reader(reader).into_iter::(); + let stream = super::limited_iterator(stream, limit); + + // Iterate over each deserialized object + for json_pt in stream { + let raw_query = (&json_pt.unwrap()).into(); + let query = vector.prepare_query(raw_query); + searcher.insert(vector, graph, &query, &mut rng).await; + } + }) +} diff --git a/iris-mpc-cpu/src/py_bindings/io.rs b/iris-mpc-cpu/src/py_bindings/io.rs new file mode 100644 index 000000000..77f2c5b6f --- /dev/null +++ b/iris-mpc-cpu/src/py_bindings/io.rs @@ -0,0 +1,36 @@ +use bincode; +use eyre::Result; +use serde::{de::DeserializeOwned, Serialize}; +use serde_json; +use std::{ + fs::File, + io::{BufReader, BufWriter}, +}; + +pub fn write_bin(data: &T, filename: &str) -> Result<()> { + let file = File::create(filename)?; + let writer = BufWriter::new(file); + bincode::serialize_into(writer, data)?; + Ok(()) +} + +pub fn read_bin(filename: &str) -> Result { + let file = File::open(filename)?; + let reader = BufReader::new(file); + let data: T = bincode::deserialize_from(reader)?; + Ok(data) +} + +pub fn write_json(data: &T, filename: &str) -> Result<()> { + let file = File::create(filename)?; + let writer = BufWriter::new(file); + serde_json::to_writer(writer, &data)?; + Ok(()) +} + +pub fn read_json(filename: &str) -> Result { + let file = File::open(filename)?; + let reader = BufReader::new(file); + let data: T = serde_json::from_reader(reader)?; + Ok(data) +} diff --git a/iris-mpc-cpu/src/py_bindings/mod.rs b/iris-mpc-cpu/src/py_bindings/mod.rs new file mode 100644 index 000000000..b655e05f2 --- /dev/null +++ b/iris-mpc-cpu/src/py_bindings/mod.rs @@ -0,0 +1,13 @@ +pub mod hnsw; +pub mod io; +pub mod plaintext_store; + +pub fn limited_iterator(iter: I, limit: Option) -> Box> +where + I: Iterator + 'static, +{ + match limit { + Some(num) => Box::new(iter.take(num)), + None => Box::new(iter), + } +} diff --git a/iris-mpc-cpu/src/py_bindings/plaintext_store.rs b/iris-mpc-cpu/src/py_bindings/plaintext_store.rs new file mode 100644 index 000000000..7340454e8 --- /dev/null +++ b/iris-mpc-cpu/src/py_bindings/plaintext_store.rs @@ -0,0 +1,79 @@ +use crate::hawkers::plaintext_store::{PlaintextIris, PlaintextPoint, PlaintextStore}; +use iris_mpc_common::iris_db::iris::{IrisCode, IrisCodeArray}; +use serde::{Deserialize, Serialize}; +use std::{ + fs::File, + io::{self, BufReader, BufWriter, Write}, +}; + +/// Iris code representation using base64 encoding compatible with Open IRIS +#[derive(Serialize, Deserialize)] +pub struct Base64IrisCode { + iris_codes: String, + mask_codes: String, +} + +impl From<&IrisCode> for Base64IrisCode { + fn from(value: &IrisCode) -> Self { + Self { + iris_codes: value.code.to_base64().unwrap(), + mask_codes: value.mask.to_base64().unwrap(), + } + } +} + +impl From<&Base64IrisCode> for IrisCode { + fn from(value: &Base64IrisCode) -> Self { + Self { + code: IrisCodeArray::from_base64(&value.iris_codes).unwrap(), + mask: IrisCodeArray::from_base64(&value.mask_codes).unwrap(), + } + } +} + +pub fn from_ndjson_file(filename: &str, len: Option) -> io::Result { + let file = File::open(filename)?; + let reader = BufReader::new(file); + + // Create an iterator over deserialized objects + let stream = serde_json::Deserializer::from_reader(reader).into_iter::(); + let stream = super::limited_iterator(stream, len); + + // Iterate over each deserialized object + let mut vector = PlaintextStore::default(); + for json_pt in stream { + let json_pt = json_pt?; + vector.points.push(PlaintextPoint { + data: PlaintextIris((&json_pt).into()), + is_persistent: true, + }); + } + + if let Some(num) = len { + if vector.points.len() != num { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "File {} contains too few entries; number read: {}", + filename, + vector.points.len() + ), + )); + } + } + + Ok(vector) +} + +pub fn to_ndjson_file(vector: &PlaintextStore, filename: &str) -> std::io::Result<()> { + // Serialize the objects to the file + let file = File::create(filename)?; + let mut writer = BufWriter::new(file); + for pt in &vector.points { + let json_pt: Base64IrisCode = (&pt.data.0).into(); + serde_json::to_writer(&mut writer, &json_pt)?; + writer.write_all(b"\n")?; // Write a newline after each JSON object + } + writer.flush()?; + Ok(()) +} diff --git a/iris-mpc-cpu/src/shares/share.rs b/iris-mpc-cpu/src/shares/share.rs index 4092760b4..2b5a584df 100644 --- a/iris-mpc-cpu/src/shares/share.rs +++ b/iris-mpc-cpu/src/shares/share.rs @@ -29,7 +29,7 @@ impl Share { } pub fn add_assign_const_role(&mut self, other: T, role: Role) { - match role.zero_based() { + match role.index() { 0 => self.a += RingElement(other), 1 => self.b += RingElement(other), 2 => {} @@ -319,3 +319,20 @@ impl Shl for Share { } } } + +// Additive share of a Hamming distance value +#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)] +#[serde(bound = "")] +pub struct DistanceShare { + pub code_dot: Share, + pub mask_dot: Share, +} + +impl DistanceShare +where + T: IntRing2k, +{ + pub fn new(code_dot: Share, mask_dot: Share) -> Self { + DistanceShare { code_dot, mask_dot } + } +} diff --git a/iris-mpc-gpu/Cargo.toml b/iris-mpc-gpu/Cargo.toml index 9ee89c19e..da1c791c4 100644 --- a/iris-mpc-gpu/Cargo.toml +++ b/iris-mpc-gpu/Cargo.toml @@ -31,6 +31,7 @@ iris-mpc-common = { path = "../iris-mpc-common" } base64 = "0.22.1" metrics = "0.22.1" metrics-exporter-statsd = "0.7" +memmap2.workspace = true [dev-dependencies] criterion = "0.5" diff --git a/iris-mpc-gpu/src/dot/distance_comparator.rs b/iris-mpc-gpu/src/dot/distance_comparator.rs index c70fd4a8e..455a9c70a 100644 --- a/iris-mpc-gpu/src/dot/distance_comparator.rs +++ b/iris-mpc-gpu/src/dot/distance_comparator.rs @@ -276,7 +276,7 @@ impl DistanceComparator { pub fn fetch_all_match_ids( &self, - match_counters: Vec>, + match_counters: &[Vec], matches: &[CudaSlice], ) -> Vec> { let mut results = vec![]; @@ -289,6 +289,7 @@ impl DistanceComparator { ); } + let batch_match_idx: u32 = u32::MAX - (self.query_length / ROTATIONS) as u32; // batch matches have an index of u32::MAX - index let mut matches_per_query = vec![vec![]; match_counters[0].len()]; let n_devices = self.device_manager.device_count(); for i in 0..self.device_manager.device_count() { @@ -297,7 +298,14 @@ impl DistanceComparator { let len = match_counters[i][j] as usize; let ids = results[i][offset..offset + min(len, ALL_MATCHES_LEN)] .iter() - .map(|idx| idx * n_devices as u32 + i as u32) + .map(|&idx| { + if idx > batch_match_idx { + idx + } else { + idx * n_devices as u32 + i as u32 + } + }) + .filter(|&idx| idx < batch_match_idx || i == 0) // take all normal matches, but only batch matches from device 0 .collect::>(); matches_per_query[j].extend_from_slice(&ids); offset += ALL_MATCHES_LEN; diff --git a/iris-mpc-gpu/src/dot/kernel.cu b/iris-mpc-gpu/src/dot/kernel.cu index ff9c13633..b523b02be 100644 --- a/iris-mpc-gpu/src/dot/kernel.cu +++ b/iris-mpc-gpu/src/dot/kernel.cu @@ -108,7 +108,7 @@ extern "C" __global__ void mergeBatchResults(unsigned long long *matchResultsSel continue; // Query is already considering rotations, ignore rotated db entries - if ((dbIdx - ROTATIONS) % ALL_ROTATIONS != 0) + if ((dbIdx < ROTATIONS) || ((dbIdx - ROTATIONS) % ALL_ROTATIONS != 0)) continue; // Only consider results above the diagonal diff --git a/iris-mpc-gpu/src/dot/share_db.rs b/iris-mpc-gpu/src/dot/share_db.rs index d1169a761..b6dcd210a 100644 --- a/iris-mpc-gpu/src/dot/share_db.rs +++ b/iris-mpc-gpu/src/dot/share_db.rs @@ -20,18 +20,19 @@ use cudarc::{ CudaBlas, }, driver::{ - result::{self, malloc_async, malloc_managed}, - sys::{CUdeviceptr, CUmemAttach_flags}, + result::{self, malloc_async}, + sys::{CUdeviceptr, CU_MEMHOSTALLOC_PORTABLE}, CudaFunction, CudaSlice, CudaStream, CudaView, DevicePtr, DeviceSlice, LaunchAsync, }, nccl, nvrtc::compile_ptx, }; use itertools::{izip, Itertools}; +use memmap2::MmapMut; use rayon::prelude::*; use std::{ ffi::{c_void, CStr}, - mem, + mem::{self, forget}, sync::Arc, }; @@ -114,6 +115,12 @@ pub struct SlicedProcessedDatabase { pub code_sums_gr: CudaVec2DSlicerU32, } +#[derive(Clone)] +pub struct DBChunkBuffers { + pub limb_0: Vec>, + pub limb_1: Vec>, +} + pub struct ShareDB { peer_id: usize, is_remote: bool, @@ -237,22 +244,23 @@ impl ShareDB { .devices() .iter() .map(|device| unsafe { + let host_mem0 = MmapMut::map_anon(max_size * self.code_length).unwrap(); + let host_mem1 = MmapMut::map_anon(max_size * self.code_length).unwrap(); + + let host_mem0_ptr = host_mem0.as_ptr() as u64; + let host_mem1_ptr = host_mem1.as_ptr() as u64; + + // Make sure to not drop the memory, even though we only use the pointers + // afterwards. This also has the effect that this memory is never freed, which + // is fine for the db. + forget(host_mem0); + forget(host_mem1); + ( StreamAwareCudaSlice::from(device.alloc(max_size).unwrap()), ( StreamAwareCudaSlice::from(device.alloc(max_size).unwrap()), - ( - malloc_managed( - max_size * self.code_length, - CUmemAttach_flags::CU_MEM_ATTACH_GLOBAL, - ) - .unwrap(), - malloc_managed( - max_size * self.code_length, - CUmemAttach_flags::CU_MEM_ATTACH_GLOBAL, - ) - .unwrap(), - ), + (host_mem0_ptr, host_mem1_ptr), ), ) }) @@ -274,6 +282,26 @@ impl ShareDB { } } + pub fn register_host_memory(&self, db: &SlicedProcessedDatabase, max_db_length: usize) { + let max_size = max_db_length / self.device_manager.device_count(); + for (device_index, device) in self.device_manager.devices().iter().enumerate() { + device.bind_to_thread().unwrap(); + unsafe { + let _ = cudarc::driver::sys::lib().cuMemHostRegister_v2( + db.code_gr.limb_0[device_index] as *mut _, + max_size * self.code_length, + CU_MEMHOSTALLOC_PORTABLE, + ); + + let _ = cudarc::driver::sys::lib().cuMemHostRegister_v2( + db.code_gr.limb_1[device_index] as *mut _, + max_size * self.code_length, + CU_MEMHOSTALLOC_PORTABLE, + ); + } + } + } + pub fn load_single_record( index: usize, db: &CudaVec2DSlicerRawPointer, @@ -281,7 +309,7 @@ impl ShareDB { n_shards: usize, code_length: usize, ) { - assert!(record.len() == code_length); + assert_eq!(record.len(), code_length); let a0_host = record .iter() @@ -450,6 +478,64 @@ impl ShareDB { } } + pub fn alloc_db_chunk_buffer(&self, max_chunk_size: usize) -> DBChunkBuffers { + let mut limb_0 = vec![]; + let mut limb_1 = vec![]; + for device in self.device_manager.devices() { + unsafe { + limb_0.push(device.alloc(max_chunk_size * self.code_length).unwrap()); + limb_1.push(device.alloc(max_chunk_size * self.code_length).unwrap()); + } + } + DBChunkBuffers { limb_0, limb_1 } + } + + pub fn prefetch_db_chunk( + &self, + db: &SlicedProcessedDatabase, + buffers: &DBChunkBuffers, + chunk_sizes: &[usize], + offset: &[usize], + db_sizes: &[usize], + streams: &[CudaStream], + ) { + for idx in 0..self.device_manager.device_count() { + let device = self.device_manager.device(idx); + device.bind_to_thread().unwrap(); + + if offset[idx] >= db_sizes[idx] + || offset[idx] + chunk_sizes[idx] > db_sizes[idx] + || chunk_sizes[idx] == 0 + { + continue; + } + + unsafe { + cudarc::driver::sys::lib() + .cuMemcpyHtoDAsync_v2( + *buffers.limb_0[idx].device_ptr(), + (db.code_gr.limb_0[idx] as usize + offset[idx] * self.code_length) + as *mut _, + chunk_sizes[idx] * self.code_length, + streams[idx].stream, + ) + .result() + .unwrap(); + + cudarc::driver::sys::lib() + .cuMemcpyHtoDAsync_v2( + *buffers.limb_1[idx].device_ptr(), + (db.code_gr.limb_1[idx] as usize + offset[idx] * self.code_length) + as *mut _, + chunk_sizes[idx] * self.code_length, + streams[idx].stream, + ) + .result() + .unwrap(); + } + } + } + pub fn dot( &mut self, queries: &CudaVec2DSlicer, @@ -805,6 +891,7 @@ mod tests { .unwrap(); let query_sums = engine.query_sums(&preprocessed_query, &streams, &blass); let mut db_slices = engine.alloc_db(DB_SIZE); + engine.register_host_memory(&db_slices, DB_SIZE); let db_sizes = engine.load_full_db(&mut db_slices, &db); engine.dot( @@ -906,6 +993,7 @@ mod tests { .unwrap(); let query_sums = engine.query_sums(&preprocessed_query, &streams, &blass); let mut db_slices = engine.alloc_db(DB_SIZE); + engine.register_host_memory(&db_slices, DB_SIZE); let db_sizes = engine.load_full_db(&mut db_slices, &codes_db); engine.dot( @@ -1039,6 +1127,8 @@ mod tests { let db_sizes = codes_engine.load_full_db(&mut code_db_slices, &codes_db); let mut mask_db_slices = masks_engine.alloc_db(DB_SIZE); let mask_db_sizes = masks_engine.load_full_db(&mut mask_db_slices, &masks_db); + codes_engine.register_host_memory(&code_db_slices, DB_SIZE); + masks_engine.register_host_memory(&mask_db_slices, DB_SIZE); assert_eq!(db_sizes, mask_db_sizes); diff --git a/iris-mpc-gpu/src/helpers/device_manager.rs b/iris-mpc-gpu/src/helpers/device_manager.rs index f7f7b098e..4053324a4 100644 --- a/iris-mpc-gpu/src/helpers/device_manager.rs +++ b/iris-mpc-gpu/src/helpers/device_manager.rs @@ -102,6 +102,13 @@ impl DeviceManager { events } + pub fn destroy_events(&self, events: Vec) { + for (device_idx, event) in events.iter().enumerate() { + self.device(device_idx).bind_to_thread().unwrap(); + unsafe { event::destroy(*event).unwrap() }; + } + } + pub fn record_event(&self, streams: &[CudaStream], events: &[CUevent]) { for idx in 0..self.devices.len() { unsafe { diff --git a/iris-mpc-gpu/src/helpers/mod.rs b/iris-mpc-gpu/src/helpers/mod.rs index 149fc10cf..e4bd114ee 100644 --- a/iris-mpc-gpu/src/helpers/mod.rs +++ b/iris-mpc-gpu/src/helpers/mod.rs @@ -1,7 +1,7 @@ use crate::threshold_ring::protocol::ChunkShare; use cudarc::driver::{ result::{self, memcpy_dtoh_async, memcpy_htod_async, stream}, - sys::{CUdeviceptr, CUstream, CUstream_st}, + sys::{lib, CUdeviceptr, CUstream, CUstream_st}, CudaDevice, CudaSlice, CudaStream, DevicePtr, DevicePtrMut, DeviceRepr, DriverError, LaunchConfig, }; @@ -104,6 +104,32 @@ pub unsafe fn dtod_at_offset( } } +/// Copy a slice from device to host with respective offsets. +/// # Safety +/// +/// The caller must ensure that the `dst` and `src` pointers are valid +/// with the respective offsets +pub unsafe fn dtoh_at_offset( + dst: u64, + dst_offset: usize, + src: CUdeviceptr, + src_offset: usize, + len: usize, + stream_ptr: CUstream, +) { + unsafe { + lib() + .cuMemcpyDtoHAsync_v2( + (dst + dst_offset as u64) as *mut _, + (src + src_offset as u64) as CUdeviceptr, + len, + stream_ptr, + ) + .result() + .unwrap(); + } +} + pub fn dtoh_on_stream_sync>( input: &U, device: &Arc, diff --git a/iris-mpc-gpu/src/helpers/query_processor.rs b/iris-mpc-gpu/src/helpers/query_processor.rs index a02a3b4bd..de27e4e5d 100644 --- a/iris-mpc-gpu/src/helpers/query_processor.rs +++ b/iris-mpc-gpu/src/helpers/query_processor.rs @@ -1,6 +1,6 @@ use crate::{ dot::{ - share_db::{ShareDB, SlicedProcessedDatabase}, + share_db::{DBChunkBuffers, ShareDB, SlicedProcessedDatabase}, IRIS_CODE_LENGTH, MASK_CODE_LENGTH, }, helpers::device_manager::DeviceManager, @@ -82,6 +82,15 @@ impl From<&CudaVec2DSlicer> for CudaVec2DSlicerRawPointer { } } +impl From<&DBChunkBuffers> for CudaVec2DSlicerRawPointer { + fn from(buffers: &DBChunkBuffers) -> Self { + CudaVec2DSlicerRawPointer { + limb_0: buffers.limb_0.iter().map(|s| *s.device_ptr()).collect(), + limb_1: buffers.limb_1.iter().map(|s| *s.device_ptr()).collect(), + } + } +} + pub struct CudaVec2DSlicer { pub limb_0: Vec>, pub limb_1: Vec>, @@ -193,8 +202,8 @@ impl DeviceCompactQuery { &self, code_engine: &mut ShareDB, mask_engine: &mut ShareDB, - sliced_code_db: &SlicedProcessedDatabase, - sliced_mask_db: &SlicedProcessedDatabase, + sliced_code_db: &CudaVec2DSlicerRawPointer, + sliced_mask_db: &CudaVec2DSlicerRawPointer, database_sizes: &[usize], offset: usize, streams: &[CudaStream], @@ -202,7 +211,7 @@ impl DeviceCompactQuery { ) { code_engine.dot( &self.code_query, - &sliced_code_db.code_gr, + sliced_code_db, database_sizes, offset, streams, @@ -210,7 +219,7 @@ impl DeviceCompactQuery { ); mask_engine.dot( &self.mask_query, - &sliced_mask_db.code_gr, + sliced_mask_db, database_sizes, offset, streams, diff --git a/iris-mpc-gpu/src/server/actor.rs b/iris-mpc-gpu/src/server/actor.rs index df4814f1a..6711c0884 100644 --- a/iris-mpc-gpu/src/server/actor.rs +++ b/iris-mpc-gpu/src/server/actor.rs @@ -2,21 +2,23 @@ use super::{BatchQuery, Eye, ServerJob, ServerJobResult}; use crate::{ dot::{ distance_comparator::DistanceComparator, - share_db::{preprocess_query, ShareDB, SlicedProcessedDatabase}, + share_db::{preprocess_query, DBChunkBuffers, ShareDB, SlicedProcessedDatabase}, IRIS_CODE_LENGTH, MASK_CODE_LENGTH, ROTATIONS, }, helpers::{ self, comm::NcclComm, device_manager::DeviceManager, - query_processor::{CompactQuery, DeviceCompactQuery, DeviceCompactSums}, + query_processor::{ + CompactQuery, CudaVec2DSlicerRawPointer, DeviceCompactQuery, DeviceCompactSums, + }, }, threshold_ring::protocol::{ChunkShare, Circuits}, }; use cudarc::{ cublas::CudaBlas, driver::{ - result::{self, event::elapsed}, + result::{self, event::elapsed, mem_get_info}, sys::CUevent, CudaDevice, CudaSlice, CudaStream, DevicePtr, DeviceSlice, }, @@ -35,14 +37,18 @@ use std::{collections::HashMap, mem, sync::Arc, time::Instant}; use tokio::sync::{mpsc, oneshot}; macro_rules! record_stream_time { - ($manager:expr, $streams:expr, $map:expr, $label:expr, $block:block) => {{ - let evt0 = $manager.create_events(); - let evt1 = $manager.create_events(); - $manager.record_event($streams, &evt0); - let res = $block; - $manager.record_event($streams, &evt1); - $map.entry($label).or_default().extend(vec![evt0, evt1]); - res + ($manager:expr, $streams:expr, $map:expr, $label:expr, $enable_timing:expr, $block:block) => {{ + if $enable_timing { + let evt0 = $manager.create_events(); + let evt1 = $manager.create_events(); + $manager.record_event($streams, &evt0); + let res = $block; + $manager.record_event($streams, &evt1); + $map.entry($label).or_default().extend(vec![evt0, evt1]); + res + } else { + $block + } }}; } @@ -68,6 +74,7 @@ impl ServerActorHandle { const DB_CHUNK_SIZE: usize = 1 << 15; const KDF_SALT: &str = "111a1a93518f670e9bb0c2c68888e2beb9406d4c4ed571dc77b801e676ae3091"; // Random 32 byte salt +const SUPERMATCH_THRESHOLD: usize = 4_000; pub struct ServerActor { job_queue: mpsc::Receiver, @@ -102,6 +109,12 @@ pub struct ServerActor { max_db_size: usize, return_partial_results: bool, disable_persistence: bool, + enable_debug_timing: bool, + code_chunk_buffers: Vec, + mask_chunk_buffers: Vec, + dot_events: Vec>, + exchange_events: Vec>, + phase2_events: Vec>, } const NON_MATCH_ID: u32 = u32::MAX; @@ -116,6 +129,7 @@ impl ServerActor { max_batch_size: usize, return_partial_results: bool, disable_persistence: bool, + enable_debug_timing: bool, ) -> eyre::Result<(Self, ServerActorHandle)> { let device_manager = Arc::new(DeviceManager::init()); Self::new_with_device_manager( @@ -127,6 +141,7 @@ impl ServerActor { max_batch_size, return_partial_results, disable_persistence, + enable_debug_timing, ) } #[allow(clippy::too_many_arguments)] @@ -139,6 +154,7 @@ impl ServerActor { max_batch_size: usize, return_partial_results: bool, disable_persistence: bool, + enable_debug_timing: bool, ) -> eyre::Result<(Self, ServerActorHandle)> { let ids = device_manager.get_ids_from_magic(0); let comms = device_manager.instantiate_network_from_ids(party_id, &ids)?; @@ -152,6 +168,7 @@ impl ServerActor { max_batch_size, return_partial_results, disable_persistence, + enable_debug_timing, ) } @@ -166,6 +183,7 @@ impl ServerActor { max_batch_size: usize, return_partial_results: bool, disable_persistence: bool, + enable_debug_timing: bool, ) -> eyre::Result<(Self, ServerActorHandle)> { let (tx, rx) = mpsc::channel(job_queue_size); let actor = Self::init( @@ -178,6 +196,7 @@ impl ServerActor { max_batch_size, return_partial_results, disable_persistence, + enable_debug_timing, )?; Ok((actor, ServerActorHandle { job_queue: tx })) } @@ -193,6 +212,7 @@ impl ServerActor { max_batch_size: usize, return_partial_results: bool, disable_persistence: bool, + enable_debug_timing: bool, ) -> eyre::Result { assert!(max_batch_size != 0); let mut kdf_nonce = 0; @@ -233,11 +253,15 @@ impl ServerActor { comms.clone(), ); + let now = Instant::now(); + let left_code_db_slices = codes_engine.alloc_db(max_db_size); let left_mask_db_slices = masks_engine.alloc_db(max_db_size); let right_code_db_slices = codes_engine.alloc_db(max_db_size); let right_mask_db_slices = masks_engine.alloc_db(max_db_size); + tracing::info!("Allocated db in {:?}", now.elapsed()); + // Engines for inflight queries let batch_codes_engine = ShareDB::init( party_id, @@ -316,9 +340,16 @@ impl ServerActor { let batch_match_list_right = distance_comparator.prepare_db_match_list(n_queries); let query_db_size = vec![n_queries; device_manager.device_count()]; - let current_db_sizes = vec![0; device_manager.device_count()]; + let code_chunk_buffers = vec![codes_engine.alloc_db_chunk_buffer(DB_CHUNK_SIZE); 2]; + let mask_chunk_buffers = vec![masks_engine.alloc_db_chunk_buffer(DB_CHUNK_SIZE); 2]; + + // Create all needed events + let dot_events = vec![device_manager.create_events(); 2]; + let exchange_events = vec![device_manager.create_events(); 2]; + let phase2_events = vec![device_manager.create_events(); 2]; + for dev in device_manager.devices() { dev.synchronize().unwrap(); } @@ -354,6 +385,12 @@ impl ServerActor { max_db_size, return_partial_results, disable_persistence, + enable_debug_timing, + code_chunk_buffers, + mask_chunk_buffers, + dot_events, + exchange_events, + phase2_events, }) } @@ -460,6 +497,9 @@ impl ServerActor { self.device_manager.device_count(), MASK_CODE_LENGTH, ); + } + + pub fn increment_db_size(&mut self, index: usize) { self.current_db_sizes[index % self.device_manager.device_count()] += 1; } @@ -474,6 +514,17 @@ impl ServerActor { .preprocess_db(&mut self.right_mask_db_slices, &self.current_db_sizes); } + pub fn register_host_memory(&self) { + self.codes_engine + .register_host_memory(&self.left_code_db_slices, self.max_db_size); + self.masks_engine + .register_host_memory(&self.left_mask_db_slices, self.max_db_size); + self.codes_engine + .register_host_memory(&self.right_code_db_slices, self.max_db_size); + self.masks_engine + .register_host_memory(&self.right_mask_db_slices, self.max_db_size); + } + fn process_batch_query( &mut self, batch: BatchQuery, @@ -578,6 +629,7 @@ impl ServerActor { &self.streams[0], events, "query_preprocess", + self.enable_debug_timing, { // This needs to be max_batch_size, even though the query can be shorter to have // enough padding for GEMM @@ -624,6 +676,7 @@ impl ServerActor { &self.streams[0], events, "query_preprocess", + self.enable_debug_timing, { // This needs to be MAX_BATCH_SIZE, even though the query can be shorter to have // enough padding for GEMM @@ -674,6 +727,7 @@ impl ServerActor { ); self.device_manager.await_streams(&self.streams[0]); + self.device_manager.await_streams(&self.streams[1]); // Iterate over a list of tracing payloads, and create logs with mappings to // payloads Log at least a "start" event using a log with trace.id @@ -695,30 +749,6 @@ impl ServerActor { // Truncate the results to the batch size host_results.iter_mut().for_each(|x| x.truncate(batch_size)); - // Evaluate the results across devices - // Format: merged_results[query_index] - let mut merged_results = - get_merged_results(&host_results, self.device_manager.device_count()); - - // List the indices of the queries that did not match. - let insertion_list = merged_results - .iter() - .enumerate() - .filter(|&(_idx, &num)| num == NON_MATCH_ID) - .map(|(idx, _num)| idx) - .collect::>(); - - // Spread the insertions across devices. - let insertion_list = distribute_insertions(&insertion_list, &self.current_db_sizes); - - // Calculate the new indices for the inserted queries - let matches = calculate_insertion_indices( - &mut merged_results, - &insertion_list, - &self.current_db_sizes, - batch_size, - ); - // Fetch and truncate the match counters let match_counters_devices = self .distance_comparator @@ -740,7 +770,7 @@ impl ServerActor { // Transfer all match ids let match_ids = self.distance_comparator.fetch_all_match_ids( - match_counters_devices, + &match_counters_devices, &self.distance_comparator.all_matches, ); @@ -757,34 +787,100 @@ impl ServerActor { } } - let (partial_match_ids_left, partial_match_ids_right) = if self.return_partial_results { + // Fetch the partial matches + let ( + partial_match_ids_left, + partial_match_counters_left, + partial_match_ids_right, + partial_match_counters_right, + ) = if self.return_partial_results { // Transfer the partial results to the host let partial_match_counters_left = self .distance_comparator - .fetch_match_counters(&self.distance_comparator.match_counters_left); + .fetch_match_counters(&self.distance_comparator.match_counters_left) + .into_iter() + .map(|x| x[..batch_size].to_vec()) + .collect::>(); let partial_match_counters_right = self .distance_comparator - .fetch_match_counters(&self.distance_comparator.match_counters_right); + .fetch_match_counters(&self.distance_comparator.match_counters_right) + .into_iter() + .map(|x| x[..batch_size].to_vec()) + .collect::>(); let partial_results_left = self.distance_comparator.fetch_all_match_ids( - partial_match_counters_left, + &partial_match_counters_left, &self.distance_comparator.partial_results_left, ); let partial_results_right = self.distance_comparator.fetch_all_match_ids( - partial_match_counters_right, + &partial_match_counters_right, &self.distance_comparator.partial_results_right, ); - (partial_results_left, partial_results_right) + ( + partial_results_left, + partial_match_counters_left, + partial_results_right, + partial_match_counters_right, + ) } else { - (vec![], vec![]) + (vec![], vec![], vec![], vec![]) }; + let partial_match_counters_left = partial_match_counters_left.iter().fold( + vec![0usize; batch_size], + |mut acc, counters| { + for (i, &value) in counters.iter().enumerate() { + acc[i] += value as usize; + } + acc + }, + ); + + let partial_match_counters_right = partial_match_counters_right.iter().fold( + vec![0usize; batch_size], + |mut acc, counters| { + for (i, &value) in counters.iter().enumerate() { + acc[i] += value as usize; + } + acc + }, + ); + + // Evaluate the results across devices + // Format: merged_results[query_index] + let mut merged_results = + get_merged_results(&host_results, self.device_manager.device_count()); + + // List the indices of the queries that did not match. + let insertion_list = merged_results + .iter() + .enumerate() + .filter(|&(idx, &num)| { + num == NON_MATCH_ID + // Filter-out supermatchers on both sides (TODO: remove this in the future) + && partial_match_counters_left[idx] <= SUPERMATCH_THRESHOLD + && partial_match_counters_right[idx] <= SUPERMATCH_THRESHOLD + }) + .map(|(idx, _num)| idx) + .collect::>(); + + // Spread the insertions across devices. + let insertion_list = distribute_insertions(&insertion_list, &self.current_db_sizes); + + // Calculate the new indices for the inserted queries + let matches = calculate_insertion_indices( + &mut merged_results, + &insertion_list, + &self.current_db_sizes, + batch_size, + ); + // Check for batch matches let matched_batch_request_ids = match_ids .iter() .map(|ids| { ids.iter() - .filter(|&&x| x > (u32::MAX - self.max_batch_size as u32)) + .filter(|&&x| x > (u32::MAX - batch_size as u32)) // ignore matches outside the batch size (dummy matches) .map(|&x| batch.request_ids[(u32::MAX - x) as usize].clone()) .collect::>() }) @@ -823,6 +919,7 @@ impl ServerActor { &self.streams[0], events, "db_write", + self.enable_debug_timing, { for i in 0..self.device_manager.device_count() { self.device_manager.device(i).bind_to_thread().unwrap(); @@ -871,12 +968,8 @@ impl ServerActor { }) .unwrap(); - // Wait for all streams before get timings - self.device_manager.await_streams(&self.streams[0]); - self.device_manager.await_streams(&self.streams[1]); - // Reset the results buffers for reuse - for dst in &[ + for dst in [ &self.db_match_list_left, &self.db_match_list_right, &self.batch_match_list_left, @@ -885,29 +978,24 @@ impl ServerActor { reset_slice(self.device_manager.devices(), dst, 0, &self.streams[0]); } - reset_slice( - self.device_manager.devices(), + for dst in [ + &self.distance_comparator.all_matches, &self.distance_comparator.match_counters, - 0, - &self.streams[0], - ); - - reset_slice( - self.device_manager.devices(), &self.distance_comparator.match_counters_left, - 0, - &self.streams[0], - ); - - reset_slice( - self.device_manager.devices(), &self.distance_comparator.match_counters_right, - 0, - &self.streams[0], - ); + &self.distance_comparator.partial_results_left, + &self.distance_comparator.partial_results_right, + ] { + reset_slice(self.device_manager.devices(), dst, 0, &self.streams[0]); + } + + self.device_manager.await_streams(&self.streams[0]); + self.device_manager.await_streams(&self.streams[1]); // ---- END RESULT PROCESSING ---- - log_timers(events); + if self.enable_debug_timing { + log_timers(events); + } let processed_mil_elements_per_second = (self.max_batch_size * previous_total_db_size) as f64 / now.elapsed().as_secs_f64() @@ -933,6 +1021,21 @@ impl ServerActor { metrics::gauge!("batch_size").set(batch_size as f64); metrics::gauge!("max_batch_size").set(self.max_batch_size as f64); + // Update GPU memory metrics + let mut sum_free = 0; + let mut sum_total = 0; + for i in 0..self.device_manager.device_count() { + let device = self.device_manager.device(i); + unsafe { result::ctx::set_current(*device.cu_primary_ctx()) }.unwrap(); + let (free, total) = mem_get_info()?; + metrics::gauge!(format!("gpu_memory_free_{}", i)).set(free as f64); + metrics::gauge!(format!("gpu_memory_total_{}", i)).set(total as f64); + sum_free += free; + sum_total += total; + } + metrics::gauge!("gpu_memory_free_sum").set(sum_free as f64); + metrics::gauge!("gpu_memory_total_sum").set(sum_total as f64); + Ok(()) } @@ -960,34 +1063,42 @@ impl ServerActor { // ---- START BATCH DEDUP ---- tracing::info!(party_id = self.party_id, "Starting batch deduplication"); - record_stream_time!(&self.device_manager, batch_streams, events, "batch_dot", { - tracing::info!(party_id = self.party_id, "batch_dot start"); + record_stream_time!( + &self.device_manager, + batch_streams, + events, + "batch_dot", + self.enable_debug_timing, + { + tracing::info!(party_id = self.party_id, "batch_dot start"); - compact_device_queries.compute_dot_products( - &mut self.batch_codes_engine, - &mut self.batch_masks_engine, - &self.query_db_size, - 0, - batch_streams, - batch_cublas, - ); - tracing::info!(party_id = self.party_id, "compute_dot_reducers start"); + compact_device_queries.compute_dot_products( + &mut self.batch_codes_engine, + &mut self.batch_masks_engine, + &self.query_db_size, + 0, + batch_streams, + batch_cublas, + ); + tracing::info!(party_id = self.party_id, "compute_dot_reducers start"); - compact_device_sums.compute_dot_reducers( - &mut self.batch_codes_engine, - &mut self.batch_masks_engine, - &self.query_db_size, - 0, - batch_streams, - ); - tracing::info!(party_id = self.party_id, "batch_dot end"); - }); + compact_device_sums.compute_dot_reducers( + &mut self.batch_codes_engine, + &mut self.batch_masks_engine, + &self.query_db_size, + 0, + batch_streams, + ); + tracing::info!(party_id = self.party_id, "batch_dot end"); + } + ); record_stream_time!( &self.device_manager, batch_streams, events, "batch_reshare", + self.enable_debug_timing, { tracing::info!(party_id = self.party_id, "batch_reshare start"); self.batch_codes_engine @@ -1009,6 +1120,7 @@ impl ServerActor { batch_streams, events, "batch_threshold", + self.enable_debug_timing, { tracing::info!(party_id = self.party_id, "batch_threshold start"); self.phase2_batch.compare_threshold_masked_many( @@ -1042,13 +1154,38 @@ impl ServerActor { tracing::info!(party_id = self.party_id, "Finished batch deduplication"); // ---- END BATCH DEDUP ---- - // Create new initial events - let mut current_dot_event = self.device_manager.create_events(); - let mut next_dot_event = self.device_manager.create_events(); - let mut current_exchange_event = self.device_manager.create_events(); - let mut next_exchange_event = self.device_manager.create_events(); - let mut current_phase2_event = self.device_manager.create_events(); - let mut next_phase2_event = self.device_manager.create_events(); + let chunk_sizes = |chunk_idx: usize| { + self.current_db_sizes + .iter() + .map(|s| (s - DB_CHUNK_SIZE * chunk_idx).clamp(0, DB_CHUNK_SIZE)) + .collect::>() + }; + + record_stream_time!( + &self.device_manager, + &self.streams[0], + events, + "prefetch_db_chunk", + self.enable_debug_timing, + { + self.codes_engine.prefetch_db_chunk( + code_db_slices, + &self.code_chunk_buffers[0], + &chunk_sizes(0), + &vec![0; self.device_manager.device_count()], + &self.current_db_sizes, + &self.streams[0], + ); + self.masks_engine.prefetch_db_chunk( + mask_db_slices, + &self.mask_chunk_buffers[0], + &chunk_sizes(0), + &vec![0; self.device_manager.device_count()], + &self.current_db_sizes, + &self.streams[0], + ); + } + ); // ---- START DATABASE DEDUP ---- tracing::info!(party_id = self.party_id, "Start DB deduplication"); @@ -1057,14 +1194,12 @@ impl ServerActor { let mut db_chunk_idx = 0; loop { let request_streams = &self.streams[db_chunk_idx % 2]; + let next_request_streams = &self.streams[(db_chunk_idx + 1) % 2]; let request_cublas_handles = &self.cublas_handles[db_chunk_idx % 2]; let offset = db_chunk_idx * DB_CHUNK_SIZE; - let chunk_size = self - .current_db_sizes - .iter() - .map(|s| (s - DB_CHUNK_SIZE * db_chunk_idx).clamp(1, DB_CHUNK_SIZE)) - .collect::>(); + let chunk_size = chunk_sizes(db_chunk_idx); + let next_chunk_size = chunk_sizes(db_chunk_idx + 1); // We need to pad the chunk size for two reasons: // 1. Chunk size needs to be a multiple of 4, because the underlying @@ -1075,45 +1210,84 @@ impl ServerActor { // later. let dot_chunk_size = chunk_size .iter() - .map(|s| s.div_ceil(64) * 64) + .map(|&s| (s.max(1).div_ceil(64) * 64)) .collect::>(); // First stream doesn't need to wait if db_chunk_idx == 0 { self.device_manager - .record_event(request_streams, ¤t_dot_event); + .record_event(request_streams, &self.dot_events[db_chunk_idx % 2]); self.device_manager - .record_event(request_streams, ¤t_exchange_event); + .record_event(request_streams, &self.exchange_events[db_chunk_idx % 2]); self.device_manager - .record_event(request_streams, ¤t_phase2_event); + .record_event(request_streams, &self.phase2_events[db_chunk_idx % 2]); } + // Prefetch next chunk + record_stream_time!( + &self.device_manager, + next_request_streams, + events, + "prefetch_db_chunk", + self.enable_debug_timing, + { + self.codes_engine.prefetch_db_chunk( + code_db_slices, + &self.code_chunk_buffers[(db_chunk_idx + 1) % 2], + &next_chunk_size, + &chunk_size.iter().map(|s| offset + s).collect::>(), + &self.current_db_sizes, + next_request_streams, + ); + self.masks_engine.prefetch_db_chunk( + mask_db_slices, + &self.mask_chunk_buffers[(db_chunk_idx + 1) % 2], + &next_chunk_size, + &chunk_size.iter().map(|s| offset + s).collect::>(), + &self.current_db_sizes, + next_request_streams, + ); + } + ); + self.device_manager - .await_event(request_streams, ¤t_dot_event); + .await_event(request_streams, &self.dot_events[db_chunk_idx % 2]); // ---- START PHASE 1 ---- - record_stream_time!(&self.device_manager, batch_streams, events, "db_dot", { - compact_device_queries.dot_products_against_db( - &mut self.codes_engine, - &mut self.masks_engine, - code_db_slices, - mask_db_slices, - &dot_chunk_size, - offset, - request_streams, - request_cublas_handles, - ); - }); + record_stream_time!( + &self.device_manager, + batch_streams, + events, + "db_dot", + self.enable_debug_timing, + { + compact_device_queries.dot_products_against_db( + &mut self.codes_engine, + &mut self.masks_engine, + &CudaVec2DSlicerRawPointer::from( + &self.code_chunk_buffers[db_chunk_idx % 2], + ), + &CudaVec2DSlicerRawPointer::from( + &self.mask_chunk_buffers[db_chunk_idx % 2], + ), + &dot_chunk_size, + 0, + request_streams, + request_cublas_handles, + ); + } + ); // wait for the exchange result buffers to be ready self.device_manager - .await_event(request_streams, ¤t_exchange_event); + .await_event(request_streams, &self.exchange_events[db_chunk_idx % 2]); record_stream_time!( &self.device_manager, request_streams, events, "db_reduce", + self.enable_debug_timing, { compact_device_sums.compute_dot_reducer_against_db( &mut self.codes_engine, @@ -1128,13 +1302,14 @@ impl ServerActor { ); self.device_manager - .record_event(request_streams, &next_dot_event); + .record_event(request_streams, &self.dot_events[(db_chunk_idx + 1) % 2]); record_stream_time!( &self.device_manager, request_streams, events, "db_reshare", + self.enable_debug_timing, { self.codes_engine .reshare_results(&dot_chunk_size, request_streams); @@ -1146,7 +1321,7 @@ impl ServerActor { // ---- END PHASE 1 ---- self.device_manager - .await_event(request_streams, ¤t_phase2_event); + .await_event(request_streams, &self.phase2_events[db_chunk_idx % 2]); // ---- START PHASE 2 ---- let max_chunk_size = dot_chunk_size.iter().max().copied().unwrap(); @@ -1167,6 +1342,7 @@ impl ServerActor { request_streams, events, "db_threshold", + self.enable_debug_timing, { self.phase2.compare_threshold_masked_many( &code_dots, @@ -1178,40 +1354,41 @@ impl ServerActor { // we can now record the exchange event since the phase 2 is no longer using the // code_dots/mask_dots which are just reinterpretations of the exchange result // buffers - self.device_manager - .record_event(request_streams, &next_exchange_event); + self.device_manager.record_event( + request_streams, + &self.exchange_events[(db_chunk_idx + 1) % 2], + ); let res = self.phase2.take_result_buffer(); - record_stream_time!(&self.device_manager, request_streams, events, "db_open", { - open( - &mut self.phase2, - &res, - &self.distance_comparator, - db_match_bitmap, - max_chunk_size * self.max_batch_size * ROTATIONS / 64, - &dot_chunk_size, - &chunk_size, - offset, - &self.current_db_sizes, - &ignore_device_results, - request_streams, - ); - self.phase2.return_result_buffer(res); - }); + record_stream_time!( + &self.device_manager, + request_streams, + events, + "db_open", + self.enable_debug_timing, + { + open( + &mut self.phase2, + &res, + &self.distance_comparator, + db_match_bitmap, + max_chunk_size * self.max_batch_size * ROTATIONS / 64, + &dot_chunk_size, + &chunk_size, + offset, + &self.current_db_sizes, + &ignore_device_results, + request_streams, + ); + self.phase2.return_result_buffer(res); + } + ); } self.device_manager - .record_event(request_streams, &next_phase2_event); + .record_event(request_streams, &self.phase2_events[(db_chunk_idx + 1) % 2]); // ---- END PHASE 2 ---- - // Update events for synchronization - current_dot_event = next_dot_event; - current_exchange_event = next_exchange_event; - current_phase2_event = next_phase2_event; - next_dot_event = self.device_manager.create_events(); - next_exchange_event = self.device_manager.create_events(); - next_phase2_event = self.device_manager.create_events(); - // Increment chunk index db_chunk_idx += 1; @@ -1552,7 +1729,7 @@ fn write_db_at_index( ), ] { unsafe { - helpers::dtod_at_offset( + helpers::dtoh_at_offset( db.code_gr.limb_0[device_index], dst_index * code_length, *query.limb_0[device_index].device_ptr(), @@ -1561,7 +1738,7 @@ fn write_db_at_index( streams[device_index].stream, ); - helpers::dtod_at_offset( + helpers::dtoh_at_offset( db.code_gr.limb_1[device_index], dst_index * code_length, *query.limb_1[device_index].device_ptr(), diff --git a/iris-mpc-gpu/tests/e2e.rs b/iris-mpc-gpu/tests/e2e.rs index 3af422b4c..df377fa08 100644 --- a/iris-mpc-gpu/tests/e2e.rs +++ b/iris-mpc-gpu/tests/e2e.rs @@ -129,9 +129,11 @@ mod e2e_test { MAX_BATCH_SIZE, true, false, + false, ) { Ok((mut actor, handle)) => { actor.load_full_db(&(&db0.0, &db0.1), &(&db0.0, &db0.1), DB_SIZE); + actor.register_host_memory(); tx0.send(Ok(handle)).unwrap(); actor } @@ -156,9 +158,11 @@ mod e2e_test { MAX_BATCH_SIZE, true, false, + false, ) { Ok((mut actor, handle)) => { actor.load_full_db(&(&db1.0, &db1.1), &(&db1.0, &db1.1), DB_SIZE); + actor.register_host_memory(); tx1.send(Ok(handle)).unwrap(); actor } @@ -183,9 +187,11 @@ mod e2e_test { MAX_BATCH_SIZE, true, false, + false, ) { Ok((mut actor, handle)) => { actor.load_full_db(&(&db2.0, &db2.1), &(&db2.0, &db2.1), DB_SIZE); + actor.register_host_memory(); tx2.send(Ok(handle)).unwrap(); actor } diff --git a/iris-mpc-py/.gitignore b/iris-mpc-py/.gitignore new file mode 100644 index 000000000..c8f044299 --- /dev/null +++ b/iris-mpc-py/.gitignore @@ -0,0 +1,72 @@ +/target + +# Byte-compiled / optimized / DLL files +__pycache__/ +.pytest_cache/ +*.py[cod] + +# C extensions +*.so + +# Distribution / packaging +.Python +.venv/ +env/ +bin/ +build/ +develop-eggs/ +dist/ +eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +include/ +man/ +venv/ +*.egg-info/ +.installed.cfg +*.egg + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt +pip-selfcheck.json + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.cache +nosetests.xml +coverage.xml + +# Translations +*.mo + +# Mr Developer +.mr.developer.cfg +.project +.pydevproject + +# Rope +.ropeproject + +# Django stuff: +*.log +*.pot + +.DS_Store + +# Sphinx documentation +docs/_build/ + +# PyCharm +.idea/ + +# VSCode +.vscode/ + +# Pyenv +.python-version diff --git a/iris-mpc-py/Cargo.toml b/iris-mpc-py/Cargo.toml new file mode 100644 index 000000000..d3d325935 --- /dev/null +++ b/iris-mpc-py/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "iris-mpc-py" +version = "0.1.0" +publish = false + +edition.workspace = true +license.workspace = true +repository.workspace = true + +[lib] +name = "iris_mpc_py" +crate-type = ["cdylib"] + +[dependencies] +iris-mpc-common = { path = "../iris-mpc-common" } +iris-mpc-cpu = { path = "../iris-mpc-cpu" } +hawk-pack.workspace = true +pyo3 = { version = "0.22.0", features = ["extension-module"] } +rand.workspace = true diff --git a/iris-mpc-py/README.md b/iris-mpc-py/README.md new file mode 100644 index 000000000..aea78cecb --- /dev/null +++ b/iris-mpc-py/README.md @@ -0,0 +1,93 @@ +# Python Bindings + +This package provides Python bindings for some functionalities in the `iris-mpc` workspace, currently focused on execution of the HNSW k-nearest neighbors graph search algorithm over plaintext iris codes for testing and data analysis. For compatibility, compilation of this crate is disabled from the workspace root, but enabled from within the crate subdirectory via the Cargo default feature flag `enable`. + +## Installation + +Installation of Python bindings from the PyO3 library code can be accomplished using the Maturin Python package as follows: + +- Install Maturin in the target Python environment, e.g. the venv used for data analysis, using `pip install maturin` + +- Optionally install `patchelf` library with `pip install patchelf` for support for patching wheel files that link other shared libraries + +- Build and install current bindings as a module in the current Python environment by navigating to the `iris-mpc-py` directory and running `maturin develop --release` + +- Build a wheel file suitable for installation using `pip install` by instead running `maturin build --release`; the `.whl` file is specific to the building architecture and Python version, and can be found in `iris_mpc/target/wheels` directory + +See the [Maturin User Guide Tutorial](https://www.maturin.rs/tutorial#build-and-install-the-module-with-maturin-develop) for additional details. + +## Usage + +Once successfully installed, the native rust module `iris_mpc_py` can be imported in your Python environment as usual with `import iris_mpc_py`. Example usage: + +```python +from iris_mpc_py import PyHawkSearcher, PyPlaintextStore, PyGraphStore, PyIrisCode + +hnsw = PyHawkSearcher(32, 64, 32) # M, ef_constr, ef_search +vector = PyPlaintextStore() +graph = PyGraphStore() + +hnsw.fill_uniform_random(1000, vector, graph) + +iris = PyIrisCode.uniform_random() +iris_id = hnsw.insert(iris, vector, graph) +print("Inserted iris id:", iris_id) + +nearest_id, nearest_dist = hnsw.search(iris, vector, graph) +print("Nearest iris id:", nearest_id) # should be iris_id +print("Nearest iris distance:", nearest_dist) # should be 0.0 +``` + +To write the HNSW vector and graph indices to file and read them back: + +```python +hnsw.write_to_json("searcher.json") +vector.write_to_ndjson("vector.ndjson") +graph.write_to_bin("graph.dat") + +hnsw2 = PyHawkSearcher.read_from_json("searcher.json") +vector2 = PyPlaintextStore.read_from_ndjson("vector.ndjson") +graph2 = PyGraphStore.read_from_bin("graph.dat") +``` + +As an efficiency feature, the data from the vector store is read in a streamed fashion. This means that for a large database of iris codes, the first `num` can be read from file without loading the entire database into memory. This can be used in two ways; first, a vector store can be initialized from the large databse file for use with a previously generated HNSW index: + +```python +# Serialized HNSW graph constructed from the first 10k entries of database file +vector = PyPlaintextStore.read_from_ndjson("large_vector_database.ndjson", 10000) +graph = PyGraphStore.read_from_bin("graph.dat") +``` + +Second, to construct an HNSW index dynamically from streamed database entries: + +```python +hnsw = PyHawkSearcher(32, 64, 32) +vector = PyPlaintextStore() +graph = PyGraphStore() +hnsw.fill_from_ndjson_file("large_vector_database.ndjson", vector, graph, 10000) +``` + +To generate a vector database directly for use in this way: + +```python +# Generate 100k uniform random iris codes +vector_init = PyPlaintextStore() +for i in range(1,100000): + vector_init.insert(PyIrisCode.uniform_random()) +vector_init.write_to_ndjson("vector.ndjson") +``` + +Basic interoperability with Open IRIS iris templates is provided by way of a common base64 encoding scheme, provided by the `iris.io.dataclasses.IrisTemplate` methods `serialize` and `deserialize`. These methods use a base64 encoding of iris code and mask code arrays represented as a Python `dict` with base64-encoded fields `iris_codes`, `mask_codes`, and a version string `iris_code_version` to check for compatibility. The `PyIrisCode` class interacts with this representation as follows: + +```python +serialized_iris_code = { + "iris_codes": "...", + "mask_codes": "...", + "iris_code_version": "1.0", +} + +iris = PyIrisCode.from_open_iris_template_dict(serialized_iris_code) +reserialized_iris_code = iris.to_open_iris_template_dict("1.0") +``` + +Note that the `to_open_iris_template_dict` method takes an optional argument which fills the `iris_code_version` field of the resulting Python `dict` since the `PyIrisCode` object does not preserve this data. diff --git a/iris-mpc-py/examples-py/test_integration.py b/iris-mpc-py/examples-py/test_integration.py new file mode 100644 index 000000000..d22bad8ee --- /dev/null +++ b/iris-mpc-py/examples-py/test_integration.py @@ -0,0 +1,37 @@ +from iris_mpc_py import PyIrisCode, PyPlaintextStore, PyGraphStore, PyHawkSearcher + +print("Generating 100k uniform random iris codes...") +vector_init = PyPlaintextStore() +iris0 = PyIrisCode.uniform_random() +iris_id = vector_init.insert(iris0) +for i in range(1,100000): + vector_init.insert(PyIrisCode.uniform_random()) + +# write vector store to file +print("Writing vector store to file...") +vector_init.write_to_ndjson("vector.ndjson") + +print("Generating HNSW graphs for 10k imported iris codes...") +hnsw = PyHawkSearcher(32, 64, 32) +vector1 = PyPlaintextStore() +graph1 = PyGraphStore() +hnsw.fill_from_ndjson_file("vector.ndjson", vector1, graph1, 10000) + +print("Imported length:", vector1.len()) + +retrieved_iris = vector1.get(iris_id) +print("Retrieved iris0 base64 == original iris0 base64:", iris0.code.to_base64() == retrieved_iris.code.to_base64() and iris0.mask.to_base64() == retrieved_iris.mask.to_base64()) + +query = PyIrisCode.uniform_random() +print("Search for random query iris code:", hnsw.search(query, vector1, graph1)) + +# write graph store to file +print("Writing graph store to file...") +graph1.write_to_bin("graph1.dat") + +# read HNSW graphs from disk +print("Reading vector and graph stores from file...") +vector2 = PyPlaintextStore.read_from_ndjson("vector.ndjson", 10000) +graph2 = PyGraphStore.read_from_bin("graph1.dat") + +print("Search for random query iris code:", hnsw.search(query, vector2, graph2)) diff --git a/iris-mpc-py/pyproject.toml b/iris-mpc-py/pyproject.toml new file mode 100644 index 000000000..8b731d0c3 --- /dev/null +++ b/iris-mpc-py/pyproject.toml @@ -0,0 +1,16 @@ +[build-system] +requires = ["maturin>=1.7,<2.0"] +build-backend = "maturin" + +[project] +name = "iris-mpc-py" +requires-python = ">=3.8" +classifiers = [ + "Programming Language :: Rust", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dynamic = ["version"] +[tool.maturin] +features = ["pyo3/extension-module"] +module-name = "iris_mpc_py" \ No newline at end of file diff --git a/iris-mpc-py/src/lib.rs b/iris-mpc-py/src/lib.rs new file mode 100644 index 000000000..d8301516c --- /dev/null +++ b/iris-mpc-py/src/lib.rs @@ -0,0 +1 @@ +pub mod py_hnsw; diff --git a/iris-mpc-py/src/py_hnsw/mod.rs b/iris-mpc-py/src/py_hnsw/mod.rs new file mode 100644 index 000000000..d5fe0536c --- /dev/null +++ b/iris-mpc-py/src/py_hnsw/mod.rs @@ -0,0 +1,2 @@ +pub mod pyclasses; +pub mod pymodule; diff --git a/iris-mpc-py/src/py_hnsw/pyclasses/graph_store.rs b/iris-mpc-py/src/py_hnsw/pyclasses/graph_store.rs new file mode 100644 index 000000000..fc6768f3d --- /dev/null +++ b/iris-mpc-py/src/py_hnsw/pyclasses/graph_store.rs @@ -0,0 +1,27 @@ +use hawk_pack::graph_store::GraphMem; +use iris_mpc_cpu::{hawkers::plaintext_store::PlaintextStore, py_bindings}; +use pyo3::{exceptions::PyIOError, prelude::*}; + +#[pyclass] +#[derive(Clone, Default)] +pub struct PyGraphStore(pub GraphMem); + +#[pymethods] +impl PyGraphStore { + #[new] + pub fn new() -> Self { + Self::default() + } + + #[staticmethod] + pub fn read_from_bin(filename: String) -> PyResult { + let result = py_bindings::io::read_bin(&filename) + .map_err(|_| PyIOError::new_err("Unable to read from file"))?; + Ok(Self(result)) + } + + pub fn write_to_bin(&self, filename: String) -> PyResult<()> { + py_bindings::io::write_bin(&self.0, &filename) + .map_err(|_| PyIOError::new_err("Unable to write to file")) + } +} diff --git a/iris-mpc-py/src/py_hnsw/pyclasses/hawk_searcher.rs b/iris-mpc-py/src/py_hnsw/pyclasses/hawk_searcher.rs new file mode 100644 index 000000000..1d154346a --- /dev/null +++ b/iris-mpc-py/src/py_hnsw/pyclasses/hawk_searcher.rs @@ -0,0 +1,124 @@ +use super::{graph_store::PyGraphStore, iris_code::PyIrisCode, plaintext_store::PyPlaintextStore}; +use hawk_pack::{ + hawk_searcher::{HawkerParams, N_PARAM_LAYERS}, + HawkSearcher, +}; +use iris_mpc_cpu::py_bindings; +use pyo3::{exceptions::PyIOError, prelude::*}; + +#[pyclass] +#[derive(Clone, Default)] +pub struct PyHawkSearcher(pub HawkSearcher); + +#[pymethods] +#[allow(non_snake_case)] +impl PyHawkSearcher { + #[new] + pub fn new(M: usize, ef_constr: usize, ef_search: usize) -> Self { + Self::new_standard(ef_constr, ef_search, M) + } + + #[staticmethod] + pub fn new_standard(M: usize, ef_constr: usize, ef_search: usize) -> Self { + let params = HawkerParams::new(ef_constr, ef_search, M); + Self(HawkSearcher { params }) + } + + #[staticmethod] + pub fn new_uniform(M: usize, ef: usize) -> Self { + let params = HawkerParams::new_uniform(ef, M); + Self(HawkSearcher { params }) + } + + /// Construct `HawkSearcher` with fully general parameters, specifying the + /// values of various parameters used during construction and search at + /// different levels of the graph hierarchy. + #[staticmethod] + pub fn new_general( + M: [usize; N_PARAM_LAYERS], + M_max: [usize; N_PARAM_LAYERS], + ef_constr_search: [usize; N_PARAM_LAYERS], + ef_constr_insert: [usize; N_PARAM_LAYERS], + ef_search: [usize; N_PARAM_LAYERS], + layer_probability: f64, + ) -> Self { + let params = HawkerParams { + M, + M_max, + ef_constr_search, + ef_constr_insert, + ef_search, + layer_probability, + }; + Self(HawkSearcher { params }) + } + + pub fn insert( + &self, + iris: PyIrisCode, + vector: &mut PyPlaintextStore, + graph: &mut PyGraphStore, + ) -> u32 { + let id = py_bindings::hnsw::insert(iris.0, &self.0, &mut vector.0, &mut graph.0); + id.0 + } + + pub fn insert_uniform_random( + &self, + vector: &mut PyPlaintextStore, + graph: &mut PyGraphStore, + ) -> u32 { + let id = py_bindings::hnsw::insert_uniform_random(&self.0, &mut vector.0, &mut graph.0); + id.0 + } + + pub fn fill_uniform_random( + &self, + num: usize, + vector: &mut PyPlaintextStore, + graph: &mut PyGraphStore, + ) { + py_bindings::hnsw::fill_uniform_random(num, &self.0, &mut vector.0, &mut graph.0); + } + + #[pyo3(signature = (filename, vector, graph, limit=None))] + pub fn fill_from_ndjson_file( + &self, + filename: String, + vector: &mut PyPlaintextStore, + graph: &mut PyGraphStore, + limit: Option, + ) { + py_bindings::hnsw::fill_from_ndjson_file( + &filename, + limit, + &self.0, + &mut vector.0, + &mut graph.0, + ); + } + + /// Search HNSW index and return nearest ID and its distance from query + pub fn search( + &mut self, + query: &PyIrisCode, + vector: &mut PyPlaintextStore, + graph: &mut PyGraphStore, + ) -> (u32, f64) { + let (id, dist) = + py_bindings::hnsw::search(query.0.clone(), &self.0, &mut vector.0, &mut graph.0); + (id.0, dist) + } + + #[staticmethod] + pub fn read_from_json(filename: String) -> PyResult { + let result = py_bindings::io::read_json(&filename) + .map_err(|_| PyIOError::new_err("Unable to read from file"))?; + Ok(Self(result)) + } + + pub fn write_to_json(&self, filename: String) -> PyResult<()> { + py_bindings::io::write_json(&self.0, &filename) + .map_err(|_| PyIOError::new_err("Unable to write to file")) + } +} diff --git a/iris-mpc-py/src/py_hnsw/pyclasses/iris_code.rs b/iris-mpc-py/src/py_hnsw/pyclasses/iris_code.rs new file mode 100644 index 000000000..c004344ee --- /dev/null +++ b/iris-mpc-py/src/py_hnsw/pyclasses/iris_code.rs @@ -0,0 +1,73 @@ +use super::iris_code_array::PyIrisCodeArray; +use iris_mpc_common::iris_db::iris::IrisCode; +use pyo3::{prelude::*, types::PyDict}; +use rand::rngs::ThreadRng; + +#[pyclass] +#[derive(Clone, Default)] +pub struct PyIrisCode(pub IrisCode); + +#[pymethods] +impl PyIrisCode { + #[new] + pub fn new(code: &PyIrisCodeArray, mask: &PyIrisCodeArray) -> Self { + Self(IrisCode { + code: code.0, + mask: mask.0, + }) + } + + #[getter] + pub fn code(&self) -> PyIrisCodeArray { + PyIrisCodeArray(self.0.code) + } + + #[getter] + pub fn mask(&self) -> PyIrisCodeArray { + PyIrisCodeArray(self.0.mask) + } + + #[staticmethod] + pub fn uniform_random() -> Self { + let mut rng = ThreadRng::default(); + Self(IrisCode::random_rng(&mut rng)) + } + + #[pyo3(signature = (version=None))] + pub fn to_open_iris_template_dict<'py>( + &self, + py: Python<'py>, + version: Option, + ) -> PyResult> { + let dict = PyDict::new_bound(py); + + dict.set_item("iris_codes", self.0.code.to_base64().unwrap())?; + dict.set_item("mask_codes", self.0.mask.to_base64().unwrap())?; + dict.set_item("iris_code_version", version)?; + + Ok(dict) + } + + #[staticmethod] + pub fn from_open_iris_template_dict(dict_obj: &Bound) -> PyResult { + // Extract base64-encoded iris code arrays + let iris_codes_str: String = dict_obj.get_item("iris_codes")?.unwrap().extract()?; + let mask_codes_str: String = dict_obj.get_item("mask_codes")?.unwrap().extract()?; + + // Convert the base64 strings into PyIrisCodeArrays + let code = PyIrisCodeArray::from_base64(iris_codes_str); + let mask = PyIrisCodeArray::from_base64(mask_codes_str); + + // Construct and return PyIrisCode + Ok(Self(IrisCode { + code: code.0, + mask: mask.0, + })) + } +} + +impl From for PyIrisCode { + fn from(value: IrisCode) -> Self { + Self(value) + } +} diff --git a/iris-mpc-py/src/py_hnsw/pyclasses/iris_code_array.rs b/iris-mpc-py/src/py_hnsw/pyclasses/iris_code_array.rs new file mode 100644 index 000000000..7d12fe3e7 --- /dev/null +++ b/iris-mpc-py/src/py_hnsw/pyclasses/iris_code_array.rs @@ -0,0 +1,46 @@ +use iris_mpc_common::iris_db::iris::IrisCodeArray; +use pyo3::prelude::*; +use rand::rngs::ThreadRng; + +#[pyclass] +#[derive(Clone, Default)] +pub struct PyIrisCodeArray(pub IrisCodeArray); + +#[pymethods] +impl PyIrisCodeArray { + #[new] + pub fn new(input: String) -> Self { + Self::from_base64(input) + } + + pub fn to_base64(&self) -> String { + self.0.to_base64().unwrap() + } + + #[staticmethod] + pub fn from_base64(input: String) -> Self { + Self(IrisCodeArray::from_base64(&input).unwrap()) + } + + #[staticmethod] + pub fn zeros() -> Self { + Self(IrisCodeArray::ZERO) + } + + #[staticmethod] + pub fn ones() -> Self { + Self(IrisCodeArray::ONES) + } + + #[staticmethod] + pub fn uniform_random() -> Self { + let mut rng = ThreadRng::default(); + Self(IrisCodeArray::random_rng(&mut rng)) + } +} + +impl From for PyIrisCodeArray { + fn from(value: IrisCodeArray) -> Self { + Self(value) + } +} diff --git a/iris-mpc-py/src/py_hnsw/pyclasses/mod.rs b/iris-mpc-py/src/py_hnsw/pyclasses/mod.rs new file mode 100644 index 000000000..eea66d959 --- /dev/null +++ b/iris-mpc-py/src/py_hnsw/pyclasses/mod.rs @@ -0,0 +1,5 @@ +pub mod graph_store; +pub mod hawk_searcher; +pub mod iris_code; +pub mod iris_code_array; +pub mod plaintext_store; diff --git a/iris-mpc-py/src/py_hnsw/pyclasses/plaintext_store.rs b/iris-mpc-py/src/py_hnsw/pyclasses/plaintext_store.rs new file mode 100644 index 000000000..f1d3fed19 --- /dev/null +++ b/iris-mpc-py/src/py_hnsw/pyclasses/plaintext_store.rs @@ -0,0 +1,52 @@ +use super::iris_code::PyIrisCode; +use iris_mpc_cpu::{ + hawkers::plaintext_store::{PlaintextIris, PlaintextPoint, PlaintextStore}, + py_bindings, +}; +use pyo3::{exceptions::PyIOError, prelude::*}; + +#[pyclass] +#[derive(Clone, Default)] +pub struct PyPlaintextStore(pub PlaintextStore); + +#[pymethods] +impl PyPlaintextStore { + #[new] + pub fn new() -> Self { + Self::default() + } + + pub fn get(&self, id: u32) -> PyIrisCode { + self.0.points[id as usize].data.0.clone().into() + } + + pub fn insert(&mut self, iris: PyIrisCode) -> u32 { + let new_id = self.0.points.len() as u32; + self.0.points.push(PlaintextPoint { + data: PlaintextIris(iris.0), + is_persistent: true, + }); + new_id + } + + pub fn len(&self) -> usize { + self.0.points.len() + } + + pub fn is_empty(&self) -> bool { + self.0.points.is_empty() + } + + #[staticmethod] + #[pyo3(signature = (filename, len=None))] + pub fn read_from_ndjson(filename: String, len: Option) -> PyResult { + let result = py_bindings::plaintext_store::from_ndjson_file(&filename, len) + .map_err(|_| PyIOError::new_err("Unable to read from file"))?; + Ok(Self(result)) + } + + pub fn write_to_ndjson(&self, filename: String) -> PyResult<()> { + py_bindings::plaintext_store::to_ndjson_file(&self.0, &filename) + .map_err(|_| PyIOError::new_err("Unable to write to file")) + } +} diff --git a/iris-mpc-py/src/py_hnsw/pymodule.rs b/iris-mpc-py/src/py_hnsw/pymodule.rs new file mode 100644 index 000000000..b0ceae8e3 --- /dev/null +++ b/iris-mpc-py/src/py_hnsw/pymodule.rs @@ -0,0 +1,15 @@ +use super::pyclasses::{ + graph_store::PyGraphStore, hawk_searcher::PyHawkSearcher, iris_code::PyIrisCode, + iris_code_array::PyIrisCodeArray, plaintext_store::PyPlaintextStore, +}; +use pyo3::prelude::*; + +#[pymodule] +fn iris_mpc_py(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + Ok(()) +} diff --git a/iris-mpc-store/Cargo.toml b/iris-mpc-store/Cargo.toml index 1a084d010..5dec19b00 100644 --- a/iris-mpc-store/Cargo.toml +++ b/iris-mpc-store/Cargo.toml @@ -8,16 +8,23 @@ license.workspace = true repository.workspace = true [dependencies] +aws-sdk-s3.workspace = true +bytes.workspace = true +async-trait.workspace = true iris-mpc-common = { path = "../iris-mpc-common" } bytemuck.workspace = true +csv.workspace = true futures.workspace = true sqlx.workspace = true eyre.workspace = true +hex.workspace = true itertools.workspace = true serde.workspace = true serde_json.workspace = true tracing.workspace = true +tokio.workspace = true rand.workspace = true +rayon.workspace = true [dev-dependencies] rand.workspace = true diff --git a/iris-mpc-store/migrations/20241121084719_remove_sequence.sql b/iris-mpc-store/migrations/20241121084719_remove_sequence.sql new file mode 100644 index 000000000..b78151277 --- /dev/null +++ b/iris-mpc-store/migrations/20241121084719_remove_sequence.sql @@ -0,0 +1 @@ +ALTER TABLE irises ALTER COLUMN id DROP IDENTITY IF EXISTS; \ No newline at end of file diff --git a/iris-mpc-store/migrations/20241206150412_add-modified-at.down.sql b/iris-mpc-store/migrations/20241206150412_add-modified-at.down.sql new file mode 100644 index 000000000..f33c3b008 --- /dev/null +++ b/iris-mpc-store/migrations/20241206150412_add-modified-at.down.sql @@ -0,0 +1,3 @@ +DROP TRIGGER IF EXISTS set_last_modified_at ON irises; +DROP FUNCTION IF EXISTS update_last_modified_at(); +ALTER TABLE irises DROP COLUMN last_modified_at; diff --git a/iris-mpc-store/migrations/20241206150412_add-modified-at.up.sql b/iris-mpc-store/migrations/20241206150412_add-modified-at.up.sql new file mode 100644 index 000000000..2d713a025 --- /dev/null +++ b/iris-mpc-store/migrations/20241206150412_add-modified-at.up.sql @@ -0,0 +1,14 @@ +ALTER TABLE irises ADD COLUMN last_modified_at BIGINT; + +CREATE OR REPLACE FUNCTION update_last_modified_at() +RETURNS TRIGGER AS $$ +BEGIN + NEW.last_modified_at = EXTRACT(EPOCH FROM NOW())::BIGINT; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +CREATE TRIGGER set_last_modified_at + BEFORE INSERT OR UPDATE ON irises + FOR EACH ROW + EXECUTE FUNCTION update_last_modified_at(); diff --git a/iris-mpc-store/src/lib.rs b/iris-mpc-store/src/lib.rs index 9a0fdd883..b1c13e923 100644 --- a/iris-mpc-store/src/lib.rs +++ b/iris-mpc-store/src/lib.rs @@ -1,15 +1,21 @@ +#![feature(int_roundings)] + +mod s3_importer; + use bytemuck::cast_slice; use eyre::{eyre, Result}; use futures::{ stream::{self}, - Stream, + Stream, TryStreamExt, }; use iris_mpc_common::{ config::Config, galois_engine::degree4::{GaloisRingIrisCodeShare, GaloisRingTrimmedMaskCodeShare}, iris_db::iris::IrisCode, + IRIS_CODE_LENGTH, MASK_CODE_LENGTH, }; use rand::{rngs::StdRng, Rng, SeedableRng}; +pub use s3_importer::{fetch_and_parse_chunks, last_snapshot_timestamp, ObjectStore, S3Store}; use sqlx::{ migrate::Migrator, postgres::PgPoolOptions, Executor, PgPool, Postgres, Row, Transaction, }; @@ -31,6 +37,12 @@ fn sql_switch_schema(schema_name: &str) -> Result { )) } +// Enum to define the source of the irises +pub enum IrisSource { + S3(StoredIris), + DB(StoredIris), +} + #[derive(sqlx::FromRow, Debug, Default, PartialEq, Eq)] pub struct StoredIris { #[allow(dead_code)] @@ -65,10 +77,48 @@ impl StoredIris { pub fn id(&self) -> i64 { self.id } + + pub fn from_bytes(bytes: &[u8]) -> Result { + let mut cursor = 0; + + // Helper closure to extract a slice of a given size + let extract_slice = + |bytes: &[u8], cursor: &mut usize, size: usize| -> Result, eyre::Error> { + if *cursor + size > bytes.len() { + return Err(eyre!("Exceeded total bytes while extracting slice",)); + } + let slice = &bytes[*cursor..*cursor + size]; + *cursor += size; + Ok(slice.to_vec()) + }; + + // Parse `id` (i64) + let id_bytes = extract_slice(bytes, &mut cursor, 4)?; + let id = u32::from_be_bytes( + id_bytes + .try_into() + .map_err(|_| eyre!("Failed to convert id bytes to i64"))?, + ) as i64; + + // parse codes and masks + let left_code = extract_slice(bytes, &mut cursor, IRIS_CODE_LENGTH * size_of::())?; + let left_mask = extract_slice(bytes, &mut cursor, MASK_CODE_LENGTH * size_of::())?; + let right_code = extract_slice(bytes, &mut cursor, IRIS_CODE_LENGTH * size_of::())?; + let right_mask = extract_slice(bytes, &mut cursor, MASK_CODE_LENGTH * size_of::())?; + + Ok(StoredIris { + id, + left_code, + left_mask, + right_code, + right_mask, + }) + } } #[derive(Clone)] pub struct StoredIrisRef<'a> { + pub id: i64, pub left_code: &'a [u16], pub left_mask: &'a [u16], pub right_code: &'a [u16], @@ -157,8 +207,9 @@ impl Store { /// Stream irises in parallel, without a particular order. pub async fn stream_irises_par( &self, + min_last_modified_at: Option, partitions: usize, - ) -> impl Stream> + '_ { + ) -> impl Stream> + '_ { let count = self.count_irises().await.expect("Failed count_irises"); let partition_size = count.div_ceil(partitions).max(1); @@ -168,14 +219,28 @@ impl Store { let start_id = 1 + partition_size * i; let end_id = start_id + partition_size - 1; - let partition_stream = - sqlx::query_as::<_, StoredIris>("SELECT * FROM irises WHERE id BETWEEN $1 AND $2") - .bind(start_id as i64) - .bind(end_id as i64) - .fetch(&self.pool); + let partition_stream = match min_last_modified_at { + Some(min_last_modified_at) => sqlx::query_as::<_, StoredIris>( + "SELECT id, left_code, left_mask, right_code, right_mask FROM irises WHERE id \ + BETWEEN $1 AND $2 AND last_modified_at >= $3", + ) + .bind(start_id as i64) + .bind(end_id as i64) + .bind(min_last_modified_at) + .fetch(&self.pool) + .map_err(Into::into), + None => sqlx::query_as::<_, StoredIris>( + "SELECT id, left_code, left_mask, right_code, right_mask FROM irises WHERE id \ + BETWEEN $1 AND $2", + ) + .bind(start_id as i64) + .bind(end_id as i64) + .fetch(&self.pool) + .map_err(Into::into), + }; partition_streams.push(Box::pin(partition_stream) - as Pin> + Send>>); + as Pin> + Send>>); } stream::select_all(partition_streams) @@ -190,9 +255,10 @@ impl Store { return Ok(vec![]); } let mut query = sqlx::QueryBuilder::new( - "INSERT INTO irises (left_code, left_mask, right_code, right_mask)", + "INSERT INTO irises (id, left_code, left_mask, right_code, right_mask)", ); query.push_values(codes_and_masks, |mut query, iris| { + query.push_bind(iris.id); query.push_bind(cast_slice::(iris.left_code)); query.push_bind(cast_slice::(iris.left_mask)); query.push_bind(cast_slice::(iris.right_code)); @@ -212,6 +278,36 @@ impl Store { Ok(ids) } + pub async fn insert_irises_overriding( + &self, + tx: &mut Transaction<'_, Postgres>, + codes_and_masks: &[StoredIrisRef<'_>], + ) -> Result<()> { + if codes_and_masks.is_empty() { + return Ok(()); + } + let mut query = sqlx::QueryBuilder::new( + "INSERT INTO irises (id, left_code, left_mask, right_code, right_mask)", + ); + query.push_values(codes_and_masks, |mut query, iris| { + query.push_bind(iris.id); + query.push_bind(cast_slice::(iris.left_code)); + query.push_bind(cast_slice::(iris.left_mask)); + query.push_bind(cast_slice::(iris.right_code)); + query.push_bind(cast_slice::(iris.right_mask)); + }); + query.push( + r#" +ON CONFLICT (id) +DO UPDATE SET left_code = EXCLUDED.left_code, left_mask = EXCLUDED.left_mask, right_code = EXCLUDED.right_code, right_mask = EXCLUDED.right_mask; +"#, + ); + + query.build().execute(tx.deref_mut()).await?; + + Ok(()) + } + /// Update existing iris with given shares. pub async fn update_iris( &self, @@ -290,27 +386,6 @@ DO UPDATE SET right_code = EXCLUDED.right_code, right_mask = EXCLUDED.right_mask Ok(()) } - async fn set_sequence_id( - &self, - id: usize, - executor: impl sqlx::Executor<'_, Database = Postgres>, - ) -> Result<()> { - if id == 0 { - // If requested id is 0 (only used in tests), reset the sequence to 1 with - // advance_nextval set to false. This is because serial id starts from 1. - sqlx::query("SELECT setval(pg_get_serial_sequence('irises', 'id'), 1, false)") - .execute(executor) - .await?; - } else { - sqlx::query("SELECT setval(pg_get_serial_sequence('irises', 'id'), $1, true)") - .bind(id as i64) - .execute(executor) - .await?; - } - - Ok(()) - } - pub async fn rollback(&self, db_len: usize) -> Result<()> { let mut tx = self.pool.begin().await?; @@ -319,18 +394,12 @@ DO UPDATE SET right_code = EXCLUDED.right_code, right_mask = EXCLUDED.right_mask .execute(&mut *tx) .await?; - self.set_sequence_id(db_len, &mut *tx).await?; - tx.commit().await?; Ok(()) } - pub async fn set_irises_sequence_id(&self, id: usize) -> Result<()> { - self.set_sequence_id(id, &self.pool).await - } - - pub async fn get_irises_sequence_id(&self) -> Result { - let id: (i64,) = sqlx::query_as("SELECT last_value FROM irises_id_seq") + pub async fn get_max_serial_id(&self) -> Result { + let id: (i64,) = sqlx::query_as("SELECT MAX(id) FROM irises") .fetch_one(&self.pool) .await?; Ok(id.0 as usize) @@ -353,16 +422,6 @@ DO UPDATE SET right_code = EXCLUDED.right_code, right_mask = EXCLUDED.right_mask Ok(()) } - pub async fn update_iris_id_sequence(&self) -> Result<()> { - sqlx::query( - "SELECT setval(pg_get_serial_sequence('irises', 'id'), COALESCE(MAX(id), 0), true) \ - FROM irises", - ) - .execute(&self.pool) - .await?; - Ok(()) - } - pub async fn last_results(&self, count: usize) -> Result> { let mut result_events: Vec = sqlx::query_scalar("SELECT result_event FROM results ORDER BY id DESC LIMIT $1") @@ -442,6 +501,7 @@ DO UPDATE SET right_code = EXCLUDED.right_code, right_mask = EXCLUDED.right_mask // inserting shares and masks in the db. Reusing the same share and mask for // left and right self.insert_irises(&mut tx, &[StoredIrisRef { + id: (i + 1) as i64, left_code: &share.coefs, left_mask: &mask.coefs, right_code: &share.coefs, @@ -489,7 +549,7 @@ mod tests { use super::*; use futures::TryStreamExt; - use iris_mpc_common::helpers::smpc_request::UniquenessResult; + use iris_mpc_common::helpers::smpc_response::UniquenessResult; #[tokio::test] async fn test_store() -> Result<()> { @@ -500,23 +560,30 @@ mod tests { let got: Vec = store.stream_irises().await.try_collect().await?; assert_eq!(got.len(), 0); - let got: Vec = store.stream_irises_par(2).await.try_collect().await?; + let got: Vec = store + .stream_irises_par(Some(0), 2) + .await + .try_collect() + .await?; assert_eq!(got.len(), 0); let codes_and_masks = &[ StoredIrisRef { + id: 1, left_code: &[1, 2, 3, 4], left_mask: &[5, 6, 7, 8], right_code: &[9, 10, 11, 12], right_mask: &[13, 14, 15, 16], }, StoredIrisRef { + id: 2, left_code: &[1117, 18, 19, 20], left_mask: &[21, 1122, 23, 24], right_code: &[25, 26, 1127, 28], right_mask: &[29, 30, 31, 1132], }, StoredIrisRef { + id: 3, left_code: &[17, 18, 19, 20], left_mask: &[21, 22, 23, 24], // Empty is allowed until stereo is implemented. @@ -532,7 +599,11 @@ mod tests { let got_len = store.count_irises().await?; let got: Vec = store.stream_irises().await.try_collect().await?; - let mut got_par: Vec = store.stream_irises_par(2).await.try_collect().await?; + let mut got_par: Vec = store + .stream_irises_par(Some(0), 2) + .await + .try_collect() + .await?; got_par.sort_by_key(|iris| iris.id); assert_eq!(got, got_par); @@ -568,18 +639,23 @@ mod tests { #[tokio::test] async fn test_insert_many() -> Result<()> { - let count = 1 << 3; + let count: usize = 1 << 3; let schema_name = temporary_name(); let store = Store::new(&test_db_url()?, &schema_name).await?; - let iris = StoredIrisRef { - left_code: &[123_u16; 12800], - left_mask: &[456_u16; 12800], - right_code: &[789_u16; 12800], - right_mask: &[101_u16; 12800], - }; - let codes_and_masks = vec![iris; count]; + let mut codes_and_masks = vec![]; + + for i in 0..count { + let iris = StoredIrisRef { + id: (i + 1) as i64, + left_code: &[123_u16; 12800], + left_mask: &[456_u16; 12800], + right_code: &[789_u16; 12800], + right_mask: &[101_u16; 12800], + }; + codes_and_masks.push(iris); + } let result_event = serde_json::to_string(&UniquenessResult::new( 0, @@ -605,7 +681,7 @@ mod tests { // Compare with the parallel version with several edge-cases. for parallelism in [1, 5, MAX_CONNECTIONS as usize + 1] { let mut got_par: Vec = store - .stream_irises_par(parallelism) + .stream_irises_par(Some(0), parallelism) .await .try_collect() .await?; @@ -641,15 +717,20 @@ mod tests { let schema_name = temporary_name(); let store = Store::new(&test_db_url()?, &schema_name).await?; - let iris = StoredIrisRef { - left_code: &[123_u16; 12800], - left_mask: &[456_u16; 12800], - right_code: &[789_u16; 12800], - right_mask: &[101_u16; 12800], - }; + let mut irises = vec![]; + for i in 0..10 { + let iris = StoredIrisRef { + id: (i + 1) as i64, + left_code: &[123_u16; 12800], + left_mask: &[456_u16; 12800], + right_code: &[789_u16; 12800], + right_mask: &[101_u16; 12800], + }; + irises.push(iris); + } let mut tx = store.tx().await?; - store.insert_irises(&mut tx, &vec![iris; 10]).await?; + store.insert_irises(&mut tx, &irises).await?; tx.commit().await?; store.rollback(5).await?; @@ -779,31 +860,37 @@ mod tests { let store = Store::new(&test_db_url()?, &schema_name).await?; // insert two irises into db - let iris = StoredIrisRef { + let iris1 = StoredIrisRef { + id: 1, left_code: &[123_u16; 12800], left_mask: &[456_u16; 6400], right_code: &[789_u16; 12800], right_mask: &[101_u16; 6400], }; + let mut iris2 = iris1.clone(); + iris2.id = 2; + let mut tx = store.tx().await?; - store.insert_irises(&mut tx, &vec![iris.clone(); 2]).await?; + store + .insert_irises(&mut tx, &[iris1, iris2.clone()]) + .await?; tx.commit().await?; // update iris with id 1 in db let updated_left_code = GaloisRingIrisCodeShare { - id: 0, + id: 1, coefs: [666_u16; 12800], }; let updated_left_mask = GaloisRingTrimmedMaskCodeShare { - id: 0, + id: 1, coefs: [777_u16; 6400], }; let updated_right_code = GaloisRingIrisCodeShare { - id: 0, + id: 1, coefs: [888_u16; 12800], }; let updated_right_mask = GaloisRingTrimmedMaskCodeShare { - id: 0, + id: 1, coefs: [999_u16; 6400], }; store @@ -825,10 +912,10 @@ mod tests { assert_eq!(cast_u8_to_u16(&got[0].right_mask), updated_right_mask.coefs); // assert the other iris in db is not updated - assert_eq!(cast_u8_to_u16(&got[1].left_code), iris.left_code); - assert_eq!(cast_u8_to_u16(&got[1].left_mask), iris.left_mask); - assert_eq!(cast_u8_to_u16(&got[1].right_code), iris.right_code); - assert_eq!(cast_u8_to_u16(&got[1].right_mask), iris.right_mask); + assert_eq!(cast_u8_to_u16(&got[1].left_code), iris2.left_code); + assert_eq!(cast_u8_to_u16(&got[1].left_mask), iris2.left_mask); + assert_eq!(cast_u8_to_u16(&got[1].right_code), iris2.right_code); + assert_eq!(cast_u8_to_u16(&got[1].right_mask), iris2.right_mask); cleanup(&store, &schema_name).await?; Ok(()) diff --git a/iris-mpc-store/src/s3_importer.rs b/iris-mpc-store/src/s3_importer.rs new file mode 100644 index 000000000..75693fafe --- /dev/null +++ b/iris-mpc-store/src/s3_importer.rs @@ -0,0 +1,333 @@ +use crate::StoredIris; +use async_trait::async_trait; +use aws_sdk_s3::{primitives::ByteStream, Client}; +use futures::{stream, Stream, StreamExt}; +use iris_mpc_common::{IRIS_CODE_LENGTH, MASK_CODE_LENGTH}; +use std::{ + mem, + pin::Pin, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + time::Instant, +}; +use tokio::io::AsyncReadExt; + +const SINGLE_ELEMENT_SIZE: usize = IRIS_CODE_LENGTH * mem::size_of::() * 2 + + MASK_CODE_LENGTH * mem::size_of::() * 2 + + mem::size_of::(); // 75 KB + +const MAX_RANGE_SIZE: usize = 200; // Download chunks in sub-chunks of 200 elements = 15 MB + +#[async_trait] +pub trait ObjectStore: Send + Sync + 'static { + async fn get_object(&self, key: &str, range: (usize, usize)) -> eyre::Result; + async fn list_objects(&self, prefix: &str) -> eyre::Result>; +} + +pub struct S3Store { + client: Arc, + bucket: String, +} + +impl S3Store { + pub fn new(client: Arc, bucket: String) -> Self { + Self { client, bucket } + } +} + +#[async_trait] +impl ObjectStore for S3Store { + async fn get_object(&self, key: &str, range: (usize, usize)) -> eyre::Result { + let res = self + .client + .get_object() + .bucket(&self.bucket) + .key(key) + .range(format!("bytes={}-{}", range.0, range.1 - 1)) + .send() + .await?; + + Ok(res.body) + } + + async fn list_objects(&self, prefix: &str) -> eyre::Result> { + let mut objects = Vec::new(); + let mut continuation_token = None; + + loop { + let mut request = self + .client + .list_objects_v2() + .bucket(&self.bucket) + .prefix(prefix); + + if let Some(token) = continuation_token { + request = request.continuation_token(token); + } + + let response = request.send().await?; + + objects.extend( + response + .contents() + .iter() + .filter_map(|obj| obj.key().map(String::from)), + ); + + match response.next_continuation_token() { + Some(token) => continuation_token = Some(token.to_string()), + None => break, + } + } + + Ok(objects) + } +} + +#[derive(Debug)] +pub struct LastSnapshotDetails { + pub timestamp: i64, + pub last_serial_id: i64, + pub chunk_size: i64, +} + +impl LastSnapshotDetails { + // Parse last snapshot from s3 file name. + // It is in {unixTime}_{batchSize}_{lastSerialId} format. + pub fn new_from_str(last_snapshot_str: &str) -> Option { + let parts: Vec<&str> = last_snapshot_str.split('_').collect(); + match parts.len() { + 3 => Some(Self { + timestamp: parts[0].parse().unwrap(), + chunk_size: parts[1].parse().unwrap(), + last_serial_id: parts[2].parse().unwrap(), + }), + _ => { + tracing::warn!("Invalid export timestamp file name: {}", last_snapshot_str); + None + } + } + } +} + +pub async fn last_snapshot_timestamp( + store: &impl ObjectStore, + prefix_name: String, +) -> eyre::Result { + tracing::info!("Looking for last snapshot time in prefix: {}", prefix_name); + let timestamps_path = format!("{}/timestamps/", prefix_name); + store + .list_objects(timestamps_path.as_str()) + .await? + .into_iter() + .filter_map(|f| match f.split('/').last() { + Some(file_name) => LastSnapshotDetails::new_from_str(file_name), + _ => None, + }) + .max_by_key(|s| s.timestamp) + .ok_or_else(|| eyre::eyre!("No snapshot found")) +} + +pub async fn fetch_and_parse_chunks( + store: &impl ObjectStore, + concurrency: usize, + prefix_name: String, + last_snapshot_details: LastSnapshotDetails, +) -> Pin> + Send + '_>> { + tracing::info!("Generating chunk files using: {:?}", last_snapshot_details); + let range_size = if last_snapshot_details.chunk_size as usize > MAX_RANGE_SIZE { + MAX_RANGE_SIZE + } else { + last_snapshot_details.chunk_size as usize + }; + let total_bytes = Arc::new(AtomicUsize::new(0)); + let now = Instant::now(); + + let result_stream = + stream::iter((1..=last_snapshot_details.last_serial_id).step_by(range_size)) + .map({ + let total_bytes_clone = total_bytes.clone(); + move |chunk| { + let counter = total_bytes_clone.clone(); + let prefix_name = prefix_name.clone(); + async move { + let chunk_id = (chunk / last_snapshot_details.chunk_size) + * last_snapshot_details.chunk_size + + 1; + let offset_within_chunk = (chunk - chunk_id) as usize; + let mut object_stream = store + .get_object( + &format!("{}/{}.bin", prefix_name, chunk_id), + ( + offset_within_chunk * SINGLE_ELEMENT_SIZE, + (offset_within_chunk + range_size) * SINGLE_ELEMENT_SIZE, + ), + ) + .await? + .into_async_read(); + let mut records = Vec::with_capacity(range_size); + let mut buf = vec![0u8; SINGLE_ELEMENT_SIZE]; + loop { + match object_stream.read_exact(&mut buf).await { + Ok(_) => { + let iris = StoredIris::from_bytes(&buf); + records.push(iris); + counter.fetch_add(SINGLE_ELEMENT_SIZE, Ordering::Relaxed); + } + Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break, + Err(e) => return Err(e.into()), + } + } + Ok::<_, eyre::Error>(stream::iter(records)) + } + } + }) + .buffer_unordered(concurrency) + .flat_map(|result| match result { + Ok(stream) => stream.boxed(), + Err(e) => stream::once(async move { Err(e) }).boxed(), + }) + .inspect({ + let counter = Arc::new(AtomicUsize::new(0)); + move |_| { + if counter.fetch_add(1, Ordering::Relaxed) % 1_000_000 == 0 { + let elapsed = now.elapsed().as_secs_f32(); + if elapsed > 0.0 { + let bytes = total_bytes.load(Ordering::Relaxed); + tracing::info!( + "Current download throughput: {:.2} Gbps", + bytes as f32 * 8.0 / 1e9 / elapsed + ); + } + } + } + }) + .boxed(); + + result_stream +} + +#[cfg(test)] +mod tests { + use super::*; + use aws_sdk_s3::primitives::SdkBody; + use rand::Rng; + use std::{cmp::min, collections::HashSet}; + + #[derive(Default, Clone)] + pub struct MockStore { + objects: std::collections::HashMap>, + } + + impl MockStore { + pub fn new() -> Self { + Self::default() + } + + pub fn add_timestamp_file(&mut self, key: &str) { + self.objects.insert(key.to_string(), Vec::new()); + } + + pub fn add_test_data(&mut self, key: &str, records: Vec) { + let mut result = Vec::new(); + for record in records { + result.extend_from_slice(&(record.id as u32).to_be_bytes()); + result.extend_from_slice(&record.left_code); + result.extend_from_slice(&record.left_mask); + result.extend_from_slice(&record.right_code); + result.extend_from_slice(&record.right_mask); + } + self.objects.insert(key.to_string(), result); + } + } + + #[async_trait] + impl ObjectStore for MockStore { + async fn get_object(&self, key: &str, range: (usize, usize)) -> eyre::Result { + let bytes = self + .objects + .get(key) + .cloned() + .ok_or_else(|| eyre::eyre!("Object not found: {}", key))?; + + // Handle the range parameter by slicing the bytes + let start = range.0; + let end = range.1.min(bytes.len()); + let sliced_bytes = bytes[start..end].to_vec(); + + Ok(ByteStream::from(SdkBody::from(sliced_bytes))) + } + + async fn list_objects(&self, _: &str) -> eyre::Result> { + Ok(self.objects.keys().cloned().collect()) + } + } + + fn random_bytes(len: usize) -> Vec { + let mut rng = rand::thread_rng(); + let mut v = vec![0u8; len]; + v.fill_with(|| rng.gen()); + v + } + + fn dummy_entry(id: usize) -> StoredIris { + StoredIris { + id: id as i64, + left_code: random_bytes(IRIS_CODE_LENGTH * mem::size_of::()), + left_mask: random_bytes(MASK_CODE_LENGTH * mem::size_of::()), + right_code: random_bytes(IRIS_CODE_LENGTH * mem::size_of::()), + right_mask: random_bytes(MASK_CODE_LENGTH * mem::size_of::()), + } + } + + #[tokio::test] + async fn test_last_snapshot_timestamp() { + let mut store = MockStore::new(); + store.add_timestamp_file("out/timestamps/123_100_954"); + store.add_timestamp_file("out/timestamps/124_100_958"); + store.add_timestamp_file("out/timestamps/125_100_958"); + + let last_snapshot = last_snapshot_timestamp(&store, "out".to_string()) + .await + .unwrap(); + assert_eq!(last_snapshot.timestamp, 125); + assert_eq!(last_snapshot.last_serial_id, 958); + assert_eq!(last_snapshot.chunk_size, 100); + } + + #[tokio::test] + async fn test_fetch_and_parse_chunks() { + const MOCK_ENTRIES: usize = 107; + const MOCK_CHUNK_SIZE: usize = 10; + let mut store = MockStore::new(); + let n_chunks = MOCK_ENTRIES.div_ceil(MOCK_CHUNK_SIZE); + for i in 0..n_chunks { + let start_serial_id = i * MOCK_CHUNK_SIZE + 1; + let end_serial_id = min((i + 1) * MOCK_CHUNK_SIZE, MOCK_ENTRIES); + store.add_test_data( + &format!("out/{start_serial_id}.bin"), + (start_serial_id..=end_serial_id).map(dummy_entry).collect(), + ); + } + + assert_eq!(store.list_objects("").await.unwrap().len(), n_chunks); + let last_snapshot_details = LastSnapshotDetails { + timestamp: 0, + last_serial_id: MOCK_ENTRIES as i64, + chunk_size: MOCK_CHUNK_SIZE as i64, + }; + let mut chunks = + fetch_and_parse_chunks(&store, 1, "out".to_string(), last_snapshot_details).await; + let mut count = 0; + let mut ids: HashSet = HashSet::from_iter(1..MOCK_ENTRIES); + while let Some(chunk) = chunks.next().await { + let chunk = chunk.unwrap(); + ids.remove(&(chunk.id as usize)); + count += 1; + } + assert_eq!(count, MOCK_ENTRIES); + assert!(ids.is_empty()); + } +} diff --git a/iris-mpc-upgrade/Cargo.toml b/iris-mpc-upgrade/Cargo.toml index cdc5b2a12..fe1c040f8 100644 --- a/iris-mpc-upgrade/Cargo.toml +++ b/iris-mpc-upgrade/Cargo.toml @@ -11,7 +11,7 @@ repository.workspace = true axum.workspace = true iris-mpc-common = { path = "../iris-mpc-common" } iris-mpc-store = { path = "../iris-mpc-store" } -clap.workspace = true +clap = { workspace = true, features = ["env"] } eyre.workspace = true bytemuck.workspace = true sqlx.workspace = true @@ -30,10 +30,23 @@ mpc-uniqueness-check = { package = "mpc", git = "https://github.com/worldcoin/mp indicatif = "0.17.8" rcgen = "0.13.1" tokio-native-tls = "0.3.1" +tonic = { version = "0.12.3", features = [ + "tls", + "tls-native-roots", + "transport", +] } +prost = "0.13.3" +sha2 = "0.10.8" +thiserror.workspace = true +hkdf = "0.12.4" [dev-dependencies] float_eq = "1" + +[build-dependencies] +tonic-build = "0.12.3" + [[bin]] name = "upgrade-checker" path = "src/bin/checker.rs" @@ -49,3 +62,15 @@ path = "src/bin/tcp_ssl_upgrade_client.rs" [[bin]] name = "seed-v1-dbs" path = "src/bin/seed_v1_dbs.rs" + +[[bin]] +name = "seed-v2-dbs" +path = "src/bin/seed_v2_dbs.rs" + +[[bin]] +name = "reshare-server" +path = "src/bin/reshare-server.rs" + +[[bin]] +name = "reshare-client" +path = "src/bin/reshare-client.rs" diff --git a/iris-mpc-upgrade/build.rs b/iris-mpc-upgrade/build.rs new file mode 100644 index 000000000..8a43d83c9 --- /dev/null +++ b/iris-mpc-upgrade/build.rs @@ -0,0 +1,11 @@ +fn main() { + println!("cargo:rerun-if-changed=build.rs"); + println!("cargo:rerun-if-changed=protos/reshare.proto"); + tonic_build::configure() + .out_dir("src/proto/") + .compile_protos( + &["reshare.proto"], // Files in the path + &["protos"], // The include path to search + ) + .unwrap(); +} diff --git a/iris-mpc-upgrade/protos/reshare.proto b/iris-mpc-upgrade/protos/reshare.proto new file mode 100644 index 000000000..3ceef37ab --- /dev/null +++ b/iris-mpc-upgrade/protos/reshare.proto @@ -0,0 +1,35 @@ + +syntax = "proto3"; +package iris_mpc_reshare; + +message IrisCodeReShare { + bytes LeftIrisCodeShare = 1; + bytes LeftMaskShare = 2; + bytes RightIrisCodeShare = 3; + bytes RightMaskShare = 4; +} + +message IrisCodeReShareRequest { + uint64 SenderId = 1; + uint64 OtherId = 2; + uint64 ReceiverId = 3; + int64 IdRangeStartInclusive = 4; + int64 IdRangeEndNonInclusive = 5; + repeated IrisCodeReShare IrisCodeReShares = 6; + bytes ClientCorrelationSanityCheck = 7; +} + +message IrisCodeReShareResponse { + IrisCodeReShareStatus Status = 1; + string Message = 2; +} + +enum IrisCodeReShareStatus { + IRIS_CODE_RE_SHARE_STATUS_OK = 0; + IRIS_CODE_RE_SHARE_STATUS_FULL_QUEUE = 1; + IRIS_CODE_RE_SHARE_STATUS_ERROR = 2; +} + +service IrisCodeReShareService { + rpc ReShare(IrisCodeReShareRequest) returns (IrisCodeReShareResponse); +} diff --git a/iris-mpc-upgrade/src/bin/.gitignore b/iris-mpc-upgrade/src/bin/.gitignore index c3d47bb9d..cc0468a13 100644 --- a/iris-mpc-upgrade/src/bin/.gitignore +++ b/iris-mpc-upgrade/src/bin/.gitignore @@ -1,3 +1,4 @@ out0/ out1/ out2/ +*.log diff --git a/iris-mpc-upgrade/src/bin/README.md b/iris-mpc-upgrade/src/bin/README.md index caa02666e..53e25c6c8 100644 --- a/iris-mpc-upgrade/src/bin/README.md +++ b/iris-mpc-upgrade/src/bin/README.md @@ -19,7 +19,7 @@ cargo run --release --bin seed-v1-dbs -- --side left --shares-db-urls postgres:/ ## Upgrade for left eye -### Run the 3 upgrade servers +### Run the 3 upgrade servers Concurrently run: @@ -71,12 +71,86 @@ Concurrently run: ```bash cargo run --release --bin upgrade-client -- --server1 127.0.0.1:8000 --server2 127.0.0.1:8001 --server3 127.0.0.1:8002 --db-start 0 --db-end 10000 --party-id 0 --eye right --shares-db-url postgres://postgres:postgres@localhost:6100 --masks-db-url postgres://postgres:postgres@localhost:6111 ``` + ```bash cargo run --release --bin upgrade-client -- --server1 127.0.0.1:8000 --server2 127.0.0.1:8001 --server3 127.0.0.1:8002 --db-start 0 --db-end 10000 --party-id 1 --eye right --shares-db-url postgres://postgres:postgres@localhost:6101 --masks-db-url postgres://postgres:postgres@localhost:6111 ``` + ## Check the upgrade was successful ```bash cargo run --release --bin upgrade-checker -- --environment dev --num-elements 10000 --db-urls postgres://postgres:postgres@localhost:6100 --db-urls postgres://postgres:postgres@localhost:6101 --db-urls postgres://postgres:postgres@localhost:6111 --db-urls postgres://postgres:postgres@localhost:6200 --db-urls postgres://postgres:postgres@localhost:6201 --db-urls postgres://postgres:postgres@localhost:6202 ``` + +# Reshare Protocol + +The aim of the reshare protocol is to allow 2 existing parties in SMPCv2 to work together to recover the share of another party using a simple MPC functionality. + +## Internal Server structure + +The current internal structure of this service works as follows: + +* The receiving party hosts a GRPC server to receive reshare batches from the two sending parties. +* The two sending parties send reshare batches via GRPC. +* The GPRC server collects reshare request batches from the two clients and stores it internally. +* Once matching requests from both parties are collected, the server processes the requests and stores them to the DB. + +Currently, the matching is not very robust and requires that both clients send batches for the exact ranges (i.e., client 1 and 2 send batch for ids 1-100, it cannot handle client 1 sending 1-100 and client 2 sending 1-50 and 51-100). + +## Example Protocol run + +In this example we start a reshare process where parties 0 and 1 are the senders (i.e., clients) and party 2 is the receiver (i.e., server). + +### Bring up some DBs and seed them + +Here, the seed-v2-dbs binary just creates fully replicated DB for 3 parties, in DBs with ports 6200,6201,6202. Additionally, there is also another DB at 6203, which we will use as a target for the reshare protocol to fill into. + +```bash +docker-compose up -d +cargo run --release --bin seed-v2-dbs -- --db-url-party1 postgres://postgres:postgres@localhost:6200 --db-url-party2 postgres://postgres:postgres@localhost:6201 --db-url-party3 postgres://postgres:postgres@localhost:6202 --schema-name-party1 SMPC_testing_0 --schema-name-party2 SMPC_testing_1 --schema-name-party3 SMPC_testing_2 --fill-to 10000 --batch-size 100 +``` + +### Start a server for the receiving party + +```bash +cargo run --release --bin reshare-server -- --party-id 2 --sender1-party-id 0 --sender2-party-id 1 --bind-addr 0.0.0.0:7000 --environment testing --db-url postgres://postgres:postgres@localhost:6203 --db-start 1 --db-end 10001 --batch-size 100 +``` + +Short rundown of the parameters: + +* `party-id`: the 0-indexed party id of the receiving party. This corresponds to the (i+1)-th point on the exceptional sequence for Shamir poly evaluation +* `sender1-party-id`: The party id of the first sender, just for sanity checks against received packets. (Order between sender1 and sender2 does not matter here) +* `sender2-party-id`: The party id of the second sender, just for sanity checks against received packets. +* `bind-addr`: Socket addr to bind to for gGRPC server. +* `environment`: Which environment are we running in, used for DB schema name +* `db-url`: Postgres connection string. We save the results in this DB +* `db-start`: Expected range of DB entries to receive, just used for sanity checks. Start is inclusive. +* `db-end`: Expected range of DB entries to receive, just used for sanity checks. End is exclusive. +* `batch-size`: maximum size of received reshare batches + +### Start clients for the sending parties + +```bash +cargo run --release --bin reshare-client -- --party-id 0 --other-party-id 1 --target-party-id 2 --server-url http://localhost:7000 --environment testing --db-url postgres://postgres:postgres@localhost:6200 --db-start 1 --db-end 10001 --batch-size 100 +``` + +```bash +cargo run --release --bin reshare-client -- --party-id 1 --other-party-id 0 --target-party-id 2 --server-url http://localhost:7000 --environment testing --db-url postgres://postgres:postgres@localhost:6201 --db-start 1 --db-end 10001 --batch-size 100 +``` + +Short rundown of the parameters: + +* `party-id`: the 0-indexed party id of our own client party. This corresponds to the (i+1)-th point on the exceptional sequence for Shamir poly evaluation +* `other-party-id`: the 0-indexed party id of the other client party. This needs to be passed for the correct calculation of lagrange interpolation polynomials. +* `target-party-id`: the 0-indexed party id of the receiving party. This needs to be passed for the correct calculation of lagrange interpolation polynomials. +* `server-url`: Url where to reach the GRPC server (can also be https, client supports both). +* `environment`: Which environment are we running in, used for DB schema name +* `db-url`: Postgres connection string. We load our shares from this DB +* `db-start`: Range of DB entries to send. Start is inclusive. +* `db-end`: Range of DB entries to send. End is exclusive. +* `batch-size`: maximum size of sent reshare batches + +### Checking results + +Since the shares on a given shamir poly are deterministic given the party ids, the above upgrade process can be checked by comparing the databases at port 6202 and 6203 for equality. diff --git a/iris-mpc-upgrade/src/bin/docker-compose.yaml b/iris-mpc-upgrade/src/bin/docker-compose.yaml index cab988519..40198e699 100644 --- a/iris-mpc-upgrade/src/bin/docker-compose.yaml +++ b/iris-mpc-upgrade/src/bin/docker-compose.yaml @@ -1,5 +1,4 @@ services: - old-db-shares-1: image: postgres:16 ports: @@ -42,3 +41,18 @@ services: environment: POSTGRES_USER: "postgres" POSTGRES_PASSWORD: "postgres" + new-db-4: + image: postgres:16 + ports: + - "6203:5432" + environment: + POSTGRES_USER: "postgres" + POSTGRES_PASSWORD: "postgres" + localstack: + image: localstack/localstack + ports: + - "127.0.0.1:4566:4566" + - "127.0.0.1:4571:4571" + environment: + - SERVICES=kms + - DEFAULT_REGION=us-east-1 diff --git a/iris-mpc-upgrade/src/bin/reshare-client.rs b/iris-mpc-upgrade/src/bin/reshare-client.rs new file mode 100644 index 000000000..8cd751fee --- /dev/null +++ b/iris-mpc-upgrade/src/bin/reshare-client.rs @@ -0,0 +1,154 @@ +use clap::Parser; +use futures::StreamExt; +use hkdf::Hkdf; +use iris_mpc_common::{ + galois_engine::degree4::{GaloisRingIrisCodeShare, GaloisRingTrimmedMaskCodeShare}, + helpers::kms_dh::derive_shared_secret, +}; +use iris_mpc_store::Store; +use iris_mpc_upgrade::{ + config::ReShareClientConfig, + proto::{ + self, + iris_mpc_reshare::{ + iris_code_re_share_service_client::IrisCodeReShareServiceClient, IrisCodeReShareStatus, + }, + }, + reshare::IrisCodeReshareSenderHelper, + utils::install_tracing, +}; +use sha2::Sha256; + +const APP_NAME: &str = "SMPC"; + +async fn derive_common_seed(config: &ReShareClientConfig) -> eyre::Result<[u8; 32]> { + let shared_secret = if config.environment == "testing" { + // TODO: remove once localstack fixes KMS bug that returns different shared + // secrets + [0u8; 32] + } else { + derive_shared_secret(&config.my_kms_key_arn, &config.other_kms_key_arn).await? + }; + + let hk = Hkdf::::new( + // sesstion id is used as salt + Some(config.reshare_run_session_id.as_bytes()), + &shared_secret, + ); + let mut common_seed = [0u8; 32]; + // expand the common seed bound to the context "ReShare-Protocol-Client" + hk.expand(b"ReShare-Protocol-Client", &mut common_seed) + .map_err(|e| eyre::eyre!("error during HKDF expansion: {}", e))?; + Ok(common_seed) +} + +#[tokio::main] +async fn main() -> eyre::Result<()> { + install_tracing(); + let config = ReShareClientConfig::parse(); + + let common_seed = derive_common_seed(&config).await?; + + let schema_name = format!("{}_{}_{}", APP_NAME, config.environment, config.party_id); + let store = Store::new(&config.db_url, &schema_name).await?; + + let iris_stream = store.stream_irises_in_range(config.db_start..config.db_end); + let mut iris_stream_chunks = iris_stream.chunks(config.batch_size as usize); + + let mut iris_reshare_helper = IrisCodeReshareSenderHelper::new( + config.party_id as usize, + config.other_party_id as usize, + config.target_party_id as usize, + common_seed, + ); + + let encoded_message_size = + proto::get_size_of_reshare_iris_code_share_batch(config.batch_size as usize); + if encoded_message_size > 100 * 1024 * 1024 { + tracing::warn!( + "encoded batch message size is large: {}MB", + encoded_message_size as f64 / 1024.0 / 1024.0 + ); + } + let encoded_message_size_with_buf = (encoded_message_size as f64 * 1.1) as usize; + + let mut grpc_client = IrisCodeReShareServiceClient::connect(config.server_url) + .await? + .max_decoding_message_size(encoded_message_size_with_buf) + .max_encoding_message_size(encoded_message_size_with_buf); + + while let Some(chunk) = iris_stream_chunks.next().await { + let iris_codes = chunk.into_iter().collect::, sqlx::Error>>()?; + if iris_codes.is_empty() { + continue; + } + let db_chunk_start = iris_codes.first().unwrap().id(); + let db_chunk_end = iris_codes.last().unwrap().id(); + + // sanity check + for window in iris_codes.as_slice().windows(2) { + assert_eq!( + window[0].id() + 1, + window[1].id(), + "expect consecutive iris codes" + ); + } + + iris_reshare_helper.start_reshare_batch(db_chunk_start, db_chunk_end + 1); + + for iris_code in iris_codes { + iris_reshare_helper.add_reshare_iris_to_batch( + iris_code.id(), + GaloisRingIrisCodeShare { + id: config.party_id as usize + 1, + coefs: iris_code.left_code().try_into().unwrap(), + }, + GaloisRingTrimmedMaskCodeShare { + id: config.party_id as usize + 1, + coefs: iris_code.left_mask().try_into().unwrap(), + }, + GaloisRingIrisCodeShare { + id: config.party_id as usize + 1, + coefs: iris_code.right_code().try_into().unwrap(), + }, + GaloisRingTrimmedMaskCodeShare { + id: config.party_id as usize + 1, + coefs: iris_code.right_mask().try_into().unwrap(), + }, + ); + } + tracing::info!( + "Submitting reshare request for iris codes {} to {}", + db_chunk_start, + db_chunk_end + ); + + let request = iris_reshare_helper.finalize_reshare_batch(); + let mut timeout = tokio::time::Duration::from_millis(config.retry_backoff_millis); + loop { + let resp = grpc_client.re_share(request.clone()).await?; + let resp = resp.into_inner(); + match resp.status { + x if x == IrisCodeReShareStatus::Ok as i32 => { + break; + } + x if x == IrisCodeReShareStatus::FullQueue as i32 => { + tokio::time::sleep(timeout).await; + timeout += tokio::time::Duration::from_millis(config.retry_backoff_millis); + continue; + } + x if x == IrisCodeReShareStatus::Error as i32 => { + return Err(eyre::eyre!( + "error during reshare request submission: {}", + resp.message + )); + } + _ => { + return Err(eyre::eyre!("unexpected reshare status: {}", resp.status)); + } + } + } + } + + Ok(()) +} diff --git a/iris-mpc-upgrade/src/bin/reshare-protocol-local.sh b/iris-mpc-upgrade/src/bin/reshare-protocol-local.sh new file mode 100755 index 000000000..f41b52e7d --- /dev/null +++ b/iris-mpc-upgrade/src/bin/reshare-protocol-local.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env bash + +rm -rf "*.log" + +docker-compose down --remove-orphans +docker-compose up -d + +sleep 1 + +aws_local() { + AWS_ACCESS_KEY_ID=test AWS_SECRET_ACCESS_KEY=test AWS_DEFAULT_REGION=us-east-1 aws --endpoint-url=http://${LOCALSTACK_HOST:-localhost}:4566 "$@" +} + +key1_metadata=$(aws_local kms create-key --region us-east-1 --description "Key for Party1" --key-spec ECC_NIST_P256 --key-usage KEY_AGREEMENT) +echo "Created key1: $key1_metadata" +key1_arn=$(echo "$key1_metadata" | jq ".KeyMetadata.Arn" -r) +echo "Key1 ARN: $key1_arn" +key2_metadata=$(aws_local kms create-key --region us-east-1 --description "Key for Party2" --key-spec ECC_NIST_P256 --key-usage KEY_AGREEMENT) +echo "Created key2: $key2_metadata" +key2_arn=$(echo "$key2_metadata" | jq ".KeyMetadata.Arn" -r) +echo "Key2 ARN: $key2_arn" + +sleep 1 + +cargo build --release --bin seed-v2-dbs --bin reshare-server --bin reshare-client + + + +TARGET_DIR=$(cargo metadata --format-version 1 | jq ".target_directory" -r) + +$TARGET_DIR/release/seed-v2-dbs --db-url-party1 postgres://postgres:postgres@localhost:6200 --db-url-party2 postgres://postgres:postgres@localhost:6201 --db-url-party3 postgres://postgres:postgres@localhost:6202 --schema-name-party1 SMPC_testing_0 --schema-name-party2 SMPC_testing_1 --schema-name-party3 SMPC_testing_2 --fill-to 10000 --batch-size 100 + +$TARGET_DIR/release/reshare-server --party-id 2 --sender1-party-id 0 --sender2-party-id 1 --bind-addr 0.0.0.0:7000 --environment testing --db-url postgres://postgres:postgres@localhost:6203 --batch-size 100 & > reshare-server.log + +sleep 5 + +AWS_ACCESS_KEY_ID=test AWS_SECRET_ACCESS_KEY=test AWS_DEFAULT_REGION=us-east-1 AWS_ENDPOINT_URL=http://${LOCALSTACK_HOST:-localhost}:4566 $TARGET_DIR/release/reshare-client --party-id 0 --other-party-id 1 --target-party-id 2 --server-url http://localhost:7000 --environment testing --db-url postgres://postgres:postgres@localhost:6200 --db-start 1 --db-end 10001 --batch-size 100 --my-kms-key-arn $key1_arn --other-kms-key-arn $key2_arn --reshare-run-session-id testrun1 & > reshare-client-0.log + +AWS_ACCESS_KEY_ID=test AWS_SECRET_ACCESS_KEY=test AWS_DEFAULT_REGION=us-east-1 AWS_ENDPOINT_URL=http://${LOCALSTACK_HOST:-localhost}:4566 $TARGET_DIR/release/reshare-client --party-id 1 --other-party-id 0 --target-party-id 2 --server-url http://localhost:7000 --environment testing --db-url postgres://postgres:postgres@localhost:6201 --db-start 1 --db-end 10001 --batch-size 100 --my-kms-key-arn $key2_arn --other-kms-key-arn $key1_arn --reshare-run-session-id testrun1 > reshare-client-1.log + +sleep 5 +killall reshare-server + diff --git a/iris-mpc-upgrade/src/bin/reshare-server.rs b/iris-mpc-upgrade/src/bin/reshare-server.rs new file mode 100644 index 000000000..7bd9a9610 --- /dev/null +++ b/iris-mpc-upgrade/src/bin/reshare-server.rs @@ -0,0 +1,73 @@ +use clap::Parser; +use iris_mpc_common::helpers::task_monitor::TaskMonitor; +use iris_mpc_store::Store; +use iris_mpc_upgrade::{ + config::ReShareServerConfig, + proto::{ + self, iris_mpc_reshare::iris_code_re_share_service_server::IrisCodeReShareServiceServer, + }, + reshare::{GrpcReshareServer, IrisCodeReshareReceiverHelper}, + utils::{install_tracing, spawn_healthcheck_server}, +}; +use tonic::transport::Server; + +const APP_NAME: &str = "SMPC"; + +#[tokio::main] +async fn main() -> eyre::Result<()> { + install_tracing(); + let config = ReShareServerConfig::parse(); + + tracing::info!("Starting healthcheck server."); + + let mut background_tasks = TaskMonitor::new(); + let _health_check_abort = background_tasks + .spawn(async move { spawn_healthcheck_server(config.healthcheck_port).await }); + background_tasks.check_tasks(); + tracing::info!( + "Healthcheck server running on port {}.", + config.healthcheck_port.clone() + ); + + tracing::info!( + "Healthcheck server running on port {}.", + config.healthcheck_port + ); + + let schema_name = format!("{}_{}_{}", APP_NAME, config.environment, config.party_id); + let store = Store::new(&config.db_url, &schema_name).await?; + + let receiver_helper = IrisCodeReshareReceiverHelper::new( + config.party_id as usize, + config.sender1_party_id as usize, + config.sender2_party_id as usize, + config.max_buffer_size, + ); + + let encoded_message_size = + proto::get_size_of_reshare_iris_code_share_batch(config.batch_size as usize); + if encoded_message_size > 100 * 1024 * 1024 { + tracing::warn!( + "encoded batch message size is large: {}MB", + encoded_message_size as f64 / 1024.0 / 1024.0 + ); + } + let encoded_message_size_with_buf = (encoded_message_size as f64 * 1.1) as usize; + let grpc_server = + IrisCodeReShareServiceServer::new(GrpcReshareServer::new(store, receiver_helper)) + .max_decoding_message_size(encoded_message_size_with_buf) + .max_encoding_message_size(encoded_message_size_with_buf); + + Server::builder() + .add_service(grpc_server) + .serve_with_shutdown(config.bind_addr, shutdown_signal()) + .await?; + + Ok(()) +} + +async fn shutdown_signal() { + tokio::signal::ctrl_c() + .await + .expect("failed to install CTRL+C signal handler"); +} diff --git a/iris-mpc-upgrade/src/bin/seed_v2_dbs.rs b/iris-mpc-upgrade/src/bin/seed_v2_dbs.rs new file mode 100644 index 000000000..c737c82c0 --- /dev/null +++ b/iris-mpc-upgrade/src/bin/seed_v2_dbs.rs @@ -0,0 +1,145 @@ +use clap::Parser; +use iris_mpc_common::{ + galois_engine::degree4::FullGaloisRingIrisCodeShare, iris_db::iris::IrisCode, +}; +use iris_mpc_store::{Store, StoredIrisRef}; +use itertools::Itertools; +use rand::thread_rng; +use std::cmp::min; + +#[derive(Debug, Clone, Parser)] +struct Args { + #[clap(long)] + db_url_party1: String, + + #[clap(long)] + db_url_party2: String, + + #[clap(long)] + db_url_party3: String, + + #[clap(long)] + fill_to: u64, + + #[clap(long)] + batch_size: usize, + + #[clap(long)] + schema_name_party1: String, + + #[clap(long)] + schema_name_party2: String, + + #[clap(long)] + schema_name_party3: String, + + #[clap(long, value_delimiter = ',', num_args = 1..)] + deleted_identities: Option>, +} + +#[tokio::main] +async fn main() -> eyre::Result<()> { + let args = Args::parse(); + + let store1 = Store::new(&args.db_url_party1, &args.schema_name_party1).await?; + let store2 = Store::new(&args.db_url_party2, &args.schema_name_party2).await?; + let store3 = Store::new(&args.db_url_party3, &args.schema_name_party3).await?; + + let mut rng = rand::thread_rng(); + + let latest_serial_id1 = store1.count_irises().await?; + let latest_serial_id2 = store2.count_irises().await?; + let latest_serial_id3 = store3.count_irises().await?; + let mut latest_serial_id = + min(min(latest_serial_id1, latest_serial_id2), latest_serial_id3) as u64; + + if latest_serial_id == args.fill_to { + return Ok(()); + } + // TODO: Does this make sense? + if latest_serial_id == 0 { + latest_serial_id += 1 + } + + let deleted_serial_ids = args.deleted_identities.unwrap_or_default(); + + for range_chunk in &(latest_serial_id..args.fill_to).chunks(args.batch_size) { + let range_chunk = range_chunk.collect_vec(); + let (party1, party2, party3): (Vec<_>, Vec<_>, Vec<_>) = range_chunk + .iter() + .map(|serial_id| { + let (iris_code_left, iris_code_right) = + if deleted_serial_ids.contains(&(*serial_id as i32)) { + ( + // TODO: set them to the deleted values + IrisCode::random_rng(&mut thread_rng()), + IrisCode::random_rng(&mut thread_rng()), + ) + } else { + ( + IrisCode::random_rng(&mut rng), + IrisCode::random_rng(&mut rng), + ) + }; + let [left1, left2, left3] = + FullGaloisRingIrisCodeShare::encode_iris_code(&iris_code_left, &mut rng); + let [right1, right2, right3] = + FullGaloisRingIrisCodeShare::encode_iris_code(&iris_code_right, &mut rng); + ((left1, right1), (left2, right2), (left3, right3)) + }) + .multiunzip(); + let party1_insert = party1 + .iter() + .zip(range_chunk.iter()) + .map(|((left, right), id)| StoredIrisRef { + id: *id as i64, + left_code: &left.code.coefs, + left_mask: &left.mask.coefs, + right_code: &right.code.coefs, + right_mask: &right.mask.coefs, + }) + .collect_vec(); + + let mut tx = store1.tx().await?; + store1 + .insert_irises_overriding(&mut tx, &party1_insert) + .await?; + tx.commit().await?; + + let party2_insert = party2 + .iter() + .zip(range_chunk.iter()) + .map(|((left, right), id)| StoredIrisRef { + id: *id as i64, + left_code: &left.code.coefs, + left_mask: &left.mask.coefs, + right_code: &right.code.coefs, + right_mask: &right.mask.coefs, + }) + .collect_vec(); + let mut tx = store2.tx().await?; + store2 + .insert_irises_overriding(&mut tx, &party2_insert) + .await?; + tx.commit().await?; + + let party3_insert = party3 + .iter() + .zip(range_chunk.iter()) + .map(|((left, right), id)| StoredIrisRef { + id: *id as i64, + left_code: &left.code.coefs, + left_mask: &left.mask.coefs, + right_code: &right.code.coefs, + right_mask: &right.mask.coefs, + }) + .collect_vec(); + let mut tx = store3.tx().await?; + store3 + .insert_irises_overriding(&mut tx, &party3_insert) + .await?; + tx.commit().await?; + } + + Ok(()) +} diff --git a/iris-mpc-upgrade/src/bin/tcp_upgrade_server.rs b/iris-mpc-upgrade/src/bin/tcp_upgrade_server.rs index ec30ce02e..a602cd5e9 100644 --- a/iris-mpc-upgrade/src/bin/tcp_upgrade_server.rs +++ b/iris-mpc-upgrade/src/bin/tcp_upgrade_server.rs @@ -1,12 +1,12 @@ -use axum::{routing::get, Router}; use clap::Parser; -use eyre::{bail, Context}; +use eyre::bail; use futures_concurrency::future::Join; use iris_mpc_common::helpers::task_monitor::TaskMonitor; use iris_mpc_store::Store; use iris_mpc_upgrade::{ config::{Eye, UpgradeServerConfig, BATCH_SUCCESSFUL_ACK, FINAL_BATCH_SUCCESSFUL_ACK}, packets::{MaskShareMessage, TwoToThreeIrisCodeMessage}, + utils::{install_tracing, spawn_healthcheck_server}, IrisCodeUpgrader, NewIrisShareSink, }; use std::time::Instant; @@ -14,20 +14,6 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader}; const APP_NAME: &str = "SMPC"; -fn install_tracing() { - use tracing_subscriber::{fmt, prelude::*, EnvFilter}; - - let fmt_layer = fmt::layer().with_target(true).with_line_number(true); - let filter_layer = EnvFilter::try_from_default_env() - .or_else(|_| EnvFilter::try_new("info")) - .unwrap(); - - tracing_subscriber::registry() - .with(filter_layer) - .with(fmt_layer) - .init(); -} - struct UpgradeTask { msg1: TwoToThreeIrisCodeMessage, msg2: TwoToThreeIrisCodeMessage, @@ -47,19 +33,13 @@ async fn main() -> eyre::Result<()> { tracing::info!("Starting healthcheck server."); let mut background_tasks = TaskMonitor::new(); - let _health_check_abort = background_tasks.spawn(async move { - let app = Router::new().route("/health", get(|| async {})); // implicit 200 return - let listener = tokio::net::TcpListener::bind("0.0.0.0:3000") - .await - .wrap_err("healthcheck listener bind error")?; - axum::serve(listener, app) - .await - .wrap_err("healthcheck listener server launch error")?; - Ok(()) - }); - + let _health_check_abort = background_tasks + .spawn(async move { spawn_healthcheck_server(args.healthcheck_port).await }); background_tasks.check_tasks(); - tracing::info!("Healthcheck server running on port 3000."); + tracing::info!( + "Healthcheck server running on port {}.", + args.healthcheck_port.clone() + ); let upgrader = IrisCodeUpgrader::new(args.party_id, sink.clone()); @@ -212,10 +192,6 @@ async fn main() -> eyre::Result<()> { client_stream1.write_u8(FINAL_BATCH_SUCCESSFUL_ACK).await?; tracing::info!("Sent final ACK to client1"); - tracing::info!("Updating iris id sequence"); - sink.update_iris_id_sequence().await?; - tracing::info!("Iris id sequence updated"); - Ok(()) } @@ -252,8 +228,4 @@ impl NewIrisShareSink for IrisShareDbSink { } } } - - async fn update_iris_id_sequence(&self) -> eyre::Result<()> { - self.store.update_iris_id_sequence().await - } } diff --git a/iris-mpc-upgrade/src/config.rs b/iris-mpc-upgrade/src/config.rs index 92034a08d..696bee343 100644 --- a/iris-mpc-upgrade/src/config.rs +++ b/iris-mpc-upgrade/src/config.rs @@ -54,6 +54,9 @@ pub struct UpgradeServerConfig { #[clap(long)] pub environment: String, + + #[clap(long)] + pub healthcheck_port: usize, } impl fmt::Debug for UpgradeServerConfig { @@ -118,3 +121,106 @@ impl fmt::Debug for UpgradeClientConfig { .finish() } } + +#[derive(Parser)] +pub struct ReShareClientConfig { + /// The URL of the server to send reshare messages to + #[clap(long, default_value = "http://localhost:8000", env("SERVER_URL"))] + pub server_url: String, + + /// The DB index where we start to send Iris codes from (inclusive) + #[clap(long)] + pub db_start: u64, + + /// The DB index where we stop to send Iris codes (exclusive) + #[clap(long)] + pub db_end: u64, + + /// the 0-indexed party ID of the client party + #[clap(long)] + pub party_id: u8, + + /// the 0-indexed party ID of the other client party + #[clap(long)] + pub other_party_id: u8, + + /// the 0-indexed party ID of the receiving party + #[clap(long)] + pub target_party_id: u8, + + /// The batch size to use when sending reshare messages (i.e., how many iris + /// code DB entries per message) + #[clap(long)] + pub batch_size: u64, + + /// DB connection URL for the reshare client + #[clap(long)] + pub db_url: String, + + /// The amount of time to wait before retrying a batch if the server queue + /// was full, in milliseconds. Does a simple linear backoff strategy + #[clap(long, default_value = "100")] + pub retry_backoff_millis: u64, + + /// The environment in which the reshare protocol is being run (mostly used + /// for the DB schema name) + #[clap(long)] + pub environment: String, + + /// The ARN of the KMS key that will be used to derive the common secret + #[clap(long)] + pub my_kms_key_arn: String, + + /// The ARN of the KMS key of the other client party that will be used to + /// derive the common secret + #[clap(long)] + pub other_kms_key_arn: String, + + /// The session ID of the reshare protocol run, this will be used to salt + /// the common secret derived between the two parties + #[clap(long)] + pub reshare_run_session_id: String, +} + +#[derive(Parser)] +pub struct ReShareServerConfig { + /// The socket to bind the reshare server to + #[clap(long, default_value = "0.0.0.0:8000", env("BIND_ADDR"))] + pub bind_addr: SocketAddr, + + /// The 0-indexed party ID of the server party + #[clap(long)] + pub party_id: u8, + + /// The 0-indexed party ID of the first client party (order of the two + /// client parties does not matter) + #[clap(long)] + pub sender1_party_id: u8, + + /// The 0-indexed party ID of the second client party (order of the two + /// client parties does not matter) + #[clap(long)] + pub sender2_party_id: u8, + + /// The maximum allowed batch size for reshare messages + #[clap(long)] + pub batch_size: u64, + + /// The DB connection URL to store reshared iris codes to + #[clap(long)] + pub db_url: String, + + /// The environment in which the reshare protocol is being run (mostly used + /// for the DB schema name) + #[clap(long)] + pub environment: String, + + /// The maximum buffer size for the reshare server (i.e., how many messages + /// are accepted from one client without receving corresponding messages + /// from the other client) + #[clap(long, default_value = "10")] + pub max_buffer_size: usize, + + #[clap(long, default_value = "3000")] + pub healthcheck_port: usize, +} diff --git a/iris-mpc-upgrade/src/lib.rs b/iris-mpc-upgrade/src/lib.rs index 73676b66f..ba1e9afb7 100644 --- a/iris-mpc-upgrade/src/lib.rs +++ b/iris-mpc-upgrade/src/lib.rs @@ -13,6 +13,8 @@ use std::{ pub mod config; pub mod db; pub mod packets; +pub mod proto; +pub mod reshare; pub mod utils; pub trait OldIrisShareSource { @@ -42,8 +44,6 @@ pub trait NewIrisShareSink { code_share: &[u16; IRIS_CODE_LENGTH], mask_share: &[u16; MASK_CODE_LENGTH], ) -> Result<()>; - - async fn update_iris_id_sequence(&self) -> Result<()>; } #[derive(Debug, Clone)] @@ -83,10 +83,6 @@ impl NewIrisShareSink for IrisShareTestFileSink { file.flush()?; Ok(()) } - - async fn update_iris_id_sequence(&self) -> Result<()> { - Ok(()) - } } #[derive(Clone)] diff --git a/iris-mpc-upgrade/src/proto/iris_mpc_reshare.rs b/iris-mpc-upgrade/src/proto/iris_mpc_reshare.rs new file mode 100644 index 000000000..3c26ac978 --- /dev/null +++ b/iris-mpc-upgrade/src/proto/iris_mpc_reshare.rs @@ -0,0 +1,368 @@ +// This file is @generated by prost-build. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct IrisCodeReShare { + #[prost(bytes = "vec", tag = "1")] + pub left_iris_code_share: ::prost::alloc::vec::Vec, + #[prost(bytes = "vec", tag = "2")] + pub left_mask_share: ::prost::alloc::vec::Vec, + #[prost(bytes = "vec", tag = "3")] + pub right_iris_code_share: ::prost::alloc::vec::Vec, + #[prost(bytes = "vec", tag = "4")] + pub right_mask_share: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct IrisCodeReShareRequest { + #[prost(uint64, tag = "1")] + pub sender_id: u64, + #[prost(uint64, tag = "2")] + pub other_id: u64, + #[prost(uint64, tag = "3")] + pub receiver_id: u64, + #[prost(int64, tag = "4")] + pub id_range_start_inclusive: i64, + #[prost(int64, tag = "5")] + pub id_range_end_non_inclusive: i64, + #[prost(message, repeated, tag = "6")] + pub iris_code_re_shares: ::prost::alloc::vec::Vec, + #[prost(bytes = "vec", tag = "7")] + pub client_correlation_sanity_check: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct IrisCodeReShareResponse { + #[prost(enumeration = "IrisCodeReShareStatus", tag = "1")] + pub status: i32, + #[prost(string, tag = "2")] + pub message: ::prost::alloc::string::String, +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum IrisCodeReShareStatus { + Ok = 0, + FullQueue = 1, + Error = 2, +} +impl IrisCodeReShareStatus { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::Ok => "IRIS_CODE_RE_SHARE_STATUS_OK", + Self::FullQueue => "IRIS_CODE_RE_SHARE_STATUS_FULL_QUEUE", + Self::Error => "IRIS_CODE_RE_SHARE_STATUS_ERROR", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "IRIS_CODE_RE_SHARE_STATUS_OK" => Some(Self::Ok), + "IRIS_CODE_RE_SHARE_STATUS_FULL_QUEUE" => Some(Self::FullQueue), + "IRIS_CODE_RE_SHARE_STATUS_ERROR" => Some(Self::Error), + _ => None, + } + } +} +/// Generated client implementations. +pub mod iris_code_re_share_service_client { + #![allow( + unused_variables, + dead_code, + missing_docs, + clippy::wildcard_imports, + clippy::let_unit_value, + )] + use tonic::codegen::*; + use tonic::codegen::http::Uri; + #[derive(Debug, Clone)] + pub struct IrisCodeReShareServiceClient { + inner: tonic::client::Grpc, + } + impl IrisCodeReShareServiceClient { + /// Attempt to create a new client by connecting to a given endpoint. + pub async fn connect(dst: D) -> Result + where + D: TryInto, + D::Error: Into, + { + let conn = tonic::transport::Endpoint::new(dst)?.connect().await?; + Ok(Self::new(conn)) + } + } + impl IrisCodeReShareServiceClient + where + T: tonic::client::GrpcService, + T::Error: Into, + T::ResponseBody: Body + std::marker::Send + 'static, + ::Error: Into + std::marker::Send, + { + pub fn new(inner: T) -> Self { + let inner = tonic::client::Grpc::new(inner); + Self { inner } + } + pub fn with_origin(inner: T, origin: Uri) -> Self { + let inner = tonic::client::Grpc::with_origin(inner, origin); + Self { inner } + } + pub fn with_interceptor( + inner: T, + interceptor: F, + ) -> IrisCodeReShareServiceClient> + where + F: tonic::service::Interceptor, + T::ResponseBody: Default, + T: tonic::codegen::Service< + http::Request, + Response = http::Response< + >::ResponseBody, + >, + >, + , + >>::Error: Into + std::marker::Send + std::marker::Sync, + { + IrisCodeReShareServiceClient::new( + InterceptedService::new(inner, interceptor), + ) + } + /// Compress requests with the given encoding. + /// + /// This requires the server to support it otherwise it might respond with an + /// error. + #[must_use] + pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.send_compressed(encoding); + self + } + /// Enable decompressing responses. + #[must_use] + pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.accept_compressed(encoding); + self + } + /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_decoding_message_size(limit); + self + } + /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_encoding_message_size(limit); + self + } + pub async fn re_share( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::unknown( + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/iris_mpc_reshare.IrisCodeReShareService/ReShare", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert( + GrpcMethod::new("iris_mpc_reshare.IrisCodeReShareService", "ReShare"), + ); + self.inner.unary(req, path, codec).await + } + } +} +/// Generated server implementations. +pub mod iris_code_re_share_service_server { + #![allow( + unused_variables, + dead_code, + missing_docs, + clippy::wildcard_imports, + clippy::let_unit_value, + )] + use tonic::codegen::*; + /// Generated trait containing gRPC methods that should be implemented for use with IrisCodeReShareServiceServer. + #[async_trait] + pub trait IrisCodeReShareService: std::marker::Send + std::marker::Sync + 'static { + async fn re_share( + &self, + request: tonic::Request, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + >; + } + #[derive(Debug)] + pub struct IrisCodeReShareServiceServer { + inner: Arc, + accept_compression_encodings: EnabledCompressionEncodings, + send_compression_encodings: EnabledCompressionEncodings, + max_decoding_message_size: Option, + max_encoding_message_size: Option, + } + impl IrisCodeReShareServiceServer { + pub fn new(inner: T) -> Self { + Self::from_arc(Arc::new(inner)) + } + pub fn from_arc(inner: Arc) -> Self { + Self { + inner, + accept_compression_encodings: Default::default(), + send_compression_encodings: Default::default(), + max_decoding_message_size: None, + max_encoding_message_size: None, + } + } + pub fn with_interceptor( + inner: T, + interceptor: F, + ) -> InterceptedService + where + F: tonic::service::Interceptor, + { + InterceptedService::new(Self::new(inner), interceptor) + } + /// Enable decompressing requests with the given encoding. + #[must_use] + pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.accept_compression_encodings.enable(encoding); + self + } + /// Compress responses with the given encoding, if the client supports it. + #[must_use] + pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.send_compression_encodings.enable(encoding); + self + } + /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.max_decoding_message_size = Some(limit); + self + } + /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.max_encoding_message_size = Some(limit); + self + } + } + impl tonic::codegen::Service> + for IrisCodeReShareServiceServer + where + T: IrisCodeReShareService, + B: Body + std::marker::Send + 'static, + B::Error: Into + std::marker::Send + 'static, + { + type Response = http::Response; + type Error = std::convert::Infallible; + type Future = BoxFuture; + fn poll_ready( + &mut self, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + fn call(&mut self, req: http::Request) -> Self::Future { + match req.uri().path() { + "/iris_mpc_reshare.IrisCodeReShareService/ReShare" => { + #[allow(non_camel_case_types)] + struct ReShareSvc(pub Arc); + impl< + T: IrisCodeReShareService, + > tonic::server::UnaryService + for ReShareSvc { + type Response = super::IrisCodeReShareResponse; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::re_share(&inner, request) + .await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let method = ReShareSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + _ => { + Box::pin(async move { + let mut response = http::Response::new(empty_body()); + let headers = response.headers_mut(); + headers + .insert( + tonic::Status::GRPC_STATUS, + (tonic::Code::Unimplemented as i32).into(), + ); + headers + .insert( + http::header::CONTENT_TYPE, + tonic::metadata::GRPC_CONTENT_TYPE, + ); + Ok(response) + }) + } + } + } + } + impl Clone for IrisCodeReShareServiceServer { + fn clone(&self) -> Self { + let inner = self.inner.clone(); + Self { + inner, + accept_compression_encodings: self.accept_compression_encodings, + send_compression_encodings: self.send_compression_encodings, + max_decoding_message_size: self.max_decoding_message_size, + max_encoding_message_size: self.max_encoding_message_size, + } + } + } + /// Generated gRPC service name + pub const SERVICE_NAME: &str = "iris_mpc_reshare.IrisCodeReShareService"; + impl tonic::server::NamedService for IrisCodeReShareServiceServer { + const NAME: &'static str = SERVICE_NAME; + } +} diff --git a/iris-mpc-upgrade/src/proto/mod.rs b/iris-mpc-upgrade/src/proto/mod.rs new file mode 100644 index 000000000..f230282aa --- /dev/null +++ b/iris-mpc-upgrade/src/proto/mod.rs @@ -0,0 +1,30 @@ +use iris_mpc_common::{IRIS_CODE_LENGTH, MASK_CODE_LENGTH}; +use iris_mpc_reshare::IrisCodeReShare; +use prost::Message; + +// this is generated code so we skip linting it +#[rustfmt::skip] +#[allow(clippy::all)] +pub mod iris_mpc_reshare; + +pub fn get_size_of_reshare_iris_code_share_batch(batch_size: usize) -> usize { + let dummy = iris_mpc_reshare::IrisCodeReShareRequest { + sender_id: 0, + other_id: 1, + receiver_id: 2, + id_range_start_inclusive: 0, + id_range_end_non_inclusive: batch_size as i64, + iris_code_re_shares: vec![ + IrisCodeReShare { + left_iris_code_share: vec![1u8; IRIS_CODE_LENGTH * size_of::()], + left_mask_share: vec![2u8; MASK_CODE_LENGTH * size_of::()], + right_iris_code_share: vec![3u8; IRIS_CODE_LENGTH * size_of::()], + right_mask_share: vec![4u8; MASK_CODE_LENGTH * size_of::()], + }; + batch_size + ], + client_correlation_sanity_check: vec![7u8; 32], + }; + + dummy.encoded_len() +} diff --git a/iris-mpc-upgrade/src/reshare.rs b/iris-mpc-upgrade/src/reshare.rs new file mode 100644 index 000000000..bef4f220e --- /dev/null +++ b/iris-mpc-upgrade/src/reshare.rs @@ -0,0 +1,766 @@ +//! # Iris Code Resharing +//! +//! This module has functionality for resharing a secret shared iris code to a +//! new party, producing a valid share for the new party, without leaking +//! information about the individual shares of the sending parties. + +use crate::proto::{ + self, + iris_mpc_reshare::{ + iris_code_re_share_service_server, IrisCodeReShare, IrisCodeReShareRequest, + IrisCodeReShareStatus, + }, +}; +use iris_mpc_common::{ + galois::degree4::{basis::Monomial, GaloisRingElement, ShamirGaloisRingShare}, + galois_engine::degree4::{GaloisRingIrisCodeShare, GaloisRingTrimmedMaskCodeShare}, + IRIS_CODE_LENGTH, MASK_CODE_LENGTH, +}; +use iris_mpc_store::{Store, StoredIrisRef}; +use itertools::{izip, Itertools}; +use rand::{CryptoRng, Rng, SeedableRng}; +use sha2::{Digest, Sha256}; +use std::{collections::VecDeque, sync::Mutex}; +use tonic::Response; + +pub struct IrisCodeReshareSenderHelper { + my_party_id: usize, + other_party_id: usize, + target_party_id: usize, + lagrange_helper: GaloisRingElement, + common_seed: [u8; 32], + current_packet: Option, +} + +impl IrisCodeReshareSenderHelper { + pub fn new( + my_party_id: usize, + other_party_id: usize, + target_party_id: usize, + common_seed: [u8; 32], + ) -> Self { + let lagrange_helper = ShamirGaloisRingShare::deg_1_lagrange_poly_at_v( + my_party_id, + other_party_id, + target_party_id, + ); + Self { + my_party_id, + other_party_id, + target_party_id, + lagrange_helper, + common_seed, + current_packet: None, + } + } + fn reshare_with_random_additive_zero( + &self, + share: GaloisRingElement, + rng: &mut (impl CryptoRng + Rng), + ) -> GaloisRingElement { + let random_mask = GaloisRingElement::::random(rng); + if self.my_party_id < self.other_party_id { + share + random_mask + } else { + share - random_mask + } + } + + fn reshare_code( + &self, + mut code_share: GaloisRingIrisCodeShare, + rng: &mut (impl CryptoRng + Rng), + ) -> Vec { + for i in (0..IRIS_CODE_LENGTH).step_by(4) { + let mut share = GaloisRingElement::::from_coefs([ + code_share.coefs[i], + code_share.coefs[i + 1], + code_share.coefs[i + 2], + code_share.coefs[i + 3], + ]); + share = share * self.lagrange_helper; + share = self.reshare_with_random_additive_zero(share, rng); + code_share.coefs[i] = share.coefs[0]; + code_share.coefs[i + 1] = share.coefs[1]; + code_share.coefs[i + 2] = share.coefs[2]; + code_share.coefs[i + 3] = share.coefs[3]; + } + code_share + .coefs + .into_iter() + .flat_map(|x| x.to_le_bytes()) + .collect() + } + fn reshare_mask( + &self, + mut mask_share: GaloisRingTrimmedMaskCodeShare, + rng: &mut (impl CryptoRng + Rng), + ) -> Vec { + for i in (0..MASK_CODE_LENGTH).step_by(4) { + let mut share = GaloisRingElement::::from_coefs([ + mask_share.coefs[i], + mask_share.coefs[i + 1], + mask_share.coefs[i + 2], + mask_share.coefs[i + 3], + ]); + share = share * self.lagrange_helper; + share = self.reshare_with_random_additive_zero(share, rng); + mask_share.coefs[i] = share.coefs[0]; + mask_share.coefs[i + 1] = share.coefs[1]; + mask_share.coefs[i + 2] = share.coefs[2]; + mask_share.coefs[i + 3] = share.coefs[3]; + } + mask_share + .coefs + .into_iter() + .flat_map(|x| x.to_le_bytes()) + .collect() + } + + /// Start the production of a new reshare batch request. + /// The batch will contain reshared iris codes for the given range of + /// database indices. The start range is inclusive, the end range is + /// exclusive. + /// + /// # Panics + /// + /// Panics if this is called while a batch is already being built. + pub fn start_reshare_batch(&mut self, start_db_index: i64, end_db_index: i64) { + assert!( + self.current_packet.is_none(), + "We expected no batch to be currently being built, but it is..." + ); + let mut digest = Sha256::new(); + digest.update(self.common_seed); + digest.update(start_db_index.to_le_bytes()); + digest.update(end_db_index.to_le_bytes()); + digest.update(b"ReShareSanityCheck"); + + self.current_packet = Some(IrisCodeReShareRequest { + sender_id: self.my_party_id as u64, + other_id: self.other_party_id as u64, + receiver_id: self.target_party_id as u64, + id_range_start_inclusive: start_db_index, + id_range_end_non_inclusive: end_db_index, + iris_code_re_shares: Vec::new(), + client_correlation_sanity_check: digest.finalize().as_slice().to_vec(), + }); + } + + /// Adds a new iris code to the current reshare batch. + /// + /// # Panics + /// + /// Panics if this is called without [Self::start_reshare_batch] being + /// called beforehand. + /// Panics if this is called with an iris code id that is out of the range + /// of the current batch. + pub fn add_reshare_iris_to_batch( + &mut self, + iris_code_id: i64, + left_code_share: GaloisRingIrisCodeShare, + left_mask_share: GaloisRingTrimmedMaskCodeShare, + right_code_share: GaloisRingIrisCodeShare, + right_mask_share: GaloisRingTrimmedMaskCodeShare, + ) { + assert!( + self.current_packet.is_some(), + "We expect a batch to be currently being built" + ); + assert!( + self.current_packet + .as_ref() + .unwrap() + .id_range_start_inclusive + <= iris_code_id + && self + .current_packet + .as_ref() + .unwrap() + .id_range_end_non_inclusive + > iris_code_id, + "The iris code id is out of the range of the current batch" + ); + let mut digest = Sha256::new(); + digest.update(self.common_seed); + digest.update(iris_code_id.to_le_bytes()); + let mut rng = rand_chacha::ChaChaRng::from_seed(digest.finalize().into()); + let left_reshared_code = self.reshare_code(left_code_share, &mut rng); + let left_reshared_mask = self.reshare_mask(left_mask_share, &mut rng); + let right_reshared_code = self.reshare_code(right_code_share, &mut rng); + let right_reshared_mask = self.reshare_mask(right_mask_share, &mut rng); + + let reshare = IrisCodeReShare { + left_iris_code_share: left_reshared_code, + left_mask_share: left_reshared_mask, + right_iris_code_share: right_reshared_code, + right_mask_share: right_reshared_mask, + }; + self.current_packet + .as_mut() + .expect("There is currently a batch being built") + .iris_code_re_shares + .push(reshare); + } + + /// Finalizes the current reshare batch and returns the reshare request. + /// + /// # Panics + /// + /// Panics if this is called without [Self::start_reshare_batch] being + /// called beforehand. Also panics if this is called without the correct + /// number of iris codes being added to the batch. + pub fn finalize_reshare_batch(&mut self) -> IrisCodeReShareRequest { + assert!(self.current_packet.is_some(), "No batch to finalize"); + let packet = self.current_packet.take().unwrap(); + assert_eq!( + packet.iris_code_re_shares.len(), + (packet.id_range_end_non_inclusive - packet.id_range_start_inclusive) as usize, + "Expected the correct number of iris codes to be added to the batch" + ); + packet + } +} + +#[derive(Debug, thiserror::Error)] +pub enum IrisCodeReShareError { + #[error("Invalid reshare request received: {reason}")] + InvalidRequest { reason: String }, + #[error( + "Too many requests received from this party ({party_id}) without matching request from \ + the other party ({other_party_id}" + )] + TooManyRequests { + party_id: usize, + other_party_id: usize, + }, +} + +#[derive(Debug)] +pub struct IrisCodeReshareReceiverHelper { + my_party_id: usize, + sender1_party_id: usize, + sender2_party_id: usize, + max_buffer_size: usize, + sender_1_buffer: Mutex>, + sender_2_buffer: Mutex>, +} + +impl IrisCodeReshareReceiverHelper { + pub fn new( + my_party_id: usize, + sender1_party_id: usize, + sender2_party_id: usize, + max_buffer_size: usize, + ) -> Self { + Self { + my_party_id, + sender1_party_id, + sender2_party_id, + max_buffer_size, + sender_1_buffer: Mutex::new(VecDeque::new()), + sender_2_buffer: Mutex::new(VecDeque::new()), + } + } + + fn check_valid(&self, request: &IrisCodeReShareRequest) -> Result<(), IrisCodeReShareError> { + if request.sender_id as usize == self.sender1_party_id { + if request.other_id as usize != self.sender2_party_id { + return Err(IrisCodeReShareError::InvalidRequest { + reason: "Received a request from unexpected set of parties".to_string(), + }); + } + } else if request.sender_id as usize == self.sender2_party_id { + if request.other_id as usize != self.sender1_party_id { + return Err(IrisCodeReShareError::InvalidRequest { + reason: "Received a request from unexpected set of parties".to_string(), + }); + } + } else { + return Err(IrisCodeReShareError::InvalidRequest { + reason: "Received a request from unexpected set of parties".to_string(), + }); + } + if request.receiver_id != self.my_party_id as u64 { + return Err(IrisCodeReShareError::InvalidRequest { + reason: "Received a request intended for a different party".to_string(), + }); + } + if request.id_range_start_inclusive >= request.id_range_end_non_inclusive { + return Err(IrisCodeReShareError::InvalidRequest { + reason: "Invalid range of iris codes in received request".to_string(), + }); + } + if request.iris_code_re_shares.len() + != (request.id_range_end_non_inclusive - request.id_range_start_inclusive) as usize + { + return Err(IrisCodeReShareError::InvalidRequest { + reason: "Invalid number of iris codes in received request".to_string(), + }); + } + + // Check that the iris code shares are of the correct length + if !request.iris_code_re_shares.iter().all(|reshare| { + reshare.left_iris_code_share.len() == IRIS_CODE_LENGTH * std::mem::size_of::() + && reshare.left_mask_share.len() == MASK_CODE_LENGTH * std::mem::size_of::() + && reshare.right_iris_code_share.len() + == IRIS_CODE_LENGTH * std::mem::size_of::() + && reshare.right_mask_share.len() == MASK_CODE_LENGTH * std::mem::size_of::() + }) { + return Err(IrisCodeReShareError::InvalidRequest { + reason: "Invalid iris code/mask share length".to_string(), + }); + } + Ok(()) + } + + pub fn add_request_batch( + &self, + request: IrisCodeReShareRequest, + ) -> Result<(), IrisCodeReShareError> { + self.check_valid(&request)?; + if request.sender_id as usize == self.sender1_party_id { + let mut sender_1_buffer = self.sender_1_buffer.lock().unwrap(); + if sender_1_buffer.len() + 1 >= self.max_buffer_size { + return Err(IrisCodeReShareError::TooManyRequests { + party_id: self.sender1_party_id, + other_party_id: self.sender2_party_id, + }); + } + sender_1_buffer.push_back(request); + } else if request.sender_id as usize == self.sender2_party_id { + let mut sender_2_buffer = self.sender_2_buffer.lock().unwrap(); + if sender_2_buffer.len() + 1 >= self.max_buffer_size { + return Err(IrisCodeReShareError::TooManyRequests { + party_id: self.sender2_party_id, + other_party_id: self.sender1_party_id, + }); + } + sender_2_buffer.push_back(request); + } else { + // check valid should have caught this + unreachable!() + } + + Ok(()) + } + + fn check_requests_matching( + &self, + request1: &IrisCodeReShareRequest, + request2: &IrisCodeReShareRequest, + ) -> Result<(), IrisCodeReShareError> { + if request1.id_range_start_inclusive != request2.id_range_start_inclusive + || request1.id_range_end_non_inclusive != request2.id_range_end_non_inclusive + { + return Err(IrisCodeReShareError::InvalidRequest { + reason: format!( + "Received requests with different iris code ranges: {}-{} from {} and {}-{} \ + from {}", + request1.id_range_start_inclusive, + request1.id_range_end_non_inclusive, + request1.sender_id, + request2.id_range_start_inclusive, + request2.id_range_end_non_inclusive, + request2.sender_id, + ), + }); + } + + if request1.client_correlation_sanity_check != request2.client_correlation_sanity_check { + return Err(IrisCodeReShareError::InvalidRequest { + reason: "Received requests with different correlation sanity checks, recheck the \ + used Keys for common secret derivation" + .to_string(), + }); + } + Ok(()) + } + + fn reshare_code_batch( + &self, + request1: IrisCodeReShareRequest, + request2: IrisCodeReShareRequest, + ) -> Result { + let len = request1.iris_code_re_shares.len(); + let mut left_code = Vec::with_capacity(len); + let mut left_mask = Vec::with_capacity(len); + let mut right_code = Vec::with_capacity(len); + let mut right_mask = Vec::with_capacity(len); + + for (reshare1, reshare2) in + izip!(request1.iris_code_re_shares, request2.iris_code_re_shares) + { + // build galois shares from the u8 Vecs + let mut left_code_share1 = GaloisRingIrisCodeShare { + id: self.my_party_id + 1, + coefs: reshare1 + .left_iris_code_share + .chunks_exact(size_of::()) + .map(|x| u16::from_le_bytes(x.try_into().unwrap())) + .collect_vec() + .try_into() + // we checked this beforehand in check_valid + .expect("Invalid iris code share length"), + }; + let mut left_mask_share1 = GaloisRingTrimmedMaskCodeShare { + id: self.my_party_id + 1, + coefs: reshare1 + .left_mask_share + .chunks_exact(size_of::()) + .map(|x| u16::from_le_bytes(x.try_into().unwrap())) + // we checked this beforehand in check_valid + .collect_vec() + .try_into() + .expect("Invalid mask share length"), + }; + let left_code_share2 = GaloisRingIrisCodeShare { + id: self.my_party_id + 1, + coefs: reshare2 + .left_iris_code_share + .chunks_exact(size_of::()) + .map(|x| u16::from_le_bytes(x.try_into().unwrap())) + .collect_vec() + .try_into() + // we checked this beforehand in check_valid + .expect("Invalid iris code share length"), + }; + let left_mask_share2 = GaloisRingTrimmedMaskCodeShare { + id: self.my_party_id + 1, + coefs: reshare2 + .left_mask_share + .chunks_exact(size_of::()) + .map(|x| u16::from_le_bytes(x.try_into().unwrap())) + // we checked this beforehand in check_valid + .collect_vec() + .try_into() + .expect("Invalid mask share length"), + }; + + // add them together + left_code_share1 + .coefs + .iter_mut() + .zip(left_code_share2.coefs.iter()) + .for_each(|(x, y)| { + *x = x.wrapping_add(*y); + }); + left_mask_share1 + .coefs + .iter_mut() + .zip(left_mask_share2.coefs.iter()) + .for_each(|(x, y)| { + *x = x.wrapping_add(*y); + }); + + left_code.push(left_code_share1); + left_mask.push(left_mask_share1); + + // now the right eye + // build galois shares from the u8 Vecs + let mut right_code_share1 = GaloisRingIrisCodeShare { + id: self.my_party_id + 1, + coefs: reshare1 + .right_iris_code_share + .chunks_exact(size_of::()) + .map(|x| u16::from_le_bytes(x.try_into().unwrap())) + .collect_vec() + .try_into() + // we checked this beforehand in check_valid + .expect("Invalid iris code share length"), + }; + let mut right_mask_share1 = GaloisRingTrimmedMaskCodeShare { + id: self.my_party_id + 1, + coefs: reshare1 + .right_mask_share + .chunks_exact(size_of::()) + .map(|x| u16::from_le_bytes(x.try_into().unwrap())) + // we checked this beforehand in check_valid + .collect_vec() + .try_into() + .expect("Invalid mask share length"), + }; + let right_code_share2 = GaloisRingIrisCodeShare { + id: self.my_party_id + 1, + coefs: reshare2 + .right_iris_code_share + .chunks_exact(size_of::()) + .map(|x| u16::from_le_bytes(x.try_into().unwrap())) + .collect_vec() + .try_into() + // we checked this beforehand in check_valid + .expect("Invalid iris code share length"), + }; + let right_mask_share2 = GaloisRingTrimmedMaskCodeShare { + id: self.my_party_id + 1, + coefs: reshare2 + .right_mask_share + .chunks_exact(std::mem::size_of::()) + .map(|x| u16::from_le_bytes(x.try_into().unwrap())) + // we checked this beforehand in check_valid + .collect_vec() + .try_into() + .expect("Invalid mask share length"), + }; + + // add them together + right_code_share1 + .coefs + .iter_mut() + .zip(right_code_share2.coefs.iter()) + .for_each(|(x, y)| { + *x = x.wrapping_add(*y); + }); + right_mask_share1 + .coefs + .iter_mut() + .zip(right_mask_share2.coefs.iter()) + .for_each(|(x, y)| { + *x = x.wrapping_add(*y); + }); + + right_code.push(right_code_share1); + right_mask.push(right_mask_share1); + } + + Ok(RecombinedIrisCodeBatch { + range_start_inclusive: request1.id_range_start_inclusive, + range_end_exclusive: request1.id_range_end_non_inclusive, + left_iris_codes: left_code, + left_masks: left_mask, + right_iris_codes: right_code, + right_masks: right_mask, + }) + } + + pub fn try_handle_batch( + &self, + ) -> Result, IrisCodeReShareError> { + let mut sender_1_buffer = self.sender_1_buffer.lock().unwrap(); + let mut sender_2_buffer = self.sender_2_buffer.lock().unwrap(); + if sender_1_buffer.is_empty() || sender_2_buffer.is_empty() { + return Ok(None); + } + + let sender_1_batch = sender_1_buffer.pop_front().unwrap(); + let sender_2_batch = sender_2_buffer.pop_front().unwrap(); + drop(sender_1_buffer); + drop(sender_2_buffer); + + self.check_requests_matching(&sender_1_batch, &sender_2_batch)?; + + let reshare = self.reshare_code_batch(sender_1_batch, sender_2_batch)?; + + Ok(Some(reshare)) + } +} + +/// A batch of recombined iris codes, produced by resharing iris codes from two +/// other parties. This should be inserted into the database. +pub struct RecombinedIrisCodeBatch { + range_start_inclusive: i64, + #[expect(unused)] + range_end_exclusive: i64, + left_iris_codes: Vec, + left_masks: Vec, + right_iris_codes: Vec, + right_masks: Vec, +} + +impl RecombinedIrisCodeBatch { + pub async fn insert_into_store(self, store: &Store) -> eyre::Result<()> { + let to_be_inserted = izip!( + &self.left_iris_codes, + &self.left_masks, + &self.right_iris_codes, + &self.right_masks + ) + .enumerate() + .map(|(idx, (left_iris, left_mask, right_iris, right_mask))| { + let id = self.range_start_inclusive + idx as i64; + StoredIrisRef { + id, + left_code: &left_iris.coefs, + left_mask: &left_mask.coefs, + right_code: &right_iris.coefs, + right_mask: &right_mask.coefs, + } + }) + .collect::>(); + let mut tx = store.tx().await?; + store + .insert_irises_overriding(&mut tx, &to_be_inserted) + .await?; + tx.commit().await?; + Ok(()) + } +} + +pub struct GrpcReshareServer { + store: Store, + receiver_helper: IrisCodeReshareReceiverHelper, +} + +impl GrpcReshareServer { + pub fn new(store: Store, receiver_helper: IrisCodeReshareReceiverHelper) -> Self { + Self { + store, + receiver_helper, + } + } +} + +#[tonic::async_trait] +impl iris_code_re_share_service_server::IrisCodeReShareService for GrpcReshareServer { + async fn re_share( + &self, + request: tonic::Request, + ) -> Result, tonic::Status> { + match self.receiver_helper.add_request_batch(request.into_inner()) { + Ok(()) => (), + Err(err) => { + tracing::warn!(error = err.to_string(), "Error handling reshare request"); + return match err { + IrisCodeReShareError::InvalidRequest { reason } => Ok(Response::new( + proto::iris_mpc_reshare::IrisCodeReShareResponse { + status: IrisCodeReShareStatus::Error as i32, + message: reason, + }, + )), + IrisCodeReShareError::TooManyRequests { .. } => Ok(Response::new( + proto::iris_mpc_reshare::IrisCodeReShareResponse { + status: IrisCodeReShareStatus::FullQueue as i32, + message: err.to_string(), + }, + )), + }; + } + } + // we received a batch, try to handle it + match self.receiver_helper.try_handle_batch() { + Ok(Some(batch)) => { + // write the reshared iris codes to the database + match batch.insert_into_store(&self.store).await { + Ok(()) => (), + Err(err) => { + tracing::error!( + error = err.to_string(), + "Error inserting reshared iris codes into DB" + ); + } + } + } + Ok(None) => (), + Err(err) => { + tracing::warn!(error = err.to_string(), "Error handling reshare request"); + return Ok(Response::new( + proto::iris_mpc_reshare::IrisCodeReShareResponse { + status: IrisCodeReShareStatus::Error as i32, + message: err.to_string(), + }, + )); + } + } + + Ok(Response::new( + proto::iris_mpc_reshare::IrisCodeReShareResponse { + status: IrisCodeReShareStatus::Ok as i32, + message: Default::default(), + }, + )) + } +} + +#[cfg(test)] +mod tests { + use super::IrisCodeReshareSenderHelper; + use crate::reshare::IrisCodeReshareReceiverHelper; + use iris_mpc_common::{ + galois_engine::degree4::FullGaloisRingIrisCodeShare, iris_db::db::IrisDB, + }; + use itertools::Itertools; + use rand::thread_rng; + + #[test] + fn test_basic_resharing() { + const DB_SIZE: usize = 100; + + let left_db = IrisDB::new_random_rng(DB_SIZE, &mut thread_rng()); + let right_db = IrisDB::new_random_rng(DB_SIZE, &mut thread_rng()); + + let (party0_db_left, party1_db_left, party2_db_left): (Vec<_>, Vec<_>, Vec<_>) = left_db + .db + .iter() + .map(|x| { + let [a, b, c] = FullGaloisRingIrisCodeShare::encode_iris_code(x, &mut thread_rng()); + (a, b, c) + }) + .multiunzip(); + let (party0_db_right, party1_db_right, party2_db_right): (Vec<_>, Vec<_>, Vec<_>) = + right_db + .db + .iter() + .map(|x| { + let [a, b, c] = + FullGaloisRingIrisCodeShare::encode_iris_code(x, &mut thread_rng()); + (a, b, c) + }) + .multiunzip(); + + let mut reshare_helper_0_1_2 = IrisCodeReshareSenderHelper::new(0, 1, 2, [0; 32]); + let mut reshare_helper_1_0_2 = IrisCodeReshareSenderHelper::new(1, 0, 2, [0; 32]); + let reshare_helper_2 = IrisCodeReshareReceiverHelper::new(2, 0, 1, 100); + + reshare_helper_0_1_2.start_reshare_batch(0, DB_SIZE as i64); + for (idx, (left, right)) in party0_db_left + .iter() + .zip(party0_db_right.iter()) + .enumerate() + { + reshare_helper_0_1_2.add_reshare_iris_to_batch( + idx as i64, + left.code.clone(), + left.mask.clone(), + right.code.clone(), + right.mask.clone(), + ); + } + let reshare_request_0_1_2 = reshare_helper_0_1_2.finalize_reshare_batch(); + + reshare_helper_1_0_2.start_reshare_batch(0, DB_SIZE as i64); + for (idx, (left, right)) in party1_db_left + .iter() + .zip(party1_db_right.iter()) + .enumerate() + { + reshare_helper_1_0_2.add_reshare_iris_to_batch( + idx as i64, + left.code.clone(), + left.mask.clone(), + right.code.clone(), + right.mask.clone(), + ); + } + let reshare_request_1_0_2 = reshare_helper_1_0_2.finalize_reshare_batch(); + + reshare_helper_2 + .add_request_batch(reshare_request_0_1_2) + .unwrap(); + reshare_helper_2 + .add_request_batch(reshare_request_1_0_2) + .unwrap(); + + let reshare_batch = reshare_helper_2.try_handle_batch().unwrap().unwrap(); + + for (idx, (left, right)) in party2_db_left + .iter() + .zip(party2_db_right.iter()) + .enumerate() + { + assert_eq!(&left.code, &reshare_batch.left_iris_codes[idx]); + assert_eq!(&left.mask, &reshare_batch.left_masks[idx]); + assert_eq!(&right.code, &reshare_batch.right_iris_codes[idx]); + assert_eq!(&right.mask, &reshare_batch.right_masks[idx]); + } + } +} diff --git a/iris-mpc-upgrade/src/utils.rs b/iris-mpc-upgrade/src/utils.rs index 34b5a70d3..a6c9ca385 100644 --- a/iris-mpc-upgrade/src/utils.rs +++ b/iris-mpc-upgrade/src/utils.rs @@ -3,6 +3,8 @@ use crate::{ packets::{MaskShareMessage, TwoToThreeIrisCodeMessage}, OldIrisShareSource, }; +use axum::{routing::get, Router}; +use eyre::Context; use futures::{Stream, StreamExt}; use iris_mpc_common::galois_engine::degree4::{ GaloisRingIrisCodeShare, GaloisRingTrimmedMaskCodeShare, @@ -132,3 +134,14 @@ impl OldIrisShareSource for V1Database { })) } } + +pub async fn spawn_healthcheck_server(healthcheck_port: usize) -> eyre::Result<()> { + let app = Router::new().route("/health", get(|| async {})); // Implicit 200 response + let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", healthcheck_port)) + .await + .wrap_err("Healthcheck listener bind error")?; + axum::serve(listener, app) + .await + .wrap_err("healthcheck listener server launch error")?; + Ok(()) +} diff --git a/iris-mpc/Cargo.toml b/iris-mpc/Cargo.toml index 5e603a9b8..605a831d8 100644 --- a/iris-mpc/Cargo.toml +++ b/iris-mpc/Cargo.toml @@ -11,6 +11,7 @@ repository.workspace = true aws-config.workspace = true aws-sdk-sns.workspace = true aws-sdk-sqs.workspace = true +aws-sdk-s3.workspace = true axum.workspace = true tokio.workspace = true tracing.workspace = true @@ -34,6 +35,8 @@ iris-mpc-store = { path = "../iris-mpc-store" } sha2 = "0.10.8" metrics = "0.22.1" metrics-exporter-statsd = "0.7" +serde = { version = "1.0.214", features = ["derive"] } +bincode = "1.3.3" [dev-dependencies] criterion = "0.5" diff --git a/iris-mpc/src/bin/client.rs b/iris-mpc/src/bin/client.rs index 682b630d3..cc0cf0529 100644 --- a/iris-mpc/src/bin/client.rs +++ b/iris-mpc/src/bin/client.rs @@ -10,10 +10,8 @@ use iris_mpc_common::{ helpers::{ key_pair::download_public_key, sha256::calculate_sha256, - smpc_request::{ - create_message_type_attribute_map, IrisCodesJSON, UniquenessRequest, UniquenessResult, - UNIQUENESS_MESSAGE_TYPE, - }, + smpc_request::{IrisCodesJSON, UniquenessRequest, UNIQUENESS_MESSAGE_TYPE}, + smpc_response::{create_message_type_attribute_map, UniquenessResult}, sqs_s3_helper::upload_file_and_generate_presigned_url, }, iris_db::{db::IrisDB, iris::IrisCode}, @@ -374,7 +372,7 @@ async fn main() -> eyre::Result<()> { let request_message = UniquenessRequest { batch_size: None, signup_id: request_id.to_string(), - s3_presigned_url: presigned_url, + s3_key: presigned_url, iris_shares_file_hashes, }; diff --git a/iris-mpc/src/bin/server.rs b/iris-mpc/src/bin/server.rs index 09026a092..adf7beddb 100644 --- a/iris-mpc/src/bin/server.rs +++ b/iris-mpc/src/bin/server.rs @@ -1,11 +1,13 @@ #![allow(clippy::needless_range_loop)] +use aws_config::retry::RetryConfig; +use aws_sdk_s3::{config::Builder as S3ConfigBuilder, Client as S3Client}; use aws_sdk_sns::{types::MessageAttributeValue, Client as SNSClient}; use aws_sdk_sqs::{config::Region, Client}; -use axum::{routing::get, Router}; +use axum::{response::IntoResponse, routing::get, Router}; use clap::Parser; use eyre::{eyre, Context}; -use futures::TryStreamExt; +use futures::{stream::select_all, StreamExt, TryStreamExt}; use iris_mpc_common::{ config::{json_wrapper::JsonStrWrapper, Config, Opt}, galois_engine::degree4::{GaloisRingIrisCodeShare, GaloisRingTrimmedMaskCodeShare}, @@ -18,10 +20,13 @@ use iris_mpc_common::{ kms_dh::derive_shared_secret, shutdown_handler::ShutdownHandler, smpc_request::{ - create_message_type_attribute_map, CircuitBreakerRequest, IdentityDeletionRequest, - IdentityDeletionResult, ReceiveRequestError, SQSMessage, UniquenessRequest, - UniquenessResult, CIRCUIT_BREAKER_MESSAGE_TYPE, IDENTITY_DELETION_MESSAGE_TYPE, - SMPC_MESSAGE_TYPE_ATTRIBUTE, UNIQUENESS_MESSAGE_TYPE, + CircuitBreakerRequest, IdentityDeletionRequest, ReceiveRequestError, SQSMessage, + UniquenessRequest, CIRCUIT_BREAKER_MESSAGE_TYPE, IDENTITY_DELETION_MESSAGE_TYPE, + UNIQUENESS_MESSAGE_TYPE, + }, + smpc_response::{ + create_message_type_attribute_map, IdentityDeletionResult, UniquenessResult, + ERROR_FAILED_TO_PROCESS_IRIS_SHARES, SMPC_MESSAGE_TYPE_ATTRIBUTE, }, sync::SyncState, task_monitor::TaskMonitor, @@ -34,13 +39,21 @@ use iris_mpc_gpu::{ BatchQueryEntriesPreprocessed, ServerActor, ServerJobResult, }, }; -use iris_mpc_store::{Store, StoredIrisRef}; +use iris_mpc_store::{ + fetch_and_parse_chunks, last_snapshot_timestamp, IrisSource, S3Store, Store, StoredIrisRef, +}; use metrics_exporter_statsd::StatsdBuilder; +use reqwest::StatusCode; +use serde::{Deserialize, Serialize}; use std::{ backtrace::Backtrace, - collections::HashMap, + collections::{HashMap, HashSet}, mem, panic, - sync::{Arc, LazyLock, Mutex}, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, LazyLock, Mutex, + }, + time, time::{Duration, Instant}, }; use telemetry_batteries::tracing::{datadog::DatadogBattery, TracingShutdownHandle}; @@ -113,13 +126,17 @@ fn preprocess_iris_message_shares( async fn receive_batch( party_id: usize, client: &Client, - queue_url: &String, + sns_client: &SNSClient, + s3_client: &Arc, + config: &Config, store: &Store, skip_request_ids: &[String], shares_encryption_key_pairs: SharesEncryptionKeyPairs, - max_batch_size: usize, shutdown_handler: &ShutdownHandler, + error_result_attributes: &HashMap, ) -> eyre::Result, ReceiveRequestError> { + let max_batch_size = config.clone().max_batch_size; + let queue_url = &config.clone().requests_queue_url; if shutdown_handler.is_shutting_down() { tracing::info!("Stopping batch receive due to shutdown signal..."); return Ok(None); @@ -264,17 +281,21 @@ async fn receive_batch( batch_query.metadata.push(batch_metadata); let semaphore = Arc::clone(&semaphore); + let s3_client_arc = Arc::clone(s3_client); + let bucket_name = config.shares_bucket_name.clone(); let handle = tokio::spawn(async move { let _ = semaphore.acquire().await?; - let base_64_encoded_message_payload = - match smpc_request.get_iris_data_by_party_id(party_id).await { - Ok(iris_message_share) => iris_message_share, - Err(e) => { - tracing::error!("Failed to get iris shares: {:?}", e); - eyre::bail!("Failed to get iris shares: {:?}", e); - } - }; + let base_64_encoded_message_payload = match smpc_request + .get_iris_data_by_party_id(party_id, &bucket_name, &s3_client_arc) + .await + { + Ok(iris_message_share) => iris_message_share, + Err(e) => { + tracing::error!("Failed to get iris shares: {:?}", e); + eyre::bail!("Failed to get iris shares: {:?}", e); + } + }; let iris_message_share = match smpc_request.decrypt_iris_share( base_64_encoded_message_payload, @@ -344,8 +365,7 @@ async fn receive_batch( tokio::time::sleep(SQS_POLLING_INTERVAL).await; } } - - for handle in handles { + for (index, handle) in handles.into_iter().enumerate() { let ( ( ( @@ -373,6 +393,18 @@ async fn receive_batch( Ok(res) => (res, true), Err(e) => { tracing::error!("Failed to process iris shares: {:?}", e); + // Return error message back to the signup-service if failed to process iris + // shares + send_error_results_to_sns( + batch_query.request_ids[index].clone(), + &batch_query.metadata[index], + sns_client, + config, + error_result_attributes, + UNIQUENESS_MESSAGE_TYPE, + ERROR_FAILED_TO_PROCESS_IRIS_SHARES, + ) + .await?; // If we failed to process the iris shares, we include a dummy entry in the // batch in order to keep the same order across nodes let dummy_code_share = GaloisRingIrisCodeShare::default_for_party(party_id); @@ -418,6 +450,8 @@ async fn receive_batch( batch_query.query_right.mask.extend(mask_shares_right); } + tracing::info!("batch signups ids in order: {:?}", batch_query.request_ids); + // Preprocess query shares here already to avoid blocking the actor batch_query.query_left_preprocessed = BatchQueryEntriesPreprocessed::from(batch_query.query_left.clone()); @@ -527,6 +561,43 @@ async fn initialize_chacha_seeds( Ok(chacha_seeds) } +async fn send_error_results_to_sns( + signup_id: String, + metadata: &BatchMetadata, + sns_client: &SNSClient, + config: &Config, + base_message_attributes: &HashMap, + message_type: &str, + error_reason: &str, +) -> eyre::Result<()> { + let message: UniquenessResult = UniquenessResult { + node_id: config.party_id, + serial_id: None, + is_match: false, + signup_id, + matched_serial_ids: None, + matched_serial_ids_left: None, + matched_serial_ids_right: None, + matched_batch_request_ids: None, + error: Some(true), + error_reason: Some(String::from(error_reason)), + }; + let message_serialised = serde_json::to_string(&message)?; + let mut message_attributes = base_message_attributes.clone(); + let trace_attributes = construct_message_attributes(&metadata.trace_id, &metadata.span_id)?; + message_attributes.extend(trace_attributes); + sns_client + .publish() + .topic_arn(&config.results_topic_arn) + .message(message_serialised) + .message_group_id(format!("party-id-{}", config.party_id)) + .set_message_attributes(Some(message_attributes)) + .send() + .await?; + metrics::counter!("result.sent", "type" => message_type.to_owned()+"_error").increment(1); + + Ok(()) +} async fn send_results_to_sns( result_events: Vec, metadata: &[BatchMetadata], @@ -605,6 +676,14 @@ async fn server_main(config: Config) -> eyre::Result<()> { let shared_config = aws_config::from_env().region(region_provider).load().await; let sqs_client = Client::new(&shared_config); let sns_client = SNSClient::new(&shared_config); + + // Increase S3 retries to 5 + let retry_config = RetryConfig::standard().with_max_attempts(5); + let s3_config = S3ConfigBuilder::from(&shared_config) + .retry_config(retry_config) + .build(); + let s3_client = Arc::new(S3Client::from_conf(s3_config)); + let s3_client_clone = Arc::clone(&s3_client); let shares_encryption_key_pair = match SharesEncryptionKeyPairs::from_storage(config.clone()).await { Ok(key_pair) => key_pair, @@ -660,36 +739,19 @@ async fn server_main(config: Config) -> eyre::Result<()> { tracing::info!("Size of the database after init: {}", store_len); // Check if the sequence id is consistent with the number of irises - let iris_sequence_id = store.get_irises_sequence_id().await?; - if iris_sequence_id != store_len { - tracing::warn!( - "Detected inconsistent iris sequence id {} != {}, resetting...", - iris_sequence_id, + let max_serial_id = store.get_max_serial_id().await?; + if max_serial_id != store_len { + tracing::error!( + "Detected inconsistency between max serial id {} and db size {}.", + max_serial_id, store_len ); - // Reset the sequence id - store.set_irises_sequence_id(store_len).await?; - - // Fetch again and check that the sequence id is consistent now - let store_len = store.count_irises().await?; - let iris_sequence_id = store.get_irises_sequence_id().await?; - - // If db is empty, we set the sequence id to 1 with advance_nextval false - let empty_db_sequence_ok = store_len == 0 && iris_sequence_id == 1; - - if iris_sequence_id != store_len && !empty_db_sequence_ok { - tracing::error!( - "Iris sequence id is still inconsistent: {} != {}", - iris_sequence_id, - store_len - ); - eyre::bail!( - "Iris sequence id is still inconsistent: {} != {}", - iris_sequence_id, - store_len - ); - } + eyre::bail!( + "Detected inconsistency between max serial id {} and db size {}.", + max_serial_id, + store_len + ); } if store_len > config.max_db_size { @@ -697,14 +759,147 @@ async fn server_main(config: Config) -> eyre::Result<()> { eyre::bail!("Database size exceeds maximum allowed size: {}", store_len); } + tracing::info!("Preparing task monitor"); + let mut background_tasks = TaskMonitor::new(); + + // -------------------------------------------------------------------------- + // ANCHOR: Starting Healthcheck and Readiness server + // -------------------------------------------------------------------------- + tracing::info!("⚓️ ANCHOR: Starting Healthcheck and Readiness server"); + + let is_ready_flag = Arc::new(AtomicBool::new(false)); + let is_ready_flag_cloned = Arc::clone(&is_ready_flag); + + #[derive(Serialize, Deserialize)] + struct ReadyProbeResponse { + image_name: String, + uuid: String, + } + + let _health_check_abort = background_tasks.spawn({ + let uuid = uuid::Uuid::new_v4().to_string(); + let ready_probe_response = ReadyProbeResponse { + image_name: config.image_name.clone(), + uuid, + }; + let serialized_response = serde_json::to_string(&ready_probe_response) + .expect("Serialization to JSON to probe response failed"); + tracing::info!("Healthcheck probe response: {}", serialized_response); + async move { + // Generate a random UUID for each run. + let app = Router::new() + .route( + "/health", + get(move || async move { serialized_response.clone() }), + ) + .route( + "/ready", + get({ + // We are only ready once this flag is set to true. + let is_ready_flag = Arc::clone(&is_ready_flag); + move || async move { + if is_ready_flag.load(Ordering::SeqCst) { + "ready".into_response() + } else { + StatusCode::SERVICE_UNAVAILABLE.into_response() + } + } + }), + ); + let listener = tokio::net::TcpListener::bind("0.0.0.0:3000") + .await + .wrap_err("healthcheck listener bind error")?; + axum::serve(listener, app) + .await + .wrap_err("healthcheck listener server launch error")?; + + Ok::<(), eyre::Error>(()) + } + }); + + background_tasks.check_tasks(); + tracing::info!("Healthcheck and Readiness server running on port 3000."); + + let (heartbeat_tx, heartbeat_rx) = oneshot::channel(); + let mut heartbeat_tx = Some(heartbeat_tx); + let all_nodes = config.node_hostnames.clone(); + let image_name = config.image_name.clone(); + let _heartbeat = background_tasks.spawn(async move { + let next_node = &all_nodes[(config.party_id + 1) % 3]; + let prev_node = &all_nodes[(config.party_id + 2) % 3]; + let mut last_response = [String::default(), String::default()]; + let mut connected = [false, false]; + let mut retries = [0, 0]; + + loop { + for (i, host) in [next_node, prev_node].iter().enumerate() { + let res = reqwest::get(format!("http://{}:3000/health", host)).await; + if res.is_err() || !res.as_ref().unwrap().status().is_success() { + // If it's the first time after startup, we allow a few retries to let the other + // nodes start up as well. + if last_response[i] == String::default() + && retries[i] < config.heartbeat_initial_retries + { + retries[i] += 1; + tracing::warn!("Node {} did not respond with success, retrying...", host); + continue; + } + // The other node seems to be down or returned an error. + panic!( + "Node {} did not respond with success, killing server...", + host + ); + } + + let probe_response = res + .unwrap() + .json::() + .await + .expect("Deserialization of probe response failed"); + if probe_response.image_name != image_name { + // Do not create a panic as we still can continue to process before its + // updated + tracing::error!( + "Host {} is using image {} which differs from current node image: {}", + host, + probe_response.image_name.clone(), + image_name + ); + } + if last_response[i] == String::default() { + last_response[i] = probe_response.uuid; + connected[i] = true; + + // If all nodes are connected, notify the main thread. + if connected.iter().all(|&c| c) { + if let Some(tx) = heartbeat_tx.take() { + tx.send(()).unwrap(); + } + } + } else if probe_response.uuid != last_response[i] { + // If the UUID response is different, the node has restarted without us + // noticing. Our main NCCL connections cannot recover from + // this, so we panic. + panic!("Node {} seems to have restarted, killing server...", host); + } else { + tracing::info!("Heartbeat: Node {} is healthy", host); + } + } + + tokio::time::sleep(Duration::from_secs(config.heartbeat_interval_secs)).await; + } + }); + + tracing::info!("Heartbeat starting..."); + heartbeat_rx.await?; + tracing::info!("Heartbeat on all nodes started."); + background_tasks.check_tasks(); + let my_state = SyncState { db_len: store_len as u64, deleted_request_ids: store.last_deleted_requests(max_sync_lookback).await?, }; - tracing::info!("Preparing task monitor"); - let mut background_tasks = TaskMonitor::new(); - // Start the actor in separate task. // A bit convoluted, but we need to create the actor on the thread already, // since it blocks a lot and is `!Send`, we get back the handle via the oneshot @@ -715,15 +910,26 @@ async fn server_main(config: Config) -> eyre::Result<()> { .ok_or(eyre!("Missing database config"))? .load_parallelism; + let load_chunks_parallelism = config.load_chunks_parallelism; + let db_chunks_bucket_name = config.db_chunks_bucket_name.clone(); + let db_chunks_folder_name = config.db_chunks_folder_name.clone(); + let (tx, rx) = oneshot::channel(); background_tasks.spawn_blocking(move || { let device_manager = Arc::new(DeviceManager::init()); let ids = device_manager.get_ids_from_magic(0); - tracing::info!("Starting NCCL"); + // -------------------------------------------------------------------------- + // ANCHOR: Starting NCCL + // -------------------------------------------------------------------------- + tracing::info!("⚓️ ANCHOR: Starting NCCL"); let comms = device_manager.instantiate_network_from_ids(config.party_id, &ids)?; + // FYI: If any of the nodes die after this, all connections are broken. - tracing::info!("NCCL: getting sync results"); + // -------------------------------------------------------------------------- + // ANCHOR: Syncing latest node state + // -------------------------------------------------------------------------- + tracing::info!("⚓️ ANCHOR: Syncing latest node state"); let sync_result = match sync_nccl::sync(&comms[0], &my_state) { Ok(res) => res, Err(e) => { @@ -731,6 +937,7 @@ async fn server_main(config: Config) -> eyre::Result<()> { return Ok(()); } }; + tracing::info!("Database store length is: {}", store_len); if let Some(db_len) = sync_result.must_rollback_storage() { tracing::error!("Databases are out-of-sync: {:?}", sync_result); @@ -741,10 +948,20 @@ async fn server_main(config: Config) -> eyre::Result<()> { db_len, )); } + tracing::warn!( + "Rolling back from database length {} to other nodes length {}", + store_len, + db_len + ); tokio::runtime::Handle::current().block_on(async { store.rollback(db_len).await })?; - tracing::error!("Rolled back to db_len={}", db_len); + metrics::counter!("db.sync.rollback").increment(1); } + // -------------------------------------------------------------------------- + // ANCHOR: Load the database + // -------------------------------------------------------------------------- + tracing::info!("⚓️ ANCHOR: Load the database"); + tracing::info!("Starting server actor"); match ServerActor::new_with_device_manager_and_comms( config.party_id, @@ -756,6 +973,7 @@ async fn server_main(config: Config) -> eyre::Result<()> { config.max_batch_size, config.return_partial_results, config.disable_persistence, + config.enable_debug_timing, ) { Ok((mut actor, handle)) => { let res = if config.fake_db_size > 0 { @@ -774,20 +992,102 @@ async fn server_main(config: Config) -> eyre::Result<()> { "Initialize iris db: Loading from DB (parallelism: {})", parallelism ); + let s3_store = S3Store::new(s3_client_clone, db_chunks_bucket_name); tokio::runtime::Handle::current().block_on(async { - let mut stream = store.stream_irises_par(parallelism).await; + let mut stream = match config.enable_s3_importer { + true => { + tracing::info!("S3 importer enabled. Fetching from s3 + db"); + // First fetch last snapshot from S3 + let last_snapshot_details = last_snapshot_timestamp( + &s3_store, + db_chunks_folder_name.clone(), + ) + .await?; + let min_last_modified_at = last_snapshot_details.timestamp + - config.db_load_safety_overlap_seconds; + tracing::info!( + "Last snapshot timestamp: {}, min_last_modified_at: {}", + last_snapshot_details.timestamp, + min_last_modified_at + ); + let stream_s3 = fetch_and_parse_chunks( + &s3_store, + load_chunks_parallelism, + db_chunks_folder_name, + last_snapshot_details, + ) + .await + .map(|result| result.map(IrisSource::S3)) + .boxed(); + + let stream_db = store + .stream_irises_par(Some(min_last_modified_at), parallelism) + .await + .map(|result| result.map(IrisSource::DB)) + .boxed(); + + select_all(vec![stream_s3, stream_db]) + } + false => { + tracing::info!("S3 importer disabled. Fetching only from db"); + let stream_db = store + .stream_irises_par(None, parallelism) + .await + .map(|result| result.map(IrisSource::DB)) + .boxed(); + select_all(vec![stream_db]) + } + }; + + let now = Instant::now(); + let mut now_load_summary = Instant::now(); + let mut time_waiting_for_stream = time::Duration::from_secs(0); + let mut time_loading_into_memory = time::Duration::from_secs(0); let mut record_counter = 0; - while let Some(iris) = stream.try_next().await? { + let mut all_serial_ids: HashSet = + HashSet::from_iter(1..=(store_len as i64)); + let mut serial_ids_from_db: HashSet = HashSet::new(); + let mut n_loaded_from_db = 0; + let mut n_loaded_from_s3 = 0; + while let Some(result) = stream.try_next().await? { + time_waiting_for_stream += now_load_summary.elapsed(); + now_load_summary = Instant::now(); + + let iris = match result { + IrisSource::DB(iris) => { + n_loaded_from_db += 1; + serial_ids_from_db.insert(iris.id()); + iris + } + IrisSource::S3(iris) => { + if serial_ids_from_db.contains(&iris.id()) { + tracing::warn!( + "Skip overriding record already loaded via DB with S3 \ + record: {}", + iris.id() + ); + continue; + } + n_loaded_from_s3 += 1; + iris + } + }; + if record_counter % 100_000 == 0 { + let elapsed = now.elapsed(); tracing::info!( - "Loaded {} records from db into memory", - record_counter + "Loaded {} records into memory in {:?} ({:.2} entries/s)", + record_counter, + elapsed, + record_counter as f64 / elapsed.as_secs_f64() ); } - if iris.index() > store_len { - tracing::error!("Inconsistent iris index {}", iris.index()); - return Err(eyre!("Inconsistent iris index {}", iris.index())); + + if iris.index() == 0 || iris.index() > store_len { + tracing::error!("Invalid iris index {}", iris.index()); + return Err(eyre!("Invalid iris index {}", iris.index())); } + actor.load_single_record( iris.index() - 1, iris.left_code(), @@ -795,17 +1095,47 @@ async fn server_main(config: Config) -> eyre::Result<()> { iris.right_code(), iris.right_mask(), ); + + // if the serial id hasn't been loaded before, count is as unique record + if all_serial_ids.contains(&(iris.index() as i64)) { + actor.increment_db_size(iris.index() - 1); + } + + time_loading_into_memory += now_load_summary.elapsed(); + now_load_summary = Instant::now(); + + all_serial_ids.remove(&(iris.index() as i64)); record_counter += 1; } - assert_eq!( - record_counter, store_len, - "Loaded record count does not match db size" + tracing::info!( + "Loading summary => Loaded {:?} items. {} from DB, {} from S3. Waited \ + for stream: {:?}, Loaded into memory: {:?}", + record_counter, + n_loaded_from_db, + n_loaded_from_s3, + time_waiting_for_stream, + time_loading_into_memory, ); + // Clear the memory allocated by temp HashSet + serial_ids_from_db.clear(); + serial_ids_from_db.shrink_to_fit(); + + if !all_serial_ids.is_empty() { + tracing::error!("Not all serial_ids were loaded: {:?}", all_serial_ids); + return Err(eyre!( + "Not all serial_ids were loaded: {:?}", + all_serial_ids + )); + } + tracing::info!("Preprocessing db"); actor.preprocess_db(); + tracing::info!("Page-lock host memory"); + actor.register_host_memory(); + tracing::info!( "Loaded {} records from db into memory [DB sizes: {:?}]", record_counter, @@ -906,7 +1236,7 @@ async fn server_main(config: Config) -> eyre::Result<()> { .collect::>>()?; // Insert non-matching queries into the persistent store. - let (memory_serial_ids, codes_and_masks): (Vec, Vec) = matches + let (memory_serial_ids, codes_and_masks): (Vec, Vec) = matches .iter() .enumerate() .filter_map( @@ -914,8 +1244,10 @@ async fn server_main(config: Config) -> eyre::Result<()> { |(query_idx, is_match)| if !is_match { Some(query_idx) } else { None }, ) .map(|query_idx| { + let serial_id = (merged_results[query_idx] + 1) as i64; // Get the original vectors from `receive_batch`. - (merged_results[query_idx] + 1, StoredIrisRef { + (serial_id, StoredIrisRef { + id: serial_id, left_code: &store_left.code[query_idx].coefs[..], left_mask: &store_left.mask[query_idx].coefs[..], right_code: &store_right.code[query_idx].coefs[..], @@ -931,13 +1263,7 @@ async fn server_main(config: Config) -> eyre::Result<()> { .await?; if !codes_and_masks.is_empty() && !config_bg.disable_persistence { - let db_serial_ids = store_bg - .insert_irises(&mut tx, &codes_and_masks) - .await - .wrap_err("failed to persist queries")? - .iter() - .map(|&x| x as u32) - .collect::>(); + let db_serial_ids = store_bg.insert_irises(&mut tx, &codes_and_masks).await?; // Check if the serial_ids match between memory and db. if memory_serial_ids != db_serial_ids { @@ -1003,89 +1329,55 @@ async fn server_main(config: Config) -> eyre::Result<()> { }); background_tasks.check_tasks(); - tracing::info!("All systems ready."); - tracing::info!("Starting healthcheck server."); - - let _health_check_abort = background_tasks.spawn(async move { - // Generate a random UUID for each run. - let uuid = uuid::Uuid::new_v4().to_string(); - let app = Router::new().route("/health", get(|| async { uuid })); // implicit 200 return - let listener = tokio::net::TcpListener::bind("0.0.0.0:3000") - .await - .wrap_err("healthcheck listener bind error")?; - axum::serve(listener, app) - .await - .wrap_err("healthcheck listener server launch error")?; - - Ok(()) - }); + // -------------------------------------------------------------------------- + // ANCHOR: Enable readiness and check all nodes + // -------------------------------------------------------------------------- + tracing::info!("⚓️ ANCHOR: Enable readiness and check all nodes"); - background_tasks.check_tasks(); - tracing::info!("Healthcheck server running on port 3000."); + // Set the readiness flag to true, which will make the readiness server return a + // 200 status code. + is_ready_flag_cloned.store(true, std::sync::atomic::Ordering::SeqCst); - let (heartbeat_tx, heartbeat_rx) = oneshot::channel(); - let mut heartbeat_tx = Some(heartbeat_tx); + // Check other nodes and wait until all nodes are ready. + let (readiness_tx, readiness_rx) = oneshot::channel(); + let mut readiness_tx = Some(readiness_tx); let all_nodes = config.node_hostnames.clone(); let _heartbeat = background_tasks.spawn(async move { let next_node = &all_nodes[(config.party_id + 1) % 3]; let prev_node = &all_nodes[(config.party_id + 2) % 3]; - let mut last_response = [String::default(), String::default()]; let mut connected = [false, false]; - let mut retries = [0, 0]; loop { for (i, host) in [next_node, prev_node].iter().enumerate() { - let res = reqwest::get(format!("http://{}:3000/health", host)).await; - if res.is_err() || !res.as_ref().unwrap().status().is_success() { - // If it's the first time after startup, we allow a few retries to let the other - // nodes start up as well. - if last_response[i] == String::default() - && retries[i] < config.heartbeat_initial_retries - { - retries[i] += 1; - tracing::warn!("Node {} did not respond with success, retrying...", host); - continue; - } - // The other node seems to be down or returned an error. - panic!( - "Node {} did not respond with success, killing server...", - host - ); - } + let res = reqwest::get(format!("http://{}:3000/ready", host)).await; - let uuid = res.unwrap().text().await?; - if last_response[i] == String::default() { - last_response[i] = uuid; + if res.is_ok() && res.as_ref().unwrap().status().is_success() { connected[i] = true; - // If all nodes are connected, notify the main thread. if connected.iter().all(|&c| c) { - if let Some(tx) = heartbeat_tx.take() { + if let Some(tx) = readiness_tx.take() { tx.send(()).unwrap(); } } - } else if uuid != last_response[i] { - // If the UUID response is different, the node has restarted without us - // noticing. Our main NCCL connections cannot recover from - // this, so we panic. - panic!("Node {} seems to have restarted, killing server...", host); - } else { - tracing::info!("Heartbeat: Node {} is healthy", host); } } - tokio::time::sleep(Duration::from_secs(config.heartbeat_interval_secs)).await; + tokio::time::sleep(Duration::from_secs(1)).await; } }); - tracing::info!("Heartbeat starting..."); - heartbeat_rx.await?; - tracing::info!("Heartbeat on all nodes started."); + tracing::info!("Waiting for all nodes to be ready..."); + readiness_rx.await?; + tracing::info!("All nodes are ready."); background_tasks.check_tasks(); - let processing_timeout = Duration::from_secs(config.processing_timeout_secs); + // -------------------------------------------------------------------------- + // ANCHOR: Start the main loop + // -------------------------------------------------------------------------- + tracing::info!("⚓️ ANCHOR: Start the main loop"); - // Main loop + let processing_timeout = Duration::from_secs(config.processing_timeout_secs); + let error_result_attribute = create_message_type_attribute_map(UNIQUENESS_MESSAGE_TYPE); let res: eyre::Result<()> = async { tracing::info!("Entering main loop"); // **Tensor format of queries** @@ -1109,12 +1401,14 @@ async fn server_main(config: Config) -> eyre::Result<()> { let mut next_batch = receive_batch( party_id, &sqs_client, - &config.requests_queue_url, + &sns_client, + &s3_client, + &config, &store, &skip_request_ids, shares_encryption_key_pair.clone(), - config.max_batch_size, &shutdown_handler, + &error_result_attribute, ); let dummy_shares_for_deletions = get_dummy_shares_for_deletion(party_id); @@ -1161,12 +1455,14 @@ async fn server_main(config: Config) -> eyre::Result<()> { next_batch = receive_batch( party_id, &sqs_client, - &config.requests_queue_url, + &sns_client, + &s3_client, + &config, &store, &skip_request_ids, shares_encryption_key_pair.clone(), - config.max_batch_size, &shutdown_handler, + &error_result_attribute, ); // await the result