diff --git a/src/lib.rs b/src/lib.rs index 20c8dd6..272f606 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,11 +9,11 @@ mod error; pub use error::{KbsTypesError, Result}; #[cfg(all(feature = "alloc", not(feature = "std")))] -use alloc::{collections::btree_map::BTreeMap, string::String, vec::Vec}; +use alloc::{string::String, vec::Vec}; use base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine}; -use serde_json::Value; +use serde_json::{Map, Value}; #[cfg(all(feature = "std", not(feature = "alloc")))] -use std::{collections::BTreeMap, string::String}; +use std::string::String; use serde::{Deserialize, Serialize}; @@ -97,8 +97,8 @@ pub struct ProtectedHeader { pub enc: String, /// Other fields of Protected Header - #[serde(skip_serializing_if = "BTreeMap::is_empty", flatten)] - pub other_fields: BTreeMap, + #[serde(skip_serializing_if = "Map::is_empty", flatten)] + pub other_fields: Map, } impl ProtectedHeader { @@ -157,7 +157,7 @@ where Ok(decoded) } -fn serialize_base64_option( +fn serialize_base64_vec( sub: &Option>, serializer: S, ) -> core::result::Result @@ -166,25 +166,23 @@ where { match sub { Some(value) => { - let encoded = BASE64_URL_SAFE_NO_PAD.encode(value); + let encoded = String::from_utf8(value.clone()).map_err(serde::ser::Error::custom)?; serializer.serialize_str(&encoded) } None => serializer.serialize_none(), } } -fn deserialize_base64_option<'de, D>( +fn deserialize_base64_vec<'de, D>( deserializer: D, ) -> core::result::Result>, D::Error> where D: serde::Deserializer<'de>, { - let encoded = String::deserialize(deserializer)?; - let decoded = BASE64_URL_SAFE_NO_PAD - .decode(encoded) - .map_err(serde::de::Error::custom)?; + let string = String::deserialize(deserializer)?; + let bytes = string.into_bytes(); - Ok(Some(decoded)) + Ok(Some(bytes)) } #[derive(Clone, Serialize, Deserialize, Debug)] @@ -202,10 +200,10 @@ pub struct Response { pub encrypted_key: Vec, #[serde( - deserialize_with = "deserialize_base64_option", skip_serializing_if = "Option::is_none", - serialize_with = "serialize_base64_option", - default = "Option::default" + default = "Option::default", + serialize_with = "serialize_base64_vec", + deserialize_with = "deserialize_base64_vec" )] pub aad: Option>, @@ -279,7 +277,7 @@ mod tests { let protected_header = ProtectedHeader { alg: "fakealg".to_string(), enc: "fakeenc".to_string(), - other_fields: BTreeMap::new(), + other_fields: Map::new(), }; let aad = protected_header.generate_aad().unwrap(); @@ -313,6 +311,41 @@ mod tests { assert_eq!(response.aad, None); } + #[test] + fn parse_response_nested_protected_header() { + let data = r#" + { + "protected": "eyJhbGciOiJmYWtlYWxnIiwiZW5jIjoiZmFrZWVuYyIsImVwayI6eyJrdHkiOiJPS1AiLCJjcnYiOiJYMjU1MTkiLCJ4IjoiaFNEd0NZa3dwMVIwaTMzY3RENzNXZzJfT2cwbU9CcjA2NlNwanFxYlRtbyJ9fQo", + "encrypted_key": "ZmFrZWtleQ", + "iv": "cmFuZG9tZGF0YQ", + "ciphertext": "ZmFrZWVuY291dHB1dA", + "tag": "ZmFrZXRhZw" + }"#; + + let response: Response = serde_json::from_str(data).unwrap(); + + assert_eq!(response.protected.alg, "fakealg"); + assert_eq!(response.protected.enc, "fakeenc"); + + let expected_other_fields = json!({ + "epk": { + "kty" : "OKP", + "crv": "X25519", + "x": "hSDwCYkwp1R0i33ctD73Wg2_Og0mOBr066SpjqqbTmo" + } + }) + .as_object() + .unwrap() + .clone(); + + assert_eq!(response.protected.other_fields, expected_other_fields); + assert_eq!(response.encrypted_key, "fakekey".as_bytes()); + assert_eq!(response.iv, "randomdata".as_bytes()); + assert_eq!(response.ciphertext, "fakeencoutput".as_bytes()); + assert_eq!(response.tag, "faketag".as_bytes()); + assert_eq!(response.aad, None); + } + #[test] fn parse_response_with_aad() { let data = r#" @@ -320,7 +353,7 @@ mod tests { "protected": "eyJhbGciOiJmYWtlYWxnIiwiZW5jIjoiZmFrZWVuYyJ9Cg", "encrypted_key": "ZmFrZWtleQ", "iv": "cmFuZG9tZGF0YQ", - "aad": "ZmFrZWFhZA", + "aad": "fakeaad", "ciphertext": "ZmFrZWVuY291dHB1dA", "tag": "ZmFrZXRhZw" }"#; @@ -344,7 +377,7 @@ mod tests { "protected": "eyJhbGciOiJmYWtlYWxnIiwiZW5jIjoiZmFrZWVuYyIsImZha2VmaWVsZCI6ImZha2V2YWx1ZSJ9", "encrypted_key": "ZmFrZWtleQ", "iv": "cmFuZG9tZGF0YQ", - "aad": "ZmFrZWFhZA", + "aad": "fakeaad", "ciphertext": "ZmFrZWVuY291dHB1dA", "tag": "ZmFrZXRhZw" }"#; @@ -382,7 +415,7 @@ mod tests { "protected": "eyJhbGciOiJmYWtlYWxnIiwiZW5jIjoiZmFrZWVuYyIsImZha2VmaWVsZCI6ImZha2V2YWx1ZSJ9", "encrypted_key": "ZmFrZWtleQ", "iv": "cmFuZG9tZGF0YQ", - "aad": "ZmFrZWFhZA", + "aad": "fakeaad", "ciphertext": "ZmFrZWVuY291dHB1dA", "tag": "ZmFrZXRhZw" });