Skip to content

Commit

Permalink
✨ Update ApiClient to latest; implements deserializing enums; drops o…
Browse files Browse the repository at this point in the history
…ur custom handling for booleans ... let's see if it still works 🚧
  • Loading branch information
ff137 committed Apr 4, 2024
1 parent 37cf736 commit 66ceb76
Showing 1 changed file with 71 additions and 76 deletions.
147 changes: 71 additions & 76 deletions aries_cloudcontroller/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,20 @@
import os
import re
import tempfile
from typing import List, Optional, Tuple
from enum import Enum
from typing import Dict, List, Optional, Tuple
from urllib.parse import quote

from dateutil.parser import parse

import aries_cloudcontroller.models
from aries_cloudcontroller import rest
from aries_cloudcontroller.api_response import ApiResponse
from aries_cloudcontroller.api_response import T as ApiResponseT
from aries_cloudcontroller.configuration import Configuration
from aries_cloudcontroller.exceptions import (
ApiException,
ApiValueError,
BadRequestException,
ForbiddenException,
NotFoundException,
ServiceException,
UnauthorizedException,
)
from aries_cloudcontroller.exceptions import ApiException, ApiValueError

RequestSerialized = Tuple[str, str, Dict[str, str], Optional[str], List[str]]


class ApiClient:
Expand Down Expand Up @@ -145,7 +141,7 @@ def param_serialize(
collection_formats=None,
_host=None,
_request_auth=None,
) -> Tuple:
) -> RequestSerialized:
"""Builds the HTTP request params needed by the request.
:param method: Method to call.
:param resource_path: Path to method endpoint.
Expand Down Expand Up @@ -222,9 +218,7 @@ def param_serialize(

# query parameters
if query_params:
query_params = self.sanitize_for_serialization(
query_params, query_params=True
)
query_params = self.sanitize_for_serialization(query_params)
url_query = self.parameters_to_url_query(query_params, collection_formats)
url += "?" + url_query

Expand Down Expand Up @@ -263,21 +257,24 @@ async def call_api(
)

except ApiException as e:
if e.body:
e.body = e.body.decode("utf-8")
raise e

return response_data

def response_deserialize(
self, response_data=None, response_types_map=None
) -> ApiResponse:
self,
response_data: rest.RESTResponse,
response_types_map: Optional[Dict[str, ApiResponseT]] = None,
) -> ApiResponse[ApiResponseT]:
"""Deserializes response into an object.
:param response_data: RESTResponse object to be deserialized.
:param response_types_map: dict of response types.
:return: ApiResponse
"""

msg = "RESTResponse.read() must be called before passing it to response_deserialize()"
assert response_data.data is not None, msg

response_type = response_types_map.get(str(response_data.status), None)
if (
not response_type
Expand All @@ -289,39 +286,29 @@ def response_deserialize(
str(response_data.status)[0] + "XX", None
)

if not 200 <= response_data.status <= 299:
if response_data.status == 400:
raise BadRequestException(http_resp=response_data)

if response_data.status == 401:
raise UnauthorizedException(http_resp=response_data)

if response_data.status == 403:
raise ForbiddenException(http_resp=response_data)

if response_data.status == 404:
raise NotFoundException(http_resp=response_data)

if 500 <= response_data.status <= 599:
raise ServiceException(http_resp=response_data)
raise ApiException(http_resp=response_data)

# deserialize response data

if response_type == "bytearray":
return_data = response_data.data
elif response_type is None:
return_data = None
elif response_type == "file":
return_data = self.__deserialize_file(response_data)
else:
match = None
content_type = response_data.getheader("content-type")
if content_type is not None:
match = re.search(r"charset=([a-zA-Z\-\d]+)[\s;]?", content_type)
encoding = match.group(1) if match else "utf-8"
response_text = response_data.data.decode(encoding)
return_data = self.deserialize(response_text, response_type)
response_text = None
return_data = None
try:
if response_type == "bytearray":
return_data = response_data.data
elif response_type == "file":
return_data = self.__deserialize_file(response_data)
elif response_type is not None:
match = None
content_type = response_data.getheader("content-type")
if content_type is not None:
match = re.search(r"charset=([a-zA-Z\-\d]+)[\s;]?", content_type)
encoding = match.group(1) if match else "utf-8"
response_text = response_data.data.decode(encoding)
return_data = self.deserialize(response_text, response_type)
finally:
if not 200 <= response_data.status <= 299:
raise ApiException.from_response(
http_resp=response_data,
body=response_text,
data=return_data,
)

return ApiResponse(
status_code=response_data.status,
Expand All @@ -330,7 +317,7 @@ def response_deserialize(
raw_data=response_data.data,
)

def sanitize_for_serialization(self, obj, query_params=False):
def sanitize_for_serialization(self, obj):
"""Builds a JSON POST object.
If obj is None, return None.
Expand All @@ -342,31 +329,20 @@ def sanitize_for_serialization(self, obj, query_params=False):
If obj is OpenAPI model, return the properties dict.
:param obj: The data to serialize.
:param query_params: If these are query params or not
:return: The serialized form of data.
"""
if query_params and isinstance(obj, bool):
# Custom conversion: convert boolean query parameters to their string equivalents
return str(obj).lower()

if obj is None:
return None
elif isinstance(obj, self.PRIMITIVE_TYPES):
return obj
elif isinstance(obj, list):
return [
self.sanitize_for_serialization(sub_obj, query_params=query_params)
for sub_obj in obj
]
return [self.sanitize_for_serialization(sub_obj) for sub_obj in obj]
elif isinstance(obj, tuple):
return tuple(
self.sanitize_for_serialization(sub_obj, query_params=query_params)
for sub_obj in obj
)
return tuple(self.sanitize_for_serialization(sub_obj) for sub_obj in obj)
elif isinstance(obj, (datetime.datetime, datetime.date)):
return obj.isoformat()

if isinstance(obj, dict):
elif isinstance(obj, dict):
obj_dict = obj
else:
# Convert model obj to dict except
Expand All @@ -377,8 +353,7 @@ def sanitize_for_serialization(self, obj, query_params=False):
obj_dict = obj.to_dict()

return {
key: self.sanitize_for_serialization(val, query_params=query_params)
for key, val in obj_dict.items()
key: self.sanitize_for_serialization(val) for key, val in obj_dict.items()
}

def deserialize(self, response_text, response_type):
Expand Down Expand Up @@ -412,11 +387,15 @@ def __deserialize(self, data, klass):

if isinstance(klass, str):
if klass.startswith("List["):
sub_kls = re.match(r"List\[(.*)]", klass).group(1)
m = re.match(r"List\[(.*)]", klass)
assert m is not None, "Malformed List type definition"
sub_kls = m.group(1)
return [self.__deserialize(sub_data, sub_kls) for sub_data in data]

if klass.startswith("Dict["):
sub_kls = re.match(r"Dict\[([^,]*), (.*)]", klass).group(2)
m = re.match(r"Dict\[([^,]*), (.*)]", klass)
assert m is not None, "Malformed Dict type definition"
sub_kls = m.group(2)
return {k: self.__deserialize(v, sub_kls) for k, v in data.items()}

# convert str to class
Expand All @@ -433,6 +412,8 @@ def __deserialize(self, data, klass):
return self.__deserialize_date(data)
elif klass == datetime.datetime:
return self.__deserialize_datetime(data)
elif issubclass(klass, Enum):
return self.__deserialize_enum(data, klass)
else:
return self.__deserialize_model(data, klass)

Expand All @@ -443,7 +424,7 @@ def parameters_to_tuples(self, params, collection_formats):
:param dict collection_formats: Parameter collection formats
:return: Parameters as list of tuples, collections formatted
"""
new_params = []
new_params: List[Tuple[str, str]] = []
if collection_formats is None:
collection_formats = {}
for k, v in params.items() if isinstance(params, dict) else params:
Expand Down Expand Up @@ -472,7 +453,7 @@ def parameters_to_url_query(self, params, collection_formats):
:param dict collection_formats: Parameter collection formats
:return: URL query string (e.g. a=Hello%20World&b=123)
"""
new_params = []
new_params: List[Tuple[str, str]] = []
if collection_formats is None:
collection_formats = {}
for k, v in params.items() if isinstance(params, dict) else params:
Expand All @@ -486,7 +467,7 @@ def parameters_to_url_query(self, params, collection_formats):
if k in collection_formats:
collection_format = collection_formats[k]
if collection_format == "multi":
new_params.extend((k, value) for value in v)
new_params.extend((k, str(value)) for value in v)
else:
if collection_format == "ssv":
delimiter = " "
Expand All @@ -502,7 +483,7 @@ def parameters_to_url_query(self, params, collection_formats):
else:
new_params.append((k, quote(str(v))))

return "&".join(["=".join(item) for item in new_params])
return "&".join(["=".join(map(str, item)) for item in new_params])

def files_parameters(self, files=None):
"""Builds form parameters.
Expand Down Expand Up @@ -637,9 +618,9 @@ def __deserialize_file(self, response):

content_disposition = response.getheader("Content-Disposition")
if content_disposition:
filename = re.search(
r'filename=[\'"]?([^\'"\s]+)[\'"]?', content_disposition
).group(1)
m = re.search(r'filename=[\'"]?([^\'"\s]+)[\'"]?', content_disposition)
assert m is not None, "Unexpected 'content-disposition' header value"
filename = m.group(1)
path = os.path.join(os.path.dirname(path), filename)

with open(path, "wb") as f:
Expand Down Expand Up @@ -702,6 +683,20 @@ def __deserialize_datetime(self, string):
reason=("Failed to parse `{0}` as datetime object".format(string)),
)

def __deserialize_enum(self, data, klass):
"""Deserializes primitive type to enum.
:param data: primitive type.
:param klass: class literal.
:return: enum value.
"""
try:
return klass(data)
except ValueError:
raise rest.ApiException(
status=0, reason=("Failed to parse `{0}` as `{1}`".format(data, klass))
)

def __deserialize_model(self, data, klass):
"""Deserializes list or dict to model.
Expand Down

0 comments on commit 66ceb76

Please sign in to comment.