Skip to content

Commit

Permalink
Refactor HttpClient trait to match http-std crate
Browse files Browse the repository at this point in the history
  • Loading branch information
Diane Huxley committed Oct 3, 2024
1 parent 30d475e commit e004bd0
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 107 deletions.
10 changes: 7 additions & 3 deletions crates/web5/src/credentials/credential_schema.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::collections::HashMap;

use super::verifiable_credential_1_1::VerifiableCredential;
use crate::{errors::{Result, Web5Error}, http::get_http_client};
use crate::{
errors::{Result, Web5Error},
http::get_http_client,
};
use jsonschema::{Draft, JSONSchema};
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -33,9 +36,10 @@ pub(crate) async fn validate_credential_schema(
let headers: HashMap<String, String> = HashMap::from([
("Host".to_string(), "{}".to_string()),
("Connection".to_string(), "close".to_string()),
("Accept".to_string(), "application/json".to_string())
("Accept".to_string(), "application/json".to_string()),
]);
let response = get_http_client().get(url, Some(headers))
let response = get_http_client()
.get(url, Some(headers))
.await
.map_err(|e| Web5Error::Network(format!("Failed to fetch credential schema: {}", e)))?;

Expand Down
18 changes: 13 additions & 5 deletions crates/web5/src/dids/methods/did_dht/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ use crate::{
resolution_metadata::ResolutionMetadataError, resolution_result::ResolutionResult,
},
},
errors::{Result, Web5Error}, http::get_http_client,
errors::{Result, Web5Error},
http::get_http_client,
};
use std::{collections::HashMap, sync::Arc};

Expand Down Expand Up @@ -194,10 +195,14 @@ impl DidDht {
("Host".to_string(), "{}".to_string()),
("Connection".to_string(), "close".to_string()),
("Content-Length".to_string(), "{}".to_string()),
("Content-Type".to_string(), "application/octet-stream".to_string())
(
"Content-Type".to_string(),
"application/octet-stream".to_string(),
),
]);

