Skip to content

Commit

Permalink
Improved deserializer injection
Browse files Browse the repository at this point in the history
  • Loading branch information
pranavr12 committed Aug 9, 2024
1 parent a0ff43f commit af4d295
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.module.SimpleModule;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Multimaps;
import com.google.inject.Inject;
import io.airlift.http.client.FullJsonResponseHandler.JsonResponse;
import io.airlift.http.client.HttpClient;
Expand All @@ -28,6 +30,7 @@
import jakarta.ws.rs.core.UriBuilder;

import java.net.URI;
import java.util.Map;
import java.util.Optional;

import static io.airlift.http.client.FullJsonResponseHandler.createFullJsonResponseHandler;
Expand All @@ -40,6 +43,7 @@ public class HttpCredentialsProvider
private final HttpClient httpClient;
private final JsonCodec<Credentials> jsonCodec;
private final URI httpCredentialsProviderEndpoint;
private final Map<String, String> httpHeaders;

@Inject
public HttpCredentialsProvider(@ForHttpCredentialsProvider HttpClient httpClient, HttpCredentialsProviderConfig config, ObjectMapper objectMapper, Class<? extends Identity> identityClass)
Expand All @@ -49,17 +53,18 @@ public HttpCredentialsProvider(@ForHttpCredentialsProvider HttpClient httpClient
this.httpCredentialsProviderEndpoint = config.getEndpoint();
ObjectMapper adjustedObjectMapper = objectMapper.registerModule(new SimpleModule().addAbstractTypeMapping(Identity.class, identityClass));
this.jsonCodec = new JsonCodecFactory(() -> adjustedObjectMapper).jsonCodec(Credentials.class);
this.httpHeaders = ImmutableMap.copyOf(config.getHttpHeaders());
}

