Skip to content

Commit

Permalink
Support max bytes quota for requests
Browse files Browse the repository at this point in the history
`s3proxy.sts.request.quota.byte-qty`

Closes #95
  • Loading branch information
Randgalt committed Aug 14, 2024
1 parent 7f7618d commit 46f239d
Show file tree
Hide file tree
Showing 7 changed files with 286 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import io.airlift.configuration.Config;
import io.airlift.configuration.ConfigDescription;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;
import io.airlift.units.MinDuration;
import jakarta.validation.constraints.Min;
Expand All @@ -33,6 +34,7 @@ public class TrinoAwsProxyConfig
private Duration presignedUrlsDuration = new Duration(15, TimeUnit.MINUTES);
private boolean generatePresignedUrlsOnHead = true;
private int requestLoggerSavedQty = 10000;
private Optional<DataSize> maxPayloadSize = Optional.empty();

@Config("aws.proxy.s3.hostname")
@ConfigDescription("Hostname to use for S3 REST operations, virtual-host style addressing is only supported if this is set")
Expand Down Expand Up @@ -116,4 +118,18 @@ public TrinoAwsProxyConfig setRequestLoggerSavedQty(int requestLoggerSavedQty)
this.requestLoggerSavedQty = requestLoggerSavedQty;
return this;
}

@NotNull
public Optional<DataSize> getMaxPayloadSize()
{
return maxPayloadSize;
}

