diff --git a/application/CohortManager/src/Functions/Shared/Common/Utilities/ReflectionUtilities.cs b/application/CohortManager/src/Functions/Shared/Common/Utilities/ReflectionUtilities.cs index f7ed1b15a..fbd5a985d 100644 --- a/application/CohortManager/src/Functions/Shared/Common/Utilities/ReflectionUtilities.cs +++ b/application/CohortManager/src/Functions/Shared/Common/Utilities/ReflectionUtilities.cs @@ -16,4 +16,15 @@ public static PropertyInfo GetKey() where TEntity : class return type.GetProperties().FirstOrDefault(p => p.CustomAttributes.Any(attr => attr.AttributeType == typeof(KeyAttribute))); } + + + public static Type GetPropertyType(Type type, string property) + { + return type.GetProperty(property).PropertyType; + } + + public static bool PropertyExists(Type type, string property) => + Array.Exists(type.GetProperties(), p => p.Name == property); + + } diff --git a/application/CohortManager/src/Functions/Shared/DataServices.Client/DataServiceClient.cs b/application/CohortManager/src/Functions/Shared/DataServices.Client/DataServiceClient.cs index 06d51d91b..3a8416ae2 100644 --- a/application/CohortManager/src/Functions/Shared/DataServices.Client/DataServiceClient.cs +++ b/application/CohortManager/src/Functions/Shared/DataServices.Client/DataServiceClient.cs @@ -46,33 +46,15 @@ public async Task> GetAll() public async Task> GetByFilter(Expression> predicate) { - try - { - //Resolves the constants - var expr = new ClosureResolver().Visit(predicate); - _logger.LogWarning(expr.ToString()); - - var jsonString = await _callFunction.SendGet(_baseUrl,new Dictionary{{"query",expr.ToString()}}); - if(string.IsNullOrEmpty(jsonString)) - { - return null; - } - IEnumerable result = JsonSerializer.Deserialize>(jsonString); - return result; - - } - catch(WebException wex) + var jsonString = await GetJsonStringByFilter(predicate); + if(string.IsNullOrEmpty(jsonString)) { - HttpWebResponse response = (HttpWebResponse)wex.Response; - if(response.StatusCode! == HttpStatusCode.NotFound) - { - return null; - } - - _logger.LogError(wex,"An Exception Happened while calling data service API"); - throw; + return null; } + IEnumerable result = JsonSerializer.Deserialize>(jsonString); + return result; + } public virtual async Task GetSingle(string id) @@ -102,15 +84,27 @@ public virtual async Task GetSingle(string id) throw; } } + public async Task GetSingleByFilter(Expression> predicate) + { + + var jsonString = await GetJsonStringByFilter(predicate,true); + if(string.IsNullOrEmpty(jsonString)) + { + return null; + } + TEntity result = JsonSerializer.Deserialize(jsonString); + return result; + + } public async Task Delete(string id) { var result = await _callFunction.SendDelete(UrlBuilder(_baseUrl,id)); return result; } - public async Task AddRange(IEnumerable entity) + public async Task AddRange(IEnumerable entities) { - var jsonString = JsonSerializer.Serialize>(entity); + var jsonString = JsonSerializer.Serialize>(entities); if(string.IsNullOrEmpty(jsonString)) { @@ -164,6 +158,38 @@ public async Task Update(TEntity entity) return true; } + private async Task GetJsonStringByFilter(Expression> predicate, bool returnOneRecord = false) + { + try + { + + //Resolves the constants + var expr = new ClosureResolver().Visit(predicate); + + var queryItems = new Dictionary{{"query",expr.ToString()}}; + + if(returnOneRecord) + { + queryItems.Add("single","true"); + } + + var jsonString = await _callFunction.SendGet(_baseUrl,queryItems); + return jsonString; + } + catch(WebException wex) + { + HttpWebResponse response = (HttpWebResponse)wex.Response; + if(response.StatusCode! == HttpStatusCode.NotFound) + { + return null; + } + + _logger.LogError(wex,"An Exception Happened while calling data service API"); + throw; + } + + } + private static string UrlBuilder(string baseUrl, string argument) { baseUrl = baseUrl.TrimEnd('/'); @@ -171,4 +197,5 @@ private static string UrlBuilder(string baseUrl, string argument) return string.Format("{0}/{1}", baseUrl, argument); } + } diff --git a/application/CohortManager/src/Functions/Shared/DataServices.Client/IDataServiceClient.cs b/application/CohortManager/src/Functions/Shared/DataServices.Client/IDataServiceClient.cs index 09499411b..76b5878a9 100644 --- a/application/CohortManager/src/Functions/Shared/DataServices.Client/IDataServiceClient.cs +++ b/application/CohortManager/src/Functions/Shared/DataServices.Client/IDataServiceClient.cs @@ -4,10 +4,40 @@ namespace DataServices.Client; public interface IDataServiceClient { + /// + /// Gets all items + /// + /// Returns a task with a result type of TEntity Task> GetAll(); + /// + /// Gets a single item given the primary key of the table given as an argument + /// + /// Primary key of table being queried + /// Returns a task with a result type of TEntity Task GetSingle(string id); + /// + /// Gets a single by an expression such as i => i.item == "This item" + /// + /// linq query defining the filter on the table + /// Returns a task with the result type of TEntity + Task GetSingleByFilter(Expression> predicate); + /// + /// Get a list of items where they meet the given predicate + /// + /// linq query defining the filter on the table + /// Returns a task with the result type of IEnumerable Task> GetByFilter(Expression> predicate); + /// + /// Adds a given records to the database + /// + /// object of type TEntity to be inserted in the database + /// a boolean representing if the record was inserted successfully Task Add(TEntity entity); - Task AddRange(IEnumerable entity); + /// + /// Adds an IEnumerable of type TEntity to the database + /// + /// IEnumerable of items to be added to the database + /// a boolean representing if the record was inserted successfully + Task AddRange(IEnumerable entities); } diff --git a/application/CohortManager/src/Functions/Shared/DataServices.Core/DataServiceAccessor.cs b/application/CohortManager/src/Functions/Shared/DataServices.Core/DataServiceAccessor.cs index 3fca5c58d..1cd2c304d 100644 --- a/application/CohortManager/src/Functions/Shared/DataServices.Core/DataServiceAccessor.cs +++ b/application/CohortManager/src/Functions/Shared/DataServices.Core/DataServiceAccessor.cs @@ -19,9 +19,12 @@ public DataServiceAccessor(DataServicesContext context, ILogger GetSingle(Expression> predicate) { - var result = _context.Set().SingleOrDefault(predicate); + var result = _context.Set().Where(predicate).ToList(); + if(result.Count > 1){ + throw new MultipleRecordsFoundException("Multiple Records where found for filter expression when only one was expected"); + } await Task.CompletedTask; - return result; + return result.Single(); } public async Task> GetRange(Expression> predicates) @@ -71,6 +74,8 @@ public async Task Update(TEntity entity, Expression using var transaction = await _context.Database.BeginTransactionAsync(); TEntity updatedEntity = _context.Entry(existingEntity).CurrentValues.SetValues(entity); var rowsEffected = await _context.SaveChangesAsync(); + + if(rowsEffected == 1) { return updatedEntity; @@ -78,11 +83,12 @@ public async Task Update(TEntity entity, Expression else if(rowsEffected > 1) { await transaction.RollbackAsync(); - _logger.LogError("Multiple Records were updated by PUT request, Changes have been Rolledback"); - throw new Exception("Multiple Records were updated by PUT request, Changes have been Rolledback"); + _logger.LogError("Multiple Records were updated by PUT request, Changes have been Rolled-back"); + throw new MultipleRecordsFoundException("Multiple Records were updated by PUT request, Changes have been Rolled-back"); } + _logger.LogError("No records were updated despite a record being found"); - throw new Exception("Multiple Records were updated by PUT request, Changes have been Rolledback"); + throw new MultipleRecordsFoundException("Multiple Records were updated by PUT request, Changes have been Rolled-back"); } diff --git a/application/CohortManager/src/Functions/Shared/DataServices.Core/Exceptions/MultipleRecordsFoundException.cs b/application/CohortManager/src/Functions/Shared/DataServices.Core/Exceptions/MultipleRecordsFoundException.cs new file mode 100644 index 000000000..9960ccd19 --- /dev/null +++ b/application/CohortManager/src/Functions/Shared/DataServices.Core/Exceptions/MultipleRecordsFoundException.cs @@ -0,0 +1,17 @@ +namespace DataServices.Core; + +public class MultipleRecordsFoundException : Exception +{ + public MultipleRecordsFoundException() + { + } + public MultipleRecordsFoundException(string message) + : base(message) + { + } + + public MultipleRecordsFoundException(string message, Exception inner) + : base(message, inner) + { + } +} diff --git a/application/CohortManager/src/Functions/Shared/DataServices.Core/RequestHandler.cs b/application/CohortManager/src/Functions/Shared/DataServices.Core/RequestHandler.cs index 1103357d2..d580b33bd 100644 --- a/application/CohortManager/src/Functions/Shared/DataServices.Core/RequestHandler.cs +++ b/application/CohortManager/src/Functions/Shared/DataServices.Core/RequestHandler.cs @@ -11,6 +11,7 @@ namespace DataServices.Core; using System.Text.Json.Nodes; using System.Text.Json.Serialization; using Common; +using FluentValidation.Validators; using Microsoft.Azure.Functions.Worker.Http; using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.Extensions.Logging; @@ -21,12 +22,10 @@ public class RequestHandler : IRequestHandler where TEntity : private readonly IDataServiceAccessor _dataServiceAccessor; private readonly ILogger> _logger; - private readonly AuthenticationConfiguration _authConfig; - private readonly PropertyInfo _keyInfo; - - private static JsonSerializerOptions jsonSerializerOptions = new JsonSerializerOptions{ + private static JsonSerializerOptions jsonSerializerOptions = new JsonSerializerOptions + { UnmappedMemberHandling = JsonUnmappedMemberHandling.Disallow }; @@ -38,7 +37,6 @@ public RequestHandler(IDataServiceAccessor dataServiceAccessor, ILogger _logger = logger; _authConfig = authenticationConfiguration; _keyInfo = ReflectionUtilities.GetKey(); - } public async Task HandleRequest(HttpRequestData req, string? key = null) @@ -93,17 +91,37 @@ private async Task Get(HttpRequestData req) _logger.LogWarning("Unauthorized Method was called"); return CreateErrorResponse(req,UnauthorizedErrorMessage,HttpStatusCode.Unauthorized); } - try{ + + try + { var predicate = CreateFilterExpression(req); - var result = await _dataServiceAccessor.GetRange(predicate); - if(result == null || !result.Any()) + object result; + + if (GetBooleanQueryItem(req,"single")) + { + result = await _dataServiceAccessor.GetSingle(predicate); + } + else + { + result = await _dataServiceAccessor.GetRange(predicate); + } + + if (!ResultHasContent(result)) { return CreateErrorResponse(req,"No Data Found",HttpStatusCode.NoContent); } + return CreateHttpResponse(req,new DataServiceResponse { JsonData = JsonSerializer.Serialize(result) }); + } + catch(MultipleRecordsFoundException mre) + { + _logger.LogWarning(mre,"Multiple Records were returned from filter expression when only one was expected: {message}",mre.Message); + return CreateErrorResponse(req,"Multiple rows met filter condition when only one row was expected",HttpStatusCode.BadRequest); + + } catch(Exception ex) { @@ -256,7 +274,7 @@ private Expression> CreateGetByKeyExpression(string filter) { var entityParameter = Expression.Parameter(typeof(TEntity)); var entityKey = Expression.Property(entityParameter, _keyInfo.Name); - var filterConstant = Expression.Constant(Convert.ChangeType(filter, GetPropertyType(typeof(TEntity), _keyInfo.Name))); + var filterConstant = Expression.Constant(Convert.ChangeType(filter, ReflectionUtilities.GetPropertyType(typeof(TEntity), _keyInfo.Name))); var expr = Expression.Equal(entityKey, filterConstant); return Expression.Lambda>(expr, entityParameter); @@ -278,13 +296,13 @@ private Expression> CreateFilterExpression(HttpRequestData r return DynamicExpressionParser.ParseLambda(new ParsingConfig(),true, req.Query[item]); } - if (!PropertyExists(typeof(TEntity), item)) + if (!ReflectionUtilities.PropertyExists(typeof(TEntity), item)) { _logger.LogWarning("Query Item: '{item}' does not exist in TEntity: '{entityName}'", item, typeof(TEntity).Name); continue; } var entityKey = Expression.Property(entityParameter, item); - var filterConstant = Expression.Constant(Convert.ChangeType(req.Query[item], GetPropertyType(typeof(TEntity), item))); + var filterConstant = Expression.Constant(Convert.ChangeType(req.Query[item], ReflectionUtilities.GetPropertyType(typeof(TEntity), item))); var comparison = Expression.Equal(entityKey, filterConstant); if (expr == null) { @@ -298,17 +316,34 @@ private Expression> CreateFilterExpression(HttpRequestData r } - private static Type GetPropertyType(Type type, string property) + private static bool GetBooleanQueryItem(HttpRequestData req, string headerKey, bool defaultValue = false) { - return type.GetProperty(property).PropertyType; + if(req.Query[headerKey] == null){ + return defaultValue; + } + if(bool.TryParse(req.Query[headerKey],out var result)){ + return result; + } + return defaultValue; } - private static bool PropertyExists(Type type, string property) => - Array.Exists(type.GetProperties(), p => p.Name == property); - + private static bool ResultHasContent(Object obj) + { + if(obj == null) + { + return true; + } + if(obj is not IEnumerable) // Object isnt null and isnt IEnumerable so will have data + { + return true; + } + var data = (IEnumerable)obj; + var result = data.Any(); + return result; + } private HttpResponseData CreateErrorResponse(HttpRequestData req, string message, HttpStatusCode statusCode) { diff --git a/tests/UnitTests/DataServiceTests/BsSelectGpPracticeDataServiceTests.cs b/tests/UnitTests/DataServiceTests/BsSelectGpPracticeDataServiceTests.cs index fdd5abc98..1467612f0 100644 --- a/tests/UnitTests/DataServiceTests/BsSelectGpPracticeDataServiceTests.cs +++ b/tests/UnitTests/DataServiceTests/BsSelectGpPracticeDataServiceTests.cs @@ -228,6 +228,23 @@ public async Task RunAsync_GetItemByQueryBadQuery_ReturnsBadRequest(string query //assert Assert.AreEqual(HttpStatusCode.BadRequest, result.StatusCode); } + [DataRow("i => i.CountryCategory = \"ENGLAND\"")] + [TestMethod] + public async Task RunAsyncGetItemByQueryExpectsSingle_ReturnsBadRequest(string query) + { + //arrange + _authenticationConfiguration = DataServiceTestHelper.AllowAllAccessConfig; + var _requestHandler = new RequestHandler(_dataServiceAccessor, _mockRequestHandlerLogger.Object, _authenticationConfiguration); + BsSelectGpPracticeDataService function = new BsSelectGpPracticeDataService(_mockFunctionLogger.Object, _requestHandler, _createResponse); + var req = new MockHttpRequestData(_context.Object, "", "GET"); + req.AddQuery("query", query); + req.AddQuery("single","true"); + //act + var result = await function.Run(req, null); + + //assert + Assert.AreEqual(HttpStatusCode.BadRequest, result.StatusCode); + } #endregion #region Deletes diff --git a/tests/UnitTests/DataServiceTests/DataServiceClientTests.cs b/tests/UnitTests/DataServiceTests/DataServiceClientTests.cs index 6184e4a4c..cc85d690a 100644 --- a/tests/UnitTests/DataServiceTests/DataServiceClientTests.cs +++ b/tests/UnitTests/DataServiceTests/DataServiceClientTests.cs @@ -102,6 +102,26 @@ public async Task GetByFilter_SendsGetByFilterRequest_ReturnsArray() _mockCallFunction.VerifyNoOtherCalls(); } [TestMethod] + public async Task GetSingleByFilter_SendsValidRequest_ReturnsParticipantDemographic() + { + //arrange + DataServiceClient dataServiceClient = new DataServiceClient(_mockLogger.Object, _dataServiceResolver, _mockCallFunction.Object); + _mockCallFunction.Setup(i => i.SendGet(It.IsAny(), It.IsAny>())).ReturnsAsync("{}"); + + var participant = new ParticipantDemographic + { + ParticipantId = 123 + }; + + //act + var result = await dataServiceClient.GetSingleByFilter(i => i.ParticipantId == 123); + + //assert + result.Should().BeAssignableTo(); + _mockCallFunction.Verify(i => i.SendGet(baseUrl, It.IsAny>()), Times.Once); + _mockCallFunction.VerifyNoOtherCalls(); + } + [TestMethod] public async Task GetSingle_SendsGetSingleRequest_ReturnsParticipantDemographic() { //arrange