diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 7eeab8ea..19f1b397 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -16,10 +16,11 @@ name: "Aries-Askar" jobs: check: - name: Run Checks + name: Run checks strategy: + fail-fast: false matrix: - os: [macos-11, windows-latest, ubuntu-latest] + os: [ubuntu-latest, macos-11, windows-latest] runs-on: ${{ matrix.os }} steps: @@ -30,18 +31,15 @@ jobs: uses: actions-rs/toolchain@v1 with: profile: minimal - toolchain: stable + toolchain: 1.56 + override: true + components: clippy, rustfmt - name: Cache cargo resources uses: Swatinem/rust-cache@v1 with: - sharedKey: deps - - - name: Cargo check - uses: actions-rs/cargo@v1 - with: - command: check - args: --workspace + sharedKey: check + cache-on-failure: true - name: Cargo fmt uses: actions-rs/cargo@v1 @@ -49,17 +47,35 @@ jobs: command: fmt args: --all -- --check + - name: Cargo check + uses: actions-rs/cargo@v1 + with: + command: check + args: --workspace + - name: Debug build uses: actions-rs/cargo@v1 with: command: build args: --all-targets - - name: Test + - if: "runner.os == 'Linux'" + name: Start postgres (Linux) + run: | + sudo systemctl start postgresql.service + pg_isready + sudo -u postgres psql -c "ALTER USER postgres WITH PASSWORD 'postgres'" + echo "POSTGRES_URL=postgres://postgres:postgres@localhost:5432/test-db" >> $GITHUB_ENV + echo "TEST_FEATURES=pg_test" >> $GITHUB_ENV + + - name: Run tests uses: actions-rs/cargo@v1 with: command: test - args: --workspace + args: --workspace --features "${{ env.TEST_FEATURES }}" -- --nocapture --test-threads 1 --skip contention + env: + RUST_BACKTRACE: full + # RUST_LOG: debug - name: Test askar-crypto no_std uses: actions-rs/cargo@v1 @@ -73,52 +89,12 @@ jobs: command: test args: --manifest-path ./askar-bbs/Cargo.toml --no-default-features - test-postgres: - name: Postgres - runs-on: ubuntu-latest - needs: [check] - - services: - postgres: - image: postgres - env: - POSTGRES_PASSWORD: postgres - options: >- - --health-cmd pg_isready - --health-interval 10s - --health-timeout 5s - --health-retries 5 - ports: - - 5432:5432 - - steps: - - name: Checkout - uses: actions/checkout@v2 - - - name: Install Rust toolchain - uses: actions-rs/toolchain@v1 - with: - profile: minimal - toolchain: stable - - - name: Cache cargo resources - uses: Swatinem/rust-cache@v1 - with: - sharedKey: deps - - - name: Test - uses: actions-rs/cargo@v1 - env: - POSTGRES_URL: postgres://postgres:postgres@localhost:5432/test-db - with: - command: test - args: --features pg_test - build-manylinux: - name: Build Library + name: Build (manylinux) needs: [check] strategy: + fail-fast: false matrix: include: - os: ubuntu-latest @@ -136,16 +112,18 @@ jobs: uses: actions-rs/toolchain@v1 with: profile: minimal - toolchain: stable + toolchain: 1.56 + override: true - name: Cache cargo resources uses: Swatinem/rust-cache@v1 with: - sharedKey: deps + sharedKey: check - name: Build library env: BUILD_TARGET: ${{ matrix.target }} + # LIBSQLITE3_FLAGS: SQLITE_DEBUG SQLITE_MEMDEBUG run: sh ./build.sh - name: Upload library artifacts @@ -154,20 +132,21 @@ jobs: name: library-${{ runner.os }} path: target/release/${{ matrix.lib }} - build-other: - name: Build Library + build-native: + name: Build (native) needs: [check] strategy: + fail-fast: false matrix: include: - os: macos-11 lib: libaries_askar.dylib target: apple-darwin # creates a universal library - toolchain: beta # beta required for aarch64-apple-darwin target + toolchain: nightly-2021-10-21 # beta required for aarch64-apple-darwin target - os: windows-latest lib: aries_askar.dll - toolchain: stable + toolchain: 1.56 runs-on: ${{ matrix.os }} @@ -180,16 +159,18 @@ jobs: with: profile: minimal toolchain: ${{ matrix.toolchain }} + override: true - name: Cache cargo resources uses: Swatinem/rust-cache@v1 with: - sharedKey: deps + sharedKey: check - name: Build library env: BUILD_TARGET: ${{ matrix.target }} BUILD_TOOLCHAIN: ${{ matrix.toolchain }} + # LIBSQLITE3_FLAGS: SQLITE_DEBUG SQLITE_MEMDEBUG run: sh ./build.sh - name: Upload library artifacts @@ -199,10 +180,11 @@ jobs: path: target/release/${{ matrix.lib }} build-py: - name: Build Python - needs: [build-manylinux, build-other] + name: Build Python packages + needs: [build-manylinux, build-native] strategy: + fail-fast: false matrix: os: [ubuntu-latest, macos-11, windows-latest] python-version: [3.7] @@ -236,19 +218,42 @@ jobs: name: library-${{ runner.os }} path: wrappers/python/aries_askar/ - - name: Build and test python package + - if: "runner.os == 'Linux'" + name: Start postgres (Linux) + run: | + sudo systemctl start postgresql.service + pg_isready + sudo -u postgres psql -c "ALTER USER postgres WITH PASSWORD 'postgres'" + echo "POSTGRES_URL=postgres://postgres:postgres@localhost:5432/test-db" >> $GITHUB_ENV + + - name: Build package shell: sh run: | python setup.py bdist_wheel --python-tag=py3 --plat-name=${{ matrix.plat-name }} - pip install pytest pytest-asyncio dist/* - python -m pytest - TEST_STORE_URI=sqlite://test.db python -m pytest working-directory: wrappers/python - if: "runner.os == 'Linux'" - name: Auditwheel + name: Audit wheel run: auditwheel show wrappers/python/dist/* + - name: Test package + shell: sh + run: | + pip install pytest pytest-asyncio dist/* + echo "-- Test SQLite in-memory --" + python -m pytest --log-cli-level=WARNING -k "not contention" + echo "-- Test SQLite file DB --" + TEST_STORE_URI=sqlite://test.db python -m pytest --log-cli-level=WARNING -k "not contention" + if [ -n "$POSTGRES_URL" ]; then + echo "-- Test Postgres DB --" + TEST_STORE_URI="$POSTGRES_URL" python -m pytest --log-cli-level=WARNING -k "not contention" + fi + working-directory: wrappers/python + env: + no_proxy: "*" # python issue 30385 + RUST_BACKTRACE: full + # RUST_LOG: debug + - name: Upload python package uses: actions/upload-artifact@v2 with: diff --git a/Cargo.toml b/Cargo.toml index 7db2ba0e..af2f5141 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ members = ["askar-bbs", "askar-crypto"] [package] name = "aries-askar" -version = "0.2.4" +version = "0.2.5" authors = ["Hyperledger Aries Contributors "] edition = "2018" description = "Hyperledger Aries Askar secure storage" @@ -27,7 +27,7 @@ rustdoc-args = ["--cfg", "docsrs"] default = ["all_backends", "ffi", "logger"] all_backends = ["any", "postgres", "sqlite"] any = [] -ffi = ["any", "ffi-support", "logger", "option-lock"] +ffi = ["any", "ffi-support", "logger"] jemalloc = ["jemallocator"] logger = ["env_logger", "log"] postgres = ["sqlx", "sqlx/postgres", "sqlx/tls"] @@ -38,7 +38,7 @@ pg_test = ["postgres"] hex-literal = "0.3" [dependencies] -async-lock = "2.4" +async-lock = "2.5" async-stream = "0.3" bs58 = "0.4" chrono = "0.4" @@ -53,7 +53,6 @@ itertools = "0.10" jemallocator = { version = "0.3", optional = true } log = { version = "0.4", optional = true } num_cpus = { version = "1.0", optional = true } -option-lock = { version = "0.3", optional = true } once_cell = "1.5" percent-encoding = "2.0" serde = { version = "1.0", features = ["derive"] } @@ -67,18 +66,18 @@ uuid = { version = "0.8", features = ["v4"] } zeroize = "1.4" [dependencies.askar-crypto] -version = "0.2" +version = "0.2.5" path = "./askar-crypto" features = ["all_keys", "any_key", "argon2", "crypto_box", "std"] [dependencies.sqlx] -version = "0.5.11" +version = "0.5.12" default-features = false features = ["chrono", "runtime-tokio-rustls"] optional = true [profile.release] -lto = true + lto = true codegen-units = 1 [[test]] diff --git a/askar-crypto/Cargo.toml b/askar-crypto/Cargo.toml index 7f7cb7f7..1848df01 100644 --- a/askar-crypto/Cargo.toml +++ b/askar-crypto/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "askar-crypto" -version = "0.2.4" +version = "0.2.5" authors = ["Hyperledger Aries Contributors "] edition = "2018" description = "Hyperledger Aries Askar cryptography" diff --git a/askar-crypto/src/alg/any.rs b/askar-crypto/src/alg/any.rs index 1f5dfbf9..e54c85b8 100644 --- a/askar-crypto/src/alg/any.rs +++ b/askar-crypto/src/alg/any.rs @@ -4,6 +4,7 @@ use core::convert::TryFrom; use core::{ any::{Any, TypeId}, fmt::Debug, + panic::{RefUnwindSafe, UnwindSafe}, }; #[cfg(feature = "aes")] @@ -54,10 +55,10 @@ use super::EcCurves; use crate::kdf::{FromKeyDerivation, FromKeyExchange}; #[derive(Debug)] -pub struct KeyT(T); +pub struct KeyT(T); /// The type-erased representation for a concrete key instance -pub type AnyKey = KeyT; +pub type AnyKey = KeyT; impl AnyKey { pub fn algorithm(&self) -> KeyAlg { @@ -79,12 +80,6 @@ impl AnyKey { } } -// key instances are immutable -#[cfg(feature = "std")] -impl std::panic::UnwindSafe for AnyKey {} -#[cfg(feature = "std")] -impl std::panic::RefUnwindSafe for AnyKey {} - /// Create `AnyKey` instances from various sources pub trait AnyKeyCreate: Sized { /// Generate a new key from a key material generator for the given key algorithm. @@ -108,7 +103,7 @@ pub trait AnyKeyCreate: Sized { fn from_secret_bytes(alg: KeyAlg, secret: &[u8]) -> Result; /// Convert from a concrete key instance - fn from_key(key: K) -> Self; + fn from_key(key: K) -> Self; /// Create a new key instance from a key exchange fn from_key_exchange(alg: KeyAlg, secret: &Sk, public: &Pk) -> Result @@ -137,7 +132,7 @@ impl AnyKeyCreate for Box { } #[inline(always)] - fn from_key(key: K) -> Self { + fn from_key(key: K) -> Self { Box::new(KeyT(key)) } @@ -172,7 +167,7 @@ impl AnyKeyCreate for Arc { } #[inline(always)] - fn from_key(key: K) -> Self { + fn from_key(key: K) -> Self { Arc::new(KeyT(key)) } @@ -819,19 +814,19 @@ impl KeySigVerify for AnyKey { // may want to implement in-place initialization to avoid copies trait AllocKey { - fn alloc_key(key: K) -> Self; + fn alloc_key(key: K) -> Self; } impl AllocKey for Arc { #[inline(always)] - fn alloc_key(key: K) -> Self { + fn alloc_key(key: K) -> Self { Self::from_key(key) } } impl AllocKey for Box { #[inline(always)] - fn alloc_key(key: K) -> Self { + fn alloc_key(key: K) -> Self { Self::from_key(key) } } diff --git a/askar-crypto/src/buffer/secret.rs b/askar-crypto/src/buffer/secret.rs index 6ee40631..81982882 100644 --- a/askar-crypto/src/buffer/secret.rs +++ b/askar-crypto/src/buffer/secret.rs @@ -47,6 +47,12 @@ impl SecretBytes { Self(v) } + /// Accessor for the current capacity of the buffer + #[inline] + pub fn capacity(&self) -> usize { + self.0.capacity() + } + /// Accessor for the length of the buffer contents #[inline] pub fn len(&self) -> usize { @@ -66,7 +72,7 @@ impl SecretBytes { } else if cap > 0 && min_cap >= cap { // allocate a new buffer and copy the secure data over let new_cap = min_cap.max(cap * 2).max(32); - let mut buf = SecretBytes::with_capacity(new_cap); + let mut buf = Self::with_capacity(new_cap); buf.0.extend_from_slice(&self.0[..]); mem::swap(&mut buf, self); // old buf zeroized on drop @@ -93,28 +99,31 @@ impl SecretBytes { self.ensure_capacity(self.len() + extra) } - /// Convert this buffer into a boxed slice - pub fn into_boxed_slice(mut self) -> Box<[u8]> { + /// Shrink the buffer capacity to match the length + pub fn shrink_to_fit(&mut self) { let len = self.0.len(); if self.0.capacity() > len { // copy to a smaller buffer (capacity is not tracked for boxed slice) // and proceed with the normal zeroize on drop - let mut v = Vec::with_capacity(len); - v.append(&mut self.0); - v.into_boxed_slice() - } else { - // no realloc and copy needed - self.into_vec().into_boxed_slice() + let mut buf = Self::with_capacity(len); + buf.0.extend_from_slice(&self.0[..]); + mem::swap(&mut buf, self); + // old buf zeroized on drop } } + /// Convert this buffer into a boxed slice + pub fn into_boxed_slice(mut self) -> Box<[u8]> { + self.shrink_to_fit(); + self.into_vec().into_boxed_slice() + } + /// Unwrap this buffer into a Vec #[inline] pub fn into_vec(mut self) -> Vec { // FIXME zeroize extra capacity in case it was used previously? let mut v = Vec::new(); // note: no heap allocation for empty vec mem::swap(&mut v, &mut self.0); - mem::forget(self); v } diff --git a/src/backend/any.rs b/src/backend/any.rs index 4c440e79..1d084c6c 100644 --- a/src/backend/any.rs +++ b/src/backend/any.rs @@ -270,7 +270,11 @@ impl<'a> ManageBackend<'a> for &'a str { Ok(Store::new(AnyBackend::Sqlite(mgr.into_inner()))) } - _ => Err(err_msg!(Unsupported, "Invalid backend: {}", &opts.schema)), + _ => Err(err_msg!( + Unsupported, + "Unsupported backend: {}", + &opts.schema + )), } }) } @@ -301,7 +305,11 @@ impl<'a> ManageBackend<'a> for &'a str { Ok(Store::new(AnyBackend::Sqlite(mgr.into_inner()))) } - _ => Err(err_msg!(Unsupported, "Invalid backend: {}", &opts.schema)), + _ => Err(err_msg!( + Unsupported, + "Unsupported backend: {}", + &opts.schema + )), } }) } @@ -324,7 +332,11 @@ impl<'a> ManageBackend<'a> for &'a str { Ok(opts.remove().await?) } - _ => Err(err_msg!(Unsupported, "Invalid backend: {}", &opts.schema)), + _ => Err(err_msg!( + Unsupported, + "Unsupported backend: {}", + &opts.schema + )), } }) } diff --git a/src/backend/db_utils.rs b/src/backend/db_utils.rs index 97bb94a9..0f043188 100644 --- a/src/backend/db_utils.rs +++ b/src/backend/db_utils.rs @@ -551,32 +551,24 @@ pub fn encode_tag_filter( } } -// convert a slice of tags into a Vec, when ensuring there is +// allocate a String while ensuring there is sufficient capacity to reuse during encryption +fn _prepare_string(value: &str) -> String { + let buf = ProfileKey::prepare_input(value.as_bytes()).into_vec(); + unsafe { String::from_utf8_unchecked(buf) } +} + +// convert a slice of tags into a Vec, while ensuring there is // adequate space in the allocations to reuse them during encryption pub fn prepare_tags(tags: &[EntryTag]) -> Result, Error> { let mut result = Vec::with_capacity(tags.len()); for tag in tags { result.push(match tag { - EntryTag::Plaintext(name, value) => EntryTag::Plaintext( - unsafe { - String::from_utf8_unchecked( - ProfileKey::prepare_input(name.as_bytes()).into_vec(), - ) - }, - value.clone(), - ), - EntryTag::Encrypted(name, value) => EntryTag::Encrypted( - unsafe { - String::from_utf8_unchecked( - ProfileKey::prepare_input(name.as_bytes()).into_vec(), - ) - }, - unsafe { - String::from_utf8_unchecked( - ProfileKey::prepare_input(value.as_bytes()).into_vec(), - ) - }, - ), + EntryTag::Plaintext(name, value) => { + EntryTag::Plaintext(_prepare_string(name), value.clone()) + } + EntryTag::Encrypted(name, value) => { + EntryTag::Encrypted(_prepare_string(name), _prepare_string(value)) + } }); } Ok(result) diff --git a/src/backend/postgres/mod.rs b/src/backend/postgres/mod.rs index 2723da5c..73faaa1f 100644 --- a/src/backend/postgres/mod.rs +++ b/src/backend/postgres/mod.rs @@ -49,11 +49,14 @@ const FETCH_QUERY_UPDATE: &'static str = "SELECT id, value, FROM items_tags it WHERE it.item_id = i.id) tags FROM items i WHERE profile_id = $1 AND kind = $2 AND category = $3 AND name = $4 - AND (expiry IS NULL OR expiry > CURRENT_TIMESTAMP) FOR UPDATE"; + AND (expiry IS NULL OR expiry > CURRENT_TIMESTAMP) FOR NO KEY UPDATE"; const INSERT_QUERY: &'static str = "INSERT INTO items (profile_id, kind, category, name, value, expiry) VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT DO NOTHING RETURNING id"; +const UPDATE_QUERY: &'static str = "UPDATE items SET value=$5, expiry=$6 + WHERE profile_id=$1 AND kind=$2 AND category=$3 AND name=$4 + RETURNING id"; const SCAN_QUERY: &'static str = "SELECT id, name, value, (SELECT ARRAY_TO_STRING(ARRAY_AGG(it.plaintext || ':' || ENCODE(it.name, 'hex') || ':' || ENCODE(it.value, 'hex')), ',') @@ -64,6 +67,8 @@ const DELETE_ALL_QUERY: &'static str = "DELETE FROM items i WHERE i.profile_id = $1 AND i.kind = $2 AND i.category = $3"; const TAG_INSERT_QUERY: &'static str = "INSERT INTO items_tags (item_id, name, value, plaintext) VALUES ($1, $2, $3, $4)"; +const TAG_DELETE_QUERY: &'static str = "DELETE FROM items_tags + WHERE item_id=$1"; mod provision; pub use provision::PostgresStoreOptions; @@ -445,7 +450,7 @@ impl QueryBackend for DbSession { let name = ProfileKey::prepare_input(name.as_bytes()); match operation { - EntryOperation::Insert => { + op @ EntryOperation::Insert | op @ EntryOperation::Replace => { let value = ProfileKey::prepare_input(value.unwrap()); let tags = tags.map(prepare_tags); Box::pin(async move { @@ -473,42 +478,7 @@ impl QueryBackend for DbSession { &enc_value, enc_tags, expiry_ms, - ) - .await?; - txn.commit().await?; - Ok(()) - }) - } - EntryOperation::Replace => { - let value = ProfileKey::prepare_input(value.unwrap()); - let tags = tags.map(prepare_tags); - Box::pin(async move { - let (_, key) = acquire_key(&mut *self).await?; - let (enc_category, enc_name, enc_value, enc_tags) = unblock(move || { - let enc_value = - key.encrypt_entry_value(category.as_ref(), name.as_ref(), value)?; - Result::<_, Error>::Ok(( - key.encrypt_entry_category(category)?, - key.encrypt_entry_name(name)?, - enc_value, - tags.transpose()? - .map(|t| key.encrypt_entry_tags(t)) - .transpose()?, - )) - }) - .await?; - - let mut active = acquire_session(&mut *self).await?; - let mut txn = active.as_transaction().await?; - perform_remove(&mut txn, kind, &enc_category, &enc_name, false).await?; - perform_insert( - &mut txn, - kind, - &enc_category, - &enc_name, - &enc_value, - enc_tags, - expiry_ms, + op == EntryOperation::Insert, ) .await?; txn.commit().await?; @@ -613,18 +583,38 @@ async fn perform_insert<'q>( enc_value: &[u8], enc_tags: Option>, expiry_ms: Option, + new_row: bool, ) -> Result<(), Error> { - trace!("Insert entry"); - let row_id: i64 = sqlx::query_scalar(INSERT_QUERY) - .bind(active.profile_id) - .bind(kind as i16) - .bind(enc_category) - .bind(enc_name) - .bind(enc_value) - .bind(expiry_ms.map(expiry_timestamp).transpose()?) - .fetch_optional(active.connection_mut()) - .await? - .ok_or_else(|| err_msg!(Duplicate, "Duplicate row"))?; + let row_id = if new_row { + trace!("Insert entry"); + sqlx::query_scalar(INSERT_QUERY) + .bind(active.profile_id) + .bind(kind as i16) + .bind(enc_category) + .bind(enc_name) + .bind(enc_value) + .bind(expiry_ms.map(expiry_timestamp).transpose()?) + .fetch_optional(active.connection_mut()) + .await? + .ok_or_else(|| err_msg!(Duplicate, "Duplicate row"))? + } else { + trace!("Update entry"); + let row_id: i64 = sqlx::query_scalar(UPDATE_QUERY) + .bind(active.profile_id) + .bind(kind as i16) + .bind(enc_category) + .bind(enc_name) + .bind(enc_value) + .bind(expiry_ms.map(expiry_timestamp).transpose()?) + .fetch_one(active.connection_mut()) + .await + .map_err(|_| err_msg!(NotFound, "Error updating existing row"))?; + sqlx::query(TAG_DELETE_QUERY) + .bind(row_id) + .execute(active.connection_mut()) + .await?; + row_id + }; if let Some(tags) = enc_tags { for tag in tags { sqlx::query(TAG_INSERT_QUERY) @@ -690,7 +680,7 @@ fn perform_scan<'q>( params.push(enc_category); let mut query = extend_query::(SCAN_QUERY, &mut params, tag_filter, offset, limit)?; if for_update { - query.push_str(" FOR UPDATE"); + query.push_str(" FOR NO KEY UPDATE"); } let mut batch = Vec::with_capacity(PAGE_SIZE); diff --git a/src/backend/postgres/test_db.rs b/src/backend/postgres/test_db.rs index 24d2f40e..8dccf404 100644 --- a/src/backend/postgres/test_db.rs +++ b/src/backend/postgres/test_db.rs @@ -81,6 +81,32 @@ impl TestDB { lock_txn: Some(lock_txn), }) } + + async fn close_internal( + mut lock_txn: Option, + mut inst: Option>, + ) -> Result<(), Error> { + if let Some(lock_txn) = lock_txn.take() { + lock_txn.close().await?; + } + if let Some(inst) = inst.take() { + timeout(Duration::from_secs(30), inst.close()) + .await + .ok_or_else(|| { + err_msg!( + Backend, + "Timed out waiting for the pool connection to close" + ) + })??; + } + Ok(()) + } + + /// Explicitly close the test database + pub async fn close(mut self) -> Result<(), Error> { + Self::close_internal(self.lock_txn.take(), self.inst.take()).await?; + Ok(()) + } } impl std::ops::Deref for TestDB { @@ -93,21 +119,14 @@ impl std::ops::Deref for TestDB { impl Drop for TestDB { fn drop(&mut self) { - if let Some(lock_txn) = self.lock_txn.take() { + if self.lock_txn.is_some() || self.inst.is_some() { + let lock_txn = self.lock_txn.take(); + let inst = self.inst.take(); spawn_ok(async { - lock_txn - .close() + Self::close_internal(lock_txn, inst) .await .expect("Error closing database connection"); }); } - if let Some(inst) = self.inst.take() { - spawn_ok(async { - timeout(Duration::from_secs(30), inst.close()) - .await - .expect("Timed out waiting for the pool connection to close") - .expect("Error closing connection pool"); - }); - } } } diff --git a/src/backend/sqlite/mod.rs b/src/backend/sqlite/mod.rs index c34b65be..5d5265de 100644 --- a/src/backend/sqlite/mod.rs +++ b/src/backend/sqlite/mod.rs @@ -47,6 +47,9 @@ const FETCH_QUERY: &'static str = "SELECT i.id, i.value, const INSERT_QUERY: &'static str = "INSERT OR IGNORE INTO items (profile_id, kind, category, name, value, expiry) VALUES (?1, ?2, ?3, ?4, ?5, ?6)"; +const UPDATE_QUERY: &'static str = + "UPDATE items SET value=?5, expiry=?6 WHERE profile_id=?1 AND kind=?2 + AND category=?3 AND name=?4 RETURNING id"; const SCAN_QUERY: &'static str = "SELECT i.id, i.name, i.value, (SELECT GROUP_CONCAT(it.plaintext || ':' || HEX(it.name) || ':' || HEX(it.value)) FROM items_tags it WHERE it.item_id = i.id) AS tags @@ -56,6 +59,8 @@ const DELETE_ALL_QUERY: &'static str = "DELETE FROM items AS i WHERE i.profile_id = ?1 AND i.kind = ?2 AND i.category = ?3"; const TAG_INSERT_QUERY: &'static str = "INSERT INTO items_tags (item_id, name, value, plaintext) VALUES (?1, ?2, ?3, ?4)"; +const TAG_DELETE_QUERY: &'static str = "DELETE FROM items_tags + WHERE item_id=?1"; /// A Sqlite database store pub struct SqliteStore { @@ -435,9 +440,6 @@ impl QueryBackend for DbSession { .await?; let mut active = acquire_session(&mut *self).await?; let mut txn = active.as_transaction().await?; - if op == EntryOperation::Replace { - perform_remove(&mut txn, kind, &enc_category, &enc_name, false).await?; - } perform_insert( &mut txn, kind, @@ -446,6 +448,7 @@ impl QueryBackend for DbSession { &enc_value, enc_tags, expiry_ms, + op == EntryOperation::Insert, ) .await?; txn.commit().await?; @@ -484,8 +487,10 @@ impl ExtDatabase for Sqlite { Box::pin(async move { ::TransactionManager::begin(&mut *conn).await?; if !nested { - sqlx::query("ROLLBACK").execute(&mut *conn).await?; - sqlx::query("BEGIN IMMEDIATE").execute(conn).await?; + // a no-op write transaction + sqlx::query("DELETE FROM config WHERE 0") + .execute(&mut *conn) + .await?; } Ok(()) }) @@ -540,21 +545,41 @@ async fn perform_insert<'q>( enc_value: &[u8], enc_tags: Option>, expiry_ms: Option, + new_row: bool, ) -> Result<(), Error> { - trace!("Insert entry"); - let done = sqlx::query(INSERT_QUERY) - .bind(active.profile_id) - .bind(kind as i16) - .bind(enc_category) - .bind(enc_name) - .bind(enc_value) - .bind(expiry_ms.map(expiry_timestamp).transpose()?) - .execute(active.connection_mut()) - .await?; - if done.rows_affected() == 0 { - return Err(err_msg!(Duplicate, "Duplicate row")); - } - let row_id = done.last_insert_rowid(); + let row_id = if new_row { + trace!("Insert entry"); + let done = sqlx::query(INSERT_QUERY) + .bind(active.profile_id) + .bind(kind as i16) + .bind(enc_category) + .bind(enc_name) + .bind(enc_value) + .bind(expiry_ms.map(expiry_timestamp).transpose()?) + .execute(active.connection_mut()) + .await?; + if done.rows_affected() == 0 { + return Err(err_msg!(Duplicate, "Duplicate row")); + } + done.last_insert_rowid() + } else { + trace!("Update entry"); + let row_id: i64 = sqlx::query_scalar(UPDATE_QUERY) + .bind(active.profile_id) + .bind(kind as i16) + .bind(enc_category) + .bind(enc_name) + .bind(enc_value) + .bind(expiry_ms.map(expiry_timestamp).transpose()?) + .fetch_one(active.connection_mut()) + .await + .map_err(|_| err_msg!(NotFound, "Error updating existing row"))?; + sqlx::query(TAG_DELETE_QUERY) + .bind(row_id) + .execute(active.connection_mut()) + .await?; + row_id + }; if let Some(tags) = enc_tags { for tag in tags { sqlx::query(TAG_INSERT_QUERY) diff --git a/src/backend/sqlite/provision.rs b/src/backend/sqlite/provision.rs index 987ca9ff..d57b9e2b 100644 --- a/src/backend/sqlite/provision.rs +++ b/src/backend/sqlite/provision.rs @@ -1,10 +1,12 @@ -use std::borrow::Cow; -use std::fs::remove_file; -use std::io::ErrorKind as IoErrorKind; -use std::str::FromStr; +use std::{ + borrow::Cow, fs::remove_file, io::ErrorKind as IoErrorKind, str::FromStr, time::Duration, +}; use sqlx::{ - sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions}, + sqlite::{ + SqliteAutoVacuum, SqliteConnectOptions, SqliteJournalMode, SqliteLockingMode, SqlitePool, + SqlitePoolOptions, SqliteSynchronous, + }, ConnectOptions, Error as SqlxError, Row, }; @@ -20,18 +22,46 @@ use crate::{ storage::{IntoOptions, Options, Store}, }; +const DEFAULT_MIN_CONNECTIONS: u32 = 1; +const DEFAULT_BUSY_TIMEOUT: Duration = Duration::from_secs(5); +const DEFAULT_JOURNAL_MODE: SqliteJournalMode = SqliteJournalMode::Wal; +const DEFAULT_LOCKING_MODE: SqliteLockingMode = SqliteLockingMode::Normal; +const DEFAULT_SHARED_CACHE: bool = true; +const DEFAULT_SYNCHRONOUS: SqliteSynchronous = SqliteSynchronous::Full; + /// Configuration options for Sqlite stores #[derive(Debug)] pub struct SqliteStoreOptions { pub(crate) in_memory: bool, pub(crate) path: String, + pub(crate) busy_timeout: Duration, pub(crate) max_connections: u32, + pub(crate) min_connections: u32, + pub(crate) journal_mode: SqliteJournalMode, + pub(crate) locking_mode: SqliteLockingMode, + pub(crate) shared_cache: bool, + pub(crate) synchronous: SqliteSynchronous, +} + +impl Default for SqliteStoreOptions { + fn default() -> Self { + Self::new(":memory:").expect("Error initializing with default options") + } } impl SqliteStoreOptions { /// Initialize `SqliteStoreOptions` from a generic set of options pub fn new<'a>(options: impl IntoOptions<'a>) -> Result { let mut opts = options.into_options()?; + let busy_timeout = if let Some(timeout) = opts.query.remove("busy_timeout") { + Duration::from_millis( + timeout + .parse() + .map_err(err_map!(Input, "Error parsing 'busy_timeout' parameter"))?, + ) + } else { + DEFAULT_BUSY_TIMEOUT + }; let max_connections = if let Some(max_conn) = opts.query.remove("max_connections") { max_conn .parse() @@ -39,19 +69,62 @@ impl SqliteStoreOptions { } else { num_cpus::get() as u32 }; + let min_connections = if let Some(min_conn) = opts.query.remove("min_connections") { + min_conn + .parse() + .map_err(err_map!(Input, "Error parsing 'min_connections' parameter"))? + } else { + DEFAULT_MIN_CONNECTIONS + }; + let journal_mode = if let Some(mode) = opts.query.remove("journal_mode") { + SqliteJournalMode::from_str(&mode) + .map_err(err_map!(Input, "Error parsing 'journal_mode' parameter"))? + } else { + DEFAULT_JOURNAL_MODE + }; + let locking_mode = if let Some(mode) = opts.query.remove("locking_mode") { + SqliteLockingMode::from_str(&mode) + .map_err(err_map!(Input, "Error parsing 'locking_mode' parameter"))? + } else { + DEFAULT_LOCKING_MODE + }; + let shared_cache = if let Some(cache) = opts.query.remove("cache") { + cache.eq_ignore_ascii_case("shared") + } else { + DEFAULT_SHARED_CACHE + }; + let synchronous = if let Some(sync) = opts.query.remove("synchronous") { + SqliteSynchronous::from_str(&sync) + .map_err(err_map!(Input, "Error parsing 'synchronous' parameter"))? + } else { + DEFAULT_SYNCHRONOUS + }; + let mut path = opts.host.to_string(); path.push_str(&*opts.path); Ok(Self { in_memory: path == ":memory:", path, + busy_timeout, max_connections, + min_connections, + journal_mode, + locking_mode, + shared_cache, + synchronous, }) } async fn pool(&self, auto_create: bool) -> std::result::Result { #[allow(unused_mut)] - let mut conn_opts = - SqliteConnectOptions::from_str(self.path.as_ref())?.create_if_missing(auto_create); + let mut conn_opts = SqliteConnectOptions::from_str(self.path.as_ref())? + .create_if_missing(auto_create) + .auto_vacuum(SqliteAutoVacuum::Incremental) + .busy_timeout(self.busy_timeout) + .journal_mode(self.journal_mode) + .locking_mode(self.locking_mode) + .shared_cache(self.shared_cache) + .synchronous(self.synchronous); #[cfg(feature = "log")] { conn_opts.log_statements(log::LevelFilter::Debug); @@ -61,7 +134,7 @@ impl SqliteStoreOptions { // maintains at least 1 connection. // for an in-memory database this is required to avoid dropping the database, // for a file database this signals other instances that the database is in use - .min_connections(1) + .min_connections(self.min_connections) .max_connections(self.max_connections) .test_before_acquire(false) .connect_with(conn_opts) diff --git a/src/error.rs b/src/error.rs index 1a1d79de..b20aa9f7 100644 --- a/src/error.rs +++ b/src/error.rs @@ -84,7 +84,10 @@ impl Error { self.message.as_ref().map(String::as_str) } - pub(crate) fn with_cause>>(mut self, err: T) -> Self { + pub(crate) fn with_cause>>( + mut self, + err: T, + ) -> Self { self.cause = Some(err.into()); self } @@ -108,7 +111,7 @@ impl StdError for Error { fn source(&self) -> Option<&(dyn StdError + 'static)> { self.cause .as_ref() - .map(|err| unsafe { std::mem::transmute(&**err) }) + .map(|err| &**err as &(dyn StdError + 'static)) } } diff --git a/src/ffi/handle.rs b/src/ffi/handle.rs index 935ddd12..a1da1c79 100644 --- a/src/ffi/handle.rs +++ b/src/ffi/handle.rs @@ -1,40 +1,40 @@ -use std::{fmt::Display, marker::PhantomData, mem, sync::Arc}; +use std::{fmt::Display, mem, ptr, sync::Arc}; use crate::error::Error; -#[repr(transparent)] -pub struct ArcHandle(usize, PhantomData); +#[repr(C)] +pub struct ArcHandle(*const T); -impl ArcHandle { +impl ArcHandle { pub fn invalid() -> Self { - Self(0, PhantomData) + Self(ptr::null()) } pub fn create(value: T) -> Self { let results = Arc::into_raw(Arc::new(value)); - Self(results as usize, PhantomData) + Self(results) } pub fn load(&self) -> Result, Error> { self.validate()?; - let slf = unsafe { Arc::from_raw(self.0 as *const T) }; - let copy = slf.clone(); - mem::forget(slf); // Arc::increment_strong_count(..) in 1.51 - Ok(copy) + unsafe { + let result = mem::ManuallyDrop::new(Arc::from_raw(self.0)); + Ok((&*result).clone()) + } } pub fn remove(&self) { - if self.0 != 0 { - unsafe { + unsafe { + if !self.0.is_null() { // Drop the initial reference. There could be others outstanding. - Arc::from_raw(self.0 as *const T); + Arc::decrement_strong_count(self.0); } } } #[inline] pub fn validate(&self) -> Result<(), Error> { - if self.0 == 0 { + if self.0.is_null() { Err(err_msg!("Invalid handle")) } else { Ok(()) @@ -42,9 +42,9 @@ impl ArcHandle { } } -impl std::fmt::Display for ArcHandle { +impl std::fmt::Display for ArcHandle { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Handle({:p})", self.0 as *const T) + write!(f, "Handle({:p})", self.0) } } @@ -62,7 +62,7 @@ macro_rules! new_sequence_handle (($newtype:ident, $counter:ident) => ( static $counter: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0); #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] - #[repr(transparent)] + #[repr(C)] pub struct $newtype(pub usize); impl $crate::ffi::ResourceHandle for $newtype { diff --git a/src/ffi/key.rs b/src/ffi/key.rs index 9b7d973f..49ab97ad 100644 --- a/src/ffi/key.rs +++ b/src/ffi/key.rs @@ -365,7 +365,7 @@ pub extern "C" fn askar_key_verify_signature( trace!("Verify signature: {}", handle); check_useful_c_ptr!(out); let key = handle.load()?; - let verify = key.verify_signature(message.as_slice(),signature.as_slice(), sig_type.as_opt_str())?; + let verify = key.verify_signature(message.as_slice(), signature.as_slice(), sig_type.as_opt_str())?; unsafe { *out = verify as i8 }; Ok(ErrorCode::Success) } diff --git a/src/ffi/log.rs b/src/ffi/log.rs index 440b4a4b..4b522812 100644 --- a/src/ffi/log.rs +++ b/src/ffi/log.rs @@ -1,12 +1,16 @@ use std::ffi::CString; use std::os::raw::{c_char, c_void}; use std::ptr; +use std::sync::atomic::{AtomicBool, Ordering}; use log::{LevelFilter, Metadata, Record}; +use once_cell::sync::OnceCell; use super::error::ErrorCode; use crate::error::Error; +static LOGGER: OnceCell = OnceCell::new(); + pub type EnabledCallback = extern "C" fn(context: *const c_void, level: i32) -> i8; pub type LogCallback = extern "C" fn( @@ -26,6 +30,7 @@ pub struct CustomLogger { enabled: Option, log: LogCallback, flush: Option, + disabled: AtomicBool, } impl CustomLogger { @@ -40,13 +45,20 @@ impl CustomLogger { enabled, log, flush, + disabled: AtomicBool::new(false), } } + + fn disable(&self) { + self.disabled.store(false, Ordering::Release); + } } impl log::Log for CustomLogger { fn enabled(&self, metadata: &Metadata<'_>) -> bool { - if let Some(enabled_cb) = self.enabled { + if !self.disabled.load(Ordering::Acquire) { + false + } else if let Some(enabled_cb) = self.enabled { enabled_cb(self.context, metadata.level() as i32) != 0 } else { true @@ -98,18 +110,30 @@ pub extern "C" fn askar_set_custom_logger( ) -> ErrorCode { catch_err! { let max_level = get_level_filter(max_level)?; - let logger = CustomLogger::new(context, enabled, log, flush); - log::set_boxed_logger(Box::new(logger)).map_err(err_map!(Unexpected))?; + if LOGGER.set(CustomLogger::new(context, enabled, log, flush)).is_err() { + return Err(err_msg!(Input, "Repeated logger initialization")); + } + log::set_logger(LOGGER.get().unwrap()).map_err( + |_| err_msg!(Input, "Repeated logger initialization"))?; log::set_max_level(max_level); debug!("Initialized custom logger"); Ok(ErrorCode::Success) } } +#[no_mangle] +pub extern "C" fn askar_clear_custom_logger() { + debug!("Removing custom logger"); + if let Some(logger) = LOGGER.get() { + logger.disable(); + } +} + #[no_mangle] pub extern "C" fn askar_set_default_logger() -> ErrorCode { catch_err! { - env_logger::init(); + env_logger::try_init().map_err( + |_| err_msg!(Input, "Repeated logger initialization"))?; debug!("Initialized default logger"); Ok(ErrorCode::Success) } diff --git a/src/ffi/mod.rs b/src/ffi/mod.rs index 5c67706e..29460c9f 100644 --- a/src/ffi/mod.rs +++ b/src/ffi/mod.rs @@ -32,6 +32,7 @@ ffi_support::define_string_destructor!(askar_string_free); pub struct EnsureCallback)> { f: F, + resolved: bool, _pd: PhantomData, } @@ -39,20 +40,23 @@ impl)> EnsureCallback { pub fn new(f: F) -> Self { Self { f, + resolved: false, _pd: PhantomData, } } - pub fn resolve(self, value: Result) { + pub fn resolve(mut self, value: Result) { + self.resolved = true; (self.f)(value); - std::mem::forget(self); } } impl)> Drop for EnsureCallback { fn drop(&mut self) { // if std::thread::panicking() - capture trace? - (self.f)(Err(err_msg!(Unexpected))); + if !self.resolved { + (self.f)(Err(err_msg!(Unexpected))); + } } } diff --git a/src/ffi/secret.rs b/src/ffi/secret.rs index d6c15267..e24b0fb7 100644 --- a/src/ffi/secret.rs +++ b/src/ffi/secret.rs @@ -30,10 +30,12 @@ impl Default for SecretBuffer { impl SecretBuffer { pub fn from_secret(buffer: impl Into) -> Self { - let mut buf = buffer.into().into_boxed_slice(); + let mut buf = buffer.into(); + buf.shrink_to_fit(); + debug_assert_eq!(buf.len(), buf.capacity()); + let mut buf = mem::ManuallyDrop::new(buf.into_vec()); let len = i64::try_from(buf.len()).expect("secret length exceeds i64::MAX"); let data = buf.as_mut_ptr(); - mem::forget(buf); Self { len, data } } diff --git a/src/ffi/store.rs b/src/ffi/store.rs index c8f8cedb..6d43ac03 100644 --- a/src/ffi/store.rs +++ b/src/ffi/store.rs @@ -1,9 +1,8 @@ use std::{collections::BTreeMap, os::raw::c_char, ptr, str::FromStr, sync::Arc}; -use async_lock::RwLock; +use async_lock::{Mutex as TryMutex, MutexGuardArc as TryMutexGuard, RwLock}; use ffi_support::{rust_string_to_c, ByteBuffer, FfiStr}; use once_cell::sync::Lazy; -use option_lock::{Mutex as TryMutex, MutexGuardArc as TryMutexGuard}; use super::{ error::set_last_error, @@ -87,7 +86,7 @@ where pub async fn remove(&self, handle: K) -> Option> { self.map.write().await.remove(&handle).map(|(_s, v)| { Arc::try_unwrap(v) - .map(|item| item.into_inner().unwrap()) + .map(|item| item.into_inner()) .map_err(|_| err_msg!(Busy, "Resource handle in use")) }) } @@ -100,7 +99,7 @@ where .ok_or_else(|| err_msg!("Invalid resource handle"))? .1 .try_lock_arc() - .map_err(|_| err_msg!(Busy, "Resource handle in use")) + .ok_or_else(|| err_msg!(Busy, "Resource handle in use")) } pub async fn remove_all(&self, store: StoreHandle) -> Result<(), Error> { diff --git a/src/future.rs b/src/future.rs index 0aace01d..b87544ac 100644 --- a/src/future.rs +++ b/src/future.rs @@ -7,10 +7,12 @@ pub type BoxFuture<'a, T> = Pin + Send + 'a>>; static RUNTIME: Lazy = Lazy::new(|| Runtime::new().expect("Error creating tokio runtime")); +/// Block the current thread on an async task, when not running inside the scheduler. pub fn block_on(f: impl Future) -> R { RUNTIME.block_on(f) } +/// Run a blocking task without interrupting the async scheduler. #[inline] pub async fn unblock(f: F) -> T where @@ -23,16 +25,21 @@ where .expect("Error running blocking task") } +/// Spawn an async task into the runtime. #[inline] pub fn spawn_ok(fut: impl Future + Send + 'static) { RUNTIME.spawn(fut); } +/// Wait until a specific duration has passed (used in tests). +#[doc(hidden)] pub async fn sleep(dur: Duration) { let _rt = RUNTIME.enter(); tokio::time::sleep(dur).await } +/// Cancel an async task if it does not complete after a timeout (used in tests). +#[doc(hidden)] pub async fn timeout(dur: Duration, f: impl Future) -> Option { let _rt = RUNTIME.enter(); tokio::time::timeout(dur, f).await.ok() diff --git a/tests/backends.rs b/tests/backends.rs index a7fe96e1..ca77a734 100644 --- a/tests/backends.rs +++ b/tests/backends.rs @@ -1,19 +1,27 @@ mod utils; +const ERR_CLOSE: &'static str = "Error closing database"; + macro_rules! backend_tests { ($init:expr) => { use aries_askar::future::block_on; + use std::sync::Arc; + use $crate::utils::TestStore; #[test] fn init() { - block_on($init); + block_on(async { + let db = $init.await; + db.close().await.expect(ERR_CLOSE); + }); } #[test] fn create_remove_profile() { block_on(async { let db = $init.await; - super::utils::db_create_remove_profile(&db).await; + super::utils::db_create_remove_profile(db.clone()).await; + db.close().await.expect(ERR_CLOSE); }) } @@ -21,7 +29,8 @@ macro_rules! backend_tests { fn fetch_fail() { block_on(async { let db = $init.await; - super::utils::db_fetch_fail(&db).await; + super::utils::db_fetch_fail(db.clone()).await; + db.close().await.expect(ERR_CLOSE); }) } @@ -29,7 +38,8 @@ macro_rules! backend_tests { fn insert_fetch() { block_on(async { let db = $init.await; - super::utils::db_insert_fetch(&db).await; + super::utils::db_insert_fetch(db.clone()).await; + db.close().await.expect(ERR_CLOSE); }) } @@ -37,7 +47,8 @@ macro_rules! backend_tests { fn insert_duplicate() { block_on(async { let db = $init.await; - super::utils::db_insert_duplicate(&db).await; + super::utils::db_insert_duplicate(db.clone()).await; + db.close().await.expect(ERR_CLOSE); }) } @@ -45,7 +56,8 @@ macro_rules! backend_tests { fn insert_remove() { block_on(async { let db = $init.await; - super::utils::db_insert_remove(&db).await; + super::utils::db_insert_remove(db.clone()).await; + db.close().await.expect(ERR_CLOSE); }) } @@ -53,7 +65,8 @@ macro_rules! backend_tests { fn remove_missing() { block_on(async { let db = $init.await; - super::utils::db_remove_missing(&db).await; + super::utils::db_remove_missing(db.clone()).await; + db.close().await.expect(ERR_CLOSE); }) } @@ -61,7 +74,8 @@ macro_rules! backend_tests { fn replace_fetch() { block_on(async { let db = $init.await; - super::utils::db_replace_fetch(&db).await; + super::utils::db_replace_fetch(db.clone()).await; + db.close().await.expect(ERR_CLOSE); }) } @@ -69,7 +83,8 @@ macro_rules! backend_tests { fn replace_missing() { block_on(async { let db = $init.await; - super::utils::db_replace_missing(&db).await; + super::utils::db_replace_missing(db.clone()).await; + db.close().await.expect(ERR_CLOSE); }) } @@ -77,7 +92,8 @@ macro_rules! backend_tests { fn count() { block_on(async { let db = $init.await; - super::utils::db_count(&db).await; + super::utils::db_count(db.clone()).await; + db.close().await.expect(ERR_CLOSE); }) } @@ -85,7 +101,8 @@ macro_rules! backend_tests { fn count_exist() { block_on(async { let db = $init.await; - super::utils::db_count_exist(&db).await; + super::utils::db_count_exist(db.clone()).await; + db.close().await.expect(ERR_CLOSE); }) } @@ -93,7 +110,8 @@ macro_rules! backend_tests { fn scan() { block_on(async { let db = $init.await; - super::utils::db_scan(&db).await; + super::utils::db_scan(db.clone()).await; + db.close().await.expect(ERR_CLOSE); }) } @@ -101,23 +119,26 @@ macro_rules! backend_tests { fn remove_all() { block_on(async { let db = $init.await; - super::utils::db_remove_all(&db).await; + super::utils::db_remove_all(db.clone()).await; + db.close().await.expect(ERR_CLOSE); }) } - // #[test] - // fn keypair_create_fetch() { - // block_on(async { - // let db = $init.await; - // super::utils::db_keypair_create_fetch(&db).await; - // }) - // } + #[test] + fn keypair_create_fetch() { + block_on(async { + let db = $init.await; + super::utils::db_keypair_insert_fetch(db.clone()).await; + db.close().await.expect(ERR_CLOSE); + }) + } // #[test] // fn keypair_sign_verify() { // block_on(async { // let db = $init.await; - // super::utils::db_keypair_sign_verify(&db).await; + // super::utils::db_keypair_sign_verify(db.clone()).await; + // db.close().await.expect(ERR_CLOSE); // }) // } @@ -125,7 +146,8 @@ macro_rules! backend_tests { // fn keypair_pack_unpack_anon() { // block_on(async { // let db = $init.await; - // super::utils::db_keypair_pack_unpack_anon(&db).await; + // super::utils::db_keypair_pack_unpack_anon(db.clone()).await; + // db.close().await.expect(ERR_CLOSE); // }) // } @@ -133,7 +155,8 @@ macro_rules! backend_tests { // fn keypair_pack_unpack_auth() { // block_on(async { // let db = $init.await; - // super::utils::db_keypair_pack_unpack_auth(&db).await; + // super::utils::db_keypair_pack_unpack_auth(db).await; + // db.close().await.expect(ERR_CLOSE); // }) // } @@ -141,7 +164,8 @@ macro_rules! backend_tests { fn txn_rollback() { block_on(async { let db = $init.await; - super::utils::db_txn_rollback(&db).await; + super::utils::db_txn_rollback(db.clone()).await; + db.close().await.expect(ERR_CLOSE); }) } @@ -149,7 +173,8 @@ macro_rules! backend_tests { fn txn_drop() { block_on(async { let db = $init.await; - super::utils::db_txn_drop(&db).await; + super::utils::db_txn_drop(db.clone()).await; + db.close().await.expect(ERR_CLOSE); }) } @@ -157,7 +182,8 @@ macro_rules! backend_tests { fn session_drop() { block_on(async { let db = $init.await; - super::utils::db_session_drop(&db).await; + super::utils::db_session_drop(db.clone()).await; + db.close().await.expect(ERR_CLOSE); }) } @@ -165,7 +191,8 @@ macro_rules! backend_tests { fn txn_commit() { block_on(async { let db = $init.await; - super::utils::db_txn_commit(&db).await; + super::utils::db_txn_commit(db.clone()).await; + db.close().await.expect(ERR_CLOSE); }) } @@ -173,7 +200,17 @@ macro_rules! backend_tests { fn txn_fetch_for_update() { block_on(async { let db = $init.await; - super::utils::db_txn_fetch_for_update(&db).await; + super::utils::db_txn_fetch_for_update(db.clone()).await; + db.close().await.expect(ERR_CLOSE); + }) + } + + #[test] + fn txn_contention() { + block_on(async { + let db = $init.await; + super::utils::db_txn_contention(db.clone()).await; + db.close().await.expect(ERR_CLOSE); }) } }; @@ -244,7 +281,7 @@ mod sqlite { #[test] fn rekey_db() { log_init(); - let fname = format!("sqlite-test-{}.db", uuid::Uuid::new_v4().to_string()); + let fname = format!("sqlite-rekey-{}.db", uuid::Uuid::new_v4().to_string()); let key1 = generate_raw_store_key(None).expect("Error creating raw key"); let key2 = generate_raw_store_key(None).expect("Error creating raw key"); assert_ne!(key1, key2); @@ -280,13 +317,109 @@ mod sqlite { }) } - async fn init_db() -> Store { + #[test] + fn txn_contention_file() { + log_init(); + let fname = format!("sqlite-contention-{}.db", uuid::Uuid::new_v4().to_string()); + let key = generate_raw_store_key(None).expect("Error creating raw key"); + + block_on(async move { + let store = SqliteStoreOptions::new(fname.as_str()) + .expect("Error initializing sqlite store options") + .provision_backend(StoreKeyMethod::RawKey, key.as_ref(), None, true) + .await + .expect("Error provisioning sqlite store"); + + let db = std::sync::Arc::new(store); + super::utils::db_txn_contention(db.clone()).await; + db.close().await.expect("Error closing sqlite store"); + + SqliteStoreOptions::new(fname.as_str()) + .expect("Error initializing sqlite store options") + .remove_backend() + .await + .expect("Error removing sqlite store"); + }); + } + + #[cfg(feature = "stress_test")] + #[test] + fn stress_test() { + log_init(); + use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions}; + use std::str::FromStr; + let conn_opts = SqliteConnectOptions::from_str("sqlite:test.db") + .unwrap() + .create_if_missing(true); + // .shared_cache(true); + block_on(async move { + let pool = SqlitePoolOptions::default() + // maintains at least 1 connection. + // for an in-memory database this is required to avoid dropping the database, + // for a file database this signals other instances that the database is in use + .min_connections(1) + .max_connections(5) + .test_before_acquire(false) + .connect_with(conn_opts) + .await + .unwrap(); + + let mut conn = pool.begin().await.unwrap(); + sqlx::query("CREATE TABLE test (name TEXT)") + .execute(&mut conn) + .await + .unwrap(); + sqlx::query("INSERT INTO test (name) VALUES ('test')") + .execute(&mut conn) + .await + .unwrap(); + conn.commit().await.unwrap(); + + const TASKS: usize = 25; + const COUNT: usize = 1000; + + async fn fetch(pool: SqlitePool) -> Result<(), &'static str> { + // try to avoid panics in this section, as they will be raised on a tokio worker thread + for _ in 0..COUNT { + let mut txn = pool.acquire().await.expect("Acquire error"); + sqlx::query("BEGIN IMMEDIATE") + .execute(&mut txn) + .await + .expect("Transaction error"); + let _ = sqlx::query("SELECT * FROM test") + .fetch_one(&mut txn) + .await + .expect("Error fetching row"); + sqlx::query("COMMIT") + .execute(&mut txn) + .await + .expect("Commit error"); + } + Ok(()) + } + + let mut tasks = vec![]; + for _ in 0..TASKS { + tasks.push(tokio::spawn(fetch(pool.clone()))); + } + + for task in tasks { + if let Err(s) = task.await.unwrap() { + panic!("Error in concurrent update task: {}", s); + } + } + }); + } + + async fn init_db() -> Arc> { log_init(); let key = generate_raw_store_key(None).expect("Error creating raw key"); - SqliteStoreOptions::in_memory() - .provision(StoreKeyMethod::RawKey, key, None, false) - .await - .expect("Error provisioning sqlite store") + Arc::new( + SqliteStoreOptions::in_memory() + .provision(StoreKeyMethod::RawKey, key, None, false) + .await + .expect("Error provisioning sqlite store"), + ) } backend_tests!(init_db()); @@ -315,15 +448,38 @@ mod sqlite { #[cfg(feature = "pg_test")] mod postgres { - use aries_askar::backend::postgres::test_db::TestDB; + use aries_askar::{backend::postgres::test_db::TestDB, postgres::PostgresStore, Store}; + use std::{future::Future, ops::Deref, pin::Pin}; use super::*; - async fn init_db() -> TestDB { + #[derive(Clone, Debug)] + struct Wrap(Arc); + + impl Deref for Wrap { + type Target = Store; + + fn deref(&self) -> &Self::Target { + &**self.0 + } + } + + impl TestStore for Wrap { + type DB = PostgresStore; + + fn close(self) -> Pin>>> { + let db = Arc::try_unwrap(self.0).unwrap(); + Box::pin(db.close()) + } + } + + async fn init_db() -> Wrap { log_init(); - TestDB::provision() - .await - .expect("Error provisioning postgres test database") + Wrap(Arc::new( + TestDB::provision() + .await + .expect("Error provisioning postgres test database"), + )) } backend_tests!(init_db()); diff --git a/tests/local_key.rs b/tests/local_key.rs new file mode 100644 index 00000000..47d0da53 --- /dev/null +++ b/tests/local_key.rs @@ -0,0 +1,49 @@ +use aries_askar::kms::{KeyAlg, LocalKey}; + +const ERR_CREATE_KEYPAIR: &'static str = "Error creating keypair"; +const ERR_SIGN: &'static str = "Error signing message"; +const ERR_VERIFY: &'static str = "Error verifying signature"; + +pub async fn localkey_sign_verify() { + let keypair = LocalKey::generate(KeyAlg::Ed25519, true).expect(ERR_CREATE_KEYPAIR); + + let message = b"message".to_vec(); + let sig = keypair.sign_message(&message, None).expect(ERR_SIGN); + + assert_eq!( + keypair + .verify_signature(&message, &sig, None) + .expect(ERR_VERIFY), + true + ); + + assert_eq!( + keypair + .verify_signature(b"bad input", &sig, None) + .expect(ERR_VERIFY), + false + ); + + assert_eq!( + keypair.verify_signature( + // [0u8; 64] + b"xt19s1sp2UZCGhy9rNyb1FtxdKiDGZZPNFnc1KiM9jYYEuHxuwNeFf1oQKsn8zv6yvYBGhXa83288eF4MqN1oDq", + &sig,None + ).expect(ERR_VERIFY), + false + ); + + assert_eq!( + keypair + .verify_signature(&message, b"bad sig", None) + .is_err(), + true + ); + + assert_eq!( + keypair + .verify_signature(&message, &sig, Some("invalid type")) + .is_err(), + true + ); +} diff --git a/tests/utils/mod.rs b/tests/utils/mod.rs index 7eae053d..f6da69d2 100644 --- a/tests/utils/mod.rs +++ b/tests/utils/mod.rs @@ -1,8 +1,16 @@ -use aries_askar::{Backend, Entry, EntryTag, ErrorKind, Store, TagFilter}; +use std::{fmt::Debug, future::Future, ops::Deref, pin::Pin, sync::Arc}; + +use aries_askar::{ + kms::{KeyAlg, LocalKey}, + Backend, Entry, EntryTag, Error, ErrorKind, Store, TagFilter, +}; + +use tokio::task::spawn; const ERR_PROFILE: &'static str = "Error creating profile"; const ERR_SESSION: &'static str = "Error starting session"; const ERR_TRANSACTION: &'static str = "Error starting transaction"; +const ERR_COMMIT: &'static str = "Error committing transaction"; const ERR_COUNT: &'static str = "Error performing count"; const ERR_FETCH: &'static str = "Error fetching test row"; const ERR_FETCH_ALL: &'static str = "Error fetching all test rows"; @@ -13,12 +21,27 @@ const ERR_REPLACE: &'static str = "Error replacing test row"; const ERR_REMOVE_ALL: &'static str = "Error removing test rows"; const ERR_SCAN: &'static str = "Error starting scan"; const ERR_SCAN_NEXT: &'static str = "Error fetching scan rows"; -// const ERR_CREATE_KEYPAIR: &'static str = "Error creating keypair"; -// const ERR_FETCH_KEY: &'static str = "Error fetching key"; -// const ERR_SIGN: &'static str = "Error signing message"; -// const ERR_VERIFY: &'static str = "Error verifying signature"; +const ERR_CREATE_KEYPAIR: &'static str = "Error creating keypair"; +const ERR_INSERT_KEY: &'static str = "Error inserting key"; +const ERR_FETCH_KEY: &'static str = "Error fetching key"; +const ERR_LOAD_KEY: &'static str = "Error loading key"; + +pub trait TestStore: Clone + Deref> + Send + Sync { + type DB: Backend + Debug + 'static; + + fn close(self) -> Pin>>>; +} + +impl TestStore for Arc> { + type DB = B; + + fn close(self) -> Pin>>> { + let db = Arc::try_unwrap(self).unwrap(); + Box::pin(db.close()) + } +} -pub async fn db_create_remove_profile(db: &Store) { +pub async fn db_create_remove_profile(db: impl TestStore) { let profile = db.create_profile(None).await.expect(ERR_PROFILE); assert_eq!( db.remove_profile(profile) @@ -34,13 +57,13 @@ pub async fn db_create_remove_profile(db: &Store) { ); } -pub async fn db_fetch_fail(db: &Store) { +pub async fn db_fetch_fail(db: impl TestStore) { let mut conn = db.session(None).await.expect(ERR_SESSION); let result = conn.fetch("cat", "name", false).await.expect(ERR_FETCH); assert_eq!(result.is_none(), true); } -pub async fn db_insert_fetch(db: &Store) { +pub async fn db_insert_fetch(db: impl TestStore) { let test_row = Entry::new( "category", "name", @@ -78,7 +101,7 @@ pub async fn db_insert_fetch(db: &Store) { assert_eq!(rows[0], test_row); } -pub async fn db_insert_duplicate(db: &Store) { +pub async fn db_insert_duplicate(db: impl TestStore) { let test_row = Entry::new("category", "name", "value", Vec::new()); let mut conn = db.session(None).await.expect(ERR_SESSION); @@ -106,7 +129,7 @@ pub async fn db_insert_duplicate(db: &Store) { assert_eq!(err.kind(), ErrorKind::Duplicate); } -pub async fn db_insert_remove(db: &Store) { +pub async fn db_insert_remove(db: impl TestStore) { let test_row = Entry::new("category", "name", "value", Vec::new()); let mut conn = db.session(None).await.expect(ERR_SESSION); @@ -126,14 +149,14 @@ pub async fn db_insert_remove(db: &Store) { .expect(ERR_REQ_ROW); } -pub async fn db_remove_missing(db: &Store) { +pub async fn db_remove_missing(db: impl TestStore) { let mut conn = db.session(None).await.expect(ERR_SESSION); let err = conn.remove("cat", "name").await.expect_err(ERR_REQ_ERR); assert_eq!(err.kind(), ErrorKind::NotFound); } -pub async fn db_replace_fetch(db: &Store) { +pub async fn db_replace_fetch(db: impl TestStore) { let test_row = Entry::new("category", "name", "value", Vec::new()); let mut conn = db.session(None).await.expect(ERR_SESSION); @@ -168,7 +191,7 @@ pub async fn db_replace_fetch(db: &Store) { assert_eq!(row, replace_row); } -pub async fn db_replace_missing(db: &Store) { +pub async fn db_replace_missing(db: impl TestStore) { let test_row = Entry::new("category", "name", "value", Vec::new()); let mut conn = db.session(None).await.expect(ERR_SESSION); @@ -186,7 +209,7 @@ pub async fn db_replace_missing(db: &Store) { assert_eq!(err.kind(), ErrorKind::NotFound); } -pub async fn db_count(db: &Store) { +pub async fn db_count(db: impl TestStore) { let category = "category".to_string(); let test_rows = vec![Entry::new(&category, "name", "value", Vec::new())]; @@ -213,7 +236,7 @@ pub async fn db_count(db: &Store) { assert_eq!(count, 0); } -pub async fn db_count_exist(db: &Store) { +pub async fn db_count_exist(db: impl TestStore) { let test_row = Entry::new( "category", "name", @@ -352,7 +375,7 @@ pub async fn db_count_exist(db: &Store) { ); } -pub async fn db_scan(db: &Store) { +pub async fn db_scan(db: impl TestStore) { let category = "category".to_string(); let test_rows = vec![Entry::new( &category, @@ -400,7 +423,7 @@ pub async fn db_scan(db: &Store) { assert_eq!(rows, None); } -pub async fn db_remove_all(db: &Store) { +pub async fn db_remove_all(db: impl TestStore) { let test_rows = vec![ Entry::new( "category", @@ -460,117 +483,30 @@ pub async fn db_remove_all(db: &Store) { assert_eq!(removed, 2); } -// pub async fn db_keypair_create_fetch(db: &Store) { -// let mut conn = db.session(None).await.expect(ERR_SESSION); - -// let metadata = "meta".to_owned(); -// let key_info = conn -// .create_keypair(KeyAlg::Ed25519, Some(&metadata), None, None) -// .await -// .expect(ERR_CREATE_KEYPAIR); -// assert_eq!(key_info.params.metadata, Some(metadata)); - -// let found = conn -// .fetch_key(key_info.category.clone(), &key_info.ident, false) -// .await -// .expect(ERR_FETCH_KEY); -// assert_eq!(Some(key_info), found); -// } - -// pub async fn db_keypair_sign_verify(db: &Store) { -// let mut conn = db.session(None).await.expect(ERR_SESSION); - -// let key_info = conn -// .create_keypair(KeyAlg::Ed25519, None, None, None) -// .await -// .expect(ERR_CREATE_KEYPAIR); - -// let message = b"message".to_vec(); -// let sig = conn -// .sign_message(&key_info.ident, &message) -// .await -// .expect(ERR_SIGN); - -// assert_eq!( -// verify_signature(&key_info.ident, &message, &sig).expect(ERR_VERIFY), -// true -// ); - -// assert_eq!( -// verify_signature(&key_info.ident, b"bad input", &sig).expect(ERR_VERIFY), -// false -// ); - -// assert_eq!( -// verify_signature( -// &key_info.ident, -// // [0u8; 64] -// b"xt19s1sp2UZCGhy9rNyb1FtxdKiDGZZPNFnc1KiM9jYYEuHxuwNeFf1oQKsn8zv6yvYBGhXa83288eF4MqN1oDq", -// &sig -// ).expect(ERR_VERIFY), -// false -// ); - -// assert_eq!( -// verify_signature(&key_info.ident, &message, b"bad sig").is_err(), -// true -// ); - -// let err = verify_signature("not a key", &message, &sig).expect_err(ERR_REQ_ERR); -// assert_eq!(err.kind(), ErrorKind::Input); -// } - -// pub async fn db_keypair_pack_unpack_anon(db: &Store) { -// let mut conn = db.session(None).await.expect(ERR_SESSION); - -// let recip_key = conn -// .create_keypair(KeyAlg::Ed25519, None, None, None) -// .await -// .expect(ERR_CREATE_KEYPAIR); - -// let msg = b"message".to_vec(); - -// let packed = conn -// .pack_message(vec![recip_key.ident.as_str()], None, &msg) -// .await -// .expect(ERR_PACK); - -// let (unpacked, p_recip, p_send) = conn.unpack_message(&packed).await.expect(ERR_UNPACK); -// assert_eq!(unpacked, msg); -// assert_eq!(p_recip.to_string(), recip_key.ident); -// assert_eq!(p_send, None); -// } - -// pub async fn db_keypair_pack_unpack_auth(db: &Store) { -// let mut conn = db.session(None).await.expect(ERR_SESSION); - -// let sender_key = conn -// .create_keypair(KeyAlg::Ed25519, None, None, None) -// .await -// .expect(ERR_CREATE_KEYPAIR); -// let recip_key = conn -// .create_keypair(KeyAlg::Ed25519, None, None, None) -// .await -// .expect(ERR_CREATE_KEYPAIR); - -// let msg = b"message".to_vec(); - -// let packed = conn -// .pack_message( -// vec![recip_key.ident.as_str()], -// Some(&sender_key.ident), -// &msg, -// ) -// .await -// .expect(ERR_PACK); - -// let (unpacked, p_recip, p_send) = conn.unpack_message(&packed).await.expect(ERR_UNPACK); -// assert_eq!(unpacked, msg); -// assert_eq!(p_recip.to_string(), recip_key.ident); -// assert_eq!(p_send.map(|k| k.to_string()), Some(sender_key.ident)); -// } - -pub async fn db_txn_rollback(db: &Store) { +pub async fn db_keypair_insert_fetch(db: impl TestStore) { + let keypair = LocalKey::generate(KeyAlg::Ed25519, false).expect(ERR_CREATE_KEYPAIR); + + let mut conn = db.session(None).await.expect(ERR_SESSION); + + let key_name = "testkey"; + let metadata = "meta"; + conn.insert_key(&key_name, &keypair, Some(metadata), None, None) + .await + .expect(ERR_INSERT_KEY); + + let found = conn + .fetch_key(&key_name, false) + .await + .expect(ERR_FETCH_KEY) + .expect(ERR_REQ_ROW); + assert_eq!(found.algorithm(), Some(KeyAlg::Ed25519.as_str())); + assert_eq!(found.name(), key_name); + assert_eq!(found.metadata(), Some(metadata)); + assert_eq!(found.is_local(), true); + found.load_local_key().expect(ERR_LOAD_KEY); +} + +pub async fn db_txn_rollback(db: impl TestStore) { let test_row = Entry::new("category", "name", "value", Vec::new()); let mut conn = db.transaction(None).await.expect(ERR_TRANSACTION); @@ -598,7 +534,7 @@ pub async fn db_txn_rollback(db: &Store) { assert_eq!(row, None); } -pub async fn db_txn_drop(db: &Store) { +pub async fn db_txn_drop(db: impl TestStore) { let test_row = Entry::new("category", "name", "value", Vec::new()); let mut conn = db @@ -628,7 +564,7 @@ pub async fn db_txn_drop(db: &Store) { } // test that session does NOT have transaction rollback behaviour -pub async fn db_session_drop(db: &Store) { +pub async fn db_session_drop(db: impl TestStore) { let test_row = Entry::new("category", "name", "value", Vec::new()); let mut conn = db.session(None).await.expect(ERR_SESSION); @@ -654,7 +590,7 @@ pub async fn db_session_drop(db: &Store) { assert_eq!(row, Some(test_row)); } -pub async fn db_txn_commit(db: &Store) { +pub async fn db_txn_commit(db: impl TestStore) { let test_row = Entry::new("category", "name", "value", Vec::new()); let mut conn = db.transaction(None).await.expect(ERR_TRANSACTION); @@ -669,7 +605,7 @@ pub async fn db_txn_commit(db: &Store) { .await .expect(ERR_INSERT); - conn.commit().await.expect("Error committing transaction"); + conn.commit().await.expect(ERR_COMMIT); let mut conn = db.session(None).await.expect(ERR_SESSION); @@ -680,7 +616,7 @@ pub async fn db_txn_commit(db: &Store) { assert_eq!(row, Some(test_row)); } -pub async fn db_txn_fetch_for_update(db: &Store) { +pub async fn db_txn_fetch_for_update(db: impl TestStore) { let test_row = Entry::new("category", "name", "value", Vec::new()); let mut conn = db.transaction(None).await.expect(ERR_TRANSACTION); @@ -711,5 +647,90 @@ pub async fn db_txn_fetch_for_update(db: &Store) { assert_eq!(rows.len(), 1); assert_eq!(rows[0], test_row); - conn.commit().await.expect("Error committing transaction"); + conn.commit().await.expect(ERR_COMMIT); +} + +pub async fn db_txn_contention(db: impl TestStore + 'static) { + let test_row = Entry::new( + "category", + "count", + "0", + vec![ + EntryTag::Encrypted("t1".to_string(), "v1".to_string()), + EntryTag::Plaintext("t2".to_string(), "v2".to_string()), + ], + ); + + let mut conn = db.transaction(None).await.expect(ERR_TRANSACTION); + + conn.insert( + &test_row.category, + &test_row.name, + &test_row.value, + Some(test_row.tags.as_slice()), + None, + ) + .await + .expect(ERR_INSERT); + + conn.commit().await.expect(ERR_COMMIT); + + const TASKS: usize = 10; + const INC: usize = 1000; + + async fn inc(db: impl TestStore, category: String, name: String) -> Result<(), &'static str> { + // try to avoid panics in this section, as they will be raised on a tokio worker thread + for _ in 0..INC { + let mut conn = db.transaction(None).await.expect(ERR_TRANSACTION); + let row = conn + .fetch(&category, &name, true) + .await + .map_err(|e| { + log::error!("{:?}", e); + ERR_FETCH + })? + .ok_or(ERR_REQ_ROW)?; + let val: usize = str::parse(row.value.as_opt_str().ok_or("Non-string counter value")?) + .map_err(|_| "Error parsing counter value")?; + conn.replace( + &category, + &name, + &format!("{}", val + 1).as_bytes(), + Some(row.tags.as_slice()), + None, + ) + .await + .map_err(|e| { + log::error!("{:?}", e); + ERR_REPLACE + })?; + conn.commit().await.map_err(|_| ERR_COMMIT)?; + } + Ok(()) + } + + let mut tasks = vec![]; + for _ in 0..TASKS { + tasks.push(spawn(inc( + db.clone(), + test_row.category.clone(), + test_row.name.clone(), + ))); + } + + // JoinSet is not stable yet, just await all the tasks + for task in tasks { + if let Err(s) = task.await.unwrap() { + panic!("Error in concurrent update task: {}", s); + } + } + + // check the total + let mut conn = db.session(None).await.expect(ERR_SESSION); + let row = conn + .fetch(&test_row.category, &test_row.name, false) + .await + .expect(ERR_FETCH) + .expect(ERR_REQ_ROW); + assert_eq!(row.value, format!("{}", TASKS * INC).as_bytes()); } diff --git a/wrappers/python/aries_askar/__init__.py b/wrappers/python/aries_askar/__init__.py index 09a6f9eb..608e045b 100644 --- a/wrappers/python/aries_askar/__init__.py +++ b/wrappers/python/aries_askar/__init__.py @@ -1,6 +1,6 @@ """aries-askar Python wrapper library""" -from .bindings import version, Encrypted +from .bindings import Encrypted, version from .error import AskarError, AskarErrorCode from .key import Key from .store import Entry, EntryList, KeyEntry, KeyEntryList, Session, Store diff --git a/wrappers/python/aries_askar/bindings.py b/wrappers/python/aries_askar/bindings.py deleted file mode 100644 index 85a8b4c7..00000000 --- a/wrappers/python/aries_askar/bindings.py +++ /dev/null @@ -1,1368 +0,0 @@ -"""Low-level interaction with the aries-askar library.""" - -import asyncio -import json -import logging -import os -import sys -from ctypes import ( - Array, - CDLL, - CFUNCTYPE, - POINTER, - Structure, - byref, - c_char_p, - c_int8, - c_int32, - c_int64, - c_size_t, - c_void_p, - c_ubyte, -) -from ctypes.util import find_library -from typing import Optional, Tuple, Union - -from .error import AskarError, AskarErrorCode -from .types import EntryOperation, KeyAlg, SeedMethod - - -CALLBACKS = {} -LIB: CDLL = None -LOGGER = logging.getLogger(__name__) -LOG_LEVELS = { - 1: logging.ERROR, - 2: logging.WARNING, - 3: logging.INFO, - 4: logging.DEBUG, -} -MODULE_NAME = __name__.split(".")[0] - - -class StoreHandle(c_size_t): - """Index of an active Store instance.""" - - def __repr__(self) -> str: - """Format store handle as a string.""" - return f"{self.__class__.__name__}({self.value})" - - async def close(self): - """Close the store, waiting for any active connections.""" - if not getattr(self, "_closed", False): - await do_call_async("askar_store_close", self) - setattr(self, "_closed", True) - - def __del__(self): - """Close the store when there are no more references to this object.""" - if not getattr(self, "_closed", False) and self: - do_call("askar_store_close", self, c_void_p()) - - -class SessionHandle(c_size_t): - """Index of an active Session/Transaction instance.""" - - def __repr__(self) -> str: - """Format session handle as a string.""" - return f"{self.__class__.__name__}({self.value})" - - async def close(self, commit: bool = False): - """Close the session.""" - if not getattr(self, "_closed", False) and self: - await do_call_async( - "askar_session_close", - self, - c_int8(commit), - ) - setattr(self, "_closed", True) - - def __del__(self): - """Close the session when there are no more references to this object.""" - if not getattr(self, "_closed", False) and self: - do_call("askar_session_close", self, c_int8(0), c_void_p()) - - -class ScanHandle(c_size_t): - """Index of an active Store scan instance.""" - - def __repr__(self) -> str: - """Format scan handle as a string.""" - return f"{self.__class__.__name__}({self.value})" - - def __del__(self): - """Close the scan when there are no more references to this object.""" - if self: - get_library().askar_scan_free(self) - - -class EntryListHandle(c_size_t): - """Pointer to an active EntryList instance.""" - - def get_category(self, index: int) -> str: - """Get the entry category.""" - cat = StrBuffer() - do_call( - "askar_entry_list_get_category", - self, - c_int32(index), - byref(cat), - ) - return str(cat) - - def get_name(self, index: int) -> str: - """Get the entry name.""" - name = StrBuffer() - do_call( - "askar_entry_list_get_name", - self, - c_int32(index), - byref(name), - ) - return str(name) - - def get_value(self, index: int) -> memoryview: - """Get the entry value.""" - val = ByteBuffer() - do_call("askar_entry_list_get_value", self, c_int32(index), byref(val)) - return memoryview(val.raw) - - def get_tags(self, index: int) -> dict: - """Get the entry tags.""" - tags = StrBuffer() - do_call( - "askar_entry_list_get_tags", - self, - c_int32(index), - byref(tags), - ) - if tags: - tags = json.loads(tags.value) - for t in tags: - if isinstance(tags[t], list): - tags[t] = set(tags[t]) - else: - tags = dict() - return tags - - def __repr__(self) -> str: - """Format entry list handle as a string.""" - return f"{self.__class__.__name__}({self.value})" - - def __del__(self): - """Free the entry set when there are no more references.""" - if self: - get_library().askar_entry_list_free(self) - - -class KeyEntryListHandle(c_size_t): - """Pointer to an active KeyEntryList instance.""" - - def get_algorithm(self, index: int) -> str: - """Get the key algorithm.""" - name = StrBuffer() - do_call( - "askar_key_entry_list_get_algorithm", - self, - c_int32(index), - byref(name), - ) - return str(name) - - def get_name(self, index: int) -> str: - """Get the key name.""" - name = StrBuffer() - do_call( - "askar_key_entry_list_get_name", - self, - c_int32(index), - byref(name), - ) - return str(name) - - def get_metadata(self, index: int) -> str: - """Get for the key metadata.""" - metadata = StrBuffer() - do_call( - "askar_key_entry_list_get_metadata", - self, - c_int32(index), - byref(metadata), - ) - return str(metadata) - - def get_tags(self, index: int) -> dict: - """Get the key tags.""" - tags = StrBuffer() - do_call( - "askar_key_entry_list_get_tags", - self, - c_int32(index), - byref(tags), - ) - return json.loads(tags.value) if tags else None - - def load_key(self, index: int) -> "LocalKeyHandle": - """Load the key instance.""" - handle = LocalKeyHandle() - do_call( - "askar_key_entry_list_load_local", - self, - c_int32(index), - byref(handle), - ) - return handle - - def __repr__(self) -> str: - """Format key entry list handle as a string.""" - return f"{self.__class__.__name__}({self.value})" - - def __del__(self): - """Free the key entry set when there are no more references.""" - if self: - get_library().askar_key_entry_list_free(self) - - -class LocalKeyHandle(c_size_t): - """Pointer to an active LocalKey instance.""" - - def __repr__(self) -> str: - """Format key handle as a string.""" - return f"{self.__class__.__name__}({self.value})" - - def __del__(self): - """Free the key when there are no more references.""" - if self: - get_library().askar_key_free(self) - - -class FfiByteBuffer(Structure): - """A byte buffer allocated by python.""" - - _fields_ = [ - ("len", c_int64), - ("value", POINTER(c_ubyte)), - ] - - -class RawBuffer(Structure): - """A byte buffer allocated by the library.""" - - _fields_ = [ - ("len", c_int64), - ("data", c_void_p), - ] - - -class ByteBuffer(Structure): - """A managed byte buffer allocated by the library.""" - - _fields_ = [("buffer", RawBuffer)] - - @property - def raw(self) -> Array: - ret = (c_ubyte * self.buffer.len).from_address(self.buffer.data) - setattr(ret, "_ref_", self) # ensure buffer is not dropped - return ret - - def __bytes__(self) -> bytes: - return bytes(self.raw) - - def __repr__(self) -> str: - """Format byte buffer as a string.""" - return repr(bytes(self)) - - def __del__(self): - """Call the byte buffer destructor when this instance is released.""" - get_library().askar_buffer_free(self.buffer) - - -class StrBuffer(c_char_p): - """A string allocated by the library.""" - - @classmethod - def from_param(cls): - """Returns the type ctypes should use for loading the result.""" - return c_void_p - - def is_none(self) -> bool: - """Check if the returned string pointer is null.""" - return self.value is None - - def opt_str(self) -> Optional[str]: - """Convert to an optional string.""" - val = self.value - return val.decode("utf-8") if val is not None else None - - def __bytes__(self) -> bytes: - """Convert to bytes.""" - return self.value - - def __str__(self): - """Convert to a string.""" - # not allowed to return None - val = self.opt_str() - return val if val is not None else "" - - def __del__(self): - """Call the string destructor when this instance is released.""" - get_library().askar_string_free(self) - - -class AeadParams(Structure): - """A byte buffer allocated by the library.""" - - _fields_ = [ - ("nonce_length", c_int32), - ("tag_length", c_int32), - ] - - def __repr__(self) -> str: - """Format AEAD params as a string.""" - return ( - f"" - ) - - -class Encrypted(Structure): - """The result of an AEAD encryption operation.""" - - _fields_ = [ - ("buffer", RawBuffer), - ("tag_pos", c_int64), - ("nonce_pos", c_int64), - ] - - def __getitem__(self, idx) -> bytes: - arr = (c_ubyte * self.buffer.len).from_address(self.buffer.data) - return bytes(arr[idx]) - - def __bytes__(self) -> bytes: - """Convert to bytes.""" - return self.ciphertext_tag - - @property - def ciphertext_tag(self) -> bytes: - """Accessor for the combined ciphertext and tag.""" - p = self.nonce_pos - return self[:p] - - @property - def ciphertext(self) -> bytes: - """Accessor for the ciphertext.""" - p = self.tag_pos - return self[:p] - - @property - def nonce(self) -> bytes: - """Accessor for the nonce.""" - p = self.nonce_pos - return self[p:] - - @property - def tag(self) -> bytes: - """Accessor for the authentication tag.""" - p1 = self.tag_pos - p2 = self.nonce_pos - return self[p1:p2] - - @property - def parts(self) -> Tuple[bytes, bytes, bytes]: - """Accessor for the ciphertext, tag, and nonce.""" - p1 = self.tag_pos - p2 = self.nonce_pos - return self[:p1], self[p1:p2], self[p2:] - - def __repr__(self) -> str: - """Format encrypted value as a string.""" - return ( - f"" - ) - - def __del__(self): - """Call the byte buffer destructor when this instance is released.""" - get_library().askar_buffer_free(self.buffer) - - -def get_library() -> CDLL: - """Return the CDLL instance, loading it if necessary.""" - global LIB - if LIB is None: - LIB = _load_library("aries_askar") - _init_logger() - return LIB - - -def _load_library(lib_name: str) -> CDLL: - """Load the CDLL library. - The python module directory is searched first, followed by the usual - library resolution for the current system. - """ - lib_prefix_mapping = {"win32": ""} - lib_suffix_mapping = {"darwin": ".dylib", "win32": ".dll"} - try: - os_name = sys.platform - lib_prefix = lib_prefix_mapping.get(os_name, "lib") - lib_suffix = lib_suffix_mapping.get(os_name, ".so") - lib_path = os.path.join( - os.path.dirname(__file__), f"{lib_prefix}{lib_name}{lib_suffix}" - ) - return CDLL(lib_path) - except KeyError: - LOGGER.debug("Unknown platform for shared library") - except OSError: - LOGGER.warning("Library not loaded from python package") - - lib_path = find_library(lib_name) - if not lib_path: - raise AskarError( - AskarErrorCode.WRAPPER, f"Library not found in path: {lib_path}" - ) - try: - return CDLL(lib_path) - except OSError as e: - raise AskarError( - AskarErrorCode.WRAPPER, f"Error loading library: {lib_path}" - ) from e - - -def _init_logger(): - logger = logging.getLogger(MODULE_NAME) - if logging.getLevelName("TRACE") == "Level TRACE": - # avoid redefining TRACE if another library has added it - logging.addLevelName(5, "TRACE") - - def _enabled(_context, level: int) -> bool: - return logger.isEnabledFor(LOG_LEVELS.get(level, level)) - - def _log( - _context, - level: int, - target: c_char_p, - message: c_char_p, - module_path: c_char_p, - file_name: c_char_p, - line: int, - ): - logger.getChild("native." + target.decode().replace("::", ".")).log( - LOG_LEVELS.get(level, level), - "\t%s:%d | %s", - file_name.decode() if file_name else None, - line, - message.decode(), - ) - - _init_logger.enabled_cb = CFUNCTYPE(c_int8, c_void_p, c_int32)(_enabled) - - _init_logger.log_cb = CFUNCTYPE( - None, c_void_p, c_int32, c_char_p, c_char_p, c_char_p, c_char_p, c_int32 - )(_log) - - if os.getenv("RUST_LOG"): - # level from environment - level = -1 - else: - # inherit current level from logger - level = _convert_log_level(logger.level or logger.parent.level) - - do_call( - "askar_set_custom_logger", - c_void_p(), # context - _init_logger.log_cb, - _init_logger.enabled_cb, - c_void_p(), # flush - c_int32(level), - ) - - -def set_max_log_level(level: Union[str, int, None]): - get_library() # ensure logger is initialized - set_level = _convert_log_level(level) - do_call("askar_set_max_log_level", c_int32(set_level)) - - -def _convert_log_level(level: Union[str, int, None]): - if level is None or level == "-1": - return -1 - else: - if isinstance(level, str): - level = level.upper() - name = logging.getLevelName(level) - for k, v in LOG_LEVELS.items(): - if logging.getLevelName(v) == name: - return k - return 0 - - -def _fulfill_future(fut: asyncio.Future, result, err: Exception = None): - """Resolve a callback future given the result and exception, if any.""" - if fut.cancelled(): - LOGGER.debug("callback previously cancelled") - elif err: - fut.set_exception(err) - else: - fut.set_result(result) - - -def _create_callback(cb_type: CFUNCTYPE, fut: asyncio.Future, post_process=None): - """Create a callback to handle the response from an async library method.""" - - def _cb(id: int, err: int, result=None): - """Callback function passed to the CFUNCTYPE for invocation.""" - if post_process: - result = post_process(result) - exc = get_current_error() if err else None - try: - (loop, _cb) = CALLBACKS.pop(fut) - except KeyError: - LOGGER.debug("callback already fulfilled") - return - loop.call_soon_threadsafe(lambda: _fulfill_future(fut, result, exc)) - - res = cb_type(_cb) - return res - - -def do_call(fn_name, *args): - """Perform a synchronous library function call.""" - lib_fn = getattr(get_library(), fn_name) - lib_fn.restype = c_int64 - result = lib_fn(*args) - if result: - raise get_current_error(True) - - -def do_call_async( - fn_name, *args, return_type=None, post_process=None -) -> asyncio.Future: - """Perform an asynchronous library function call.""" - lib_fn = getattr(get_library(), fn_name) - lib_fn.restype = c_int64 - loop = asyncio.get_event_loop() - fut = loop.create_future() - cf_args = [None, c_int64, c_int64] - if return_type: - cf_args.append(return_type) - cb_type = CFUNCTYPE(*cf_args) # could be cached - cb_res = _create_callback(cb_type, fut, post_process) - # keep a reference to the callback function to avoid it being freed - CALLBACKS[fut] = (loop, cb_res) - result = lib_fn(*args, cb_res, c_void_p()) # not making use of callback ID - if result: - # callback will not be executed - if CALLBACKS.pop(fut): - fut.set_exception(get_current_error()) - return fut - - -def encode_str(arg: Optional[Union[str, bytes]]) -> c_char_p: - """ - Encode an optional input argument as a string. - - Returns: None if the argument is None, otherwise the value encoded utf-8. - """ - if arg is None: - return c_char_p() - if isinstance(arg, str): - arg = arg.encode("utf-8") - return c_char_p(arg) - - -def encode_bytes( - arg: Optional[Union[str, bytes, ByteBuffer, FfiByteBuffer]] -) -> Union[FfiByteBuffer, ByteBuffer]: - if isinstance(arg, ByteBuffer) or isinstance(arg, FfiByteBuffer): - return arg - buf = FfiByteBuffer() - if isinstance(arg, memoryview): - buf.len = arg.nbytes - if arg.contiguous and not arg.readonly: - buf.value = (c_ubyte * buf.len).from_buffer(arg.obj) - else: - buf.value = (c_ubyte * buf.len).from_buffer_copy(arg.obj) - elif isinstance(arg, bytearray): - buf.len = len(arg) - if buf.len > 0: - buf.value = (c_ubyte * buf.len).from_buffer(arg) - elif arg is not None: - if isinstance(arg, str): - arg = arg.encode("utf-8") - buf.len = len(arg) - if buf.len > 0: - buf.value = (c_ubyte * buf.len).from_buffer_copy(arg) - return buf - - -def encode_tags(tags: Optional[dict]) -> c_char_p: - """Encode the tags as a JSON string.""" - if tags: - tags = json.dumps( - { - name: (list(value) if isinstance(value, set) else value) - for name, value in tags.items() - } - ) - else: - tags = None - return encode_str(tags) - - -def get_current_error(expect: bool = False) -> Optional[AskarError]: - """ - Get the error result from the previous failed API method. - - Args: - expect: Return a default error message if none is found - """ - err_json = StrBuffer() - if not get_library().askar_get_current_error(byref(err_json)): - try: - msg = json.loads(err_json.value) - except json.JSONDecodeError: - LOGGER.warning("JSON decode error for askar_get_current_error") - msg = None - if msg and "message" in msg and "code" in msg: - return AskarError( - AskarErrorCode(msg["code"]), msg["message"], msg.get("extra") - ) - if not expect: - return None - return AskarError(AskarErrorCode.WRAPPER, "Unknown error") - - -def generate_raw_key(seed: Union[str, bytes] = None) -> str: - """Generate a new raw store wrapping key.""" - key = StrBuffer() - do_call("askar_store_generate_raw_key", encode_bytes(seed), byref(key)) - return str(key) - - -def version() -> str: - """Get the version of the installed aries-askar library.""" - lib = get_library() - lib.askar_version.restype = c_void_p - return str(StrBuffer(lib.askar_version())) - - -async def store_open( - uri: str, key_method: str = None, pass_key: str = None, profile: str = None -) -> StoreHandle: - """Open an existing Store and return the open handle.""" - return await do_call_async( - "askar_store_open", - encode_str(uri), - encode_str(key_method and key_method.lower()), - encode_str(pass_key), - encode_str(profile), - return_type=StoreHandle, - ) - - -async def store_provision( - uri: str, - key_method: str = None, - pass_key: str = None, - profile: str = None, - recreate: bool = False, -) -> StoreHandle: - """Provision a new Store and return the open handle.""" - return await do_call_async( - "askar_store_provision", - encode_str(uri), - encode_str(key_method and key_method.lower()), - encode_str(pass_key), - encode_str(profile), - c_int8(recreate), - return_type=StoreHandle, - ) - - -async def store_create_profile(handle: StoreHandle, name: str = None) -> str: - """Create a new profile in a Store.""" - return str( - await do_call_async( - "askar_store_create_profile", - handle, - encode_str(name), - return_type=StrBuffer, - ) - ) - - -async def store_get_profile_name(handle: StoreHandle) -> str: - """Get the name of the default Store instance profile.""" - return str( - await do_call_async( - "askar_store_get_profile_name", - handle, - return_type=StrBuffer, - ) - ) - - -async def store_remove_profile(handle: StoreHandle, name: str) -> bool: - """Remove an existing profile from a Store.""" - return ( - await do_call_async( - "askar_store_remove_profile", - handle, - encode_str(name), - return_type=c_int8, - ) - != 0 - ) - - -async def store_rekey( - handle: StoreHandle, - key_method: str = None, - pass_key: str = None, -) -> StoreHandle: - """Replace the store key on a Store.""" - return await do_call_async( - "askar_store_rekey", - handle, - encode_str(key_method and key_method.lower()), - encode_str(pass_key), - ) - - -async def store_remove(uri: str) -> bool: - """Remove an existing Store, if any.""" - return ( - await do_call_async( - "askar_store_remove", - encode_str(uri), - return_type=c_int8, - ) - != 0 - ) - - -async def session_start( - handle: StoreHandle, profile: Optional[str] = None, as_transaction: bool = False -) -> SessionHandle: - """Start a new session with an open Store.""" - return await do_call_async( - "askar_session_start", - handle, - encode_str(profile), - c_int8(as_transaction), - return_type=SessionHandle, - ) - - -async def session_count( - handle: SessionHandle, category: str, tag_filter: Union[str, dict] = None -) -> int: - """Count rows in the Store.""" - category = encode_str(category) - if isinstance(tag_filter, dict): - tag_filter = json.dumps(tag_filter) - tag_filter = encode_str(tag_filter) - return int( - await do_call_async( - "askar_session_count", handle, category, tag_filter, return_type=c_int64 - ) - ) - - -async def session_fetch( - handle: SessionHandle, category: str, name: str, for_update: bool = False -) -> EntryListHandle: - """Fetch a row from the Store.""" - category = encode_str(category) - name = encode_str(name) - return await do_call_async( - "askar_session_fetch", - handle, - category, - name, - c_int8(for_update), - return_type=EntryListHandle, - ) - - -async def session_fetch_all( - handle: SessionHandle, - category: str, - tag_filter: Union[str, dict] = None, - limit: int = None, - for_update: bool = False, -) -> EntryListHandle: - """Fetch all matching rows in the Store.""" - if isinstance(tag_filter, dict): - tag_filter = json.dumps(tag_filter) - return await do_call_async( - "askar_session_fetch_all", - handle, - encode_str(category), - encode_str(tag_filter), - c_int64(limit if limit is not None else -1), - c_int8(for_update), - return_type=EntryListHandle, - ) - - -async def session_remove_all( - handle: SessionHandle, - category: str, - tag_filter: Union[str, dict] = None, -) -> int: - """Remove all matching rows in the Store.""" - if isinstance(tag_filter, dict): - tag_filter = json.dumps(tag_filter) - return int( - await do_call_async( - "askar_session_remove_all", - handle, - encode_str(category), - encode_str(tag_filter), - return_type=c_int64, - ) - ) - - -async def session_update( - handle: SessionHandle, - operation: EntryOperation, - category: str, - name: str, - value: Union[str, bytes] = None, - tags: dict = None, - expiry_ms: Optional[int] = None, -): - """Update a Store by inserting, updating, or removing a record.""" - - return await do_call_async( - "askar_session_update", - handle, - c_int8(operation.value), - encode_str(category), - encode_str(name), - encode_bytes(value), - encode_tags(tags), - c_int64(-1 if expiry_ms is None else expiry_ms), - ) - - -async def session_insert_key( - handle: SessionHandle, - key_handle: LocalKeyHandle, - name: str, - metadata: str = None, - tags: dict = None, - expiry_ms: Optional[int] = None, -): - await do_call_async( - "askar_session_insert_key", - handle, - key_handle, - encode_str(name), - encode_str(metadata), - encode_tags(tags), - c_int64(-1 if expiry_ms is None else expiry_ms), - return_type=c_void_p, - ) - - -async def session_fetch_key( - handle: SessionHandle, name: str, for_update: bool = False -) -> Optional[KeyEntryListHandle]: - ptr = await do_call_async( - "askar_session_fetch_key", - handle, - encode_str(name), - c_int8(for_update), - return_type=c_void_p, - ) - if ptr: - return KeyEntryListHandle(ptr) - - -async def session_fetch_all_keys( - handle: SessionHandle, - alg: Union[str, KeyAlg] = None, - thumbprint: str = None, - tag_filter: Union[str, dict] = None, - limit: int = None, - for_update: bool = False, -) -> EntryListHandle: - """Fetch all matching keys in the Store.""" - if isinstance(alg, KeyAlg): - alg = alg.value - if isinstance(tag_filter, dict): - tag_filter = json.dumps(tag_filter) - return await do_call_async( - "askar_session_fetch_all_keys", - handle, - encode_str(alg), - encode_str(thumbprint), - encode_str(tag_filter), - c_int64(limit if limit is not None else -1), - c_int8(for_update), - return_type=KeyEntryListHandle, - ) - - -async def session_update_key( - handle: SessionHandle, - name: str, - metadata: str = None, - tags: dict = None, - expiry_ms: Optional[int] = None, -): - await do_call_async( - "askar_session_update_key", - handle, - encode_str(name), - encode_str(metadata), - encode_tags(tags), - c_int64(-1 if expiry_ms is None else expiry_ms), - ) - - -async def session_remove_key(handle: SessionHandle, name: str): - await do_call_async( - "askar_session_remove_key", - handle, - encode_str(name), - ) - - -async def scan_start( - handle: StoreHandle, - profile: Optional[str], - category: str, - tag_filter: Union[str, dict] = None, - offset: int = None, - limit: int = None, -) -> ScanHandle: - """Create a new Scan against the Store.""" - if isinstance(tag_filter, dict): - tag_filter = json.dumps(tag_filter) - tag_filter = encode_str(tag_filter) - return await do_call_async( - "askar_scan_start", - handle, - encode_str(profile), - encode_str(category), - tag_filter, - c_int64(offset or 0), - c_int64(limit if limit is not None else -1), - return_type=ScanHandle, - ) - - -async def scan_next(handle: StoreHandle) -> Optional[EntryListHandle]: - handle = await do_call_async("askar_scan_next", handle, return_type=EntryListHandle) - return handle or None - - -def entry_list_count(handle: EntryListHandle) -> int: - len = c_int32() - do_call("askar_entry_list_count", handle, byref(len)) - return len.value - - -def key_entry_list_count(handle: EntryListHandle) -> int: - len = c_int32() - do_call("askar_key_entry_list_count", handle, byref(len)) - return len.value - - -def key_generate(alg: Union[str, KeyAlg], ephemeral: bool = False) -> LocalKeyHandle: - handle = LocalKeyHandle() - if isinstance(alg, KeyAlg): - alg = alg.value - do_call("askar_key_generate", encode_str(alg), c_int8(ephemeral), byref(handle)) - return handle - - -def key_from_seed( - alg: Union[str, KeyAlg], - seed: Union[str, bytes, ByteBuffer], - method: Union[str, SeedMethod] = None, -) -> LocalKeyHandle: - handle = LocalKeyHandle() - if isinstance(alg, KeyAlg): - alg = alg.value - if isinstance(method, SeedMethod): - method = method.value - do_call( - "askar_key_from_seed", - encode_str(alg), - encode_bytes(seed), - encode_str(method), - byref(handle), - ) - return handle - - -def key_from_public_bytes( - alg: Union[str, KeyAlg], public: Union[bytes, ByteBuffer] -) -> LocalKeyHandle: - handle = LocalKeyHandle() - if isinstance(alg, KeyAlg): - alg = alg.value - do_call( - "askar_key_from_public_bytes", - encode_str(alg), - encode_bytes(public), - byref(handle), - ) - return handle - - -def key_get_public_bytes(handle: LocalKeyHandle) -> ByteBuffer: - buf = ByteBuffer() - do_call( - "askar_key_get_public_bytes", - handle, - byref(buf), - ) - return buf - - -def key_from_secret_bytes( - alg: Union[str, KeyAlg], secret: Union[bytes, ByteBuffer] -) -> LocalKeyHandle: - handle = LocalKeyHandle() - if isinstance(alg, KeyAlg): - alg = alg.value - do_call( - "askar_key_from_secret_bytes", - encode_str(alg), - encode_bytes(secret), - byref(handle), - ) - return handle - - -def key_get_secret_bytes(handle: LocalKeyHandle) -> ByteBuffer: - buf = ByteBuffer() - do_call( - "askar_key_get_secret_bytes", - handle, - byref(buf), - ) - return buf - - -def key_from_jwk(jwk: Union[dict, str, bytes]) -> LocalKeyHandle: - handle = LocalKeyHandle() - if isinstance(jwk, dict): - jwk = json.dumps(jwk) - do_call("askar_key_from_jwk", encode_bytes(jwk), byref(handle)) - return handle - - -def key_convert(handle: LocalKeyHandle, alg: Union[str, KeyAlg]) -> LocalKeyHandle: - key = LocalKeyHandle() - if isinstance(alg, KeyAlg): - alg = alg.value - do_call("askar_key_convert", handle, encode_str(alg), byref(key)) - return key - - -def key_exchange( - alg: Union[str, KeyAlg], sk_handle: LocalKeyHandle, pk_handle: LocalKeyHandle -) -> LocalKeyHandle: - key = LocalKeyHandle() - if isinstance(alg, KeyAlg): - alg = alg.value - do_call( - "askar_key_from_key_exchange", encode_str(alg), sk_handle, pk_handle, byref(key) - ) - return key - - -def key_get_algorithm(handle: LocalKeyHandle) -> str: - alg = StrBuffer() - do_call("askar_key_get_algorithm", handle, byref(alg)) - return str(alg) - - -def key_get_ephemeral(handle: LocalKeyHandle) -> bool: - eph = c_int8() - do_call("askar_key_get_ephemeral", handle, byref(eph)) - return eph.value != 0 - - -def key_get_jwk_public(handle: LocalKeyHandle, alg: Union[str, KeyAlg] = None) -> str: - jwk = StrBuffer() - if isinstance(alg, KeyAlg): - alg = alg.value - do_call("askar_key_get_jwk_public", handle, encode_str(alg), byref(jwk)) - return str(jwk) - - -def key_get_jwk_secret(handle: LocalKeyHandle) -> ByteBuffer: - sec = ByteBuffer() - do_call("askar_key_get_jwk_secret", handle, byref(sec)) - return sec - - -def key_get_jwk_thumbprint( - handle: LocalKeyHandle, alg: Union[str, KeyAlg] = None -) -> str: - thumb = StrBuffer() - if isinstance(alg, KeyAlg): - alg = alg.value - do_call("askar_key_get_jwk_thumbprint", handle, encode_str(alg), byref(thumb)) - return str(thumb) - - -def key_aead_get_params(handle: LocalKeyHandle) -> AeadParams: - params = AeadParams() - do_call("askar_key_aead_get_params", handle, byref(params)) - return params - - -def key_aead_random_nonce(handle: LocalKeyHandle) -> ByteBuffer: - nonce = ByteBuffer() - do_call("askar_key_aead_random_nonce", handle, byref(nonce)) - return nonce - - -def key_aead_encrypt( - handle: LocalKeyHandle, - input: Union[bytes, str, ByteBuffer], - nonce: Union[bytes, ByteBuffer], - aad: Optional[Union[bytes, ByteBuffer]], -) -> Encrypted: - enc = Encrypted() - do_call( - "askar_key_aead_encrypt", - handle, - encode_bytes(input), - encode_bytes(nonce), - encode_bytes(aad), - byref(enc), - ) - return enc - - -def key_aead_decrypt( - handle: LocalKeyHandle, - ciphertext: Union[bytes, ByteBuffer, Encrypted], - nonce: Union[bytes, ByteBuffer], - tag: Optional[Union[bytes, ByteBuffer]], - aad: Optional[Union[bytes, ByteBuffer]], -) -> ByteBuffer: - dec = ByteBuffer() - if isinstance(ciphertext, Encrypted): - ciphertext = ciphertext.ciphertext_tag - do_call( - "askar_key_aead_decrypt", - handle, - encode_bytes(ciphertext), - encode_bytes(nonce), - encode_bytes(tag), - encode_bytes(aad), - byref(dec), - ) - return dec - - -def key_sign_message( - handle: LocalKeyHandle, - message: Union[bytes, str, ByteBuffer], - sig_type: Optional[str], -) -> ByteBuffer: - sig = ByteBuffer() - do_call( - "askar_key_sign_message", - handle, - encode_bytes(message), - encode_str(sig_type), - byref(sig), - ) - return sig - - -def key_verify_signature( - handle: LocalKeyHandle, - message: Union[bytes, str, ByteBuffer], - signature: Union[bytes, ByteBuffer], - sig_type: Optional[str], -) -> bool: - verify = c_int8() - do_call( - "askar_key_verify_signature", - handle, - encode_bytes(message), - encode_bytes(signature), - encode_str(sig_type), - byref(verify), - ) - return verify.value != 0 - - -def key_wrap_key( - handle: LocalKeyHandle, - other: LocalKeyHandle, - nonce: Optional[Union[bytes, ByteBuffer]], -) -> Encrypted: - wrapped = Encrypted() - do_call( - "askar_key_wrap_key", - handle, - other, - encode_bytes(nonce), - byref(wrapped), - ) - return wrapped - - -def key_unwrap_key( - handle: LocalKeyHandle, - alg: Union[str, KeyAlg], - ciphertext: Union[bytes, ByteBuffer, Encrypted], - nonce: Union[bytes, ByteBuffer], - tag: Optional[Union[bytes, ByteBuffer]], -) -> LocalKeyHandle: - result = LocalKeyHandle() - if isinstance(alg, KeyAlg): - alg = alg.value - if isinstance(ciphertext, Encrypted): - ciphertext = ciphertext.ciphertext_tag - do_call( - "askar_key_unwrap_key", - handle, - encode_str(alg), - encode_bytes(ciphertext), - encode_bytes(nonce), - encode_bytes(tag), - byref(result), - ) - return result - - -def key_crypto_box_random_nonce() -> ByteBuffer: - buf = ByteBuffer() - do_call( - "askar_key_crypto_box_random_nonce", - byref(buf), - ) - return buf - - -def key_crypto_box( - recip_handle: LocalKeyHandle, - sender_handle: LocalKeyHandle, - message: Union[bytes, str, ByteBuffer], - nonce: Union[bytes, ByteBuffer], -) -> ByteBuffer: - buf = ByteBuffer() - do_call( - "askar_key_crypto_box", - recip_handle, - sender_handle, - encode_bytes(message), - encode_bytes(nonce), - byref(buf), - ) - return buf - - -def key_crypto_box_open( - recip_handle: LocalKeyHandle, - sender_handle: LocalKeyHandle, - message: Union[bytes, str, ByteBuffer], - nonce: Union[bytes, ByteBuffer], -) -> ByteBuffer: - buf = ByteBuffer() - do_call( - "askar_key_crypto_box_open", - recip_handle, - sender_handle, - encode_bytes(message), - encode_bytes(nonce), - byref(buf), - ) - return buf - - -def key_crypto_box_seal( - handle: LocalKeyHandle, - message: Union[bytes, str, ByteBuffer], -) -> ByteBuffer: - buf = ByteBuffer() - do_call( - "askar_key_crypto_box_seal", - handle, - encode_bytes(message), - byref(buf), - ) - return buf - - -def key_crypto_box_seal_open( - handle: LocalKeyHandle, - ciphertext: Union[bytes, ByteBuffer], -) -> ByteBuffer: - buf = ByteBuffer() - do_call( - "askar_key_crypto_box_seal_open", - handle, - encode_bytes(ciphertext), - byref(buf), - ) - return buf - - -def key_derive_ecdh_es( - key_alg: Union[str, KeyAlg], - ephem_key: LocalKeyHandle, - receiver_key: LocalKeyHandle, - alg_id: Union[bytes, str, ByteBuffer], - apu: Union[bytes, str, ByteBuffer], - apv: Union[bytes, str, ByteBuffer], - receive: bool, -) -> LocalKeyHandle: - key = LocalKeyHandle() - if isinstance(key_alg, KeyAlg): - key_alg = key_alg.value - do_call( - "askar_key_derive_ecdh_es", - encode_str(key_alg), - ephem_key, - receiver_key, - encode_bytes(alg_id), - encode_bytes(apu), - encode_bytes(apv), - c_int8(receive), - byref(key), - ) - return key - - -def key_derive_ecdh_1pu( - key_alg: Union[str, KeyAlg], - ephem_key: LocalKeyHandle, - sender_key: LocalKeyHandle, - receiver_key: LocalKeyHandle, - alg_id: Union[bytes, str, ByteBuffer], - apu: Union[bytes, str, ByteBuffer], - apv: Union[bytes, str, ByteBuffer], - cc_tag: Optional[Union[bytes, ByteBuffer]], - receive: bool, -) -> LocalKeyHandle: - key = LocalKeyHandle() - if isinstance(key_alg, KeyAlg): - key_alg = key_alg.value - do_call( - "askar_key_derive_ecdh_1pu", - encode_str(key_alg), - ephem_key, - sender_key, - receiver_key, - encode_bytes(alg_id), - encode_bytes(apu), - encode_bytes(apv), - encode_bytes(cc_tag), - c_int8(receive), - byref(key), - ) - return key diff --git a/wrappers/python/aries_askar/bindings/__init__.py b/wrappers/python/aries_askar/bindings/__init__.py new file mode 100644 index 00000000..4b6f6684 --- /dev/null +++ b/wrappers/python/aries_askar/bindings/__init__.py @@ -0,0 +1,948 @@ +"""Low-level interaction with the aries-askar library.""" + +import asyncio +import json +import logging +import sys + +from ctypes import POINTER, byref, c_char_p, c_int8, c_int32, c_int64 +from typing import Optional, Union + +from ..types import EntryOperation, KeyAlg, SeedMethod + +from .lib import ( + AeadParams, + ByteBuffer, + Encrypted, + FfiByteBuffer, + FfiJson, + FfiStr, + FfiTagsJson, + Lib, + StrBuffer, +) +from .handle import ( + EntryListHandle, + KeyEntryListHandle, + LocalKeyHandle, + ScanHandle, + SessionHandle, + StoreHandle, +) + + +LIB = Lib() +LOGGER = logging.getLogger(__name__) +MODULE_NAME = __name__.split(".")[0] + + +def get_library(init: bool = True) -> Lib: + """Return the library instance, loading it if necessary.""" + global LIB + if LIB and init: + # preload library - required to create handle instances + LIB.loaded + return LIB + + +def set_max_log_level(level: Union[str, int, None]): + """Set the maximum logging level.""" + get_library().set_max_log_level(level) + + +def invoke(name, argtypes, *args): + """Perform a synchronous library function call.""" + get_library().invoke(name, argtypes, *args) + + +def invoke_async(name: str, argtypes, *args, return_type=None) -> asyncio.Future: + """Perform an asynchronous library function call.""" + return get_library().invoke_async(name, argtypes, *args, return_type=return_type) + + +def generate_raw_key(seed: Union[str, bytes] = None) -> str: + """Generate a new raw store wrapping key.""" + key = StrBuffer() + invoke( + "askar_store_generate_raw_key", + (FfiByteBuffer, POINTER(StrBuffer)), + seed, + byref(key), + ) + return str(key) + + +def version() -> str: + """Get the version of the installed library.""" + return get_library().version() + + +async def store_open( + uri: str, key_method: str = None, pass_key: str = None, profile: str = None +) -> StoreHandle: + """Open an existing Store and return the open handle.""" + return await invoke_async( + "askar_store_open", + (FfiStr, FfiStr, FfiStr, FfiStr), + uri, + key_method and key_method.lower(), + pass_key, + profile, + return_type=StoreHandle, + ) + + +async def store_provision( + uri: str, + key_method: str = None, + pass_key: str = None, + profile: str = None, + recreate: bool = False, +) -> StoreHandle: + """Provision a new Store and return the open handle.""" + return await invoke_async( + "askar_store_provision", + (FfiStr, FfiStr, FfiStr, FfiStr, c_int8), + uri, + key_method and key_method.lower(), + pass_key, + profile, + recreate, + return_type=StoreHandle, + ) + + +async def store_create_profile(handle: StoreHandle, name: str = None) -> str: + """Create a new profile in a Store.""" + return str( + await invoke_async( + "askar_store_create_profile", + (StoreHandle, FfiStr), + handle, + name, + return_type=StrBuffer, + ) + ) + + +async def store_get_profile_name(handle: StoreHandle) -> str: + """Get the name of the default Store instance profile.""" + return str( + await invoke_async( + "askar_store_get_profile_name", + (StoreHandle,), + handle, + return_type=StrBuffer, + ) + ) + + +async def store_remove_profile(handle: StoreHandle, name: str) -> bool: + """Remove an existing profile from a Store.""" + return ( + await invoke_async( + "askar_store_remove_profile", + (StoreHandle, FfiStr), + handle, + name, + return_type=c_int8, + ) + != 0 + ) + + +async def store_rekey( + handle: StoreHandle, + key_method: str = None, + pass_key: str = None, +) -> StoreHandle: + """Replace the store key on a Store.""" + return await invoke_async( + "askar_store_rekey", + (StoreHandle, FfiStr, FfiStr), + handle, + key_method and key_method.lower(), + pass_key, + return_type=c_int8, + ) + + +async def store_remove(uri: str) -> bool: + """Remove an existing Store, if any.""" + return ( + await invoke_async( + "askar_store_remove", + (FfiStr,), + uri, + return_type=c_int8, + ) + != 0 + ) + + +async def session_start( + handle: StoreHandle, profile: Optional[str] = None, as_transaction: bool = False +) -> SessionHandle: + """Start a new session with an open Store.""" + handle = await invoke_async( + "askar_session_start", + (StoreHandle, FfiStr, c_int8), + handle, + profile, + as_transaction, + return_type=SessionHandle, + ) + return handle + + +async def session_count( + handle: SessionHandle, category: str, tag_filter: Union[str, dict] = None +) -> int: + """Count rows in the Store.""" + return int( + await invoke_async( + "askar_session_count", + (SessionHandle, FfiStr, FfiJson), + handle, + category, + tag_filter, + return_type=c_int64, + ) + ) + + +async def session_fetch( + handle: SessionHandle, category: str, name: str, for_update: bool = False +) -> EntryListHandle: + """Fetch a row from the Store.""" + return await invoke_async( + "askar_session_fetch", + (SessionHandle, FfiStr, FfiStr, c_int8), + handle, + category, + name, + for_update, + return_type=EntryListHandle, + ) + + +async def session_fetch_all( + handle: SessionHandle, + category: str, + tag_filter: Union[str, dict] = None, + limit: int = None, + for_update: bool = False, +) -> EntryListHandle: + """Fetch all matching rows in the Store.""" + return await invoke_async( + "askar_session_fetch_all", + (SessionHandle, FfiStr, FfiJson, c_int64, c_int8), + handle, + category, + tag_filter, + limit if limit is not None else -1, + for_update, + return_type=EntryListHandle, + ) + + +async def session_remove_all( + handle: SessionHandle, + category: str, + tag_filter: Union[str, dict] = None, +) -> int: + """Remove all matching rows in the Store.""" + return int( + await invoke_async( + "askar_session_remove_all", + (SessionHandle, FfiStr, FfiJson), + handle, + category, + tag_filter, + return_type=c_int64, + ) + ) + + +async def session_update( + handle: SessionHandle, + operation: EntryOperation, + category: str, + name: str, + value: Union[str, bytes] = None, + tags: dict = None, + expiry_ms: Optional[int] = None, +): + """Update a Store by inserting, updating, or removing a record.""" + return await invoke_async( + "askar_session_update", + (SessionHandle, c_int8, FfiStr, FfiStr, FfiByteBuffer, FfiTagsJson, c_int64), + handle, + operation.value, + category, + name, + value, + tags, + -1 if expiry_ms is None else expiry_ms, + ) + + +async def session_insert_key( + handle: SessionHandle, + key_handle: LocalKeyHandle, + name: str, + metadata: str = None, + tags: dict = None, + expiry_ms: Optional[int] = None, +): + return await invoke_async( + "askar_session_insert_key", + (SessionHandle, LocalKeyHandle, FfiStr, FfiStr, FfiTagsJson, c_int64), + handle, + key_handle, + name, + metadata, + tags, + -1 if expiry_ms is None else expiry_ms, + ) + + +async def session_fetch_key( + handle: SessionHandle, name: str, for_update: bool = False +) -> KeyEntryListHandle: + return await invoke_async( + "askar_session_fetch_key", + (SessionHandle, FfiStr, c_int8), + handle, + name, + for_update, + return_type=KeyEntryListHandle, + ) + + +async def session_fetch_all_keys( + handle: SessionHandle, + alg: Union[str, KeyAlg] = None, + thumbprint: str = None, + tag_filter: Union[str, dict] = None, + limit: int = None, + for_update: bool = False, +) -> KeyEntryListHandle: + """Fetch all matching keys in the Store.""" + if isinstance(alg, KeyAlg): + alg = alg.value + return await invoke_async( + "askar_session_fetch_all_keys", + (SessionHandle, FfiStr, FfiStr, FfiJson, c_int64, c_int8), + handle, + alg, + thumbprint, + tag_filter, + limit if limit is not None else -1, + for_update, + return_type=KeyEntryListHandle, + ) + + +async def session_update_key( + handle: SessionHandle, + name: str, + metadata: str = None, + tags: dict = None, + expiry_ms: Optional[int] = None, +): + await invoke_async( + "askar_session_update_key", + (SessionHandle, FfiStr, FfiStr, FfiTagsJson, c_int64), + handle, + name, + metadata, + tags, + -1 if expiry_ms is None else expiry_ms, + ) + + +async def session_remove_key(handle: SessionHandle, name: str): + await invoke_async( + "askar_session_remove_key", + (SessionHandle, FfiStr), + handle, + name, + ) + + +async def scan_start( + handle: StoreHandle, + profile: Optional[str], + category: str, + tag_filter: Union[str, dict] = None, + offset: int = None, + limit: int = None, +) -> ScanHandle: + """Create a new Scan against the Store.""" + return await invoke_async( + "askar_scan_start", + (StoreHandle, FfiStr, FfiStr, FfiJson, c_int64, c_int64), + handle, + profile, + category, + tag_filter, + offset or 0, + limit if limit is not None else -1, + return_type=ScanHandle, + ) + + +async def scan_next(handle: ScanHandle) -> EntryListHandle: + return await invoke_async( + "askar_scan_next", (ScanHandle,), handle, return_type=EntryListHandle + ) + + +def entry_list_count(handle: EntryListHandle) -> int: + len = c_int32() + invoke( + "askar_entry_list_count", + (EntryListHandle, POINTER(c_int32)), + handle, + byref(len), + ) + return len.value + + +def key_entry_list_count(handle: KeyEntryListHandle) -> int: + len = c_int32() + invoke( + "askar_key_entry_list_count", + (KeyEntryListHandle, POINTER(c_int32)), + handle, + byref(len), + ) + return len.value + + +def key_generate(alg: Union[str, KeyAlg], ephemeral: bool = False) -> LocalKeyHandle: + handle = LocalKeyHandle() + if isinstance(alg, KeyAlg): + alg = alg.value + invoke( + "askar_key_generate", + (FfiStr, c_int8, POINTER(LocalKeyHandle)), + alg, + ephemeral, + byref(handle), + ) + return handle + + +def key_from_seed( + alg: Union[str, KeyAlg], + seed: Union[str, bytes, ByteBuffer], + method: Union[str, SeedMethod] = None, +) -> LocalKeyHandle: + handle = LocalKeyHandle() + if isinstance(alg, KeyAlg): + alg = alg.value + if isinstance(method, SeedMethod): + method = method.value + invoke( + "askar_key_from_seed", + (FfiStr, FfiByteBuffer, FfiStr, POINTER(LocalKeyHandle)), + alg, + seed, + method, + byref(handle), + ) + return handle + + +def key_from_public_bytes( + alg: Union[str, KeyAlg], public: Union[bytes, ByteBuffer] +) -> LocalKeyHandle: + handle = LocalKeyHandle() + if isinstance(alg, KeyAlg): + alg = alg.value + invoke( + "askar_key_from_public_bytes", + (FfiStr, FfiByteBuffer, POINTER(LocalKeyHandle)), + alg, + public, + byref(handle), + ) + return handle + + +def key_get_public_bytes(handle: LocalKeyHandle) -> ByteBuffer: + buf = ByteBuffer() + invoke( + "askar_key_get_public_bytes", + (LocalKeyHandle, POINTER(ByteBuffer)), + handle, + byref(buf), + ) + return buf + + +def key_from_secret_bytes( + alg: Union[str, KeyAlg], secret: Union[bytes, ByteBuffer] +) -> LocalKeyHandle: + handle = LocalKeyHandle() + if isinstance(alg, KeyAlg): + alg = alg.value + invoke( + "askar_key_from_secret_bytes", + (FfiStr, FfiByteBuffer, POINTER(LocalKeyHandle)), + alg, + secret, + byref(handle), + ) + return handle + + +def key_get_secret_bytes(handle: LocalKeyHandle) -> ByteBuffer: + buf = ByteBuffer() + invoke( + "askar_key_get_secret_bytes", + (LocalKeyHandle, POINTER(ByteBuffer)), + handle, + byref(buf), + ) + return buf + + +def key_from_jwk(jwk: Union[dict, str, bytes]) -> LocalKeyHandle: + handle = LocalKeyHandle() + if isinstance(jwk, dict): + jwk = json.dumps(jwk) + invoke( + "askar_key_from_jwk", + (FfiByteBuffer, POINTER(LocalKeyHandle)), + jwk, + byref(handle), + ) + return handle + + +def key_convert(handle: LocalKeyHandle, alg: Union[str, KeyAlg]) -> LocalKeyHandle: + key = LocalKeyHandle() + if isinstance(alg, KeyAlg): + alg = alg.value + invoke( + "askar_key_convert", + (LocalKeyHandle, FfiStr, POINTER(LocalKeyHandle)), + handle, + alg, + byref(key), + ) + return key + + +def key_exchange( + alg: Union[str, KeyAlg], sk_handle: LocalKeyHandle, pk_handle: LocalKeyHandle +) -> LocalKeyHandle: + key = LocalKeyHandle() + if isinstance(alg, KeyAlg): + alg = alg.value + invoke( + "askar_key_from_key_exchange", + (FfiStr, LocalKeyHandle, LocalKeyHandle, POINTER(LocalKeyHandle)), + alg, + sk_handle, + pk_handle, + byref(key), + ) + return key + + +def key_get_algorithm(handle: LocalKeyHandle) -> str: + alg = StrBuffer() + invoke( + "askar_key_get_algorithm", + (LocalKeyHandle, POINTER(StrBuffer)), + handle, + byref(alg), + ) + return str(alg) + + +def key_get_ephemeral(handle: LocalKeyHandle) -> bool: + eph = c_int8() + invoke( + "askar_key_get_ephemeral", + (LocalKeyHandle, POINTER(c_int8)), + handle, + byref(eph), + ) + return eph.value != 0 + + +def key_get_jwk_public(handle: LocalKeyHandle, alg: Union[str, KeyAlg] = None) -> str: + jwk = StrBuffer() + if isinstance(alg, KeyAlg): + alg = alg.value + invoke( + "askar_key_get_jwk_public", + (LocalKeyHandle, FfiStr, POINTER(StrBuffer)), + handle, + alg, + byref(jwk), + ) + return str(jwk) + + +def key_get_jwk_secret(handle: LocalKeyHandle) -> ByteBuffer: + sec = ByteBuffer() + invoke( + "askar_key_get_jwk_secret", + (LocalKeyHandle, POINTER(ByteBuffer)), + handle, + byref(sec), + ) + return sec + + +def key_get_jwk_thumbprint( + handle: LocalKeyHandle, alg: Union[str, KeyAlg] = None +) -> str: + thumb = StrBuffer() + if isinstance(alg, KeyAlg): + alg = alg.value + invoke( + "askar_key_get_jwk_thumbprint", + (LocalKeyHandle, FfiStr, POINTER(StrBuffer)), + handle, + alg, + byref(thumb), + ) + return str(thumb) + + +def key_aead_get_params(handle: LocalKeyHandle) -> AeadParams: + params = AeadParams() + invoke( + "askar_key_aead_get_params", + (LocalKeyHandle, POINTER(AeadParams)), + handle, + byref(params), + ) + return params + + +def key_aead_random_nonce(handle: LocalKeyHandle) -> ByteBuffer: + nonce = ByteBuffer() + invoke( + "askar_key_aead_random_nonce", + (LocalKeyHandle, POINTER(ByteBuffer)), + handle, + byref(nonce), + ) + return nonce + + +def key_aead_encrypt( + handle: LocalKeyHandle, + input: Union[bytes, str, ByteBuffer], + nonce: Union[bytes, ByteBuffer], + aad: Optional[Union[bytes, ByteBuffer]], +) -> Encrypted: + enc = Encrypted() + invoke( + "askar_key_aead_encrypt", + ( + LocalKeyHandle, + FfiByteBuffer, + FfiByteBuffer, + FfiByteBuffer, + POINTER(Encrypted), + ), + handle, + input, + nonce, + aad, + byref(enc), + ) + return enc + + +def key_aead_decrypt( + handle: LocalKeyHandle, + ciphertext: Union[bytes, ByteBuffer, Encrypted], + nonce: Union[bytes, ByteBuffer], + tag: Optional[Union[bytes, ByteBuffer]], + aad: Optional[Union[bytes, ByteBuffer]], +) -> ByteBuffer: + dec = ByteBuffer() + if isinstance(ciphertext, Encrypted): + nonce = ciphertext.nonce + ciphertext = ciphertext.ciphertext_tag + invoke( + "askar_key_aead_decrypt", + ( + LocalKeyHandle, + FfiByteBuffer, + FfiByteBuffer, + FfiByteBuffer, + FfiByteBuffer, + POINTER(ByteBuffer), + ), + handle, + ciphertext, + nonce, + tag, + aad, + byref(dec), + ) + return dec + + +def key_sign_message( + handle: LocalKeyHandle, + message: Union[bytes, str, ByteBuffer], + sig_type: Optional[str], +) -> ByteBuffer: + sig = ByteBuffer() + invoke( + "askar_key_sign_message", + (LocalKeyHandle, FfiByteBuffer, FfiStr, POINTER(ByteBuffer)), + handle, + message, + sig_type, + byref(sig), + ) + return sig + + +def key_verify_signature( + handle: LocalKeyHandle, + message: Union[bytes, str, ByteBuffer], + signature: Union[bytes, ByteBuffer], + sig_type: Optional[str], +) -> bool: + verify = c_int8() + invoke( + "askar_key_verify_signature", + (LocalKeyHandle, FfiByteBuffer, FfiByteBuffer, FfiStr, POINTER(c_int8)), + handle, + message, + signature, + sig_type, + byref(verify), + ) + return verify.value != 0 + + +def key_wrap_key( + handle: LocalKeyHandle, + other: LocalKeyHandle, + nonce: Optional[Union[bytes, ByteBuffer]], +) -> Encrypted: + wrapped = Encrypted() + invoke( + "askar_key_wrap_key", + (LocalKeyHandle, LocalKeyHandle, FfiByteBuffer, POINTER(Encrypted)), + handle, + other, + nonce, + byref(wrapped), + ) + return wrapped + + +def key_unwrap_key( + handle: LocalKeyHandle, + alg: Union[str, KeyAlg], + ciphertext: Union[bytes, ByteBuffer, Encrypted], + nonce: Union[bytes, ByteBuffer], + tag: Optional[Union[bytes, ByteBuffer]], +) -> LocalKeyHandle: + result = LocalKeyHandle() + if isinstance(alg, KeyAlg): + alg = alg.value + if isinstance(ciphertext, Encrypted): + ciphertext = ciphertext.ciphertext_tag + invoke( + "askar_key_unwrap_key", + ( + LocalKeyHandle, + FfiStr, + FfiByteBuffer, + FfiByteBuffer, + FfiByteBuffer, + POINTER(LocalKeyHandle), + ), + handle, + alg, + ciphertext, + nonce, + tag, + byref(result), + ) + return result + + +def key_crypto_box_random_nonce() -> ByteBuffer: + buf = ByteBuffer() + invoke( + "askar_key_crypto_box_random_nonce", + (POINTER(ByteBuffer),), + byref(buf), + ) + return buf + + +def key_crypto_box( + recip_handle: LocalKeyHandle, + sender_handle: LocalKeyHandle, + message: Union[bytes, str, ByteBuffer], + nonce: Union[bytes, ByteBuffer], +) -> ByteBuffer: + buf = ByteBuffer() + invoke( + "askar_key_crypto_box", + ( + LocalKeyHandle, + LocalKeyHandle, + FfiByteBuffer, + FfiByteBuffer, + POINTER(ByteBuffer), + ), + recip_handle, + sender_handle, + message, + nonce, + byref(buf), + ) + return buf + + +def key_crypto_box_open( + recip_handle: LocalKeyHandle, + sender_handle: LocalKeyHandle, + message: Union[bytes, str, ByteBuffer], + nonce: Union[bytes, ByteBuffer], +) -> ByteBuffer: + buf = ByteBuffer() + invoke( + "askar_key_crypto_box_open", + ( + LocalKeyHandle, + LocalKeyHandle, + FfiByteBuffer, + FfiByteBuffer, + POINTER(ByteBuffer), + ), + recip_handle, + sender_handle, + message, + nonce, + byref(buf), + ) + return buf + + +def key_crypto_box_seal( + handle: LocalKeyHandle, + message: Union[bytes, str, ByteBuffer], +) -> ByteBuffer: + buf = ByteBuffer() + invoke( + "askar_key_crypto_box_seal", + (LocalKeyHandle, FfiByteBuffer, POINTER(ByteBuffer)), + handle, + message, + byref(buf), + ) + return buf + + +def key_crypto_box_seal_open( + handle: LocalKeyHandle, + ciphertext: Union[bytes, ByteBuffer], +) -> ByteBuffer: + buf = ByteBuffer() + invoke( + "askar_key_crypto_box_seal_open", + (LocalKeyHandle, FfiByteBuffer, POINTER(ByteBuffer)), + handle, + ciphertext, + byref(buf), + ) + return buf + + +def key_derive_ecdh_es( + key_alg: Union[str, KeyAlg], + ephem_key: LocalKeyHandle, + receiver_key: LocalKeyHandle, + alg_id: Union[bytes, str, ByteBuffer], + apu: Union[bytes, str, ByteBuffer], + apv: Union[bytes, str, ByteBuffer], + receive: bool, +) -> LocalKeyHandle: + key = LocalKeyHandle() + if isinstance(key_alg, KeyAlg): + key_alg = key_alg.value + invoke( + "askar_key_derive_ecdh_es", + ( + FfiStr, + LocalKeyHandle, + LocalKeyHandle, + FfiByteBuffer, + FfiByteBuffer, + FfiByteBuffer, + c_int8, + POINTER(LocalKeyHandle), + ), + key_alg, + ephem_key, + receiver_key, + alg_id, + apu, + apv, + receive, + byref(key), + ) + return key + + +def key_derive_ecdh_1pu( + key_alg: Union[str, KeyAlg], + ephem_key: LocalKeyHandle, + sender_key: LocalKeyHandle, + receiver_key: LocalKeyHandle, + alg_id: Union[bytes, str, ByteBuffer], + apu: Union[bytes, str, ByteBuffer], + apv: Union[bytes, str, ByteBuffer], + cc_tag: Optional[Union[bytes, ByteBuffer]], + receive: bool, +) -> LocalKeyHandle: + key = LocalKeyHandle() + if isinstance(key_alg, KeyAlg): + key_alg = key_alg.value + invoke( + "askar_key_derive_ecdh_1pu", + ( + FfiStr, + LocalKeyHandle, + LocalKeyHandle, + LocalKeyHandle, + FfiByteBuffer, + FfiByteBuffer, + FfiByteBuffer, + FfiByteBuffer, + c_int8, + POINTER(LocalKeyHandle), + ), + key_alg, + ephem_key, + sender_key, + receiver_key, + alg_id, + apu, + apv, + cc_tag, + receive, + byref(key), + ) + return key diff --git a/wrappers/python/aries_askar/bindings/handle.py b/wrappers/python/aries_askar/bindings/handle.py new file mode 100644 index 00000000..8f6b0c25 --- /dev/null +++ b/wrappers/python/aries_askar/bindings/handle.py @@ -0,0 +1,248 @@ +"""Handles for allocated resources.""" + +import json +import logging + +from ctypes import ( + POINTER, + Structure, + byref, + c_int8, + c_int32, + c_int64, + c_size_t, + c_void_p, +) + +from .lib import ByteBuffer, Lib, StrBuffer, finalize_struct + + +LOGGER = logging.getLogger(__name__) + + +class ArcHandle(Structure): + """Base class for handle instances.""" + + _fields_ = [ + ("value", c_size_t), + ] + _dtor_: str = None + + def __init__(self, value=0): + """Initializer.""" + if isinstance(value, c_size_t): + value = value.value + if not isinstance(value, int): + raise ValueError("Invalid handle") + super().__init__(value=value) + finalize_struct(self, c_size_t) + + @classmethod + def from_param(cls, param): + """Create from an input to a library method invocation.""" + if isinstance(param, cls): + return param + return cls(param) + + def __bool__(self) -> bool: + """Convert to a boolean value.""" + return bool(self.value) + + def __repr__(self) -> str: + """Format handle as a string.""" + return f"{self.__class__.__name__}({self.value})" + + @classmethod + def _cleanup(cls, value: c_size_t): + """Destructor.""" + if cls._dtor_: + Lib().invoke_dtor(cls._dtor_, value) + + +class StoreHandle(ArcHandle): + """Handle for an active Store instance.""" + + async def close(self): + """Manually close the store, waiting for any active connections.""" + if self.value: + await Lib().invoke_async("askar_store_close", (c_size_t,), self.value) + self.value = 0 + + @classmethod + def _cleanup(cls, value: c_size_t): + """Close the store when there are no more references to this object.""" + Lib().invoke_dtor( + "askar_store_close", + value, + None, + 0, + argtypes=(c_size_t, c_void_p, c_int64), + restype=c_int64, + ) + + +class SessionHandle(ArcHandle): + """Handle for an active Session/Transaction instance.""" + + async def close(self, commit: bool = False): + """Manually close the session.""" + if self.value: + await Lib().invoke_async( + "askar_session_close", + (c_size_t, c_int8), + self.value, + commit, + ) + self.value = 0 + + @classmethod + def _cleanup(cls, value: c_size_t): + """Close the session when there are no more references to this object.""" + Lib().invoke_dtor( + "askar_session_close", + value, + 0, + None, + 0, + argtypes=(c_size_t, c_int8, c_void_p, c_int64), + restype=c_int64, + ) + + +class ScanHandle(ArcHandle): + """Handle for an active Store scan instance.""" + + _dtor_ = "askar_scan_free" + + +class EntryListHandle(ArcHandle): + """Handle for an active EntryList instance.""" + + _dtor_ = "askar_entry_list_free" + + def get_category(self, index: int) -> str: + """Get the entry category.""" + cat = StrBuffer() + Lib().invoke( + "askar_entry_list_get_category", + (EntryListHandle, c_int32, POINTER(StrBuffer)), + self, + index, + byref(cat), + ) + return str(cat) + + def get_name(self, index: int) -> str: + """Get the entry name.""" + name = StrBuffer() + Lib().invoke( + "askar_entry_list_get_name", + (EntryListHandle, c_int32, POINTER(StrBuffer)), + self, + index, + byref(name), + ) + return str(name) + + def get_value(self, index: int) -> memoryview: + """Get the entry value.""" + val = ByteBuffer() + Lib().invoke( + "askar_entry_list_get_value", + (EntryListHandle, c_int32, POINTER(ByteBuffer)), + self, + index, + byref(val), + ) + return val.view + + def get_tags(self, index: int) -> dict: + """Get the entry tags.""" + tags = StrBuffer() + Lib().invoke( + "askar_entry_list_get_tags", + (EntryListHandle, c_int32, POINTER(StrBuffer)), + self, + index, + byref(tags), + ) + if tags: + tags = json.loads(tags.value) + for t in tags: + if isinstance(tags[t], list): + tags[t] = set(tags[t]) + else: + tags = dict() + return tags + + +class KeyEntryListHandle(ArcHandle): + """Handle for an active KeyEntryList instance.""" + + _dtor_ = "askar_key_entry_list_free" + + def get_algorithm(self, index: int) -> str: + """Get the key algorithm.""" + name = StrBuffer() + Lib().invoke( + "askar_key_entry_list_get_algorithm", + (KeyEntryListHandle, c_int32, POINTER(StrBuffer)), + self, + index, + byref(name), + ) + return str(name) + + def get_name(self, index: int) -> str: + """Get the key name.""" + name = StrBuffer() + Lib().invoke( + "askar_key_entry_list_get_name", + (KeyEntryListHandle, c_int32, POINTER(StrBuffer)), + self, + index, + byref(name), + ) + return str(name) + + def get_metadata(self, index: int) -> str: + """Get for the key metadata.""" + metadata = StrBuffer() + Lib().invoke( + "askar_key_entry_list_get_metadata", + (KeyEntryListHandle, c_int32, POINTER(StrBuffer)), + self, + index, + byref(metadata), + ) + return str(metadata) + + def get_tags(self, index: int) -> dict: + """Get the key tags.""" + tags = StrBuffer() + Lib().invoke( + "askar_key_entry_list_get_tags", + (KeyEntryListHandle, c_int32, POINTER(StrBuffer)), + self, + index, + byref(tags), + ) + return json.loads(tags.value) if tags else None + + def load_key(self, index: int) -> "LocalKeyHandle": + """Load the key instance.""" + handle = LocalKeyHandle() + Lib().invoke( + "askar_key_entry_list_load_local", + (KeyEntryListHandle, c_int32, POINTER(LocalKeyHandle)), + self, + index, + byref(handle), + ) + return handle + + +class LocalKeyHandle(ArcHandle): + """Handle for an active LocalKey instance.""" + + _dtor_ = "askar_key_free" diff --git a/wrappers/python/aries_askar/bindings/lib.py b/wrappers/python/aries_askar/bindings/lib.py new file mode 100644 index 00000000..8eb64aab --- /dev/null +++ b/wrappers/python/aries_askar/bindings/lib.py @@ -0,0 +1,672 @@ +"""Library instance and allocated buffer handling.""" + +import asyncio +import json +import itertools +import logging +import os +import sys +import threading +import time + +from ctypes import ( + Array, + CDLL, + CFUNCTYPE, + POINTER, + Structure, + addressof, + byref, + cast, + c_char, + c_char_p, + c_int8, + c_int32, + c_int64, + c_ubyte, + c_void_p, +) +from ctypes.util import find_library +from typing import Callable, Optional, Tuple, Union +from weakref import finalize, ref + +from ..error import AskarError, AskarErrorCode + + +LOGGER = logging.getLogger(__name__) +MODULE_NAME = __name__.split(".")[0] + +LOG_LEVELS = { + 1: logging.ERROR, + 2: logging.WARNING, + 3: logging.INFO, + 4: logging.DEBUG, +} + + +def _convert_log_level(level: Union[str, int, None]): + if level is None or level == "-1": + return -1 + else: + if isinstance(level, str): + level = level.upper() + name = logging.getLevelName(level) + for k, v in LOG_LEVELS.items(): + if logging.getLevelName(v) == name: + return k + return 0 + + +def _load_method_arguments(name, argtypes, args): + """Preload argument values to avoid freeing any intermediate data.""" + if not argtypes: + return args + if len(args) != len(argtypes): + raise ValueError(f"{name}: Arguments length does not match argtypes length") + return [ + arg if hasattr(argtype, "_type_") else argtype.from_param(arg) + for (arg, argtype) in zip(args, argtypes) + ] + + +def _struct_dtor(ctype: type, address: int, dtor: Callable): + value = ctype.from_address(address) + if value: + dtor(value) + + +def finalize_struct(instance, ctype): + finalize( + instance, _struct_dtor, ctype, addressof(instance), instance.__class__._cleanup + ) + + +class LibLoad: + def __init__(self, lib_name: str): + """Load the CDLL library. + + The python module directory is searched first, followed by the usual + library resolution for the current system. + """ + self._cdll = None + self._callbacks = {} + self._cb_id = itertools.count(0) + self._cfuncs = {} + self._lib_name = lib_name + self._log_cb = None + self._log_enabled_cb = None + self._methods = {} + + self._load_library() + self._init_logger() + + def _load_library(self): + lib_name = self._lib_name + lib_prefix_mapping = {"win32": ""} + lib_suffix_mapping = {"darwin": ".dylib", "win32": ".dll"} + try: + os_name = sys.platform + lib_prefix = lib_prefix_mapping.get(os_name, "lib") + lib_suffix = lib_suffix_mapping.get(os_name, ".so") + lib_path = os.path.join( + os.path.dirname(__file__), "..", f"{lib_prefix}{lib_name}{lib_suffix}" + ) + self._cdll = CDLL(lib_path) + return + except KeyError: + LOGGER.debug("Unknown platform for shared library") + except OSError: + LOGGER.warning("Library not loaded from python package") + + lib_path = find_library(lib_name) + if not lib_path: + raise AskarError( + AskarErrorCode.WRAPPER, f"Library not found in path: {lib_name}" + ) + try: + self._cdll = CDLL(lib_path) + except OSError as e: + raise AskarError( + AskarErrorCode.WRAPPER, f"Error loading library: {lib_path}" + ) from e + + def _init_logger(self): + if self._log_cb: + return + + logger = logging.getLogger(MODULE_NAME) + if logging.getLevelName("TRACE") == "Level TRACE": + # avoid redefining TRACE if another library has added it + logging.addLevelName(5, "TRACE") + + self._log_cb_t = CFUNCTYPE( + None, c_void_p, c_int32, c_char_p, c_char_p, c_char_p, c_char_p, c_int32 + ) + + def _log_cb( + _context, + level: int, + target: c_char_p, + message: c_char_p, + _module_path: c_char_p, + file_name: c_char_p, + line: int, + ): + logger.getChild("native." + target.decode().replace("::", ".")).log( + LOG_LEVELS.get(level, level), + "\t%s:%d | %s", + file_name.decode() if file_name else None, + line, + message.decode(), + ) + + self._log_cb = self._log_cb_t(_log_cb) + + self._log_enabled_cb_t = CFUNCTYPE(c_int8, c_void_p, c_int32) + + def _enabled_cb(_context, level: int) -> bool: + return self._cdll and logger.isEnabledFor(LOG_LEVELS.get(level, level)) + + self._log_enabled_cb = self._log_enabled_cb_t(_enabled_cb) + + if os.getenv("RUST_LOG"): + # level from environment + level = -1 + else: + # inherit current level from logger + level = _convert_log_level(logger.level or logger.parent.level) + + set_logger = self.method( + "askar_set_custom_logger", + (c_void_p, c_void_p, c_void_p, c_void_p, c_int32), + restype=c_int64, + ) + if set_logger( + None, # context + self._log_cb, + self._log_enabled_cb, + None, # flush + level, + ): + raise self.get_current_error(True) + + try: + finalize(self, self.method("askar_clear_custom_logger", None, restype=None)) + except AttributeError: + # method is new as of 0.2.5 + pass + + def invoke(self, name, argtypes, *args): + """Perform a synchronous library function call.""" + method = self.method(name, argtypes, restype=c_int64) + args = _load_method_arguments(name, argtypes, args) + result = method(*args) + if result: + raise self.get_current_error(True) + + def invoke_async(self, name: str, argtypes, *args, return_type=None): + """Perform an asynchronous library function call.""" + method = self.method(name, (*argtypes, c_void_p, c_int64), restype=c_int64) + loop = asyncio.get_event_loop() + fut = loop.create_future() + cb_info = self._cfuncs.get(name) + if cb_info: + cb = cb_info[1] + else: + cb_args = [c_int64, c_int64] + if return_type: + cb_args.append(return_type) + cb_type = CFUNCTYPE(None, *cb_args) + cb = cb_type(self._handle_callback) + # must maintain a reference to cb_type, otherwise + # it may be freed, resulting in memory errors. + self._cfuncs[name] = (cb_type, cb) + args = _load_method_arguments(name, argtypes, args) + cb_id = next(self._cb_id) + self._callbacks[cb_id] = (loop, fut, name) + result = method(*args, cb, cb_id) + if result: + # FFI must not execute the callback if an error is returned + err = self.get_current_error(True) + if self._callbacks.pop(cb_id, None): + self._fulfill_future(fut, None, err) + return fut + + def invoke_dtor(self, name: str, *values, argtypes=None, restype=None): + method = self.method(name, argtypes, restype=restype) + if method: + method(*values) + + def _handle_callback(self, cb_id: int, err: int, result=None): + exc = self.get_current_error(True) if err else None + cb = self._callbacks.pop(cb_id, None) + if not cb: + LOGGER.info("Callback already fulfilled: %s", cb_id) + return + (loop, fut, _name) = cb + loop.call_soon_threadsafe(self._fulfill_future, fut, result, exc) + + def _fulfill_future(self, fut: asyncio.Future, result, err: Exception = None): + """Resolve a callback future given the result and exception, if any.""" + if fut.cancelled(): + LOGGER.debug("callback previously cancelled") + elif err: + fut.set_exception(err) + else: + fut.set_result(result) + + def get_current_error(self, expect: bool = False) -> Optional[AskarError]: + """ + Get the error result from the previous failed API method. + + Args: + expect: Return a default error message if none is found + """ + err_json = StrBuffer() + method = self.method( + "askar_get_current_error", (POINTER(StrBuffer),), restype=c_int64 + ) + if not method(byref(err_json)): + try: + msg = json.loads(err_json.value) + except json.JSONDecodeError: + LOGGER.warning("JSON decode error for askar_get_current_error") + msg = None + if msg and "message" in msg and "code" in msg: + return AskarError( + AskarErrorCode(msg["code"]), msg["message"], msg.get("extra") + ) + if not expect: + return None + return AskarError(AskarErrorCode.WRAPPER, "Unknown error") + + def method(self, name, argtypes, *, restype=None): + """Access a method of the library.""" + method = self._methods.get(name) + if not method: + method = getattr(self._cdll, name, None) + if not method: + return None + if argtypes: + method.argtypes = argtypes + method.restype = restype + self._methods[name] = method + return method + + def _cleanup(self): + """Destructor.""" + if self._callbacks: + + def _wait_callbacks(cb): + while cb: + time.sleep(0.01) + + th = threading.Thread(target=_wait_callbacks, args=(self._callbacks,)) + th.start() + th.join(timeout=1.0) + if th.is_alive(): + LOGGER.error( + "%s: Timed out waiting for callbacks to complete", + self._lib_name, + ) + + +class Lib: + """The loaded library instance.""" + + INSTANCE = None + LIB_NAME = "aries_askar" + + def __new__(cls, *args): + """Class initializer.""" + inst = cls.INSTANCE and cls.INSTANCE() + if inst is None: + inst = super().__new__(cls, *args) + inst._lib = None + inst._objs = [] + # Keep a weak reference to the instance. This assumes that + # at least one instance is assigned to a persistent variable. + cls.INSTANCE = ref(inst) + # Register finalizer to be called later than any derived objects. + finalize(inst, cls._cleanup, inst._objs) + return inst + + @property + def loaded(self) -> LibLoad: + """Determine if the library has been loaded.""" + if not self._lib: + self._lib = LibLoad(self.__class__.LIB_NAME) + self._objs.append(self._lib) + return self._lib + + def invoke(self, name, argtypes, *args): + """Perform a synchronous library function call.""" + self.loaded.invoke(name, argtypes, *args) + + async def invoke_async(self, name: str, argtypes, *args, return_type=None): + """Perform an asynchronous library function call.""" + return await self.loaded.invoke_async( + name, argtypes, *args, return_type=return_type + ) + + def invoke_dtor(self, name: str, *args, argtypes=None, restype=None): + """Call a destructor method.""" + if self._lib: + self._lib.invoke_dtor(name, *args, argtypes=argtypes, restype=restype) + + def set_max_log_level(self, level: Union[str, int, None]): + """Set the maximum log level for the library.""" + set_level = _convert_log_level(level) + self.invoke("askar_set_max_log_level", (c_int32,), set_level) + + def version(self) -> str: + """Get the version of the installed library.""" + return str( + self.loaded._method( + "askar_version", + None, + restype=StrBuffer, + )() + ) + + def __repr__(self) -> str: + loaded = self._lib is not None + return f"" + + @classmethod + def _cleanup(cls, objs): + for obj in objs: + obj._cleanup() + + +class RawBuffer(Structure): + """A byte buffer allocated by the library.""" + + _fields_ = [ + ("len", c_int64), + ("data", POINTER(c_ubyte)), + ] + + def __bool__(self) -> bool: + return bool(self.data) + + def __bytes__(self) -> bytes: + if not self.len: + return b"" + return bytes(self.array) + + def __len__(self) -> int: + return int(self.len) + + @property + def array(self) -> Array: + return cast(self.data, POINTER(c_ubyte * self.len)).contents + + def __repr__(self) -> str: + return f"" + + +class FfiByteBuffer: + """A byte buffer allocated by Python.""" + + def __init__(self, value): + if isinstance(value, str): + value = value.encode("utf-8") + + if value is None: + dlen = 0 + data = c_char_p() + elif isinstance(value, memoryview): + dlen = value.nbytes + data = c_char_p(value.tobytes()) + elif isinstance(value, bytes): + dlen = len(value) + data = c_char_p(value) + else: + raise TypeError(f"Expected str or bytes value, got {type(value)}") + self._dlen = dlen + self._data = data + + def __bytes__(self) -> bytes: + if not self._data: + return b"" + return self._data.value + + def __len__(self) -> int: + return self._dlen + + @property + def _as_parameter_(self) -> RawBuffer: + buf = RawBuffer(len=self._dlen, data=cast(self._data, POINTER(c_ubyte))) + return buf + + @classmethod + def from_param(cls, value): + if isinstance(value, (ByteBuffer, FfiByteBuffer)): + return value + return cls(value) + + +class ByteBuffer(Structure): + """A managed byte buffer allocated by the library.""" + + _fields_ = [("buffer", RawBuffer)] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._a = addressof(self) + finalize_struct(self, RawBuffer) + + @property + def _as_parameter_(self): + return self.buffer + + @property + def array(self) -> Array: + return self.buffer.array + + @property + def view(self) -> memoryview: + m = memoryview(self.array) + # ensure self stays alive until the view is dropped + finalize(m, lambda _: (), self) + return m + + def __bytes__(self) -> bytes: + return bytes(self.buffer) + + def __len__(self) -> int: + return len(self.buffer) + + def __getitem__(self, idx) -> bytes: + return bytes(self.buffer.array[idx]) + + def __repr__(self) -> str: + """Format byte buffer as a string.""" + return f"{self.__class__.__name__}({bytes(self)})" + + @classmethod + def _cleanup(cls, buffer: RawBuffer): + """Call the byte buffer destructor when this instance is released.""" + Lib().invoke_dtor("askar_buffer_free", buffer) + + +class FfiStr: + """A string value allocated by Python.""" + + def __init__(self, value=None): + if value is None: + value = c_char_p() + elif isinstance(value, c_char_p): + pass + else: + if isinstance(value, str): + value = value.encode("utf-8") + if not isinstance(value, bytes): + raise TypeError(f"Expected string value, got {type(value)}") + value = c_char_p(value) + self.value = value + + @classmethod + def from_param(cls, value): + if isinstance(value, cls): + return value + return cls(value) + + @property + def _as_parameter_(self): + return self.value + + def __repr__(self) -> str: + """Format handle as a string.""" + return f"{self.__class__.__name__}({self.value})" + + +class FfiJson: + @classmethod + def from_param(cls, value): + if isinstance(value, FfiStr): + return value + if isinstance(value, dict): + value = json.dumps(value) + return FfiStr(value) + + +class FfiTagsJson: + @classmethod + def from_param(cls, tags): + if isinstance(tags, FfiStr): + return tags + if tags: + tags = json.dumps( + { + name: (list(value) if isinstance(value, set) else value) + for name, value in tags.items() + } + ) + else: + tags = None + return FfiStr(tags) + + +class StrBuffer(Structure): + """A string allocated by the library.""" + + _fields_ = [("buffer", POINTER(c_char))] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + finalize_struct(self, c_char_p) + + def is_none(self) -> bool: + """Check if the returned string pointer is null.""" + return not self.buffer + + def opt_str(self) -> Optional[str]: + """Convert to an optional string.""" + val = self.value + return val.decode("utf-8") if val is not None else None + + def __bool__(self) -> bool: + return bool(self.buffer) + + def __bytes__(self) -> bytes: + """Convert to bytes.""" + bval = self.value + return bval if bval is not None else bytes() + + def __str__(self): + """Convert to a string.""" + # not allowed to return None + val = self.opt_str() + return val if val is not None else "" + + @property + def value(self) -> bytes: + return cast(self.buffer, c_char_p).value + + @classmethod + def _cleanup(cls, buffer: c_char_p): + """Call the string destructor when this instance is released.""" + Lib().invoke_dtor("askar_string_free", buffer) + + +class AeadParams(Structure): + """A byte buffer allocated by the library.""" + + _fields_ = [ + ("nonce_length", c_int32), + ("tag_length", c_int32), + ] + + def __repr__(self) -> str: + """Format AEAD params as a string.""" + return ( + f"" + ) + + +class Encrypted(Structure): + """The result of an AEAD encryption operation.""" + + _fields_ = [ + ("buffer", RawBuffer), + ("tag_pos", c_int64), + ("nonce_pos", c_int64), + ] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + finalize_struct(self, RawBuffer) + + def __getitem__(self, idx) -> bytes: + return bytes(self.buffer.array[idx]) + + def __bytes__(self) -> bytes: + """Convert to bytes.""" + return self.ciphertext_tag + + @property + def ciphertext_tag(self) -> bytes: + """Accessor for the combined ciphertext and tag.""" + p = self.nonce_pos + return self[:p] + + @property + def ciphertext(self) -> bytes: + """Accessor for the ciphertext.""" + p = self.tag_pos + return self[:p] + + @property + def nonce(self) -> bytes: + """Accessor for the nonce.""" + p = self.nonce_pos + return self[p:] + + @property + def tag(self) -> bytes: + """Accessor for the authentication tag.""" + p1 = self.tag_pos + p2 = self.nonce_pos + return self[p1:p2] + + @property + def parts(self) -> Tuple[bytes, bytes, bytes]: + """Accessor for the ciphertext, tag, and nonce.""" + p1 = self.tag_pos + p2 = self.nonce_pos + return self[:p1], self[p1:p2], self[p2:] + + def __repr__(self) -> str: + """Format encrypted value as a string.""" + return ( + f"" + ) + + @classmethod + def _cleanup(cls, buffer: RawBuffer): + """Call the byte buffer destructor when this instance is released.""" + Lib().invoke_dtor("askar_buffer_free", buffer) diff --git a/wrappers/python/aries_askar/ecdh.py b/wrappers/python/aries_askar/ecdh.py index d95dffed..bac928d7 100644 --- a/wrappers/python/aries_askar/ecdh.py +++ b/wrappers/python/aries_askar/ecdh.py @@ -5,10 +5,10 @@ from .types import KeyAlg -def _load_key(key: Union[dict, str, Key]) -> Key: - if isinstance(key, (str, dict)): - key = Key.from_jwk(key) - return key +def _load_key(key: Union[dict, str, bytes, Key]) -> Key: + if isinstance(key, Key): + return key + return Key.from_jwk(key) class EcdhEs: diff --git a/wrappers/python/aries_askar/key.py b/wrappers/python/aries_askar/key.py index bdd2e657..39d1a09c 100644 --- a/wrappers/python/aries_askar/key.py +++ b/wrappers/python/aries_askar/key.py @@ -3,15 +3,14 @@ from typing import Union from . import bindings - -from .bindings import Encrypted +from .bindings import AeadParams, Encrypted, LocalKeyHandle from .types import KeyAlg, SeedMethod class Key: """An active key or keypair instance.""" - def __init__(self, handle: bindings.LocalKeyHandle): + def __init__(self, handle: LocalKeyHandle): """Initialize the Key instance.""" self._handle = handle @@ -42,7 +41,7 @@ def from_jwk(cls, jwk: Union[dict, str, bytes]) -> "Key": return cls(bindings.key_from_jwk(jwk)) @property - def handle(self) -> bindings.LocalKeyHandle: + def handle(self) -> LocalKeyHandle: """Accessor for the key handle.""" return self._handle @@ -52,7 +51,7 @@ def algorithm(self) -> KeyAlg: return KeyAlg.from_key_alg(alg) @property - def ephemeral(self) -> "Key": + def ephemeral(self) -> bool: return bindings.key_get_ephemeral(self._handle) def convert_key(self, alg: Union[str, KeyAlg]) -> "Key": @@ -76,7 +75,7 @@ def get_jwk_secret(self) -> str: def get_jwk_thumbprint(self, alg: Union[str, KeyAlg] = None) -> str: return bindings.key_get_jwk_thumbprint(self._handle, alg) - def aead_params(self) -> bindings.AeadParams: + def aead_params(self) -> AeadParams: return bindings.key_aead_get_params(self._handle) def aead_random_nonce(self) -> bytes: diff --git a/wrappers/python/aries_askar/store.py b/wrappers/python/aries_askar/store.py index d5f86832..67b0565a 100644 --- a/wrappers/python/aries_askar/store.py +++ b/wrappers/python/aries_askar/store.py @@ -3,12 +3,13 @@ import json from typing import Optional, Sequence, Union +from weakref import ref from cached_property import cached_property from . import bindings - from .bindings import ( + ByteBuffer, EntryListHandle, KeyEntryListHandle, ScanHandle, @@ -101,7 +102,20 @@ def __getitem__(self, index) -> Entry: return Entry(self._handle, index) def __iter__(self): - return self + return IterEntryList(self) + + def __len__(self) -> int: + return self._len + + def __repr__(self) -> str: + return f"" + + +class IterEntryList: + def __init__(self, list: EntryList): + self._handle = list._handle + self._len = list._len + self._pos = 0 def __next__(self): if self._pos < self._len: @@ -111,12 +125,6 @@ def __next__(self): else: raise StopIteration - def __len__(self) -> int: - return self._len - - def __repr__(self) -> str: - return f"" - class KeyEntry: """Pointer to one result of a KeyEntryList instance.""" @@ -184,15 +192,7 @@ def __getitem__(self, index) -> KeyEntry: return KeyEntry(self._handle, index) def __iter__(self): - return self - - def __next__(self): - if self._pos < self._len: - entry = KeyEntry(self._handle, self._pos) - self._pos += 1 - return entry - else: - raise StopIteration + return IterKeyEntryList(self) def __len__(self) -> int: return self._len @@ -203,6 +203,21 @@ def __repr__(self) -> str: ) +class IterKeyEntryList: + def __init__(self, list: KeyEntryList): + self._handle = list._handle + self._len = list._len + self._pos = 0 + + def __next__(self): + if self._pos < self._len: + entry = KeyEntry(self._handle, self._pos) + self._pos += 1 + return entry + else: + raise StopIteration + + class Scan: """A scan of the Store.""" @@ -216,9 +231,9 @@ def __init__( limit: int = None, ): """Initialize the Scan instance.""" - self.params = (store, profile, category, tag_filter, offset, limit) + self._params = (store, profile, category, tag_filter, offset, limit) self._handle: ScanHandle = None - self._buffer: EntryList = None + self._buffer: IterEntryList = None @property def handle(self) -> ScanHandle: @@ -230,7 +245,8 @@ def __aiter__(self): async def __anext__(self): if self._handle is None: - (store, profile, category, tag_filter, offset, limit) = self.params + (store, profile, category, tag_filter, offset, limit) = self._params + self._params = None if not store.handle: raise AskarError( AskarErrorCode.WRAPPER, "Cannot scan from closed store" @@ -239,7 +255,7 @@ async def __anext__(self): store.handle, profile, category, tag_filter, offset, limit ) list_handle = await bindings.scan_next(self._handle) - self._buffer = EntryList(list_handle) if list_handle else None + self._buffer = iter(EntryList(list_handle)) if list_handle else None while True: if not self._buffer: raise StopAsyncIteration @@ -247,7 +263,7 @@ async def __anext__(self): if row: return row list_handle = await bindings.scan_next(self._handle) - self._buffer = EntryList(list_handle) if list_handle else None + self._buffer = iter(EntryList(list_handle)) if list_handle else None async def fetch_all(self) -> Sequence[Entry]: rows = [] @@ -317,11 +333,13 @@ async def remove(cls, uri: str) -> bool: async def __aenter__(self) -> "Session": if not self._opener: - self._opener = OpenSession(self, None, False) + self._opener = OpenSession(self._handle, None, False) return await self._opener.__aenter__() async def __aexit__(self, exc_type, exc, tb): - return await self._opener.__aexit__(exc_type, exc, tb) + opener = self._opener + self._opener = None + return await opener.__aexit__(exc_type, exc, tb) async def create_profile(self, name: str = None) -> str: return await bindings.store_create_profile(self._handle, name) @@ -350,10 +368,10 @@ def scan( return Scan(self, profile, category, tag_filter, offset, limit) def session(self, profile: str = None) -> "OpenSession": - return OpenSession(self, profile, False) + return OpenSession(self._handle, profile, False) def transaction(self, profile: str = None) -> "OpenSession": - return OpenSession(self, profile, True) + return OpenSession(self._handle, profile, True) async def close(self, *, remove: bool = False) -> bool: """Close and free the pool instance.""" @@ -373,7 +391,7 @@ def __repr__(self) -> str: class Session: """An opened Session instance.""" - def __init__(self, store: Store, handle: SessionHandle, is_txn: bool): + def __init__(self, store: StoreHandle, handle: SessionHandle, is_txn: bool): """Initialize the Session instance.""" self._store = store self._handle = handle @@ -389,11 +407,6 @@ def handle(self) -> SessionHandle: """Accessor for the SessionHandle instance.""" return self._handle - @property - def store(self) -> Store: - """Accessor for the Store instance.""" - return self._store - async def count(self, category: str, tag_filter: Union[str, dict] = None) -> int: if not self._handle: raise AskarError(AskarErrorCode.WRAPPER, "Cannot count from closed session") @@ -407,7 +420,7 @@ async def fetch( result_handle = await bindings.session_fetch( self._handle, category, name, for_update ) - return next(EntryList(result_handle, 1), None) if result_handle else None + return next(iter(EntryList(result_handle, 1)), None) if result_handle else None async def fetch_all( self, @@ -508,7 +521,9 @@ async def fetch_key( AskarErrorCode.WRAPPER, "Cannot fetch key from closed session" ) result_handle = await bindings.session_fetch_key(self._handle, name, for_update) - return next(KeyEntryList(result_handle, 1)) if result_handle else None + return ( + next(iter(KeyEntryList(result_handle, 1)), None) if result_handle else None + ) async def fetch_all_keys( self, @@ -577,39 +592,38 @@ def __repr__(self) -> str: class OpenSession: - def __init__(self, store: Store, profile: Optional[str], is_txn: bool): + def __init__(self, store: StoreHandle, profile: Optional[str], is_txn: bool): """Initialize the OpenSession instance.""" self._store = store self._profile = profile self._is_txn = is_txn - self._session = None + self._session: Session = None @property def is_transaction(self) -> bool: return self._is_txn async def _open(self) -> Session: - if not self._store.handle: + if not self._store: raise AskarError( AskarErrorCode.WRAPPER, "Cannot start session from closed store" ) if self._session: raise AskarError(AskarErrorCode.WRAPPER, "Session already opened") - self._session = Session( + return Session( self._store, - await bindings.session_start( - self._store.handle, self._profile, self._is_txn - ), + await bindings.session_start(self._store, self._profile, self._is_txn), self._is_txn, ) - return self._session def __await__(self) -> Session: return self._open().__await__() async def __aenter__(self) -> Session: - return await self._open() + self._session = await self._open() + return self._session async def __aexit__(self, exc_type, exc, tb): - await self._session.close() + session = self._session self._session = None + await session.close() diff --git a/wrappers/python/aries_askar/version.py b/wrappers/python/aries_askar/version.py index 9cc67882..b4d49e71 100644 --- a/wrappers/python/aries_askar/version.py +++ b/wrappers/python/aries_askar/version.py @@ -1,3 +1,3 @@ """aries_askar library wrapper version.""" -__version__ = "0.2.4" +__version__ = "0.2.5" diff --git a/wrappers/python/demo/perf.py b/wrappers/python/demo/perf.py index aaa80b4c..a3fde18b 100644 --- a/wrappers/python/demo/perf.py +++ b/wrappers/python/demo/perf.py @@ -31,22 +31,37 @@ async def perf_test(): store = await Store.provision(REPO_URI, "raw", key, recreate=True) + insert_start = time.perf_counter() + async with store.session() as session: + for idx in range(PERF_ROWS): + await session.insert( + "seq", + f"name-{idx}", + b"value", + {"~plaintag": "a", "enctag": "b"}, + ) + dur = time.perf_counter() - insert_start + print(f"sequential insert duration ({PERF_ROWS} rows): {dur:0.2f}s") + insert_start = time.perf_counter() async with store.transaction() as txn: - # ^ faster within a transaction + # ^ should be faster within a transaction for idx in range(PERF_ROWS): await txn.insert( - "category", f"name-{idx}", b"value", {"~plaintag": "a", "enctag": "b"} + "txn", + f"name-{idx}", + b"value", + {"~plaintag": "a", "enctag": "b"}, ) await txn.commit() dur = time.perf_counter() - insert_start - print(f"insert duration ({PERF_ROWS} rows): {dur:0.2f}s") + print(f"transaction batch insert duration ({PERF_ROWS} rows): {dur:0.2f}s") + tags = 0 fetch_start = time.perf_counter() async with store as session: - tags = 0 for idx in range(PERF_ROWS): - entry = await session.fetch("category", f"name-{idx}") + entry = await session.fetch("seq", f"name-{idx}") tags += len(entry.tags) dur = time.perf_counter() - fetch_start print(f"fetch duration ({PERF_ROWS} rows, {tags} tags): {dur:0.2f}s") @@ -54,12 +69,36 @@ async def perf_test(): rc = 0 tags = 0 scan_start = time.perf_counter() - async for row in store.scan("category", {"~plaintag": "a", "enctag": "b"}): + async for row in store.scan("seq", {"~plaintag": "a", "enctag": "b"}): rc += 1 tags += len(row.tags) dur = time.perf_counter() - scan_start print(f"scan duration ({rc} rows, {tags} tags): {dur:0.2f}s") + async with store as session: + await session.insert("seq", "count", "0", {"~plaintag": "a", "enctag": "b"}) + update_start = time.perf_counter() + count = 0 + for idx in range(PERF_ROWS): + count += 1 + await session.replace( + "seq", "count", str(count), {"~plaintag": "a", "enctag": "b"} + ) + dur = time.perf_counter() - update_start + print(f"unchecked update duration ({PERF_ROWS} rows): {dur:0.2f}s") + + async with store as session: + await session.insert("txn", "count", "0", {"~plaintag": "a", "enctag": "b"}) + update_start = time.perf_counter() + for idx in range(PERF_ROWS): + async with store.transaction() as txn: + row = await txn.fetch("txn", "count", for_update=True) + count = str(int(row.value) + 1) + await txn.replace("txn", "count", count, {"~plaintag": "a", "enctag": "b"}) + await txn.commit() + dur = time.perf_counter() - update_start + print(f"transactional update duration ({PERF_ROWS} rows): {dur:0.2f}s") + await store.close() diff --git a/wrappers/python/tests/test_cleanup.py b/wrappers/python/tests/test_cleanup.py new file mode 100644 index 00000000..4a1d799f --- /dev/null +++ b/wrappers/python/tests/test_cleanup.py @@ -0,0 +1,60 @@ +from ctypes import c_char, c_char_p, c_size_t, c_ubyte, pointer +from unittest import mock + +from aries_askar.bindings.handle import ArcHandle +from aries_askar.bindings.lib import ByteBuffer, RawBuffer, StrBuffer + + +def test_cleanup_handle(): + logged = [] + + class Handle(ArcHandle): + @classmethod + def _cleanup(cls, handle: c_size_t): + logged.append(handle.value) + + h = Handle() + assert not h.value + del h + assert not logged + + h = Handle() + h.value = 99 + del h + assert logged == [(99)] + + +def test_cleanup_bytebuffer(): + logged = [] + + def cleanup(buffer: RawBuffer): + logged.append((buffer.len, buffer.data.contents.value if buffer.data else None)) + + with mock.patch.object(ByteBuffer, "_cleanup", cleanup): + b = ByteBuffer() + del b + assert not logged + + c = c_ubyte(99) + b = ByteBuffer() + b.buffer = RawBuffer(len=1, data=pointer(c)) + del b + assert logged == [(1, 99)] + + +def test_cleanup_strbuffer(): + logged = [] + + def cleanup(buffer: c_char_p): + logged.append(buffer.value) + + with mock.patch.object(StrBuffer, "_cleanup", cleanup): + s = StrBuffer() + del s + assert not logged + + s = StrBuffer() + c = c_char(ord("a")) + s.buffer = pointer(c) + del s + assert logged == [b"a"] diff --git a/wrappers/python/tests/test_store.py b/wrappers/python/tests/test_store.py index 03ef787e..684b9d45 100644 --- a/wrappers/python/tests/test_store.py +++ b/wrappers/python/tests/test_store.py @@ -1,3 +1,4 @@ +import asyncio import os from pytest import mark, raises @@ -25,6 +26,7 @@ def raw_key() -> str: @pytest_asyncio.fixture +@mark.asyncio async def store() -> Store: key = raw_key() store = await Store.provision(TEST_STORE_URI, "raw", key, recreate=True) @@ -110,7 +112,7 @@ async def test_scan(store: Store): @mark.asyncio -async def test_transaction(store: Store): +async def test_txn_basic(store: Store): async with store.transaction() as txn: # Insert a new entry @@ -144,6 +146,43 @@ async def test_transaction(store: Store): assert dict(found) == TEST_ENTRY +@mark.asyncio +async def test_txn_contention(store: Store): + async with store.transaction() as txn: + await txn.insert( + TEST_ENTRY["category"], + TEST_ENTRY["name"], + "0", + ) + await txn.commit() + + INC_COUNT = 1000 + TASKS = 10 + + async def inc(): + for _ in range(INC_COUNT): + async with store.transaction() as txn: + row = await txn.fetch( + TEST_ENTRY["category"], TEST_ENTRY["name"], for_update=True + ) + if not row: + raise Exception("Row not found") + new_value = str(int(row.value) + 1) + await txn.replace(TEST_ENTRY["category"], TEST_ENTRY["name"], new_value) + await txn.commit() + + tasks = [asyncio.create_task(inc()) for _ in range(TASKS)] + await asyncio.gather(*tasks) + + # Check all the updates completed + async with store.session() as session: + result = await session.fetch( + TEST_ENTRY["category"], + TEST_ENTRY["name"], + ) + assert int(result.value) == INC_COUNT * TASKS + + @mark.asyncio async def test_key_store(store: Store): # test key operations in a new session