diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/HttpCredentialsProvider.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/HttpCredentialsProvider.java index 3255b9e7..c3d2848e 100644 --- a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/HttpCredentialsProvider.java +++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/HttpCredentialsProvider.java @@ -15,6 +15,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.module.SimpleModule; +import com.google.common.collect.Multimaps; import com.google.inject.Inject; import io.airlift.http.client.FullJsonResponseHandler.JsonResponse; import io.airlift.http.client.HttpClient; @@ -28,6 +29,7 @@ import jakarta.ws.rs.core.UriBuilder; import java.net.URI; +import java.util.Map; import java.util.Optional; import static io.airlift.http.client.FullJsonResponseHandler.createFullJsonResponseHandler; @@ -40,6 +42,7 @@ public class HttpCredentialsProvider private final HttpClient httpClient; private final JsonCodec jsonCodec; private final URI httpCredentialsProviderEndpoint; + private final Optional> httpHeaders; @Inject public HttpCredentialsProvider(@ForHttpCredentialsProvider HttpClient httpClient, HttpCredentialsProviderConfig config, ObjectMapper objectMapper, Class identityClass) @@ -49,6 +52,7 @@ public HttpCredentialsProvider(@ForHttpCredentialsProvider HttpClient httpClient this.httpCredentialsProviderEndpoint = config.getEndpoint(); ObjectMapper adjustedObjectMapper = objectMapper.registerModule(new SimpleModule().addAbstractTypeMapping(Identity.class, identityClass)); this.jsonCodec = new JsonCodecFactory(() -> adjustedObjectMapper).jsonCodec(Credentials.class); + this.httpHeaders = config.getHttpHeaders(); } @Override @@ -56,10 +60,10 @@ public Optional credentials(String emulatedAccessKey, Optional uriBuilder.queryParam("sessionToken", sessionToken)); - Request request = prepareGet() - .setUri(uriBuilder.build()) - .build(); - JsonResponse response = httpClient.execute(request, createFullJsonResponseHandler(jsonCodec)); + Request.Builder requestBuilder = prepareGet() + .setUri(uriBuilder.build()); + httpHeaders.ifPresent(stringStringMap -> requestBuilder.addHeaders(Multimaps.forMap(stringStringMap))); + JsonResponse response = httpClient.execute(requestBuilder.build(), createFullJsonResponseHandler(jsonCodec)); if (response.getStatusCode() == HttpStatus.NOT_FOUND.code() || !response.hasValue()) { return Optional.empty(); } diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/HttpCredentialsProviderConfig.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/HttpCredentialsProviderConfig.java index 1188b61a..cc534a8a 100644 --- a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/HttpCredentialsProviderConfig.java +++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/HttpCredentialsProviderConfig.java @@ -17,10 +17,15 @@ import jakarta.validation.constraints.NotNull; import java.net.URI; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; public class HttpCredentialsProviderConfig { private URI endpoint; + private Optional> httpHeaders = Optional.ofNullable(null); @NotNull public URI getEndpoint() @@ -34,4 +39,25 @@ public HttpCredentialsProviderConfig setEndpoint(String endpoint) this.endpoint = URI.create(endpoint); return this; } + + public Optional> getHttpHeaders() + { + return httpHeaders; + } + + @Config("credentials-provider.http.headers") + public HttpCredentialsProviderConfig setHttpHeaders(List httpHeaderList) + { + try { + this.httpHeaders = Optional.of(httpHeaderList.stream() + .map(s -> s.split(":", 2)) + .collect(Collectors.toUnmodifiableMap( + a -> a[0].trim(), + a -> a[1].trim()))); + } + catch (IndexOutOfBoundsException e) { + throw new IllegalArgumentException("Invalid HTTP header list: %s" + String.join(",", httpHeaderList)); + } + return this; + } } diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/http/TestHttpCredentialsProvider.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/http/TestHttpCredentialsProvider.java index 6788a59f..51872708 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/http/TestHttpCredentialsProvider.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/http/TestHttpCredentialsProvider.java @@ -13,9 +13,12 @@ */ package io.trino.aws.proxy.server.credentials.http; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.inject.Inject; +import io.airlift.http.client.HttpClient; import io.airlift.http.server.HttpServerConfig; import io.airlift.http.server.HttpServerInfo; import io.airlift.http.server.testing.TestingHttpServer; @@ -34,6 +37,7 @@ import org.junit.jupiter.api.Test; import java.io.IOException; +import java.util.List; import java.util.Optional; import static io.trino.aws.proxy.server.credentials.http.HttpCredentialsModule.HTTP_CREDENTIALS_PROVIDER_IDENTIFIER; @@ -46,10 +50,14 @@ public class TestHttpCredentialsProvider { private final CredentialsProvider credentialsProvider; + private final HttpClient httpClient; + private final ObjectMapper objectMapper; public static class Filter implements BuilderFilter { + private static String httpEndpointUri; + @Override public TestingTrinoAwsProxyServer.Builder filter(TestingTrinoAwsProxyServer.Builder builder) { @@ -57,6 +65,7 @@ public TestingTrinoAwsProxyServer.Builder filter(TestingTrinoAwsProxyServer.Buil try { httpCredentialsServer = createTestingHttpCredentialsServer(); httpCredentialsServer.start(); + httpEndpointUri = httpCredentialsServer.getBaseUrl().toString(); } catch (Exception e) { throw new RuntimeException("Failed to start test http credentials provider server", e); @@ -65,14 +74,17 @@ public TestingTrinoAwsProxyServer.Builder filter(TestingTrinoAwsProxyServer.Buil .addModule(new HttpCredentialsModule()) .addModule(binder -> bindIdentityType(binder, TestingIdentity.class)) .withProperty("credentials-provider.type", HTTP_CREDENTIALS_PROVIDER_IDENTIFIER) - .withProperty("credentials-provider.http.endpoint", httpCredentialsServer.getBaseUrl().toString()); + .withProperty("credentials-provider.http.endpoint", httpEndpointUri) + .withProperty("credentials-provider.http.headers", "Authorization: auth, Content-Type: application/json"); } } @Inject - public TestHttpCredentialsProvider(CredentialsProvider credentialsProvider) + public TestHttpCredentialsProvider(CredentialsProvider credentialsProvider, @ForHttpCredentialsProvider HttpClient httpClient, ObjectMapper objectMapper) { this.credentialsProvider = requireNonNull(credentialsProvider, "credentialsProvider is null"); + this.httpClient = requireNonNull(httpClient, "httpClient is null"); + this.objectMapper = requireNonNull(objectMapper, "objectMapper is null"); } @Test @@ -123,6 +135,14 @@ public void testIncorrectResponseFromServer() assertThat(actual).isEmpty(); } + @Test void testValidCredentialsIncorrectHeader() + { + HttpCredentialsProviderConfig config = new HttpCredentialsProviderConfig().setEndpoint(Filter.httpEndpointUri).setHttpHeaders(List.of("incorrect-header: incorrect-value")); + HttpCredentialsProvider customHttpCredentialsProvider = new HttpCredentialsProvider(httpClient, config, objectMapper, TestingIdentity.class); + Optional actual = customHttpCredentialsProvider.credentials("test-emulated-access-key", Optional.of("test-emulated-access-key")); + assertThat(actual).isEmpty(); + } + private static TestingHttpServer createTestingHttpCredentialsServer() throws IOException { @@ -139,6 +159,10 @@ private static class HttpCredentialsServlet protected void doGet(HttpServletRequest request, HttpServletResponse response) throws IOException { + if (Strings.isNullOrEmpty(request.getHeader("Authorization")) || Strings.isNullOrEmpty("Content-Type")) { + response.setStatus(HttpServletResponse.SC_BAD_REQUEST); + return; + } Optional sessionToken = Optional.ofNullable(request.getParameter("sessionToken")); String emulatedAccessKey = request.getPathInfo().substring(1); String credentialsIdentifier = ""; diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/http/TestHttpCredentialsProviderConfig.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/http/TestHttpCredentialsProviderConfig.java index a07d6079..e257f2e0 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/http/TestHttpCredentialsProviderConfig.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/http/TestHttpCredentialsProviderConfig.java @@ -16,21 +16,31 @@ import com.google.common.collect.ImmutableMap; import org.junit.jupiter.api.Test; -import java.io.IOException; +import java.util.List; import java.util.Map; import static io.airlift.configuration.testing.ConfigAssertions.assertFullMapping; +import static org.junit.jupiter.api.Assertions.assertThrows; public class TestHttpCredentialsProviderConfig { @Test public void testExplicitPropertyMappings() - throws IOException { Map properties = ImmutableMap.of( - "credentials-provider.http.endpoint", "http://usersvc:9000/api/v1/users"); + "credentials-provider.http.endpoint", "http://usersvc:9000/api/v1/users", + "credentials-provider.http.headers", "x-api-key: xyz123, Content-Type: application/json"); HttpCredentialsProviderConfig expected = new HttpCredentialsProviderConfig() - .setEndpoint("http://usersvc:9000/api/v1/users"); + .setEndpoint("http://usersvc:9000/api/v1/users") + .setHttpHeaders(List.of("x-api-key: xyz123", "Content-Type: application/json")); assertFullMapping(properties, expected); } + + @Test + public void testIncorrectHttpHeader() + { + assertThrows(IllegalArgumentException.class, () -> new HttpCredentialsProviderConfig() + .setEndpoint("http://usersvc:9000/api/v1/users") + .setHttpHeaders(List.of("malformed-header"))); + } }