From 5c2c3339717298438e5f59d367d6f04749a0a63a Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Thu, 24 Mar 2022 20:34:34 -0700 Subject: [PATCH 01/43] perform replace in postgres without removing original row Signed-off-by: Andrew Whitehead --- src/backend/postgres/mod.rs | 50 ++++++++++++++++++++++++++++++++++--- 1 file changed, 46 insertions(+), 4 deletions(-) diff --git a/src/backend/postgres/mod.rs b/src/backend/postgres/mod.rs index 2723da5c..a548c47d 100644 --- a/src/backend/postgres/mod.rs +++ b/src/backend/postgres/mod.rs @@ -49,11 +49,14 @@ const FETCH_QUERY_UPDATE: &'static str = "SELECT id, value, FROM items_tags it WHERE it.item_id = i.id) tags FROM items i WHERE profile_id = $1 AND kind = $2 AND category = $3 AND name = $4 - AND (expiry IS NULL OR expiry > CURRENT_TIMESTAMP) FOR UPDATE"; + AND (expiry IS NULL OR expiry > CURRENT_TIMESTAMP) FOR NO KEY UPDATE"; const INSERT_QUERY: &'static str = "INSERT INTO items (profile_id, kind, category, name, value, expiry) VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT DO NOTHING RETURNING id"; +const UPDATE_QUERY: &'static str = "UPDATE items SET value=$5, expiry=$6 + WHERE profile_id=$1 AND kind=$2 AND category=$3 AND name=$4 + RETURNING id"; const SCAN_QUERY: &'static str = "SELECT id, name, value, (SELECT ARRAY_TO_STRING(ARRAY_AGG(it.plaintext || ':' || ENCODE(it.name, 'hex') || ':' || ENCODE(it.value, 'hex')), ',') @@ -64,6 +67,8 @@ const DELETE_ALL_QUERY: &'static str = "DELETE FROM items i WHERE i.profile_id = $1 AND i.kind = $2 AND i.category = $3"; const TAG_INSERT_QUERY: &'static str = "INSERT INTO items_tags (item_id, name, value, plaintext) VALUES ($1, $2, $3, $4)"; +const TAG_DELETE_QUERY: &'static str = "DELETE FROM items_tags + WHERE item_id=$1"; mod provision; pub use provision::PostgresStoreOptions; @@ -500,8 +505,7 @@ impl QueryBackend for DbSession { let mut active = acquire_session(&mut *self).await?; let mut txn = active.as_transaction().await?; - perform_remove(&mut txn, kind, &enc_category, &enc_name, false).await?; - perform_insert( + perform_update( &mut txn, kind, &enc_category, @@ -639,6 +643,44 @@ async fn perform_insert<'q>( Ok(()) } +async fn perform_update<'q>( + active: &mut DbSessionActive<'q, Postgres>, + kind: EntryKind, + enc_category: &[u8], + enc_name: &[u8], + enc_value: &[u8], + enc_tags: Option>, + expiry_ms: Option, +) -> Result<(), Error> { + trace!("Update entry"); + let row_id: i64 = sqlx::query_scalar(UPDATE_QUERY) + .bind(active.profile_id) + .bind(kind as i16) + .bind(enc_category) + .bind(enc_name) + .bind(enc_value) + .bind(expiry_ms.map(expiry_timestamp).transpose()?) + .fetch_one(active.connection_mut()) + .await + .map_err(|_| err_msg!(NotFound, "Error updating existing row"))?; + sqlx::query(TAG_DELETE_QUERY) + .bind(row_id) + .execute(active.connection_mut()) + .await?; + if let Some(tags) = enc_tags { + for tag in tags { + sqlx::query(TAG_INSERT_QUERY) + .bind(row_id) + .bind(&tag.name) + .bind(&tag.value) + .bind(tag.plaintext as i16) + .execute(active.connection_mut()) + .await?; + } + } + Ok(()) +} + async fn perform_remove<'q>( active: &mut DbSessionActive<'q, Postgres>, kind: EntryKind, @@ -690,7 +732,7 @@ fn perform_scan<'q>( params.push(enc_category); let mut query = extend_query::(SCAN_QUERY, &mut params, tag_filter, offset, limit)?; if for_update { - query.push_str(" FOR UPDATE"); + query.push_str(" FOR NO KEY UPDATE"); } let mut batch = Vec::with_capacity(PAGE_SIZE); From 743f594d9e6a9a713d4c39cb25284bbc9083177a Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Thu, 24 Mar 2022 20:34:48 -0700 Subject: [PATCH 02/43] additional performance tests Signed-off-by: Andrew Whitehead --- wrappers/python/demo/perf.py | 51 +++++++++++++++++++++++++++++++----- 1 file changed, 45 insertions(+), 6 deletions(-) diff --git a/wrappers/python/demo/perf.py b/wrappers/python/demo/perf.py index aaa80b4c..a3fde18b 100644 --- a/wrappers/python/demo/perf.py +++ b/wrappers/python/demo/perf.py @@ -31,22 +31,37 @@ async def perf_test(): store = await Store.provision(REPO_URI, "raw", key, recreate=True) + insert_start = time.perf_counter() + async with store.session() as session: + for idx in range(PERF_ROWS): + await session.insert( + "seq", + f"name-{idx}", + b"value", + {"~plaintag": "a", "enctag": "b"}, + ) + dur = time.perf_counter() - insert_start + print(f"sequential insert duration ({PERF_ROWS} rows): {dur:0.2f}s") + insert_start = time.perf_counter() async with store.transaction() as txn: - # ^ faster within a transaction + # ^ should be faster within a transaction for idx in range(PERF_ROWS): await txn.insert( - "category", f"name-{idx}", b"value", {"~plaintag": "a", "enctag": "b"} + "txn", + f"name-{idx}", + b"value", + {"~plaintag": "a", "enctag": "b"}, ) await txn.commit() dur = time.perf_counter() - insert_start - print(f"insert duration ({PERF_ROWS} rows): {dur:0.2f}s") + print(f"transaction batch insert duration ({PERF_ROWS} rows): {dur:0.2f}s") + tags = 0 fetch_start = time.perf_counter() async with store as session: - tags = 0 for idx in range(PERF_ROWS): - entry = await session.fetch("category", f"name-{idx}") + entry = await session.fetch("seq", f"name-{idx}") tags += len(entry.tags) dur = time.perf_counter() - fetch_start print(f"fetch duration ({PERF_ROWS} rows, {tags} tags): {dur:0.2f}s") @@ -54,12 +69,36 @@ async def perf_test(): rc = 0 tags = 0 scan_start = time.perf_counter() - async for row in store.scan("category", {"~plaintag": "a", "enctag": "b"}): + async for row in store.scan("seq", {"~plaintag": "a", "enctag": "b"}): rc += 1 tags += len(row.tags) dur = time.perf_counter() - scan_start print(f"scan duration ({rc} rows, {tags} tags): {dur:0.2f}s") + async with store as session: + await session.insert("seq", "count", "0", {"~plaintag": "a", "enctag": "b"}) + update_start = time.perf_counter() + count = 0 + for idx in range(PERF_ROWS): + count += 1 + await session.replace( + "seq", "count", str(count), {"~plaintag": "a", "enctag": "b"} + ) + dur = time.perf_counter() - update_start + print(f"unchecked update duration ({PERF_ROWS} rows): {dur:0.2f}s") + + async with store as session: + await session.insert("txn", "count", "0", {"~plaintag": "a", "enctag": "b"}) + update_start = time.perf_counter() + for idx in range(PERF_ROWS): + async with store.transaction() as txn: + row = await txn.fetch("txn", "count", for_update=True) + count = str(int(row.value) + 1) + await txn.replace("txn", "count", count, {"~plaintag": "a", "enctag": "b"}) + await txn.commit() + dur = time.perf_counter() - update_start + print(f"transactional update duration ({PERF_ROWS} rows): {dur:0.2f}s") + await store.close() From 44a05dd5d1a7325b6125beff75e8d8cf28a7756b Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Thu, 24 Mar 2022 20:35:19 -0700 Subject: [PATCH 03/43] test multi-threaded row updates Signed-off-by: Andrew Whitehead --- wrappers/python/tests/test_store.py | 89 ++++++++++++++++++++++++++++- 1 file changed, 88 insertions(+), 1 deletion(-) diff --git a/wrappers/python/tests/test_store.py b/wrappers/python/tests/test_store.py index 03ef787e..c886aafb 100644 --- a/wrappers/python/tests/test_store.py +++ b/wrappers/python/tests/test_store.py @@ -1,4 +1,6 @@ +import asyncio import os +import threading from pytest import mark, raises import pytest_asyncio @@ -110,7 +112,7 @@ async def test_scan(store: Store): @mark.asyncio -async def test_transaction(store: Store): +async def test_transaction_basic(store: Store): async with store.transaction() as txn: # Insert a new entry @@ -144,6 +146,91 @@ async def test_transaction(store: Store): assert dict(found) == TEST_ENTRY +@mark.asyncio +async def test_transaction_conflict(store: Store): + async with store.transaction() as txn: + await txn.insert( + TEST_ENTRY["category"], + TEST_ENTRY["name"], + "0", + ) + await txn.commit() + + INC_COUNT = 500 + TASKS = 50 + + async def inc(): + for _ in range(INC_COUNT): + async with store.transaction() as txn: + row = await txn.fetch( + TEST_ENTRY["category"], TEST_ENTRY["name"], for_update=True + ) + if not row: + raise Exception("Row not found") + new_value = str(int(row.value) + 1) + await txn.replace(TEST_ENTRY["category"], TEST_ENTRY["name"], new_value) + await txn.commit() + + tasks = [asyncio.create_task(inc()) for _ in range(TASKS)] + await asyncio.gather(*tasks) + + # Check all the updates completed + async with store.session() as session: + result = await session.fetch( + TEST_ENTRY["category"], + TEST_ENTRY["name"], + ) + assert int(result.value) == INC_COUNT * TASKS + + +@mark.asyncio +async def test_transaction_conflict_threaded(store: Store): + async with store.transaction() as txn: + await txn.insert( + TEST_ENTRY["category"], + TEST_ENTRY["name"], + "0", + ) + await txn.commit() + + INC_COUNT = 500 + TASKS = 50 + + async def inc(): + for _ in range(INC_COUNT): + async with store.transaction() as txn: + row = await txn.fetch( + TEST_ENTRY["category"], TEST_ENTRY["name"], for_update=True + ) + if not row: + raise Exception("Row not found") + new_value = str(int(row.value) + 1) + await txn.replace(TEST_ENTRY["category"], TEST_ENTRY["name"], new_value) + await txn.commit() + + def proc(): + loop = asyncio.new_event_loop() + loop.run_until_complete(inc()) + + tasks = [] + for _ in range(TASKS): + th = threading.Thread(target=proc) + th.start() + tasks.append(th) + + # This will pause the current event loop, but that shouldn't be a problem + for task in tasks: + task.join() + + # Check all the updates completed + async with store.session() as session: + result = await session.fetch( + TEST_ENTRY["category"], + TEST_ENTRY["name"], + ) + assert int(result.value) == INC_COUNT * TASKS + + @mark.asyncio async def test_key_store(store: Store): # test key operations in a new session From e85de6fedaa66f48a0f5351f8779c02f1a79174a Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Thu, 24 Mar 2022 20:38:10 -0700 Subject: [PATCH 04/43] update version to 0.2.5 Signed-off-by: Andrew Whitehead --- Cargo.toml | 2 +- wrappers/python/aries_askar/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 7db2ba0e..ee5e1a29 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ members = ["askar-bbs", "askar-crypto"] [package] name = "aries-askar" -version = "0.2.4" +version = "0.2.5" authors = ["Hyperledger Aries Contributors "] edition = "2018" description = "Hyperledger Aries Askar secure storage" diff --git a/wrappers/python/aries_askar/version.py b/wrappers/python/aries_askar/version.py index 9cc67882..b4d49e71 100644 --- a/wrappers/python/aries_askar/version.py +++ b/wrappers/python/aries_askar/version.py @@ -1,3 +1,3 @@ """aries_askar library wrapper version.""" -__version__ = "0.2.4" +__version__ = "0.2.5" From 9db0d1ef6d3ec9a47d11bbc63c9236a3911e4b57 Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Thu, 24 Mar 2022 22:24:28 -0700 Subject: [PATCH 05/43] try disabling lto Signed-off-by: Andrew Whitehead --- Cargo.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ee5e1a29..fdc4ba24 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -77,9 +77,9 @@ default-features = false features = ["chrono", "runtime-tokio-rustls"] optional = true -[profile.release] -lto = true -codegen-units = 1 +# [profile.release] +# lto = true +# codegen-units = 1 [[test]] name = "backends" From 53520053cd88b8ebbee185b1e70ab61ef06f9e44 Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Fri, 25 Mar 2022 11:42:50 -0700 Subject: [PATCH 06/43] always enable shared cache in sqlite Signed-off-by: Andrew Whitehead --- src/backend/sqlite/provision.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/backend/sqlite/provision.rs b/src/backend/sqlite/provision.rs index 987ca9ff..480eafbc 100644 --- a/src/backend/sqlite/provision.rs +++ b/src/backend/sqlite/provision.rs @@ -50,8 +50,9 @@ impl SqliteStoreOptions { async fn pool(&self, auto_create: bool) -> std::result::Result { #[allow(unused_mut)] - let mut conn_opts = - SqliteConnectOptions::from_str(self.path.as_ref())?.create_if_missing(auto_create); + let mut conn_opts = SqliteConnectOptions::from_str(self.path.as_ref())? + .create_if_missing(auto_create) + .shared_cache(true); #[cfg(feature = "log")] { conn_opts.log_statements(log::LevelFilter::Debug); From 61db34102fe9a182dde8520eb6c666ec9738e457 Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Fri, 25 Mar 2022 11:59:38 -0700 Subject: [PATCH 07/43] reduce tested threads Signed-off-by: Andrew Whitehead --- wrappers/python/tests/test_store.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wrappers/python/tests/test_store.py b/wrappers/python/tests/test_store.py index c886aafb..8563aeed 100644 --- a/wrappers/python/tests/test_store.py +++ b/wrappers/python/tests/test_store.py @@ -157,7 +157,7 @@ async def test_transaction_conflict(store: Store): await txn.commit() INC_COUNT = 500 - TASKS = 50 + TASKS = 20 async def inc(): for _ in range(INC_COUNT): @@ -194,7 +194,7 @@ async def test_transaction_conflict_threaded(store: Store): await txn.commit() INC_COUNT = 500 - TASKS = 50 + TASKS = 20 async def inc(): for _ in range(INC_COUNT): From 8a1253fcd52ed16b74c8e0beeca577fc11f16d62 Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Fri, 25 Mar 2022 14:21:06 -0700 Subject: [PATCH 08/43] remove multi-threaded test Signed-off-by: Andrew Whitehead --- wrappers/python/tests/test_store.py | 48 ----------------------------- 1 file changed, 48 deletions(-) diff --git a/wrappers/python/tests/test_store.py b/wrappers/python/tests/test_store.py index 8563aeed..c0e034ba 100644 --- a/wrappers/python/tests/test_store.py +++ b/wrappers/python/tests/test_store.py @@ -183,54 +183,6 @@ async def inc(): assert int(result.value) == INC_COUNT * TASKS -@mark.asyncio -async def test_transaction_conflict_threaded(store: Store): - async with store.transaction() as txn: - await txn.insert( - TEST_ENTRY["category"], - TEST_ENTRY["name"], - "0", - ) - await txn.commit() - - INC_COUNT = 500 - TASKS = 20 - - async def inc(): - for _ in range(INC_COUNT): - async with store.transaction() as txn: - row = await txn.fetch( - TEST_ENTRY["category"], TEST_ENTRY["name"], for_update=True - ) - if not row: - raise Exception("Row not found") - new_value = str(int(row.value) + 1) - await txn.replace(TEST_ENTRY["category"], TEST_ENTRY["name"], new_value) - await txn.commit() - - def proc(): - loop = asyncio.new_event_loop() - loop.run_until_complete(inc()) - - tasks = [] - for _ in range(TASKS): - th = threading.Thread(target=proc) - th.start() - tasks.append(th) - - # This will pause the current event loop, but that shouldn't be a problem - for task in tasks: - task.join() - - # Check all the updates completed - async with store.session() as session: - result = await session.fetch( - TEST_ENTRY["category"], - TEST_ENTRY["name"], - ) - assert int(result.value) == INC_COUNT * TASKS - - @mark.asyncio async def test_key_store(store: Store): # test key operations in a new session From 9a0d98f0f1e331c9da0d9b24d455a7be6371847d Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Fri, 25 Mar 2022 14:21:35 -0700 Subject: [PATCH 09/43] add contention test in rust; test python wrapper on postgres; disable fail-fast Signed-off-by: Andrew Whitehead --- .github/workflows/build.yml | 87 ++++++++---------- src/backend/postgres/test_db.rs | 41 ++++++--- tests/backends.rs | 153 +++++++++++++++++++++++++------- tests/utils/mod.rs | 141 ++++++++++++++++++++++++----- 4 files changed, 304 insertions(+), 118 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 7eeab8ea..5e3e047e 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -19,7 +19,7 @@ jobs: name: Run Checks strategy: matrix: - os: [macos-11, windows-latest, ubuntu-latest] + os: [ubuntu-latest, macos-11, windows-latest] runs-on: ${{ matrix.os }} steps: @@ -30,12 +30,12 @@ jobs: uses: actions-rs/toolchain@v1 with: profile: minimal - toolchain: stable + toolchain: 1.56 - name: Cache cargo resources uses: Swatinem/rust-cache@v1 with: - sharedKey: deps + sharedKey: check - name: Cargo check uses: actions-rs/cargo@v1 @@ -55,7 +55,17 @@ jobs: command: build args: --all-targets - - name: Test + - if: "runner.os == 'Linux'" + name: Test with postgres + run: | + sudo systemctl start postgresql.service + pg_isready + sudo -u postgres psql -c "ALTER USER postgres WITH PASSWORD 'postgres'" + echo "-- Test Postgres DB --" + POSTGRES_URL=postgres://postgres:postgres@localhost:5432/test-db cargo test --workspace --features pg_test + + - if: "runner.os != 'Linux'" + name: Test without postgres uses: actions-rs/cargo@v1 with: command: test @@ -73,52 +83,12 @@ jobs: command: test args: --manifest-path ./askar-bbs/Cargo.toml --no-default-features - test-postgres: - name: Postgres - runs-on: ubuntu-latest - needs: [check] - - services: - postgres: - image: postgres - env: - POSTGRES_PASSWORD: postgres - options: >- - --health-cmd pg_isready - --health-interval 10s - --health-timeout 5s - --health-retries 5 - ports: - - 5432:5432 - - steps: - - name: Checkout - uses: actions/checkout@v2 - - - name: Install Rust toolchain - uses: actions-rs/toolchain@v1 - with: - profile: minimal - toolchain: stable - - - name: Cache cargo resources - uses: Swatinem/rust-cache@v1 - with: - sharedKey: deps - - - name: Test - uses: actions-rs/cargo@v1 - env: - POSTGRES_URL: postgres://postgres:postgres@localhost:5432/test-db - with: - command: test - args: --features pg_test - build-manylinux: name: Build Library needs: [check] strategy: + fail-fast: false matrix: include: - os: ubuntu-latest @@ -136,12 +106,12 @@ jobs: uses: actions-rs/toolchain@v1 with: profile: minimal - toolchain: stable + toolchain: 1.56 - name: Cache cargo resources uses: Swatinem/rust-cache@v1 with: - sharedKey: deps + sharedKey: check - name: Build library env: @@ -159,6 +129,7 @@ jobs: needs: [check] strategy: + fail-fast: false matrix: include: - os: macos-11 @@ -167,7 +138,7 @@ jobs: toolchain: beta # beta required for aarch64-apple-darwin target - os: windows-latest lib: aries_askar.dll - toolchain: stable + toolchain: 1.56 runs-on: ${{ matrix.os }} @@ -184,7 +155,7 @@ jobs: - name: Cache cargo resources uses: Swatinem/rust-cache@v1 with: - sharedKey: deps + sharedKey: check - name: Build library env: @@ -203,6 +174,7 @@ jobs: needs: [build-manylinux, build-other] strategy: + fail-fast: false matrix: os: [ubuntu-latest, macos-11, windows-latest] python-version: [3.7] @@ -241,12 +213,27 @@ jobs: run: | python setup.py bdist_wheel --python-tag=py3 --plat-name=${{ matrix.plat-name }} pip install pytest pytest-asyncio dist/* + echo "-- Test SQLite in-memory --" python -m pytest + echo "-- Test SQLite file DB --" TEST_STORE_URI=sqlite://test.db python -m pytest working-directory: wrappers/python + env: + no_proxy: "*" # python issue 30385 + + - if: "runner.os == 'Linux'" + name: Test postgres + run: | + sudo systemctl start postgresql.service + pg_isready + sudo -u postgres psql -c "ALTER USER postgres WITH PASSWORD 'postgres'" + TEST_STORE_URI=postgres://postgres:postgres@localhost:5432/test-db python -m pytest + working-directory: wrappers/python + env: + no_proxy: "*" # python issue 30385 - if: "runner.os == 'Linux'" - name: Auditwheel + name: Audit wheel run: auditwheel show wrappers/python/dist/* - name: Upload python package diff --git a/src/backend/postgres/test_db.rs b/src/backend/postgres/test_db.rs index 24d2f40e..8dccf404 100644 --- a/src/backend/postgres/test_db.rs +++ b/src/backend/postgres/test_db.rs @@ -81,6 +81,32 @@ impl TestDB { lock_txn: Some(lock_txn), }) } + + async fn close_internal( + mut lock_txn: Option, + mut inst: Option>, + ) -> Result<(), Error> { + if let Some(lock_txn) = lock_txn.take() { + lock_txn.close().await?; + } + if let Some(inst) = inst.take() { + timeout(Duration::from_secs(30), inst.close()) + .await + .ok_or_else(|| { + err_msg!( + Backend, + "Timed out waiting for the pool connection to close" + ) + })??; + } + Ok(()) + } + + /// Explicitly close the test database + pub async fn close(mut self) -> Result<(), Error> { + Self::close_internal(self.lock_txn.take(), self.inst.take()).await?; + Ok(()) + } } impl std::ops::Deref for TestDB { @@ -93,21 +119,14 @@ impl std::ops::Deref for TestDB { impl Drop for TestDB { fn drop(&mut self) { - if let Some(lock_txn) = self.lock_txn.take() { + if self.lock_txn.is_some() || self.inst.is_some() { + let lock_txn = self.lock_txn.take(); + let inst = self.inst.take(); spawn_ok(async { - lock_txn - .close() + Self::close_internal(lock_txn, inst) .await .expect("Error closing database connection"); }); } - if let Some(inst) = self.inst.take() { - spawn_ok(async { - timeout(Duration::from_secs(30), inst.close()) - .await - .expect("Timed out waiting for the pool connection to close") - .expect("Error closing connection pool"); - }); - } } } diff --git a/tests/backends.rs b/tests/backends.rs index a7fe96e1..f8cdd259 100644 --- a/tests/backends.rs +++ b/tests/backends.rs @@ -1,19 +1,27 @@ mod utils; +const ERR_CLOSE: &'static str = "Error closing database"; + macro_rules! backend_tests { ($init:expr) => { use aries_askar::future::block_on; + use std::sync::Arc; + use $crate::utils::TestStore; #[test] fn init() { - block_on($init); + block_on(async { + let db = $init.await; + db.close().await.expect(ERR_CLOSE); + }); } #[test] fn create_remove_profile() { block_on(async { let db = $init.await; - super::utils::db_create_remove_profile(&db).await; + super::utils::db_create_remove_profile(db.clone()).await; + db.close().await.expect(ERR_CLOSE); }) } @@ -21,7 +29,8 @@ macro_rules! backend_tests { fn fetch_fail() { block_on(async { let db = $init.await; - super::utils::db_fetch_fail(&db).await; + super::utils::db_fetch_fail(db.clone()).await; + db.close().await.expect(ERR_CLOSE); }) } @@ -29,7 +38,8 @@ macro_rules! backend_tests { fn insert_fetch() { block_on(async { let db = $init.await; - super::utils::db_insert_fetch(&db).await; + super::utils::db_insert_fetch(db.clone()).await; + db.close().await.expect(ERR_CLOSE); }) } @@ -37,7 +47,8 @@ macro_rules! backend_tests { fn insert_duplicate() { block_on(async { let db = $init.await; - super::utils::db_insert_duplicate(&db).await; + super::utils::db_insert_duplicate(db.clone()).await; + db.close().await.expect(ERR_CLOSE); }) } @@ -45,7 +56,8 @@ macro_rules! backend_tests { fn insert_remove() { block_on(async { let db = $init.await; - super::utils::db_insert_remove(&db).await; + super::utils::db_insert_remove(db.clone()).await; + db.close().await.expect(ERR_CLOSE); }) } @@ -53,7 +65,8 @@ macro_rules! backend_tests { fn remove_missing() { block_on(async { let db = $init.await; - super::utils::db_remove_missing(&db).await; + super::utils::db_remove_missing(db.clone()).await; + db.close().await.expect(ERR_CLOSE); }) } @@ -61,7 +74,8 @@ macro_rules! backend_tests { fn replace_fetch() { block_on(async { let db = $init.await; - super::utils::db_replace_fetch(&db).await; + super::utils::db_replace_fetch(db.clone()).await; + db.close().await.expect(ERR_CLOSE); }) } @@ -69,7 +83,8 @@ macro_rules! backend_tests { fn replace_missing() { block_on(async { let db = $init.await; - super::utils::db_replace_missing(&db).await; + super::utils::db_replace_missing(db.clone()).await; + db.close().await.expect(ERR_CLOSE); }) } @@ -77,7 +92,8 @@ macro_rules! backend_tests { fn count() { block_on(async { let db = $init.await; - super::utils::db_count(&db).await; + super::utils::db_count(db.clone()).await; + db.close().await.expect(ERR_CLOSE); }) } @@ -85,7 +101,8 @@ macro_rules! backend_tests { fn count_exist() { block_on(async { let db = $init.await; - super::utils::db_count_exist(&db).await; + super::utils::db_count_exist(db.clone()).await; + db.close().await.expect(ERR_CLOSE); }) } @@ -93,7 +110,8 @@ macro_rules! backend_tests { fn scan() { block_on(async { let db = $init.await; - super::utils::db_scan(&db).await; + super::utils::db_scan(db.clone()).await; + db.close().await.expect(ERR_CLOSE); }) } @@ -101,7 +119,8 @@ macro_rules! backend_tests { fn remove_all() { block_on(async { let db = $init.await; - super::utils::db_remove_all(&db).await; + super::utils::db_remove_all(db.clone()).await; + db.close().await.expect(ERR_CLOSE); }) } @@ -109,7 +128,8 @@ macro_rules! backend_tests { // fn keypair_create_fetch() { // block_on(async { // let db = $init.await; - // super::utils::db_keypair_create_fetch(&db).await; + // super::utils::db_keypair_create_fetch(db.clone()).await; + // db.close().await.expect(ERR_CLOSE); // }) // } @@ -117,7 +137,8 @@ macro_rules! backend_tests { // fn keypair_sign_verify() { // block_on(async { // let db = $init.await; - // super::utils::db_keypair_sign_verify(&db).await; + // super::utils::db_keypair_sign_verify(db.clone()).await; + // db.close().await.expect(ERR_CLOSE); // }) // } @@ -125,7 +146,8 @@ macro_rules! backend_tests { // fn keypair_pack_unpack_anon() { // block_on(async { // let db = $init.await; - // super::utils::db_keypair_pack_unpack_anon(&db).await; + // super::utils::db_keypair_pack_unpack_anon(db.clone()).await; + // db.close().await.expect(ERR_CLOSE); // }) // } @@ -133,7 +155,8 @@ macro_rules! backend_tests { // fn keypair_pack_unpack_auth() { // block_on(async { // let db = $init.await; - // super::utils::db_keypair_pack_unpack_auth(&db).await; + // super::utils::db_keypair_pack_unpack_auth(db).await; + // db.close().await.expect(ERR_CLOSE); // }) // } @@ -141,7 +164,8 @@ macro_rules! backend_tests { fn txn_rollback() { block_on(async { let db = $init.await; - super::utils::db_txn_rollback(&db).await; + super::utils::db_txn_rollback(db.clone()).await; + db.close().await.expect(ERR_CLOSE); }) } @@ -149,7 +173,8 @@ macro_rules! backend_tests { fn txn_drop() { block_on(async { let db = $init.await; - super::utils::db_txn_drop(&db).await; + super::utils::db_txn_drop(db.clone()).await; + db.close().await.expect(ERR_CLOSE); }) } @@ -157,7 +182,8 @@ macro_rules! backend_tests { fn session_drop() { block_on(async { let db = $init.await; - super::utils::db_session_drop(&db).await; + super::utils::db_session_drop(db.clone()).await; + db.close().await.expect(ERR_CLOSE); }) } @@ -165,7 +191,8 @@ macro_rules! backend_tests { fn txn_commit() { block_on(async { let db = $init.await; - super::utils::db_txn_commit(&db).await; + super::utils::db_txn_commit(db.clone()).await; + db.close().await.expect(ERR_CLOSE); }) } @@ -173,7 +200,17 @@ macro_rules! backend_tests { fn txn_fetch_for_update() { block_on(async { let db = $init.await; - super::utils::db_txn_fetch_for_update(&db).await; + super::utils::db_txn_fetch_for_update(db.clone()).await; + db.close().await.expect(ERR_CLOSE); + }) + } + + #[test] + fn txn_contention() { + block_on(async { + let db = $init.await; + super::utils::db_txn_contention(db.clone()).await; + db.close().await.expect(ERR_CLOSE); }) } }; @@ -244,7 +281,7 @@ mod sqlite { #[test] fn rekey_db() { log_init(); - let fname = format!("sqlite-test-{}.db", uuid::Uuid::new_v4().to_string()); + let fname = format!("sqlite-rekey-{}.db", uuid::Uuid::new_v4().to_string()); let key1 = generate_raw_store_key(None).expect("Error creating raw key"); let key2 = generate_raw_store_key(None).expect("Error creating raw key"); assert_ne!(key1, key2); @@ -280,13 +317,40 @@ mod sqlite { }) } - async fn init_db() -> Store { + #[test] + fn file_db_contention() { + log_init(); + let fname = format!("sqlite-contend-{}.db", uuid::Uuid::new_v4().to_string()); + let key = generate_raw_store_key(None).expect("Error creating raw key"); + + block_on(async move { + let store = SqliteStoreOptions::new(fname.as_str()) + .expect("Error initializing sqlite store options") + .provision_backend(StoreKeyMethod::RawKey, key.as_ref(), None, true) + .await + .expect("Error provisioning sqlite store"); + + let db = std::sync::Arc::new(store); + super::utils::db_txn_contention(db.clone()).await; + db.close().await.expect("Error closing sqlite store"); + + SqliteStoreOptions::new(fname.as_str()) + .expect("Error initializing sqlite store options") + .remove_backend() + .await + .expect("Error removing sqlite store"); + }); + } + + async fn init_db() -> Arc> { log_init(); let key = generate_raw_store_key(None).expect("Error creating raw key"); - SqliteStoreOptions::in_memory() - .provision(StoreKeyMethod::RawKey, key, None, false) - .await - .expect("Error provisioning sqlite store") + Arc::new( + SqliteStoreOptions::in_memory() + .provision(StoreKeyMethod::RawKey, key, None, false) + .await + .expect("Error provisioning sqlite store"), + ) } backend_tests!(init_db()); @@ -315,15 +379,38 @@ mod sqlite { #[cfg(feature = "pg_test")] mod postgres { - use aries_askar::backend::postgres::test_db::TestDB; + use aries_askar::{backend::postgres::test_db::TestDB, postgres::PostgresStore, Store}; + use std::{future::Future, ops::Deref, pin::Pin}; use super::*; - async fn init_db() -> TestDB { + #[derive(Clone, Debug)] + struct Wrap(Arc); + + impl Deref for Wrap { + type Target = Store; + + fn deref(&self) -> &Self::Target { + &**self.0 + } + } + + impl TestStore for Wrap { + type DB = PostgresStore; + + fn close(self) -> Pin>>> { + let db = Arc::try_unwrap(self.0).unwrap(); + Box::pin(db.close()) + } + } + + async fn init_db() -> Wrap { log_init(); - TestDB::provision() - .await - .expect("Error provisioning postgres test database") + Wrap(Arc::new( + TestDB::provision() + .await + .expect("Error provisioning postgres test database"), + )) } backend_tests!(init_db()); diff --git a/tests/utils/mod.rs b/tests/utils/mod.rs index 7eae053d..abbb4ea3 100644 --- a/tests/utils/mod.rs +++ b/tests/utils/mod.rs @@ -1,8 +1,13 @@ -use aries_askar::{Backend, Entry, EntryTag, ErrorKind, Store, TagFilter}; +use std::{fmt::Debug, future::Future, ops::Deref, pin::Pin, sync::Arc}; + +use aries_askar::{Backend, Entry, EntryTag, Error, ErrorKind, Store, TagFilter}; + +use tokio::task::spawn; const ERR_PROFILE: &'static str = "Error creating profile"; const ERR_SESSION: &'static str = "Error starting session"; const ERR_TRANSACTION: &'static str = "Error starting transaction"; +const ERR_COMMIT: &'static str = "Error committing transaction"; const ERR_COUNT: &'static str = "Error performing count"; const ERR_FETCH: &'static str = "Error fetching test row"; const ERR_FETCH_ALL: &'static str = "Error fetching all test rows"; @@ -18,7 +23,22 @@ const ERR_SCAN_NEXT: &'static str = "Error fetching scan rows"; // const ERR_SIGN: &'static str = "Error signing message"; // const ERR_VERIFY: &'static str = "Error verifying signature"; -pub async fn db_create_remove_profile(db: &Store) { +pub trait TestStore: Clone + Deref> + Send + Sync { + type DB: Backend + Debug + 'static; + + fn close(self) -> Pin>>>; +} + +impl TestStore for Arc> { + type DB = B; + + fn close(self) -> Pin>>> { + let db = Arc::try_unwrap(self).unwrap(); + Box::pin(db.close()) + } +} + +pub async fn db_create_remove_profile(db: impl TestStore) { let profile = db.create_profile(None).await.expect(ERR_PROFILE); assert_eq!( db.remove_profile(profile) @@ -34,13 +54,13 @@ pub async fn db_create_remove_profile(db: &Store) { ); } -pub async fn db_fetch_fail(db: &Store) { +pub async fn db_fetch_fail(db: impl TestStore) { let mut conn = db.session(None).await.expect(ERR_SESSION); let result = conn.fetch("cat", "name", false).await.expect(ERR_FETCH); assert_eq!(result.is_none(), true); } -pub async fn db_insert_fetch(db: &Store) { +pub async fn db_insert_fetch(db: impl TestStore) { let test_row = Entry::new( "category", "name", @@ -78,7 +98,7 @@ pub async fn db_insert_fetch(db: &Store) { assert_eq!(rows[0], test_row); } -pub async fn db_insert_duplicate(db: &Store) { +pub async fn db_insert_duplicate(db: impl TestStore) { let test_row = Entry::new("category", "name", "value", Vec::new()); let mut conn = db.session(None).await.expect(ERR_SESSION); @@ -106,7 +126,7 @@ pub async fn db_insert_duplicate(db: &Store) { assert_eq!(err.kind(), ErrorKind::Duplicate); } -pub async fn db_insert_remove(db: &Store) { +pub async fn db_insert_remove(db: impl TestStore) { let test_row = Entry::new("category", "name", "value", Vec::new()); let mut conn = db.session(None).await.expect(ERR_SESSION); @@ -126,14 +146,14 @@ pub async fn db_insert_remove(db: &Store) { .expect(ERR_REQ_ROW); } -pub async fn db_remove_missing(db: &Store) { +pub async fn db_remove_missing(db: impl TestStore) { let mut conn = db.session(None).await.expect(ERR_SESSION); let err = conn.remove("cat", "name").await.expect_err(ERR_REQ_ERR); assert_eq!(err.kind(), ErrorKind::NotFound); } -pub async fn db_replace_fetch(db: &Store) { +pub async fn db_replace_fetch(db: impl TestStore) { let test_row = Entry::new("category", "name", "value", Vec::new()); let mut conn = db.session(None).await.expect(ERR_SESSION); @@ -168,7 +188,7 @@ pub async fn db_replace_fetch(db: &Store) { assert_eq!(row, replace_row); } -pub async fn db_replace_missing(db: &Store) { +pub async fn db_replace_missing(db: impl TestStore) { let test_row = Entry::new("category", "name", "value", Vec::new()); let mut conn = db.session(None).await.expect(ERR_SESSION); @@ -186,7 +206,7 @@ pub async fn db_replace_missing(db: &Store) { assert_eq!(err.kind(), ErrorKind::NotFound); } -pub async fn db_count(db: &Store) { +pub async fn db_count(db: impl TestStore) { let category = "category".to_string(); let test_rows = vec![Entry::new(&category, "name", "value", Vec::new())]; @@ -213,7 +233,7 @@ pub async fn db_count(db: &Store) { assert_eq!(count, 0); } -pub async fn db_count_exist(db: &Store) { +pub async fn db_count_exist(db: impl TestStore) { let test_row = Entry::new( "category", "name", @@ -352,7 +372,7 @@ pub async fn db_count_exist(db: &Store) { ); } -pub async fn db_scan(db: &Store) { +pub async fn db_scan(db: impl TestStore) { let category = "category".to_string(); let test_rows = vec![Entry::new( &category, @@ -400,7 +420,7 @@ pub async fn db_scan(db: &Store) { assert_eq!(rows, None); } -pub async fn db_remove_all(db: &Store) { +pub async fn db_remove_all(db: impl TestStore) { let test_rows = vec![ Entry::new( "category", @@ -460,7 +480,7 @@ pub async fn db_remove_all(db: &Store) { assert_eq!(removed, 2); } -// pub async fn db_keypair_create_fetch(db: &Store) { +// pub async fn db_keypair_create_fetch(db: impl TestStore) { // let mut conn = db.session(None).await.expect(ERR_SESSION); // let metadata = "meta".to_owned(); @@ -477,7 +497,7 @@ pub async fn db_remove_all(db: &Store) { // assert_eq!(Some(key_info), found); // } -// pub async fn db_keypair_sign_verify(db: &Store) { +// pub async fn db_keypair_sign_verify(db: impl TestStore) { // let mut conn = db.session(None).await.expect(ERR_SESSION); // let key_info = conn @@ -520,7 +540,7 @@ pub async fn db_remove_all(db: &Store) { // assert_eq!(err.kind(), ErrorKind::Input); // } -// pub async fn db_keypair_pack_unpack_anon(db: &Store) { +// pub async fn db_keypair_pack_unpack_anon(db: impl TestStore) { // let mut conn = db.session(None).await.expect(ERR_SESSION); // let recip_key = conn @@ -541,7 +561,7 @@ pub async fn db_remove_all(db: &Store) { // assert_eq!(p_send, None); // } -// pub async fn db_keypair_pack_unpack_auth(db: &Store) { +// pub async fn db_keypair_pack_unpack_auth(db: impl TestStore) { // let mut conn = db.session(None).await.expect(ERR_SESSION); // let sender_key = conn @@ -570,7 +590,7 @@ pub async fn db_remove_all(db: &Store) { // assert_eq!(p_send.map(|k| k.to_string()), Some(sender_key.ident)); // } -pub async fn db_txn_rollback(db: &Store) { +pub async fn db_txn_rollback(db: impl TestStore) { let test_row = Entry::new("category", "name", "value", Vec::new()); let mut conn = db.transaction(None).await.expect(ERR_TRANSACTION); @@ -598,7 +618,7 @@ pub async fn db_txn_rollback(db: &Store) { assert_eq!(row, None); } -pub async fn db_txn_drop(db: &Store) { +pub async fn db_txn_drop(db: impl TestStore) { let test_row = Entry::new("category", "name", "value", Vec::new()); let mut conn = db @@ -628,7 +648,7 @@ pub async fn db_txn_drop(db: &Store) { } // test that session does NOT have transaction rollback behaviour -pub async fn db_session_drop(db: &Store) { +pub async fn db_session_drop(db: impl TestStore) { let test_row = Entry::new("category", "name", "value", Vec::new()); let mut conn = db.session(None).await.expect(ERR_SESSION); @@ -654,7 +674,7 @@ pub async fn db_session_drop(db: &Store) { assert_eq!(row, Some(test_row)); } -pub async fn db_txn_commit(db: &Store) { +pub async fn db_txn_commit(db: impl TestStore) { let test_row = Entry::new("category", "name", "value", Vec::new()); let mut conn = db.transaction(None).await.expect(ERR_TRANSACTION); @@ -669,7 +689,7 @@ pub async fn db_txn_commit(db: &Store) { .await .expect(ERR_INSERT); - conn.commit().await.expect("Error committing transaction"); + conn.commit().await.expect(ERR_COMMIT); let mut conn = db.session(None).await.expect(ERR_SESSION); @@ -680,7 +700,7 @@ pub async fn db_txn_commit(db: &Store) { assert_eq!(row, Some(test_row)); } -pub async fn db_txn_fetch_for_update(db: &Store) { +pub async fn db_txn_fetch_for_update(db: impl TestStore) { let test_row = Entry::new("category", "name", "value", Vec::new()); let mut conn = db.transaction(None).await.expect(ERR_TRANSACTION); @@ -711,5 +731,78 @@ pub async fn db_txn_fetch_for_update(db: &Store) { assert_eq!(rows.len(), 1); assert_eq!(rows[0], test_row); - conn.commit().await.expect("Error committing transaction"); + conn.commit().await.expect(ERR_COMMIT); +} + +pub async fn db_txn_contention(db: impl TestStore + 'static) { + let test_row = Entry::new( + "category", + "count", + "0", + vec![ + EntryTag::Encrypted("t1".to_string(), "v1".to_string()), + EntryTag::Plaintext("t2".to_string(), "v2".to_string()), + ], + ); + + let mut conn = db.transaction(None).await.expect(ERR_TRANSACTION); + + conn.insert( + &test_row.category, + &test_row.name, + &test_row.value, + Some(test_row.tags.as_slice()), + None, + ) + .await + .expect(ERR_INSERT); + + conn.commit().await.expect(ERR_COMMIT); + + const TASKS: usize = 25; + const INC: usize = 500; + + async fn inc(db: impl TestStore, category: String, name: String) { + for _ in 0..INC { + let mut conn = db.transaction(None).await.expect(ERR_TRANSACTION); + let row = conn + .fetch(&category, &name, true) + .await + .expect(ERR_FETCH) + .expect(ERR_REQ_ROW); + let val: usize = str::parse(row.value.as_opt_str().unwrap()).unwrap(); + conn.replace( + &category, + &name, + &format!("{}", val + 1).as_bytes(), + Some(row.tags.as_slice()), + None, + ) + .await + .expect(ERR_REPLACE); + conn.commit().await.expect(ERR_COMMIT); + } + } + + let mut tasks = vec![]; + for _ in 0..TASKS { + tasks.push(spawn(inc( + db.clone(), + test_row.category.clone(), + test_row.name.clone(), + ))); + } + // JoinSet is not stable yet, just await all the tasks + for task in tasks { + task.await.unwrap(); + } + + // check the total + let mut conn = db.session(None).await.expect(ERR_SESSION); + let row = conn + .fetch(&test_row.category, &test_row.name, false) + .await + .expect(ERR_FETCH) + .expect(ERR_REQ_ROW); + assert_eq!(row.value, format!("{}", TASKS * INC).as_bytes()); } From 5f5908c89664402f6308710f02faf7602497875a Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Tue, 5 Apr 2022 14:18:10 -0700 Subject: [PATCH 10/43] (postgres) use update instead of delete & insert Signed-off-by: Andrew Whitehead --- src/backend/postgres/mod.rs | 118 ++++++++++-------------------------- 1 file changed, 33 insertions(+), 85 deletions(-) diff --git a/src/backend/postgres/mod.rs b/src/backend/postgres/mod.rs index a548c47d..73faaa1f 100644 --- a/src/backend/postgres/mod.rs +++ b/src/backend/postgres/mod.rs @@ -450,7 +450,7 @@ impl QueryBackend for DbSession { let name = ProfileKey::prepare_input(name.as_bytes()); match operation { - EntryOperation::Insert => { + op @ EntryOperation::Insert | op @ EntryOperation::Replace => { let value = ProfileKey::prepare_input(value.unwrap()); let tags = tags.map(prepare_tags); Box::pin(async move { @@ -478,41 +478,7 @@ impl QueryBackend for DbSession { &enc_value, enc_tags, expiry_ms, - ) - .await?; - txn.commit().await?; - Ok(()) - }) - } - EntryOperation::Replace => { - let value = ProfileKey::prepare_input(value.unwrap()); - let tags = tags.map(prepare_tags); - Box::pin(async move { - let (_, key) = acquire_key(&mut *self).await?; - let (enc_category, enc_name, enc_value, enc_tags) = unblock(move || { - let enc_value = - key.encrypt_entry_value(category.as_ref(), name.as_ref(), value)?; - Result::<_, Error>::Ok(( - key.encrypt_entry_category(category)?, - key.encrypt_entry_name(name)?, - enc_value, - tags.transpose()? - .map(|t| key.encrypt_entry_tags(t)) - .transpose()?, - )) - }) - .await?; - - let mut active = acquire_session(&mut *self).await?; - let mut txn = active.as_transaction().await?; - perform_update( - &mut txn, - kind, - &enc_category, - &enc_name, - &enc_value, - enc_tags, - expiry_ms, + op == EntryOperation::Insert, ) .await?; txn.commit().await?; @@ -617,56 +583,38 @@ async fn perform_insert<'q>( enc_value: &[u8], enc_tags: Option>, expiry_ms: Option, + new_row: bool, ) -> Result<(), Error> { - trace!("Insert entry"); - let row_id: i64 = sqlx::query_scalar(INSERT_QUERY) - .bind(active.profile_id) - .bind(kind as i16) - .bind(enc_category) - .bind(enc_name) - .bind(enc_value) - .bind(expiry_ms.map(expiry_timestamp).transpose()?) - .fetch_optional(active.connection_mut()) - .await? - .ok_or_else(|| err_msg!(Duplicate, "Duplicate row"))?; - if let Some(tags) = enc_tags { - for tag in tags { - sqlx::query(TAG_INSERT_QUERY) - .bind(row_id) - .bind(&tag.name) - .bind(&tag.value) - .bind(tag.plaintext as i16) - .execute(active.connection_mut()) - .await?; - } - } - Ok(()) -} - -async fn perform_update<'q>( - active: &mut DbSessionActive<'q, Postgres>, - kind: EntryKind, - enc_category: &[u8], - enc_name: &[u8], - enc_value: &[u8], - enc_tags: Option>, - expiry_ms: Option, -) -> Result<(), Error> { - trace!("Update entry"); - let row_id: i64 = sqlx::query_scalar(UPDATE_QUERY) - .bind(active.profile_id) - .bind(kind as i16) - .bind(enc_category) - .bind(enc_name) - .bind(enc_value) - .bind(expiry_ms.map(expiry_timestamp).transpose()?) - .fetch_one(active.connection_mut()) - .await - .map_err(|_| err_msg!(NotFound, "Error updating existing row"))?; - sqlx::query(TAG_DELETE_QUERY) - .bind(row_id) - .execute(active.connection_mut()) - .await?; + let row_id = if new_row { + trace!("Insert entry"); + sqlx::query_scalar(INSERT_QUERY) + .bind(active.profile_id) + .bind(kind as i16) + .bind(enc_category) + .bind(enc_name) + .bind(enc_value) + .bind(expiry_ms.map(expiry_timestamp).transpose()?) + .fetch_optional(active.connection_mut()) + .await? + .ok_or_else(|| err_msg!(Duplicate, "Duplicate row"))? + } else { + trace!("Update entry"); + let row_id: i64 = sqlx::query_scalar(UPDATE_QUERY) + .bind(active.profile_id) + .bind(kind as i16) + .bind(enc_category) + .bind(enc_name) + .bind(enc_value) + .bind(expiry_ms.map(expiry_timestamp).transpose()?) + .fetch_one(active.connection_mut()) + .await + .map_err(|_| err_msg!(NotFound, "Error updating existing row"))?; + sqlx::query(TAG_DELETE_QUERY) + .bind(row_id) + .execute(active.connection_mut()) + .await?; + row_id + }; if let Some(tags) = enc_tags { for tag in tags { sqlx::query(TAG_INSERT_QUERY) From 1caf494f5ee6f1018f28ab8e4e8dcf97bfcc7861 Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Tue, 5 Apr 2022 14:18:52 -0700 Subject: [PATCH 11/43] implement additional connection options for sqlite Signed-off-by: Andrew Whitehead --- src/backend/sqlite/provision.rs | 101 +++++++++++++++++++++++++++++--- 1 file changed, 94 insertions(+), 7 deletions(-) diff --git a/src/backend/sqlite/provision.rs b/src/backend/sqlite/provision.rs index 480eafbc..8e34298b 100644 --- a/src/backend/sqlite/provision.rs +++ b/src/backend/sqlite/provision.rs @@ -1,10 +1,12 @@ -use std::borrow::Cow; -use std::fs::remove_file; -use std::io::ErrorKind as IoErrorKind; -use std::str::FromStr; +use std::{ + borrow::Cow, fs::remove_file, io::ErrorKind as IoErrorKind, str::FromStr, time::Duration, +}; use sqlx::{ - sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions}, + sqlite::{ + SqliteAutoVacuum, SqliteConnectOptions, SqliteJournalMode, SqliteLockingMode, SqlitePool, + SqlitePoolOptions, SqliteSynchronous, + }, ConnectOptions, Error as SqlxError, Row, }; @@ -20,18 +22,56 @@ use crate::{ storage::{IntoOptions, Options, Store}, }; +const DEFAULT_MIN_CONNECTIONS: u32 = 1; +const DEFAULT_BUSY_TIMEOUT: Duration = Duration::from_secs(5); +const DEFAULT_JOURNAL_MODE: SqliteJournalMode = SqliteJournalMode::Wal; +const DEFAULT_LOCKING_MODE: SqliteLockingMode = SqliteLockingMode::Normal; +const DEFAULT_SHARED_CACHE: bool = true; +const DEFAULT_SYNCHRONOUS: SqliteSynchronous = SqliteSynchronous::Full; + /// Configuration options for Sqlite stores #[derive(Debug)] pub struct SqliteStoreOptions { pub(crate) in_memory: bool, pub(crate) path: String, + pub(crate) busy_timeout: Duration, pub(crate) max_connections: u32, + pub(crate) min_connections: u32, + pub(crate) journal_mode: SqliteJournalMode, + pub(crate) locking_mode: SqliteLockingMode, + pub(crate) shared_cache: bool, + pub(crate) synchronous: SqliteSynchronous, +} + +impl Default for SqliteStoreOptions { + fn default() -> Self { + Self { + in_memory: true, + path: ":memory:".into(), + busy_timeout: DEFAULT_BUSY_TIMEOUT, + max_connections: num_cpus::get() as u32, + min_connections: DEFAULT_MIN_CONNECTIONS, + journal_mode: DEFAULT_JOURNAL_MODE, + locking_mode: DEFAULT_LOCKING_MODE, + shared_cache: DEFAULT_SHARED_CACHE, + synchronous: DEFAULT_SYNCHRONOUS, + } + } } impl SqliteStoreOptions { /// Initialize `SqliteStoreOptions` from a generic set of options pub fn new<'a>(options: impl IntoOptions<'a>) -> Result { let mut opts = options.into_options()?; + let busy_timeout = if let Some(timeout) = opts.query.remove("busy_timeout") { + Duration::from_millis( + timeout + .parse() + .map_err(err_map!(Input, "Error parsing 'busy_timeout' parameter"))?, + ) + } else { + DEFAULT_BUSY_TIMEOUT + }; let max_connections = if let Some(max_conn) = opts.query.remove("max_connections") { max_conn .parse() @@ -39,12 +79,54 @@ impl SqliteStoreOptions { } else { num_cpus::get() as u32 }; + let min_connections = if let Some(min_conn) = opts.query.remove("min_connections") { + min_conn + .parse() + .map_err(err_map!(Input, "Error parsing 'min_connections' parameter"))? + } else { + DEFAULT_MIN_CONNECTIONS + }; + let journal_mode = if let Some(mode) = opts.query.remove("journal_mode") { + SqliteJournalMode::from_str(&mode) + .map_err(err_map!(Input, "Error parsing 'journal_mode' parameter"))? + } else { + DEFAULT_JOURNAL_MODE + }; + let locking_mode = if let Some(mode) = opts.query.remove("locking_mode") { + if mode.eq_ignore_ascii_case("exclusive") { + SqliteLockingMode::Exclusive + } else if mode.eq_ignore_ascii_case("normal") { + SqliteLockingMode::Normal + } else { + return Err(err_msg!(Input, "Error parsing 'locking_mode' parameter")); + } + } else { + DEFAULT_LOCKING_MODE + }; + let shared_cache = if let Some(cache) = opts.query.remove("cache") { + cache.eq_ignore_ascii_case("shared") + } else { + DEFAULT_SHARED_CACHE + }; + let synchronous = if let Some(sync) = opts.query.remove("synchronous") { + SqliteSynchronous::from_str(&sync) + .map_err(err_map!(Input, "Error parsing 'synchronous' parameter"))? + } else { + DEFAULT_SYNCHRONOUS + }; + let mut path = opts.host.to_string(); path.push_str(&*opts.path); Ok(Self { in_memory: path == ":memory:", path, + busy_timeout, max_connections, + min_connections, + journal_mode, + locking_mode, + shared_cache, + synchronous, }) } @@ -52,7 +134,12 @@ impl SqliteStoreOptions { #[allow(unused_mut)] let mut conn_opts = SqliteConnectOptions::from_str(self.path.as_ref())? .create_if_missing(auto_create) - .shared_cache(true); + .auto_vacuum(SqliteAutoVacuum::Incremental) + .busy_timeout(self.busy_timeout) + .journal_mode(self.journal_mode.clone()) + .locking_mode(self.locking_mode.clone()) + .shared_cache(self.shared_cache) + .synchronous(self.synchronous.clone()); #[cfg(feature = "log")] { conn_opts.log_statements(log::LevelFilter::Debug); @@ -62,7 +149,7 @@ impl SqliteStoreOptions { // maintains at least 1 connection. // for an in-memory database this is required to avoid dropping the database, // for a file database this signals other instances that the database is in use - .min_connections(1) + .min_connections(self.min_connections) .max_connections(self.max_connections) .test_before_acquire(false) .connect_with(conn_opts) From 78639a5196eeee0234e100dd406897a8c9bb9ddb Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Tue, 5 Apr 2022 14:19:26 -0700 Subject: [PATCH 12/43] adjust error message Signed-off-by: Andrew Whitehead --- src/backend/any.rs | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/backend/any.rs b/src/backend/any.rs index 4c440e79..1d084c6c 100644 --- a/src/backend/any.rs +++ b/src/backend/any.rs @@ -270,7 +270,11 @@ impl<'a> ManageBackend<'a> for &'a str { Ok(Store::new(AnyBackend::Sqlite(mgr.into_inner()))) } - _ => Err(err_msg!(Unsupported, "Invalid backend: {}", &opts.schema)), + _ => Err(err_msg!( + Unsupported, + "Unsupported backend: {}", + &opts.schema + )), } }) } @@ -301,7 +305,11 @@ impl<'a> ManageBackend<'a> for &'a str { Ok(Store::new(AnyBackend::Sqlite(mgr.into_inner()))) } - _ => Err(err_msg!(Unsupported, "Invalid backend: {}", &opts.schema)), + _ => Err(err_msg!( + Unsupported, + "Unsupported backend: {}", + &opts.schema + )), } }) } @@ -324,7 +332,11 @@ impl<'a> ManageBackend<'a> for &'a str { Ok(opts.remove().await?) } - _ => Err(err_msg!(Unsupported, "Invalid backend: {}", &opts.schema)), + _ => Err(err_msg!( + Unsupported, + "Unsupported backend: {}", + &opts.schema + )), } }) } From a16054ccac869f311a417003db17abea4aeb25bf Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Tue, 5 Apr 2022 14:20:23 -0700 Subject: [PATCH 13/43] use update instead of delete & insert Signed-off-by: Andrew Whitehead --- src/backend/sqlite/mod.rs | 63 +++++++++++++++++++++++++++------------ 1 file changed, 44 insertions(+), 19 deletions(-) diff --git a/src/backend/sqlite/mod.rs b/src/backend/sqlite/mod.rs index c34b65be..5d5265de 100644 --- a/src/backend/sqlite/mod.rs +++ b/src/backend/sqlite/mod.rs @@ -47,6 +47,9 @@ const FETCH_QUERY: &'static str = "SELECT i.id, i.value, const INSERT_QUERY: &'static str = "INSERT OR IGNORE INTO items (profile_id, kind, category, name, value, expiry) VALUES (?1, ?2, ?3, ?4, ?5, ?6)"; +const UPDATE_QUERY: &'static str = + "UPDATE items SET value=?5, expiry=?6 WHERE profile_id=?1 AND kind=?2 + AND category=?3 AND name=?4 RETURNING id"; const SCAN_QUERY: &'static str = "SELECT i.id, i.name, i.value, (SELECT GROUP_CONCAT(it.plaintext || ':' || HEX(it.name) || ':' || HEX(it.value)) FROM items_tags it WHERE it.item_id = i.id) AS tags @@ -56,6 +59,8 @@ const DELETE_ALL_QUERY: &'static str = "DELETE FROM items AS i WHERE i.profile_id = ?1 AND i.kind = ?2 AND i.category = ?3"; const TAG_INSERT_QUERY: &'static str = "INSERT INTO items_tags (item_id, name, value, plaintext) VALUES (?1, ?2, ?3, ?4)"; +const TAG_DELETE_QUERY: &'static str = "DELETE FROM items_tags + WHERE item_id=?1"; /// A Sqlite database store pub struct SqliteStore { @@ -435,9 +440,6 @@ impl QueryBackend for DbSession { .await?; let mut active = acquire_session(&mut *self).await?; let mut txn = active.as_transaction().await?; - if op == EntryOperation::Replace { - perform_remove(&mut txn, kind, &enc_category, &enc_name, false).await?; - } perform_insert( &mut txn, kind, @@ -446,6 +448,7 @@ impl QueryBackend for DbSession { &enc_value, enc_tags, expiry_ms, + op == EntryOperation::Insert, ) .await?; txn.commit().await?; @@ -484,8 +487,10 @@ impl ExtDatabase for Sqlite { Box::pin(async move { ::TransactionManager::begin(&mut *conn).await?; if !nested { - sqlx::query("ROLLBACK").execute(&mut *conn).await?; - sqlx::query("BEGIN IMMEDIATE").execute(conn).await?; + // a no-op write transaction + sqlx::query("DELETE FROM config WHERE 0") + .execute(&mut *conn) + .await?; } Ok(()) }) @@ -540,21 +545,41 @@ async fn perform_insert<'q>( enc_value: &[u8], enc_tags: Option>, expiry_ms: Option, + new_row: bool, ) -> Result<(), Error> { - trace!("Insert entry"); - let done = sqlx::query(INSERT_QUERY) - .bind(active.profile_id) - .bind(kind as i16) - .bind(enc_category) - .bind(enc_name) - .bind(enc_value) - .bind(expiry_ms.map(expiry_timestamp).transpose()?) - .execute(active.connection_mut()) - .await?; - if done.rows_affected() == 0 { - return Err(err_msg!(Duplicate, "Duplicate row")); - } - let row_id = done.last_insert_rowid(); + let row_id = if new_row { + trace!("Insert entry"); + let done = sqlx::query(INSERT_QUERY) + .bind(active.profile_id) + .bind(kind as i16) + .bind(enc_category) + .bind(enc_name) + .bind(enc_value) + .bind(expiry_ms.map(expiry_timestamp).transpose()?) + .execute(active.connection_mut()) + .await?; + if done.rows_affected() == 0 { + return Err(err_msg!(Duplicate, "Duplicate row")); + } + done.last_insert_rowid() + } else { + trace!("Update entry"); + let row_id: i64 = sqlx::query_scalar(UPDATE_QUERY) + .bind(active.profile_id) + .bind(kind as i16) + .bind(enc_category) + .bind(enc_name) + .bind(enc_value) + .bind(expiry_ms.map(expiry_timestamp).transpose()?) + .fetch_one(active.connection_mut()) + .await + .map_err(|_| err_msg!(NotFound, "Error updating existing row"))?; + sqlx::query(TAG_DELETE_QUERY) + .bind(row_id) + .execute(active.connection_mut()) + .await?; + row_id + }; if let Some(tags) = enc_tags { for tag in tags { sqlx::query(TAG_INSERT_QUERY) From 29537dff0d4e5884638bdb0a5d3b2929f9d5bb07 Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Tue, 5 Apr 2022 14:20:58 -0700 Subject: [PATCH 14/43] use stabilized Arc::increment_strong_count Signed-off-by: Andrew Whitehead --- src/ffi/handle.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/ffi/handle.rs b/src/ffi/handle.rs index 935ddd12..59ca81f1 100644 --- a/src/ffi/handle.rs +++ b/src/ffi/handle.rs @@ -1,4 +1,4 @@ -use std::{fmt::Display, marker::PhantomData, mem, sync::Arc}; +use std::{fmt::Display, marker::PhantomData, sync::Arc}; use crate::error::Error; @@ -18,16 +18,15 @@ impl ArcHandle { pub fn load(&self) -> Result, Error> { self.validate()?; let slf = unsafe { Arc::from_raw(self.0 as *const T) }; - let copy = slf.clone(); - mem::forget(slf); // Arc::increment_strong_count(..) in 1.51 - Ok(copy) + unsafe { Arc::increment_strong_count(self.0 as *const T) }; + Ok(slf) } pub fn remove(&self) { if self.0 != 0 { unsafe { // Drop the initial reference. There could be others outstanding. - Arc::from_raw(self.0 as *const T); + Arc::decrement_strong_count(self.0 as *const T); } } } From 9bb57e632a7c3f30ed75e5fe72384903c757fcc0 Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Tue, 5 Apr 2022 14:21:46 -0700 Subject: [PATCH 15/43] updates to transaction stress tests Signed-off-by: Andrew Whitehead --- tests/backends.rs | 72 ++++++++++++++++++++++++++++++++++++++++++++-- tests/utils/mod.rs | 30 +++++++++++++------ 2 files changed, 91 insertions(+), 11 deletions(-) diff --git a/tests/backends.rs b/tests/backends.rs index f8cdd259..3155b175 100644 --- a/tests/backends.rs +++ b/tests/backends.rs @@ -318,9 +318,9 @@ mod sqlite { } #[test] - fn file_db_contention() { + fn txn_contention_file() { log_init(); - let fname = format!("sqlite-contend-{}.db", uuid::Uuid::new_v4().to_string()); + let fname = format!("sqlite-contention-{}.db", uuid::Uuid::new_v4().to_string()); let key = generate_raw_store_key(None).expect("Error creating raw key"); block_on(async move { @@ -342,6 +342,74 @@ mod sqlite { }); } + // #[test] + // fn stress_test() { + // log_init(); + // use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions}; + // use std::str::FromStr; + // let conn_opts = SqliteConnectOptions::from_str("sqlite:test.db") + // .unwrap() + // .create_if_missing(true); + // // .shared_cache(true); + // block_on(async move { + // let pool = SqlitePoolOptions::default() + // // maintains at least 1 connection. + // // for an in-memory database this is required to avoid dropping the database, + // // for a file database this signals other instances that the database is in use + // .min_connections(1) + // .max_connections(5) + // .test_before_acquire(false) + // .connect_with(conn_opts) + // .await + // .unwrap(); + + // let mut conn = pool.begin().await.unwrap(); + // sqlx::query("CREATE TABLE test (name TEXT)") + // .execute(&mut conn) + // .await + // .unwrap(); + // sqlx::query("INSERT INTO test (name) VALUES ('test')") + // .execute(&mut conn) + // .await + // .unwrap(); + // conn.commit().await.unwrap(); + + // const TASKS: usize = 25; + // const COUNT: usize = 1000; + + // async fn fetch(pool: SqlitePool) -> Result<(), &'static str> { + // // try to avoid panics in this section, as they will be raised on a tokio worker thread + // for _ in 0..COUNT { + // let mut txn = pool.acquire().await.expect("Acquire error"); + // sqlx::query("BEGIN IMMEDIATE") + // .execute(&mut txn) + // .await + // .expect("Transaction error"); + // let _ = sqlx::query("SELECT * FROM test") + // .fetch_one(&mut txn) + // .await + // .expect("Error fetching row"); + // sqlx::query("COMMIT") + // .execute(&mut txn) + // .await + // .expect("Commit error"); + // } + // Ok(()) + // } + + // let mut tasks = vec![]; + // for _ in 0..TASKS { + // tasks.push(tokio::spawn(fetch(pool.clone()))); + // } + + // for task in tasks { + // if let Err(s) = task.await.unwrap() { + // panic!("Error in concurrent update task: {}", s); + // } + // } + // }); + // } + async fn init_db() -> Arc> { log_init(); let key = generate_raw_store_key(None).expect("Error creating raw key"); diff --git a/tests/utils/mod.rs b/tests/utils/mod.rs index abbb4ea3..777307c9 100644 --- a/tests/utils/mod.rs +++ b/tests/utils/mod.rs @@ -759,18 +759,23 @@ pub async fn db_txn_contention(db: impl TestStore + 'static) { conn.commit().await.expect(ERR_COMMIT); - const TASKS: usize = 25; - const INC: usize = 500; + const TASKS: usize = 10; + const INC: usize = 1000; - async fn inc(db: impl TestStore, category: String, name: String) { + async fn inc(db: impl TestStore, category: String, name: String) -> Result<(), &'static str> { + // try to avoid panics in this section, as they will be raised on a tokio worker thread for _ in 0..INC { let mut conn = db.transaction(None).await.expect(ERR_TRANSACTION); let row = conn .fetch(&category, &name, true) .await - .expect(ERR_FETCH) - .expect(ERR_REQ_ROW); - let val: usize = str::parse(row.value.as_opt_str().unwrap()).unwrap(); + .map_err(|e| { + log::error!("{:?}", e); + ERR_FETCH + })? + .ok_or(ERR_REQ_ROW)?; + let val: usize = str::parse(row.value.as_opt_str().ok_or("Non-string counter value")?) + .map_err(|_| "Error parsing counter value")?; conn.replace( &category, &name, @@ -779,9 +784,13 @@ pub async fn db_txn_contention(db: impl TestStore + 'static) { None, ) .await - .expect(ERR_REPLACE); - conn.commit().await.expect(ERR_COMMIT); + .map_err(|e| { + log::error!("{:?}", e); + ERR_REPLACE + })?; + conn.commit().await.map_err(|_| ERR_COMMIT)?; } + Ok(()) } let mut tasks = vec![]; @@ -792,9 +801,12 @@ pub async fn db_txn_contention(db: impl TestStore + 'static) { test_row.name.clone(), ))); } + // JoinSet is not stable yet, just await all the tasks for task in tasks { - task.await.unwrap(); + if let Err(s) = task.await.unwrap() { + panic!("Error in concurrent update task: {}", s); + } } // check the total From 0cb834c360541236f3fe496b00886f0f6583ae3c Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Tue, 5 Apr 2022 14:45:41 -0700 Subject: [PATCH 16/43] update localkey test Signed-off-by: Andrew Whitehead --- tests/backends.rs | 151 +++++++++++++++++++++++---------------------- tests/local_key.rs | 49 +++++++++++++++ tests/utils/mod.rs | 144 +++++++++--------------------------------- 3 files changed, 155 insertions(+), 189 deletions(-) create mode 100644 tests/local_key.rs diff --git a/tests/backends.rs b/tests/backends.rs index 3155b175..ca77a734 100644 --- a/tests/backends.rs +++ b/tests/backends.rs @@ -124,14 +124,14 @@ macro_rules! backend_tests { }) } - // #[test] - // fn keypair_create_fetch() { - // block_on(async { - // let db = $init.await; - // super::utils::db_keypair_create_fetch(db.clone()).await; - // db.close().await.expect(ERR_CLOSE); - // }) - // } + #[test] + fn keypair_create_fetch() { + block_on(async { + let db = $init.await; + super::utils::db_keypair_insert_fetch(db.clone()).await; + db.close().await.expect(ERR_CLOSE); + }) + } // #[test] // fn keypair_sign_verify() { @@ -342,73 +342,74 @@ mod sqlite { }); } - // #[test] - // fn stress_test() { - // log_init(); - // use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions}; - // use std::str::FromStr; - // let conn_opts = SqliteConnectOptions::from_str("sqlite:test.db") - // .unwrap() - // .create_if_missing(true); - // // .shared_cache(true); - // block_on(async move { - // let pool = SqlitePoolOptions::default() - // // maintains at least 1 connection. - // // for an in-memory database this is required to avoid dropping the database, - // // for a file database this signals other instances that the database is in use - // .min_connections(1) - // .max_connections(5) - // .test_before_acquire(false) - // .connect_with(conn_opts) - // .await - // .unwrap(); - - // let mut conn = pool.begin().await.unwrap(); - // sqlx::query("CREATE TABLE test (name TEXT)") - // .execute(&mut conn) - // .await - // .unwrap(); - // sqlx::query("INSERT INTO test (name) VALUES ('test')") - // .execute(&mut conn) - // .await - // .unwrap(); - // conn.commit().await.unwrap(); - - // const TASKS: usize = 25; - // const COUNT: usize = 1000; - - // async fn fetch(pool: SqlitePool) -> Result<(), &'static str> { - // // try to avoid panics in this section, as they will be raised on a tokio worker thread - // for _ in 0..COUNT { - // let mut txn = pool.acquire().await.expect("Acquire error"); - // sqlx::query("BEGIN IMMEDIATE") - // .execute(&mut txn) - // .await - // .expect("Transaction error"); - // let _ = sqlx::query("SELECT * FROM test") - // .fetch_one(&mut txn) - // .await - // .expect("Error fetching row"); - // sqlx::query("COMMIT") - // .execute(&mut txn) - // .await - // .expect("Commit error"); - // } - // Ok(()) - // } - - // let mut tasks = vec![]; - // for _ in 0..TASKS { - // tasks.push(tokio::spawn(fetch(pool.clone()))); - // } - - // for task in tasks { - // if let Err(s) = task.await.unwrap() { - // panic!("Error in concurrent update task: {}", s); - // } - // } - // }); - // } + #[cfg(feature = "stress_test")] + #[test] + fn stress_test() { + log_init(); + use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions}; + use std::str::FromStr; + let conn_opts = SqliteConnectOptions::from_str("sqlite:test.db") + .unwrap() + .create_if_missing(true); + // .shared_cache(true); + block_on(async move { + let pool = SqlitePoolOptions::default() + // maintains at least 1 connection. + // for an in-memory database this is required to avoid dropping the database, + // for a file database this signals other instances that the database is in use + .min_connections(1) + .max_connections(5) + .test_before_acquire(false) + .connect_with(conn_opts) + .await + .unwrap(); + + let mut conn = pool.begin().await.unwrap(); + sqlx::query("CREATE TABLE test (name TEXT)") + .execute(&mut conn) + .await + .unwrap(); + sqlx::query("INSERT INTO test (name) VALUES ('test')") + .execute(&mut conn) + .await + .unwrap(); + conn.commit().await.unwrap(); + + const TASKS: usize = 25; + const COUNT: usize = 1000; + + async fn fetch(pool: SqlitePool) -> Result<(), &'static str> { + // try to avoid panics in this section, as they will be raised on a tokio worker thread + for _ in 0..COUNT { + let mut txn = pool.acquire().await.expect("Acquire error"); + sqlx::query("BEGIN IMMEDIATE") + .execute(&mut txn) + .await + .expect("Transaction error"); + let _ = sqlx::query("SELECT * FROM test") + .fetch_one(&mut txn) + .await + .expect("Error fetching row"); + sqlx::query("COMMIT") + .execute(&mut txn) + .await + .expect("Commit error"); + } + Ok(()) + } + + let mut tasks = vec![]; + for _ in 0..TASKS { + tasks.push(tokio::spawn(fetch(pool.clone()))); + } + + for task in tasks { + if let Err(s) = task.await.unwrap() { + panic!("Error in concurrent update task: {}", s); + } + } + }); + } async fn init_db() -> Arc> { log_init(); diff --git a/tests/local_key.rs b/tests/local_key.rs new file mode 100644 index 00000000..47d0da53 --- /dev/null +++ b/tests/local_key.rs @@ -0,0 +1,49 @@ +use aries_askar::kms::{KeyAlg, LocalKey}; + +const ERR_CREATE_KEYPAIR: &'static str = "Error creating keypair"; +const ERR_SIGN: &'static str = "Error signing message"; +const ERR_VERIFY: &'static str = "Error verifying signature"; + +pub async fn localkey_sign_verify() { + let keypair = LocalKey::generate(KeyAlg::Ed25519, true).expect(ERR_CREATE_KEYPAIR); + + let message = b"message".to_vec(); + let sig = keypair.sign_message(&message, None).expect(ERR_SIGN); + + assert_eq!( + keypair + .verify_signature(&message, &sig, None) + .expect(ERR_VERIFY), + true + ); + + assert_eq!( + keypair + .verify_signature(b"bad input", &sig, None) + .expect(ERR_VERIFY), + false + ); + + assert_eq!( + keypair.verify_signature( + // [0u8; 64] + b"xt19s1sp2UZCGhy9rNyb1FtxdKiDGZZPNFnc1KiM9jYYEuHxuwNeFf1oQKsn8zv6yvYBGhXa83288eF4MqN1oDq", + &sig,None + ).expect(ERR_VERIFY), + false + ); + + assert_eq!( + keypair + .verify_signature(&message, b"bad sig", None) + .is_err(), + true + ); + + assert_eq!( + keypair + .verify_signature(&message, &sig, Some("invalid type")) + .is_err(), + true + ); +} diff --git a/tests/utils/mod.rs b/tests/utils/mod.rs index 777307c9..f6da69d2 100644 --- a/tests/utils/mod.rs +++ b/tests/utils/mod.rs @@ -1,6 +1,9 @@ use std::{fmt::Debug, future::Future, ops::Deref, pin::Pin, sync::Arc}; -use aries_askar::{Backend, Entry, EntryTag, Error, ErrorKind, Store, TagFilter}; +use aries_askar::{ + kms::{KeyAlg, LocalKey}, + Backend, Entry, EntryTag, Error, ErrorKind, Store, TagFilter, +}; use tokio::task::spawn; @@ -18,10 +21,10 @@ const ERR_REPLACE: &'static str = "Error replacing test row"; const ERR_REMOVE_ALL: &'static str = "Error removing test rows"; const ERR_SCAN: &'static str = "Error starting scan"; const ERR_SCAN_NEXT: &'static str = "Error fetching scan rows"; -// const ERR_CREATE_KEYPAIR: &'static str = "Error creating keypair"; -// const ERR_FETCH_KEY: &'static str = "Error fetching key"; -// const ERR_SIGN: &'static str = "Error signing message"; -// const ERR_VERIFY: &'static str = "Error verifying signature"; +const ERR_CREATE_KEYPAIR: &'static str = "Error creating keypair"; +const ERR_INSERT_KEY: &'static str = "Error inserting key"; +const ERR_FETCH_KEY: &'static str = "Error fetching key"; +const ERR_LOAD_KEY: &'static str = "Error loading key"; pub trait TestStore: Clone + Deref> + Send + Sync { type DB: Backend + Debug + 'static; @@ -480,115 +483,28 @@ pub async fn db_remove_all(db: impl TestStore) { assert_eq!(removed, 2); } -// pub async fn db_keypair_create_fetch(db: impl TestStore) { -// let mut conn = db.session(None).await.expect(ERR_SESSION); - -// let metadata = "meta".to_owned(); -// let key_info = conn -// .create_keypair(KeyAlg::Ed25519, Some(&metadata), None, None) -// .await -// .expect(ERR_CREATE_KEYPAIR); -// assert_eq!(key_info.params.metadata, Some(metadata)); - -// let found = conn -// .fetch_key(key_info.category.clone(), &key_info.ident, false) -// .await -// .expect(ERR_FETCH_KEY); -// assert_eq!(Some(key_info), found); -// } - -// pub async fn db_keypair_sign_verify(db: impl TestStore) { -// let mut conn = db.session(None).await.expect(ERR_SESSION); - -// let key_info = conn -// .create_keypair(KeyAlg::Ed25519, None, None, None) -// .await -// .expect(ERR_CREATE_KEYPAIR); - -// let message = b"message".to_vec(); -// let sig = conn -// .sign_message(&key_info.ident, &message) -// .await -// .expect(ERR_SIGN); - -// assert_eq!( -// verify_signature(&key_info.ident, &message, &sig).expect(ERR_VERIFY), -// true -// ); - -// assert_eq!( -// verify_signature(&key_info.ident, b"bad input", &sig).expect(ERR_VERIFY), -// false -// ); - -// assert_eq!( -// verify_signature( -// &key_info.ident, -// // [0u8; 64] -// b"xt19s1sp2UZCGhy9rNyb1FtxdKiDGZZPNFnc1KiM9jYYEuHxuwNeFf1oQKsn8zv6yvYBGhXa83288eF4MqN1oDq", -// &sig -// ).expect(ERR_VERIFY), -// false -// ); - -// assert_eq!( -// verify_signature(&key_info.ident, &message, b"bad sig").is_err(), -// true -// ); - -// let err = verify_signature("not a key", &message, &sig).expect_err(ERR_REQ_ERR); -// assert_eq!(err.kind(), ErrorKind::Input); -// } - -// pub async fn db_keypair_pack_unpack_anon(db: impl TestStore) { -// let mut conn = db.session(None).await.expect(ERR_SESSION); - -// let recip_key = conn -// .create_keypair(KeyAlg::Ed25519, None, None, None) -// .await -// .expect(ERR_CREATE_KEYPAIR); - -// let msg = b"message".to_vec(); - -// let packed = conn -// .pack_message(vec![recip_key.ident.as_str()], None, &msg) -// .await -// .expect(ERR_PACK); - -// let (unpacked, p_recip, p_send) = conn.unpack_message(&packed).await.expect(ERR_UNPACK); -// assert_eq!(unpacked, msg); -// assert_eq!(p_recip.to_string(), recip_key.ident); -// assert_eq!(p_send, None); -// } - -// pub async fn db_keypair_pack_unpack_auth(db: impl TestStore) { -// let mut conn = db.session(None).await.expect(ERR_SESSION); - -// let sender_key = conn -// .create_keypair(KeyAlg::Ed25519, None, None, None) -// .await -// .expect(ERR_CREATE_KEYPAIR); -// let recip_key = conn -// .create_keypair(KeyAlg::Ed25519, None, None, None) -// .await -// .expect(ERR_CREATE_KEYPAIR); - -// let msg = b"message".to_vec(); - -// let packed = conn -// .pack_message( -// vec![recip_key.ident.as_str()], -// Some(&sender_key.ident), -// &msg, -// ) -// .await -// .expect(ERR_PACK); - -// let (unpacked, p_recip, p_send) = conn.unpack_message(&packed).await.expect(ERR_UNPACK); -// assert_eq!(unpacked, msg); -// assert_eq!(p_recip.to_string(), recip_key.ident); -// assert_eq!(p_send.map(|k| k.to_string()), Some(sender_key.ident)); -// } +pub async fn db_keypair_insert_fetch(db: impl TestStore) { + let keypair = LocalKey::generate(KeyAlg::Ed25519, false).expect(ERR_CREATE_KEYPAIR); + + let mut conn = db.session(None).await.expect(ERR_SESSION); + + let key_name = "testkey"; + let metadata = "meta"; + conn.insert_key(&key_name, &keypair, Some(metadata), None, None) + .await + .expect(ERR_INSERT_KEY); + + let found = conn + .fetch_key(&key_name, false) + .await + .expect(ERR_FETCH_KEY) + .expect(ERR_REQ_ROW); + assert_eq!(found.algorithm(), Some(KeyAlg::Ed25519.as_str())); + assert_eq!(found.name(), key_name); + assert_eq!(found.metadata(), Some(metadata)); + assert_eq!(found.is_local(), true); + found.load_local_key().expect(ERR_LOAD_KEY); +} pub async fn db_txn_rollback(db: impl TestStore) { let test_row = Entry::new("category", "name", "value", Vec::new()); From 61ad3c4c910d9268c2baddaa05c2c0ab4f1b2d08 Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Tue, 5 Apr 2022 14:46:22 -0700 Subject: [PATCH 17/43] reduce count Signed-off-by: Andrew Whitehead --- wrappers/python/tests/test_store.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/wrappers/python/tests/test_store.py b/wrappers/python/tests/test_store.py index c0e034ba..8ac3bde0 100644 --- a/wrappers/python/tests/test_store.py +++ b/wrappers/python/tests/test_store.py @@ -1,6 +1,5 @@ import asyncio import os -import threading from pytest import mark, raises import pytest_asyncio @@ -27,6 +26,7 @@ def raw_key() -> str: @pytest_asyncio.fixture +@mark.asyncio async def store() -> Store: key = raw_key() store = await Store.provision(TEST_STORE_URI, "raw", key, recreate=True) @@ -156,8 +156,8 @@ async def test_transaction_conflict(store: Store): ) await txn.commit() - INC_COUNT = 500 - TASKS = 20 + INC_COUNT = 1000 + TASKS = 10 async def inc(): for _ in range(INC_COUNT): From f68c84023b3108be12fdea7c66a00a8874dc92e0 Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Tue, 5 Apr 2022 14:47:52 -0700 Subject: [PATCH 18/43] workflow updates Signed-off-by: Andrew Whitehead --- .github/workflows/build.yml | 69 +++++++++++++++++++++++++------------ 1 file changed, 47 insertions(+), 22 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 5e3e047e..c0dbc06c 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -18,6 +18,7 @@ jobs: check: name: Run Checks strategy: + fail-fast: false matrix: os: [ubuntu-latest, macos-11, windows-latest] runs-on: ${{ matrix.os }} @@ -31,17 +32,14 @@ jobs: with: profile: minimal toolchain: 1.56 + override: true + components: clippy, rustfmt - name: Cache cargo resources uses: Swatinem/rust-cache@v1 with: sharedKey: check - - - name: Cargo check - uses: actions-rs/cargo@v1 - with: - command: check - args: --workspace + cache-on-failure: true - name: Cargo fmt uses: actions-rs/cargo@v1 @@ -49,6 +47,12 @@ jobs: command: fmt args: --all -- --check + - name: Cargo check + uses: actions-rs/cargo@v1 + with: + command: check + args: --workspace + - name: Debug build uses: actions-rs/cargo@v1 with: @@ -56,20 +60,32 @@ jobs: args: --all-targets - if: "runner.os == 'Linux'" - name: Test with postgres + name: Start postgres (Linux) run: | sudo systemctl start postgresql.service pg_isready sudo -u postgres psql -c "ALTER USER postgres WITH PASSWORD 'postgres'" - echo "-- Test Postgres DB --" - POSTGRES_URL=postgres://postgres:postgres@localhost:5432/test-db cargo test --workspace --features pg_test + + - if: "runner.os == 'Linux'" + name: Test with postgres + uses: actions-rs/cargo@v1 + with: + command: test + args: --workspace --features pg_test -- --nocapture --test-threads 1 + env: + POSTGRES_URL: postgres://postgres:postgres@localhost:5432/test-db + RUST_BACKTRACE: full + # RUST_LOG: debug - if: "runner.os != 'Linux'" name: Test without postgres uses: actions-rs/cargo@v1 with: command: test - args: --workspace + args: --workspace -- --nocapture --test-threads 1 + env: + RUST_BACKTRACE: full + # RUST_LOG: debug - name: Test askar-crypto no_std uses: actions-rs/cargo@v1 @@ -107,15 +123,17 @@ jobs: with: profile: minimal toolchain: 1.56 + override: true - - name: Cache cargo resources - uses: Swatinem/rust-cache@v1 - with: - sharedKey: check + # - name: Cache cargo resources + # uses: Swatinem/rust-cache@v1 + # with: + # sharedKey: check - name: Build library env: BUILD_TARGET: ${{ matrix.target }} + # LIBSQLITE3_FLAGS: SQLITE_DEBUG SQLITE_MEMDEBUG run: sh ./build.sh - name: Upload library artifacts @@ -135,7 +153,7 @@ jobs: - os: macos-11 lib: libaries_askar.dylib target: apple-darwin # creates a universal library - toolchain: beta # beta required for aarch64-apple-darwin target + toolchain: nightly-2021-10-21 # beta required for aarch64-apple-darwin target - os: windows-latest lib: aries_askar.dll toolchain: 1.56 @@ -151,16 +169,18 @@ jobs: with: profile: minimal toolchain: ${{ matrix.toolchain }} + override: true - - name: Cache cargo resources - uses: Swatinem/rust-cache@v1 - with: - sharedKey: check + # - name: Cache cargo resources + # uses: Swatinem/rust-cache@v1 + # with: + # sharedKey: check - name: Build library env: BUILD_TARGET: ${{ matrix.target }} BUILD_TOOLCHAIN: ${{ matrix.toolchain }} + # LIBSQLITE3_FLAGS: SQLITE_DEBUG SQLITE_MEMDEBUG run: sh ./build.sh - name: Upload library artifacts @@ -214,12 +234,14 @@ jobs: python setup.py bdist_wheel --python-tag=py3 --plat-name=${{ matrix.plat-name }} pip install pytest pytest-asyncio dist/* echo "-- Test SQLite in-memory --" - python -m pytest + python -m pytest --log-cli-level=DEBUG echo "-- Test SQLite file DB --" - TEST_STORE_URI=sqlite://test.db python -m pytest + TEST_STORE_URI=sqlite://test.db python -m pytest --log-cli-level=DEBUG working-directory: wrappers/python env: no_proxy: "*" # python issue 30385 + RUST_BACKTRACE: full + # RUST_LOG: debug - if: "runner.os == 'Linux'" name: Test postgres @@ -227,10 +249,13 @@ jobs: sudo systemctl start postgresql.service pg_isready sudo -u postgres psql -c "ALTER USER postgres WITH PASSWORD 'postgres'" - TEST_STORE_URI=postgres://postgres:postgres@localhost:5432/test-db python -m pytest + python -m pytest --log-cli-level=DEBUG working-directory: wrappers/python env: no_proxy: "*" # python issue 30385 + RUST_BACKTRACE: full + # RUST_LOG: debug + TEST_STORE_URI: postgres://postgres:postgres@localhost:5432/test-db - if: "runner.os == 'Linux'" name: Audit wheel From 6b1a8496e4d02814cd0f4eb6cb42ae549dfdd303 Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Tue, 5 Apr 2022 15:21:58 -0700 Subject: [PATCH 19/43] use async-lock instead of option-lock Signed-off-by: Andrew Whitehead --- Cargo.toml | 5 ++--- src/ffi/store.rs | 7 +++---- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index fdc4ba24..e22aea7d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,7 @@ rustdoc-args = ["--cfg", "docsrs"] default = ["all_backends", "ffi", "logger"] all_backends = ["any", "postgres", "sqlite"] any = [] -ffi = ["any", "ffi-support", "logger", "option-lock"] +ffi = ["any", "ffi-support", "logger"] jemalloc = ["jemallocator"] logger = ["env_logger", "log"] postgres = ["sqlx", "sqlx/postgres", "sqlx/tls"] @@ -38,7 +38,7 @@ pg_test = ["postgres"] hex-literal = "0.3" [dependencies] -async-lock = "2.4" +async-lock = "2.5" async-stream = "0.3" bs58 = "0.4" chrono = "0.4" @@ -53,7 +53,6 @@ itertools = "0.10" jemallocator = { version = "0.3", optional = true } log = { version = "0.4", optional = true } num_cpus = { version = "1.0", optional = true } -option-lock = { version = "0.3", optional = true } once_cell = "1.5" percent-encoding = "2.0" serde = { version = "1.0", features = ["derive"] } diff --git a/src/ffi/store.rs b/src/ffi/store.rs index c8f8cedb..6d43ac03 100644 --- a/src/ffi/store.rs +++ b/src/ffi/store.rs @@ -1,9 +1,8 @@ use std::{collections::BTreeMap, os::raw::c_char, ptr, str::FromStr, sync::Arc}; -use async_lock::RwLock; +use async_lock::{Mutex as TryMutex, MutexGuardArc as TryMutexGuard, RwLock}; use ffi_support::{rust_string_to_c, ByteBuffer, FfiStr}; use once_cell::sync::Lazy; -use option_lock::{Mutex as TryMutex, MutexGuardArc as TryMutexGuard}; use super::{ error::set_last_error, @@ -87,7 +86,7 @@ where pub async fn remove(&self, handle: K) -> Option> { self.map.write().await.remove(&handle).map(|(_s, v)| { Arc::try_unwrap(v) - .map(|item| item.into_inner().unwrap()) + .map(|item| item.into_inner()) .map_err(|_| err_msg!(Busy, "Resource handle in use")) }) } @@ -100,7 +99,7 @@ where .ok_or_else(|| err_msg!("Invalid resource handle"))? .1 .try_lock_arc() - .map_err(|_| err_msg!(Busy, "Resource handle in use")) + .ok_or_else(|| err_msg!(Busy, "Resource handle in use")) } pub async fn remove_all(&self, store: StoreHandle) -> Result<(), Error> { From cdd3341c112aabdb7f43bb5c60f6163d41ffcb24 Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Tue, 5 Apr 2022 15:58:17 -0700 Subject: [PATCH 20/43] adjust log level Signed-off-by: Andrew Whitehead --- .github/workflows/build.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index c0dbc06c..c27950d7 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -234,9 +234,9 @@ jobs: python setup.py bdist_wheel --python-tag=py3 --plat-name=${{ matrix.plat-name }} pip install pytest pytest-asyncio dist/* echo "-- Test SQLite in-memory --" - python -m pytest --log-cli-level=DEBUG + python -m pytest --log-cli-level=WARNING echo "-- Test SQLite file DB --" - TEST_STORE_URI=sqlite://test.db python -m pytest --log-cli-level=DEBUG + TEST_STORE_URI=sqlite://test.db python -m pytest --log-cli-level=WARNING working-directory: wrappers/python env: no_proxy: "*" # python issue 30385 @@ -249,7 +249,7 @@ jobs: sudo systemctl start postgresql.service pg_isready sudo -u postgres psql -c "ALTER USER postgres WITH PASSWORD 'postgres'" - python -m pytest --log-cli-level=DEBUG + python -m pytest --log-cli-level=WARNING working-directory: wrappers/python env: no_proxy: "*" # python issue 30385 From 2b947a62e84760f371cdab3679acfd6b9ddfca57 Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Tue, 12 Apr 2022 19:14:10 -0700 Subject: [PATCH 21/43] use *const T in ArcHandle Signed-off-by: Andrew Whitehead --- src/ffi/handle.rs | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/src/ffi/handle.rs b/src/ffi/handle.rs index 59ca81f1..c4b9dfef 100644 --- a/src/ffi/handle.rs +++ b/src/ffi/handle.rs @@ -1,39 +1,41 @@ -use std::{fmt::Display, marker::PhantomData, sync::Arc}; +use std::{fmt::Display, marker::PhantomData, ptr, sync::Arc}; use crate::error::Error; -#[repr(transparent)] -pub struct ArcHandle(usize, PhantomData); +#[repr(C)] +pub struct ArcHandle(*const T, PhantomData); impl ArcHandle { pub fn invalid() -> Self { - Self(0, PhantomData) + Self(ptr::null(), PhantomData) } pub fn create(value: T) -> Self { let results = Arc::into_raw(Arc::new(value)); - Self(results as usize, PhantomData) + Self(results, PhantomData) } pub fn load(&self) -> Result, Error> { self.validate()?; - let slf = unsafe { Arc::from_raw(self.0 as *const T) }; - unsafe { Arc::increment_strong_count(self.0 as *const T) }; - Ok(slf) + unsafe { + let result = Arc::from_raw(self.0); + Arc::increment_strong_count(self.0); + Ok(result) + } } pub fn remove(&self) { - if self.0 != 0 { - unsafe { + unsafe { + if !self.0.is_null() { // Drop the initial reference. There could be others outstanding. - Arc::decrement_strong_count(self.0 as *const T); + Arc::decrement_strong_count(self.0); } } } #[inline] pub fn validate(&self) -> Result<(), Error> { - if self.0 == 0 { + if self.0.is_null() { Err(err_msg!("Invalid handle")) } else { Ok(()) @@ -43,7 +45,7 @@ impl ArcHandle { impl std::fmt::Display for ArcHandle { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Handle({:p})", self.0 as *const T) + write!(f, "Handle({:p})", self.0) } } @@ -61,7 +63,7 @@ macro_rules! new_sequence_handle (($newtype:ident, $counter:ident) => ( static $counter: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0); #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] - #[repr(transparent)] + #[repr(C)] pub struct $newtype(pub usize); impl $crate::ffi::ResourceHandle for $newtype { From 2326873b35210d8d04dc618d1da4b85731304a12 Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Tue, 12 Apr 2022 19:14:48 -0700 Subject: [PATCH 22/43] use flag instead of mem::forget Signed-off-by: Andrew Whitehead --- src/ffi/key.rs | 2 +- src/ffi/mod.rs | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/ffi/key.rs b/src/ffi/key.rs index 9b7d973f..49ab97ad 100644 --- a/src/ffi/key.rs +++ b/src/ffi/key.rs @@ -365,7 +365,7 @@ pub extern "C" fn askar_key_verify_signature( trace!("Verify signature: {}", handle); check_useful_c_ptr!(out); let key = handle.load()?; - let verify = key.verify_signature(message.as_slice(),signature.as_slice(), sig_type.as_opt_str())?; + let verify = key.verify_signature(message.as_slice(), signature.as_slice(), sig_type.as_opt_str())?; unsafe { *out = verify as i8 }; Ok(ErrorCode::Success) } diff --git a/src/ffi/mod.rs b/src/ffi/mod.rs index 5c67706e..29460c9f 100644 --- a/src/ffi/mod.rs +++ b/src/ffi/mod.rs @@ -32,6 +32,7 @@ ffi_support::define_string_destructor!(askar_string_free); pub struct EnsureCallback)> { f: F, + resolved: bool, _pd: PhantomData, } @@ -39,20 +40,23 @@ impl)> EnsureCallback { pub fn new(f: F) -> Self { Self { f, + resolved: false, _pd: PhantomData, } } - pub fn resolve(self, value: Result) { + pub fn resolve(mut self, value: Result) { + self.resolved = true; (self.f)(value); - std::mem::forget(self); } } impl)> Drop for EnsureCallback { fn drop(&mut self) { // if std::thread::panicking() - capture trace? - (self.f)(Err(err_msg!(Unexpected))); + if !self.resolved { + (self.f)(Err(err_msg!(Unexpected))); + } } } From bba2ad19f137fa3712405976988b8801185363f9 Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Tue, 12 Apr 2022 19:17:18 -0700 Subject: [PATCH 23/43] python wrapper refactoring (ensure method argument types) Signed-off-by: Andrew Whitehead --- wrappers/python/aries_askar/bindings.py | 1247 ++++++++++++++--------- wrappers/python/aries_askar/ecdh.py | 8 +- wrappers/python/aries_askar/key.py | 8 +- wrappers/python/aries_askar/store.py | 66 +- wrappers/python/tests/test_jose_ecdh.py | 2 + 5 files changed, 813 insertions(+), 518 deletions(-) diff --git a/wrappers/python/aries_askar/bindings.py b/wrappers/python/aries_askar/bindings.py index 85a8b4c7..f4147c7c 100644 --- a/wrappers/python/aries_askar/bindings.py +++ b/wrappers/python/aries_askar/bindings.py @@ -6,12 +6,14 @@ import os import sys from ctypes import ( + _SimpleCData, Array, CDLL, CFUNCTYPE, POINTER, Structure, byref, + cast, c_char_p, c_int8, c_int32, @@ -28,6 +30,7 @@ CALLBACKS = {} +INVOKE = {} LIB: CDLL = None LOGGER = logging.getLogger(__name__) LOG_LEVELS = { @@ -39,250 +42,160 @@ MODULE_NAME = __name__.split(".")[0] -class StoreHandle(c_size_t): - """Index of an active Store instance.""" - - def __repr__(self) -> str: - """Format store handle as a string.""" - return f"{self.__class__.__name__}({self.value})" - - async def close(self): - """Close the store, waiting for any active connections.""" - if not getattr(self, "_closed", False): - await do_call_async("askar_store_close", self) - setattr(self, "_closed", True) +class RawBuffer(Structure): + """A byte buffer allocated by the library.""" - def __del__(self): - """Close the store when there are no more references to this object.""" - if not getattr(self, "_closed", False) and self: - do_call("askar_store_close", self, c_void_p()) + _fields_ = [ + ("len", c_int64), + ("data", POINTER(c_ubyte)), + ] + def __bytes__(self) -> bytes: + if not self.len: + return b"" + return bytes(self.array) -class SessionHandle(c_size_t): - """Index of an active Session/Transaction instance.""" + def __len__(self) -> int: + return int(self.len) - def __repr__(self) -> str: - """Format session handle as a string.""" - return f"{self.__class__.__name__}({self.value})" + @property + def array(self) -> Array: + return cast(self.data, POINTER(c_ubyte * self.len)).contents + + +class FfiByteBuffer: + """A byte buffer allocated by Python.""" + + def __init__(self, value): + if isinstance(value, str): + value = value.encode("utf-8") + + if value is None: + dlen = 0 + data = c_char_p() + elif isinstance(value, memoryview): + dlen = value.nbytes + data = c_char_p(value.tobytes()) + elif isinstance(value, bytes): + dlen = len(value) + data = c_char_p(value) + b = c_void_p.from_buffer(data) + del b + else: + raise TypeError(f"Expected str or bytes value, got {type(value)}") + self._dlen = dlen + self._data = data - async def close(self, commit: bool = False): - """Close the session.""" - if not getattr(self, "_closed", False) and self: - await do_call_async( - "askar_session_close", - self, - c_int8(commit), - ) - setattr(self, "_closed", True) + def __bytes__(self) -> bytes: + if not self._data: + return b"" + return self._data.value - def __del__(self): - """Close the session when there are no more references to this object.""" - if not getattr(self, "_closed", False) and self: - do_call("askar_session_close", self, c_int8(0), c_void_p()) + def __len__(self) -> int: + return self._dlen + @property + def _as_parameter_(self) -> RawBuffer: + buf = RawBuffer(len=self._dlen, data=cast(self._data, POINTER(c_ubyte))) + return buf -class ScanHandle(c_size_t): - """Index of an active Store scan instance.""" + @classmethod + def from_param(cls, value): + if isinstance(value, (ByteBuffer, FfiByteBuffer)): + return value + return cls(value) - def __repr__(self) -> str: - """Format scan handle as a string.""" - return f"{self.__class__.__name__}({self.value})" - def __del__(self): - """Close the scan when there are no more references to this object.""" - if self: - get_library().askar_scan_free(self) +class ByteBuffer(Structure): + """A managed byte buffer allocated by the library.""" + _fields_ = [("buffer", RawBuffer)] -class EntryListHandle(c_size_t): - """Pointer to an active EntryList instance.""" + @property + def array(self) -> Array: + return self.buffer.array - def get_category(self, index: int) -> str: - """Get the entry category.""" - cat = StrBuffer() - do_call( - "askar_entry_list_get_category", - self, - c_int32(index), - byref(cat), - ) - return str(cat) + @property + def view(self) -> memoryview: + return memoryview(self.array) - def get_name(self, index: int) -> str: - """Get the entry name.""" - name = StrBuffer() - do_call( - "askar_entry_list_get_name", - self, - c_int32(index), - byref(name), - ) - return str(name) + def __bytes__(self) -> bytes: + return bytes(self.buffer) - def get_value(self, index: int) -> memoryview: - """Get the entry value.""" - val = ByteBuffer() - do_call("askar_entry_list_get_value", self, c_int32(index), byref(val)) - return memoryview(val.raw) + def __len__(self) -> int: + return len(self.buffer) - def get_tags(self, index: int) -> dict: - """Get the entry tags.""" - tags = StrBuffer() - do_call( - "askar_entry_list_get_tags", - self, - c_int32(index), - byref(tags), - ) - if tags: - tags = json.loads(tags.value) - for t in tags: - if isinstance(tags[t], list): - tags[t] = set(tags[t]) - else: - tags = dict() - return tags + def __getitem__(self, idx) -> bytes: + return bytes(self.buffer.array[idx]) def __repr__(self) -> str: - """Format entry list handle as a string.""" - return f"{self.__class__.__name__}({self.value})" + """Format byte buffer as a string.""" + return f"{self.__class__.__name__}({bytes(self)})" def __del__(self): - """Free the entry set when there are no more references.""" - if self: - get_library().askar_entry_list_free(self) - - -class KeyEntryListHandle(c_size_t): - """Pointer to an active KeyEntryList instance.""" - - def get_algorithm(self, index: int) -> str: - """Get the key algorithm.""" - name = StrBuffer() - do_call( - "askar_key_entry_list_get_algorithm", - self, - c_int32(index), - byref(name), - ) - return str(name) - - def get_name(self, index: int) -> str: - """Get the key name.""" - name = StrBuffer() - do_call( - "askar_key_entry_list_get_name", - self, - c_int32(index), - byref(name), - ) - return str(name) - - def get_metadata(self, index: int) -> str: - """Get for the key metadata.""" - metadata = StrBuffer() - do_call( - "askar_key_entry_list_get_metadata", - self, - c_int32(index), - byref(metadata), - ) - return str(metadata) - - def get_tags(self, index: int) -> dict: - """Get the key tags.""" - tags = StrBuffer() - do_call( - "askar_key_entry_list_get_tags", - self, - c_int32(index), - byref(tags), - ) - return json.loads(tags.value) if tags else None - - def load_key(self, index: int) -> "LocalKeyHandle": - """Load the key instance.""" - handle = LocalKeyHandle() - do_call( - "askar_key_entry_list_load_local", - self, - c_int32(index), - byref(handle), - ) - return handle + """Call the byte buffer destructor when this instance is released.""" + invoke_dtor("askar_buffer_free", self.buffer) - def __repr__(self) -> str: - """Format key entry list handle as a string.""" - return f"{self.__class__.__name__}({self.value})" - def __del__(self): - """Free the key entry set when there are no more references.""" - if self: - get_library().askar_key_entry_list_free(self) +class FfiStr: + def __init__(self, value=None): + if value is None: + value = c_char_p() + elif isinstance(value, c_char_p): + pass + else: + if isinstance(value, str): + value = value.encode("utf-8") + if not isinstance(value, bytes): + raise TypeError(f"Expected string value, got {type(value)}") + value = c_char_p(value) + self.value = value + @classmethod + def from_param(cls, value): + if isinstance(value, cls): + return value + return cls(value) -class LocalKeyHandle(c_size_t): - """Pointer to an active LocalKey instance.""" + @property + def _as_parameter_(self): + return self.value def __repr__(self) -> str: - """Format key handle as a string.""" + """Format handle as a string.""" return f"{self.__class__.__name__}({self.value})" - def __del__(self): - """Free the key when there are no more references.""" - if self: - get_library().askar_key_free(self) - - -class FfiByteBuffer(Structure): - """A byte buffer allocated by python.""" - - _fields_ = [ - ("len", c_int64), - ("value", POINTER(c_ubyte)), - ] - - -class RawBuffer(Structure): - """A byte buffer allocated by the library.""" - - _fields_ = [ - ("len", c_int64), - ("data", c_void_p), - ] - - -class ByteBuffer(Structure): - """A managed byte buffer allocated by the library.""" - - _fields_ = [("buffer", RawBuffer)] - - @property - def raw(self) -> Array: - ret = (c_ubyte * self.buffer.len).from_address(self.buffer.data) - setattr(ret, "_ref_", self) # ensure buffer is not dropped - return ret - def __bytes__(self) -> bytes: - return bytes(self.raw) +class FfiJson: + @classmethod + def from_param(cls, value): + if isinstance(value, FfiStr): + return value + if isinstance(value, dict): + value = json.dumps(value) + return FfiStr(value) - def __repr__(self) -> str: - """Format byte buffer as a string.""" - return repr(bytes(self)) - def __del__(self): - """Call the byte buffer destructor when this instance is released.""" - get_library().askar_buffer_free(self.buffer) +class FfiTagsJson: + @classmethod + def from_param(cls, tags): + if isinstance(tags, FfiStr): + return tags + if tags: + tags = json.dumps( + { + name: (list(value) if isinstance(value, set) else value) + for name, value in tags.items() + } + ) + else: + tags = None + return FfiStr(tags) class StrBuffer(c_char_p): """A string allocated by the library.""" - @classmethod - def from_param(cls): - """Returns the type ctypes should use for loading the result.""" - return c_void_p - def is_none(self) -> bool: """Check if the returned string pointer is null.""" return self.value is None @@ -304,7 +217,8 @@ def __str__(self): def __del__(self): """Call the string destructor when this instance is released.""" - get_library().askar_string_free(self) + if self: + invoke_dtor("askar_string_free", self) class AeadParams(Structure): @@ -333,8 +247,7 @@ class Encrypted(Structure): ] def __getitem__(self, idx) -> bytes: - arr = (c_ubyte * self.buffer.len).from_address(self.buffer.data) - return bytes(arr[idx]) + return bytes(self.buffer.array[idx]) def __bytes__(self) -> bytes: """Convert to bytes.""" @@ -381,7 +294,230 @@ def __repr__(self) -> str: def __del__(self): """Call the byte buffer destructor when this instance is released.""" - get_library().askar_buffer_free(self.buffer) + invoke_dtor("askar_buffer_free", self.buffer) + + +class ArcHandle(Structure): + """Base class for handle instances.""" + + _fields_ = [ + ("value", c_size_t), + ] + + def __init__(self, value=0): + if isinstance(value, c_size_t): + value = value.value + if not isinstance(value, int): + raise ValueError("Invalid handle") + super().__init__(value) + + @classmethod + def from_param(cls, param): + if isinstance(param, cls): + return param + return cls(param) + + def __bool__(self): + return bool(self.value) + + def __repr__(self) -> str: + """Format handle as a string.""" + return f"{self.__class__.__name__}({self.value})" + + +class StoreHandle(ArcHandle): + """Handle for an active Store instance.""" + + async def close(self): + """Close the store, waiting for any active connections.""" + if self: + await invoke_async("askar_store_close", (StoreHandle,), self) + self.value = 0 + + def __del__(self): + """Close the store when there are no more references to this object.""" + if self: + invoke_dtor( + "askar_store_close", + self, + None, + 0, + argtypes=(StoreHandle, c_void_p, c_int64), + ) + + +class SessionHandle(ArcHandle): + """Handle for an active Session/Transaction instance.""" + + async def close(self, commit: bool = False): + """Close the session.""" + if self: + await invoke_async( + "askar_session_close", + (SessionHandle, c_int8), + self, + commit, + ) + self.value = 0 + + def __del__(self): + """Close the session when there are no more references to this object.""" + if self: + invoke_dtor( + "askar_session_close", + self, + 0, + None, + 0, + argtypes=(SessionHandle, c_int8, c_void_p, c_int64), + ) + + +class ScanHandle(ArcHandle): + """Handle for an active Store scan instance.""" + + def __del__(self): + """Close the scan when there are no more references to this object.""" + invoke_dtor("askar_scan_free", self) + + +class EntryListHandle(ArcHandle): + """Handle for an active EntryList instance.""" + + def get_category(self, index: int) -> str: + """Get the entry category.""" + cat = StrBuffer() + invoke( + "askar_entry_list_get_category", + (EntryListHandle, c_int32, POINTER(c_char_p)), + self, + index, + byref(cat), + ) + return str(cat) + + def get_name(self, index: int) -> str: + """Get the entry name.""" + name = StrBuffer() + invoke( + "askar_entry_list_get_name", + (EntryListHandle, c_int32, POINTER(c_char_p)), + self, + index, + byref(name), + ) + return str(name) + + def get_value(self, index: int) -> ByteBuffer: + """Get the entry value.""" + val = ByteBuffer() + invoke( + "askar_entry_list_get_value", + (EntryListHandle, c_int32, POINTER(ByteBuffer)), + self, + index, + byref(val), + ) + return val + + def get_tags(self, index: int) -> dict: + """Get the entry tags.""" + tags = StrBuffer() + invoke( + "askar_entry_list_get_tags", + (EntryListHandle, c_int32, POINTER(c_char_p)), + self, + index, + byref(tags), + ) + if tags: + tags = json.loads(tags.value) + for t in tags: + if isinstance(tags[t], list): + tags[t] = set(tags[t]) + else: + tags = dict() + return tags + + def __del__(self): + """Free the entry set when there are no more references.""" + invoke_dtor("askar_entry_list_free", self) + + +class KeyEntryListHandle(ArcHandle): + """Handle for an active KeyEntryList instance.""" + + def get_algorithm(self, index: int) -> str: + """Get the key algorithm.""" + name = StrBuffer() + invoke( + "askar_key_entry_list_get_algorithm", + (KeyEntryListHandle, c_int32, POINTER(c_char_p)), + self, + index, + byref(name), + ) + return str(name) + + def get_name(self, index: int) -> str: + """Get the key name.""" + name = StrBuffer() + invoke( + "askar_key_entry_list_get_name", + (KeyEntryListHandle, c_int32, POINTER(c_char_p)), + self, + index, + byref(name), + ) + return str(name) + + def get_metadata(self, index: int) -> str: + """Get for the key metadata.""" + metadata = StrBuffer() + invoke( + "askar_key_entry_list_get_metadata", + (KeyEntryListHandle, c_int32, POINTER(c_char_p)), + self, + index, + byref(metadata), + ) + return str(metadata) + + def get_tags(self, index: int) -> dict: + """Get the key tags.""" + tags = StrBuffer() + invoke( + "askar_key_entry_list_get_tags", + (KeyEntryListHandle, c_int32, POINTER(c_char_p)), + self, + index, + byref(tags), + ) + return json.loads(tags.value) if tags else None + + def load_key(self, index: int) -> "LocalKeyHandle": + """Load the key instance.""" + handle = LocalKeyHandle() + invoke( + "askar_key_entry_list_load_local", + (KeyEntryListHandle, c_int32, POINTER(LocalKeyHandle)), + self, + index, + byref(handle), + ) + return handle + + def __del__(self): + """Free the key entry set when there are no more references.""" + invoke_dtor("askar_key_entry_list_free", self) + + +class LocalKeyHandle(ArcHandle): + """Handle for an active LocalKey instance.""" + + def __del__(self): + """Free the key when there are no more references.""" + invoke_dtor("askar_key_free", self) def get_library() -> CDLL: @@ -432,31 +568,35 @@ def _init_logger(): # avoid redefining TRACE if another library has added it logging.addLevelName(5, "TRACE") - def _enabled(_context, level: int) -> bool: - return logger.isEnabledFor(LOG_LEVELS.get(level, level)) - - def _log( - _context, - level: int, - target: c_char_p, - message: c_char_p, - module_path: c_char_p, - file_name: c_char_p, - line: int, - ): - logger.getChild("native." + target.decode().replace("::", ".")).log( - LOG_LEVELS.get(level, level), - "\t%s:%d | %s", - file_name.decode() if file_name else None, - line, - message.decode(), + if not hasattr(_init_logger, "log_cb"): + + @CFUNCTYPE( + None, c_void_p, c_int32, c_char_p, c_char_p, c_char_p, c_char_p, c_int32 ) + def _log_cb( + _context, + level: int, + target: c_char_p, + message: c_char_p, + module_path: c_char_p, + file_name: c_char_p, + line: int, + ): + logger.getChild("native." + target.decode().replace("::", ".")).log( + LOG_LEVELS.get(level, level), + "\t%s:%d | %s", + file_name.decode() if file_name else None, + line, + message.decode(), + ) + + _init_logger.log_cb = _log_cb - _init_logger.enabled_cb = CFUNCTYPE(c_int8, c_void_p, c_int32)(_enabled) + @CFUNCTYPE(c_int8, c_void_p, c_int32) + def _enabled_cb(_context, level: int) -> bool: + return logger.isEnabledFor(LOG_LEVELS.get(level, level)) - _init_logger.log_cb = CFUNCTYPE( - None, c_void_p, c_int32, c_char_p, c_char_p, c_char_p, c_char_p, c_int32 - )(_log) + _init_logger.enabled_cb = _enabled_cb if os.getenv("RUST_LOG"): # level from environment @@ -465,20 +605,21 @@ def _log( # inherit current level from logger level = _convert_log_level(logger.level or logger.parent.level) - do_call( + invoke( "askar_set_custom_logger", - c_void_p(), # context + (c_void_p, c_void_p, c_void_p, c_void_p, c_int32), + None, # context _init_logger.log_cb, _init_logger.enabled_cb, - c_void_p(), # flush - c_int32(level), + None, # flush + level, ) def set_max_log_level(level: Union[str, int, None]): get_library() # ensure logger is initialized set_level = _convert_log_level(level) - do_call("askar_set_max_log_level", c_int32(set_level)) + invoke("askar_set_max_log_level", (c_int32,), set_level) def _convert_log_level(level: Union[str, int, None]): @@ -496,6 +637,9 @@ def _convert_log_level(level: Union[str, int, None]): def _fulfill_future(fut: asyncio.Future, result, err: Exception = None): """Resolve a callback future given the result and exception, if any.""" + if not CALLBACKS.pop(fut, None): + LOGGER.info("callback already fulfilled") + return if fut.cancelled(): LOGGER.debug("callback previously cancelled") elif err: @@ -504,107 +648,86 @@ def _fulfill_future(fut: asyncio.Future, result, err: Exception = None): fut.set_result(result) -def _create_callback(cb_type: CFUNCTYPE, fut: asyncio.Future, post_process=None): +def _create_callback( + cb_type: CFUNCTYPE, + loop: asyncio.AbstractEventLoop, + fut: asyncio.Future, +): """Create a callback to handle the response from an async library method.""" - def _cb(id: int, err: int, result=None): + def _cb(_id: int, err: int, result=None): """Callback function passed to the CFUNCTYPE for invocation.""" - if post_process: - result = post_process(result) exc = get_current_error() if err else None - try: - (loop, _cb) = CALLBACKS.pop(fut) - except KeyError: - LOGGER.debug("callback already fulfilled") - return - loop.call_soon_threadsafe(lambda: _fulfill_future(fut, result, exc)) + loop.call_soon_threadsafe(_fulfill_future, fut, result, exc) res = cb_type(_cb) return res -def do_call(fn_name, *args): +def _get_library_method(name: str, argtypes, *, restype=c_int64): + method = INVOKE.get(name) + if not method: + method = getattr(get_library(), name) + method.argtypes = argtypes + method.restype = restype + INVOKE[name] = method + return method + + +def _load_method_arguments(name, argtypes, args): + """Preload argument values to avoid freeing any intermediate data.""" + if not argtypes: + return args + if len(args) != len(argtypes): + raise ValueError(f"{name}: Arguments length does not match argtypes length") + return [ + arg if issubclass(argtype, _SimpleCData) else argtype.from_param(arg) + for (arg, argtype) in zip(args, argtypes) + ] + + +def invoke(name, argtypes, *args): """Perform a synchronous library function call.""" - lib_fn = getattr(get_library(), fn_name) - lib_fn.restype = c_int64 - result = lib_fn(*args) + method = _get_library_method(name, argtypes) + args = _load_method_arguments(name, argtypes, args) + result = method(*args) if result: raise get_current_error(True) -def do_call_async( - fn_name, *args, return_type=None, post_process=None -) -> asyncio.Future: +def invoke_async(name: str, argtypes, *args, return_type=None): """Perform an asynchronous library function call.""" - lib_fn = getattr(get_library(), fn_name) - lib_fn.restype = c_int64 + method = _get_library_method(name, (*argtypes, c_void_p, c_int64)) loop = asyncio.get_event_loop() fut = loop.create_future() - cf_args = [None, c_int64, c_int64] + cf_args = [c_int64, c_int64] if return_type: cf_args.append(return_type) - cb_type = CFUNCTYPE(*cf_args) # could be cached - cb_res = _create_callback(cb_type, fut, post_process) - # keep a reference to the callback function to avoid it being freed - CALLBACKS[fut] = (loop, cb_res) - result = lib_fn(*args, cb_res, c_void_p()) # not making use of callback ID + cb_type = CFUNCTYPE(None, *cf_args) # could be cached + cb_res = _create_callback(cb_type, loop, fut) + args = _load_method_arguments(name, argtypes, args) + # save a reference to the callback function and arguments to avoid GC + CALLBACKS[fut] = (cb_res, args) + result = method(*args, cb_res, 0) # not making use of callback ID if result: - # callback will not be executed - if CALLBACKS.pop(fut): - fut.set_exception(get_current_error()) + # FFI must not execute the callback if an error is returned + err = get_current_error(True) + _fulfill_future(fut, None, err) return fut -def encode_str(arg: Optional[Union[str, bytes]]) -> c_char_p: - """ - Encode an optional input argument as a string. - - Returns: None if the argument is None, otherwise the value encoded utf-8. - """ - if arg is None: - return c_char_p() - if isinstance(arg, str): - arg = arg.encode("utf-8") - return c_char_p(arg) - - -def encode_bytes( - arg: Optional[Union[str, bytes, ByteBuffer, FfiByteBuffer]] -) -> Union[FfiByteBuffer, ByteBuffer]: - if isinstance(arg, ByteBuffer) or isinstance(arg, FfiByteBuffer): - return arg - buf = FfiByteBuffer() - if isinstance(arg, memoryview): - buf.len = arg.nbytes - if arg.contiguous and not arg.readonly: - buf.value = (c_ubyte * buf.len).from_buffer(arg.obj) - else: - buf.value = (c_ubyte * buf.len).from_buffer_copy(arg.obj) - elif isinstance(arg, bytearray): - buf.len = len(arg) - if buf.len > 0: - buf.value = (c_ubyte * buf.len).from_buffer(arg) - elif arg is not None: - if isinstance(arg, str): - arg = arg.encode("utf-8") - buf.len = len(arg) - if buf.len > 0: - buf.value = (c_ubyte * buf.len).from_buffer_copy(arg) - return buf - - -def encode_tags(tags: Optional[dict]) -> c_char_p: - """Encode the tags as a JSON string.""" - if tags: - tags = json.dumps( - { - name: (list(value) if isinstance(value, set) else value) - for name, value in tags.items() - } - ) - else: - tags = None - return encode_str(tags) +def invoke_dtor(name: str, *values, argtypes=None): + method = INVOKE.get(name) + if not method: + lib = get_library() + if not lib: + return + method = getattr(lib, name) + if argtypes: + method.argtypes = argtypes + method.restype = None + INVOKE[name] = method + method(*values) def get_current_error(expect: bool = False) -> Optional[AskarError]: @@ -615,7 +738,7 @@ def get_current_error(expect: bool = False) -> Optional[AskarError]: expect: Return a default error message if none is found """ err_json = StrBuffer() - if not get_library().askar_get_current_error(byref(err_json)): + if not LIB or not LIB.askar_get_current_error(byref(err_json)): try: msg = json.loads(err_json.value) except json.JSONDecodeError: @@ -633,7 +756,12 @@ def get_current_error(expect: bool = False) -> Optional[AskarError]: def generate_raw_key(seed: Union[str, bytes] = None) -> str: """Generate a new raw store wrapping key.""" key = StrBuffer() - do_call("askar_store_generate_raw_key", encode_bytes(seed), byref(key)) + invoke( + "askar_store_generate_raw_key", + (FfiByteBuffer, POINTER(c_char_p)), + seed, + byref(key), + ) return str(key) @@ -648,12 +776,13 @@ async def store_open( uri: str, key_method: str = None, pass_key: str = None, profile: str = None ) -> StoreHandle: """Open an existing Store and return the open handle.""" - return await do_call_async( + return await invoke_async( "askar_store_open", - encode_str(uri), - encode_str(key_method and key_method.lower()), - encode_str(pass_key), - encode_str(profile), + (FfiStr, FfiStr, FfiStr, FfiStr), + uri, + key_method and key_method.lower(), + pass_key, + profile, return_type=StoreHandle, ) @@ -666,13 +795,14 @@ async def store_provision( recreate: bool = False, ) -> StoreHandle: """Provision a new Store and return the open handle.""" - return await do_call_async( + return await invoke_async( "askar_store_provision", - encode_str(uri), - encode_str(key_method and key_method.lower()), - encode_str(pass_key), - encode_str(profile), - c_int8(recreate), + (FfiStr, FfiStr, FfiStr, FfiStr, c_int8), + uri, + key_method and key_method.lower(), + pass_key, + profile, + recreate, return_type=StoreHandle, ) @@ -680,10 +810,11 @@ async def store_provision( async def store_create_profile(handle: StoreHandle, name: str = None) -> str: """Create a new profile in a Store.""" return str( - await do_call_async( + await invoke_async( "askar_store_create_profile", + (StoreHandle, FfiStr), handle, - encode_str(name), + name, return_type=StrBuffer, ) ) @@ -692,8 +823,9 @@ async def store_create_profile(handle: StoreHandle, name: str = None) -> str: async def store_get_profile_name(handle: StoreHandle) -> str: """Get the name of the default Store instance profile.""" return str( - await do_call_async( + await invoke_async( "askar_store_get_profile_name", + (StoreHandle,), handle, return_type=StrBuffer, ) @@ -703,10 +835,11 @@ async def store_get_profile_name(handle: StoreHandle) -> str: async def store_remove_profile(handle: StoreHandle, name: str) -> bool: """Remove an existing profile from a Store.""" return ( - await do_call_async( + await invoke_async( "askar_store_remove_profile", + (StoreHandle, FfiStr), handle, - encode_str(name), + name, return_type=c_int8, ) != 0 @@ -719,20 +852,23 @@ async def store_rekey( pass_key: str = None, ) -> StoreHandle: """Replace the store key on a Store.""" - return await do_call_async( + return await invoke_async( "askar_store_rekey", + (StoreHandle, FfiStr, FfiStr), handle, - encode_str(key_method and key_method.lower()), - encode_str(pass_key), + key_method and key_method.lower(), + pass_key, + return_type=c_int8, ) async def store_remove(uri: str) -> bool: """Remove an existing Store, if any.""" return ( - await do_call_async( + await invoke_async( "askar_store_remove", - encode_str(uri), + (FfiStr,), + uri, return_type=c_int8, ) != 0 @@ -743,11 +879,12 @@ async def session_start( handle: StoreHandle, profile: Optional[str] = None, as_transaction: bool = False ) -> SessionHandle: """Start a new session with an open Store.""" - return await do_call_async( + return await invoke_async( "askar_session_start", + (StoreHandle, FfiStr, c_int8), handle, - encode_str(profile), - c_int8(as_transaction), + profile, + as_transaction, return_type=SessionHandle, ) @@ -756,13 +893,14 @@ async def session_count( handle: SessionHandle, category: str, tag_filter: Union[str, dict] = None ) -> int: """Count rows in the Store.""" - category = encode_str(category) - if isinstance(tag_filter, dict): - tag_filter = json.dumps(tag_filter) - tag_filter = encode_str(tag_filter) return int( - await do_call_async( - "askar_session_count", handle, category, tag_filter, return_type=c_int64 + await invoke_async( + "askar_session_count", + (SessionHandle, FfiStr, FfiJson), + handle, + category, + tag_filter, + return_type=c_int64, ) ) @@ -771,14 +909,13 @@ async def session_fetch( handle: SessionHandle, category: str, name: str, for_update: bool = False ) -> EntryListHandle: """Fetch a row from the Store.""" - category = encode_str(category) - name = encode_str(name) - return await do_call_async( + return await invoke_async( "askar_session_fetch", + (SessionHandle, FfiStr, FfiStr, c_int8), handle, category, name, - c_int8(for_update), + for_update, return_type=EntryListHandle, ) @@ -791,15 +928,14 @@ async def session_fetch_all( for_update: bool = False, ) -> EntryListHandle: """Fetch all matching rows in the Store.""" - if isinstance(tag_filter, dict): - tag_filter = json.dumps(tag_filter) - return await do_call_async( + return await invoke_async( "askar_session_fetch_all", + (SessionHandle, FfiStr, FfiJson, c_int64, c_int8), handle, - encode_str(category), - encode_str(tag_filter), - c_int64(limit if limit is not None else -1), - c_int8(for_update), + category, + tag_filter, + limit if limit is not None else -1, + for_update, return_type=EntryListHandle, ) @@ -810,14 +946,13 @@ async def session_remove_all( tag_filter: Union[str, dict] = None, ) -> int: """Remove all matching rows in the Store.""" - if isinstance(tag_filter, dict): - tag_filter = json.dumps(tag_filter) return int( - await do_call_async( + await invoke_async( "askar_session_remove_all", + (SessionHandle, FfiStr, FfiJson), handle, - encode_str(category), - encode_str(tag_filter), + category, + tag_filter, return_type=c_int64, ) ) @@ -834,15 +969,16 @@ async def session_update( ): """Update a Store by inserting, updating, or removing a record.""" - return await do_call_async( + return await invoke_async( "askar_session_update", + (SessionHandle, c_int8, FfiStr, FfiStr, FfiByteBuffer, FfiTagsJson, c_int64), handle, - c_int8(operation.value), - encode_str(category), - encode_str(name), - encode_bytes(value), - encode_tags(tags), - c_int64(-1 if expiry_ms is None else expiry_ms), + operation.value, + category, + name, + value, + tags, + -1 if expiry_ms is None else expiry_ms, ) @@ -854,30 +990,29 @@ async def session_insert_key( tags: dict = None, expiry_ms: Optional[int] = None, ): - await do_call_async( + return await invoke_async( "askar_session_insert_key", + (SessionHandle, LocalKeyHandle, FfiStr, FfiStr, FfiTagsJson, c_int64), handle, key_handle, - encode_str(name), - encode_str(metadata), - encode_tags(tags), - c_int64(-1 if expiry_ms is None else expiry_ms), - return_type=c_void_p, + name, + metadata, + tags, + -1 if expiry_ms is None else expiry_ms, ) async def session_fetch_key( handle: SessionHandle, name: str, for_update: bool = False -) -> Optional[KeyEntryListHandle]: - ptr = await do_call_async( +) -> KeyEntryListHandle: + return await invoke_async( "askar_session_fetch_key", + (SessionHandle, FfiStr, c_int8), handle, - encode_str(name), - c_int8(for_update), - return_type=c_void_p, + name, + for_update, + return_type=KeyEntryListHandle, ) - if ptr: - return KeyEntryListHandle(ptr) async def session_fetch_all_keys( @@ -887,20 +1022,19 @@ async def session_fetch_all_keys( tag_filter: Union[str, dict] = None, limit: int = None, for_update: bool = False, -) -> EntryListHandle: +) -> KeyEntryListHandle: """Fetch all matching keys in the Store.""" if isinstance(alg, KeyAlg): alg = alg.value - if isinstance(tag_filter, dict): - tag_filter = json.dumps(tag_filter) - return await do_call_async( + return await invoke_async( "askar_session_fetch_all_keys", + (SessionHandle, FfiStr, FfiStr, FfiJson, c_int64, c_int8), handle, - encode_str(alg), - encode_str(thumbprint), - encode_str(tag_filter), - c_int64(limit if limit is not None else -1), - c_int8(for_update), + alg, + thumbprint, + tag_filter, + limit if limit is not None else -1, + for_update, return_type=KeyEntryListHandle, ) @@ -912,21 +1046,23 @@ async def session_update_key( tags: dict = None, expiry_ms: Optional[int] = None, ): - await do_call_async( + await invoke_async( "askar_session_update_key", + (SessionHandle, FfiStr, FfiStr, FfiTagsJson, c_int64), handle, - encode_str(name), - encode_str(metadata), - encode_tags(tags), - c_int64(-1 if expiry_ms is None else expiry_ms), + name, + metadata, + tags, + -1 if expiry_ms is None else expiry_ms, ) async def session_remove_key(handle: SessionHandle, name: str): - await do_call_async( + await invoke_async( "askar_session_remove_key", + (SessionHandle, FfiStr), handle, - encode_str(name), + name, ) @@ -939,35 +1075,44 @@ async def scan_start( limit: int = None, ) -> ScanHandle: """Create a new Scan against the Store.""" - if isinstance(tag_filter, dict): - tag_filter = json.dumps(tag_filter) - tag_filter = encode_str(tag_filter) - return await do_call_async( + return await invoke_async( "askar_scan_start", + (StoreHandle, FfiStr, FfiStr, FfiJson, c_int64, c_int64), handle, - encode_str(profile), - encode_str(category), + profile, + category, tag_filter, - c_int64(offset or 0), - c_int64(limit if limit is not None else -1), + offset or 0, + limit if limit is not None else -1, return_type=ScanHandle, ) -async def scan_next(handle: StoreHandle) -> Optional[EntryListHandle]: - handle = await do_call_async("askar_scan_next", handle, return_type=EntryListHandle) - return handle or None +async def scan_next(handle: ScanHandle) -> EntryListHandle: + return await invoke_async( + "askar_scan_next", (ScanHandle,), handle, return_type=EntryListHandle + ) def entry_list_count(handle: EntryListHandle) -> int: len = c_int32() - do_call("askar_entry_list_count", handle, byref(len)) + invoke( + "askar_entry_list_count", + (EntryListHandle, POINTER(c_int32)), + handle, + byref(len), + ) return len.value -def key_entry_list_count(handle: EntryListHandle) -> int: +def key_entry_list_count(handle: KeyEntryListHandle) -> int: len = c_int32() - do_call("askar_key_entry_list_count", handle, byref(len)) + invoke( + "askar_key_entry_list_count", + (KeyEntryListHandle, POINTER(c_int32)), + handle, + byref(len), + ) return len.value @@ -975,7 +1120,13 @@ def key_generate(alg: Union[str, KeyAlg], ephemeral: bool = False) -> LocalKeyHa handle = LocalKeyHandle() if isinstance(alg, KeyAlg): alg = alg.value - do_call("askar_key_generate", encode_str(alg), c_int8(ephemeral), byref(handle)) + invoke( + "askar_key_generate", + (FfiStr, c_int8, POINTER(LocalKeyHandle)), + alg, + ephemeral, + byref(handle), + ) return handle @@ -989,11 +1140,12 @@ def key_from_seed( alg = alg.value if isinstance(method, SeedMethod): method = method.value - do_call( + invoke( "askar_key_from_seed", - encode_str(alg), - encode_bytes(seed), - encode_str(method), + (FfiStr, FfiByteBuffer, FfiStr, POINTER(LocalKeyHandle)), + alg, + seed, + method, byref(handle), ) return handle @@ -1005,10 +1157,11 @@ def key_from_public_bytes( handle = LocalKeyHandle() if isinstance(alg, KeyAlg): alg = alg.value - do_call( + invoke( "askar_key_from_public_bytes", - encode_str(alg), - encode_bytes(public), + (FfiStr, FfiByteBuffer, POINTER(LocalKeyHandle)), + alg, + public, byref(handle), ) return handle @@ -1016,8 +1169,9 @@ def key_from_public_bytes( def key_get_public_bytes(handle: LocalKeyHandle) -> ByteBuffer: buf = ByteBuffer() - do_call( + invoke( "askar_key_get_public_bytes", + (LocalKeyHandle, POINTER(ByteBuffer)), handle, byref(buf), ) @@ -1030,10 +1184,11 @@ def key_from_secret_bytes( handle = LocalKeyHandle() if isinstance(alg, KeyAlg): alg = alg.value - do_call( + invoke( "askar_key_from_secret_bytes", - encode_str(alg), - encode_bytes(secret), + (FfiStr, FfiByteBuffer, POINTER(LocalKeyHandle)), + alg, + secret, byref(handle), ) return handle @@ -1041,8 +1196,9 @@ def key_from_secret_bytes( def key_get_secret_bytes(handle: LocalKeyHandle) -> ByteBuffer: buf = ByteBuffer() - do_call( + invoke( "askar_key_get_secret_bytes", + (LocalKeyHandle, POINTER(ByteBuffer)), handle, byref(buf), ) @@ -1053,7 +1209,12 @@ def key_from_jwk(jwk: Union[dict, str, bytes]) -> LocalKeyHandle: handle = LocalKeyHandle() if isinstance(jwk, dict): jwk = json.dumps(jwk) - do_call("askar_key_from_jwk", encode_bytes(jwk), byref(handle)) + invoke( + "askar_key_from_jwk", + (FfiByteBuffer, POINTER(LocalKeyHandle)), + jwk, + byref(handle), + ) return handle @@ -1061,7 +1222,13 @@ def key_convert(handle: LocalKeyHandle, alg: Union[str, KeyAlg]) -> LocalKeyHand key = LocalKeyHandle() if isinstance(alg, KeyAlg): alg = alg.value - do_call("askar_key_convert", handle, encode_str(alg), byref(key)) + invoke( + "askar_key_convert", + (LocalKeyHandle, FfiStr, POINTER(LocalKeyHandle)), + handle, + alg, + byref(key), + ) return key @@ -1071,21 +1238,36 @@ def key_exchange( key = LocalKeyHandle() if isinstance(alg, KeyAlg): alg = alg.value - do_call( - "askar_key_from_key_exchange", encode_str(alg), sk_handle, pk_handle, byref(key) + invoke( + "askar_key_from_key_exchange", + (FfiStr, LocalKeyHandle, LocalKeyHandle, POINTER(LocalKeyHandle)), + alg, + sk_handle, + pk_handle, + byref(key), ) return key def key_get_algorithm(handle: LocalKeyHandle) -> str: alg = StrBuffer() - do_call("askar_key_get_algorithm", handle, byref(alg)) + invoke( + "askar_key_get_algorithm", + (LocalKeyHandle, POINTER(c_char_p)), + handle, + byref(alg), + ) return str(alg) def key_get_ephemeral(handle: LocalKeyHandle) -> bool: eph = c_int8() - do_call("askar_key_get_ephemeral", handle, byref(eph)) + invoke( + "askar_key_get_ephemeral", + (LocalKeyHandle, POINTER(c_int8)), + handle, + byref(eph), + ) return eph.value != 0 @@ -1093,13 +1275,24 @@ def key_get_jwk_public(handle: LocalKeyHandle, alg: Union[str, KeyAlg] = None) - jwk = StrBuffer() if isinstance(alg, KeyAlg): alg = alg.value - do_call("askar_key_get_jwk_public", handle, encode_str(alg), byref(jwk)) + invoke( + "askar_key_get_jwk_public", + (LocalKeyHandle, FfiStr, POINTER(c_char_p)), + handle, + alg, + byref(jwk), + ) return str(jwk) def key_get_jwk_secret(handle: LocalKeyHandle) -> ByteBuffer: sec = ByteBuffer() - do_call("askar_key_get_jwk_secret", handle, byref(sec)) + invoke( + "askar_key_get_jwk_secret", + (LocalKeyHandle, POINTER(ByteBuffer)), + handle, + byref(sec), + ) return sec @@ -1109,19 +1302,35 @@ def key_get_jwk_thumbprint( thumb = StrBuffer() if isinstance(alg, KeyAlg): alg = alg.value - do_call("askar_key_get_jwk_thumbprint", handle, encode_str(alg), byref(thumb)) + invoke( + "askar_key_get_jwk_thumbprint", + (LocalKeyHandle, FfiStr, POINTER(c_char_p)), + handle, + alg, + byref(thumb), + ) return str(thumb) def key_aead_get_params(handle: LocalKeyHandle) -> AeadParams: params = AeadParams() - do_call("askar_key_aead_get_params", handle, byref(params)) + invoke( + "askar_key_aead_get_params", + (LocalKeyHandle, POINTER(AeadParams)), + handle, + byref(params), + ) return params def key_aead_random_nonce(handle: LocalKeyHandle) -> ByteBuffer: nonce = ByteBuffer() - do_call("askar_key_aead_random_nonce", handle, byref(nonce)) + invoke( + "askar_key_aead_random_nonce", + (LocalKeyHandle, POINTER(ByteBuffer)), + handle, + byref(nonce), + ) return nonce @@ -1132,12 +1341,19 @@ def key_aead_encrypt( aad: Optional[Union[bytes, ByteBuffer]], ) -> Encrypted: enc = Encrypted() - do_call( + invoke( "askar_key_aead_encrypt", + ( + LocalKeyHandle, + FfiByteBuffer, + FfiByteBuffer, + FfiByteBuffer, + POINTER(Encrypted), + ), handle, - encode_bytes(input), - encode_bytes(nonce), - encode_bytes(aad), + input, + nonce, + aad, byref(enc), ) return enc @@ -1152,14 +1368,23 @@ def key_aead_decrypt( ) -> ByteBuffer: dec = ByteBuffer() if isinstance(ciphertext, Encrypted): + nonce = ciphertext.nonce ciphertext = ciphertext.ciphertext_tag - do_call( + invoke( "askar_key_aead_decrypt", + ( + LocalKeyHandle, + FfiByteBuffer, + FfiByteBuffer, + FfiByteBuffer, + FfiByteBuffer, + POINTER(ByteBuffer), + ), handle, - encode_bytes(ciphertext), - encode_bytes(nonce), - encode_bytes(tag), - encode_bytes(aad), + ciphertext, + nonce, + tag, + aad, byref(dec), ) return dec @@ -1171,11 +1396,12 @@ def key_sign_message( sig_type: Optional[str], ) -> ByteBuffer: sig = ByteBuffer() - do_call( + invoke( "askar_key_sign_message", + (LocalKeyHandle, FfiByteBuffer, FfiStr, POINTER(ByteBuffer)), handle, - encode_bytes(message), - encode_str(sig_type), + message, + sig_type, byref(sig), ) return sig @@ -1188,12 +1414,13 @@ def key_verify_signature( sig_type: Optional[str], ) -> bool: verify = c_int8() - do_call( + invoke( "askar_key_verify_signature", + (LocalKeyHandle, FfiByteBuffer, FfiByteBuffer, FfiStr, POINTER(c_int8)), handle, - encode_bytes(message), - encode_bytes(signature), - encode_str(sig_type), + message, + signature, + sig_type, byref(verify), ) return verify.value != 0 @@ -1205,11 +1432,12 @@ def key_wrap_key( nonce: Optional[Union[bytes, ByteBuffer]], ) -> Encrypted: wrapped = Encrypted() - do_call( + invoke( "askar_key_wrap_key", + (LocalKeyHandle, LocalKeyHandle, FfiByteBuffer, POINTER(Encrypted)), handle, other, - encode_bytes(nonce), + nonce, byref(wrapped), ) return wrapped @@ -1227,13 +1455,21 @@ def key_unwrap_key( alg = alg.value if isinstance(ciphertext, Encrypted): ciphertext = ciphertext.ciphertext_tag - do_call( + invoke( "askar_key_unwrap_key", + ( + LocalKeyHandle, + FfiStr, + FfiByteBuffer, + FfiByteBuffer, + FfiByteBuffer, + POINTER(LocalKeyHandle), + ), handle, - encode_str(alg), - encode_bytes(ciphertext), - encode_bytes(nonce), - encode_bytes(tag), + alg, + ciphertext, + nonce, + tag, byref(result), ) return result @@ -1241,8 +1477,9 @@ def key_unwrap_key( def key_crypto_box_random_nonce() -> ByteBuffer: buf = ByteBuffer() - do_call( + invoke( "askar_key_crypto_box_random_nonce", + (POINTER(ByteBuffer),), byref(buf), ) return buf @@ -1255,12 +1492,19 @@ def key_crypto_box( nonce: Union[bytes, ByteBuffer], ) -> ByteBuffer: buf = ByteBuffer() - do_call( + invoke( "askar_key_crypto_box", + ( + LocalKeyHandle, + LocalKeyHandle, + FfiByteBuffer, + FfiByteBuffer, + POINTER(ByteBuffer), + ), recip_handle, sender_handle, - encode_bytes(message), - encode_bytes(nonce), + message, + nonce, byref(buf), ) return buf @@ -1273,12 +1517,19 @@ def key_crypto_box_open( nonce: Union[bytes, ByteBuffer], ) -> ByteBuffer: buf = ByteBuffer() - do_call( + invoke( "askar_key_crypto_box_open", + ( + LocalKeyHandle, + LocalKeyHandle, + FfiByteBuffer, + FfiByteBuffer, + POINTER(ByteBuffer), + ), recip_handle, sender_handle, - encode_bytes(message), - encode_bytes(nonce), + message, + nonce, byref(buf), ) return buf @@ -1289,10 +1540,11 @@ def key_crypto_box_seal( message: Union[bytes, str, ByteBuffer], ) -> ByteBuffer: buf = ByteBuffer() - do_call( + invoke( "askar_key_crypto_box_seal", + (LocalKeyHandle, FfiByteBuffer, POINTER(ByteBuffer)), handle, - encode_bytes(message), + message, byref(buf), ) return buf @@ -1303,10 +1555,11 @@ def key_crypto_box_seal_open( ciphertext: Union[bytes, ByteBuffer], ) -> ByteBuffer: buf = ByteBuffer() - do_call( + invoke( "askar_key_crypto_box_seal_open", + (LocalKeyHandle, FfiByteBuffer, POINTER(ByteBuffer)), handle, - encode_bytes(ciphertext), + ciphertext, byref(buf), ) return buf @@ -1324,15 +1577,25 @@ def key_derive_ecdh_es( key = LocalKeyHandle() if isinstance(key_alg, KeyAlg): key_alg = key_alg.value - do_call( + invoke( "askar_key_derive_ecdh_es", - encode_str(key_alg), + ( + FfiStr, + LocalKeyHandle, + LocalKeyHandle, + FfiByteBuffer, + FfiByteBuffer, + FfiByteBuffer, + c_int8, + POINTER(LocalKeyHandle), + ), + key_alg, ephem_key, receiver_key, - encode_bytes(alg_id), - encode_bytes(apu), - encode_bytes(apv), - c_int8(receive), + alg_id, + apu, + apv, + receive, byref(key), ) return key @@ -1352,17 +1615,29 @@ def key_derive_ecdh_1pu( key = LocalKeyHandle() if isinstance(key_alg, KeyAlg): key_alg = key_alg.value - do_call( + invoke( "askar_key_derive_ecdh_1pu", - encode_str(key_alg), + ( + FfiStr, + LocalKeyHandle, + LocalKeyHandle, + LocalKeyHandle, + FfiByteBuffer, + FfiByteBuffer, + FfiByteBuffer, + FfiByteBuffer, + c_int8, + POINTER(LocalKeyHandle), + ), + key_alg, ephem_key, sender_key, receiver_key, - encode_bytes(alg_id), - encode_bytes(apu), - encode_bytes(apv), - encode_bytes(cc_tag), - c_int8(receive), + alg_id, + apu, + apv, + cc_tag, + receive, byref(key), ) return key diff --git a/wrappers/python/aries_askar/ecdh.py b/wrappers/python/aries_askar/ecdh.py index d95dffed..bac928d7 100644 --- a/wrappers/python/aries_askar/ecdh.py +++ b/wrappers/python/aries_askar/ecdh.py @@ -5,10 +5,10 @@ from .types import KeyAlg -def _load_key(key: Union[dict, str, Key]) -> Key: - if isinstance(key, (str, dict)): - key = Key.from_jwk(key) - return key +def _load_key(key: Union[dict, str, bytes, Key]) -> Key: + if isinstance(key, Key): + return key + return Key.from_jwk(key) class EcdhEs: diff --git a/wrappers/python/aries_askar/key.py b/wrappers/python/aries_askar/key.py index bdd2e657..e1602f79 100644 --- a/wrappers/python/aries_askar/key.py +++ b/wrappers/python/aries_askar/key.py @@ -4,14 +4,14 @@ from . import bindings -from .bindings import Encrypted +from .bindings import Encrypted, LocalKeyHandle from .types import KeyAlg, SeedMethod class Key: """An active key or keypair instance.""" - def __init__(self, handle: bindings.LocalKeyHandle): + def __init__(self, handle: LocalKeyHandle): """Initialize the Key instance.""" self._handle = handle @@ -42,7 +42,7 @@ def from_jwk(cls, jwk: Union[dict, str, bytes]) -> "Key": return cls(bindings.key_from_jwk(jwk)) @property - def handle(self) -> bindings.LocalKeyHandle: + def handle(self) -> LocalKeyHandle: """Accessor for the key handle.""" return self._handle @@ -52,7 +52,7 @@ def algorithm(self) -> KeyAlg: return KeyAlg.from_key_alg(alg) @property - def ephemeral(self) -> "Key": + def ephemeral(self) -> bool: return bindings.key_get_ephemeral(self._handle) def convert_key(self, alg: Union[str, KeyAlg]) -> "Key": diff --git a/wrappers/python/aries_askar/store.py b/wrappers/python/aries_askar/store.py index d5f86832..5c969fc7 100644 --- a/wrappers/python/aries_askar/store.py +++ b/wrappers/python/aries_askar/store.py @@ -9,6 +9,7 @@ from . import bindings from .bindings import ( + ByteBuffer, EntryListHandle, KeyEntryListHandle, ScanHandle, @@ -46,7 +47,7 @@ def value(self) -> bytes: return bytes(self.raw_value) @cached_property - def raw_value(self) -> memoryview: + def raw_value(self) -> ByteBuffer: """Accessor for the entry raw value.""" return self._list.get_value(self._pos) @@ -101,7 +102,20 @@ def __getitem__(self, index) -> Entry: return Entry(self._handle, index) def __iter__(self): - return self + return IterEntryList(self) + + def __len__(self) -> int: + return self._len + + def __repr__(self) -> str: + return f"" + + +class IterEntryList: + def __init__(self, list: EntryList): + self._handle = list._handle + self._len = list._len + self._pos = 0 def __next__(self): if self._pos < self._len: @@ -111,12 +125,6 @@ def __next__(self): else: raise StopIteration - def __len__(self) -> int: - return self._len - - def __repr__(self) -> str: - return f"" - class KeyEntry: """Pointer to one result of a KeyEntryList instance.""" @@ -184,15 +192,7 @@ def __getitem__(self, index) -> KeyEntry: return KeyEntry(self._handle, index) def __iter__(self): - return self - - def __next__(self): - if self._pos < self._len: - entry = KeyEntry(self._handle, self._pos) - self._pos += 1 - return entry - else: - raise StopIteration + return IterKeyEntryList(self) def __len__(self) -> int: return self._len @@ -203,6 +203,21 @@ def __repr__(self) -> str: ) +class IterKeyEntryList: + def __init__(self, list: KeyEntryList): + self._handle = list._handle + self._len = list._len + self._pos = 0 + + def __next__(self): + if self._pos < self._len: + entry = KeyEntry(self._handle, self._pos) + self._pos += 1 + return entry + else: + raise StopIteration + + class Scan: """A scan of the Store.""" @@ -216,9 +231,9 @@ def __init__( limit: int = None, ): """Initialize the Scan instance.""" - self.params = (store, profile, category, tag_filter, offset, limit) + self._params = (store, profile, category, tag_filter, offset, limit) self._handle: ScanHandle = None - self._buffer: EntryList = None + self._buffer: IterEntryList = None @property def handle(self) -> ScanHandle: @@ -230,7 +245,8 @@ def __aiter__(self): async def __anext__(self): if self._handle is None: - (store, profile, category, tag_filter, offset, limit) = self.params + (store, profile, category, tag_filter, offset, limit) = self._params + self._params = None if not store.handle: raise AskarError( AskarErrorCode.WRAPPER, "Cannot scan from closed store" @@ -239,7 +255,7 @@ async def __anext__(self): store.handle, profile, category, tag_filter, offset, limit ) list_handle = await bindings.scan_next(self._handle) - self._buffer = EntryList(list_handle) if list_handle else None + self._buffer = iter(EntryList(list_handle)) if list_handle else None while True: if not self._buffer: raise StopAsyncIteration @@ -247,7 +263,7 @@ async def __anext__(self): if row: return row list_handle = await bindings.scan_next(self._handle) - self._buffer = EntryList(list_handle) if list_handle else None + self._buffer = iter(EntryList(list_handle)) if list_handle else None async def fetch_all(self) -> Sequence[Entry]: rows = [] @@ -407,7 +423,7 @@ async def fetch( result_handle = await bindings.session_fetch( self._handle, category, name, for_update ) - return next(EntryList(result_handle, 1), None) if result_handle else None + return next(iter(EntryList(result_handle, 1)), None) if result_handle else None async def fetch_all( self, @@ -508,7 +524,9 @@ async def fetch_key( AskarErrorCode.WRAPPER, "Cannot fetch key from closed session" ) result_handle = await bindings.session_fetch_key(self._handle, name, for_update) - return next(KeyEntryList(result_handle, 1)) if result_handle else None + return ( + next(iter(KeyEntryList(result_handle, 1)), None) if result_handle else None + ) async def fetch_all_keys( self, diff --git a/wrappers/python/tests/test_jose_ecdh.py b/wrappers/python/tests/test_jose_ecdh.py index 43a155a8..5cc2fdb2 100644 --- a/wrappers/python/tests/test_jose_ecdh.py +++ b/wrappers/python/tests/test_jose_ecdh.py @@ -113,6 +113,8 @@ def test_ecdh_1pu_direct(): KeyAlg.A256GCM, ephem_key, alice_key, bob_jwk, message, aad=protected_b64 ) ciphertext, tag, nonce = encrypted_msg.parts + print("enc", *encrypted_msg.parts) + print("enc", *encrypted_msg.parts) # switch to receiver From b20cf02f56014e9cab771ec46df41d621ff5594d Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Tue, 12 Apr 2022 19:18:57 -0700 Subject: [PATCH 24/43] require UnwindSafe instead of guaranteeing it Signed-off-by: Andrew Whitehead --- askar-crypto/src/alg/any.rs | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/askar-crypto/src/alg/any.rs b/askar-crypto/src/alg/any.rs index 1f5dfbf9..e54c85b8 100644 --- a/askar-crypto/src/alg/any.rs +++ b/askar-crypto/src/alg/any.rs @@ -4,6 +4,7 @@ use core::convert::TryFrom; use core::{ any::{Any, TypeId}, fmt::Debug, + panic::{RefUnwindSafe, UnwindSafe}, }; #[cfg(feature = "aes")] @@ -54,10 +55,10 @@ use super::EcCurves; use crate::kdf::{FromKeyDerivation, FromKeyExchange}; #[derive(Debug)] -pub struct KeyT(T); +pub struct KeyT(T); /// The type-erased representation for a concrete key instance -pub type AnyKey = KeyT; +pub type AnyKey = KeyT; impl AnyKey { pub fn algorithm(&self) -> KeyAlg { @@ -79,12 +80,6 @@ impl AnyKey { } } -// key instances are immutable -#[cfg(feature = "std")] -impl std::panic::UnwindSafe for AnyKey {} -#[cfg(feature = "std")] -impl std::panic::RefUnwindSafe for AnyKey {} - /// Create `AnyKey` instances from various sources pub trait AnyKeyCreate: Sized { /// Generate a new key from a key material generator for the given key algorithm. @@ -108,7 +103,7 @@ pub trait AnyKeyCreate: Sized { fn from_secret_bytes(alg: KeyAlg, secret: &[u8]) -> Result; /// Convert from a concrete key instance - fn from_key(key: K) -> Self; + fn from_key(key: K) -> Self; /// Create a new key instance from a key exchange fn from_key_exchange(alg: KeyAlg, secret: &Sk, public: &Pk) -> Result @@ -137,7 +132,7 @@ impl AnyKeyCreate for Box { } #[inline(always)] - fn from_key(key: K) -> Self { + fn from_key(key: K) -> Self { Box::new(KeyT(key)) } @@ -172,7 +167,7 @@ impl AnyKeyCreate for Arc { } #[inline(always)] - fn from_key(key: K) -> Self { + fn from_key(key: K) -> Self { Arc::new(KeyT(key)) } @@ -819,19 +814,19 @@ impl KeySigVerify for AnyKey { // may want to implement in-place initialization to avoid copies trait AllocKey { - fn alloc_key(key: K) -> Self; + fn alloc_key(key: K) -> Self; } impl AllocKey for Arc { #[inline(always)] - fn alloc_key(key: K) -> Self { + fn alloc_key(key: K) -> Self { Self::from_key(key) } } impl AllocKey for Box { #[inline(always)] - fn alloc_key(key: K) -> Self { + fn alloc_key(key: K) -> Self { Self::from_key(key) } } From b2bac7f82a23c3dc76710856fea678dd725b1fa2 Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Tue, 12 Apr 2022 19:19:41 -0700 Subject: [PATCH 25/43] SecretBytes api updates Signed-off-by: Andrew Whitehead --- askar-crypto/src/buffer/secret.rs | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/askar-crypto/src/buffer/secret.rs b/askar-crypto/src/buffer/secret.rs index 6ee40631..81982882 100644 --- a/askar-crypto/src/buffer/secret.rs +++ b/askar-crypto/src/buffer/secret.rs @@ -47,6 +47,12 @@ impl SecretBytes { Self(v) } + /// Accessor for the current capacity of the buffer + #[inline] + pub fn capacity(&self) -> usize { + self.0.capacity() + } + /// Accessor for the length of the buffer contents #[inline] pub fn len(&self) -> usize { @@ -66,7 +72,7 @@ impl SecretBytes { } else if cap > 0 && min_cap >= cap { // allocate a new buffer and copy the secure data over let new_cap = min_cap.max(cap * 2).max(32); - let mut buf = SecretBytes::with_capacity(new_cap); + let mut buf = Self::with_capacity(new_cap); buf.0.extend_from_slice(&self.0[..]); mem::swap(&mut buf, self); // old buf zeroized on drop @@ -93,28 +99,31 @@ impl SecretBytes { self.ensure_capacity(self.len() + extra) } - /// Convert this buffer into a boxed slice - pub fn into_boxed_slice(mut self) -> Box<[u8]> { + /// Shrink the buffer capacity to match the length + pub fn shrink_to_fit(&mut self) { let len = self.0.len(); if self.0.capacity() > len { // copy to a smaller buffer (capacity is not tracked for boxed slice) // and proceed with the normal zeroize on drop - let mut v = Vec::with_capacity(len); - v.append(&mut self.0); - v.into_boxed_slice() - } else { - // no realloc and copy needed - self.into_vec().into_boxed_slice() + let mut buf = Self::with_capacity(len); + buf.0.extend_from_slice(&self.0[..]); + mem::swap(&mut buf, self); + // old buf zeroized on drop } } + /// Convert this buffer into a boxed slice + pub fn into_boxed_slice(mut self) -> Box<[u8]> { + self.shrink_to_fit(); + self.into_vec().into_boxed_slice() + } + /// Unwrap this buffer into a Vec #[inline] pub fn into_vec(mut self) -> Vec { // FIXME zeroize extra capacity in case it was used previously? let mut v = Vec::new(); // note: no heap allocation for empty vec mem::swap(&mut v, &mut self.0); - mem::forget(self); v } From 94440450c86cf4a8990f5b9b93e36534f8f5b4cd Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Tue, 12 Apr 2022 19:20:32 -0700 Subject: [PATCH 26/43] reduce unsafe Signed-off-by: Andrew Whitehead --- src/backend/db_utils.rs | 34 +++++++++++++--------------------- src/error.rs | 7 +++++-- 2 files changed, 18 insertions(+), 23 deletions(-) diff --git a/src/backend/db_utils.rs b/src/backend/db_utils.rs index 97bb94a9..0f043188 100644 --- a/src/backend/db_utils.rs +++ b/src/backend/db_utils.rs @@ -551,32 +551,24 @@ pub fn encode_tag_filter( } } -// convert a slice of tags into a Vec, when ensuring there is +// allocate a String while ensuring there is sufficient capacity to reuse during encryption +fn _prepare_string(value: &str) -> String { + let buf = ProfileKey::prepare_input(value.as_bytes()).into_vec(); + unsafe { String::from_utf8_unchecked(buf) } +} + +// convert a slice of tags into a Vec, while ensuring there is // adequate space in the allocations to reuse them during encryption pub fn prepare_tags(tags: &[EntryTag]) -> Result, Error> { let mut result = Vec::with_capacity(tags.len()); for tag in tags { result.push(match tag { - EntryTag::Plaintext(name, value) => EntryTag::Plaintext( - unsafe { - String::from_utf8_unchecked( - ProfileKey::prepare_input(name.as_bytes()).into_vec(), - ) - }, - value.clone(), - ), - EntryTag::Encrypted(name, value) => EntryTag::Encrypted( - unsafe { - String::from_utf8_unchecked( - ProfileKey::prepare_input(name.as_bytes()).into_vec(), - ) - }, - unsafe { - String::from_utf8_unchecked( - ProfileKey::prepare_input(value.as_bytes()).into_vec(), - ) - }, - ), + EntryTag::Plaintext(name, value) => { + EntryTag::Plaintext(_prepare_string(name), value.clone()) + } + EntryTag::Encrypted(name, value) => { + EntryTag::Encrypted(_prepare_string(name), _prepare_string(value)) + } }); } Ok(result) diff --git a/src/error.rs b/src/error.rs index 1a1d79de..b20aa9f7 100644 --- a/src/error.rs +++ b/src/error.rs @@ -84,7 +84,10 @@ impl Error { self.message.as_ref().map(String::as_str) } - pub(crate) fn with_cause>>(mut self, err: T) -> Self { + pub(crate) fn with_cause>>( + mut self, + err: T, + ) -> Self { self.cause = Some(err.into()); self } @@ -108,7 +111,7 @@ impl StdError for Error { fn source(&self) -> Option<&(dyn StdError + 'static)> { self.cause .as_ref() - .map(|err| unsafe { std::mem::transmute(&**err) }) + .map(|err| &**err as &(dyn StdError + 'static)) } } From 50f57a6b856e67863963979fdc63c0ff4d700e17 Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Tue, 12 Apr 2022 19:25:31 -0700 Subject: [PATCH 27/43] update ArcHandle bounds Signed-off-by: Andrew Whitehead --- src/ffi/handle.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/ffi/handle.rs b/src/ffi/handle.rs index c4b9dfef..d0d8cf6b 100644 --- a/src/ffi/handle.rs +++ b/src/ffi/handle.rs @@ -1,18 +1,18 @@ -use std::{fmt::Display, marker::PhantomData, ptr, sync::Arc}; +use std::{fmt::Display, ptr, sync::Arc}; use crate::error::Error; #[repr(C)] -pub struct ArcHandle(*const T, PhantomData); +pub struct ArcHandle(*const T); -impl ArcHandle { +impl ArcHandle { pub fn invalid() -> Self { - Self(ptr::null(), PhantomData) + Self(ptr::null()) } pub fn create(value: T) -> Self { let results = Arc::into_raw(Arc::new(value)); - Self(results, PhantomData) + Self(results) } pub fn load(&self) -> Result, Error> { @@ -43,7 +43,7 @@ impl ArcHandle { } } -impl std::fmt::Display for ArcHandle { +impl std::fmt::Display for ArcHandle { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Handle({:p})", self.0) } From 95d9a0be101ffa96f55645971d56d056f98a1a20 Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Tue, 12 Apr 2022 19:25:53 -0700 Subject: [PATCH 28/43] adjust SecretBytes construction Signed-off-by: Andrew Whitehead --- src/ffi/secret.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/ffi/secret.rs b/src/ffi/secret.rs index d6c15267..e24b0fb7 100644 --- a/src/ffi/secret.rs +++ b/src/ffi/secret.rs @@ -30,10 +30,12 @@ impl Default for SecretBuffer { impl SecretBuffer { pub fn from_secret(buffer: impl Into) -> Self { - let mut buf = buffer.into().into_boxed_slice(); + let mut buf = buffer.into(); + buf.shrink_to_fit(); + debug_assert_eq!(buf.len(), buf.capacity()); + let mut buf = mem::ManuallyDrop::new(buf.into_vec()); let len = i64::try_from(buf.len()).expect("secret length exceeds i64::MAX"); let data = buf.as_mut_ptr(); - mem::forget(buf); Self { len, data } } From cfd722a87c045960686fb172d71323fddd15815a Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Tue, 12 Apr 2022 19:28:31 -0700 Subject: [PATCH 29/43] rename for consistency Signed-off-by: Andrew Whitehead --- wrappers/python/tests/test_store.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wrappers/python/tests/test_store.py b/wrappers/python/tests/test_store.py index 8ac3bde0..684b9d45 100644 --- a/wrappers/python/tests/test_store.py +++ b/wrappers/python/tests/test_store.py @@ -112,7 +112,7 @@ async def test_scan(store: Store): @mark.asyncio -async def test_transaction_basic(store: Store): +async def test_txn_basic(store: Store): async with store.transaction() as txn: # Insert a new entry @@ -147,7 +147,7 @@ async def test_transaction_basic(store: Store): @mark.asyncio -async def test_transaction_conflict(store: Store): +async def test_txn_contention(store: Store): async with store.transaction() as txn: await txn.insert( TEST_ENTRY["category"], From 7485ee23b859c3ce9d09d97aae6c109043bced27 Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Tue, 12 Apr 2022 19:28:49 -0700 Subject: [PATCH 30/43] minor cleanup Signed-off-by: Andrew Whitehead --- wrappers/python/aries_askar/bindings.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/wrappers/python/aries_askar/bindings.py b/wrappers/python/aries_askar/bindings.py index f4147c7c..cec2e0bb 100644 --- a/wrappers/python/aries_askar/bindings.py +++ b/wrappers/python/aries_askar/bindings.py @@ -79,8 +79,6 @@ def __init__(self, value): elif isinstance(value, bytes): dlen = len(value) data = c_char_p(value) - b = c_void_p.from_buffer(data) - del b else: raise TypeError(f"Expected str or bytes value, got {type(value)}") self._dlen = dlen @@ -655,8 +653,9 @@ def _create_callback( ): """Create a callback to handle the response from an async library method.""" - def _cb(_id: int, err: int, result=None): + def _cb(cb_id: int, err: int, result=None): """Callback function passed to the CFUNCTYPE for invocation.""" + assert cb_id == 0 exc = get_current_error() if err else None loop.call_soon_threadsafe(_fulfill_future, fut, result, exc) From 7e6a563aa75ab34cf3c6234d027112f4f74ec186 Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Tue, 12 Apr 2022 19:29:10 -0700 Subject: [PATCH 31/43] skip contention tests on github Signed-off-by: Andrew Whitehead --- .github/workflows/build.yml | 38 ++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index c27950d7..9dc0456a 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -16,7 +16,7 @@ name: "Aries-Askar" jobs: check: - name: Run Checks + name: Run checks strategy: fail-fast: false matrix: @@ -71,7 +71,7 @@ jobs: uses: actions-rs/cargo@v1 with: command: test - args: --workspace --features pg_test -- --nocapture --test-threads 1 + args: --workspace --features pg_test -- --nocapture --test-threads 1 --skip contention env: POSTGRES_URL: postgres://postgres:postgres@localhost:5432/test-db RUST_BACKTRACE: full @@ -82,7 +82,7 @@ jobs: uses: actions-rs/cargo@v1 with: command: test - args: --workspace -- --nocapture --test-threads 1 + args: --workspace -- --nocapture --test-threads 1 --skip contention env: RUST_BACKTRACE: full # RUST_LOG: debug @@ -100,7 +100,7 @@ jobs: args: --manifest-path ./askar-bbs/Cargo.toml --no-default-features build-manylinux: - name: Build Library + name: Build (manylinux) needs: [check] strategy: @@ -125,10 +125,10 @@ jobs: toolchain: 1.56 override: true - # - name: Cache cargo resources - # uses: Swatinem/rust-cache@v1 - # with: - # sharedKey: check + - name: Cache cargo resources + uses: Swatinem/rust-cache@v1 + with: + sharedKey: check - name: Build library env: @@ -142,8 +142,8 @@ jobs: name: library-${{ runner.os }} path: target/release/${{ matrix.lib }} - build-other: - name: Build Library + build-native: + name: Build (native) needs: [check] strategy: @@ -171,10 +171,10 @@ jobs: toolchain: ${{ matrix.toolchain }} override: true - # - name: Cache cargo resources - # uses: Swatinem/rust-cache@v1 - # with: - # sharedKey: check + - name: Cache cargo resources + uses: Swatinem/rust-cache@v1 + with: + sharedKey: check - name: Build library env: @@ -190,8 +190,8 @@ jobs: path: target/release/${{ matrix.lib }} build-py: - name: Build Python - needs: [build-manylinux, build-other] + name: Build Python packages + needs: [build-manylinux, build-native] strategy: fail-fast: false @@ -234,9 +234,9 @@ jobs: python setup.py bdist_wheel --python-tag=py3 --plat-name=${{ matrix.plat-name }} pip install pytest pytest-asyncio dist/* echo "-- Test SQLite in-memory --" - python -m pytest --log-cli-level=WARNING + python -m pytest --log-cli-level=WARNING -k "not contention" echo "-- Test SQLite file DB --" - TEST_STORE_URI=sqlite://test.db python -m pytest --log-cli-level=WARNING + TEST_STORE_URI=sqlite://test.db python -m pytest --log-cli-level=WARNING -k "not contention" working-directory: wrappers/python env: no_proxy: "*" # python issue 30385 @@ -249,7 +249,7 @@ jobs: sudo systemctl start postgresql.service pg_isready sudo -u postgres psql -c "ALTER USER postgres WITH PASSWORD 'postgres'" - python -m pytest --log-cli-level=WARNING + python -m pytest --log-cli-level=WARNING -k "not contention" working-directory: wrappers/python env: no_proxy: "*" # python issue 30385 From f4dd69da0403e635dd370ca83976e40a9a8b5fe6 Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Wed, 13 Apr 2022 15:55:18 -0700 Subject: [PATCH 32/43] allow custom logger to be disabled after initialization Signed-off-by: Andrew Whitehead --- src/ffi/log.rs | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/src/ffi/log.rs b/src/ffi/log.rs index 440b4a4b..4b522812 100644 --- a/src/ffi/log.rs +++ b/src/ffi/log.rs @@ -1,12 +1,16 @@ use std::ffi::CString; use std::os::raw::{c_char, c_void}; use std::ptr; +use std::sync::atomic::{AtomicBool, Ordering}; use log::{LevelFilter, Metadata, Record}; +use once_cell::sync::OnceCell; use super::error::ErrorCode; use crate::error::Error; +static LOGGER: OnceCell = OnceCell::new(); + pub type EnabledCallback = extern "C" fn(context: *const c_void, level: i32) -> i8; pub type LogCallback = extern "C" fn( @@ -26,6 +30,7 @@ pub struct CustomLogger { enabled: Option, log: LogCallback, flush: Option, + disabled: AtomicBool, } impl CustomLogger { @@ -40,13 +45,20 @@ impl CustomLogger { enabled, log, flush, + disabled: AtomicBool::new(false), } } + + fn disable(&self) { + self.disabled.store(false, Ordering::Release); + } } impl log::Log for CustomLogger { fn enabled(&self, metadata: &Metadata<'_>) -> bool { - if let Some(enabled_cb) = self.enabled { + if !self.disabled.load(Ordering::Acquire) { + false + } else if let Some(enabled_cb) = self.enabled { enabled_cb(self.context, metadata.level() as i32) != 0 } else { true @@ -98,18 +110,30 @@ pub extern "C" fn askar_set_custom_logger( ) -> ErrorCode { catch_err! { let max_level = get_level_filter(max_level)?; - let logger = CustomLogger::new(context, enabled, log, flush); - log::set_boxed_logger(Box::new(logger)).map_err(err_map!(Unexpected))?; + if LOGGER.set(CustomLogger::new(context, enabled, log, flush)).is_err() { + return Err(err_msg!(Input, "Repeated logger initialization")); + } + log::set_logger(LOGGER.get().unwrap()).map_err( + |_| err_msg!(Input, "Repeated logger initialization"))?; log::set_max_level(max_level); debug!("Initialized custom logger"); Ok(ErrorCode::Success) } } +#[no_mangle] +pub extern "C" fn askar_clear_custom_logger() { + debug!("Removing custom logger"); + if let Some(logger) = LOGGER.get() { + logger.disable(); + } +} + #[no_mangle] pub extern "C" fn askar_set_default_logger() -> ErrorCode { catch_err! { - env_logger::init(); + env_logger::try_init().map_err( + |_| err_msg!(Input, "Repeated logger initialization"))?; debug!("Initialized default logger"); Ok(ErrorCode::Success) } From d5565208cce3f7e6cea017b256933d3a271a3dde Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Wed, 13 Apr 2022 16:42:32 -0700 Subject: [PATCH 33/43] move code into Lib class, apply finalizers Signed-off-by: Andrew Whitehead --- wrappers/python/aries_askar/bindings.py | 493 +++++++++++++----------- 1 file changed, 269 insertions(+), 224 deletions(-) diff --git a/wrappers/python/aries_askar/bindings.py b/wrappers/python/aries_askar/bindings.py index cec2e0bb..b2358b61 100644 --- a/wrappers/python/aries_askar/bindings.py +++ b/wrappers/python/aries_askar/bindings.py @@ -5,6 +5,7 @@ import logging import os import sys + from ctypes import ( _SimpleCData, Array, @@ -24,21 +25,14 @@ ) from ctypes.util import find_library from typing import Optional, Tuple, Union +from weakref import finalize from .error import AskarError, AskarErrorCode from .types import EntryOperation, KeyAlg, SeedMethod -CALLBACKS = {} -INVOKE = {} -LIB: CDLL = None +LIB: "Lib" = None LOGGER = logging.getLogger(__name__) -LOG_LEVELS = { - 1: logging.ERROR, - 2: logging.WARNING, - 3: logging.INFO, - 4: logging.DEBUG, -} MODULE_NAME = __name__.split(".")[0] @@ -136,6 +130,8 @@ def __del__(self): class FfiStr: + """A string value allocated by Python.""" + def __init__(self, value=None): if value is None: value = c_char_p() @@ -301,6 +297,7 @@ class ArcHandle(Structure): _fields_ = [ ("value", c_size_t), ] + _dtor_: str = None def __init__(self, value=0): if isinstance(value, c_size_t): @@ -308,6 +305,7 @@ def __init__(self, value=0): if not isinstance(value, int): raise ValueError("Invalid handle") super().__init__(value) + finalize(self, self._cleanup) @classmethod def from_param(cls, param): @@ -315,13 +313,17 @@ def from_param(cls, param): return param return cls(param) - def __bool__(self): + def __bool__(self) -> bool: return bool(self.value) def __repr__(self) -> str: """Format handle as a string.""" return f"{self.__class__.__name__}({self.value})" + def _cleanup(self): + if self.value and self.__class__._dtor_: + invoke_dtor(self.__class__._dtor_, self) + class StoreHandle(ArcHandle): """Handle for an active Store instance.""" @@ -332,7 +334,7 @@ async def close(self): await invoke_async("askar_store_close", (StoreHandle,), self) self.value = 0 - def __del__(self): + def _cleanup(self): """Close the store when there are no more references to this object.""" if self: invoke_dtor( @@ -358,7 +360,7 @@ async def close(self, commit: bool = False): ) self.value = 0 - def __del__(self): + def _cleanup(self): """Close the session when there are no more references to this object.""" if self: invoke_dtor( @@ -374,14 +376,14 @@ def __del__(self): class ScanHandle(ArcHandle): """Handle for an active Store scan instance.""" - def __del__(self): - """Close the scan when there are no more references to this object.""" - invoke_dtor("askar_scan_free", self) + _dtor_ = "askar_scan_free" class EntryListHandle(ArcHandle): """Handle for an active EntryList instance.""" + _dtor_ = "askar_entry_list_free" + def get_category(self, index: int) -> str: """Get the entry category.""" cat = StrBuffer() @@ -437,14 +439,12 @@ def get_tags(self, index: int) -> dict: tags = dict() return tags - def __del__(self): - """Free the entry set when there are no more references.""" - invoke_dtor("askar_entry_list_free", self) - class KeyEntryListHandle(ArcHandle): """Handle for an active KeyEntryList instance.""" + _dtor_ = "askar_key_entry_list_free" + def get_algorithm(self, index: int) -> str: """Get the key algorithm.""" name = StrBuffer() @@ -505,68 +505,77 @@ def load_key(self, index: int) -> "LocalKeyHandle": ) return handle - def __del__(self): - """Free the key entry set when there are no more references.""" - invoke_dtor("askar_key_entry_list_free", self) - class LocalKeyHandle(ArcHandle): """Handle for an active LocalKey instance.""" - def __del__(self): - """Free the key when there are no more references.""" - invoke_dtor("askar_key_free", self) - - -def get_library() -> CDLL: - """Return the CDLL instance, loading it if necessary.""" - global LIB - if LIB is None: - LIB = _load_library("aries_askar") - _init_logger() - return LIB - - -def _load_library(lib_name: str) -> CDLL: - """Load the CDLL library. - The python module directory is searched first, followed by the usual - library resolution for the current system. - """ - lib_prefix_mapping = {"win32": ""} - lib_suffix_mapping = {"darwin": ".dylib", "win32": ".dll"} - try: - os_name = sys.platform - lib_prefix = lib_prefix_mapping.get(os_name, "lib") - lib_suffix = lib_suffix_mapping.get(os_name, ".so") - lib_path = os.path.join( - os.path.dirname(__file__), f"{lib_prefix}{lib_name}{lib_suffix}" - ) - return CDLL(lib_path) - except KeyError: - LOGGER.debug("Unknown platform for shared library") - except OSError: - LOGGER.warning("Library not loaded from python package") - - lib_path = find_library(lib_name) - if not lib_path: - raise AskarError( - AskarErrorCode.WRAPPER, f"Library not found in path: {lib_path}" - ) - try: - return CDLL(lib_path) - except OSError as e: - raise AskarError( - AskarErrorCode.WRAPPER, f"Error loading library: {lib_path}" - ) from e - - -def _init_logger(): - logger = logging.getLogger(MODULE_NAME) - if logging.getLevelName("TRACE") == "Level TRACE": - # avoid redefining TRACE if another library has added it - logging.addLevelName(5, "TRACE") + _dtor_ = "askar_key_free" + + +class Lib: + """Aries-Askar library instance.""" + + LOG_LEVELS = { + 1: logging.ERROR, + 2: logging.WARNING, + 3: logging.INFO, + 4: logging.DEBUG, + } + + def __init__(self): + """Initializer.""" + self._cdll = None + self._callbacks = {} + self._methods = {} + self._dtor = None + self._log_cb = None + self._log_enabled_cb = None + self._load_library("aries_askar") + self._init_logger() + finalize(self, self._cleanup) + + def _load_library(self, lib_name: str): + """Load the CDLL library. + + The python module directory is searched first, followed by the usual + library resolution for the current system. + """ + lib_prefix_mapping = {"win32": ""} + lib_suffix_mapping = {"darwin": ".dylib", "win32": ".dll"} + try: + os_name = sys.platform + lib_prefix = lib_prefix_mapping.get(os_name, "lib") + lib_suffix = lib_suffix_mapping.get(os_name, ".so") + lib_path = os.path.join( + os.path.dirname(__file__), f"{lib_prefix}{lib_name}{lib_suffix}" + ) + self._cdll = CDLL(lib_path) + return + except KeyError: + LOGGER.debug("Unknown platform for shared library") + except OSError: + LOGGER.warning("Library not loaded from python package") + + lib_path = find_library(lib_name) + if not lib_path: + raise AskarError( + AskarErrorCode.WRAPPER, f"Library not found in path: {lib_path}" + ) + try: + self._cdll = CDLL(lib_path) + except OSError as e: + raise AskarError( + AskarErrorCode.WRAPPER, f"Error loading library: {lib_path}" + ) from e + + def _init_logger(self): + if self._log_cb: + return - if not hasattr(_init_logger, "log_cb"): + logger = logging.getLogger(MODULE_NAME) + if logging.getLevelName("TRACE") == "Level TRACE": + # avoid redefining TRACE if another library has added it + logging.addLevelName(5, "TRACE") @CFUNCTYPE( None, c_void_p, c_int32, c_char_p, c_char_p, c_char_p, c_char_p, c_int32 @@ -576,180 +585,218 @@ def _log_cb( level: int, target: c_char_p, message: c_char_p, - module_path: c_char_p, + _module_path: c_char_p, file_name: c_char_p, line: int, ): logger.getChild("native." + target.decode().replace("::", ".")).log( - LOG_LEVELS.get(level, level), + Lib.LOG_LEVELS.get(level, level), "\t%s:%d | %s", file_name.decode() if file_name else None, line, message.decode(), ) - _init_logger.log_cb = _log_cb + self._log_cb = _log_cb @CFUNCTYPE(c_int8, c_void_p, c_int32) def _enabled_cb(_context, level: int) -> bool: - return logger.isEnabledFor(LOG_LEVELS.get(level, level)) + return self._cdll and logger.isEnabledFor(Lib.LOG_LEVELS.get(level, level)) - _init_logger.enabled_cb = _enabled_cb + self._log_enabled_cb = _enabled_cb - if os.getenv("RUST_LOG"): - # level from environment - level = -1 - else: - # inherit current level from logger - level = _convert_log_level(logger.level or logger.parent.level) + if os.getenv("RUST_LOG"): + # level from environment + level = -1 + else: + # inherit current level from logger + level = Lib._convert_log_level(logger.level or logger.parent.level) - invoke( - "askar_set_custom_logger", - (c_void_p, c_void_p, c_void_p, c_void_p, c_int32), - None, # context - _init_logger.log_cb, - _init_logger.enabled_cb, - None, # flush - level, - ) + set_logger = self._method( + "askar_set_custom_logger", (c_void_p, c_void_p, c_void_p, c_void_p, c_int32) + ) + if set_logger( + None, # context + self._log_cb, + self._log_enabled_cb, + None, # flush + level, + ): + raise self._get_current_error(True) + + try: + self._dtor = self._method("askar_clear_custom_logger", None, restype=None) + except AttributeError: + # method is new as of 0.2.5 + pass + + def invoke(self, name, argtypes, *args): + """Perform a synchronous library function call.""" + method = self._method(name, argtypes) + args = Lib._load_method_arguments(name, argtypes, args) + result = method(*args) + if result: + raise self._get_current_error(True) + + def invoke_async(self, name: str, argtypes, *args, return_type=None): + """Perform an asynchronous library function call.""" + method = self._method(name, (*argtypes, c_void_p, c_int64)) + loop = asyncio.get_event_loop() + fut = loop.create_future() + cf_args = [c_int64, c_int64] + if return_type: + cf_args.append(return_type) + cb_type = CFUNCTYPE(None, *cf_args) # could be cached + cb_res = self._create_callback(cb_type, loop, fut) + args = Lib._load_method_arguments(name, argtypes, args) + # save a reference to the callback function and arguments to avoid GC + self._callbacks[fut] = (cb_res, args) + result = method(*args, cb_res, 0) # not making use of callback ID + if result: + # FFI must not execute the callback if an error is returned + err = self._get_current_error(True) + self._fulfill_future(fut, None, err) + return fut + + def set_max_log_level(self, level: Union[str, int, None]): + set_level = Lib._convert_log_level(level) + self.invoke("askar_set_max_log_level", (c_int32,), set_level) + + @classmethod + def _convert_log_level(cls, level: Union[str, int, None]): + if level is None or level == "-1": + return -1 + else: + if isinstance(level, str): + level = level.upper() + name = logging.getLevelName(level) + for k, v in cls.LOG_LEVELS.items(): + if logging.getLevelName(v) == name: + return k + return 0 + + def version(self) -> str: + """Get the version of the installed aries-askar library.""" + return str( + self._method( + "askar_version", + None, + restype=StrBuffer, + )() + ) + + def _method(self, name, argtypes, *, restype=c_int64): + method = self._methods.get(name) + if not method: + method = getattr(self._cdll, name) + if argtypes: + method.argtypes = argtypes + method.restype = restype + self._methods[name] = method + return method + + @classmethod + def _load_method_arguments(cls, name, argtypes, args): + """Preload argument values to avoid freeing any intermediate data.""" + if not argtypes: + return args + if len(args) != len(argtypes): + raise ValueError(f"{name}: Arguments length does not match argtypes length") + return [ + arg if issubclass(argtype, _SimpleCData) else argtype.from_param(arg) + for (arg, argtype) in zip(args, argtypes) + ] + + def _create_callback( + self, + cb_type: CFUNCTYPE, + loop: asyncio.AbstractEventLoop, + fut: asyncio.Future, + ): + """Create a callback to handle the response from an async library method.""" + + def _cb(cb_id: int, err: int, result=None): + """Callback function passed to the CFUNCTYPE for invocation.""" + assert cb_id == 0 + exc = self._get_current_error(True) if err else None + loop.call_soon_threadsafe(self._fulfill_future, fut, result, exc) + + res = cb_type(_cb) + return res + + def _fulfill_future(self, fut: asyncio.Future, result, err: Exception = None): + """Resolve a callback future given the result and exception, if any.""" + if not self._callbacks.pop(fut, None): + LOGGER.info("callback already fulfilled") + return + if fut.cancelled(): + LOGGER.debug("callback previously cancelled") + elif err: + fut.set_exception(err) + else: + fut.set_result(result) + + def _get_current_error(self, expect: bool = False) -> Optional[AskarError]: + """ + Get the error result from the previous failed API method. + + Args: + expect: Return a default error message if none is found + """ + err_json = StrBuffer() + method = self._method("askar_get_current_error", (POINTER(c_char_p),)) + if not method(byref(err_json)): + try: + msg = json.loads(err_json.value) + except json.JSONDecodeError: + LOGGER.warning("JSON decode error for askar_get_current_error") + msg = None + if msg and "message" in msg and "code" in msg: + return AskarError( + AskarErrorCode(msg["code"]), msg["message"], msg.get("extra") + ) + if not expect: + return None + return AskarError(AskarErrorCode.WRAPPER, "Unknown error") + + def _cleanup(self): + if self._cdll and self._dtor: + self._dtor() + self._dtor = None + self._cdll = None + + def __del__(self): + self._cleanup() + + +def get_library(init: bool = True) -> Lib: + """Return the library instance, loading it if necessary.""" + global LIB + if LIB is None and init: + LIB = Lib() + return LIB def set_max_log_level(level: Union[str, int, None]): - get_library() # ensure logger is initialized - set_level = _convert_log_level(level) - invoke("askar_set_max_log_level", (c_int32,), set_level) - - -def _convert_log_level(level: Union[str, int, None]): - if level is None or level == "-1": - return -1 - else: - if isinstance(level, str): - level = level.upper() - name = logging.getLevelName(level) - for k, v in LOG_LEVELS.items(): - if logging.getLevelName(v) == name: - return k - return 0 - - -def _fulfill_future(fut: asyncio.Future, result, err: Exception = None): - """Resolve a callback future given the result and exception, if any.""" - if not CALLBACKS.pop(fut, None): - LOGGER.info("callback already fulfilled") - return - if fut.cancelled(): - LOGGER.debug("callback previously cancelled") - elif err: - fut.set_exception(err) - else: - fut.set_result(result) - - -def _create_callback( - cb_type: CFUNCTYPE, - loop: asyncio.AbstractEventLoop, - fut: asyncio.Future, -): - """Create a callback to handle the response from an async library method.""" - - def _cb(cb_id: int, err: int, result=None): - """Callback function passed to the CFUNCTYPE for invocation.""" - assert cb_id == 0 - exc = get_current_error() if err else None - loop.call_soon_threadsafe(_fulfill_future, fut, result, exc) - - res = cb_type(_cb) - return res - - -def _get_library_method(name: str, argtypes, *, restype=c_int64): - method = INVOKE.get(name) - if not method: - method = getattr(get_library(), name) - method.argtypes = argtypes - method.restype = restype - INVOKE[name] = method - return method - - -def _load_method_arguments(name, argtypes, args): - """Preload argument values to avoid freeing any intermediate data.""" - if not argtypes: - return args - if len(args) != len(argtypes): - raise ValueError(f"{name}: Arguments length does not match argtypes length") - return [ - arg if issubclass(argtype, _SimpleCData) else argtype.from_param(arg) - for (arg, argtype) in zip(args, argtypes) - ] + """Set the maximum logging level.""" + get_library().set_max_log_level(level) def invoke(name, argtypes, *args): """Perform a synchronous library function call.""" - method = _get_library_method(name, argtypes) - args = _load_method_arguments(name, argtypes, args) - result = method(*args) - if result: - raise get_current_error(True) + get_library().invoke(name, argtypes, *args) -def invoke_async(name: str, argtypes, *args, return_type=None): +def invoke_async(name: str, argtypes, *args, return_type=None) -> asyncio.Future: """Perform an asynchronous library function call.""" - method = _get_library_method(name, (*argtypes, c_void_p, c_int64)) - loop = asyncio.get_event_loop() - fut = loop.create_future() - cf_args = [c_int64, c_int64] - if return_type: - cf_args.append(return_type) - cb_type = CFUNCTYPE(None, *cf_args) # could be cached - cb_res = _create_callback(cb_type, loop, fut) - args = _load_method_arguments(name, argtypes, args) - # save a reference to the callback function and arguments to avoid GC - CALLBACKS[fut] = (cb_res, args) - result = method(*args, cb_res, 0) # not making use of callback ID - if result: - # FFI must not execute the callback if an error is returned - err = get_current_error(True) - _fulfill_future(fut, None, err) - return fut + return get_library().invoke_async(name, argtypes, *args, return_type=return_type) def invoke_dtor(name: str, *values, argtypes=None): - method = INVOKE.get(name) - if not method: - lib = get_library() - if not lib: - return - method = getattr(lib, name) - if argtypes: - method.argtypes = argtypes - method.restype = None - INVOKE[name] = method - method(*values) - - -def get_current_error(expect: bool = False) -> Optional[AskarError]: - """ - Get the error result from the previous failed API method. - - Args: - expect: Return a default error message if none is found - """ - err_json = StrBuffer() - if not LIB or not LIB.askar_get_current_error(byref(err_json)): - try: - msg = json.loads(err_json.value) - except json.JSONDecodeError: - LOGGER.warning("JSON decode error for askar_get_current_error") - msg = None - if msg and "message" in msg and "code" in msg: - return AskarError( - AskarErrorCode(msg["code"]), msg["message"], msg.get("extra") - ) - if not expect: - return None - return AskarError(AskarErrorCode.WRAPPER, "Unknown error") + lib = get_library(False) + if lib: + method = lib._method(name, argtypes, restype=None) + method(*values) def generate_raw_key(seed: Union[str, bytes] = None) -> str: @@ -766,9 +813,7 @@ def generate_raw_key(seed: Union[str, bytes] = None) -> str: def version() -> str: """Get the version of the installed aries-askar library.""" - lib = get_library() - lib.askar_version.restype = c_void_p - return str(StrBuffer(lib.askar_version())) + return get_library().version() async def store_open( From e9b3cf7b062b81caf295327914e0bda4e1cfc1ca Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Wed, 13 Apr 2022 18:00:59 -0700 Subject: [PATCH 34/43] add docstrings Signed-off-by: Andrew Whitehead --- src/future.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/future.rs b/src/future.rs index 0aace01d..b87544ac 100644 --- a/src/future.rs +++ b/src/future.rs @@ -7,10 +7,12 @@ pub type BoxFuture<'a, T> = Pin + Send + 'a>>; static RUNTIME: Lazy = Lazy::new(|| Runtime::new().expect("Error creating tokio runtime")); +/// Block the current thread on an async task, when not running inside the scheduler. pub fn block_on(f: impl Future) -> R { RUNTIME.block_on(f) } +/// Run a blocking task without interrupting the async scheduler. #[inline] pub async fn unblock(f: F) -> T where @@ -23,16 +25,21 @@ where .expect("Error running blocking task") } +/// Spawn an async task into the runtime. #[inline] pub fn spawn_ok(fut: impl Future + Send + 'static) { RUNTIME.spawn(fut); } +/// Wait until a specific duration has passed (used in tests). +#[doc(hidden)] pub async fn sleep(dur: Duration) { let _rt = RUNTIME.enter(); tokio::time::sleep(dur).await } +/// Cancel an async task if it does not complete after a timeout (used in tests). +#[doc(hidden)] pub async fn timeout(dur: Duration, f: impl Future) -> Option { let _rt = RUNTIME.enter(); tokio::time::timeout(dur, f).await.ok() From 3c9d2b5c3469d169c187897035d6ae3bda343f90 Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Wed, 13 Apr 2022 18:01:38 -0700 Subject: [PATCH 35/43] wait for callbacks on shutdown Signed-off-by: Andrew Whitehead --- wrappers/python/aries_askar/bindings.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/wrappers/python/aries_askar/bindings.py b/wrappers/python/aries_askar/bindings.py index b2358b61..cf5e81ca 100644 --- a/wrappers/python/aries_askar/bindings.py +++ b/wrappers/python/aries_askar/bindings.py @@ -5,6 +5,8 @@ import logging import os import sys +import threading +import time from ctypes import ( _SimpleCData, @@ -323,6 +325,7 @@ def __repr__(self) -> str: def _cleanup(self): if self.value and self.__class__._dtor_: invoke_dtor(self.__class__._dtor_, self) + self.value = 0 class StoreHandle(ArcHandle): @@ -513,8 +516,9 @@ class LocalKeyHandle(ArcHandle): class Lib: - """Aries-Askar library instance.""" + """The loaded library instance.""" + LIB_NAME = "aries_askar" LOG_LEVELS = { 1: logging.ERROR, 2: logging.WARNING, @@ -530,7 +534,7 @@ def __init__(self): self._dtor = None self._log_cb = None self._log_enabled_cb = None - self._load_library("aries_askar") + self._load_library(self.__class__.LIB_NAME) self._init_logger() finalize(self, self._cleanup) @@ -676,7 +680,7 @@ def _convert_log_level(cls, level: Union[str, int, None]): return 0 def version(self) -> str: - """Get the version of the installed aries-askar library.""" + """Get the version of the installed library.""" return str( self._method( "askar_version", @@ -759,7 +763,20 @@ def _get_current_error(self, expect: bool = False) -> Optional[AskarError]: return None return AskarError(AskarErrorCode.WRAPPER, "Unknown error") + def _wait_callbacks(self): + while self._callbacks: + time.sleep(0.01) + def _cleanup(self): + if self._callbacks: + th = threading.Thread(target=self._wait_callbacks) + th.start() + th.join(timeout=1.0) + if th.is_alive(): + LOGGER.error( + "%s: Timed out waiting for callbacks to complete", + self.__class__.LIB_NAME, + ) if self._cdll and self._dtor: self._dtor() self._dtor = None @@ -812,7 +829,7 @@ def generate_raw_key(seed: Union[str, bytes] = None) -> str: def version() -> str: - """Get the version of the installed aries-askar library.""" + """Get the version of the installed library.""" return get_library().version() From 049eec90d63de5818ca64ef19262202590618825 Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Wed, 13 Apr 2022 21:55:21 -0700 Subject: [PATCH 36/43] refactor callback handling Signed-off-by: Andrew Whitehead --- wrappers/python/aries_askar/bindings.py | 101 ++++++++++++++---------- 1 file changed, 60 insertions(+), 41 deletions(-) diff --git a/wrappers/python/aries_askar/bindings.py b/wrappers/python/aries_askar/bindings.py index cf5e81ca..a7502848 100644 --- a/wrappers/python/aries_askar/bindings.py +++ b/wrappers/python/aries_askar/bindings.py @@ -2,6 +2,7 @@ import asyncio import json +import itertools import logging import os import sys @@ -33,7 +34,6 @@ from .types import EntryOperation, KeyAlg, SeedMethod -LIB: "Lib" = None LOGGER = logging.getLogger(__name__) MODULE_NAME = __name__.split(".")[0] @@ -105,6 +105,10 @@ class ByteBuffer(Structure): _fields_ = [("buffer", RawBuffer)] + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + finalize(self, self._cleanup) + @property def array(self) -> Array: return self.buffer.array @@ -126,7 +130,7 @@ def __repr__(self) -> str: """Format byte buffer as a string.""" return f"{self.__class__.__name__}({bytes(self)})" - def __del__(self): + def _cleanup(self): """Call the byte buffer destructor when this instance is released.""" invoke_dtor("askar_buffer_free", self.buffer) @@ -192,6 +196,10 @@ def from_param(cls, tags): class StrBuffer(c_char_p): """A string allocated by the library.""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + finalize(self, self._cleanup) + def is_none(self) -> bool: """Check if the returned string pointer is null.""" return self.value is None @@ -211,7 +219,7 @@ def __str__(self): val = self.opt_str() return val if val is not None else "" - def __del__(self): + def _cleanup(self): """Call the string destructor when this instance is released.""" if self: invoke_dtor("askar_string_free", self) @@ -242,6 +250,10 @@ class Encrypted(Structure): ("nonce_pos", c_int64), ] + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + finalize(self, self._cleanup) + def __getitem__(self, idx) -> bytes: return bytes(self.buffer.array[idx]) @@ -288,7 +300,7 @@ def __repr__(self) -> str: f" nonce={self.nonce})>" ) - def __del__(self): + def _cleanup(self): """Call the byte buffer destructor when this instance is released.""" invoke_dtor("askar_buffer_free", self.buffer) @@ -530,14 +542,22 @@ def __init__(self): """Initializer.""" self._cdll = None self._callbacks = {} + self._cb_id = itertools.count(0) + self._cfuncs = {} self._methods = {} self._dtor = None self._log_cb = None self._log_enabled_cb = None - self._load_library(self.__class__.LIB_NAME) - self._init_logger() + # This is called prior to any related object finalizers, + # and before the library itself is loaded. finalize(self, self._cleanup) + def load(self): + """Load the library.""" + if not self._cdll: + self._load_library(self.__class__.LIB_NAME) + self._init_logger() + def _load_library(self, lib_name: str): """Load the CDLL library. @@ -581,9 +601,10 @@ def _init_logger(self): # avoid redefining TRACE if another library has added it logging.addLevelName(5, "TRACE") - @CFUNCTYPE( + self._log_cb_t = CFUNCTYPE( None, c_void_p, c_int32, c_char_p, c_char_p, c_char_p, c_char_p, c_int32 ) + def _log_cb( _context, level: int, @@ -601,13 +622,14 @@ def _log_cb( message.decode(), ) - self._log_cb = _log_cb + self._log_cb = self._log_cb_t(_log_cb) + + self._log_enabled_cb_t = CFUNCTYPE(c_int8, c_void_p, c_int32) - @CFUNCTYPE(c_int8, c_void_p, c_int32) def _enabled_cb(_context, level: int) -> bool: return self._cdll and logger.isEnabledFor(Lib.LOG_LEVELS.get(level, level)) - self._log_enabled_cb = _enabled_cb + self._log_enabled_cb = self._log_enabled_cb_t(_enabled_cb) if os.getenv("RUST_LOG"): # level from environment @@ -647,18 +669,26 @@ def invoke_async(self, name: str, argtypes, *args, return_type=None): method = self._method(name, (*argtypes, c_void_p, c_int64)) loop = asyncio.get_event_loop() fut = loop.create_future() - cf_args = [c_int64, c_int64] - if return_type: - cf_args.append(return_type) - cb_type = CFUNCTYPE(None, *cf_args) # could be cached - cb_res = self._create_callback(cb_type, loop, fut) + cfunc = self._cfuncs.get(name) + if cfunc: + cb_res = cfunc[1] + else: + cf_args = [c_int64, c_int64] + if return_type: + cf_args.append(return_type) + cb_type = CFUNCTYPE(None, *cf_args) + cb_res = cb_type(self._handle_callback) + # must maintain a reference to cb_type as well, otherwise + # it may be freed, resulting in memory errors. + self._cfuncs[name] = (cb_type, cb_res) args = Lib._load_method_arguments(name, argtypes, args) - # save a reference to the callback function and arguments to avoid GC - self._callbacks[fut] = (cb_res, args) - result = method(*args, cb_res, 0) # not making use of callback ID + cb_id = next(self._cb_id) + self._callbacks[cb_id] = (loop, fut, args) + result = method(*args, cb_res, cb_id) if result: # FFI must not execute the callback if an error is returned err = self._get_current_error(True) + self._callbacks.pop(cb_id) self._fulfill_future(fut, None, err) return fut @@ -711,28 +741,17 @@ def _load_method_arguments(cls, name, argtypes, args): for (arg, argtype) in zip(args, argtypes) ] - def _create_callback( - self, - cb_type: CFUNCTYPE, - loop: asyncio.AbstractEventLoop, - fut: asyncio.Future, - ): - """Create a callback to handle the response from an async library method.""" - - def _cb(cb_id: int, err: int, result=None): - """Callback function passed to the CFUNCTYPE for invocation.""" - assert cb_id == 0 - exc = self._get_current_error(True) if err else None - loop.call_soon_threadsafe(self._fulfill_future, fut, result, exc) - - res = cb_type(_cb) - return res + def _handle_callback(self, cb_id: int, err: int, result=None): + exc = self._get_current_error(True) if err else None + cb = self._callbacks.pop(cb_id, None) + if not cb: + LOGGER.info("Callback already fulfilled: %s", cb_id) + return + (loop, fut, _) = cb + loop.call_soon_threadsafe(self._fulfill_future, fut, result, exc) def _fulfill_future(self, fut: asyncio.Future, result, err: Exception = None): """Resolve a callback future given the result and exception, if any.""" - if not self._callbacks.pop(fut, None): - LOGGER.info("callback already fulfilled") - return if fut.cancelled(): LOGGER.debug("callback previously cancelled") elif err: @@ -782,15 +801,15 @@ def _cleanup(self): self._dtor = None self._cdll = None - def __del__(self): - self._cleanup() + +LIB = Lib() def get_library(init: bool = True) -> Lib: """Return the library instance, loading it if necessary.""" global LIB - if LIB is None and init: - LIB = Lib() + if LIB and init: + LIB.load() return LIB From 479f5ef990aad2be6788f10fe5eea502b61240ec Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Thu, 14 Apr 2022 16:17:11 -0700 Subject: [PATCH 37/43] simpler ArcHandle load Signed-off-by: Andrew Whitehead --- src/ffi/handle.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/ffi/handle.rs b/src/ffi/handle.rs index d0d8cf6b..a1da1c79 100644 --- a/src/ffi/handle.rs +++ b/src/ffi/handle.rs @@ -1,4 +1,4 @@ -use std::{fmt::Display, ptr, sync::Arc}; +use std::{fmt::Display, mem, ptr, sync::Arc}; use crate::error::Error; @@ -18,9 +18,8 @@ impl ArcHandle { pub fn load(&self) -> Result, Error> { self.validate()?; unsafe { - let result = Arc::from_raw(self.0); - Arc::increment_strong_count(self.0); - Ok(result) + let result = mem::ManuallyDrop::new(Arc::from_raw(self.0)); + Ok((&*result).clone()) } } From b7c5f5c72ce209b231767d01f6601a1138921498 Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Thu, 14 Apr 2022 16:17:53 -0700 Subject: [PATCH 38/43] update to sqlx 0.5.12 Signed-off-by: Andrew Whitehead --- Cargo.toml | 2 +- src/backend/sqlite/provision.rs | 15 +++++---------- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e22aea7d..cd63abd1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -71,7 +71,7 @@ path = "./askar-crypto" features = ["all_keys", "any_key", "argon2", "crypto_box", "std"] [dependencies.sqlx] -version = "0.5.11" +version = "0.5.12" default-features = false features = ["chrono", "runtime-tokio-rustls"] optional = true diff --git a/src/backend/sqlite/provision.rs b/src/backend/sqlite/provision.rs index 8e34298b..81dc2e92 100644 --- a/src/backend/sqlite/provision.rs +++ b/src/backend/sqlite/provision.rs @@ -93,13 +93,8 @@ impl SqliteStoreOptions { DEFAULT_JOURNAL_MODE }; let locking_mode = if let Some(mode) = opts.query.remove("locking_mode") { - if mode.eq_ignore_ascii_case("exclusive") { - SqliteLockingMode::Exclusive - } else if mode.eq_ignore_ascii_case("normal") { - SqliteLockingMode::Normal - } else { - return Err(err_msg!(Input, "Error parsing 'locking_mode' parameter")); - } + SqliteLockingMode::from_str(&mode) + .map_err(err_map!(Input, "Error parsing 'locking_mode' parameter"))? } else { DEFAULT_LOCKING_MODE }; @@ -136,10 +131,10 @@ impl SqliteStoreOptions { .create_if_missing(auto_create) .auto_vacuum(SqliteAutoVacuum::Incremental) .busy_timeout(self.busy_timeout) - .journal_mode(self.journal_mode.clone()) - .locking_mode(self.locking_mode.clone()) + .journal_mode(self.journal_mode) + .locking_mode(self.locking_mode) .shared_cache(self.shared_cache) - .synchronous(self.synchronous.clone()); + .synchronous(self.synchronous); #[cfg(feature = "log")] { conn_opts.log_statements(log::LevelFilter::Debug); From a001ddac403350596acea08f52618586aa8e7179 Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Thu, 14 Apr 2022 16:20:10 -0700 Subject: [PATCH 39/43] misc build updates Signed-off-by: Andrew Whitehead --- .github/workflows/build.yml | 59 ++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 33 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 9dc0456a..19f1b397 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -65,24 +65,14 @@ jobs: sudo systemctl start postgresql.service pg_isready sudo -u postgres psql -c "ALTER USER postgres WITH PASSWORD 'postgres'" + echo "POSTGRES_URL=postgres://postgres:postgres@localhost:5432/test-db" >> $GITHUB_ENV + echo "TEST_FEATURES=pg_test" >> $GITHUB_ENV - - if: "runner.os == 'Linux'" - name: Test with postgres - uses: actions-rs/cargo@v1 - with: - command: test - args: --workspace --features pg_test -- --nocapture --test-threads 1 --skip contention - env: - POSTGRES_URL: postgres://postgres:postgres@localhost:5432/test-db - RUST_BACKTRACE: full - # RUST_LOG: debug - - - if: "runner.os != 'Linux'" - name: Test without postgres + - name: Run tests uses: actions-rs/cargo@v1 with: command: test - args: --workspace -- --nocapture --test-threads 1 --skip contention + args: --workspace --features "${{ env.TEST_FEATURES }}" -- --nocapture --test-threads 1 --skip contention env: RUST_BACKTRACE: full # RUST_LOG: debug @@ -228,38 +218,41 @@ jobs: name: library-${{ runner.os }} path: wrappers/python/aries_askar/ - - name: Build and test python package + - if: "runner.os == 'Linux'" + name: Start postgres (Linux) + run: | + sudo systemctl start postgresql.service + pg_isready + sudo -u postgres psql -c "ALTER USER postgres WITH PASSWORD 'postgres'" + echo "POSTGRES_URL=postgres://postgres:postgres@localhost:5432/test-db" >> $GITHUB_ENV + + - name: Build package shell: sh run: | python setup.py bdist_wheel --python-tag=py3 --plat-name=${{ matrix.plat-name }} - pip install pytest pytest-asyncio dist/* - echo "-- Test SQLite in-memory --" - python -m pytest --log-cli-level=WARNING -k "not contention" - echo "-- Test SQLite file DB --" - TEST_STORE_URI=sqlite://test.db python -m pytest --log-cli-level=WARNING -k "not contention" working-directory: wrappers/python - env: - no_proxy: "*" # python issue 30385 - RUST_BACKTRACE: full - # RUST_LOG: debug - if: "runner.os == 'Linux'" - name: Test postgres + name: Audit wheel + run: auditwheel show wrappers/python/dist/* + + - name: Test package + shell: sh run: | - sudo systemctl start postgresql.service - pg_isready - sudo -u postgres psql -c "ALTER USER postgres WITH PASSWORD 'postgres'" + pip install pytest pytest-asyncio dist/* + echo "-- Test SQLite in-memory --" python -m pytest --log-cli-level=WARNING -k "not contention" + echo "-- Test SQLite file DB --" + TEST_STORE_URI=sqlite://test.db python -m pytest --log-cli-level=WARNING -k "not contention" + if [ -n "$POSTGRES_URL" ]; then + echo "-- Test Postgres DB --" + TEST_STORE_URI="$POSTGRES_URL" python -m pytest --log-cli-level=WARNING -k "not contention" + fi working-directory: wrappers/python env: no_proxy: "*" # python issue 30385 RUST_BACKTRACE: full # RUST_LOG: debug - TEST_STORE_URI: postgres://postgres:postgres@localhost:5432/test-db - - - if: "runner.os == 'Linux'" - name: Audit wheel - run: auditwheel show wrappers/python/dist/* - name: Upload python package uses: actions/upload-artifact@v2 From e7639dfc3815a809c83cd7befcefe2ce3da058ad Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Thu, 14 Apr 2022 16:28:47 -0700 Subject: [PATCH 40/43] re-enable lto Signed-off-by: Andrew Whitehead --- Cargo.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index cd63abd1..eb2146dc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -76,9 +76,9 @@ default-features = false features = ["chrono", "runtime-tokio-rustls"] optional = true -# [profile.release] -# lto = true -# codegen-units = 1 +[profile.release] + lto = true +codegen-units = 1 [[test]] name = "backends" From 012bc9dc531a3b3008dc62ac373051eb1d01ae79 Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Thu, 14 Apr 2022 17:42:40 -0700 Subject: [PATCH 41/43] update askar-crypto to 0.2.5 Signed-off-by: Andrew Whitehead --- Cargo.toml | 2 +- askar-crypto/Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index eb2146dc..af2f5141 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -66,7 +66,7 @@ uuid = { version = "0.8", features = ["v4"] } zeroize = "1.4" [dependencies.askar-crypto] -version = "0.2" +version = "0.2.5" path = "./askar-crypto" features = ["all_keys", "any_key", "argon2", "crypto_box", "std"] diff --git a/askar-crypto/Cargo.toml b/askar-crypto/Cargo.toml index 7f7cb7f7..1848df01 100644 --- a/askar-crypto/Cargo.toml +++ b/askar-crypto/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "askar-crypto" -version = "0.2.4" +version = "0.2.5" authors = ["Hyperledger Aries Contributors "] edition = "2018" description = "Hyperledger Aries Askar cryptography" From ce308ebf7c133f269e72a884564c4372b39d2bbc Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Mon, 18 Apr 2022 13:36:45 -0700 Subject: [PATCH 42/43] simplify Default for SqliteOptions Signed-off-by: Andrew Whitehead --- src/backend/sqlite/provision.rs | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/src/backend/sqlite/provision.rs b/src/backend/sqlite/provision.rs index 81dc2e92..d57b9e2b 100644 --- a/src/backend/sqlite/provision.rs +++ b/src/backend/sqlite/provision.rs @@ -45,17 +45,7 @@ pub struct SqliteStoreOptions { impl Default for SqliteStoreOptions { fn default() -> Self { - Self { - in_memory: true, - path: ":memory:".into(), - busy_timeout: DEFAULT_BUSY_TIMEOUT, - max_connections: num_cpus::get() as u32, - min_connections: DEFAULT_MIN_CONNECTIONS, - journal_mode: DEFAULT_JOURNAL_MODE, - locking_mode: DEFAULT_LOCKING_MODE, - shared_cache: DEFAULT_SHARED_CACHE, - synchronous: DEFAULT_SYNCHRONOUS, - } + Self::new(":memory:").expect("Error initializing with default options") } } From 3a9135aa2d2d8a6218d9389afd2ccf18f936cba0 Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Mon, 18 Apr 2022 13:37:21 -0700 Subject: [PATCH 43/43] split up wrapper binding classes; fix circular references in finalizers; add unit tests Signed-off-by: Andrew Whitehead --- wrappers/python/aries_askar/__init__.py | 2 +- wrappers/python/aries_askar/bindings.py | 1723 ----------------- .../python/aries_askar/bindings/__init__.py | 948 +++++++++ .../python/aries_askar/bindings/handle.py | 248 +++ wrappers/python/aries_askar/bindings/lib.py | 672 +++++++ wrappers/python/aries_askar/key.py | 5 +- wrappers/python/aries_askar/store.py | 40 +- wrappers/python/tests/test_cleanup.py | 60 + wrappers/python/tests/test_jose_ecdh.py | 2 - 9 files changed, 1949 insertions(+), 1751 deletions(-) delete mode 100644 wrappers/python/aries_askar/bindings.py create mode 100644 wrappers/python/aries_askar/bindings/__init__.py create mode 100644 wrappers/python/aries_askar/bindings/handle.py create mode 100644 wrappers/python/aries_askar/bindings/lib.py create mode 100644 wrappers/python/tests/test_cleanup.py diff --git a/wrappers/python/aries_askar/__init__.py b/wrappers/python/aries_askar/__init__.py index 09a6f9eb..608e045b 100644 --- a/wrappers/python/aries_askar/__init__.py +++ b/wrappers/python/aries_askar/__init__.py @@ -1,6 +1,6 @@ """aries-askar Python wrapper library""" -from .bindings import version, Encrypted +from .bindings import Encrypted, version from .error import AskarError, AskarErrorCode from .key import Key from .store import Entry, EntryList, KeyEntry, KeyEntryList, Session, Store diff --git a/wrappers/python/aries_askar/bindings.py b/wrappers/python/aries_askar/bindings.py deleted file mode 100644 index a7502848..00000000 --- a/wrappers/python/aries_askar/bindings.py +++ /dev/null @@ -1,1723 +0,0 @@ -"""Low-level interaction with the aries-askar library.""" - -import asyncio -import json -import itertools -import logging -import os -import sys -import threading -import time - -from ctypes import ( - _SimpleCData, - Array, - CDLL, - CFUNCTYPE, - POINTER, - Structure, - byref, - cast, - c_char_p, - c_int8, - c_int32, - c_int64, - c_size_t, - c_void_p, - c_ubyte, -) -from ctypes.util import find_library -from typing import Optional, Tuple, Union -from weakref import finalize - -from .error import AskarError, AskarErrorCode -from .types import EntryOperation, KeyAlg, SeedMethod - - -LOGGER = logging.getLogger(__name__) -MODULE_NAME = __name__.split(".")[0] - - -class RawBuffer(Structure): - """A byte buffer allocated by the library.""" - - _fields_ = [ - ("len", c_int64), - ("data", POINTER(c_ubyte)), - ] - - def __bytes__(self) -> bytes: - if not self.len: - return b"" - return bytes(self.array) - - def __len__(self) -> int: - return int(self.len) - - @property - def array(self) -> Array: - return cast(self.data, POINTER(c_ubyte * self.len)).contents - - -class FfiByteBuffer: - """A byte buffer allocated by Python.""" - - def __init__(self, value): - if isinstance(value, str): - value = value.encode("utf-8") - - if value is None: - dlen = 0 - data = c_char_p() - elif isinstance(value, memoryview): - dlen = value.nbytes - data = c_char_p(value.tobytes()) - elif isinstance(value, bytes): - dlen = len(value) - data = c_char_p(value) - else: - raise TypeError(f"Expected str or bytes value, got {type(value)}") - self._dlen = dlen - self._data = data - - def __bytes__(self) -> bytes: - if not self._data: - return b"" - return self._data.value - - def __len__(self) -> int: - return self._dlen - - @property - def _as_parameter_(self) -> RawBuffer: - buf = RawBuffer(len=self._dlen, data=cast(self._data, POINTER(c_ubyte))) - return buf - - @classmethod - def from_param(cls, value): - if isinstance(value, (ByteBuffer, FfiByteBuffer)): - return value - return cls(value) - - -class ByteBuffer(Structure): - """A managed byte buffer allocated by the library.""" - - _fields_ = [("buffer", RawBuffer)] - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - finalize(self, self._cleanup) - - @property - def array(self) -> Array: - return self.buffer.array - - @property - def view(self) -> memoryview: - return memoryview(self.array) - - def __bytes__(self) -> bytes: - return bytes(self.buffer) - - def __len__(self) -> int: - return len(self.buffer) - - def __getitem__(self, idx) -> bytes: - return bytes(self.buffer.array[idx]) - - def __repr__(self) -> str: - """Format byte buffer as a string.""" - return f"{self.__class__.__name__}({bytes(self)})" - - def _cleanup(self): - """Call the byte buffer destructor when this instance is released.""" - invoke_dtor("askar_buffer_free", self.buffer) - - -class FfiStr: - """A string value allocated by Python.""" - - def __init__(self, value=None): - if value is None: - value = c_char_p() - elif isinstance(value, c_char_p): - pass - else: - if isinstance(value, str): - value = value.encode("utf-8") - if not isinstance(value, bytes): - raise TypeError(f"Expected string value, got {type(value)}") - value = c_char_p(value) - self.value = value - - @classmethod - def from_param(cls, value): - if isinstance(value, cls): - return value - return cls(value) - - @property - def _as_parameter_(self): - return self.value - - def __repr__(self) -> str: - """Format handle as a string.""" - return f"{self.__class__.__name__}({self.value})" - - -class FfiJson: - @classmethod - def from_param(cls, value): - if isinstance(value, FfiStr): - return value - if isinstance(value, dict): - value = json.dumps(value) - return FfiStr(value) - - -class FfiTagsJson: - @classmethod - def from_param(cls, tags): - if isinstance(tags, FfiStr): - return tags - if tags: - tags = json.dumps( - { - name: (list(value) if isinstance(value, set) else value) - for name, value in tags.items() - } - ) - else: - tags = None - return FfiStr(tags) - - -class StrBuffer(c_char_p): - """A string allocated by the library.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - finalize(self, self._cleanup) - - def is_none(self) -> bool: - """Check if the returned string pointer is null.""" - return self.value is None - - def opt_str(self) -> Optional[str]: - """Convert to an optional string.""" - val = self.value - return val.decode("utf-8") if val is not None else None - - def __bytes__(self) -> bytes: - """Convert to bytes.""" - return self.value - - def __str__(self): - """Convert to a string.""" - # not allowed to return None - val = self.opt_str() - return val if val is not None else "" - - def _cleanup(self): - """Call the string destructor when this instance is released.""" - if self: - invoke_dtor("askar_string_free", self) - - -class AeadParams(Structure): - """A byte buffer allocated by the library.""" - - _fields_ = [ - ("nonce_length", c_int32), - ("tag_length", c_int32), - ] - - def __repr__(self) -> str: - """Format AEAD params as a string.""" - return ( - f"" - ) - - -class Encrypted(Structure): - """The result of an AEAD encryption operation.""" - - _fields_ = [ - ("buffer", RawBuffer), - ("tag_pos", c_int64), - ("nonce_pos", c_int64), - ] - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - finalize(self, self._cleanup) - - def __getitem__(self, idx) -> bytes: - return bytes(self.buffer.array[idx]) - - def __bytes__(self) -> bytes: - """Convert to bytes.""" - return self.ciphertext_tag - - @property - def ciphertext_tag(self) -> bytes: - """Accessor for the combined ciphertext and tag.""" - p = self.nonce_pos - return self[:p] - - @property - def ciphertext(self) -> bytes: - """Accessor for the ciphertext.""" - p = self.tag_pos - return self[:p] - - @property - def nonce(self) -> bytes: - """Accessor for the nonce.""" - p = self.nonce_pos - return self[p:] - - @property - def tag(self) -> bytes: - """Accessor for the authentication tag.""" - p1 = self.tag_pos - p2 = self.nonce_pos - return self[p1:p2] - - @property - def parts(self) -> Tuple[bytes, bytes, bytes]: - """Accessor for the ciphertext, tag, and nonce.""" - p1 = self.tag_pos - p2 = self.nonce_pos - return self[:p1], self[p1:p2], self[p2:] - - def __repr__(self) -> str: - """Format encrypted value as a string.""" - return ( - f"" - ) - - def _cleanup(self): - """Call the byte buffer destructor when this instance is released.""" - invoke_dtor("askar_buffer_free", self.buffer) - - -class ArcHandle(Structure): - """Base class for handle instances.""" - - _fields_ = [ - ("value", c_size_t), - ] - _dtor_: str = None - - def __init__(self, value=0): - if isinstance(value, c_size_t): - value = value.value - if not isinstance(value, int): - raise ValueError("Invalid handle") - super().__init__(value) - finalize(self, self._cleanup) - - @classmethod - def from_param(cls, param): - if isinstance(param, cls): - return param - return cls(param) - - def __bool__(self) -> bool: - return bool(self.value) - - def __repr__(self) -> str: - """Format handle as a string.""" - return f"{self.__class__.__name__}({self.value})" - - def _cleanup(self): - if self.value and self.__class__._dtor_: - invoke_dtor(self.__class__._dtor_, self) - self.value = 0 - - -class StoreHandle(ArcHandle): - """Handle for an active Store instance.""" - - async def close(self): - """Close the store, waiting for any active connections.""" - if self: - await invoke_async("askar_store_close", (StoreHandle,), self) - self.value = 0 - - def _cleanup(self): - """Close the store when there are no more references to this object.""" - if self: - invoke_dtor( - "askar_store_close", - self, - None, - 0, - argtypes=(StoreHandle, c_void_p, c_int64), - ) - - -class SessionHandle(ArcHandle): - """Handle for an active Session/Transaction instance.""" - - async def close(self, commit: bool = False): - """Close the session.""" - if self: - await invoke_async( - "askar_session_close", - (SessionHandle, c_int8), - self, - commit, - ) - self.value = 0 - - def _cleanup(self): - """Close the session when there are no more references to this object.""" - if self: - invoke_dtor( - "askar_session_close", - self, - 0, - None, - 0, - argtypes=(SessionHandle, c_int8, c_void_p, c_int64), - ) - - -class ScanHandle(ArcHandle): - """Handle for an active Store scan instance.""" - - _dtor_ = "askar_scan_free" - - -class EntryListHandle(ArcHandle): - """Handle for an active EntryList instance.""" - - _dtor_ = "askar_entry_list_free" - - def get_category(self, index: int) -> str: - """Get the entry category.""" - cat = StrBuffer() - invoke( - "askar_entry_list_get_category", - (EntryListHandle, c_int32, POINTER(c_char_p)), - self, - index, - byref(cat), - ) - return str(cat) - - def get_name(self, index: int) -> str: - """Get the entry name.""" - name = StrBuffer() - invoke( - "askar_entry_list_get_name", - (EntryListHandle, c_int32, POINTER(c_char_p)), - self, - index, - byref(name), - ) - return str(name) - - def get_value(self, index: int) -> ByteBuffer: - """Get the entry value.""" - val = ByteBuffer() - invoke( - "askar_entry_list_get_value", - (EntryListHandle, c_int32, POINTER(ByteBuffer)), - self, - index, - byref(val), - ) - return val - - def get_tags(self, index: int) -> dict: - """Get the entry tags.""" - tags = StrBuffer() - invoke( - "askar_entry_list_get_tags", - (EntryListHandle, c_int32, POINTER(c_char_p)), - self, - index, - byref(tags), - ) - if tags: - tags = json.loads(tags.value) - for t in tags: - if isinstance(tags[t], list): - tags[t] = set(tags[t]) - else: - tags = dict() - return tags - - -class KeyEntryListHandle(ArcHandle): - """Handle for an active KeyEntryList instance.""" - - _dtor_ = "askar_key_entry_list_free" - - def get_algorithm(self, index: int) -> str: - """Get the key algorithm.""" - name = StrBuffer() - invoke( - "askar_key_entry_list_get_algorithm", - (KeyEntryListHandle, c_int32, POINTER(c_char_p)), - self, - index, - byref(name), - ) - return str(name) - - def get_name(self, index: int) -> str: - """Get the key name.""" - name = StrBuffer() - invoke( - "askar_key_entry_list_get_name", - (KeyEntryListHandle, c_int32, POINTER(c_char_p)), - self, - index, - byref(name), - ) - return str(name) - - def get_metadata(self, index: int) -> str: - """Get for the key metadata.""" - metadata = StrBuffer() - invoke( - "askar_key_entry_list_get_metadata", - (KeyEntryListHandle, c_int32, POINTER(c_char_p)), - self, - index, - byref(metadata), - ) - return str(metadata) - - def get_tags(self, index: int) -> dict: - """Get the key tags.""" - tags = StrBuffer() - invoke( - "askar_key_entry_list_get_tags", - (KeyEntryListHandle, c_int32, POINTER(c_char_p)), - self, - index, - byref(tags), - ) - return json.loads(tags.value) if tags else None - - def load_key(self, index: int) -> "LocalKeyHandle": - """Load the key instance.""" - handle = LocalKeyHandle() - invoke( - "askar_key_entry_list_load_local", - (KeyEntryListHandle, c_int32, POINTER(LocalKeyHandle)), - self, - index, - byref(handle), - ) - return handle - - -class LocalKeyHandle(ArcHandle): - """Handle for an active LocalKey instance.""" - - _dtor_ = "askar_key_free" - - -class Lib: - """The loaded library instance.""" - - LIB_NAME = "aries_askar" - LOG_LEVELS = { - 1: logging.ERROR, - 2: logging.WARNING, - 3: logging.INFO, - 4: logging.DEBUG, - } - - def __init__(self): - """Initializer.""" - self._cdll = None - self._callbacks = {} - self._cb_id = itertools.count(0) - self._cfuncs = {} - self._methods = {} - self._dtor = None - self._log_cb = None - self._log_enabled_cb = None - # This is called prior to any related object finalizers, - # and before the library itself is loaded. - finalize(self, self._cleanup) - - def load(self): - """Load the library.""" - if not self._cdll: - self._load_library(self.__class__.LIB_NAME) - self._init_logger() - - def _load_library(self, lib_name: str): - """Load the CDLL library. - - The python module directory is searched first, followed by the usual - library resolution for the current system. - """ - lib_prefix_mapping = {"win32": ""} - lib_suffix_mapping = {"darwin": ".dylib", "win32": ".dll"} - try: - os_name = sys.platform - lib_prefix = lib_prefix_mapping.get(os_name, "lib") - lib_suffix = lib_suffix_mapping.get(os_name, ".so") - lib_path = os.path.join( - os.path.dirname(__file__), f"{lib_prefix}{lib_name}{lib_suffix}" - ) - self._cdll = CDLL(lib_path) - return - except KeyError: - LOGGER.debug("Unknown platform for shared library") - except OSError: - LOGGER.warning("Library not loaded from python package") - - lib_path = find_library(lib_name) - if not lib_path: - raise AskarError( - AskarErrorCode.WRAPPER, f"Library not found in path: {lib_path}" - ) - try: - self._cdll = CDLL(lib_path) - except OSError as e: - raise AskarError( - AskarErrorCode.WRAPPER, f"Error loading library: {lib_path}" - ) from e - - def _init_logger(self): - if self._log_cb: - return - - logger = logging.getLogger(MODULE_NAME) - if logging.getLevelName("TRACE") == "Level TRACE": - # avoid redefining TRACE if another library has added it - logging.addLevelName(5, "TRACE") - - self._log_cb_t = CFUNCTYPE( - None, c_void_p, c_int32, c_char_p, c_char_p, c_char_p, c_char_p, c_int32 - ) - - def _log_cb( - _context, - level: int, - target: c_char_p, - message: c_char_p, - _module_path: c_char_p, - file_name: c_char_p, - line: int, - ): - logger.getChild("native." + target.decode().replace("::", ".")).log( - Lib.LOG_LEVELS.get(level, level), - "\t%s:%d | %s", - file_name.decode() if file_name else None, - line, - message.decode(), - ) - - self._log_cb = self._log_cb_t(_log_cb) - - self._log_enabled_cb_t = CFUNCTYPE(c_int8, c_void_p, c_int32) - - def _enabled_cb(_context, level: int) -> bool: - return self._cdll and logger.isEnabledFor(Lib.LOG_LEVELS.get(level, level)) - - self._log_enabled_cb = self._log_enabled_cb_t(_enabled_cb) - - if os.getenv("RUST_LOG"): - # level from environment - level = -1 - else: - # inherit current level from logger - level = Lib._convert_log_level(logger.level or logger.parent.level) - - set_logger = self._method( - "askar_set_custom_logger", (c_void_p, c_void_p, c_void_p, c_void_p, c_int32) - ) - if set_logger( - None, # context - self._log_cb, - self._log_enabled_cb, - None, # flush - level, - ): - raise self._get_current_error(True) - - try: - self._dtor = self._method("askar_clear_custom_logger", None, restype=None) - except AttributeError: - # method is new as of 0.2.5 - pass - - def invoke(self, name, argtypes, *args): - """Perform a synchronous library function call.""" - method = self._method(name, argtypes) - args = Lib._load_method_arguments(name, argtypes, args) - result = method(*args) - if result: - raise self._get_current_error(True) - - def invoke_async(self, name: str, argtypes, *args, return_type=None): - """Perform an asynchronous library function call.""" - method = self._method(name, (*argtypes, c_void_p, c_int64)) - loop = asyncio.get_event_loop() - fut = loop.create_future() - cfunc = self._cfuncs.get(name) - if cfunc: - cb_res = cfunc[1] - else: - cf_args = [c_int64, c_int64] - if return_type: - cf_args.append(return_type) - cb_type = CFUNCTYPE(None, *cf_args) - cb_res = cb_type(self._handle_callback) - # must maintain a reference to cb_type as well, otherwise - # it may be freed, resulting in memory errors. - self._cfuncs[name] = (cb_type, cb_res) - args = Lib._load_method_arguments(name, argtypes, args) - cb_id = next(self._cb_id) - self._callbacks[cb_id] = (loop, fut, args) - result = method(*args, cb_res, cb_id) - if result: - # FFI must not execute the callback if an error is returned - err = self._get_current_error(True) - self._callbacks.pop(cb_id) - self._fulfill_future(fut, None, err) - return fut - - def set_max_log_level(self, level: Union[str, int, None]): - set_level = Lib._convert_log_level(level) - self.invoke("askar_set_max_log_level", (c_int32,), set_level) - - @classmethod - def _convert_log_level(cls, level: Union[str, int, None]): - if level is None or level == "-1": - return -1 - else: - if isinstance(level, str): - level = level.upper() - name = logging.getLevelName(level) - for k, v in cls.LOG_LEVELS.items(): - if logging.getLevelName(v) == name: - return k - return 0 - - def version(self) -> str: - """Get the version of the installed library.""" - return str( - self._method( - "askar_version", - None, - restype=StrBuffer, - )() - ) - - def _method(self, name, argtypes, *, restype=c_int64): - method = self._methods.get(name) - if not method: - method = getattr(self._cdll, name) - if argtypes: - method.argtypes = argtypes - method.restype = restype - self._methods[name] = method - return method - - @classmethod - def _load_method_arguments(cls, name, argtypes, args): - """Preload argument values to avoid freeing any intermediate data.""" - if not argtypes: - return args - if len(args) != len(argtypes): - raise ValueError(f"{name}: Arguments length does not match argtypes length") - return [ - arg if issubclass(argtype, _SimpleCData) else argtype.from_param(arg) - for (arg, argtype) in zip(args, argtypes) - ] - - def _handle_callback(self, cb_id: int, err: int, result=None): - exc = self._get_current_error(True) if err else None - cb = self._callbacks.pop(cb_id, None) - if not cb: - LOGGER.info("Callback already fulfilled: %s", cb_id) - return - (loop, fut, _) = cb - loop.call_soon_threadsafe(self._fulfill_future, fut, result, exc) - - def _fulfill_future(self, fut: asyncio.Future, result, err: Exception = None): - """Resolve a callback future given the result and exception, if any.""" - if fut.cancelled(): - LOGGER.debug("callback previously cancelled") - elif err: - fut.set_exception(err) - else: - fut.set_result(result) - - def _get_current_error(self, expect: bool = False) -> Optional[AskarError]: - """ - Get the error result from the previous failed API method. - - Args: - expect: Return a default error message if none is found - """ - err_json = StrBuffer() - method = self._method("askar_get_current_error", (POINTER(c_char_p),)) - if not method(byref(err_json)): - try: - msg = json.loads(err_json.value) - except json.JSONDecodeError: - LOGGER.warning("JSON decode error for askar_get_current_error") - msg = None - if msg and "message" in msg and "code" in msg: - return AskarError( - AskarErrorCode(msg["code"]), msg["message"], msg.get("extra") - ) - if not expect: - return None - return AskarError(AskarErrorCode.WRAPPER, "Unknown error") - - def _wait_callbacks(self): - while self._callbacks: - time.sleep(0.01) - - def _cleanup(self): - if self._callbacks: - th = threading.Thread(target=self._wait_callbacks) - th.start() - th.join(timeout=1.0) - if th.is_alive(): - LOGGER.error( - "%s: Timed out waiting for callbacks to complete", - self.__class__.LIB_NAME, - ) - if self._cdll and self._dtor: - self._dtor() - self._dtor = None - self._cdll = None - - -LIB = Lib() - - -def get_library(init: bool = True) -> Lib: - """Return the library instance, loading it if necessary.""" - global LIB - if LIB and init: - LIB.load() - return LIB - - -def set_max_log_level(level: Union[str, int, None]): - """Set the maximum logging level.""" - get_library().set_max_log_level(level) - - -def invoke(name, argtypes, *args): - """Perform a synchronous library function call.""" - get_library().invoke(name, argtypes, *args) - - -def invoke_async(name: str, argtypes, *args, return_type=None) -> asyncio.Future: - """Perform an asynchronous library function call.""" - return get_library().invoke_async(name, argtypes, *args, return_type=return_type) - - -def invoke_dtor(name: str, *values, argtypes=None): - lib = get_library(False) - if lib: - method = lib._method(name, argtypes, restype=None) - method(*values) - - -def generate_raw_key(seed: Union[str, bytes] = None) -> str: - """Generate a new raw store wrapping key.""" - key = StrBuffer() - invoke( - "askar_store_generate_raw_key", - (FfiByteBuffer, POINTER(c_char_p)), - seed, - byref(key), - ) - return str(key) - - -def version() -> str: - """Get the version of the installed library.""" - return get_library().version() - - -async def store_open( - uri: str, key_method: str = None, pass_key: str = None, profile: str = None -) -> StoreHandle: - """Open an existing Store and return the open handle.""" - return await invoke_async( - "askar_store_open", - (FfiStr, FfiStr, FfiStr, FfiStr), - uri, - key_method and key_method.lower(), - pass_key, - profile, - return_type=StoreHandle, - ) - - -async def store_provision( - uri: str, - key_method: str = None, - pass_key: str = None, - profile: str = None, - recreate: bool = False, -) -> StoreHandle: - """Provision a new Store and return the open handle.""" - return await invoke_async( - "askar_store_provision", - (FfiStr, FfiStr, FfiStr, FfiStr, c_int8), - uri, - key_method and key_method.lower(), - pass_key, - profile, - recreate, - return_type=StoreHandle, - ) - - -async def store_create_profile(handle: StoreHandle, name: str = None) -> str: - """Create a new profile in a Store.""" - return str( - await invoke_async( - "askar_store_create_profile", - (StoreHandle, FfiStr), - handle, - name, - return_type=StrBuffer, - ) - ) - - -async def store_get_profile_name(handle: StoreHandle) -> str: - """Get the name of the default Store instance profile.""" - return str( - await invoke_async( - "askar_store_get_profile_name", - (StoreHandle,), - handle, - return_type=StrBuffer, - ) - ) - - -async def store_remove_profile(handle: StoreHandle, name: str) -> bool: - """Remove an existing profile from a Store.""" - return ( - await invoke_async( - "askar_store_remove_profile", - (StoreHandle, FfiStr), - handle, - name, - return_type=c_int8, - ) - != 0 - ) - - -async def store_rekey( - handle: StoreHandle, - key_method: str = None, - pass_key: str = None, -) -> StoreHandle: - """Replace the store key on a Store.""" - return await invoke_async( - "askar_store_rekey", - (StoreHandle, FfiStr, FfiStr), - handle, - key_method and key_method.lower(), - pass_key, - return_type=c_int8, - ) - - -async def store_remove(uri: str) -> bool: - """Remove an existing Store, if any.""" - return ( - await invoke_async( - "askar_store_remove", - (FfiStr,), - uri, - return_type=c_int8, - ) - != 0 - ) - - -async def session_start( - handle: StoreHandle, profile: Optional[str] = None, as_transaction: bool = False -) -> SessionHandle: - """Start a new session with an open Store.""" - return await invoke_async( - "askar_session_start", - (StoreHandle, FfiStr, c_int8), - handle, - profile, - as_transaction, - return_type=SessionHandle, - ) - - -async def session_count( - handle: SessionHandle, category: str, tag_filter: Union[str, dict] = None -) -> int: - """Count rows in the Store.""" - return int( - await invoke_async( - "askar_session_count", - (SessionHandle, FfiStr, FfiJson), - handle, - category, - tag_filter, - return_type=c_int64, - ) - ) - - -async def session_fetch( - handle: SessionHandle, category: str, name: str, for_update: bool = False -) -> EntryListHandle: - """Fetch a row from the Store.""" - return await invoke_async( - "askar_session_fetch", - (SessionHandle, FfiStr, FfiStr, c_int8), - handle, - category, - name, - for_update, - return_type=EntryListHandle, - ) - - -async def session_fetch_all( - handle: SessionHandle, - category: str, - tag_filter: Union[str, dict] = None, - limit: int = None, - for_update: bool = False, -) -> EntryListHandle: - """Fetch all matching rows in the Store.""" - return await invoke_async( - "askar_session_fetch_all", - (SessionHandle, FfiStr, FfiJson, c_int64, c_int8), - handle, - category, - tag_filter, - limit if limit is not None else -1, - for_update, - return_type=EntryListHandle, - ) - - -async def session_remove_all( - handle: SessionHandle, - category: str, - tag_filter: Union[str, dict] = None, -) -> int: - """Remove all matching rows in the Store.""" - return int( - await invoke_async( - "askar_session_remove_all", - (SessionHandle, FfiStr, FfiJson), - handle, - category, - tag_filter, - return_type=c_int64, - ) - ) - - -async def session_update( - handle: SessionHandle, - operation: EntryOperation, - category: str, - name: str, - value: Union[str, bytes] = None, - tags: dict = None, - expiry_ms: Optional[int] = None, -): - """Update a Store by inserting, updating, or removing a record.""" - - return await invoke_async( - "askar_session_update", - (SessionHandle, c_int8, FfiStr, FfiStr, FfiByteBuffer, FfiTagsJson, c_int64), - handle, - operation.value, - category, - name, - value, - tags, - -1 if expiry_ms is None else expiry_ms, - ) - - -async def session_insert_key( - handle: SessionHandle, - key_handle: LocalKeyHandle, - name: str, - metadata: str = None, - tags: dict = None, - expiry_ms: Optional[int] = None, -): - return await invoke_async( - "askar_session_insert_key", - (SessionHandle, LocalKeyHandle, FfiStr, FfiStr, FfiTagsJson, c_int64), - handle, - key_handle, - name, - metadata, - tags, - -1 if expiry_ms is None else expiry_ms, - ) - - -async def session_fetch_key( - handle: SessionHandle, name: str, for_update: bool = False -) -> KeyEntryListHandle: - return await invoke_async( - "askar_session_fetch_key", - (SessionHandle, FfiStr, c_int8), - handle, - name, - for_update, - return_type=KeyEntryListHandle, - ) - - -async def session_fetch_all_keys( - handle: SessionHandle, - alg: Union[str, KeyAlg] = None, - thumbprint: str = None, - tag_filter: Union[str, dict] = None, - limit: int = None, - for_update: bool = False, -) -> KeyEntryListHandle: - """Fetch all matching keys in the Store.""" - if isinstance(alg, KeyAlg): - alg = alg.value - return await invoke_async( - "askar_session_fetch_all_keys", - (SessionHandle, FfiStr, FfiStr, FfiJson, c_int64, c_int8), - handle, - alg, - thumbprint, - tag_filter, - limit if limit is not None else -1, - for_update, - return_type=KeyEntryListHandle, - ) - - -async def session_update_key( - handle: SessionHandle, - name: str, - metadata: str = None, - tags: dict = None, - expiry_ms: Optional[int] = None, -): - await invoke_async( - "askar_session_update_key", - (SessionHandle, FfiStr, FfiStr, FfiTagsJson, c_int64), - handle, - name, - metadata, - tags, - -1 if expiry_ms is None else expiry_ms, - ) - - -async def session_remove_key(handle: SessionHandle, name: str): - await invoke_async( - "askar_session_remove_key", - (SessionHandle, FfiStr), - handle, - name, - ) - - -async def scan_start( - handle: StoreHandle, - profile: Optional[str], - category: str, - tag_filter: Union[str, dict] = None, - offset: int = None, - limit: int = None, -) -> ScanHandle: - """Create a new Scan against the Store.""" - return await invoke_async( - "askar_scan_start", - (StoreHandle, FfiStr, FfiStr, FfiJson, c_int64, c_int64), - handle, - profile, - category, - tag_filter, - offset or 0, - limit if limit is not None else -1, - return_type=ScanHandle, - ) - - -async def scan_next(handle: ScanHandle) -> EntryListHandle: - return await invoke_async( - "askar_scan_next", (ScanHandle,), handle, return_type=EntryListHandle - ) - - -def entry_list_count(handle: EntryListHandle) -> int: - len = c_int32() - invoke( - "askar_entry_list_count", - (EntryListHandle, POINTER(c_int32)), - handle, - byref(len), - ) - return len.value - - -def key_entry_list_count(handle: KeyEntryListHandle) -> int: - len = c_int32() - invoke( - "askar_key_entry_list_count", - (KeyEntryListHandle, POINTER(c_int32)), - handle, - byref(len), - ) - return len.value - - -def key_generate(alg: Union[str, KeyAlg], ephemeral: bool = False) -> LocalKeyHandle: - handle = LocalKeyHandle() - if isinstance(alg, KeyAlg): - alg = alg.value - invoke( - "askar_key_generate", - (FfiStr, c_int8, POINTER(LocalKeyHandle)), - alg, - ephemeral, - byref(handle), - ) - return handle - - -def key_from_seed( - alg: Union[str, KeyAlg], - seed: Union[str, bytes, ByteBuffer], - method: Union[str, SeedMethod] = None, -) -> LocalKeyHandle: - handle = LocalKeyHandle() - if isinstance(alg, KeyAlg): - alg = alg.value - if isinstance(method, SeedMethod): - method = method.value - invoke( - "askar_key_from_seed", - (FfiStr, FfiByteBuffer, FfiStr, POINTER(LocalKeyHandle)), - alg, - seed, - method, - byref(handle), - ) - return handle - - -def key_from_public_bytes( - alg: Union[str, KeyAlg], public: Union[bytes, ByteBuffer] -) -> LocalKeyHandle: - handle = LocalKeyHandle() - if isinstance(alg, KeyAlg): - alg = alg.value - invoke( - "askar_key_from_public_bytes", - (FfiStr, FfiByteBuffer, POINTER(LocalKeyHandle)), - alg, - public, - byref(handle), - ) - return handle - - -def key_get_public_bytes(handle: LocalKeyHandle) -> ByteBuffer: - buf = ByteBuffer() - invoke( - "askar_key_get_public_bytes", - (LocalKeyHandle, POINTER(ByteBuffer)), - handle, - byref(buf), - ) - return buf - - -def key_from_secret_bytes( - alg: Union[str, KeyAlg], secret: Union[bytes, ByteBuffer] -) -> LocalKeyHandle: - handle = LocalKeyHandle() - if isinstance(alg, KeyAlg): - alg = alg.value - invoke( - "askar_key_from_secret_bytes", - (FfiStr, FfiByteBuffer, POINTER(LocalKeyHandle)), - alg, - secret, - byref(handle), - ) - return handle - - -def key_get_secret_bytes(handle: LocalKeyHandle) -> ByteBuffer: - buf = ByteBuffer() - invoke( - "askar_key_get_secret_bytes", - (LocalKeyHandle, POINTER(ByteBuffer)), - handle, - byref(buf), - ) - return buf - - -def key_from_jwk(jwk: Union[dict, str, bytes]) -> LocalKeyHandle: - handle = LocalKeyHandle() - if isinstance(jwk, dict): - jwk = json.dumps(jwk) - invoke( - "askar_key_from_jwk", - (FfiByteBuffer, POINTER(LocalKeyHandle)), - jwk, - byref(handle), - ) - return handle - - -def key_convert(handle: LocalKeyHandle, alg: Union[str, KeyAlg]) -> LocalKeyHandle: - key = LocalKeyHandle() - if isinstance(alg, KeyAlg): - alg = alg.value - invoke( - "askar_key_convert", - (LocalKeyHandle, FfiStr, POINTER(LocalKeyHandle)), - handle, - alg, - byref(key), - ) - return key - - -def key_exchange( - alg: Union[str, KeyAlg], sk_handle: LocalKeyHandle, pk_handle: LocalKeyHandle -) -> LocalKeyHandle: - key = LocalKeyHandle() - if isinstance(alg, KeyAlg): - alg = alg.value - invoke( - "askar_key_from_key_exchange", - (FfiStr, LocalKeyHandle, LocalKeyHandle, POINTER(LocalKeyHandle)), - alg, - sk_handle, - pk_handle, - byref(key), - ) - return key - - -def key_get_algorithm(handle: LocalKeyHandle) -> str: - alg = StrBuffer() - invoke( - "askar_key_get_algorithm", - (LocalKeyHandle, POINTER(c_char_p)), - handle, - byref(alg), - ) - return str(alg) - - -def key_get_ephemeral(handle: LocalKeyHandle) -> bool: - eph = c_int8() - invoke( - "askar_key_get_ephemeral", - (LocalKeyHandle, POINTER(c_int8)), - handle, - byref(eph), - ) - return eph.value != 0 - - -def key_get_jwk_public(handle: LocalKeyHandle, alg: Union[str, KeyAlg] = None) -> str: - jwk = StrBuffer() - if isinstance(alg, KeyAlg): - alg = alg.value - invoke( - "askar_key_get_jwk_public", - (LocalKeyHandle, FfiStr, POINTER(c_char_p)), - handle, - alg, - byref(jwk), - ) - return str(jwk) - - -def key_get_jwk_secret(handle: LocalKeyHandle) -> ByteBuffer: - sec = ByteBuffer() - invoke( - "askar_key_get_jwk_secret", - (LocalKeyHandle, POINTER(ByteBuffer)), - handle, - byref(sec), - ) - return sec - - -def key_get_jwk_thumbprint( - handle: LocalKeyHandle, alg: Union[str, KeyAlg] = None -) -> str: - thumb = StrBuffer() - if isinstance(alg, KeyAlg): - alg = alg.value - invoke( - "askar_key_get_jwk_thumbprint", - (LocalKeyHandle, FfiStr, POINTER(c_char_p)), - handle, - alg, - byref(thumb), - ) - return str(thumb) - - -def key_aead_get_params(handle: LocalKeyHandle) -> AeadParams: - params = AeadParams() - invoke( - "askar_key_aead_get_params", - (LocalKeyHandle, POINTER(AeadParams)), - handle, - byref(params), - ) - return params - - -def key_aead_random_nonce(handle: LocalKeyHandle) -> ByteBuffer: - nonce = ByteBuffer() - invoke( - "askar_key_aead_random_nonce", - (LocalKeyHandle, POINTER(ByteBuffer)), - handle, - byref(nonce), - ) - return nonce - - -def key_aead_encrypt( - handle: LocalKeyHandle, - input: Union[bytes, str, ByteBuffer], - nonce: Union[bytes, ByteBuffer], - aad: Optional[Union[bytes, ByteBuffer]], -) -> Encrypted: - enc = Encrypted() - invoke( - "askar_key_aead_encrypt", - ( - LocalKeyHandle, - FfiByteBuffer, - FfiByteBuffer, - FfiByteBuffer, - POINTER(Encrypted), - ), - handle, - input, - nonce, - aad, - byref(enc), - ) - return enc - - -def key_aead_decrypt( - handle: LocalKeyHandle, - ciphertext: Union[bytes, ByteBuffer, Encrypted], - nonce: Union[bytes, ByteBuffer], - tag: Optional[Union[bytes, ByteBuffer]], - aad: Optional[Union[bytes, ByteBuffer]], -) -> ByteBuffer: - dec = ByteBuffer() - if isinstance(ciphertext, Encrypted): - nonce = ciphertext.nonce - ciphertext = ciphertext.ciphertext_tag - invoke( - "askar_key_aead_decrypt", - ( - LocalKeyHandle, - FfiByteBuffer, - FfiByteBuffer, - FfiByteBuffer, - FfiByteBuffer, - POINTER(ByteBuffer), - ), - handle, - ciphertext, - nonce, - tag, - aad, - byref(dec), - ) - return dec - - -def key_sign_message( - handle: LocalKeyHandle, - message: Union[bytes, str, ByteBuffer], - sig_type: Optional[str], -) -> ByteBuffer: - sig = ByteBuffer() - invoke( - "askar_key_sign_message", - (LocalKeyHandle, FfiByteBuffer, FfiStr, POINTER(ByteBuffer)), - handle, - message, - sig_type, - byref(sig), - ) - return sig - - -def key_verify_signature( - handle: LocalKeyHandle, - message: Union[bytes, str, ByteBuffer], - signature: Union[bytes, ByteBuffer], - sig_type: Optional[str], -) -> bool: - verify = c_int8() - invoke( - "askar_key_verify_signature", - (LocalKeyHandle, FfiByteBuffer, FfiByteBuffer, FfiStr, POINTER(c_int8)), - handle, - message, - signature, - sig_type, - byref(verify), - ) - return verify.value != 0 - - -def key_wrap_key( - handle: LocalKeyHandle, - other: LocalKeyHandle, - nonce: Optional[Union[bytes, ByteBuffer]], -) -> Encrypted: - wrapped = Encrypted() - invoke( - "askar_key_wrap_key", - (LocalKeyHandle, LocalKeyHandle, FfiByteBuffer, POINTER(Encrypted)), - handle, - other, - nonce, - byref(wrapped), - ) - return wrapped - - -def key_unwrap_key( - handle: LocalKeyHandle, - alg: Union[str, KeyAlg], - ciphertext: Union[bytes, ByteBuffer, Encrypted], - nonce: Union[bytes, ByteBuffer], - tag: Optional[Union[bytes, ByteBuffer]], -) -> LocalKeyHandle: - result = LocalKeyHandle() - if isinstance(alg, KeyAlg): - alg = alg.value - if isinstance(ciphertext, Encrypted): - ciphertext = ciphertext.ciphertext_tag - invoke( - "askar_key_unwrap_key", - ( - LocalKeyHandle, - FfiStr, - FfiByteBuffer, - FfiByteBuffer, - FfiByteBuffer, - POINTER(LocalKeyHandle), - ), - handle, - alg, - ciphertext, - nonce, - tag, - byref(result), - ) - return result - - -def key_crypto_box_random_nonce() -> ByteBuffer: - buf = ByteBuffer() - invoke( - "askar_key_crypto_box_random_nonce", - (POINTER(ByteBuffer),), - byref(buf), - ) - return buf - - -def key_crypto_box( - recip_handle: LocalKeyHandle, - sender_handle: LocalKeyHandle, - message: Union[bytes, str, ByteBuffer], - nonce: Union[bytes, ByteBuffer], -) -> ByteBuffer: - buf = ByteBuffer() - invoke( - "askar_key_crypto_box", - ( - LocalKeyHandle, - LocalKeyHandle, - FfiByteBuffer, - FfiByteBuffer, - POINTER(ByteBuffer), - ), - recip_handle, - sender_handle, - message, - nonce, - byref(buf), - ) - return buf - - -def key_crypto_box_open( - recip_handle: LocalKeyHandle, - sender_handle: LocalKeyHandle, - message: Union[bytes, str, ByteBuffer], - nonce: Union[bytes, ByteBuffer], -) -> ByteBuffer: - buf = ByteBuffer() - invoke( - "askar_key_crypto_box_open", - ( - LocalKeyHandle, - LocalKeyHandle, - FfiByteBuffer, - FfiByteBuffer, - POINTER(ByteBuffer), - ), - recip_handle, - sender_handle, - message, - nonce, - byref(buf), - ) - return buf - - -def key_crypto_box_seal( - handle: LocalKeyHandle, - message: Union[bytes, str, ByteBuffer], -) -> ByteBuffer: - buf = ByteBuffer() - invoke( - "askar_key_crypto_box_seal", - (LocalKeyHandle, FfiByteBuffer, POINTER(ByteBuffer)), - handle, - message, - byref(buf), - ) - return buf - - -def key_crypto_box_seal_open( - handle: LocalKeyHandle, - ciphertext: Union[bytes, ByteBuffer], -) -> ByteBuffer: - buf = ByteBuffer() - invoke( - "askar_key_crypto_box_seal_open", - (LocalKeyHandle, FfiByteBuffer, POINTER(ByteBuffer)), - handle, - ciphertext, - byref(buf), - ) - return buf - - -def key_derive_ecdh_es( - key_alg: Union[str, KeyAlg], - ephem_key: LocalKeyHandle, - receiver_key: LocalKeyHandle, - alg_id: Union[bytes, str, ByteBuffer], - apu: Union[bytes, str, ByteBuffer], - apv: Union[bytes, str, ByteBuffer], - receive: bool, -) -> LocalKeyHandle: - key = LocalKeyHandle() - if isinstance(key_alg, KeyAlg): - key_alg = key_alg.value - invoke( - "askar_key_derive_ecdh_es", - ( - FfiStr, - LocalKeyHandle, - LocalKeyHandle, - FfiByteBuffer, - FfiByteBuffer, - FfiByteBuffer, - c_int8, - POINTER(LocalKeyHandle), - ), - key_alg, - ephem_key, - receiver_key, - alg_id, - apu, - apv, - receive, - byref(key), - ) - return key - - -def key_derive_ecdh_1pu( - key_alg: Union[str, KeyAlg], - ephem_key: LocalKeyHandle, - sender_key: LocalKeyHandle, - receiver_key: LocalKeyHandle, - alg_id: Union[bytes, str, ByteBuffer], - apu: Union[bytes, str, ByteBuffer], - apv: Union[bytes, str, ByteBuffer], - cc_tag: Optional[Union[bytes, ByteBuffer]], - receive: bool, -) -> LocalKeyHandle: - key = LocalKeyHandle() - if isinstance(key_alg, KeyAlg): - key_alg = key_alg.value - invoke( - "askar_key_derive_ecdh_1pu", - ( - FfiStr, - LocalKeyHandle, - LocalKeyHandle, - LocalKeyHandle, - FfiByteBuffer, - FfiByteBuffer, - FfiByteBuffer, - FfiByteBuffer, - c_int8, - POINTER(LocalKeyHandle), - ), - key_alg, - ephem_key, - sender_key, - receiver_key, - alg_id, - apu, - apv, - cc_tag, - receive, - byref(key), - ) - return key diff --git a/wrappers/python/aries_askar/bindings/__init__.py b/wrappers/python/aries_askar/bindings/__init__.py new file mode 100644 index 00000000..4b6f6684 --- /dev/null +++ b/wrappers/python/aries_askar/bindings/__init__.py @@ -0,0 +1,948 @@ +"""Low-level interaction with the aries-askar library.""" + +import asyncio +import json +import logging +import sys + +from ctypes import POINTER, byref, c_char_p, c_int8, c_int32, c_int64 +from typing import Optional, Union + +from ..types import EntryOperation, KeyAlg, SeedMethod + +from .lib import ( + AeadParams, + ByteBuffer, + Encrypted, + FfiByteBuffer, + FfiJson, + FfiStr, + FfiTagsJson, + Lib, + StrBuffer, +) +from .handle import ( + EntryListHandle, + KeyEntryListHandle, + LocalKeyHandle, + ScanHandle, + SessionHandle, + StoreHandle, +) + + +LIB = Lib() +LOGGER = logging.getLogger(__name__) +MODULE_NAME = __name__.split(".")[0] + + +def get_library(init: bool = True) -> Lib: + """Return the library instance, loading it if necessary.""" + global LIB + if LIB and init: + # preload library - required to create handle instances + LIB.loaded + return LIB + + +def set_max_log_level(level: Union[str, int, None]): + """Set the maximum logging level.""" + get_library().set_max_log_level(level) + + +def invoke(name, argtypes, *args): + """Perform a synchronous library function call.""" + get_library().invoke(name, argtypes, *args) + + +def invoke_async(name: str, argtypes, *args, return_type=None) -> asyncio.Future: + """Perform an asynchronous library function call.""" + return get_library().invoke_async(name, argtypes, *args, return_type=return_type) + + +def generate_raw_key(seed: Union[str, bytes] = None) -> str: + """Generate a new raw store wrapping key.""" + key = StrBuffer() + invoke( + "askar_store_generate_raw_key", + (FfiByteBuffer, POINTER(StrBuffer)), + seed, + byref(key), + ) + return str(key) + + +def version() -> str: + """Get the version of the installed library.""" + return get_library().version() + + +async def store_open( + uri: str, key_method: str = None, pass_key: str = None, profile: str = None +) -> StoreHandle: + """Open an existing Store and return the open handle.""" + return await invoke_async( + "askar_store_open", + (FfiStr, FfiStr, FfiStr, FfiStr), + uri, + key_method and key_method.lower(), + pass_key, + profile, + return_type=StoreHandle, + ) + + +async def store_provision( + uri: str, + key_method: str = None, + pass_key: str = None, + profile: str = None, + recreate: bool = False, +) -> StoreHandle: + """Provision a new Store and return the open handle.""" + return await invoke_async( + "askar_store_provision", + (FfiStr, FfiStr, FfiStr, FfiStr, c_int8), + uri, + key_method and key_method.lower(), + pass_key, + profile, + recreate, + return_type=StoreHandle, + ) + + +async def store_create_profile(handle: StoreHandle, name: str = None) -> str: + """Create a new profile in a Store.""" + return str( + await invoke_async( + "askar_store_create_profile", + (StoreHandle, FfiStr), + handle, + name, + return_type=StrBuffer, + ) + ) + + +async def store_get_profile_name(handle: StoreHandle) -> str: + """Get the name of the default Store instance profile.""" + return str( + await invoke_async( + "askar_store_get_profile_name", + (StoreHandle,), + handle, + return_type=StrBuffer, + ) + ) + + +async def store_remove_profile(handle: StoreHandle, name: str) -> bool: + """Remove an existing profile from a Store.""" + return ( + await invoke_async( + "askar_store_remove_profile", + (StoreHandle, FfiStr), + handle, + name, + return_type=c_int8, + ) + != 0 + ) + + +async def store_rekey( + handle: StoreHandle, + key_method: str = None, + pass_key: str = None, +) -> StoreHandle: + """Replace the store key on a Store.""" + return await invoke_async( + "askar_store_rekey", + (StoreHandle, FfiStr, FfiStr), + handle, + key_method and key_method.lower(), + pass_key, + return_type=c_int8, + ) + + +async def store_remove(uri: str) -> bool: + """Remove an existing Store, if any.""" + return ( + await invoke_async( + "askar_store_remove", + (FfiStr,), + uri, + return_type=c_int8, + ) + != 0 + ) + + +async def session_start( + handle: StoreHandle, profile: Optional[str] = None, as_transaction: bool = False +) -> SessionHandle: + """Start a new session with an open Store.""" + handle = await invoke_async( + "askar_session_start", + (StoreHandle, FfiStr, c_int8), + handle, + profile, + as_transaction, + return_type=SessionHandle, + ) + return handle + + +async def session_count( + handle: SessionHandle, category: str, tag_filter: Union[str, dict] = None +) -> int: + """Count rows in the Store.""" + return int( + await invoke_async( + "askar_session_count", + (SessionHandle, FfiStr, FfiJson), + handle, + category, + tag_filter, + return_type=c_int64, + ) + ) + + +async def session_fetch( + handle: SessionHandle, category: str, name: str, for_update: bool = False +) -> EntryListHandle: + """Fetch a row from the Store.""" + return await invoke_async( + "askar_session_fetch", + (SessionHandle, FfiStr, FfiStr, c_int8), + handle, + category, + name, + for_update, + return_type=EntryListHandle, + ) + + +async def session_fetch_all( + handle: SessionHandle, + category: str, + tag_filter: Union[str, dict] = None, + limit: int = None, + for_update: bool = False, +) -> EntryListHandle: + """Fetch all matching rows in the Store.""" + return await invoke_async( + "askar_session_fetch_all", + (SessionHandle, FfiStr, FfiJson, c_int64, c_int8), + handle, + category, + tag_filter, + limit if limit is not None else -1, + for_update, + return_type=EntryListHandle, + ) + + +async def session_remove_all( + handle: SessionHandle, + category: str, + tag_filter: Union[str, dict] = None, +) -> int: + """Remove all matching rows in the Store.""" + return int( + await invoke_async( + "askar_session_remove_all", + (SessionHandle, FfiStr, FfiJson), + handle, + category, + tag_filter, + return_type=c_int64, + ) + ) + + +async def session_update( + handle: SessionHandle, + operation: EntryOperation, + category: str, + name: str, + value: Union[str, bytes] = None, + tags: dict = None, + expiry_ms: Optional[int] = None, +): + """Update a Store by inserting, updating, or removing a record.""" + return await invoke_async( + "askar_session_update", + (SessionHandle, c_int8, FfiStr, FfiStr, FfiByteBuffer, FfiTagsJson, c_int64), + handle, + operation.value, + category, + name, + value, + tags, + -1 if expiry_ms is None else expiry_ms, + ) + + +async def session_insert_key( + handle: SessionHandle, + key_handle: LocalKeyHandle, + name: str, + metadata: str = None, + tags: dict = None, + expiry_ms: Optional[int] = None, +): + return await invoke_async( + "askar_session_insert_key", + (SessionHandle, LocalKeyHandle, FfiStr, FfiStr, FfiTagsJson, c_int64), + handle, + key_handle, + name, + metadata, + tags, + -1 if expiry_ms is None else expiry_ms, + ) + + +async def session_fetch_key( + handle: SessionHandle, name: str, for_update: bool = False +) -> KeyEntryListHandle: + return await invoke_async( + "askar_session_fetch_key", + (SessionHandle, FfiStr, c_int8), + handle, + name, + for_update, + return_type=KeyEntryListHandle, + ) + + +async def session_fetch_all_keys( + handle: SessionHandle, + alg: Union[str, KeyAlg] = None, + thumbprint: str = None, + tag_filter: Union[str, dict] = None, + limit: int = None, + for_update: bool = False, +) -> KeyEntryListHandle: + """Fetch all matching keys in the Store.""" + if isinstance(alg, KeyAlg): + alg = alg.value + return await invoke_async( + "askar_session_fetch_all_keys", + (SessionHandle, FfiStr, FfiStr, FfiJson, c_int64, c_int8), + handle, + alg, + thumbprint, + tag_filter, + limit if limit is not None else -1, + for_update, + return_type=KeyEntryListHandle, + ) + + +async def session_update_key( + handle: SessionHandle, + name: str, + metadata: str = None, + tags: dict = None, + expiry_ms: Optional[int] = None, +): + await invoke_async( + "askar_session_update_key", + (SessionHandle, FfiStr, FfiStr, FfiTagsJson, c_int64), + handle, + name, + metadata, + tags, + -1 if expiry_ms is None else expiry_ms, + ) + + +async def session_remove_key(handle: SessionHandle, name: str): + await invoke_async( + "askar_session_remove_key", + (SessionHandle, FfiStr), + handle, + name, + ) + + +async def scan_start( + handle: StoreHandle, + profile: Optional[str], + category: str, + tag_filter: Union[str, dict] = None, + offset: int = None, + limit: int = None, +) -> ScanHandle: + """Create a new Scan against the Store.""" + return await invoke_async( + "askar_scan_start", + (StoreHandle, FfiStr, FfiStr, FfiJson, c_int64, c_int64), + handle, + profile, + category, + tag_filter, + offset or 0, + limit if limit is not None else -1, + return_type=ScanHandle, + ) + + +async def scan_next(handle: ScanHandle) -> EntryListHandle: + return await invoke_async( + "askar_scan_next", (ScanHandle,), handle, return_type=EntryListHandle + ) + + +def entry_list_count(handle: EntryListHandle) -> int: + len = c_int32() + invoke( + "askar_entry_list_count", + (EntryListHandle, POINTER(c_int32)), + handle, + byref(len), + ) + return len.value + + +def key_entry_list_count(handle: KeyEntryListHandle) -> int: + len = c_int32() + invoke( + "askar_key_entry_list_count", + (KeyEntryListHandle, POINTER(c_int32)), + handle, + byref(len), + ) + return len.value + + +def key_generate(alg: Union[str, KeyAlg], ephemeral: bool = False) -> LocalKeyHandle: + handle = LocalKeyHandle() + if isinstance(alg, KeyAlg): + alg = alg.value + invoke( + "askar_key_generate", + (FfiStr, c_int8, POINTER(LocalKeyHandle)), + alg, + ephemeral, + byref(handle), + ) + return handle + + +def key_from_seed( + alg: Union[str, KeyAlg], + seed: Union[str, bytes, ByteBuffer], + method: Union[str, SeedMethod] = None, +) -> LocalKeyHandle: + handle = LocalKeyHandle() + if isinstance(alg, KeyAlg): + alg = alg.value + if isinstance(method, SeedMethod): + method = method.value + invoke( + "askar_key_from_seed", + (FfiStr, FfiByteBuffer, FfiStr, POINTER(LocalKeyHandle)), + alg, + seed, + method, + byref(handle), + ) + return handle + + +def key_from_public_bytes( + alg: Union[str, KeyAlg], public: Union[bytes, ByteBuffer] +) -> LocalKeyHandle: + handle = LocalKeyHandle() + if isinstance(alg, KeyAlg): + alg = alg.value + invoke( + "askar_key_from_public_bytes", + (FfiStr, FfiByteBuffer, POINTER(LocalKeyHandle)), + alg, + public, + byref(handle), + ) + return handle + + +def key_get_public_bytes(handle: LocalKeyHandle) -> ByteBuffer: + buf = ByteBuffer() + invoke( + "askar_key_get_public_bytes", + (LocalKeyHandle, POINTER(ByteBuffer)), + handle, + byref(buf), + ) + return buf + + +def key_from_secret_bytes( + alg: Union[str, KeyAlg], secret: Union[bytes, ByteBuffer] +) -> LocalKeyHandle: + handle = LocalKeyHandle() + if isinstance(alg, KeyAlg): + alg = alg.value + invoke( + "askar_key_from_secret_bytes", + (FfiStr, FfiByteBuffer, POINTER(LocalKeyHandle)), + alg, + secret, + byref(handle), + ) + return handle + + +def key_get_secret_bytes(handle: LocalKeyHandle) -> ByteBuffer: + buf = ByteBuffer() + invoke( + "askar_key_get_secret_bytes", + (LocalKeyHandle, POINTER(ByteBuffer)), + handle, + byref(buf), + ) + return buf + + +def key_from_jwk(jwk: Union[dict, str, bytes]) -> LocalKeyHandle: + handle = LocalKeyHandle() + if isinstance(jwk, dict): + jwk = json.dumps(jwk) + invoke( + "askar_key_from_jwk", + (FfiByteBuffer, POINTER(LocalKeyHandle)), + jwk, + byref(handle), + ) + return handle + + +def key_convert(handle: LocalKeyHandle, alg: Union[str, KeyAlg]) -> LocalKeyHandle: + key = LocalKeyHandle() + if isinstance(alg, KeyAlg): + alg = alg.value + invoke( + "askar_key_convert", + (LocalKeyHandle, FfiStr, POINTER(LocalKeyHandle)), + handle, + alg, + byref(key), + ) + return key + + +def key_exchange( + alg: Union[str, KeyAlg], sk_handle: LocalKeyHandle, pk_handle: LocalKeyHandle +) -> LocalKeyHandle: + key = LocalKeyHandle() + if isinstance(alg, KeyAlg): + alg = alg.value + invoke( + "askar_key_from_key_exchange", + (FfiStr, LocalKeyHandle, LocalKeyHandle, POINTER(LocalKeyHandle)), + alg, + sk_handle, + pk_handle, + byref(key), + ) + return key + + +def key_get_algorithm(handle: LocalKeyHandle) -> str: + alg = StrBuffer() + invoke( + "askar_key_get_algorithm", + (LocalKeyHandle, POINTER(StrBuffer)), + handle, + byref(alg), + ) + return str(alg) + + +def key_get_ephemeral(handle: LocalKeyHandle) -> bool: + eph = c_int8() + invoke( + "askar_key_get_ephemeral", + (LocalKeyHandle, POINTER(c_int8)), + handle, + byref(eph), + ) + return eph.value != 0 + + +def key_get_jwk_public(handle: LocalKeyHandle, alg: Union[str, KeyAlg] = None) -> str: + jwk = StrBuffer() + if isinstance(alg, KeyAlg): + alg = alg.value + invoke( + "askar_key_get_jwk_public", + (LocalKeyHandle, FfiStr, POINTER(StrBuffer)), + handle, + alg, + byref(jwk), + ) + return str(jwk) + + +def key_get_jwk_secret(handle: LocalKeyHandle) -> ByteBuffer: + sec = ByteBuffer() + invoke( + "askar_key_get_jwk_secret", + (LocalKeyHandle, POINTER(ByteBuffer)), + handle, + byref(sec), + ) + return sec + + +def key_get_jwk_thumbprint( + handle: LocalKeyHandle, alg: Union[str, KeyAlg] = None +) -> str: + thumb = StrBuffer() + if isinstance(alg, KeyAlg): + alg = alg.value + invoke( + "askar_key_get_jwk_thumbprint", + (LocalKeyHandle, FfiStr, POINTER(StrBuffer)), + handle, + alg, + byref(thumb), + ) + return str(thumb) + + +def key_aead_get_params(handle: LocalKeyHandle) -> AeadParams: + params = AeadParams() + invoke( + "askar_key_aead_get_params", + (LocalKeyHandle, POINTER(AeadParams)), + handle, + byref(params), + ) + return params + + +def key_aead_random_nonce(handle: LocalKeyHandle) -> ByteBuffer: + nonce = ByteBuffer() + invoke( + "askar_key_aead_random_nonce", + (LocalKeyHandle, POINTER(ByteBuffer)), + handle, + byref(nonce), + ) + return nonce + + +def key_aead_encrypt( + handle: LocalKeyHandle, + input: Union[bytes, str, ByteBuffer], + nonce: Union[bytes, ByteBuffer], + aad: Optional[Union[bytes, ByteBuffer]], +) -> Encrypted: + enc = Encrypted() + invoke( + "askar_key_aead_encrypt", + ( + LocalKeyHandle, + FfiByteBuffer, + FfiByteBuffer, + FfiByteBuffer, + POINTER(Encrypted), + ), + handle, + input, + nonce, + aad, + byref(enc), + ) + return enc + + +def key_aead_decrypt( + handle: LocalKeyHandle, + ciphertext: Union[bytes, ByteBuffer, Encrypted], + nonce: Union[bytes, ByteBuffer], + tag: Optional[Union[bytes, ByteBuffer]], + aad: Optional[Union[bytes, ByteBuffer]], +) -> ByteBuffer: + dec = ByteBuffer() + if isinstance(ciphertext, Encrypted): + nonce = ciphertext.nonce + ciphertext = ciphertext.ciphertext_tag + invoke( + "askar_key_aead_decrypt", + ( + LocalKeyHandle, + FfiByteBuffer, + FfiByteBuffer, + FfiByteBuffer, + FfiByteBuffer, + POINTER(ByteBuffer), + ), + handle, + ciphertext, + nonce, + tag, + aad, + byref(dec), + ) + return dec + + +def key_sign_message( + handle: LocalKeyHandle, + message: Union[bytes, str, ByteBuffer], + sig_type: Optional[str], +) -> ByteBuffer: + sig = ByteBuffer() + invoke( + "askar_key_sign_message", + (LocalKeyHandle, FfiByteBuffer, FfiStr, POINTER(ByteBuffer)), + handle, + message, + sig_type, + byref(sig), + ) + return sig + + +def key_verify_signature( + handle: LocalKeyHandle, + message: Union[bytes, str, ByteBuffer], + signature: Union[bytes, ByteBuffer], + sig_type: Optional[str], +) -> bool: + verify = c_int8() + invoke( + "askar_key_verify_signature", + (LocalKeyHandle, FfiByteBuffer, FfiByteBuffer, FfiStr, POINTER(c_int8)), + handle, + message, + signature, + sig_type, + byref(verify), + ) + return verify.value != 0 + + +def key_wrap_key( + handle: LocalKeyHandle, + other: LocalKeyHandle, + nonce: Optional[Union[bytes, ByteBuffer]], +) -> Encrypted: + wrapped = Encrypted() + invoke( + "askar_key_wrap_key", + (LocalKeyHandle, LocalKeyHandle, FfiByteBuffer, POINTER(Encrypted)), + handle, + other, + nonce, + byref(wrapped), + ) + return wrapped + + +def key_unwrap_key( + handle: LocalKeyHandle, + alg: Union[str, KeyAlg], + ciphertext: Union[bytes, ByteBuffer, Encrypted], + nonce: Union[bytes, ByteBuffer], + tag: Optional[Union[bytes, ByteBuffer]], +) -> LocalKeyHandle: + result = LocalKeyHandle() + if isinstance(alg, KeyAlg): + alg = alg.value + if isinstance(ciphertext, Encrypted): + ciphertext = ciphertext.ciphertext_tag + invoke( + "askar_key_unwrap_key", + ( + LocalKeyHandle, + FfiStr, + FfiByteBuffer, + FfiByteBuffer, + FfiByteBuffer, + POINTER(LocalKeyHandle), + ), + handle, + alg, + ciphertext, + nonce, + tag, + byref(result), + ) + return result + + +def key_crypto_box_random_nonce() -> ByteBuffer: + buf = ByteBuffer() + invoke( + "askar_key_crypto_box_random_nonce", + (POINTER(ByteBuffer),), + byref(buf), + ) + return buf + + +def key_crypto_box( + recip_handle: LocalKeyHandle, + sender_handle: LocalKeyHandle, + message: Union[bytes, str, ByteBuffer], + nonce: Union[bytes, ByteBuffer], +) -> ByteBuffer: + buf = ByteBuffer() + invoke( + "askar_key_crypto_box", + ( + LocalKeyHandle, + LocalKeyHandle, + FfiByteBuffer, + FfiByteBuffer, + POINTER(ByteBuffer), + ), + recip_handle, + sender_handle, + message, + nonce, + byref(buf), + ) + return buf + + +def key_crypto_box_open( + recip_handle: LocalKeyHandle, + sender_handle: LocalKeyHandle, + message: Union[bytes, str, ByteBuffer], + nonce: Union[bytes, ByteBuffer], +) -> ByteBuffer: + buf = ByteBuffer() + invoke( + "askar_key_crypto_box_open", + ( + LocalKeyHandle, + LocalKeyHandle, + FfiByteBuffer, + FfiByteBuffer, + POINTER(ByteBuffer), + ), + recip_handle, + sender_handle, + message, + nonce, + byref(buf), + ) + return buf + + +def key_crypto_box_seal( + handle: LocalKeyHandle, + message: Union[bytes, str, ByteBuffer], +) -> ByteBuffer: + buf = ByteBuffer() + invoke( + "askar_key_crypto_box_seal", + (LocalKeyHandle, FfiByteBuffer, POINTER(ByteBuffer)), + handle, + message, + byref(buf), + ) + return buf + + +def key_crypto_box_seal_open( + handle: LocalKeyHandle, + ciphertext: Union[bytes, ByteBuffer], +) -> ByteBuffer: + buf = ByteBuffer() + invoke( + "askar_key_crypto_box_seal_open", + (LocalKeyHandle, FfiByteBuffer, POINTER(ByteBuffer)), + handle, + ciphertext, + byref(buf), + ) + return buf + + +def key_derive_ecdh_es( + key_alg: Union[str, KeyAlg], + ephem_key: LocalKeyHandle, + receiver_key: LocalKeyHandle, + alg_id: Union[bytes, str, ByteBuffer], + apu: Union[bytes, str, ByteBuffer], + apv: Union[bytes, str, ByteBuffer], + receive: bool, +) -> LocalKeyHandle: + key = LocalKeyHandle() + if isinstance(key_alg, KeyAlg): + key_alg = key_alg.value + invoke( + "askar_key_derive_ecdh_es", + ( + FfiStr, + LocalKeyHandle, + LocalKeyHandle, + FfiByteBuffer, + FfiByteBuffer, + FfiByteBuffer, + c_int8, + POINTER(LocalKeyHandle), + ), + key_alg, + ephem_key, + receiver_key, + alg_id, + apu, + apv, + receive, + byref(key), + ) + return key + + +def key_derive_ecdh_1pu( + key_alg: Union[str, KeyAlg], + ephem_key: LocalKeyHandle, + sender_key: LocalKeyHandle, + receiver_key: LocalKeyHandle, + alg_id: Union[bytes, str, ByteBuffer], + apu: Union[bytes, str, ByteBuffer], + apv: Union[bytes, str, ByteBuffer], + cc_tag: Optional[Union[bytes, ByteBuffer]], + receive: bool, +) -> LocalKeyHandle: + key = LocalKeyHandle() + if isinstance(key_alg, KeyAlg): + key_alg = key_alg.value + invoke( + "askar_key_derive_ecdh_1pu", + ( + FfiStr, + LocalKeyHandle, + LocalKeyHandle, + LocalKeyHandle, + FfiByteBuffer, + FfiByteBuffer, + FfiByteBuffer, + FfiByteBuffer, + c_int8, + POINTER(LocalKeyHandle), + ), + key_alg, + ephem_key, + sender_key, + receiver_key, + alg_id, + apu, + apv, + cc_tag, + receive, + byref(key), + ) + return key diff --git a/wrappers/python/aries_askar/bindings/handle.py b/wrappers/python/aries_askar/bindings/handle.py new file mode 100644 index 00000000..8f6b0c25 --- /dev/null +++ b/wrappers/python/aries_askar/bindings/handle.py @@ -0,0 +1,248 @@ +"""Handles for allocated resources.""" + +import json +import logging + +from ctypes import ( + POINTER, + Structure, + byref, + c_int8, + c_int32, + c_int64, + c_size_t, + c_void_p, +) + +from .lib import ByteBuffer, Lib, StrBuffer, finalize_struct + + +LOGGER = logging.getLogger(__name__) + + +class ArcHandle(Structure): + """Base class for handle instances.""" + + _fields_ = [ + ("value", c_size_t), + ] + _dtor_: str = None + + def __init__(self, value=0): + """Initializer.""" + if isinstance(value, c_size_t): + value = value.value + if not isinstance(value, int): + raise ValueError("Invalid handle") + super().__init__(value=value) + finalize_struct(self, c_size_t) + + @classmethod + def from_param(cls, param): + """Create from an input to a library method invocation.""" + if isinstance(param, cls): + return param + return cls(param) + + def __bool__(self) -> bool: + """Convert to a boolean value.""" + return bool(self.value) + + def __repr__(self) -> str: + """Format handle as a string.""" + return f"{self.__class__.__name__}({self.value})" + + @classmethod + def _cleanup(cls, value: c_size_t): + """Destructor.""" + if cls._dtor_: + Lib().invoke_dtor(cls._dtor_, value) + + +class StoreHandle(ArcHandle): + """Handle for an active Store instance.""" + + async def close(self): + """Manually close the store, waiting for any active connections.""" + if self.value: + await Lib().invoke_async("askar_store_close", (c_size_t,), self.value) + self.value = 0 + + @classmethod + def _cleanup(cls, value: c_size_t): + """Close the store when there are no more references to this object.""" + Lib().invoke_dtor( + "askar_store_close", + value, + None, + 0, + argtypes=(c_size_t, c_void_p, c_int64), + restype=c_int64, + ) + + +class SessionHandle(ArcHandle): + """Handle for an active Session/Transaction instance.""" + + async def close(self, commit: bool = False): + """Manually close the session.""" + if self.value: + await Lib().invoke_async( + "askar_session_close", + (c_size_t, c_int8), + self.value, + commit, + ) + self.value = 0 + + @classmethod + def _cleanup(cls, value: c_size_t): + """Close the session when there are no more references to this object.""" + Lib().invoke_dtor( + "askar_session_close", + value, + 0, + None, + 0, + argtypes=(c_size_t, c_int8, c_void_p, c_int64), + restype=c_int64, + ) + + +class ScanHandle(ArcHandle): + """Handle for an active Store scan instance.""" + + _dtor_ = "askar_scan_free" + + +class EntryListHandle(ArcHandle): + """Handle for an active EntryList instance.""" + + _dtor_ = "askar_entry_list_free" + + def get_category(self, index: int) -> str: + """Get the entry category.""" + cat = StrBuffer() + Lib().invoke( + "askar_entry_list_get_category", + (EntryListHandle, c_int32, POINTER(StrBuffer)), + self, + index, + byref(cat), + ) + return str(cat) + + def get_name(self, index: int) -> str: + """Get the entry name.""" + name = StrBuffer() + Lib().invoke( + "askar_entry_list_get_name", + (EntryListHandle, c_int32, POINTER(StrBuffer)), + self, + index, + byref(name), + ) + return str(name) + + def get_value(self, index: int) -> memoryview: + """Get the entry value.""" + val = ByteBuffer() + Lib().invoke( + "askar_entry_list_get_value", + (EntryListHandle, c_int32, POINTER(ByteBuffer)), + self, + index, + byref(val), + ) + return val.view + + def get_tags(self, index: int) -> dict: + """Get the entry tags.""" + tags = StrBuffer() + Lib().invoke( + "askar_entry_list_get_tags", + (EntryListHandle, c_int32, POINTER(StrBuffer)), + self, + index, + byref(tags), + ) + if tags: + tags = json.loads(tags.value) + for t in tags: + if isinstance(tags[t], list): + tags[t] = set(tags[t]) + else: + tags = dict() + return tags + + +class KeyEntryListHandle(ArcHandle): + """Handle for an active KeyEntryList instance.""" + + _dtor_ = "askar_key_entry_list_free" + + def get_algorithm(self, index: int) -> str: + """Get the key algorithm.""" + name = StrBuffer() + Lib().invoke( + "askar_key_entry_list_get_algorithm", + (KeyEntryListHandle, c_int32, POINTER(StrBuffer)), + self, + index, + byref(name), + ) + return str(name) + + def get_name(self, index: int) -> str: + """Get the key name.""" + name = StrBuffer() + Lib().invoke( + "askar_key_entry_list_get_name", + (KeyEntryListHandle, c_int32, POINTER(StrBuffer)), + self, + index, + byref(name), + ) + return str(name) + + def get_metadata(self, index: int) -> str: + """Get for the key metadata.""" + metadata = StrBuffer() + Lib().invoke( + "askar_key_entry_list_get_metadata", + (KeyEntryListHandle, c_int32, POINTER(StrBuffer)), + self, + index, + byref(metadata), + ) + return str(metadata) + + def get_tags(self, index: int) -> dict: + """Get the key tags.""" + tags = StrBuffer() + Lib().invoke( + "askar_key_entry_list_get_tags", + (KeyEntryListHandle, c_int32, POINTER(StrBuffer)), + self, + index, + byref(tags), + ) + return json.loads(tags.value) if tags else None + + def load_key(self, index: int) -> "LocalKeyHandle": + """Load the key instance.""" + handle = LocalKeyHandle() + Lib().invoke( + "askar_key_entry_list_load_local", + (KeyEntryListHandle, c_int32, POINTER(LocalKeyHandle)), + self, + index, + byref(handle), + ) + return handle + + +class LocalKeyHandle(ArcHandle): + """Handle for an active LocalKey instance.""" + + _dtor_ = "askar_key_free" diff --git a/wrappers/python/aries_askar/bindings/lib.py b/wrappers/python/aries_askar/bindings/lib.py new file mode 100644 index 00000000..8eb64aab --- /dev/null +++ b/wrappers/python/aries_askar/bindings/lib.py @@ -0,0 +1,672 @@ +"""Library instance and allocated buffer handling.""" + +import asyncio +import json +import itertools +import logging +import os +import sys +import threading +import time + +from ctypes import ( + Array, + CDLL, + CFUNCTYPE, + POINTER, + Structure, + addressof, + byref, + cast, + c_char, + c_char_p, + c_int8, + c_int32, + c_int64, + c_ubyte, + c_void_p, +) +from ctypes.util import find_library +from typing import Callable, Optional, Tuple, Union +from weakref import finalize, ref + +from ..error import AskarError, AskarErrorCode + + +LOGGER = logging.getLogger(__name__) +MODULE_NAME = __name__.split(".")[0] + +LOG_LEVELS = { + 1: logging.ERROR, + 2: logging.WARNING, + 3: logging.INFO, + 4: logging.DEBUG, +} + + +def _convert_log_level(level: Union[str, int, None]): + if level is None or level == "-1": + return -1 + else: + if isinstance(level, str): + level = level.upper() + name = logging.getLevelName(level) + for k, v in LOG_LEVELS.items(): + if logging.getLevelName(v) == name: + return k + return 0 + + +def _load_method_arguments(name, argtypes, args): + """Preload argument values to avoid freeing any intermediate data.""" + if not argtypes: + return args + if len(args) != len(argtypes): + raise ValueError(f"{name}: Arguments length does not match argtypes length") + return [ + arg if hasattr(argtype, "_type_") else argtype.from_param(arg) + for (arg, argtype) in zip(args, argtypes) + ] + + +def _struct_dtor(ctype: type, address: int, dtor: Callable): + value = ctype.from_address(address) + if value: + dtor(value) + + +def finalize_struct(instance, ctype): + finalize( + instance, _struct_dtor, ctype, addressof(instance), instance.__class__._cleanup + ) + + +class LibLoad: + def __init__(self, lib_name: str): + """Load the CDLL library. + + The python module directory is searched first, followed by the usual + library resolution for the current system. + """ + self._cdll = None + self._callbacks = {} + self._cb_id = itertools.count(0) + self._cfuncs = {} + self._lib_name = lib_name + self._log_cb = None + self._log_enabled_cb = None + self._methods = {} + + self._load_library() + self._init_logger() + + def _load_library(self): + lib_name = self._lib_name + lib_prefix_mapping = {"win32": ""} + lib_suffix_mapping = {"darwin": ".dylib", "win32": ".dll"} + try: + os_name = sys.platform + lib_prefix = lib_prefix_mapping.get(os_name, "lib") + lib_suffix = lib_suffix_mapping.get(os_name, ".so") + lib_path = os.path.join( + os.path.dirname(__file__), "..", f"{lib_prefix}{lib_name}{lib_suffix}" + ) + self._cdll = CDLL(lib_path) + return + except KeyError: + LOGGER.debug("Unknown platform for shared library") + except OSError: + LOGGER.warning("Library not loaded from python package") + + lib_path = find_library(lib_name) + if not lib_path: + raise AskarError( + AskarErrorCode.WRAPPER, f"Library not found in path: {lib_name}" + ) + try: + self._cdll = CDLL(lib_path) + except OSError as e: + raise AskarError( + AskarErrorCode.WRAPPER, f"Error loading library: {lib_path}" + ) from e + + def _init_logger(self): + if self._log_cb: + return + + logger = logging.getLogger(MODULE_NAME) + if logging.getLevelName("TRACE") == "Level TRACE": + # avoid redefining TRACE if another library has added it + logging.addLevelName(5, "TRACE") + + self._log_cb_t = CFUNCTYPE( + None, c_void_p, c_int32, c_char_p, c_char_p, c_char_p, c_char_p, c_int32 + ) + + def _log_cb( + _context, + level: int, + target: c_char_p, + message: c_char_p, + _module_path: c_char_p, + file_name: c_char_p, + line: int, + ): + logger.getChild("native." + target.decode().replace("::", ".")).log( + LOG_LEVELS.get(level, level), + "\t%s:%d | %s", + file_name.decode() if file_name else None, + line, + message.decode(), + ) + + self._log_cb = self._log_cb_t(_log_cb) + + self._log_enabled_cb_t = CFUNCTYPE(c_int8, c_void_p, c_int32) + + def _enabled_cb(_context, level: int) -> bool: + return self._cdll and logger.isEnabledFor(LOG_LEVELS.get(level, level)) + + self._log_enabled_cb = self._log_enabled_cb_t(_enabled_cb) + + if os.getenv("RUST_LOG"): + # level from environment + level = -1 + else: + # inherit current level from logger + level = _convert_log_level(logger.level or logger.parent.level) + + set_logger = self.method( + "askar_set_custom_logger", + (c_void_p, c_void_p, c_void_p, c_void_p, c_int32), + restype=c_int64, + ) + if set_logger( + None, # context + self._log_cb, + self._log_enabled_cb, + None, # flush + level, + ): + raise self.get_current_error(True) + + try: + finalize(self, self.method("askar_clear_custom_logger", None, restype=None)) + except AttributeError: + # method is new as of 0.2.5 + pass + + def invoke(self, name, argtypes, *args): + """Perform a synchronous library function call.""" + method = self.method(name, argtypes, restype=c_int64) + args = _load_method_arguments(name, argtypes, args) + result = method(*args) + if result: + raise self.get_current_error(True) + + def invoke_async(self, name: str, argtypes, *args, return_type=None): + """Perform an asynchronous library function call.""" + method = self.method(name, (*argtypes, c_void_p, c_int64), restype=c_int64) + loop = asyncio.get_event_loop() + fut = loop.create_future() + cb_info = self._cfuncs.get(name) + if cb_info: + cb = cb_info[1] + else: + cb_args = [c_int64, c_int64] + if return_type: + cb_args.append(return_type) + cb_type = CFUNCTYPE(None, *cb_args) + cb = cb_type(self._handle_callback) + # must maintain a reference to cb_type, otherwise + # it may be freed, resulting in memory errors. + self._cfuncs[name] = (cb_type, cb) + args = _load_method_arguments(name, argtypes, args) + cb_id = next(self._cb_id) + self._callbacks[cb_id] = (loop, fut, name) + result = method(*args, cb, cb_id) + if result: + # FFI must not execute the callback if an error is returned + err = self.get_current_error(True) + if self._callbacks.pop(cb_id, None): + self._fulfill_future(fut, None, err) + return fut + + def invoke_dtor(self, name: str, *values, argtypes=None, restype=None): + method = self.method(name, argtypes, restype=restype) + if method: + method(*values) + + def _handle_callback(self, cb_id: int, err: int, result=None): + exc = self.get_current_error(True) if err else None + cb = self._callbacks.pop(cb_id, None) + if not cb: + LOGGER.info("Callback already fulfilled: %s", cb_id) + return + (loop, fut, _name) = cb + loop.call_soon_threadsafe(self._fulfill_future, fut, result, exc) + + def _fulfill_future(self, fut: asyncio.Future, result, err: Exception = None): + """Resolve a callback future given the result and exception, if any.""" + if fut.cancelled(): + LOGGER.debug("callback previously cancelled") + elif err: + fut.set_exception(err) + else: + fut.set_result(result) + + def get_current_error(self, expect: bool = False) -> Optional[AskarError]: + """ + Get the error result from the previous failed API method. + + Args: + expect: Return a default error message if none is found + """ + err_json = StrBuffer() + method = self.method( + "askar_get_current_error", (POINTER(StrBuffer),), restype=c_int64 + ) + if not method(byref(err_json)): + try: + msg = json.loads(err_json.value) + except json.JSONDecodeError: + LOGGER.warning("JSON decode error for askar_get_current_error") + msg = None + if msg and "message" in msg and "code" in msg: + return AskarError( + AskarErrorCode(msg["code"]), msg["message"], msg.get("extra") + ) + if not expect: + return None + return AskarError(AskarErrorCode.WRAPPER, "Unknown error") + + def method(self, name, argtypes, *, restype=None): + """Access a method of the library.""" + method = self._methods.get(name) + if not method: + method = getattr(self._cdll, name, None) + if not method: + return None + if argtypes: + method.argtypes = argtypes + method.restype = restype + self._methods[name] = method + return method + + def _cleanup(self): + """Destructor.""" + if self._callbacks: + + def _wait_callbacks(cb): + while cb: + time.sleep(0.01) + + th = threading.Thread(target=_wait_callbacks, args=(self._callbacks,)) + th.start() + th.join(timeout=1.0) + if th.is_alive(): + LOGGER.error( + "%s: Timed out waiting for callbacks to complete", + self._lib_name, + ) + + +class Lib: + """The loaded library instance.""" + + INSTANCE = None + LIB_NAME = "aries_askar" + + def __new__(cls, *args): + """Class initializer.""" + inst = cls.INSTANCE and cls.INSTANCE() + if inst is None: + inst = super().__new__(cls, *args) + inst._lib = None + inst._objs = [] + # Keep a weak reference to the instance. This assumes that + # at least one instance is assigned to a persistent variable. + cls.INSTANCE = ref(inst) + # Register finalizer to be called later than any derived objects. + finalize(inst, cls._cleanup, inst._objs) + return inst + + @property + def loaded(self) -> LibLoad: + """Determine if the library has been loaded.""" + if not self._lib: + self._lib = LibLoad(self.__class__.LIB_NAME) + self._objs.append(self._lib) + return self._lib + + def invoke(self, name, argtypes, *args): + """Perform a synchronous library function call.""" + self.loaded.invoke(name, argtypes, *args) + + async def invoke_async(self, name: str, argtypes, *args, return_type=None): + """Perform an asynchronous library function call.""" + return await self.loaded.invoke_async( + name, argtypes, *args, return_type=return_type + ) + + def invoke_dtor(self, name: str, *args, argtypes=None, restype=None): + """Call a destructor method.""" + if self._lib: + self._lib.invoke_dtor(name, *args, argtypes=argtypes, restype=restype) + + def set_max_log_level(self, level: Union[str, int, None]): + """Set the maximum log level for the library.""" + set_level = _convert_log_level(level) + self.invoke("askar_set_max_log_level", (c_int32,), set_level) + + def version(self) -> str: + """Get the version of the installed library.""" + return str( + self.loaded._method( + "askar_version", + None, + restype=StrBuffer, + )() + ) + + def __repr__(self) -> str: + loaded = self._lib is not None + return f"" + + @classmethod + def _cleanup(cls, objs): + for obj in objs: + obj._cleanup() + + +class RawBuffer(Structure): + """A byte buffer allocated by the library.""" + + _fields_ = [ + ("len", c_int64), + ("data", POINTER(c_ubyte)), + ] + + def __bool__(self) -> bool: + return bool(self.data) + + def __bytes__(self) -> bytes: + if not self.len: + return b"" + return bytes(self.array) + + def __len__(self) -> int: + return int(self.len) + + @property + def array(self) -> Array: + return cast(self.data, POINTER(c_ubyte * self.len)).contents + + def __repr__(self) -> str: + return f"" + + +class FfiByteBuffer: + """A byte buffer allocated by Python.""" + + def __init__(self, value): + if isinstance(value, str): + value = value.encode("utf-8") + + if value is None: + dlen = 0 + data = c_char_p() + elif isinstance(value, memoryview): + dlen = value.nbytes + data = c_char_p(value.tobytes()) + elif isinstance(value, bytes): + dlen = len(value) + data = c_char_p(value) + else: + raise TypeError(f"Expected str or bytes value, got {type(value)}") + self._dlen = dlen + self._data = data + + def __bytes__(self) -> bytes: + if not self._data: + return b"" + return self._data.value + + def __len__(self) -> int: + return self._dlen + + @property + def _as_parameter_(self) -> RawBuffer: + buf = RawBuffer(len=self._dlen, data=cast(self._data, POINTER(c_ubyte))) + return buf + + @classmethod + def from_param(cls, value): + if isinstance(value, (ByteBuffer, FfiByteBuffer)): + return value + return cls(value) + + +class ByteBuffer(Structure): + """A managed byte buffer allocated by the library.""" + + _fields_ = [("buffer", RawBuffer)] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._a = addressof(self) + finalize_struct(self, RawBuffer) + + @property + def _as_parameter_(self): + return self.buffer + + @property + def array(self) -> Array: + return self.buffer.array + + @property + def view(self) -> memoryview: + m = memoryview(self.array) + # ensure self stays alive until the view is dropped + finalize(m, lambda _: (), self) + return m + + def __bytes__(self) -> bytes: + return bytes(self.buffer) + + def __len__(self) -> int: + return len(self.buffer) + + def __getitem__(self, idx) -> bytes: + return bytes(self.buffer.array[idx]) + + def __repr__(self) -> str: + """Format byte buffer as a string.""" + return f"{self.__class__.__name__}({bytes(self)})" + + @classmethod + def _cleanup(cls, buffer: RawBuffer): + """Call the byte buffer destructor when this instance is released.""" + Lib().invoke_dtor("askar_buffer_free", buffer) + + +class FfiStr: + """A string value allocated by Python.""" + + def __init__(self, value=None): + if value is None: + value = c_char_p() + elif isinstance(value, c_char_p): + pass + else: + if isinstance(value, str): + value = value.encode("utf-8") + if not isinstance(value, bytes): + raise TypeError(f"Expected string value, got {type(value)}") + value = c_char_p(value) + self.value = value + + @classmethod + def from_param(cls, value): + if isinstance(value, cls): + return value + return cls(value) + + @property + def _as_parameter_(self): + return self.value + + def __repr__(self) -> str: + """Format handle as a string.""" + return f"{self.__class__.__name__}({self.value})" + + +class FfiJson: + @classmethod + def from_param(cls, value): + if isinstance(value, FfiStr): + return value + if isinstance(value, dict): + value = json.dumps(value) + return FfiStr(value) + + +class FfiTagsJson: + @classmethod + def from_param(cls, tags): + if isinstance(tags, FfiStr): + return tags + if tags: + tags = json.dumps( + { + name: (list(value) if isinstance(value, set) else value) + for name, value in tags.items() + } + ) + else: + tags = None + return FfiStr(tags) + + +class StrBuffer(Structure): + """A string allocated by the library.""" + + _fields_ = [("buffer", POINTER(c_char))] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + finalize_struct(self, c_char_p) + + def is_none(self) -> bool: + """Check if the returned string pointer is null.""" + return not self.buffer + + def opt_str(self) -> Optional[str]: + """Convert to an optional string.""" + val = self.value + return val.decode("utf-8") if val is not None else None + + def __bool__(self) -> bool: + return bool(self.buffer) + + def __bytes__(self) -> bytes: + """Convert to bytes.""" + bval = self.value + return bval if bval is not None else bytes() + + def __str__(self): + """Convert to a string.""" + # not allowed to return None + val = self.opt_str() + return val if val is not None else "" + + @property + def value(self) -> bytes: + return cast(self.buffer, c_char_p).value + + @classmethod + def _cleanup(cls, buffer: c_char_p): + """Call the string destructor when this instance is released.""" + Lib().invoke_dtor("askar_string_free", buffer) + + +class AeadParams(Structure): + """A byte buffer allocated by the library.""" + + _fields_ = [ + ("nonce_length", c_int32), + ("tag_length", c_int32), + ] + + def __repr__(self) -> str: + """Format AEAD params as a string.""" + return ( + f"" + ) + + +class Encrypted(Structure): + """The result of an AEAD encryption operation.""" + + _fields_ = [ + ("buffer", RawBuffer), + ("tag_pos", c_int64), + ("nonce_pos", c_int64), + ] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + finalize_struct(self, RawBuffer) + + def __getitem__(self, idx) -> bytes: + return bytes(self.buffer.array[idx]) + + def __bytes__(self) -> bytes: + """Convert to bytes.""" + return self.ciphertext_tag + + @property + def ciphertext_tag(self) -> bytes: + """Accessor for the combined ciphertext and tag.""" + p = self.nonce_pos + return self[:p] + + @property + def ciphertext(self) -> bytes: + """Accessor for the ciphertext.""" + p = self.tag_pos + return self[:p] + + @property + def nonce(self) -> bytes: + """Accessor for the nonce.""" + p = self.nonce_pos + return self[p:] + + @property + def tag(self) -> bytes: + """Accessor for the authentication tag.""" + p1 = self.tag_pos + p2 = self.nonce_pos + return self[p1:p2] + + @property + def parts(self) -> Tuple[bytes, bytes, bytes]: + """Accessor for the ciphertext, tag, and nonce.""" + p1 = self.tag_pos + p2 = self.nonce_pos + return self[:p1], self[p1:p2], self[p2:] + + def __repr__(self) -> str: + """Format encrypted value as a string.""" + return ( + f"" + ) + + @classmethod + def _cleanup(cls, buffer: RawBuffer): + """Call the byte buffer destructor when this instance is released.""" + Lib().invoke_dtor("askar_buffer_free", buffer) diff --git a/wrappers/python/aries_askar/key.py b/wrappers/python/aries_askar/key.py index e1602f79..39d1a09c 100644 --- a/wrappers/python/aries_askar/key.py +++ b/wrappers/python/aries_askar/key.py @@ -3,8 +3,7 @@ from typing import Union from . import bindings - -from .bindings import Encrypted, LocalKeyHandle +from .bindings import AeadParams, Encrypted, LocalKeyHandle from .types import KeyAlg, SeedMethod @@ -76,7 +75,7 @@ def get_jwk_secret(self) -> str: def get_jwk_thumbprint(self, alg: Union[str, KeyAlg] = None) -> str: return bindings.key_get_jwk_thumbprint(self._handle, alg) - def aead_params(self) -> bindings.AeadParams: + def aead_params(self) -> AeadParams: return bindings.key_aead_get_params(self._handle) def aead_random_nonce(self) -> bytes: diff --git a/wrappers/python/aries_askar/store.py b/wrappers/python/aries_askar/store.py index 5c969fc7..67b0565a 100644 --- a/wrappers/python/aries_askar/store.py +++ b/wrappers/python/aries_askar/store.py @@ -3,11 +3,11 @@ import json from typing import Optional, Sequence, Union +from weakref import ref from cached_property import cached_property from . import bindings - from .bindings import ( ByteBuffer, EntryListHandle, @@ -47,7 +47,7 @@ def value(self) -> bytes: return bytes(self.raw_value) @cached_property - def raw_value(self) -> ByteBuffer: + def raw_value(self) -> memoryview: """Accessor for the entry raw value.""" return self._list.get_value(self._pos) @@ -333,11 +333,13 @@ async def remove(cls, uri: str) -> bool: async def __aenter__(self) -> "Session": if not self._opener: - self._opener = OpenSession(self, None, False) + self._opener = OpenSession(self._handle, None, False) return await self._opener.__aenter__() async def __aexit__(self, exc_type, exc, tb): - return await self._opener.__aexit__(exc_type, exc, tb) + opener = self._opener + self._opener = None + return await opener.__aexit__(exc_type, exc, tb) async def create_profile(self, name: str = None) -> str: return await bindings.store_create_profile(self._handle, name) @@ -366,10 +368,10 @@ def scan( return Scan(self, profile, category, tag_filter, offset, limit) def session(self, profile: str = None) -> "OpenSession": - return OpenSession(self, profile, False) + return OpenSession(self._handle, profile, False) def transaction(self, profile: str = None) -> "OpenSession": - return OpenSession(self, profile, True) + return OpenSession(self._handle, profile, True) async def close(self, *, remove: bool = False) -> bool: """Close and free the pool instance.""" @@ -389,7 +391,7 @@ def __repr__(self) -> str: class Session: """An opened Session instance.""" - def __init__(self, store: Store, handle: SessionHandle, is_txn: bool): + def __init__(self, store: StoreHandle, handle: SessionHandle, is_txn: bool): """Initialize the Session instance.""" self._store = store self._handle = handle @@ -405,11 +407,6 @@ def handle(self) -> SessionHandle: """Accessor for the SessionHandle instance.""" return self._handle - @property - def store(self) -> Store: - """Accessor for the Store instance.""" - return self._store - async def count(self, category: str, tag_filter: Union[str, dict] = None) -> int: if not self._handle: raise AskarError(AskarErrorCode.WRAPPER, "Cannot count from closed session") @@ -595,39 +592,38 @@ def __repr__(self) -> str: class OpenSession: - def __init__(self, store: Store, profile: Optional[str], is_txn: bool): + def __init__(self, store: StoreHandle, profile: Optional[str], is_txn: bool): """Initialize the OpenSession instance.""" self._store = store self._profile = profile self._is_txn = is_txn - self._session = None + self._session: Session = None @property def is_transaction(self) -> bool: return self._is_txn async def _open(self) -> Session: - if not self._store.handle: + if not self._store: raise AskarError( AskarErrorCode.WRAPPER, "Cannot start session from closed store" ) if self._session: raise AskarError(AskarErrorCode.WRAPPER, "Session already opened") - self._session = Session( + return Session( self._store, - await bindings.session_start( - self._store.handle, self._profile, self._is_txn - ), + await bindings.session_start(self._store, self._profile, self._is_txn), self._is_txn, ) - return self._session def __await__(self) -> Session: return self._open().__await__() async def __aenter__(self) -> Session: - return await self._open() + self._session = await self._open() + return self._session async def __aexit__(self, exc_type, exc, tb): - await self._session.close() + session = self._session self._session = None + await session.close() diff --git a/wrappers/python/tests/test_cleanup.py b/wrappers/python/tests/test_cleanup.py new file mode 100644 index 00000000..4a1d799f --- /dev/null +++ b/wrappers/python/tests/test_cleanup.py @@ -0,0 +1,60 @@ +from ctypes import c_char, c_char_p, c_size_t, c_ubyte, pointer +from unittest import mock + +from aries_askar.bindings.handle import ArcHandle +from aries_askar.bindings.lib import ByteBuffer, RawBuffer, StrBuffer + + +def test_cleanup_handle(): + logged = [] + + class Handle(ArcHandle): + @classmethod + def _cleanup(cls, handle: c_size_t): + logged.append(handle.value) + + h = Handle() + assert not h.value + del h + assert not logged + + h = Handle() + h.value = 99 + del h + assert logged == [(99)] + + +def test_cleanup_bytebuffer(): + logged = [] + + def cleanup(buffer: RawBuffer): + logged.append((buffer.len, buffer.data.contents.value if buffer.data else None)) + + with mock.patch.object(ByteBuffer, "_cleanup", cleanup): + b = ByteBuffer() + del b + assert not logged + + c = c_ubyte(99) + b = ByteBuffer() + b.buffer = RawBuffer(len=1, data=pointer(c)) + del b + assert logged == [(1, 99)] + + +def test_cleanup_strbuffer(): + logged = [] + + def cleanup(buffer: c_char_p): + logged.append(buffer.value) + + with mock.patch.object(StrBuffer, "_cleanup", cleanup): + s = StrBuffer() + del s + assert not logged + + s = StrBuffer() + c = c_char(ord("a")) + s.buffer = pointer(c) + del s + assert logged == [b"a"] diff --git a/wrappers/python/tests/test_jose_ecdh.py b/wrappers/python/tests/test_jose_ecdh.py index 5cc2fdb2..43a155a8 100644 --- a/wrappers/python/tests/test_jose_ecdh.py +++ b/wrappers/python/tests/test_jose_ecdh.py @@ -113,8 +113,6 @@ def test_ecdh_1pu_direct(): KeyAlg.A256GCM, ephem_key, alice_key, bob_jwk, message, aad=protected_b64 ) ciphertext, tag, nonce = encrypted_msg.parts - print("enc", *encrypted_msg.parts) - print("enc", *encrypted_msg.parts) # switch to receiver