Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: DTOSS-6585 get single by filter #545

Merged
merged 7 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,15 @@ public static PropertyInfo GetKey<TEntity>() where TEntity : class
return type.GetProperties().FirstOrDefault(p =>
p.CustomAttributes.Any(attr => attr.AttributeType == typeof(KeyAttribute)));
}


public static Type GetPropertyType(Type type, string property)
{
return type.GetProperty(property).PropertyType;
}

public static bool PropertyExists(Type type, string property) =>
Array.Exists(type.GetProperties(), p => p.Name == property);


}
Original file line number Diff line number Diff line change
Expand Up @@ -46,33 +46,15 @@ public async Task<IEnumerable<TEntity>> GetAll()
public async Task<IEnumerable<TEntity>> GetByFilter(Expression<Func<TEntity,bool>> predicate)
{

try
{
//Resolves the constants
var expr = new ClosureResolver().Visit(predicate);
_logger.LogWarning(expr.ToString());


var jsonString = await _callFunction.SendGet(_baseUrl,new Dictionary<string, string>{{"query",expr.ToString()}});
if(string.IsNullOrEmpty(jsonString))
{
return null;
}
IEnumerable<TEntity> result = JsonSerializer.Deserialize<IEnumerable<TEntity>>(jsonString);
return result;

}
catch(WebException wex)
var jsonString = await GetJsonStringByFilter(predicate);
if(string.IsNullOrEmpty(jsonString))
{
HttpWebResponse response = (HttpWebResponse)wex.Response;
if(response.StatusCode! == HttpStatusCode.NotFound)
{
return null;
}

_logger.LogError(wex,"An Exception Happened while calling data service API");
throw;
return null;
}
IEnumerable<TEntity> result = JsonSerializer.Deserialize<IEnumerable<TEntity>>(jsonString);
return result;

}

public virtual async Task<TEntity> GetSingle(string id)
Expand Down Expand Up @@ -102,15 +84,27 @@ public virtual async Task<TEntity> GetSingle(string id)
throw;
}
}
public async Task<TEntity> GetSingleByFilter(Expression<Func<TEntity, bool>> predicate)
{

var jsonString = await GetJsonStringByFilter(predicate,true);
if(string.IsNullOrEmpty(jsonString))
{
return null;
}
TEntity result = JsonSerializer.Deserialize<TEntity>(jsonString);
return result;

}

