Skip to content

Commit

Permalink
Re-add expression-less join pushdown as fallback
Browse files Browse the repository at this point in the history
Restore older JDBC join pushdown implementation not based on
`ConnectorExpression` as a fallback.

This comes as a separate commit so that the introduction of
`ConnectorExpression`-based join pushdown can be seen (e.g. reviewed) as
a _change_, not as an _addition_.
  • Loading branch information
findepi committed Jan 10, 2024
1 parent 1db64a7 commit 4450dde
Show file tree
Hide file tree
Showing 20 changed files with 466 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,41 @@ public Optional<PreparedQuery> implementJoin(
}
}

@Deprecated
@Override
public Optional<PreparedQuery> legacyImplementJoin(
ConnectorSession session,
JoinType joinType,
PreparedQuery leftSource,
PreparedQuery rightSource,
List<JdbcJoinCondition> joinConditions,
Map<JdbcColumnHandle, String> rightAssignments,
Map<JdbcColumnHandle, String> leftAssignments,
JoinStatistics statistics)
{
for (JdbcJoinCondition joinCondition : joinConditions) {
if (!isSupportedJoinCondition(session, joinCondition)) {
return Optional.empty();
}
}

try (Connection connection = this.connectionFactory.openConnection(session)) {
return Optional.of(queryBuilder.legacyPrepareJoinQuery(
this,
session,
connection,
joinType,
leftSource,
rightSource,
joinConditions,
leftAssignments,
rightAssignments));
}
catch (SQLException e) {
throw new TrinoException(JDBC_ERROR, e);
}
}

protected boolean isSupportedJoinCondition(ConnectorSession session, JdbcJoinCondition joinCondition)
{
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,20 @@ public Optional<PreparedQuery> implementJoin(
return delegate.implementJoin(session, joinType, leftSource, leftProjections, rightSource, rightProjections, joinConditions, statistics);
}

@Override
public Optional<PreparedQuery> legacyImplementJoin(
ConnectorSession session,
JoinType joinType,
PreparedQuery leftSource,
PreparedQuery rightSource,
List<JdbcJoinCondition> joinConditions,
Map<JdbcColumnHandle, String> rightAssignments,
Map<JdbcColumnHandle, String> leftAssignments,
JoinStatistics statistics)
{
return delegate.legacyImplementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics);
}

@Override
public boolean supportsTopN(ConnectorSession session, JdbcTableHandle handle, List<JdbcSortItem> sortOrder)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import io.trino.spi.connector.Constraint;
import io.trino.spi.connector.ConstraintApplicationResult;
import io.trino.spi.connector.JoinApplicationResult;
import io.trino.spi.connector.JoinCondition;
import io.trino.spi.connector.JoinStatistics;
import io.trino.spi.connector.JoinType;
import io.trino.spi.connector.LimitApplicationResult;
Expand Down Expand Up @@ -94,6 +95,7 @@
import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_NON_TRANSIENT_ERROR;
import static io.trino.plugin.jdbc.JdbcMetadataSessionProperties.isAggregationPushdownEnabled;
import static io.trino.plugin.jdbc.JdbcMetadataSessionProperties.isComplexExpressionPushdown;
import static io.trino.plugin.jdbc.JdbcMetadataSessionProperties.isComplexJoinPushdownEnabled;
import static io.trino.plugin.jdbc.JdbcMetadataSessionProperties.isJoinPushdownEnabled;
import static io.trino.plugin.jdbc.JdbcMetadataSessionProperties.isTopNPushdownEnabled;
import static io.trino.plugin.jdbc.JdbcWriteSessionProperties.isNonTransactionalInsert;
Expand Down Expand Up @@ -447,6 +449,19 @@ public Optional<JoinApplicationResult<ConnectorTableHandle>> applyJoin(
Map<String, ColumnHandle> rightAssignments,
JoinStatistics statistics)
{
if (!isComplexJoinPushdownEnabled(session)) {
// Fallback to the old join pushdown code
return JdbcMetadata.super.applyJoin(
session,
joinType,
left,
right,
joinCondition,
leftAssignments,
rightAssignments,
statistics);
}

if (isTableHandleForProcedure(left) || isTableHandleForProcedure(right)) {
return Optional.empty();
}
Expand Down Expand Up @@ -536,6 +551,101 @@ public Optional<JoinApplicationResult<ConnectorTableHandle>> applyJoin(
precalculateStatisticsForPushdown));
}

