diff --git a/Cargo.toml b/Cargo.toml index feab0a4..0455ca7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ license = "Apache-2.0" [features] default = [ "std" ] alloc = [ "serde/alloc", "serde_json/alloc" ] -std = [ "serde/std", "serde_json/std" ] +std = ["serde/std", "serde_json/std", "thiserror/std"] tee-sev = [ "sev" ] tee-snp = [ "sev" ] @@ -19,6 +19,7 @@ base64 = "0.22.1" serde = { version = "1.0", default-features = false, features = ["derive"] } serde_json = { version = "1.0", default-features = false } sev = { version = "3.2.0", features = ["openssl"], optional = true } +thiserror = { version = "2.0.3", default-features = false } [dev-dependencies] codicon = "3.0.0" diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..1b9e584 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,9 @@ +use thiserror::Error; + +pub type Result = core::result::Result; + +#[derive(Error, Debug)] +pub enum KbsTypesError { + #[error("Serialize/Deserialize error")] + Serde, +} diff --git a/src/lib.rs b/src/lib.rs index bcb5f9f..20c8dd6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,9 @@ #[cfg(feature = "alloc")] extern crate alloc; +mod error; +pub use error::{KbsTypesError, Result}; + #[cfg(all(feature = "alloc", not(feature = "std")))] use alloc::{collections::btree_map::BTreeMap, string::String, vec::Vec}; use base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine}; @@ -100,18 +103,17 @@ pub struct ProtectedHeader { impl ProtectedHeader { /// The generation of AAD for JWE follows [A.3.5 RFC7516](https://www.rfc-editor.org/rfc/rfc7516#appendix-A.3.5) - pub fn generate_aad(&self) -> Vec { - let protected_utf8 = - serde_json::to_string(&self).expect("unexpected OOM when serializing ProtectedHeader"); + pub fn generate_aad(&self) -> Result> { + let protected_utf8 = serde_json::to_string(&self).map_err(|_| KbsTypesError::Serde)?; let aad = BASE64_URL_SAFE_NO_PAD.encode(protected_utf8); - aad.into_bytes() + Ok(aad.into_bytes()) } } fn serialize_base64_protected_header( sub: &ProtectedHeader, serializer: S, -) -> Result +) -> core::result::Result where S: serde::Serializer, { @@ -120,7 +122,9 @@ where serializer.serialize_str(&encoded) } -fn deserialize_base64_protected_header<'de, D>(deserializer: D) -> Result +fn deserialize_base64_protected_header<'de, D>( + deserializer: D, +) -> core::result::Result where D: serde::Deserializer<'de>, { @@ -133,7 +137,7 @@ where Ok(protected_header) } -fn serialize_base64(sub: &Vec, serializer: S) -> Result +fn serialize_base64(sub: &Vec, serializer: S) -> core::result::Result where S: serde::Serializer, { @@ -141,7 +145,7 @@ where serializer.serialize_str(&encoded) } -fn deserialize_base64<'de, D>(deserializer: D) -> Result, D::Error> +fn deserialize_base64<'de, D>(deserializer: D) -> core::result::Result, D::Error> where D: serde::Deserializer<'de>, { @@ -153,7 +157,10 @@ where Ok(decoded) } -fn serialize_base64_option(sub: &Option>, serializer: S) -> Result +fn serialize_base64_option( + sub: &Option>, + serializer: S, +) -> core::result::Result where S: serde::Serializer, { @@ -166,7 +173,9 @@ where } } -fn deserialize_base64_option<'de, D>(deserializer: D) -> Result>, D::Error> +fn deserialize_base64_option<'de, D>( + deserializer: D, +) -> core::result::Result>, D::Error> where D: serde::Deserializer<'de>, { @@ -273,7 +282,7 @@ mod tests { other_fields: BTreeMap::new(), }; - let aad = protected_header.generate_aad(); + let aad = protected_header.generate_aad().unwrap(); assert_eq!( aad,