diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/http/TestHttpCredentialsProvider.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/http/TestHttpCredentialsProvider.java index 330164fa..b8121c9d 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/http/TestHttpCredentialsProvider.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/http/TestHttpCredentialsProvider.java @@ -16,11 +16,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.inject.Inject; -import io.airlift.http.server.HttpServerConfig; -import io.airlift.http.server.HttpServerInfo; import io.airlift.http.server.testing.TestingHttpServer; -import io.airlift.json.ObjectMapperProvider; -import io.airlift.node.NodeInfo; +import io.trino.aws.proxy.server.testing.TestingHttpCredentialsProviderServlet; import io.trino.aws.proxy.server.testing.TestingIdentity; import io.trino.aws.proxy.server.testing.TestingTrinoAwsProxyServer; import io.trino.aws.proxy.server.testing.harness.BuilderFilter; @@ -28,40 +25,32 @@ import io.trino.aws.proxy.spi.credentials.Credential; import io.trino.aws.proxy.spi.credentials.Credentials; import io.trino.aws.proxy.spi.credentials.CredentialsProvider; -import jakarta.servlet.http.HttpServlet; -import jakarta.servlet.http.HttpServletRequest; -import jakarta.servlet.http.HttpServletResponse; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import java.io.IOException; import java.util.Map; import java.util.Optional; -import java.util.concurrent.atomic.AtomicInteger; import static io.trino.aws.proxy.server.credentials.http.HttpCredentialsModule.HTTP_CREDENTIALS_PROVIDER_IDENTIFIER; +import static io.trino.aws.proxy.server.testing.TestingHttpCredentialsProviderServlet.DUMMY_EMULATED_ACCESS_KEY; +import static io.trino.aws.proxy.server.testing.TestingHttpCredentialsProviderServlet.DUMMY_EMULATED_SECRET_KEY; +import static io.trino.aws.proxy.server.testing.TestingHttpCredentialsProviderServlet.DUMMY_REMOTE_ACCESS_KEY; +import static io.trino.aws.proxy.server.testing.TestingHttpCredentialsProviderServlet.DUMMY_REMOTE_SECRET_KEY; +import static io.trino.aws.proxy.server.testing.TestingUtil.createTestingHttpServer; import static io.trino.aws.proxy.spi.plugin.TrinoAwsProxyServerBinding.bindIdentityType; -import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON; import static java.util.Objects.requireNonNull; import static org.assertj.core.api.Assertions.assertThat; @TrinoAwsProxyTest(filters = TestHttpCredentialsProvider.Filter.class) public class TestHttpCredentialsProvider { - private static final String DUMMY_EMULATED_ACCESS_KEY = "test-emulated-access-key"; - private static final String DUMMY_EMULATED_SECRET_KEY = "test-emulated-secret-key"; - private static final String DUMMY_REMOTE_ACCESS_KEY = "test-remote-access-key"; - private static final String DUMMY_REMOTE_SECRET_KEY = "test-remote-secret-key"; - private final CredentialsProvider credentialsProvider; - private final HttpCredentialsServlet httpCredentialsServlet; + private final TestingHttpCredentialsProviderServlet httpCredentialsServlet; private final HttpCredentialsProvider httpCredentialsProvider; public static class Filter implements BuilderFilter { - private static String httpEndpointUri; - @Override public TestingTrinoAwsProxyServer.Builder filter(TestingTrinoAwsProxyServer.Builder builder) { @@ -73,9 +62,11 @@ public TestingTrinoAwsProxyServer.Builder filter(TestingTrinoAwsProxyServer.Buil .buildOrThrow(); String headerConfigAsString = "Authorization: some-auth, Content-Type: application/json, Some-Dummy-Header:test,,value"; - HttpCredentialsServlet httpCredentialsServlet = new HttpCredentialsServlet(expectedHeaders); + TestingHttpCredentialsProviderServlet httpCredentialsServlet; + String httpEndpointUri; try { - httpCredentialsServer = createTestingHttpCredentialsServer(httpCredentialsServlet); + httpCredentialsServlet = new TestingHttpCredentialsProviderServlet(expectedHeaders); + httpCredentialsServer = createTestingHttpServer(httpCredentialsServlet); httpCredentialsServer.start(); httpEndpointUri = httpCredentialsServer.getBaseUrl().toString(); } @@ -90,12 +81,12 @@ public TestingTrinoAwsProxyServer.Builder filter(TestingTrinoAwsProxyServer.Buil .withProperty("credentials-provider.http.headers", headerConfigAsString) .withProperty("credentials-provider.http.cache-size", "2") .withProperty("credentials-provider.http.cache-ttl", "10m") - .addModule(binder -> binder.bind(HttpCredentialsServlet.class).toInstance(httpCredentialsServlet)); + .addModule(binder -> binder.bind(TestingHttpCredentialsProviderServlet.class).toInstance(httpCredentialsServlet)); } } @Inject - public TestHttpCredentialsProvider(CredentialsProvider credentialsProvider, HttpCredentialsServlet httpCredentialsServlet, CredentialsProvider httpCredentialsProvider) + public TestHttpCredentialsProvider(CredentialsProvider credentialsProvider, TestingHttpCredentialsProviderServlet httpCredentialsServlet, CredentialsProvider httpCredentialsProvider) { this.credentialsProvider = requireNonNull(credentialsProvider, "credentialsProvider is null"); this.httpCredentialsServlet = requireNonNull(httpCredentialsServlet, "httpCredentialsServlet is null"); @@ -185,75 +176,4 @@ private void testNoCredentialsRetrieved(String emulatedAccessKey, Optional expectedHeaders; - private final AtomicInteger requestCounter; - - private HttpCredentialsServlet(Map expectedHeaders) - { - this.expectedHeaders = ImmutableMap.copyOf(expectedHeaders); - this.requestCounter = new AtomicInteger(); - } - - @Override - protected void doGet(HttpServletRequest request, HttpServletResponse response) - throws IOException - { - requestCounter.addAndGet(1); - for (Map.Entry expectedHeader : expectedHeaders.entrySet()) { - if (!expectedHeader.getValue().equals(request.getHeader(expectedHeader.getKey()))) { - response.setStatus(HttpServletResponse.SC_BAD_REQUEST); - return; - } - } - Optional sessionToken = Optional.ofNullable(request.getParameter("sessionToken")); - String emulatedAccessKey = request.getPathInfo().substring(1); - // The session token in the request is legal if it is either: - // - Not present - // - Matching our test logic: it should be equal to the access-key + "-token" - boolean isLegalSessionToken = sessionToken - .map(presentSessionToken -> "%s-token".formatted(emulatedAccessKey).equals(presentSessionToken)) - .orElse(true); - if (!isLegalSessionToken) { - response.setStatus(HttpServletResponse.SC_NOT_FOUND); - return; - } - switch (emulatedAccessKey) { - case DUMMY_EMULATED_ACCESS_KEY -> { - Credential emulated = new Credential(DUMMY_EMULATED_ACCESS_KEY, DUMMY_EMULATED_SECRET_KEY, sessionToken); - Credential remote = new Credential(DUMMY_REMOTE_ACCESS_KEY, DUMMY_REMOTE_SECRET_KEY); - Credentials credentials = new Credentials(emulated, Optional.of(remote), Optional.empty(), Optional.of(new TestingIdentity("test-username", ImmutableList.of(), "xyzpdq"))); - String jsonCredentials = new ObjectMapperProvider().get().writeValueAsString(credentials); - response.setContentType(APPLICATION_JSON); - response.getWriter().print(jsonCredentials); - } - case "incorrect-response" -> { - response.getWriter().print("incorrect response"); - } - default -> response.setStatus(HttpServletResponse.SC_NOT_FOUND); - } - } - - private int getRequestCount() - { - return requestCounter.get(); - } - - private void resetRequestCount() - { - requestCounter.set(0); - } - } } diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/rest/TestProxiedErrorResponses.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/rest/TestProxiedErrorResponses.java index 4f6dc2c6..20585a24 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/rest/TestProxiedErrorResponses.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/rest/TestProxiedErrorResponses.java @@ -14,15 +14,11 @@ package io.trino.aws.proxy.server.rest; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import com.google.inject.BindingAnnotation; import com.google.inject.Inject; import com.google.inject.Key; import io.airlift.http.client.HttpStatus; -import io.airlift.http.server.HttpServerConfig; -import io.airlift.http.server.HttpServerInfo; import io.airlift.http.server.testing.TestingHttpServer; -import io.airlift.node.NodeInfo; import io.trino.aws.proxy.server.remote.PathStyleRemoteS3Facade; import io.trino.aws.proxy.server.testing.TestingRemoteS3Facade; import io.trino.aws.proxy.server.testing.TestingTrinoAwsProxyServer.Builder; @@ -43,6 +39,7 @@ import java.util.Optional; import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static io.trino.aws.proxy.server.testing.TestingUtil.createTestingHttpServer; import static io.trino.aws.proxy.server.testing.TestingUtil.getFileFromStorage; import static java.lang.annotation.ElementType.FIELD; import static java.lang.annotation.ElementType.METHOD; @@ -87,7 +84,7 @@ public Builder filter(Builder builder) { TestingHttpServer httpErrorResponseServer; try { - httpErrorResponseServer = createTestingHttpErrorResponseServer(); + httpErrorResponseServer = createTestingHttpServer(new HttpErrorResponseServlet()); httpErrorResponseServer.start(); } catch (Exception e) { @@ -120,15 +117,6 @@ private void assertThrownAwsError(HttpStatus status) exception -> assertThat(exception.awsErrorDetails().errorCode()).isEqualTo(status.reason())); } - private static TestingHttpServer createTestingHttpErrorResponseServer() - throws IOException - { - NodeInfo nodeInfo = new NodeInfo("test"); - HttpServerConfig config = new HttpServerConfig().setHttpPort(0); - HttpServerInfo httpServerInfo = new HttpServerInfo(config, nodeInfo); - return new TestingHttpServer(httpServerInfo, nodeInfo, config, new HttpErrorResponseServlet(), ImmutableMap.of()); - } - private static class HttpErrorResponseServlet extends HttpServlet { diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingHttpCredentialsProviderServlet.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingHttpCredentialsProviderServlet.java new file mode 100644 index 00000000..df8f7db9 --- /dev/null +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingHttpCredentialsProviderServlet.java @@ -0,0 +1,98 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.aws.proxy.server.testing; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.airlift.json.ObjectMapperProvider; +import io.trino.aws.proxy.spi.credentials.Credential; +import io.trino.aws.proxy.spi.credentials.Credentials; +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import java.io.IOException; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicInteger; + +import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON; + +public class TestingHttpCredentialsProviderServlet + extends HttpServlet +{ + public static final String DUMMY_EMULATED_ACCESS_KEY = "test-emulated-access-key"; + public static final String DUMMY_EMULATED_SECRET_KEY = "test-emulated-secret-key"; + public static final String DUMMY_REMOTE_ACCESS_KEY = "test-remote-access-key"; + public static final String DUMMY_REMOTE_SECRET_KEY = "test-remote-secret-key"; + + private final Map expectedHeaders; + private final AtomicInteger requestCounter; + + public TestingHttpCredentialsProviderServlet(Map expectedHeaders) + { + this.expectedHeaders = ImmutableMap.copyOf(expectedHeaders); + this.requestCounter = new AtomicInteger(); + } + + @Override + protected void doGet(HttpServletRequest request, HttpServletResponse response) + throws IOException + { + requestCounter.addAndGet(1); + for (Map.Entry expectedHeader : expectedHeaders.entrySet()) { + if (!expectedHeader.getValue().equals(request.getHeader(expectedHeader.getKey()))) { + response.setStatus(HttpServletResponse.SC_BAD_REQUEST); + return; + } + } + + Optional sessionToken = Optional.ofNullable(request.getParameter("sessionToken")); + String emulatedAccessKey = request.getPathInfo().substring(1); + // The session token in the request is legal if it is either: + // - Not present + // - Matching our test logic: it should be equal to the access-key + "-token" + boolean isLegalSessionToken = sessionToken + .map(presentSessionToken -> "%s-token".formatted(emulatedAccessKey).equals(presentSessionToken)) + .orElse(true); + if (!isLegalSessionToken) { + response.setStatus(HttpServletResponse.SC_NOT_FOUND); + return; + } + switch (emulatedAccessKey) { + case DUMMY_EMULATED_ACCESS_KEY -> { + Credential emulated = new Credential(DUMMY_EMULATED_ACCESS_KEY, DUMMY_EMULATED_SECRET_KEY, sessionToken); + Credential remote = new Credential(DUMMY_REMOTE_ACCESS_KEY, DUMMY_REMOTE_SECRET_KEY); + Credentials credentials = new Credentials(emulated, Optional.of(remote), Optional.empty(), Optional.of(new TestingIdentity("test-username", ImmutableList.of(), "xyzpdq"))); + String jsonCredentials = new ObjectMapperProvider().get().writeValueAsString(credentials); + response.setContentType(APPLICATION_JSON); + response.getWriter().print(jsonCredentials); + } + case "incorrect-response" -> { + response.getWriter().print("incorrect response"); + } + default -> response.setStatus(HttpServletResponse.SC_NOT_FOUND); + } + } + + public int getRequestCount() + { + return requestCounter.get(); + } + + public void resetRequestCount() + { + requestCounter.set(0); + } +} diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingUtil.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingUtil.java index d41c5ed1..7d1a1015 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingUtil.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingUtil.java @@ -13,11 +13,17 @@ */ package io.trino.aws.proxy.server.testing; +import com.google.common.collect.ImmutableMap; import com.google.common.hash.Hashing; import com.google.common.io.Resources; import com.google.inject.BindingAnnotation; +import io.airlift.http.server.HttpServerConfig; +import io.airlift.http.server.HttpServerInfo; +import io.airlift.http.server.testing.TestingHttpServer; +import io.airlift.node.NodeInfo; import io.trino.aws.proxy.spi.credentials.Credential; import io.trino.aws.proxy.spi.credentials.Credentials; +import jakarta.servlet.Servlet; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.s3.S3ClientBuilder; @@ -148,4 +154,13 @@ public static String sha256(String content) { return Hashing.sha256().newHasher().putString(content, StandardCharsets.UTF_8).hash().toString(); } + + public static TestingHttpServer createTestingHttpServer(Servlet servlet) + throws IOException + { + NodeInfo nodeInfo = new NodeInfo("test"); + HttpServerConfig config = new HttpServerConfig().setHttpPort(0); + HttpServerInfo httpServerInfo = new HttpServerInfo(config, nodeInfo); + return new TestingHttpServer(httpServerInfo, nodeInfo, config, servlet, ImmutableMap.of()); + } }