@Deprecated
@Override
public Optional<JoinApplicationResult<ConnectorTableHandle>> applyJoin(
ConnectorSession session,
JoinType joinType,
ConnectorTableHandle left,
ConnectorTableHandle right,
List<JoinCondition> joinConditions,
Map<String, ColumnHandle> leftAssignments,
Map<String, ColumnHandle> rightAssignments,
JoinStatistics statistics)
{
if (isTableHandleForProcedure(left) || isTableHandleForProcedure(right)) {
return Optional.empty();
}

if (!isJoinPushdownEnabled(session)) {
return Optional.empty();
}

JdbcTableHandle leftHandle = flushAttributesAsQuery(session, (JdbcTableHandle) left);
JdbcTableHandle rightHandle = flushAttributesAsQuery(session, (JdbcTableHandle) right);

if (!leftHandle.getAuthorization().equals(rightHandle.getAuthorization())) {
return Optional.empty();
}
int nextSyntheticColumnId = max(leftHandle.getNextSyntheticColumnId(), rightHandle.getNextSyntheticColumnId());

ImmutableMap.Builder<JdbcColumnHandle, JdbcColumnHandle> newLeftColumnsBuilder = ImmutableMap.builder();
OptionalInt maxColumnNameLength = jdbcClient.getMaxColumnNameLength(session);
for (JdbcColumnHandle column : jdbcClient.getColumns(session, leftHandle)) {
newLeftColumnsBuilder.put(column, createSyntheticJoinProjectionColumn(column, nextSyntheticColumnId, maxColumnNameLength));
nextSyntheticColumnId++;
}
Map<JdbcColumnHandle, JdbcColumnHandle> newLeftColumns = newLeftColumnsBuilder.buildOrThrow();

ImmutableMap.Builder<JdbcColumnHandle, JdbcColumnHandle> newRightColumnsBuilder = ImmutableMap.builder();
for (JdbcColumnHandle column : jdbcClient.getColumns(session, rightHandle)) {
newRightColumnsBuilder.put(column, createSyntheticJoinProjectionColumn(column, nextSyntheticColumnId, maxColumnNameLength));
nextSyntheticColumnId++;
}
Map<JdbcColumnHandle, JdbcColumnHandle> newRightColumns = newRightColumnsBuilder.buildOrThrow();

ImmutableList.Builder<JdbcJoinCondition> jdbcJoinConditions = ImmutableList.builder();
for (JoinCondition joinCondition : joinConditions) {
Optional<JdbcColumnHandle> leftColumn = getVariableColumnHandle(leftAssignments, joinCondition.getLeftExpression());
Optional<JdbcColumnHandle> rightColumn = getVariableColumnHandle(rightAssignments, joinCondition.getRightExpression());
if (leftColumn.isEmpty() || rightColumn.isEmpty()) {
return Optional.empty();
}
jdbcJoinConditions.add(new JdbcJoinCondition(leftColumn.get(), joinCondition.getOperator(), rightColumn.get()));
}

Optional<PreparedQuery> joinQuery = jdbcClient.legacyImplementJoin(
session,
joinType,
asPreparedQuery(leftHandle),
asPreparedQuery(rightHandle),
jdbcJoinConditions.build(),
newRightColumns.entrySet().stream()
.collect(toImmutableMap(Entry::getKey, entry -> entry.getValue().getColumnName())),
newLeftColumns.entrySet().stream()
.collect(toImmutableMap(Entry::getKey, entry -> entry.getValue().getColumnName())),
statistics);

if (joinQuery.isEmpty()) {
return Optional.empty();
}

return Optional.of(new JoinApplicationResult<>(
new JdbcTableHandle(
new JdbcQueryRelationHandle(joinQuery.get()),
TupleDomain.all(),
ImmutableList.of(),
Optional.empty(),
OptionalLong.empty(),
Optional.of(
ImmutableList.<JdbcColumnHandle>builder()
.addAll(newLeftColumns.values())
.addAll(newRightColumns.values())
.build()),
leftHandle.getAllReferencedTables().flatMap(leftReferencedTables ->
rightHandle.getAllReferencedTables().map(rightReferencedTables ->
ImmutableSet.<SchemaTableName>builder()
.addAll(leftReferencedTables)
.addAll(rightReferencedTables)
.build())),
nextSyntheticColumnId,
leftHandle.getAuthorization(),
leftHandle.getUpdateAssignments()),
ImmutableMap.copyOf(newLeftColumns),
ImmutableMap.copyOf(newRightColumns),
precalculateStatisticsForPushdown));
}