public async Task<bool> Delete(string id)
{
var result = await _callFunction.SendDelete(UrlBuilder(_baseUrl,id));
return result;
}
public async Task<bool> AddRange(IEnumerable<TEntity> entity)
public async Task<bool> AddRange(IEnumerable<TEntity> entities)
{
var jsonString = JsonSerializer.Serialize<IEnumerable<TEntity>>(entity);
var jsonString = JsonSerializer.Serialize<IEnumerable<TEntity>>(entities);

if(string.IsNullOrEmpty(jsonString))
{
Expand Down Expand Up @@ -164,11 +158,44 @@ public async Task<bool> Update(TEntity entity)
return true;
}

private async Task<string> GetJsonStringByFilter(Expression<Func<TEntity,bool>> predicate, bool returnOneRecord = false)
{
try
{

//Resolves the constants
var expr = new ClosureResolver().Visit(predicate);

var queryItems = new Dictionary<string,string>{{"query",expr.ToString()}};

if(returnOneRecord)
{
queryItems.Add("single","true");
}

var jsonString = await _callFunction.SendGet(_baseUrl,queryItems);
return jsonString;
}
catch(WebException wex)
{
HttpWebResponse response = (HttpWebResponse)wex.Response;
if(response.StatusCode! == HttpStatusCode.NotFound)
{
return null;
}

_logger.LogError(wex,"An Exception Happened while calling data service API");
throw;
}

}

private static string UrlBuilder(string baseUrl, string argument)
{
baseUrl = baseUrl.TrimEnd('/');
argument = argument.TrimStart('/');
return string.Format("{0}/{1}", baseUrl, argument);
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,40 @@ namespace DataServices.Client;

public interface IDataServiceClient<TEntity>
{
/// <summary>
/// Gets all items
/// </summary>
/// <returns>Returns a task with a result type of TEntity</returns>
Task<IEnumerable<TEntity>> GetAll();
/// <summary>
/// Gets a single item given the primary key of the table given as an argument
/// </summary>
/// <param name="id">Primary key of table being queried</param>
/// <returns>Returns a task with a result type of TEntity</returns>
Task<TEntity> GetSingle(string id);
/// <summary>
/// Gets a single by an expression such as i => i.item == "This item"
/// </summary>
/// <param name="predicate">linq query defining the filter on the table</param>
/// <returns>Returns a task with the result type of TEntity</returns>
Task<TEntity> GetSingleByFilter(Expression<Func<TEntity,bool>> predicate);
/// <summary>
/// Get a list of items where they meet the given predicate
/// </summary>
/// <param name="predicate">linq query defining the filter on the table</param>
/// <returns>Returns a task with the result type of IEnumerable<TEntity></returns>
Task<IEnumerable<TEntity>> GetByFilter(Expression<Func<TEntity,bool>> predicate);
/// <summary>
/// Adds a given records to the database
/// </summary>
/// <param name="entity">object of type TEntity to be inserted in the database</param>
/// <returns>a boolean representing if the record was inserted successfully</returns>
Task<bool> Add(TEntity entity);
Task<bool> AddRange(IEnumerable<TEntity> entity);
/// <summary>
/// Adds an IEnumerable of type TEntity to the database
/// </summary>
/// <param name="entities">IEnumerable<TEntity> of items to be added to the database</param>
/// <returns>a boolean representing if the record was inserted successfully</returns>
Task<bool> AddRange(IEnumerable<TEntity> entities);

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@ public DataServiceAccessor(DataServicesContext context, ILogger<DataServiceAcces

public async Task<TEntity> GetSingle(Expression<Func<TEntity, bool>> predicate)
{
var result = _context.Set<TEntity>().SingleOrDefault(predicate);
var result = _context.Set<TEntity>().Where(predicate).ToList();
if(result.Count > 1){
throw new MultipleRecordsFoundException("Multiple Records where found for filter expression when only one was expected");
}
await Task.CompletedTask;
return result;
return result.Single();
}

public async Task<List<TEntity>> GetRange(Expression<Func<TEntity, bool>> predicates)
Expand Down Expand Up @@ -71,18 +74,21 @@ public async Task<TEntity> Update(TEntity entity, Expression<Func<TEntity, bool>
using var transaction = await _context.Database.BeginTransactionAsync();
TEntity updatedEntity = _context.Entry(existingEntity).CurrentValues.SetValues(entity);
var rowsEffected = await _context.SaveChangesAsync();


if(rowsEffected == 1)
{
return updatedEntity;
}
else if(rowsEffected > 1)
{
await transaction.RollbackAsync();
_logger.LogError("Multiple Records were updated by PUT request, Changes have been Rolledback");
throw new Exception("Multiple Records were updated by PUT request, Changes have been Rolledback");
_logger.LogError("Multiple Records were updated by PUT request, Changes have been Rolled-back");
throw new MultipleRecordsFoundException("Multiple Records were updated by PUT request, Changes have been Rolled-back");
}

_logger.LogError("No records were updated despite a record being found");
throw new Exception("Multiple Records were updated by PUT request, Changes have been Rolledback");
throw new MultipleRecordsFoundException("Multiple Records were updated by PUT request, Changes have been Rolled-back");

}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
namespace DataServices.Core;

public class MultipleRecordsFoundException : Exception
{
public MultipleRecordsFoundException()
{
}
public MultipleRecordsFoundException(string message)
: base(message)
{
}

public MultipleRecordsFoundException(string message, Exception inner)
: base(message, inner)
{
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ namespace DataServices.Core;
using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
using Common;
using FluentValidation.Validators;
using Microsoft.Azure.Functions.Worker.Http;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.Extensions.Logging;
Expand All @@ -21,12 +22,10 @@ public class RequestHandler<TEntity> : IRequestHandler<TEntity> where TEntity :

private readonly IDataServiceAccessor<TEntity> _dataServiceAccessor;
private readonly ILogger<RequestHandler<TEntity>> _logger;

private readonly AuthenticationConfiguration _authConfig;

private readonly PropertyInfo _keyInfo;

private static JsonSerializerOptions jsonSerializerOptions = new JsonSerializerOptions{
private static JsonSerializerOptions jsonSerializerOptions = new JsonSerializerOptions
{
UnmappedMemberHandling = JsonUnmappedMemberHandling.Disallow
};

Expand All @@ -38,7 +37,6 @@ public RequestHandler(IDataServiceAccessor<TEntity> dataServiceAccessor, ILogger
_logger = logger;
_authConfig = authenticationConfiguration;
_keyInfo = ReflectionUtilities.GetKey<TEntity>();

}

public async Task<HttpResponseData> HandleRequest(HttpRequestData req, string? key = null)
Expand Down Expand Up @@ -93,17 +91,37 @@ private async Task<HttpResponseData> Get(HttpRequestData req)
_logger.LogWarning("Unauthorized Method was called");
return CreateErrorResponse(req,UnauthorizedErrorMessage,HttpStatusCode.Unauthorized);
}
try{

try
{
var predicate = CreateFilterExpression(req);
var result = await _dataServiceAccessor.GetRange(predicate);
if(result == null || !result.Any())
object result;

if (GetBooleanQueryItem(req,"single"))
{
result = await _dataServiceAccessor.GetSingle(predicate);
}
else
{
result = await _dataServiceAccessor.GetRange(predicate);
}

if (!ResultHasContent(result))
{
return CreateErrorResponse(req,"No Data Found",HttpStatusCode.NoContent);
}

return CreateHttpResponse(req,new DataServiceResponse<string>
{
JsonData = JsonSerializer.Serialize(result)
});
}
catch(MultipleRecordsFoundException mre)
{
_logger.LogWarning(mre,"Multiple Records were returned from filter expression when only one was expected: {message}",mre.Message);
return CreateErrorResponse(req,"Multiple rows met filter condition when only one row was expected",HttpStatusCode.BadRequest);


}
catch(Exception ex)
{
Expand Down Expand Up @@ -256,7 +274,7 @@ private Expression<Func<TEntity, bool>> CreateGetByKeyExpression(string filter)
{
var entityParameter = Expression.Parameter(typeof(TEntity));
var entityKey = Expression.Property(entityParameter, _keyInfo.Name);
var filterConstant = Expression.Constant(Convert.ChangeType(filter, GetPropertyType(typeof(TEntity), _keyInfo.Name)));
var filterConstant = Expression.Constant(Convert.ChangeType(filter, ReflectionUtilities.GetPropertyType(typeof(TEntity), _keyInfo.Name)));
var expr = Expression.Equal(entityKey, filterConstant);

return Expression.Lambda<Func<TEntity, bool>>(expr, entityParameter);
Expand All @@ -278,13 +296,13 @@ private Expression<Func<TEntity, bool>> CreateFilterExpression(HttpRequestData r
return DynamicExpressionParser.ParseLambda<TEntity,bool>(new ParsingConfig(),true, req.Query[item]);
}

if (!PropertyExists(typeof(TEntity), item))
if (!ReflectionUtilities.PropertyExists(typeof(TEntity), item))
{
_logger.LogWarning("Query Item: '{item}' does not exist in TEntity: '{entityName}'", item, typeof(TEntity).Name);
continue;
}
var entityKey = Expression.Property(entityParameter, item);
var filterConstant = Expression.Constant(Convert.ChangeType(req.Query[item], GetPropertyType(typeof(TEntity), item)));
var filterConstant = Expression.Constant(Convert.ChangeType(req.Query[item], ReflectionUtilities.GetPropertyType(typeof(TEntity), item)));
var comparison = Expression.Equal(entityKey, filterConstant);
if (expr == null)
{
Expand All @@ -298,17 +316,34 @@ private Expression<Func<TEntity, bool>> CreateFilterExpression(HttpRequestData r

}

private static Type GetPropertyType(Type type, string property)
private static bool GetBooleanQueryItem(HttpRequestData req, string headerKey, bool defaultValue = false)
{
return type.GetProperty(property).PropertyType;
if(req.Query[headerKey] == null){
return defaultValue;
}
if(bool.TryParse(req.Query[headerKey],out var result)){
return result;
}
return defaultValue;
}

private static bool PropertyExists(Type type, string property) =>
Array.Exists(type.GetProperties(), p => p.Name == property);

private static bool ResultHasContent(Object obj)
{
if(obj == null)
{
return true;
}

if(obj is not IEnumerable<TEntity>) // Object isnt null and isnt IEnumerable so will have data
{
return true;
}

var data = (IEnumerable<TEntity>)obj;
var result = data.Any();

return result;
}

private HttpResponseData CreateErrorResponse(HttpRequestData req, string message, HttpStatusCode statusCode)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,23 @@ public async Task RunAsync_GetItemByQueryBadQuery_ReturnsBadRequest(string query
//assert
Assert.AreEqual(HttpStatusCode.BadRequest, result.StatusCode);
}
[DataRow("i => i.CountryCategory = \"ENGLAND\"")]
[TestMethod]
public async Task RunAsyncGetItemByQueryExpectsSingle_ReturnsBadRequest(string query)
{
//arrange
_authenticationConfiguration = DataServiceTestHelper.AllowAllAccessConfig;
var _requestHandler = new RequestHandler<BsSelectGpPractice>(_dataServiceAccessor, _mockRequestHandlerLogger.Object, _authenticationConfiguration);
BsSelectGpPracticeDataService function = new BsSelectGpPracticeDataService(_mockFunctionLogger.Object, _requestHandler, _createResponse);
var req = new MockHttpRequestData(_context.Object, "", "GET");
req.AddQuery("query", query);
req.AddQuery("single","true");
//act
var result = await function.Run(req, null);

//assert
Assert.AreEqual(HttpStatusCode.BadRequest, result.StatusCode);
}
#endregion
#region Deletes

Expand Down
Loading
Loading