Skip to content

Commit

Permalink
Add support for validating JDBC connections
Browse files Browse the repository at this point in the history
  • Loading branch information
xsgao-github authored and wendigo committed Jan 10, 2025
1 parent fb075d0 commit 2e55865
Show file tree
Hide file tree
Showing 10 changed files with 366 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ enum SslVerificationMode
public static final ConnectionProperty<String, LoggingLevel> HTTP_LOGGING_LEVEL = new HttpLoggingLevel();
public static final ConnectionProperty<String, Map<String, String>> RESOURCE_ESTIMATES = new ResourceEstimates();
public static final ConnectionProperty<String, List<String>> SQL_PATH = new SqlPath();
public static final ConnectionProperty<String, Boolean> VALIDATE_CONNECTION = new ValidateConnection();

private static final Set<ConnectionProperty<?, ?>> ALL_PROPERTIES = ImmutableSet.<ConnectionProperty<?, ?>>builder()
// Keep sorted
Expand Down Expand Up @@ -172,6 +173,7 @@ enum SslVerificationMode
.add(TIMEZONE)
.add(TRACE_TOKEN)
.add(USER)
.add(VALIDATE_CONNECTION)
.build();

private static final Map<String, ConnectionProperty<?, ?>> KEY_LOOKUP = unmodifiableMap(ALL_PROPERTIES.stream()
Expand Down Expand Up @@ -590,6 +592,15 @@ public KerberosRemoteServiceName()
}
}

private static class ValidateConnection
extends AbstractConnectionProperty<String, Boolean>
{
public ValidateConnection()
{
super(PropertyName.VALIDATE_CONNECTION, NOT_REQUIRED, ALLOWED, BOOLEAN_CONVERTER);
}
}

private static Predicate<Properties> isKerberosEnabled()
{
return properties -> KERBEROS_REMOTE_SERVICE_NAME.getValue(properties).isPresent();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ public enum PropertyName
TIMEOUT("timeout"),
TIMEZONE("timezone"),
TRACE_TOKEN("traceToken"),
USER("user");
USER("user"),
VALIDATE_CONNECTION("validateConnection");

private final String key;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
import static io.trino.client.uri.ConnectionProperties.TIMEZONE;
import static io.trino.client.uri.ConnectionProperties.TRACE_TOKEN;
import static io.trino.client.uri.ConnectionProperties.USER;
import static io.trino.client.uri.ConnectionProperties.VALIDATE_CONNECTION;
import static io.trino.client.uri.LoggingLevel.NONE;
import static java.lang.String.CASE_INSENSITIVE_ORDER;
import static java.lang.String.format;
Expand Down Expand Up @@ -461,6 +462,11 @@ public LoggingLevel getHttpLoggingLevel()
return resolveWithDefault(HTTP_LOGGING_LEVEL, NONE);
}

public boolean isValidateConnection()
{
return resolveWithDefault(VALIDATE_CONNECTION, false);
}

private Map<String, String> getResourceEstimates()
{
return resolveWithDefault(RESOURCE_ESTIMATES, ImmutableMap.of());
Expand Down Expand Up @@ -1047,6 +1053,11 @@ public Builder setPath(List<String> path)
return setProperty(SQL_PATH, requireNonNull(path, "path is null"));
}

public Builder setValidateConnection(boolean value)
{
return setProperty(VALIDATE_CONNECTION, value);
}

<V, T> Builder setProperty(ConnectionProperty<V, T> connectionProperty, T value)
{
properties.put(connectionProperty.getKey(), connectionProperty.encodeValue(value));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,17 @@ public void testDefaultPorts()
assertThat(secureUri.getHttpUri()).isEqualTo(URI.create("https://localhost:443"));
}

@Test
public void testValidateConnection()
{
TrinoUri uri = createTrinoUri("trino://localhost:8080");
assertThat(uri.isValidateConnection()).isFalse();
uri = createTrinoUri("trino://localhost:8080?validateConnection=true");
assertThat(uri.isValidateConnection()).isTrue();
uri = createTrinoUri("trino://localhost:8080?validateConnection=false");
assertThat(uri.isValidateConnection()).isFalse();
}