@VisibleForTesting
static JdbcColumnHandle createSyntheticJoinProjectionColumn(JdbcColumnHandle column, int nextSyntheticColumnId, OptionalInt optionalMaxColumnNameLength)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,46 @@ public PreparedQuery prepareJoinQuery(
return new PreparedQuery(query, parameters);
}

@Override
public PreparedQuery legacyPrepareJoinQuery(
JdbcClient client,
ConnectorSession session,
Connection connection,
JoinType joinType,
PreparedQuery leftSource,
PreparedQuery rightSource,
List<JdbcJoinCondition> joinConditions,
Map<JdbcColumnHandle, String> leftAssignments,
Map<JdbcColumnHandle, String> rightAssignments)
{
// Verify assignments are present. This is safe assumption as join conditions are not pruned, and simplifies the code here.
verify(!leftAssignments.isEmpty(), "leftAssignments is empty");
verify(!rightAssignments.isEmpty(), "rightAssignments is empty");
// Joins wih no conditions are not pushed down, so it is a same assumption and simplifies the code here
verify(!joinConditions.isEmpty(), "joinConditions is empty");

String leftRelationAlias = "l";
String rightRelationAlias = "r";

String query = format(
"SELECT %s, %s FROM (%s) %s %s (%s) %s ON %s",
formatAssignments(client, leftRelationAlias, leftAssignments),
formatAssignments(client, rightRelationAlias, rightAssignments),
leftSource.getQuery(),
leftRelationAlias,
formatJoinType(joinType),
rightSource.getQuery(),
rightRelationAlias,
joinConditions.stream()
.map(condition -> formatJoinCondition(client, leftRelationAlias, rightRelationAlias, condition))
.collect(joining(" AND ")));
List<QueryParameter> parameters = ImmutableList.<QueryParameter>builder()
.addAll(leftSource.getParameters())
.addAll(rightSource.getParameters())
.build();
return new PreparedQuery(query, parameters);
}

@Override
public PreparedQuery prepareDeleteQuery(
JdbcClient client,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,20 @@ public Optional<PreparedQuery> implementJoin(
return delegate().implementJoin(session, joinType, leftSource, leftProjections, rightSource, rightProjections, joinConditions, statistics);
}

@Override
public Optional<PreparedQuery> legacyImplementJoin(
ConnectorSession session,
JoinType joinType,
PreparedQuery leftSource,
PreparedQuery rightSource,
List<JdbcJoinCondition> joinConditions,
Map<JdbcColumnHandle, String> rightAssignments,
Map<JdbcColumnHandle, String> leftAssignments,
JoinStatistics statistics)
{
return delegate().legacyImplementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics);
}

@Override
public JdbcOutputTableHandle beginCreateTable(ConnectorSession session, ConnectorTableMetadata tableMetadata)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,17 @@ Optional<PreparedQuery> implementJoin(
List<ParameterizedExpression> joinConditions,
JoinStatistics statistics);

@Deprecated
Optional<PreparedQuery> legacyImplementJoin(
ConnectorSession session,
JoinType joinType,
PreparedQuery leftSource,
PreparedQuery rightSource,
List<JdbcJoinCondition> joinConditions,
Map<JdbcColumnHandle, String> rightAssignments,
Map<JdbcColumnHandle, String> leftAssignments,
JoinStatistics statistics);

boolean supportsTopN(ConnectorSession session, JdbcTableHandle handle, List<JdbcSortItem> sortOrder);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public class JdbcMetadataConfig
* in terms of performance and money due to an increased network traffic.
*/
private boolean joinPushdownEnabled;
private boolean complexJoinPushdownEnabled = true;
private boolean aggregationPushdownEnabled = true;

private boolean topNPushdownEnabled = true;
Expand Down Expand Up @@ -67,6 +68,19 @@ public JdbcMetadataConfig setJoinPushdownEnabled(boolean joinPushdownEnabled)
return this;
}

