From 6d87b5112da2feb63386319bb131ce9508ee95a1 Mon Sep 17 00:00:00 2001 From: Pranav Ramachandra Date: Mon, 12 Aug 2024 13:44:34 +0100 Subject: [PATCH] Added header to HTTP credential provider --- README.md | 15 ++++++ .../http/HttpCredentialsProvider.java | 13 +++-- .../http/HttpCredentialsProviderConfig.java | 28 +++++++++++ .../http/TestHttpCredentialsProvider.java | 11 ++++- .../TestHttpCredentialsProviderConfig.java | 47 +++++++++++++++++-- 5 files changed, 105 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index cd5fd638..f6b32d49 100644 --- a/README.md +++ b/README.md @@ -68,3 +68,18 @@ spark = SparkSession\ # try spark sql commands ``` + +## Module Configurations + +### HTTP Credentials Provider + +The HTTP credentials provider provides an option to include additional headers on requests sent to the HTTP service (e.g., for authentication). + +These can be configured with `credentials-provider.http.headers`. This config entry is formatted as a comma-separated list of header names and values, where each entry is in the format `header-name:header-value`. + +For instance, `header1:value1,header2:value2`. +If a header name or value should contain a comma, these can be escaped by doubling them (`,,` translates to a single comma in the literal header name or value, and is not treated as a separator). + +E.g.: setting this config property to `"x-api-key: xyz,,123, Authorization: key,,,,123"` results in 2 headers: +- `x-api-key`: with value `xyz,123` +- `Authorization`: with value `key,,123` \ No newline at end of file diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/HttpCredentialsProvider.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/HttpCredentialsProvider.java index 3255b9e7..7e5696e3 100644 --- a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/HttpCredentialsProvider.java +++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/HttpCredentialsProvider.java @@ -15,6 +15,8 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.module.SimpleModule; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Multimaps; import com.google.inject.Inject; import io.airlift.http.client.FullJsonResponseHandler.JsonResponse; import io.airlift.http.client.HttpClient; @@ -28,6 +30,7 @@ import jakarta.ws.rs.core.UriBuilder; import java.net.URI; +import java.util.Map; import java.util.Optional; import static io.airlift.http.client.FullJsonResponseHandler.createFullJsonResponseHandler; @@ -40,6 +43,7 @@ public class HttpCredentialsProvider private final HttpClient httpClient; private final JsonCodec jsonCodec; private final URI httpCredentialsProviderEndpoint; + private final Map httpHeaders; @Inject public HttpCredentialsProvider(@ForHttpCredentialsProvider HttpClient httpClient, HttpCredentialsProviderConfig config, ObjectMapper objectMapper, Class identityClass) @@ -49,6 +53,7 @@ public HttpCredentialsProvider(@ForHttpCredentialsProvider HttpClient httpClient this.httpCredentialsProviderEndpoint = config.getEndpoint(); ObjectMapper adjustedObjectMapper = objectMapper.registerModule(new SimpleModule().addAbstractTypeMapping(Identity.class, identityClass)); this.jsonCodec = new JsonCodecFactory(() -> adjustedObjectMapper).jsonCodec(Credentials.class); + this.httpHeaders = ImmutableMap.copyOf(config.getHttpHeaders()); } @Override @@ -56,10 +61,10 @@ public Optional credentials(String emulatedAccessKey, Optional uriBuilder.queryParam("sessionToken", sessionToken)); - Request request = prepareGet() - .setUri(uriBuilder.build()) - .build(); - JsonResponse response = httpClient.execute(request, createFullJsonResponseHandler(jsonCodec)); + Request.Builder requestBuilder = prepareGet() + .addHeaders(Multimaps.forMap(httpHeaders)) + .setUri(uriBuilder.build()); + JsonResponse response = httpClient.execute(requestBuilder.build(), createFullJsonResponseHandler(jsonCodec)); if (response.getStatusCode() == HttpStatus.NOT_FOUND.code() || !response.hasValue()) { return Optional.empty(); } diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/HttpCredentialsProviderConfig.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/HttpCredentialsProviderConfig.java index 1188b61a..c563534f 100644 --- a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/HttpCredentialsProviderConfig.java +++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/HttpCredentialsProviderConfig.java @@ -13,14 +13,19 @@ */ package io.trino.aws.proxy.server.credentials.http; +import com.google.common.base.Splitter; import io.airlift.configuration.Config; import jakarta.validation.constraints.NotNull; import java.net.URI; +import java.util.Map; + +import static com.google.common.collect.ImmutableMap.toImmutableMap; public class HttpCredentialsProviderConfig { private URI endpoint; + private Map httpHeaders = Map.of(); @NotNull public URI getEndpoint() @@ -34,4 +39,27 @@ public HttpCredentialsProviderConfig setEndpoint(String endpoint) this.endpoint = URI.create(endpoint); return this; } + + public Map getHttpHeaders() + { + return httpHeaders; + } + + @Config("credentials-provider.http.headers") + public HttpCredentialsProviderConfig setHttpHeaders(String httpHeadersList) + { + try { + this.httpHeaders = Splitter.on(",").trimResults().omitEmptyStrings() + .splitToStream(httpHeadersList.replaceAll(",,", "\r")) + .map(item -> item.replace("\r", ",")) + .map(s -> s.split(":", 2)) + .collect(toImmutableMap( + a -> a[0].trim(), + a -> a[1].trim())); + } + catch (IndexOutOfBoundsException e) { + throw new IllegalArgumentException("Invalid HTTP header list: " + httpHeadersList); + } + return this; + } } diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/http/TestHttpCredentialsProvider.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/http/TestHttpCredentialsProvider.java index 6788a59f..343bd685 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/http/TestHttpCredentialsProvider.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/http/TestHttpCredentialsProvider.java @@ -13,6 +13,7 @@ */ package io.trino.aws.proxy.server.credentials.http; +import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.inject.Inject; @@ -50,6 +51,8 @@ public class TestHttpCredentialsProvider public static class Filter implements BuilderFilter { + private static String httpEndpointUri; + @Override public TestingTrinoAwsProxyServer.Builder filter(TestingTrinoAwsProxyServer.Builder builder) { @@ -57,6 +60,7 @@ public TestingTrinoAwsProxyServer.Builder filter(TestingTrinoAwsProxyServer.Buil try { httpCredentialsServer = createTestingHttpCredentialsServer(); httpCredentialsServer.start(); + httpEndpointUri = httpCredentialsServer.getBaseUrl().toString(); } catch (Exception e) { throw new RuntimeException("Failed to start test http credentials provider server", e); @@ -65,7 +69,8 @@ public TestingTrinoAwsProxyServer.Builder filter(TestingTrinoAwsProxyServer.Buil .addModule(new HttpCredentialsModule()) .addModule(binder -> bindIdentityType(binder, TestingIdentity.class)) .withProperty("credentials-provider.type", HTTP_CREDENTIALS_PROVIDER_IDENTIFIER) - .withProperty("credentials-provider.http.endpoint", httpCredentialsServer.getBaseUrl().toString()); + .withProperty("credentials-provider.http.endpoint", httpEndpointUri) + .withProperty("credentials-provider.http.headers", "Authorization: auth, Content-Type: application/json"); } } @@ -139,6 +144,10 @@ private static class HttpCredentialsServlet protected void doGet(HttpServletRequest request, HttpServletResponse response) throws IOException { + if (Strings.isNullOrEmpty(request.getHeader("Authorization")) || Strings.isNullOrEmpty("Content-Type")) { + response.setStatus(HttpServletResponse.SC_BAD_REQUEST); + return; + } Optional sessionToken = Optional.ofNullable(request.getParameter("sessionToken")); String emulatedAccessKey = request.getPathInfo().substring(1); String credentialsIdentifier = ""; diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/http/TestHttpCredentialsProviderConfig.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/http/TestHttpCredentialsProviderConfig.java index a07d6079..6b1f2022 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/http/TestHttpCredentialsProviderConfig.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/http/TestHttpCredentialsProviderConfig.java @@ -16,21 +16,60 @@ import com.google.common.collect.ImmutableMap; import org.junit.jupiter.api.Test; -import java.io.IOException; import java.util.Map; import static io.airlift.configuration.testing.ConfigAssertions.assertFullMapping; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; public class TestHttpCredentialsProviderConfig { @Test public void testExplicitPropertyMappings() - throws IOException { Map properties = ImmutableMap.of( - "credentials-provider.http.endpoint", "http://usersvc:9000/api/v1/users"); + "credentials-provider.http.endpoint", "http://usersvc:9000/api/v1/users", + "credentials-provider.http.headers", "x-api-key: xyz123, Content-Type: application/json"); HttpCredentialsProviderConfig expected = new HttpCredentialsProviderConfig() - .setEndpoint("http://usersvc:9000/api/v1/users"); + .setEndpoint("http://usersvc:9000/api/v1/users") + .setHttpHeaders("x-api-key: xyz123, Content-Type: application/json"); assertFullMapping(properties, expected); } + + @Test + public void testValidHttpHeaderVariation1() + { + HttpCredentialsProviderConfig config = new HttpCredentialsProviderConfig() + .setEndpoint("http://usersvc:9000/api/v1/users") + .setHttpHeaders("x-api-key: Authorization: xyz123"); + Map httpHeaders = config.getHttpHeaders(); + assertThat(httpHeaders.get("x-api-key")).isEqualTo("Authorization: xyz123"); + } + + @Test + public void testValidHttpHeaderVariation2() + { + HttpCredentialsProviderConfig config = new HttpCredentialsProviderConfig() + .setEndpoint("http://usersvc:9000/api/v1/users") + .setHttpHeaders("x-api-key: xyz,,123, Authorization: key,,,,123"); + Map httpHeaders = config.getHttpHeaders(); + assertThat(httpHeaders.get("x-api-key")).isEqualTo("xyz,123"); + assertThat(httpHeaders.get("Authorization")).isEqualTo("key,,123"); + } + + @Test + public void testIncorrectHttpHeader1() + { + assertThrows(IllegalArgumentException.class, () -> new HttpCredentialsProviderConfig() + .setEndpoint("http://usersvc:9000/api/v1/users") + .setHttpHeaders("malformed-header")); + } + + @Test + public void testIncorrectHttpHeader2() + { + assertThrows(IllegalArgumentException.class, () -> new HttpCredentialsProviderConfig() + .setEndpoint("http://usersvc:9000/api/v1/users") + .setHttpHeaders("x-api-key: xyz,,,123, Authorization: key123")); + } }