diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/HttpCredentialsProvider.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/HttpCredentialsProvider.java index 3255b9e7..7e5696e3 100644 --- a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/HttpCredentialsProvider.java +++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/HttpCredentialsProvider.java @@ -15,6 +15,8 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.module.SimpleModule; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Multimaps; import com.google.inject.Inject; import io.airlift.http.client.FullJsonResponseHandler.JsonResponse; import io.airlift.http.client.HttpClient; @@ -28,6 +30,7 @@ import jakarta.ws.rs.core.UriBuilder; import java.net.URI; +import java.util.Map; import java.util.Optional; import static io.airlift.http.client.FullJsonResponseHandler.createFullJsonResponseHandler; @@ -40,6 +43,7 @@ public class HttpCredentialsProvider private final HttpClient httpClient; private final JsonCodec jsonCodec; private final URI httpCredentialsProviderEndpoint; + private final Map httpHeaders; @Inject public HttpCredentialsProvider(@ForHttpCredentialsProvider HttpClient httpClient, HttpCredentialsProviderConfig config, ObjectMapper objectMapper, Class identityClass) @@ -49,6 +53,7 @@ public HttpCredentialsProvider(@ForHttpCredentialsProvider HttpClient httpClient this.httpCredentialsProviderEndpoint = config.getEndpoint(); ObjectMapper adjustedObjectMapper = objectMapper.registerModule(new SimpleModule().addAbstractTypeMapping(Identity.class, identityClass)); this.jsonCodec = new JsonCodecFactory(() -> adjustedObjectMapper).jsonCodec(Credentials.class); + this.httpHeaders = ImmutableMap.copyOf(config.getHttpHeaders()); } @Override @@ -56,10 +61,10 @@ public Optional credentials(String emulatedAccessKey, Optional uriBuilder.queryParam("sessionToken", sessionToken)); - Request request = prepareGet() - .setUri(uriBuilder.build()) - .build(); - JsonResponse response = httpClient.execute(request, createFullJsonResponseHandler(jsonCodec)); + Request.Builder requestBuilder = prepareGet() + .addHeaders(Multimaps.forMap(httpHeaders)) + .setUri(uriBuilder.build()); + JsonResponse response = httpClient.execute(requestBuilder.build(), createFullJsonResponseHandler(jsonCodec)); if (response.getStatusCode() == HttpStatus.NOT_FOUND.code() || !response.hasValue()) { return Optional.empty(); } diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/HttpCredentialsProviderConfig.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/HttpCredentialsProviderConfig.java index 1188b61a..dda051e0 100644 --- a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/HttpCredentialsProviderConfig.java +++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/HttpCredentialsProviderConfig.java @@ -17,10 +17,15 @@ import jakarta.validation.constraints.NotNull; import java.net.URI; +import java.util.List; +import java.util.Map; + +import static com.google.common.collect.ImmutableMap.toImmutableMap; public class HttpCredentialsProviderConfig { private URI endpoint; + private Map httpHeaders = Map.of(); @NotNull public URI getEndpoint() @@ -34,4 +39,25 @@ public HttpCredentialsProviderConfig setEndpoint(String endpoint) this.endpoint = URI.create(endpoint); return this; } + + public Map getHttpHeaders() + { + return httpHeaders; + } + + @Config("credentials-provider.http.headers") + public HttpCredentialsProviderConfig setHttpHeaders(List httpHeaderList) + { + try { + this.httpHeaders = httpHeaderList.stream() + .map(s -> s.split(":", 2)) + .collect(toImmutableMap( + a -> a[0].trim(), + a -> a[1].trim())); + } + catch (IndexOutOfBoundsException e) { + throw new IllegalArgumentException("Invalid HTTP header list: %s" + String.join(",", httpHeaderList)); + } + return this; + } } diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/http/TestHttpCredentialsProvider.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/http/TestHttpCredentialsProvider.java index 6788a59f..51872708 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/http/TestHttpCredentialsProvider.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/http/TestHttpCredentialsProvider.java @@ -13,9 +13,12 @@ */ package io.trino.aws.proxy.server.credentials.http; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.inject.Inject; +import io.airlift.http.client.HttpClient; import io.airlift.http.server.HttpServerConfig; import io.airlift.http.server.HttpServerInfo; import io.airlift.http.server.testing.TestingHttpServer; @@ -34,6 +37,7 @@ import org.junit.jupiter.api.Test; import java.io.IOException; +import java.util.List; import java.util.Optional; import static io.trino.aws.proxy.server.credentials.http.HttpCredentialsModule.HTTP_CREDENTIALS_PROVIDER_IDENTIFIER; @@ -46,10 +50,14 @@ public class TestHttpCredentialsProvider { private final CredentialsProvider credentialsProvider; + private final HttpClient httpClient; + private final ObjectMapper objectMapper; public static class Filter implements BuilderFilter { + private static String httpEndpointUri; + @Override public TestingTrinoAwsProxyServer.Builder filter(TestingTrinoAwsProxyServer.Builder builder) { @@ -57,6 +65,7 @@ public TestingTrinoAwsProxyServer.Builder filter(TestingTrinoAwsProxyServer.Buil try { httpCredentialsServer = createTestingHttpCredentialsServer(); httpCredentialsServer.start(); + httpEndpointUri = httpCredentialsServer.getBaseUrl().toString(); } catch (Exception e) { throw new RuntimeException("Failed to start test http credentials provider server", e); @@ -65,14 +74,17 @@ public TestingTrinoAwsProxyServer.Builder filter(TestingTrinoAwsProxyServer.Buil .addModule(new HttpCredentialsModule()) .addModule(binder -> bindIdentityType(binder, TestingIdentity.class)) .withProperty("credentials-provider.type", HTTP_CREDENTIALS_PROVIDER_IDENTIFIER) - .withProperty("credentials-provider.http.endpoint", httpCredentialsServer.getBaseUrl().toString()); + .withProperty("credentials-provider.http.endpoint", httpEndpointUri) + .withProperty("credentials-provider.http.headers", "Authorization: auth, Content-Type: application/json"); } } @Inject - public TestHttpCredentialsProvider(CredentialsProvider credentialsProvider) + public TestHttpCredentialsProvider(CredentialsProvider credentialsProvider, @ForHttpCredentialsProvider HttpClient httpClient, ObjectMapper objectMapper) { this.credentialsProvider = requireNonNull(credentialsProvider, "credentialsProvider is null"); + this.httpClient = requireNonNull(httpClient, "httpClient is null"); + this.objectMapper = requireNonNull(objectMapper, "objectMapper is null"); } @Test @@ -123,6 +135,14 @@ public void testIncorrectResponseFromServer() assertThat(actual).isEmpty(); } + @Test void testValidCredentialsIncorrectHeader() + { + HttpCredentialsProviderConfig config = new HttpCredentialsProviderConfig().setEndpoint(Filter.httpEndpointUri).setHttpHeaders(List.of("incorrect-header: incorrect-value")); + HttpCredentialsProvider customHttpCredentialsProvider = new HttpCredentialsProvider(httpClient, config, objectMapper, TestingIdentity.class); + Optional actual = customHttpCredentialsProvider.credentials("test-emulated-access-key", Optional.of("test-emulated-access-key")); + assertThat(actual).isEmpty(); + } + private static TestingHttpServer createTestingHttpCredentialsServer() throws IOException { @@ -139,6 +159,10 @@ private static class HttpCredentialsServlet protected void doGet(HttpServletRequest request, HttpServletResponse response) throws IOException { + if (Strings.isNullOrEmpty(request.getHeader("Authorization")) || Strings.isNullOrEmpty("Content-Type")) { + response.setStatus(HttpServletResponse.SC_BAD_REQUEST); + return; + } Optional sessionToken = Optional.ofNullable(request.getParameter("sessionToken")); String emulatedAccessKey = request.getPathInfo().substring(1); String credentialsIdentifier = ""; diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/http/TestHttpCredentialsProviderConfig.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/http/TestHttpCredentialsProviderConfig.java index a07d6079..26868b87 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/http/TestHttpCredentialsProviderConfig.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/http/TestHttpCredentialsProviderConfig.java @@ -16,21 +16,42 @@ import com.google.common.collect.ImmutableMap; import org.junit.jupiter.api.Test; -import java.io.IOException; +import java.util.List; import java.util.Map; import static io.airlift.configuration.testing.ConfigAssertions.assertFullMapping; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; public class TestHttpCredentialsProviderConfig { @Test public void testExplicitPropertyMappings() - throws IOException { Map properties = ImmutableMap.of( - "credentials-provider.http.endpoint", "http://usersvc:9000/api/v1/users"); + "credentials-provider.http.endpoint", "http://usersvc:9000/api/v1/users", + "credentials-provider.http.headers", "x-api-key: xyz123, Content-Type: application/json"); HttpCredentialsProviderConfig expected = new HttpCredentialsProviderConfig() - .setEndpoint("http://usersvc:9000/api/v1/users"); + .setEndpoint("http://usersvc:9000/api/v1/users") + .setHttpHeaders(List.of("x-api-key: xyz123", "Content-Type: application/json")); assertFullMapping(properties, expected); } + + @Test + public void testValidHttpHeaderVariation1() + { + HttpCredentialsProviderConfig config = new HttpCredentialsProviderConfig() + .setEndpoint("http://usersvc:9000/api/v1/users") + .setHttpHeaders(List.of("x-api-key: Authorization: xyz123")); + Map httpHeaders = config.getHttpHeaders(); + assertThat(httpHeaders.get("x-api-key")).isEqualTo("Authorization: xyz123"); + } + + @Test + public void testIncorrectHttpHeader() + { + assertThrows(IllegalArgumentException.class, () -> new HttpCredentialsProviderConfig() + .setEndpoint("http://usersvc:9000/api/v1/users") + .setHttpHeaders(List.of("malformed-header"))); + } }