let response = get_http_client().put(&url, Some(headers), &body)
let response = get_http_client()
.put(&url, Some(headers), &body)
.await
.map_err(|e| Web5Error::Network(format!("Failed to PUT did:dht: {}", e)))?;
if response.status_code != 200 {
Expand Down Expand Up @@ -259,9 +264,12 @@ impl DidDht {
let headers: HashMap<String, String> = HashMap::from([
("Host".to_string(), "{}".to_string()),
("Connection".to_string(), "close".to_string()),
("Accept".to_string(), "application/octet-stream".to_string())
("Accept".to_string(), "application/octet-stream".to_string()),
]);
let response = get_http_client().get(&url, Some(headers)).await.map_err(|_| ResolutionMetadataError::InternalError)?;
let response = get_http_client()
.get(&url, Some(headers))
.await
.map_err(|_| ResolutionMetadataError::InternalError)?;

if response.status_code == 404 {
return Err(ResolutionMetadataError::NotFound);
Expand Down
20 changes: 12 additions & 8 deletions crates/web5/src/dids/methods/did_web/resolver.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
use std::collections::HashMap;

use crate::{dids::{
data_model::document::Document,
did::Did,
resolution::{
resolution_metadata::ResolutionMetadataError, resolution_result::ResolutionResult,
use crate::{
dids::{
data_model::document::Document,
did::Did,
resolution::{
resolution_metadata::ResolutionMetadataError, resolution_result::ResolutionResult,
},
},
}, http::get_http_client};
http::get_http_client,
};
use url::Url;

// PORT_SEP is the : character that separates the domain from the port in a URI.
Expand Down Expand Up @@ -48,9 +51,10 @@ impl Resolver {
let headers: HashMap<String, String> = HashMap::from([
("Host".to_string(), "{}".to_string()),
("Connection".to_string(), "close".to_string()),
("Accept".to_string(), "application/json".to_string())
("Accept".to_string(), "application/json".to_string()),
]);
let response = get_http_client().get(&self.http_url, Some(headers))
let response = get_http_client()
.get(&self.http_url, Some(headers))
.await
.map_err(|_| ResolutionMetadataError::InternalError)?;

Expand Down
204 changes: 114 additions & 90 deletions crates/web5/src/http.rs
Original file line number Diff line number Diff line change
@@ -1,34 +1,83 @@

use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait;
use once_cell::sync::OnceCell;
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, sync::Arc};

#[async_trait]
pub trait HttpClient: Send + Sync {
async fn get(
&self,
url: &str,
headers: Option<HashMap<String, String>>
async fn fetch(
&self,
url: &str,
options: Option<FetchOptions>,
) -> std::result::Result<HttpResponse, Box<dyn std::error::Error>>;

async fn get(
&self,
url: &str,
headers: Option<HashMap<String, String>>,
) -> std::result::Result<HttpResponse, Box<dyn std::error::Error>> {
self.fetch(
url,
Some(FetchOptions {
method: Some(Method::Get),
headers,
body: None,
}),
)
.await
}

async fn post(
&self,
url: &str,
headers: Option<HashMap<String, String>>,
body: &[u8]
) -> std::result::Result<HttpResponse, Box<dyn std::error::Error>>;
&self,
url: &str,
headers: Option<HashMap<String, String>>,
body: &[u8],
) -> std::result::Result<HttpResponse, Box<dyn std::error::Error>> {
self.fetch(
url,
Some(FetchOptions {
method: Some(Method::Post),
headers,
body: Some(body.to_vec()),
}),
)
.await
}

async fn put(
&self,
url: &str,
headers: Option<HashMap<String, String>>,
body: &[u8]
) -> std::result::Result<HttpResponse, Box<dyn std::error::Error>>;
&self,
url: &str,
headers: Option<HashMap<String, String>>,
body: &[u8],
) -> std::result::Result<HttpResponse, Box<dyn std::error::Error>> {
self.fetch(
url,
Some(FetchOptions {
method: Some(Method::Put),
headers,
body: Some(body.to_vec()),
}),
)
.await
}
}

#[derive(Default, Serialize, Deserialize)]
pub struct FetchOptions {
pub method: Option<Method>,
pub headers: Option<HashMap<String, String>>,
pub body: Option<Vec<u8>>,
}

#[derive(Serialize, Deserialize)]
pub enum Method {
Get,
Post,
Put,
}

pub struct HttpResponse {
pub status_code: u16,
#[allow(dead_code)]
pub headers: HashMap<String, String>,
pub body: Vec<u8>,
}
Expand All @@ -37,14 +86,17 @@ static HTTP_CLIENT: OnceCell<Arc<dyn HttpClient>> = OnceCell::new();

#[cfg(feature = "http_reqwest")]
pub fn get_http_client() -> &'static dyn HttpClient {
HTTP_CLIENT.get_or_init(|| {
Arc::new(reqwest_http_client::ReqwestHttpClient::new())
}).as_ref()
HTTP_CLIENT
.get_or_init(|| Arc::new(reqwest_http_client::ReqwestHttpClient::new()))
.as_ref()
}

#[cfg(not(feature = "http_reqwest"))]
pub fn get_http_client() -> &'static dyn HttpClient {
HTTP_CLIENT.get().expect("HttpClient has not been set. Please call set_http_client().").as_ref()
HTTP_CLIENT
.get()
.expect("HttpClient has not been set. Please call set_http_client().")
.as_ref()
}

#[cfg(feature = "http_reqwest")]
Expand All @@ -54,74 +106,64 @@ pub fn set_http_client(_: Arc<dyn HttpClient>) {

#[cfg(not(feature = "http_reqwest"))]
pub fn set_http_client(client: Arc<dyn HttpClient>) {
HTTP_CLIENT.set(client).unwrap_or_else(|_| panic!("HttpClient has already been set."));
HTTP_CLIENT
.set(client)
.unwrap_or_else(|_| panic!("HttpClient has already been set."));
}

#[cfg(feature = "http_reqwest")]
mod reqwest_http_client {
use super::*;
use reqwest::Client;
use std::collections::HashMap;
use reqwest::{Client as ReqwestClient, Method as ReqwestMethod, Response as ReqwestResponse};
use std::error::Error;

pub struct ReqwestHttpClient {
client: Client,
client: ReqwestClient,
}

impl ReqwestHttpClient {
pub fn new() -> Self {
ReqwestHttpClient {
client: Client::new(),
client: ReqwestClient::new(),
}
}
}

#[async_trait]
impl HttpClient for ReqwestHttpClient {
async fn get(
&self,
url: &str,
headers: Option<HashMap<String, String>>,
) -> Result<HttpResponse, Box<dyn std::error::Error>> {
let mut req = self.client.get(url);

if let Some(headers) = headers {
for (key, value) in headers {
req = req.header(&key, &value);
}
fn map_method(method: Option<Method>) -> ReqwestMethod {
match method {
Some(Method::Post) => ReqwestMethod::POST,
Some(Method::Put) => ReqwestMethod::PUT,
_ => ReqwestMethod::GET,
}

let response = req.send().await?;
let status_code = response.status().as_u16();
let headers = response
.headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
.collect();

let body = response.bytes().await?.to_vec();

Ok(HttpResponse {
status_code,
headers,
body,
})
}

async fn post(
async fn build_request(
&self,
url: &str,
headers: Option<HashMap<String, String>>,
body: &[u8],
) -> Result<HttpResponse, Box<dyn std::error::Error>> {
let mut req = self.client.post(url).body(body.to_vec());
options: Option<FetchOptions>,
) -> Result<reqwest::RequestBuilder, Box<dyn Error>> {
let FetchOptions {
method,
headers,
body,
} = options.unwrap_or_default();

let req_method = Self::map_method(method);
let mut req = self.client.request(req_method, url);

if let Some(headers) = headers {
for (key, value) in headers {
req = req.header(&key, &value);
}
}

let response = req.send().await?;
if let Some(body) = body {
req = req.body(body);
}

Ok(req)
}

async fn parse_response(response: ReqwestResponse) -> Result<HttpResponse, Box<dyn Error>> {
let status_code = response.status().as_u16();
let headers = response
.headers()
Expand All @@ -137,36 +179,18 @@ mod reqwest_http_client {
body,
})
}
}

async fn put(
#[async_trait]
impl HttpClient for ReqwestHttpClient {
async fn fetch(
&self,
url: &str,
headers: Option<HashMap<String, String>>,
body: &[u8],
) -> Result<HttpResponse, Box<dyn std::error::Error>> {
let mut req = self.client.put(url).body(body.to_vec());

if let Some(headers) = headers {
for (key, value) in headers {
req = req.header(&key, &value);
}
}

let response = req.send().await?;
let status_code = response.status().as_u16();
let headers = response
.headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
.collect();

let body = response.bytes().await?.to_vec();

Ok(HttpResponse {
status_code,
headers,
body,
})
options: Option<FetchOptions>,
) -> Result<HttpResponse, Box<dyn Error>> {
let req = self.build_request(url, options).await?;
let res = req.send().await?;
Self::parse_response(res).await
}
}
}
}
2 changes: 1 addition & 1 deletion crates/web5/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ mod tests {
fn test_without_reqwest_feature() {
println!("http_reqwest feature is NOT enabled!");
}
}
}

0 comments on commit e004bd0

Please sign in to comment.