Skip to content

Commit

Permalink
Implement complex join pushdown in JDBC connectors
Browse files Browse the repository at this point in the history
Implement non-deprecated `ConnectorMetadata.applyJoin` overload in
`DefaultJdbcMetadata`. Thew old implementation is retained as a safety
valve. The new implementation is not limited to the
`List<JdbcJoinCondition>` model, so allows pushdown of joins involving
more complex expressions, such as arithmetics.

The `BaseJdbcClient.implementJoin` and
`QueryBuilder.prepareJoinQuery` methods logically changed, but the old
implementation is left as the fallback. These methods were extension
points, so the old implementations are renamed to ensure implementors
are updated. For example, if an implementation was overriding
`BaseJdbcClient.implementJoin` it most likely wants to override the new
`implementJoin` method as well, and this is reminded about by rename of
the old method.
  • Loading branch information
findepi committed Jan 10, 2024
1 parent 4f1cc7f commit 1db64a7
Show file tree
Hide file tree
Showing 25 changed files with 441 additions and 111 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,10 @@ public QueryAssert isNotFullyPushedDown(PlanMatchPattern retainedSubplan)

/**
* Verifies join query is not fully pushed down by containing JOIN node.
*
* @deprecated because the method is not tested in BaseQueryAssertionsTest yet
*/
@Deprecated
@CanIgnoreReturnValue
public QueryAssert joinIsNotFullyPushedDown()
{
Expand All @@ -580,6 +583,7 @@ public QueryAssert joinIsNotFullyPushedDown()
.whereIsInstanceOfAny(JoinNode.class)
.findFirst()
.isEmpty()) {
// TODO show then plan when assertions fails (like hasPlan()) and add negative test coverage in BaseQueryAssertionsTest
throw new IllegalStateException("Join node should be present in explain plan, when pushdown is not applied");
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ public static List<ConnectorExpression> extractConjuncts(ConnectorExpression exp

private static void extractConjuncts(ConnectorExpression expression, ImmutableList.Builder<ConnectorExpression> resultBuilder)
{
if (expression.equals(TRUE)) {
// Skip useless conjuncts.
return;
}
if (expression instanceof Call call) {
if (AND_FUNCTION_NAME.equals(call.getFunctionName())) {
for (ConnectorExpression argument : call.getArguments()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -527,29 +527,23 @@ public Optional<PreparedQuery> implementJoin(
ConnectorSession session,
JoinType joinType,
PreparedQuery leftSource,
Map<JdbcColumnHandle, String> leftProjections,
PreparedQuery rightSource,
List<JdbcJoinCondition> joinConditions,
Map<JdbcColumnHandle, String> rightAssignments,
Map<JdbcColumnHandle, String> leftAssignments,
Map<JdbcColumnHandle, String> rightProjections,
List<ParameterizedExpression> joinConditions,
JoinStatistics statistics)
{
for (JdbcJoinCondition joinCondition : joinConditions) {
if (!isSupportedJoinCondition(session, joinCondition)) {
return Optional.empty();
}
}

try (Connection connection = this.connectionFactory.openConnection(session)) {
return Optional.of(queryBuilder.prepareJoinQuery(
this,
session,
connection,
joinType,
leftSource,
leftProjections,
rightSource,
joinConditions,
leftAssignments,
rightAssignments));
rightProjections,
joinConditions));
}
catch (SQLException e) {
throw new TrinoException(JDBC_ERROR, e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,13 +284,13 @@ public Optional<PreparedQuery> implementJoin(
ConnectorSession session,
JoinType joinType,
PreparedQuery leftSource,
Map<JdbcColumnHandle, String> leftProjections,
PreparedQuery rightSource,
List<JdbcJoinCondition> joinConditions,
Map<JdbcColumnHandle, String> rightAssignments,
Map<JdbcColumnHandle, String> leftAssignments,
Map<JdbcColumnHandle, String> rightProjections,
List<ParameterizedExpression> joinConditions,
JoinStatistics statistics)
{
return delegate.implementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics);
return delegate.implementJoin(session, joinType, leftSource, leftProjections, rightSource, rightProjections, joinConditions, statistics);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import io.trino.spi.connector.Constraint;
import io.trino.spi.connector.ConstraintApplicationResult;
import io.trino.spi.connector.JoinApplicationResult;
import io.trino.spi.connector.JoinCondition;
import io.trino.spi.connector.JoinStatistics;
import io.trino.spi.connector.JoinType;
import io.trino.spi.connector.LimitApplicationResult;
Expand Down Expand Up @@ -73,6 +72,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.OptionalLong;
Expand Down Expand Up @@ -442,7 +442,7 @@ public Optional<JoinApplicationResult<ConnectorTableHandle>> applyJoin(
JoinType joinType,
ConnectorTableHandle left,
ConnectorTableHandle right,
List<JoinCondition> joinConditions,
ConnectorExpression joinCondition,
Map<String, ColumnHandle> leftAssignments,
Map<String, ColumnHandle> rightAssignments,
JoinStatistics statistics)
Expand Down Expand Up @@ -478,26 +478,32 @@ public Optional<JoinApplicationResult<ConnectorTableHandle>> applyJoin(
}
Map<JdbcColumnHandle, JdbcColumnHandle> newRightColumns = newRightColumnsBuilder.buildOrThrow();

ImmutableList.Builder<JdbcJoinCondition> jdbcJoinConditions = ImmutableList.builder();
for (JoinCondition joinCondition : joinConditions) {
Optional<JdbcColumnHandle> leftColumn = getVariableColumnHandle(leftAssignments, joinCondition.getLeftExpression());
Optional<JdbcColumnHandle> rightColumn = getVariableColumnHandle(rightAssignments, joinCondition.getRightExpression());
if (leftColumn.isEmpty() || rightColumn.isEmpty()) {
Map<String, ColumnHandle> assignments = ImmutableMap.<String, ColumnHandle>builder()
.putAll(leftAssignments.entrySet().stream()
.collect(toImmutableMap(Entry::getKey, entry -> newLeftColumns.get((JdbcColumnHandle) entry.getValue()))))
.putAll(rightAssignments.entrySet().stream()
.collect(toImmutableMap(Entry::getKey, entry -> newRightColumns.get((JdbcColumnHandle) entry.getValue()))))
.buildOrThrow();

ImmutableList.Builder<ParameterizedExpression> joinConditions = ImmutableList.builder();
for (ConnectorExpression conjunct : extractConjuncts(joinCondition)) {
Optional<ParameterizedExpression> converted = jdbcClient.convertPredicate(session, conjunct, assignments);
if (converted.isEmpty()) {
return Optional.empty();
}
jdbcJoinConditions.add(new JdbcJoinCondition(leftColumn.get(), joinCondition.getOperator(), rightColumn.get()));
joinConditions.add(converted.get());
}

Optional<PreparedQuery> joinQuery = jdbcClient.implementJoin(
session,
joinType,
asPreparedQuery(leftHandle),
newLeftColumns.entrySet().stream()
.collect(toImmutableMap(Entry::getKey, entry -> entry.getValue().getColumnName())),
asPreparedQuery(rightHandle),
jdbcJoinConditions.build(),
newRightColumns.entrySet().stream()
.collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().getColumnName())),
newLeftColumns.entrySet().stream()
.collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().getColumnName())),
.collect(toImmutableMap(Entry::getKey, entry -> entry.getValue().getColumnName())),
joinConditions.build(),
statistics);

if (joinQuery.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,35 +117,32 @@ public PreparedQuery prepareJoinQuery(
Connection connection,
JoinType joinType,
PreparedQuery leftSource,
Map<JdbcColumnHandle, String> leftProjections,
PreparedQuery rightSource,
List<JdbcJoinCondition> joinConditions,
Map<JdbcColumnHandle, String> leftAssignments,
Map<JdbcColumnHandle, String> rightAssignments)
Map<JdbcColumnHandle, String> rightProjections,
List<ParameterizedExpression> joinConditions)
{
// Verify assignments are present. This is safe assumption as join conditions are not pruned, and simplifies the code here.
verify(!leftAssignments.isEmpty(), "leftAssignments is empty");
verify(!rightAssignments.isEmpty(), "rightAssignments is empty");
// Joins wih no conditions are not pushed down, so it is a same assumption and simplifies the code here
verify(!joinConditions.isEmpty(), "joinConditions is empty");

String leftRelationAlias = "l";
String rightRelationAlias = "r";

String query = format(
"SELECT %s, %s FROM (%s) %s %s (%s) %s ON %s",
formatAssignments(client, leftRelationAlias, leftAssignments),
formatAssignments(client, rightRelationAlias, rightAssignments),
// The subquery aliases (`l` and `r`) are needed by some databases, but are not needed for expressions
// The joinConditions and output columns are aliased to use unique names.
"SELECT * FROM (SELECT %s FROM (%s) l) l %s (SELECT %s FROM (%s) r) r ON %s",
formatProjections(client, leftProjections),
leftSource.getQuery(),
leftRelationAlias,
formatJoinType(joinType),
formatProjections(client, rightProjections),
rightSource.getQuery(),
rightRelationAlias,
joinConditions.stream()
.map(condition -> formatJoinCondition(client, leftRelationAlias, rightRelationAlias, condition))
.collect(joining(" AND ")));
.map(ParameterizedExpression::expression)
.collect(joining(") AND (", "(", ")")));
List<QueryParameter> parameters = ImmutableList.<QueryParameter>builder()
.addAll(leftSource.getParameters())
.addAll(rightSource.getParameters())
.addAll(joinConditions.stream()
.flatMap(expression -> expression.parameters().stream())
.iterator())
.build();
return new PreparedQuery(query, parameters);
}
Expand Down Expand Up @@ -296,6 +293,13 @@ protected String buildJoinColumn(JdbcClient client, JdbcColumnHandle columnHandl
return client.quoted(columnHandle.getColumnName());
}

protected String formatProjections(JdbcClient client, Map<JdbcColumnHandle, String> projections)
{
return projections.entrySet().stream()
.map(entry -> format("%s AS %s", client.quoted(entry.getKey().getColumnName()), client.quoted(entry.getValue())))
.collect(joining(", "));
}

protected String formatAssignments(JdbcClient client, String relationAlias, Map<JdbcColumnHandle, String> assignments)
{
return assignments.entrySet().stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,13 +211,13 @@ public Optional<PreparedQuery> implementJoin(
ConnectorSession session,
JoinType joinType,
PreparedQuery leftSource,
Map<JdbcColumnHandle, String> leftProjections,
PreparedQuery rightSource,
List<JdbcJoinCondition> joinConditions,
Map<JdbcColumnHandle, String> rightAssignments,
Map<JdbcColumnHandle, String> leftAssignments,
Map<JdbcColumnHandle, String> rightProjections,
List<ParameterizedExpression> joinConditions,
JoinStatistics statistics)
{
return delegate().implementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics);
return delegate().implementJoin(session, joinType, leftSource, leftProjections, rightSource, rightProjections, joinConditions, statistics);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,10 @@ Optional<PreparedQuery> implementJoin(
ConnectorSession session,
JoinType joinType,
PreparedQuery leftSource,
Map<JdbcColumnHandle, String> leftProjections,
PreparedQuery rightSource,
List<JdbcJoinCondition> joinConditions,
Map<JdbcColumnHandle, String> rightAssignments,
Map<JdbcColumnHandle, String> leftAssignments,
Map<JdbcColumnHandle, String> rightProjections,
List<ParameterizedExpression> joinConditions,
JoinStatistics statistics);

boolean supportsTopN(ConnectorSession session, JdbcTableHandle handle, List<JdbcSortItem> sortOrder);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ PreparedQuery prepareJoinQuery(
Connection connection,
JoinType joinType,
PreparedQuery leftSource,
Map<JdbcColumnHandle, String> leftProjections,
PreparedQuery rightSource,
List<JdbcJoinCondition> joinConditions,
Map<JdbcColumnHandle, String> leftAssignments,
Map<JdbcColumnHandle, String> rightAssignments);
Map<JdbcColumnHandle, String> rightProjections,
List<ParameterizedExpression> joinConditions);

PreparedQuery prepareDeleteQuery(
JdbcClient client,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.jdbc.expression;

import com.google.common.collect.ImmutableList;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.expression.ConnectorExpressionRule;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.QueryParameter;
import io.trino.spi.expression.Call;
import io.trino.spi.expression.FunctionName;
import io.trino.spi.expression.Variable;
import io.trino.spi.type.VarcharType;

import java.util.Optional;
import java.util.Set;

import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.trino.matching.Capture.newCapture;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argument;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argumentCount;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.call;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.functionName;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.type;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.variable;
import static io.trino.plugin.jdbc.CaseSensitivity.CASE_SENSITIVE;
import static io.trino.spi.type.BooleanType.BOOLEAN;

public class RewriteCaseSensitiveComparison
implements ConnectorExpressionRule<Call, ParameterizedExpression>
{
private static final Capture<Variable> LEFT = newCapture();
private static final Capture<Variable> RIGHT = newCapture();

private final Pattern<Call> pattern;

public RewriteCaseSensitiveComparison(Set<ComparisonOperator> enabledOperators)
{
Set<FunctionName> functionNames = enabledOperators.stream()
.map(ComparisonOperator::getFunctionName)
.collect(toImmutableSet());

pattern = call()
.with(type().equalTo(BOOLEAN))
.with(functionName().matching(functionNames::contains))
.with(argumentCount().equalTo(2))
.with(argument(0).matching(variable().with(type().matching(VarcharType.class::isInstance)).capturedAs(LEFT)))
.with(argument(1).matching(variable().with(type().matching(VarcharType.class::isInstance)).capturedAs(RIGHT)));
}

@Override
public Pattern<Call> getPattern()
{
return pattern;
}

@Override
public Optional<ParameterizedExpression> rewrite(Call expression, Captures captures, RewriteContext<ParameterizedExpression> context)
{
ComparisonOperator comparison = ComparisonOperator.forFunctionName(expression.getFunctionName());
Variable firstArgument = captures.get(LEFT);
Variable secondArgument = captures.get(RIGHT);

if (!isCaseSensitive(firstArgument, context) || !isCaseSensitive(secondArgument, context)) {
return Optional.empty();
}
return context.defaultRewrite(firstArgument).flatMap(first ->
context.defaultRewrite(secondArgument).map(second ->
new ParameterizedExpression(
"(%s) %s (%s)".formatted(first.expression(), comparison.getOperator(), second.expression()),
ImmutableList.<QueryParameter>builder()
.addAll(first.parameters())
.addAll(second.parameters())
.build())));
}

private static boolean isCaseSensitive(Variable variable, RewriteContext<?> context)
{
return ((JdbcColumnHandle) context.getAssignment(variable.getName())).getJdbcTypeHandle().getCaseSensitivity().equals(Optional.of(CASE_SENSITIVE));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import io.trino.plugin.jdbc.JdbcClient;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.plugin.jdbc.JdbcJoinCondition;
import io.trino.plugin.jdbc.JdbcOutputTableHandle;
import io.trino.plugin.jdbc.JdbcProcedureHandle;
import io.trino.plugin.jdbc.JdbcProcedureHandle.ProcedureQuery;
Expand Down Expand Up @@ -231,13 +230,13 @@ public CallableStatement buildProcedure(ConnectorSession session, Connection con
public Optional<PreparedQuery> implementJoin(ConnectorSession session,
JoinType joinType,
PreparedQuery leftSource,
Map<JdbcColumnHandle, String> leftProjections,
PreparedQuery rightSource,
List<JdbcJoinCondition> joinConditions,
Map<JdbcColumnHandle, String> rightAssignments,
Map<JdbcColumnHandle, String> leftAssignments,
Map<JdbcColumnHandle, String> rightProjections,
List<ParameterizedExpression> joinConditions,
JoinStatistics statistics)
{
return stats.getImplementJoin().wrap(() -> delegate().implementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics));
return stats.getImplementJoin().wrap(() -> delegate().implementJoin(session, joinType, leftSource, leftProjections, rightSource, rightProjections, joinConditions, statistics));
}

@Override
Expand Down
Loading

0 comments on commit 1db64a7

Please sign in to comment.