private static boolean isBuilderHelperMethod(String name)
{
if (name.equals("setSslVerificationNone")) {
Expand Down
106 changes: 105 additions & 1 deletion client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@
import io.trino.client.StatementClient;
import jakarta.annotation.Nullable;
import okhttp3.Call;
import okhttp3.HttpUrl;
import okhttp3.Request;
import okhttp3.Response;

import java.io.IOException;
import java.io.InterruptedIOException;
import java.net.ProtocolException;
import java.net.URI;
import java.nio.charset.CharsetEncoder;
import java.sql.Array;
Expand Down Expand Up @@ -56,6 +62,7 @@
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
Expand All @@ -66,13 +73,18 @@
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Strings.nullToEmpty;
import static com.google.common.base.Throwables.getCausalChain;
import static com.google.common.collect.Maps.fromProperties;
import static io.airlift.units.Duration.nanosSince;
import static io.trino.client.StatementClientFactory.newStatementClient;
import static io.trino.jdbc.ClientInfoProperty.APPLICATION_NAME;
import static io.trino.jdbc.ClientInfoProperty.CLIENT_INFO;
import static io.trino.jdbc.ClientInfoProperty.CLIENT_TAGS;
import static io.trino.jdbc.ClientInfoProperty.TRACE_TOKEN;
import static java.lang.String.format;
import static java.net.HttpURLConnection.HTTP_BAD_METHOD;
import static java.net.HttpURLConnection.HTTP_OK;
import static java.net.HttpURLConnection.HTTP_UNAUTHORIZED;
import static java.nio.charset.StandardCharsets.US_ASCII;
import static java.util.Collections.newSetFromMap;
import static java.util.Objects.requireNonNull;
Expand All @@ -85,6 +97,8 @@ public class TrinoConnection
{
private static final Logger logger = Logger.getLogger(TrinoConnection.class.getPackage().getName());

private static final int CONNECTION_TIMEOUT_SECONDS = 30; // Not configurable

private final AtomicBoolean closed = new AtomicBoolean();
private final AtomicBoolean autoCommit = new AtomicBoolean(true);
private final AtomicInteger isolationLevel = new AtomicInteger(TRANSACTION_READ_UNCOMMITTED);
Expand Down Expand Up @@ -119,8 +133,10 @@ public class TrinoConnection
private final Set<TrinoStatement> statements = newSetFromMap(new ConcurrentHashMap<>());
private boolean useExplicitPrepare = true;
private boolean assumeNullCatalogMeansCurrentCatalog;
private final boolean validateConnection;

TrinoConnection(TrinoDriverUri uri, Call.Factory httpCallFactory, Call.Factory segmentHttpCallFactory)
throws SQLException
{
requireNonNull(uri, "uri is null");
this.jdbcUri = uri.getUri();
Expand Down Expand Up @@ -156,6 +172,67 @@ public class TrinoConnection

uri.getExplicitPrepare().ifPresent(value -> this.useExplicitPrepare = value);
uri.getAssumeNullCatalogMeansCurrentCatalog().ifPresent(value -> this.assumeNullCatalogMeansCurrentCatalog = value);

this.validateConnection = uri.isValidateConnection();
if (validateConnection) {
try {
if (!isConnectionValid(CONNECTION_TIMEOUT_SECONDS)) {
throw new SQLException("Invalid authentication to Trino server", "28000");
}
}
catch (UnsupportedOperationException | IOException e) {
throw new SQLException("Unable to connect to Trino server", "08001", e);
}
}
}

private boolean isConnectionValid(int timeout)
throws IOException, UnsupportedOperationException
{
HttpUrl url = HttpUrl.get(httpUri)
.newBuilder()
.encodedPath("/v1/statement")
.build();

Request headRequest = new Request.Builder()
.url(url)
.head()
.build();

Exception lastException = null;
Duration timeoutDuration = new Duration(timeout, TimeUnit.SECONDS);
long start = System.nanoTime();

while (timeoutDuration.compareTo(nanosSince(start)) > 0) {
try (Response response = httpCallFactory.newCall(headRequest).execute()) {
switch (response.code()) {
case HTTP_OK:
return true;
case HTTP_UNAUTHORIZED:
return false;
case HTTP_BAD_METHOD:
throw new UnsupportedOperationException("Trino server does not support HEAD /v1/statement");
}

try {
MILLISECONDS.sleep(250);
}
catch (InterruptedException e) {
Thread.currentThread().interrupt();
return false;
}
}
catch (IOException e) {
if (getCausalChain(e).stream().anyMatch(TrinoConnection::isTransientConnectionValidationException)) {
lastException = e;
}
else {
throw e;
}
}
}

throw new IOException(format("Connection validation timed out after %ss", timeout), lastException);
}

@Override
Expand Down Expand Up @@ -528,7 +605,26 @@ public boolean isValid(int timeout)
if (timeout < 0) {
throw new SQLException("Timeout is negative");
}
return !isClosed();

if (isClosed()) {
return false;
}

if (!validateConnection) {
return true;
}

try {
return isConnectionValid(timeout);
}
catch (UnsupportedOperationException e) {
logger.log(Level.FINE, "Trino server does not support connection validation", e);
return false;
}
catch (IOException e) {
logger.log(Level.FINE, "Connection validation has failed", e);
return false;
}
}

@Override
Expand Down Expand Up @@ -901,6 +997,14 @@ else if (applicationName != null) {
return source;
}

private static boolean isTransientConnectionValidationException(Throwable e)
{
if (e instanceof InterruptedIOException && e.getMessage().equals("timeout")) {
return true;
}
return e instanceof ProtocolException;
}

private static final class SqlExceptionHolder
{
@Nullable
Expand Down
Loading

0 comments on commit 2e55865

Please sign in to comment.