diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/TrinoAwsProxyConfig.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/TrinoAwsProxyConfig.java index 4832cc37..f9e7be2e 100644 --- a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/TrinoAwsProxyConfig.java +++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/TrinoAwsProxyConfig.java @@ -15,6 +15,7 @@ import io.airlift.configuration.Config; import io.airlift.configuration.ConfigDescription; +import io.airlift.units.DataSize; import io.airlift.units.Duration; import io.airlift.units.MinDuration; import jakarta.validation.constraints.Min; @@ -33,6 +34,7 @@ public class TrinoAwsProxyConfig private Duration presignedUrlsDuration = new Duration(15, TimeUnit.MINUTES); private boolean generatePresignedUrlsOnHead = true; private int requestLoggerSavedQty = 10000; + private Optional maxPayloadSize = Optional.empty(); @Config("aws.proxy.s3.hostname") @ConfigDescription("Hostname to use for S3 REST operations, virtual-host style addressing is only supported if this is set") @@ -116,4 +118,18 @@ public TrinoAwsProxyConfig setRequestLoggerSavedQty(int requestLoggerSavedQty) this.requestLoggerSavedQty = requestLoggerSavedQty; return this; } + + @NotNull + public Optional getMaxPayloadSize() + { + return maxPayloadSize; + } + + @Config("aws.proxy.request.payload.max-size") + @ConfigDescription("Max request/response payload size, optional") + public TrinoAwsProxyConfig setMaxPayloadSize(DataSize maxPayloadSize) + { + this.maxPayloadSize = Optional.of(requireNonNull(maxPayloadSize, "requestByteQuota is null")); + return this; + } } diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/TrinoAwsProxyServerModule.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/TrinoAwsProxyServerModule.java index 59bfd30e..622a15d1 100644 --- a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/TrinoAwsProxyServerModule.java +++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/TrinoAwsProxyServerModule.java @@ -34,6 +34,7 @@ import io.trino.aws.proxy.server.credentials.file.FileBasedCredentialsModule; import io.trino.aws.proxy.server.credentials.http.HttpCredentialsModule; import io.trino.aws.proxy.server.remote.RemoteS3Module; +import io.trino.aws.proxy.server.rest.LimitStreamController; import io.trino.aws.proxy.server.rest.RequestFilter; import io.trino.aws.proxy.server.rest.RequestLoggerController; import io.trino.aws.proxy.server.rest.S3PresignController; @@ -91,6 +92,7 @@ protected void setup(Binder binder) binder.bind(CredentialsController.class).in(Scopes.SINGLETON); binder.bind(RequestLoggerController.class).in(Scopes.SINGLETON); + binder.bind(LimitStreamController.class).in(Scopes.SINGLETON); // TODO config, etc. httpClientBinder(binder).bindHttpClient("ProxyClient", ForProxyClient.class); diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/LimitStreamController.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/LimitStreamController.java new file mode 100644 index 00000000..9db751fd --- /dev/null +++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/LimitStreamController.java @@ -0,0 +1,170 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.aws.proxy.server.rest; + +import com.google.common.io.CountingInputStream; +import com.google.common.io.CountingOutputStream; +import com.google.inject.Inject; +import io.airlift.units.DataSize; +import io.trino.aws.proxy.server.TrinoAwsProxyConfig; +import jakarta.ws.rs.WebApplicationException; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Optional; + +import static jakarta.ws.rs.core.Response.Status.REQUEST_ENTITY_TOO_LARGE; + +public class LimitStreamController +{ + private final Optional quota; + + @Inject + public LimitStreamController(TrinoAwsProxyConfig trinoAwsProxyConfig) + { + quota = trinoAwsProxyConfig.getMaxPayloadSize(); + } + + public InputStream wrap(InputStream inputStream) + { + return quota.map(q -> internalWrap(inputStream, q.toBytes())).orElse(inputStream); + } + + private static InputStream internalWrap(InputStream inputStream, long quota) + { + CountingInputStream delegate = new CountingInputStream(inputStream); + return new InputStream() + { + @Override + public int read() + throws IOException + { + return validate(delegate.read()); + } + + @Override + public int read(byte[] b, int off, int len) + throws IOException + { + return validate(delegate.read(b, off, len)); + } + + @Override + public long skip(long n) + throws IOException + { + return validate(delegate.skip(n)); + } + + @Override + public void mark(int readlimit) + { + delegate.mark(readlimit); + validate(); + } + + @Override + public void reset() + throws IOException + { + delegate.reset(); + validate(); + } + + @Override + public boolean markSupported() + { + return validate(delegate.markSupported()); + } + + @Override + public void close() + throws IOException + { + delegate.close(); + } + + private void validate() + { + validate(null); + } + + private T validate(T value) + { + if (delegate.getCount() > quota) { + throw new WebApplicationException(REQUEST_ENTITY_TOO_LARGE); + } + return value; + } + }; + } + + public OutputStream wrap(OutputStream outputStream) + { + return quota.map(q -> internalWrap(outputStream, q.toBytes())).orElse(outputStream); + } + + private OutputStream internalWrap(OutputStream outputStream, long quota) + { + CountingOutputStream delegate = new CountingOutputStream(outputStream); + + return new OutputStream() + { + @Override + public void write(byte[] b) + throws IOException + { + delegate.write(b); + } + + @Override + public void write(byte[] b, int off, int len) + throws IOException + { + delegate.write(b, off, len); + validate(); + } + + @Override + public void flush() + throws IOException + { + delegate.flush(); + } + + @Override + public void close() + throws IOException + { + delegate.close(); + } + + @Override + public void write(int b) + throws IOException + { + delegate.write(b); + validate(); + } + + private void validate() + { + if (delegate.getCount() > quota) { + throw new WebApplicationException(REQUEST_ENTITY_TOO_LARGE); + } + } + }; + } +} diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/StreamingResponseHandler.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/StreamingResponseHandler.java index 09f48e49..2a6571f6 100644 --- a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/StreamingResponseHandler.java +++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/StreamingResponseHandler.java @@ -42,12 +42,14 @@ class StreamingResponseHandler private final Map presignedUrls; private final RequestLoggingSession requestLoggingSession; private final AtomicBoolean hasBeenResumed = new AtomicBoolean(false); + private final LimitStreamController limitStreamController; - StreamingResponseHandler(AsyncResponse asyncResponse, Map presignedUrls, RequestLoggingSession requestLoggingSession) + StreamingResponseHandler(AsyncResponse asyncResponse, Map presignedUrls, RequestLoggingSession requestLoggingSession, LimitStreamController limitStreamController) { this.asyncResponse = requireNonNull(asyncResponse, "asyncResponse is null"); this.presignedUrls = ImmutableMap.copyOf(presignedUrls); this.requestLoggingSession = requireNonNull(requestLoggingSession, "requestLoggingSession is null"); + this.limitStreamController = requireNonNull(limitStreamController, "quotaStreamController is null"); } @Override @@ -71,7 +73,7 @@ public Void handle(Request request, Response response) // HttpClient/Jersey timeouts control behavior. The configured HttpClient idle timeout // controls whether the InputStream will time out. Jersey configuration controls // OutputStream and general request timeouts. - inputStream.transferTo(output); + inputStream.transferTo(limitStreamController.wrap(output)); output.flush(); }; diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/TrinoS3ProxyClient.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/TrinoS3ProxyClient.java index 5250da8c..531068bc 100644 --- a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/TrinoS3ProxyClient.java +++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/TrinoS3ProxyClient.java @@ -67,6 +67,7 @@ public class TrinoS3ProxyClient private final RemoteS3Facade remoteS3Facade; private final S3SecurityController s3SecurityController; private final S3PresignController s3PresignController; + private final LimitStreamController limitStreamController; private final ExecutorService executorService = Executors.newVirtualThreadPerTaskExecutor(); private final boolean generatePresignedUrlsOnHead; @@ -82,13 +83,15 @@ public TrinoS3ProxyClient( RemoteS3Facade remoteS3Facade, S3SecurityController s3SecurityController, TrinoAwsProxyConfig trinoAwsProxyConfig, - S3PresignController s3PresignController) + S3PresignController s3PresignController, + LimitStreamController limitStreamController) { this.httpClient = requireNonNull(httpClient, "httpClient is null"); this.signingController = requireNonNull(signingController, "signingController is null"); this.remoteS3Facade = requireNonNull(remoteS3Facade, "objectStore is null"); this.s3SecurityController = requireNonNull(s3SecurityController, "securityController is null"); this.s3PresignController = requireNonNull(s3PresignController, "presignController is null"); + this.limitStreamController = requireNonNull(limitStreamController, "quotaStreamController is null"); generatePresignedUrlsOnHead = trinoAwsProxyConfig.isGeneratePresignedUrlsOnHead(); } @@ -172,7 +175,7 @@ public void proxyRequest(SigningMetadata signingMetadata, ParsedS3Request reques Request remoteRequest = remoteRequestBuilder.build(); executorService.submit(() -> { - StreamingResponseHandler responseHandler = new StreamingResponseHandler(asyncResponse, presignedUrls, requestLoggingSession); + StreamingResponseHandler responseHandler = new StreamingResponseHandler(asyncResponse, presignedUrls, requestLoggingSession, limitStreamController); try { httpClient.execute(remoteRequest, responseHandler); } @@ -185,13 +188,14 @@ public void proxyRequest(SigningMetadata signingMetadata, ParsedS3Request reques private Optional contentInputStream(RequestContent requestContent, SigningMetadata signingMetadata) { return switch (requestContent.contentType()) { - case AWS_CHUNKED, AWS_CHUNKED_IN_W3C_CHUNKED -> requestContent.inputStream().map(inputStream -> new AwsChunkedInputStream(inputStream, signingMetadata.requiredSigningContext().chunkSigningSession(), requestContent.contentLength().orElseThrow())); + case AWS_CHUNKED, AWS_CHUNKED_IN_W3C_CHUNKED -> requestContent.inputStream() + .map(inputStream -> new AwsChunkedInputStream(limitStreamController.wrap(inputStream), signingMetadata.requiredSigningContext().chunkSigningSession(), requestContent.contentLength().orElseThrow())); case STANDARD, W3C_CHUNKED -> requestContent.inputStream().map(inputStream -> { SigningContext signingContext = signingMetadata.requiredSigningContext(); return signingContext.contentHash() .filter(contentHash -> !contentHash.startsWith("STREAMING-") && !contentHash.startsWith("UNSIGNED-")) - .map(contentHash -> (InputStream) new HashCheckInputStream(inputStream, contentHash, requestContent.contentLength())) + .map(contentHash -> (InputStream) new HashCheckInputStream(limitStreamController.wrap(inputStream), contentHash, requestContent.contentLength())) .orElse(inputStream); }); diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestMaxPayloadSize.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestMaxPayloadSize.java new file mode 100644 index 00000000..56958cd7 --- /dev/null +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestMaxPayloadSize.java @@ -0,0 +1,84 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.aws.proxy.server; + +import com.google.common.collect.ImmutableList; +import com.google.inject.Inject; +import io.trino.aws.proxy.server.testing.TestingTrinoAwsProxyServer; +import io.trino.aws.proxy.server.testing.containers.S3Container.ForS3Container; +import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTest; +import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTestCommonModules.WithConfiguredBuckets; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.core.exception.SdkServiceException; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.S3Exception; + +import java.util.List; + +import static io.airlift.http.client.HttpStatus.REQUEST_ENTITY_TOO_LARGE; +import static java.util.Objects.requireNonNull; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.InstanceOfAssertFactories.type; + +@TrinoAwsProxyTest(filters = TestMaxPayloadSize.Filter.class) +public class TestMaxPayloadSize +{ + private final S3Client s3Client; + private final S3Client storageClient; + private final List configuredBuckets; + + public static class Filter + extends WithConfiguredBuckets + { + @Override + public TestingTrinoAwsProxyServer.Builder filter(TestingTrinoAwsProxyServer.Builder builder) + { + return super.filter(builder).withProperty("aws.proxy.request.payload.max-size", "1B"); + } + } + + @Inject + public TestMaxPayloadSize(S3Client s3Client, @ForS3Container S3Client storageClient, @ForS3Container List configuredBuckets) + { + this.s3Client = requireNonNull(s3Client, "s3Client is null"); + this.storageClient = requireNonNull(storageClient, "storageClient is null"); + this.configuredBuckets = ImmutableList.copyOf(configuredBuckets); + } + + @Test + public void testLimitOnPut() + { + PutObjectRequest putObjectRequest = PutObjectRequest.builder().bucket(configuredBuckets.getFirst()).key("put").build(); + assertThatThrownBy(() -> s3Client.putObject(putObjectRequest, RequestBody.fromString("this is too big"))) + .asInstanceOf(type(S3Exception.class)) + .extracting(SdkServiceException::statusCode) + .isEqualTo(REQUEST_ENTITY_TOO_LARGE.code()); + } + + @Test + public void testLimitOnGet() + { + PutObjectRequest putObjectRequest = PutObjectRequest.builder().bucket(configuredBuckets.getFirst()).key("get").build(); + storageClient.putObject(putObjectRequest, RequestBody.fromString("this is too big for get")); + + GetObjectRequest getObjectRequest = GetObjectRequest.builder().bucket(configuredBuckets.getFirst()).key("get").build(); + assertThatThrownBy(() -> s3Client.getObject(getObjectRequest).readAllBytes()) + .asInstanceOf(type(S3Exception.class)) + .extracting(SdkServiceException::statusCode) + .isEqualTo(REQUEST_ENTITY_TOO_LARGE.code()); + } +} diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/rest/HangingResource.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/rest/HangingResource.java index c6839d5b..eb7b3cf4 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/rest/HangingResource.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/rest/HangingResource.java @@ -17,6 +17,7 @@ import com.google.inject.Inject; import io.airlift.http.client.HttpClient; import io.airlift.http.client.Request; +import io.trino.aws.proxy.server.TrinoAwsProxyConfig; import io.trino.aws.proxy.server.rest.TestHangingStreamingResponseHandler.ForTimeout; import jakarta.annotation.PreDestroy; import jakarta.ws.rs.GET; @@ -63,7 +64,7 @@ public void callHangingRequest(@Context UriInfo uriInfo, @Suspended AsyncRespons { // simulate calling a remote request and streaming the result while the remote server hangs Request request = prepareGet().setUri(uriInfo.getBaseUri().resolve("hang")).build(); - httpClient.execute(request, new StreamingResponseHandler(asyncResponse, ImmutableMap.of(), () -> {})); + httpClient.execute(request, new StreamingResponseHandler(asyncResponse, ImmutableMap.of(), () -> {}, new LimitStreamController(new TrinoAwsProxyConfig()))); } @GET