@Override
public Optional<Credentials> credentials(String emulatedAccessKey, Optional<String> session)
{
UriBuilder uriBuilder = UriBuilder.fromUri(httpCredentialsProviderEndpoint).path(emulatedAccessKey);
session.ifPresent(sessionToken -> uriBuilder.queryParam("sessionToken", sessionToken));
Request request = prepareGet()
.setUri(uriBuilder.build())
.build();
JsonResponse<Credentials> response = httpClient.execute(request, createFullJsonResponseHandler(jsonCodec));
Request.Builder requestBuilder = prepareGet()
.addHeaders(Multimaps.forMap(httpHeaders))
.setUri(uriBuilder.build());
JsonResponse<Credentials> response = httpClient.execute(requestBuilder.build(), createFullJsonResponseHandler(jsonCodec));
if (response.getStatusCode() == HttpStatus.NOT_FOUND.code() || !response.hasValue()) {
return Optional.empty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,15 @@
import jakarta.validation.constraints.NotNull;

import java.net.URI;
import java.util.List;
import java.util.Map;

import static com.google.common.collect.ImmutableMap.toImmutableMap;

public class HttpCredentialsProviderConfig
{
private URI endpoint;
private Map<String, String> httpHeaders = Map.of();

@NotNull
public URI getEndpoint()
Expand All @@ -34,4 +39,25 @@ public HttpCredentialsProviderConfig setEndpoint(String endpoint)
this.endpoint = URI.create(endpoint);
return this;
}

public Map<String, String> getHttpHeaders()
{
return httpHeaders;
}

@Config("credentials-provider.http.headers")
public HttpCredentialsProviderConfig setHttpHeaders(List<String> httpHeaderList)
{
try {
this.httpHeaders = httpHeaderList.stream()
.map(s -> s.split(":", 2))
.collect(toImmutableMap(
a -> a[0].trim(),
a -> a[1].trim()));
}
catch (IndexOutOfBoundsException e) {
throw new IllegalArgumentException("Invalid HTTP header list: %s" + String.join(",", httpHeaderList));
}
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
*/
package io.trino.aws.proxy.server.credentials.http;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.inject.Inject;
import io.airlift.http.client.HttpClient;
import io.airlift.http.server.HttpServerConfig;
import io.airlift.http.server.HttpServerInfo;
import io.airlift.http.server.testing.TestingHttpServer;
Expand All @@ -34,6 +37,7 @@
import org.junit.jupiter.api.Test;

import java.io.IOException;
import java.util.List;
import java.util.Optional;

import static io.trino.aws.proxy.server.credentials.http.HttpCredentialsModule.HTTP_CREDENTIALS_PROVIDER_IDENTIFIER;
Expand All @@ -46,17 +50,22 @@
public class TestHttpCredentialsProvider
{
private final CredentialsProvider credentialsProvider;
private final HttpClient httpClient;
private final ObjectMapper objectMapper;

public static class Filter
implements BuilderFilter
{
private static String httpEndpointUri;

@Override
public TestingTrinoAwsProxyServer.Builder filter(TestingTrinoAwsProxyServer.Builder builder)
{
TestingHttpServer httpCredentialsServer;
try {
httpCredentialsServer = createTestingHttpCredentialsServer();
httpCredentialsServer.start();
httpEndpointUri = httpCredentialsServer.getBaseUrl().toString();
}
catch (Exception e) {
throw new RuntimeException("Failed to start test http credentials provider server", e);
Expand All @@ -65,14 +74,17 @@ public TestingTrinoAwsProxyServer.Builder filter(TestingTrinoAwsProxyServer.Buil
.addModule(new HttpCredentialsModule())
.addModule(binder -> bindIdentityType(binder, TestingIdentity.class))
.withProperty("credentials-provider.type", HTTP_CREDENTIALS_PROVIDER_IDENTIFIER)
.withProperty("credentials-provider.http.endpoint", httpCredentialsServer.getBaseUrl().toString());
.withProperty("credentials-provider.http.endpoint", httpEndpointUri)
.withProperty("credentials-provider.http.headers", "Authorization: auth, Content-Type: application/json");
}
}

@Inject
public TestHttpCredentialsProvider(CredentialsProvider credentialsProvider)
public TestHttpCredentialsProvider(CredentialsProvider credentialsProvider, @ForHttpCredentialsProvider HttpClient httpClient, ObjectMapper objectMapper)
{
this.credentialsProvider = requireNonNull(credentialsProvider, "credentialsProvider is null");
this.httpClient = requireNonNull(httpClient, "httpClient is null");
this.objectMapper = requireNonNull(objectMapper, "objectMapper is null");
}

@Test
Expand Down Expand Up @@ -123,6 +135,14 @@ public void testIncorrectResponseFromServer()
assertThat(actual).isEmpty();
}

@Test void testValidCredentialsIncorrectHeader()
{
HttpCredentialsProviderConfig config = new HttpCredentialsProviderConfig().setEndpoint(Filter.httpEndpointUri).setHttpHeaders(List.of("incorrect-header: incorrect-value"));
HttpCredentialsProvider customHttpCredentialsProvider = new HttpCredentialsProvider(httpClient, config, objectMapper, TestingIdentity.class);
Optional<Credentials> actual = customHttpCredentialsProvider.credentials("test-emulated-access-key", Optional.of("test-emulated-access-key"));
assertThat(actual).isEmpty();
}

private static TestingHttpServer createTestingHttpCredentialsServer()
throws IOException
{
Expand All @@ -139,6 +159,10 @@ private static class HttpCredentialsServlet
protected void doGet(HttpServletRequest request, HttpServletResponse response)
throws IOException
{
if (Strings.isNullOrEmpty(request.getHeader("Authorization")) || Strings.isNullOrEmpty("Content-Type")) {
response.setStatus(HttpServletResponse.SC_BAD_REQUEST);
return;
}
Optional<String> sessionToken = Optional.ofNullable(request.getParameter("sessionToken"));
String emulatedAccessKey = request.getPathInfo().substring(1);
String credentialsIdentifier = "";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,42 @@
import com.google.common.collect.ImmutableMap;
import org.junit.jupiter.api.Test;

import java.io.IOException;
import java.util.List;
import java.util.Map;

import static io.airlift.configuration.testing.ConfigAssertions.assertFullMapping;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertThrows;

public class TestHttpCredentialsProviderConfig
{
@Test
public void testExplicitPropertyMappings()
throws IOException
{
Map<String, String> properties = ImmutableMap.of(
"credentials-provider.http.endpoint", "http://usersvc:9000/api/v1/users");
"credentials-provider.http.endpoint", "http://usersvc:9000/api/v1/users",
"credentials-provider.http.headers", "x-api-key: xyz123, Content-Type: application/json");
HttpCredentialsProviderConfig expected = new HttpCredentialsProviderConfig()
.setEndpoint("http://usersvc:9000/api/v1/users");
.setEndpoint("http://usersvc:9000/api/v1/users")
.setHttpHeaders(List.of("x-api-key: xyz123", "Content-Type: application/json"));
assertFullMapping(properties, expected);
}

@Test
public void testValidHttpHeaderVariation1()
{
HttpCredentialsProviderConfig config = new HttpCredentialsProviderConfig()
.setEndpoint("http://usersvc:9000/api/v1/users")
.setHttpHeaders(List.of("x-api-key: Authorization: xyz123"));
Map<String, String> httpHeaders = config.getHttpHeaders();
assertThat(httpHeaders.get("x-api-key")).isEqualTo("Authorization: xyz123");
}

@Test
public void testIncorrectHttpHeader()
{
assertThrows(IllegalArgumentException.class, () -> new HttpCredentialsProviderConfig()
.setEndpoint("http://usersvc:9000/api/v1/users")
.setHttpHeaders(List.of("malformed-header")));
}
}

0 comments on commit af4d295

Please sign in to comment.