public boolean isComplexJoinPushdownEnabled()
{
return complexJoinPushdownEnabled;
}

@Config("join-pushdown.with-expressions")
@ConfigDescription("Enable join pushdown with complex expressions")
public JdbcMetadataConfig setComplexJoinPushdownEnabled(boolean complexJoinPushdownEnabled)
{
this.complexJoinPushdownEnabled = complexJoinPushdownEnabled;
return this;
}

public boolean isAggregationPushdownEnabled()
{
return aggregationPushdownEnabled;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public class JdbcMetadataSessionProperties
{
public static final String COMPLEX_EXPRESSION_PUSHDOWN = "complex_expression_pushdown";
public static final String JOIN_PUSHDOWN_ENABLED = "join_pushdown_enabled";
public static final String COMPLEX_JOIN_PUSHDOWN_ENABLED = "complex_join_pushdown_enabled";
public static final String AGGREGATION_PUSHDOWN_ENABLED = "aggregation_pushdown_enabled";
public static final String TOPN_PUSHDOWN_ENABLED = "topn_pushdown_enabled";
public static final String DOMAIN_COMPACTION_THRESHOLD = "domain_compaction_threshold";
Expand All @@ -54,6 +55,11 @@ public JdbcMetadataSessionProperties(JdbcMetadataConfig jdbcMetadataConfig, @Max
"Enable join pushdown",
jdbcMetadataConfig.isJoinPushdownEnabled(),
false))
.add(booleanProperty(
COMPLEX_JOIN_PUSHDOWN_ENABLED,
"Enable join pushdown with non-comparison expressions",
jdbcMetadataConfig.isComplexJoinPushdownEnabled(),
false))
.add(booleanProperty(
AGGREGATION_PUSHDOWN_ENABLED,
"Enable aggregation pushdown",
Expand Down Expand Up @@ -89,6 +95,11 @@ public static boolean isJoinPushdownEnabled(ConnectorSession session)
return session.getProperty(JOIN_PUSHDOWN_ENABLED, Boolean.class);
}

public static boolean isComplexJoinPushdownEnabled(ConnectorSession session)
{
return session.getProperty(COMPLEX_JOIN_PUSHDOWN_ENABLED, Boolean.class);
}

public static boolean isAggregationPushdownEnabled(ConnectorSession session)
{
return session.getProperty(AGGREGATION_PUSHDOWN_ENABLED, Boolean.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@ PreparedQuery prepareJoinQuery(
Map<JdbcColumnHandle, String> rightProjections,
List<ParameterizedExpression> joinConditions);

PreparedQuery legacyPrepareJoinQuery(
JdbcClient client,
ConnectorSession session,
Connection connection,
JoinType joinType,
PreparedQuery leftSource,
PreparedQuery rightSource,
List<JdbcJoinCondition> joinConditions,
Map<JdbcColumnHandle, String> leftAssignments,
Map<JdbcColumnHandle, String> rightAssignments);

PreparedQuery prepareDeleteQuery(
JdbcClient client,
ConnectorSession session,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import io.trino.plugin.jdbc.JdbcClient;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.plugin.jdbc.JdbcJoinCondition;
import io.trino.plugin.jdbc.JdbcOutputTableHandle;
import io.trino.plugin.jdbc.JdbcProcedureHandle;
import io.trino.plugin.jdbc.JdbcProcedureHandle.ProcedureQuery;
Expand Down Expand Up @@ -239,6 +240,19 @@ public Optional<PreparedQuery> implementJoin(ConnectorSession session,
return stats.getImplementJoin().wrap(() -> delegate().implementJoin(session, joinType, leftSource, leftProjections, rightSource, rightProjections, joinConditions, statistics));
}

@Override
public Optional<PreparedQuery> legacyImplementJoin(ConnectorSession session,
JoinType joinType,
PreparedQuery leftSource,
PreparedQuery rightSource,
List<JdbcJoinCondition> joinConditions,
Map<JdbcColumnHandle, String> rightAssignments,
Map<JdbcColumnHandle, String> leftAssignments,
JoinStatistics statistics)
{
return stats.getImplementJoin().wrap(() -> delegate().legacyImplementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics));
}

@Override
public Optional<String> getTableComment(ResultSet resultSet)
throws SQLException
Expand Down
Loading

0 comments on commit 4450dde

Please sign in to comment.