@Config("aws.proxy.request.payload.max-size")
@ConfigDescription("Max request/response payload size, optional")
public TrinoAwsProxyConfig setMaxPayloadSize(DataSize maxPayloadSize)
{
this.maxPayloadSize = Optional.of(requireNonNull(maxPayloadSize, "requestByteQuota is null"));
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import io.trino.aws.proxy.server.credentials.file.FileBasedCredentialsModule;
import io.trino.aws.proxy.server.credentials.http.HttpCredentialsModule;
import io.trino.aws.proxy.server.remote.RemoteS3Module;
import io.trino.aws.proxy.server.rest.LimitStreamController;
import io.trino.aws.proxy.server.rest.RequestFilter;
import io.trino.aws.proxy.server.rest.RequestLoggerController;
import io.trino.aws.proxy.server.rest.S3PresignController;
Expand Down Expand Up @@ -91,6 +92,7 @@ protected void setup(Binder binder)

binder.bind(CredentialsController.class).in(Scopes.SINGLETON);
binder.bind(RequestLoggerController.class).in(Scopes.SINGLETON);
binder.bind(LimitStreamController.class).in(Scopes.SINGLETON);

// TODO config, etc.
httpClientBinder(binder).bindHttpClient("ProxyClient", ForProxyClient.class);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.aws.proxy.server.rest;

import com.google.common.io.CountingInputStream;
import com.google.common.io.CountingOutputStream;
import com.google.inject.Inject;
import io.airlift.units.DataSize;
import io.trino.aws.proxy.server.TrinoAwsProxyConfig;
import jakarta.ws.rs.WebApplicationException;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Optional;

import static jakarta.ws.rs.core.Response.Status.REQUEST_ENTITY_TOO_LARGE;

public class LimitStreamController
{
private final Optional<DataSize> quota;

@Inject
public LimitStreamController(TrinoAwsProxyConfig trinoAwsProxyConfig)
{
quota = trinoAwsProxyConfig.getMaxPayloadSize();
}

public InputStream wrap(InputStream inputStream)
{
return quota.map(q -> internalWrap(inputStream, q.toBytes())).orElse(inputStream);
}

private static InputStream internalWrap(InputStream inputStream, long quota)
{
CountingInputStream delegate = new CountingInputStream(inputStream);
return new InputStream()
{
@Override
public int read()
throws IOException
{
return validate(delegate.read());
}

@Override
public int read(byte[] b, int off, int len)
throws IOException
{
return validate(delegate.read(b, off, len));
}

@Override
public long skip(long n)
throws IOException
{
return validate(delegate.skip(n));
}

@Override
public void mark(int readlimit)
{
delegate.mark(readlimit);
validate();
}

@Override
public void reset()
throws IOException
{
delegate.reset();
validate();
}

@Override
public boolean markSupported()
{
return validate(delegate.markSupported());
}

@Override
public void close()
throws IOException
{
delegate.close();
}

private void validate()
{
validate(null);
}

private <T> T validate(T value)
{
if (delegate.getCount() > quota) {
throw new WebApplicationException(REQUEST_ENTITY_TOO_LARGE);
}
return value;
}
};
}

public OutputStream wrap(OutputStream outputStream)
{
return quota.map(q -> internalWrap(outputStream, q.toBytes())).orElse(outputStream);
}

private OutputStream internalWrap(OutputStream outputStream, long quota)
{
CountingOutputStream delegate = new CountingOutputStream(outputStream);

return new OutputStream()
{
@Override
public void write(byte[] b)
throws IOException
{
delegate.write(b);
}

@Override
public void write(byte[] b, int off, int len)
throws IOException
{
delegate.write(b, off, len);
validate();
}

@Override
public void flush()
throws IOException
{
delegate.flush();
}

@Override
public void close()
throws IOException
{
delegate.close();
}

@Override
public void write(int b)
throws IOException
{
delegate.write(b);
validate();
}

private void validate()
{
if (delegate.getCount() > quota) {
throw new WebApplicationException(REQUEST_ENTITY_TOO_LARGE);
}
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,14 @@ class StreamingResponseHandler
private final Map<String, URI> presignedUrls;
private final RequestLoggingSession requestLoggingSession;
private final AtomicBoolean hasBeenResumed = new AtomicBoolean(false);
private final LimitStreamController limitStreamController;

StreamingResponseHandler(AsyncResponse asyncResponse, Map<String, URI> presignedUrls, RequestLoggingSession requestLoggingSession)
StreamingResponseHandler(AsyncResponse asyncResponse, Map<String, URI> presignedUrls, RequestLoggingSession requestLoggingSession, LimitStreamController limitStreamController)
{
this.asyncResponse = requireNonNull(asyncResponse, "asyncResponse is null");
this.presignedUrls = ImmutableMap.copyOf(presignedUrls);
this.requestLoggingSession = requireNonNull(requestLoggingSession, "requestLoggingSession is null");
this.limitStreamController = requireNonNull(limitStreamController, "quotaStreamController is null");
}

@Override
Expand All @@ -71,7 +73,7 @@ public Void handle(Request request, Response response)
// HttpClient/Jersey timeouts control behavior. The configured HttpClient idle timeout
// controls whether the InputStream will time out. Jersey configuration controls
// OutputStream and general request timeouts.
inputStream.transferTo(output);
inputStream.transferTo(limitStreamController.wrap(output));
output.flush();
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ public class TrinoS3ProxyClient
private final RemoteS3Facade remoteS3Facade;
private final S3SecurityController s3SecurityController;
private final S3PresignController s3PresignController;
private final LimitStreamController limitStreamController;
private final ExecutorService executorService = Executors.newVirtualThreadPerTaskExecutor();
private final boolean generatePresignedUrlsOnHead;

Expand All @@ -82,13 +83,15 @@ public TrinoS3ProxyClient(
RemoteS3Facade remoteS3Facade,
S3SecurityController s3SecurityController,
TrinoAwsProxyConfig trinoAwsProxyConfig,
S3PresignController s3PresignController)
S3PresignController s3PresignController,
LimitStreamController limitStreamController)
{
this.httpClient = requireNonNull(httpClient, "httpClient is null");
this.signingController = requireNonNull(signingController, "signingController is null");
this.remoteS3Facade = requireNonNull(remoteS3Facade, "objectStore is null");
this.s3SecurityController = requireNonNull(s3SecurityController, "securityController is null");
this.s3PresignController = requireNonNull(s3PresignController, "presignController is null");
this.limitStreamController = requireNonNull(limitStreamController, "quotaStreamController is null");

generatePresignedUrlsOnHead = trinoAwsProxyConfig.isGeneratePresignedUrlsOnHead();
}
Expand Down Expand Up @@ -172,7 +175,7 @@ public void proxyRequest(SigningMetadata signingMetadata, ParsedS3Request reques
Request remoteRequest = remoteRequestBuilder.build();

executorService.submit(() -> {
StreamingResponseHandler responseHandler = new StreamingResponseHandler(asyncResponse, presignedUrls, requestLoggingSession);
StreamingResponseHandler responseHandler = new StreamingResponseHandler(asyncResponse, presignedUrls, requestLoggingSession, limitStreamController);
try {
httpClient.execute(remoteRequest, responseHandler);
}
Expand All @@ -185,13 +188,14 @@ public void proxyRequest(SigningMetadata signingMetadata, ParsedS3Request reques
private Optional<InputStream> contentInputStream(RequestContent requestContent, SigningMetadata signingMetadata)
{
return switch (requestContent.contentType()) {
case AWS_CHUNKED, AWS_CHUNKED_IN_W3C_CHUNKED -> requestContent.inputStream().map(inputStream -> new AwsChunkedInputStream(inputStream, signingMetadata.requiredSigningContext().chunkSigningSession(), requestContent.contentLength().orElseThrow()));
case AWS_CHUNKED, AWS_CHUNKED_IN_W3C_CHUNKED -> requestContent.inputStream()
.map(inputStream -> new AwsChunkedInputStream(limitStreamController.wrap(inputStream), signingMetadata.requiredSigningContext().chunkSigningSession(), requestContent.contentLength().orElseThrow()));

case STANDARD, W3C_CHUNKED -> requestContent.inputStream().map(inputStream -> {
SigningContext signingContext = signingMetadata.requiredSigningContext();
return signingContext.contentHash()
.filter(contentHash -> !contentHash.startsWith("STREAMING-") && !contentHash.startsWith("UNSIGNED-"))
.map(contentHash -> (InputStream) new HashCheckInputStream(inputStream, contentHash, requestContent.contentLength()))
.map(contentHash -> (InputStream) new HashCheckInputStream(limitStreamController.wrap(inputStream), contentHash, requestContent.contentLength()))
.orElse(inputStream);
});

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.aws.proxy.server;

import com.google.common.collect.ImmutableList;
import com.google.inject.Inject;
import io.trino.aws.proxy.server.testing.TestingTrinoAwsProxyServer;
import io.trino.aws.proxy.server.testing.containers.S3Container.ForS3Container;
import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTest;
import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTestCommonModules.WithConfiguredBuckets;
import org.junit.jupiter.api.Test;
import software.amazon.awssdk.core.exception.SdkServiceException;
import software.amazon.awssdk.core.sync.RequestBody;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
import software.amazon.awssdk.services.s3.model.S3Exception;

import java.util.List;

import static io.airlift.http.client.HttpStatus.REQUEST_ENTITY_TOO_LARGE;
import static java.util.Objects.requireNonNull;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.assertj.core.api.InstanceOfAssertFactories.type;

@TrinoAwsProxyTest(filters = TestMaxPayloadSize.Filter.class)
public class TestMaxPayloadSize
{
private final S3Client s3Client;
private final S3Client storageClient;
private final List<String> configuredBuckets;

public static class Filter
extends WithConfiguredBuckets
{
@Override
public TestingTrinoAwsProxyServer.Builder filter(TestingTrinoAwsProxyServer.Builder builder)
{
return super.filter(builder).withProperty("aws.proxy.request.payload.max-size", "1B");
}
}

@Inject
public TestMaxPayloadSize(S3Client s3Client, @ForS3Container S3Client storageClient, @ForS3Container List<String> configuredBuckets)
{
this.s3Client = requireNonNull(s3Client, "s3Client is null");
this.storageClient = requireNonNull(storageClient, "storageClient is null");
this.configuredBuckets = ImmutableList.copyOf(configuredBuckets);
}

@Test
public void testLimitOnPut()
{
PutObjectRequest putObjectRequest = PutObjectRequest.builder().bucket(configuredBuckets.getFirst()).key("put").build();
assertThatThrownBy(() -> s3Client.putObject(putObjectRequest, RequestBody.fromString("this is too big")))
.asInstanceOf(type(S3Exception.class))
.extracting(SdkServiceException::statusCode)
.isEqualTo(REQUEST_ENTITY_TOO_LARGE.code());
}

@Test
public void testLimitOnGet()
{
PutObjectRequest putObjectRequest = PutObjectRequest.builder().bucket(configuredBuckets.getFirst()).key("get").build();
storageClient.putObject(putObjectRequest, RequestBody.fromString("this is too big for get"));

GetObjectRequest getObjectRequest = GetObjectRequest.builder().bucket(configuredBuckets.getFirst()).key("get").build();
assertThatThrownBy(() -> s3Client.getObject(getObjectRequest).readAllBytes())
.asInstanceOf(type(S3Exception.class))
.extracting(SdkServiceException::statusCode)
.isEqualTo(REQUEST_ENTITY_TOO_LARGE.code());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.google.inject.Inject;
import io.airlift.http.client.HttpClient;
import io.airlift.http.client.Request;
import io.trino.aws.proxy.server.TrinoAwsProxyConfig;
import io.trino.aws.proxy.server.rest.TestHangingStreamingResponseHandler.ForTimeout;
import jakarta.annotation.PreDestroy;
import jakarta.ws.rs.GET;
Expand Down Expand Up @@ -63,7 +64,7 @@ public void callHangingRequest(@Context UriInfo uriInfo, @Suspended AsyncRespons
{
// simulate calling a remote request and streaming the result while the remote server hangs
Request request = prepareGet().setUri(uriInfo.getBaseUri().resolve("hang")).build();
httpClient.execute(request, new StreamingResponseHandler(asyncResponse, ImmutableMap.of(), () -> {}));
httpClient.execute(request, new StreamingResponseHandler(asyncResponse, ImmutableMap.of(), () -> {}, new LimitStreamController(new TrinoAwsProxyConfig())));
}

@GET
Expand Down

0 comments on commit 46f239d

Please sign in to comment.