diff --git a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java index 0e6cbb17d701..ac91fad1f005 100644 --- a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java +++ b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java @@ -196,6 +196,7 @@ public final class SystemSessionProperties private static final String FAULT_TOLERANT_EXECUTION_SMALL_STAGE_SOURCE_SIZE_MULTIPLIER = "fault_tolerant_execution_small_stage_source_size_multiplier"; private static final String FAULT_TOLERANT_EXECUTION_SMALL_STAGE_REQUIRE_NO_MORE_PARTITIONS = "fault_tolerant_execution_small_stage_require_no_more_partitions"; private static final String FAULT_TOLERANT_EXECUTION_STAGE_ESTIMATION_FOR_EAGER_PARENT_ENABLED = "fault_tolerant_execution_stage_estimation_for_eager_parent_enabled"; + public static final String FAULT_TOLERANT_EXECUTION_ADAPTIVE_QUERY_PLANNING_ENABLED = "fault_tolerant_execution_adaptive_query_planning_enabled"; public static final String ADAPTIVE_PARTIAL_AGGREGATION_ENABLED = "adaptive_partial_aggregation_enabled"; public static final String ADAPTIVE_PARTIAL_AGGREGATION_UNIQUE_ROWS_RATIO_THRESHOLD = "adaptive_partial_aggregation_unique_rows_ratio_threshold"; public static final String REMOTE_TASK_ADAPTIVE_UPDATE_REQUEST_SIZE_ENABLED = "remote_task_adaptive_update_request_size_enabled"; @@ -999,6 +1000,11 @@ public SystemSessionProperties( "Enable aggressive stage output size estimation heuristic for children of stages to be executed eagerly", queryManagerConfig.isFaultTolerantExecutionStageEstimationForEagerParentEnabled(), true), + booleanProperty( + FAULT_TOLERANT_EXECUTION_ADAPTIVE_QUERY_PLANNING_ENABLED, + "Enable adaptive query planning for the fault tolerant execution", + queryManagerConfig.isFaultTolerantExecutionAdaptiveQueryPlanningEnabled(), + false), booleanProperty( ADAPTIVE_PARTIAL_AGGREGATION_ENABLED, "When enabled, partial aggregation might be adaptively turned off when it does not provide any performance gain", @@ -1846,6 +1852,11 @@ public static boolean isFaultTolerantExecutionStageEstimationForEagerParentEnabl return session.getSystemProperty(FAULT_TOLERANT_EXECUTION_STAGE_ESTIMATION_FOR_EAGER_PARENT_ENABLED, Boolean.class); } + public static boolean isFaultTolerantExecutionAdaptiveQueryPlanningEnabled(Session session) + { + return session.getSystemProperty(FAULT_TOLERANT_EXECUTION_ADAPTIVE_QUERY_PLANNING_ENABLED, Boolean.class); + } + public static boolean isAdaptivePartialAggregationEnabled(Session session) { return session.getSystemProperty(ADAPTIVE_PARTIAL_AGGREGATION_ENABLED, Boolean.class); diff --git a/core/trino-main/src/main/java/io/trino/cost/CachingStatsProvider.java b/core/trino-main/src/main/java/io/trino/cost/CachingStatsProvider.java index df7911d38612..be253bdd3143 100644 --- a/core/trino-main/src/main/java/io/trino/cost/CachingStatsProvider.java +++ b/core/trino-main/src/main/java/io/trino/cost/CachingStatsProvider.java @@ -42,15 +42,23 @@ public final class CachingStatsProvider private final Session session; private final TypeProvider types; private final TableStatsProvider tableStatsProvider; + private final RuntimeInfoProvider runtimeInfoProvider; private final Map cache = new IdentityHashMap<>(); public CachingStatsProvider(StatsCalculator statsCalculator, Session session, TypeProvider types, TableStatsProvider tableStatsProvider) { - this(statsCalculator, Optional.empty(), noLookup(), session, types, tableStatsProvider); + this(statsCalculator, Optional.empty(), noLookup(), session, types, tableStatsProvider, RuntimeInfoProvider.noImplementation()); } - public CachingStatsProvider(StatsCalculator statsCalculator, Optional memo, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider) + public CachingStatsProvider( + StatsCalculator statsCalculator, + Optional memo, + Lookup lookup, + Session session, + TypeProvider types, + TableStatsProvider tableStatsProvider, + RuntimeInfoProvider runtimeInfoProvider) { this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null"); this.memo = requireNonNull(memo, "memo is null"); @@ -58,6 +66,7 @@ public CachingStatsProvider(StatsCalculator statsCalculator, Optional memo this.session = requireNonNull(session, "session is null"); this.types = requireNonNull(types, "types is null"); this.tableStatsProvider = requireNonNull(tableStatsProvider, "tableStatsProvider is null"); + this.runtimeInfoProvider = requireNonNull(runtimeInfoProvider, "runtimeInfoProvider is null"); } @Override @@ -79,7 +88,7 @@ public PlanNodeStatsEstimate getStats(PlanNode node) return stats; } - stats = statsCalculator.calculateStats(node, new StatsCalculator.Context(this, lookup, session, types, tableStatsProvider)); + stats = statsCalculator.calculateStats(node, new StatsCalculator.Context(this, lookup, session, types, tableStatsProvider, runtimeInfoProvider)); verify(cache.put(node, stats) == null, "Stats already set"); return stats; } diff --git a/core/trino-main/src/main/java/io/trino/cost/RemoteSourceStatsRule.java b/core/trino-main/src/main/java/io/trino/cost/RemoteSourceStatsRule.java new file mode 100644 index 000000000000..7d88517949c3 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/cost/RemoteSourceStatsRule.java @@ -0,0 +1,149 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.trino.cost; + +import io.trino.cost.StatsCalculator.Context; +import io.trino.execution.scheduler.OutputDataSizeEstimate; +import io.trino.matching.Pattern; +import io.trino.spi.type.FixedWidthType; +import io.trino.spi.type.Type; +import io.trino.sql.planner.PlanFragment; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.TypeProvider; +import io.trino.sql.planner.plan.PlanFragmentId; +import io.trino.sql.planner.plan.PlanNode; +import io.trino.sql.planner.plan.RemoteSourceNode; + +import java.util.List; +import java.util.Optional; + +import static com.google.common.base.Verify.verify; +import static io.trino.cost.PlanNodeStatsEstimateMath.addStatsAndMaxDistinctValues; +import static io.trino.execution.scheduler.faulttolerant.OutputStatsEstimator.OutputStatsEstimateResult; +import static io.trino.sql.planner.plan.Patterns.remoteSourceNode; +import static io.trino.util.MoreMath.firstNonNaN; +import static java.lang.Double.NaN; +import static java.lang.Double.isNaN; + +public class RemoteSourceStatsRule + extends SimpleStatsRule +{ + private static final Pattern PATTERN = remoteSourceNode(); + + public RemoteSourceStatsRule(StatsNormalizer normalizer) + { + super(normalizer); + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + protected Optional doCalculate(RemoteSourceNode node, Context context) + { + Optional estimate = Optional.empty(); + RuntimeInfoProvider runtimeInfoProvider = context.runtimeInfoProvider(); + + for (int i = 0; i < node.getSourceFragmentIds().size(); i++) { + PlanFragmentId planFragmentId = node.getSourceFragmentIds().get(i); + OutputStatsEstimateResult stageRuntimeStats = runtimeInfoProvider.getRuntimeOutputStats(planFragmentId); + + PlanNodeStatsEstimate stageEstimatedStats = getEstimatedStats(runtimeInfoProvider, context.statsProvider(), planFragmentId); + PlanNodeStatsEstimate adjustedStageStats = adjustStats( + node.getOutputSymbols(), + context.types(), + stageRuntimeStats, + stageEstimatedStats); + + estimate = estimate + .map(planNodeStatsEstimate -> addStatsAndMaxDistinctValues(planNodeStatsEstimate, adjustedStageStats)) + .or(() -> Optional.of(adjustedStageStats)); + } + + verify(estimate.isPresent()); + return estimate; + } + + private PlanNodeStatsEstimate getEstimatedStats( + RuntimeInfoProvider runtimeInfoProvider, + StatsProvider statsProvider, + PlanFragmentId fragmentId) + { + PlanFragment fragment = runtimeInfoProvider.getPlanFragment(fragmentId); + PlanNode fragmentRoot = fragment.getRoot(); + PlanNodeStatsEstimate estimate = fragment.getStatsAndCosts().getStats().get(fragmentRoot.getId()); + // We will not have stats for the root node in a PlanFragment if collect_plan_statistics_for_all_queries + // is disabled and query isn't an explain analyze. + if (estimate != null && !estimate.isOutputRowCountUnknown()) { + return estimate; + } + return statsProvider.getStats(fragmentRoot); + } + + private PlanNodeStatsEstimate adjustStats( + List outputs, + TypeProvider typeProvider, + OutputStatsEstimateResult runtimeStats, + PlanNodeStatsEstimate estimateStats) + { + if (runtimeStats.isUnknown()) { + return estimateStats; + } + + // We prefer runtime stats over estimated stats, because runtime stats are more accurate. + OutputDataSizeEstimate outputDataSizeEstimate = runtimeStats.outputDataSizeEstimate(); + PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder() + .setOutputRowCount(runtimeStats.outputRowCountEstimate()); + + double fixedWidthTypeSize = 0; + double variableTypeValuesCount = 0; + + for (Symbol outputSymbol : outputs) { + Type type = typeProvider.get(outputSymbol); + SymbolStatsEstimate symbolStatistics = estimateStats.getSymbolStatistics(outputSymbol); + double nullsFraction = firstNonNaN(symbolStatistics.getNullsFraction(), 0d); + double numberOfNonNullRows = runtimeStats.outputRowCountEstimate() * (1.0 - nullsFraction); + + if (type instanceof FixedWidthType) { + fixedWidthTypeSize += numberOfNonNullRows * ((FixedWidthType) type).getFixedSize(); + } + else { + variableTypeValuesCount += numberOfNonNullRows; + } + } + + double runtimeOutputDataSize = outputDataSizeEstimate.getTotalSizeInBytes(); + double variableTypeValueAverageSize = NaN; + if (variableTypeValuesCount > 0 && runtimeOutputDataSize > fixedWidthTypeSize) { + variableTypeValueAverageSize = (runtimeOutputDataSize - fixedWidthTypeSize) / variableTypeValuesCount; + } + + for (Symbol outputSymbol : outputs) { + SymbolStatsEstimate symbolStatistics = estimateStats.getSymbolStatistics(outputSymbol); + Type type = typeProvider.get(outputSymbol); + if (!(isNaN(variableTypeValueAverageSize) || type instanceof FixedWidthType)) { + symbolStatistics = SymbolStatsEstimate.buildFrom(symbolStatistics) + .setAverageRowSize(variableTypeValueAverageSize) + .build(); + } + result.addSymbolStatistics(outputSymbol, symbolStatistics); + } + + return result.build(); + } +} diff --git a/core/trino-main/src/main/java/io/trino/cost/RuntimeInfoProvider.java b/core/trino-main/src/main/java/io/trino/cost/RuntimeInfoProvider.java new file mode 100644 index 000000000000..23f3d6629d3e --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/cost/RuntimeInfoProvider.java @@ -0,0 +1,58 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cost; + +import io.trino.sql.planner.PlanFragment; +import io.trino.sql.planner.plan.PlanFragmentId; + +import java.util.List; + +import static io.trino.execution.scheduler.faulttolerant.OutputStatsEstimator.OutputStatsEstimateResult; + +/** + * Provides runtime information from FTE execution. This is used to re-optimize the plan based + * on the actual runtime statistics. + */ +public interface RuntimeInfoProvider +{ + OutputStatsEstimateResult getRuntimeOutputStats(PlanFragmentId planFragmentId); + + PlanFragment getPlanFragment(PlanFragmentId planFragmentId); + + List getAllPlanFragments(); + + static RuntimeInfoProvider noImplementation() + { + return new RuntimeInfoProvider() + { + @Override + public OutputStatsEstimateResult getRuntimeOutputStats(PlanFragmentId planFragmentId) + { + throw new UnsupportedOperationException("RuntimeInfoProvider is not implemented"); + } + + @Override + public PlanFragment getPlanFragment(PlanFragmentId planFragmentId) + { + throw new UnsupportedOperationException("RuntimeInfoProvider is not implemented"); + } + + @Override + public List getAllPlanFragments() + { + throw new UnsupportedOperationException("RuntimeInfoProvider is not implemented"); + } + }; + } +} diff --git a/core/trino-main/src/main/java/io/trino/cost/StaticRuntimeInfoProvider.java b/core/trino-main/src/main/java/io/trino/cost/StaticRuntimeInfoProvider.java new file mode 100644 index 000000000000..6a7380a09d32 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/cost/StaticRuntimeInfoProvider.java @@ -0,0 +1,61 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.trino.cost; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.sql.planner.PlanFragment; +import io.trino.sql.planner.plan.PlanFragmentId; + +import java.util.List; +import java.util.Map; + +import static io.trino.execution.scheduler.faulttolerant.OutputStatsEstimator.OutputStatsEstimateResult; +import static java.util.Objects.requireNonNull; + +public class StaticRuntimeInfoProvider + implements RuntimeInfoProvider +{ + private final Map runtimeOutputStats; + private final Map planFragments; + + public StaticRuntimeInfoProvider( + Map runtimeOutputStats, + Map planFragments) + { + this.runtimeOutputStats = ImmutableMap.copyOf(requireNonNull(runtimeOutputStats, "runtimeOutputStats is null")); + this.planFragments = ImmutableMap.copyOf(requireNonNull(planFragments, "planFragments is null")); + } + + @Override + public OutputStatsEstimateResult getRuntimeOutputStats(PlanFragmentId planFragmentId) + { + return runtimeOutputStats.getOrDefault(planFragmentId, OutputStatsEstimateResult.unknown()); + } + + @Override + public PlanFragment getPlanFragment(PlanFragmentId planFragmentId) + { + PlanFragment planFragment = planFragments.get(planFragmentId); + requireNonNull(planFragment, "planFragment must not be null: %s".formatted(planFragmentId)); + return planFragment; + } + + @Override + public List getAllPlanFragments() + { + return ImmutableList.copyOf(planFragments.values()); + } +} diff --git a/core/trino-main/src/main/java/io/trino/cost/StatsCalculator.java b/core/trino-main/src/main/java/io/trino/cost/StatsCalculator.java index eeb8dc08ba6e..3a0b7d235585 100644 --- a/core/trino-main/src/main/java/io/trino/cost/StatsCalculator.java +++ b/core/trino-main/src/main/java/io/trino/cost/StatsCalculator.java @@ -40,7 +40,8 @@ record Context( Lookup lookup, Session session, TypeProvider types, - TableStatsProvider tableStatsProvider) + TableStatsProvider tableStatsProvider, + RuntimeInfoProvider runtimeInfoProvider) { public Context { @@ -49,6 +50,7 @@ record Context( requireNonNull(session, "session is null"); requireNonNull(types, "types is null"); requireNonNull(tableStatsProvider, "tableStatsProvider is null"); + requireNonNull(runtimeInfoProvider, "runtimeInfoProvider is null"); } } } diff --git a/core/trino-main/src/main/java/io/trino/cost/StatsCalculatorModule.java b/core/trino-main/src/main/java/io/trino/cost/StatsCalculatorModule.java index 539b34ab3bf7..ad10a6112262 100644 --- a/core/trino-main/src/main/java/io/trino/cost/StatsCalculatorModule.java +++ b/core/trino-main/src/main/java/io/trino/cost/StatsCalculatorModule.java @@ -85,6 +85,7 @@ public List> get() rules.add(new SampleStatsRule(normalizer)); rules.add(new SortStatsRule()); rules.add(new DynamicFilterSourceStatsRule()); + rules.add(new RemoteSourceStatsRule(normalizer)); return rules.build(); } diff --git a/core/trino-main/src/main/java/io/trino/execution/QueryManagerConfig.java b/core/trino-main/src/main/java/io/trino/execution/QueryManagerConfig.java index eee5560cb84f..bd7cce16ea7f 100644 --- a/core/trino-main/src/main/java/io/trino/execution/QueryManagerConfig.java +++ b/core/trino-main/src/main/java/io/trino/execution/QueryManagerConfig.java @@ -145,6 +145,7 @@ public class QueryManagerConfig private double faultTolerantExecutionSmallStageSourceSizeMultiplier = 1.2; private boolean faultTolerantExecutionSmallStageRequireNoMorePartitions; private boolean faultTolerantExecutionStageEstimationForEagerParentEnabled = true; + private boolean faultTolerantExecutionAdaptiveQueryPlanningEnabled; @Min(1) public int getScheduleSplitBatchSize() @@ -1105,6 +1106,19 @@ public QueryManagerConfig setFaultTolerantExecutionStageEstimationForEagerParent return this; } + public boolean isFaultTolerantExecutionAdaptiveQueryPlanningEnabled() + { + return faultTolerantExecutionAdaptiveQueryPlanningEnabled; + } + + @Config("fault-tolerant-execution-adaptive-query-planning-enabled") + @ConfigDescription("Enable adaptive query planning for the fault tolerant execution") + public QueryManagerConfig setFaultTolerantExecutionAdaptiveQueryPlanningEnabled(boolean faultTolerantExecutionSmallStageEstimationEnabled) + { + this.faultTolerantExecutionAdaptiveQueryPlanningEnabled = faultTolerantExecutionSmallStageEstimationEnabled; + return this; + } + public void applyFaultTolerantExecutionDefaults() { remoteTaskMaxErrorDuration = new Duration(1, MINUTES); diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java b/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java index d7b7ee488a0a..d924f73eaebf 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java @@ -13,6 +13,7 @@ */ package io.trino.execution; +import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListenableFuture; import com.google.errorprone.annotations.ThreadSafe; import com.google.inject.Inject; @@ -24,6 +25,7 @@ import io.opentelemetry.context.Context; import io.trino.Session; import io.trino.SystemSessionProperties; +import io.trino.cost.CachingTableStatsProvider; import io.trino.cost.CostCalculator; import io.trino.cost.StatsCalculator; import io.trino.exchange.ExchangeManagerRegistry; @@ -56,6 +58,7 @@ import io.trino.sql.analyzer.Analysis; import io.trino.sql.analyzer.Analyzer; import io.trino.sql.analyzer.AnalyzerFactory; +import io.trino.sql.planner.AdaptivePlanner; import io.trino.sql.planner.InputExtractor; import io.trino.sql.planner.IrTypeAnalyzer; import io.trino.sql.planner.LogicalPlanner; @@ -67,6 +70,7 @@ import io.trino.sql.planner.PlanOptimizersFactory; import io.trino.sql.planner.SplitSourceFactory; import io.trino.sql.planner.SubPlan; +import io.trino.sql.planner.optimizations.AdaptivePlanOptimizer; import io.trino.sql.planner.optimizations.PlanOptimizer; import io.trino.sql.planner.plan.OutputNode; import io.trino.sql.tree.ExplainAnalyze; @@ -95,6 +99,7 @@ import static io.trino.execution.QueryState.PLANNING; import static io.trino.server.DynamicFilterService.DynamicFiltersStats; import static io.trino.spi.StandardErrorCode.STACK_OVERFLOW; +import static io.trino.sql.planner.sanity.PlanSanityChecker.DISTRIBUTED_PLAN_SANITY_CHECKER; import static io.trino.tracing.ScopedSpan.scopedSpan; import static java.lang.Thread.currentThread; import static java.util.Objects.requireNonNull; @@ -116,6 +121,7 @@ public class SqlQueryExecution private final OutputStatsEstimatorFactory outputStatsEstimatorFactory; private final TaskExecutionStats taskExecutionStats; private final List planOptimizers; + private final List adaptivePlanOptimizers; private final PlanFragmenter planFragmenter; private final RemoteTaskFactory remoteTaskFactory; private final int scheduleSplitBatchSize; @@ -155,6 +161,7 @@ private SqlQueryExecution( OutputStatsEstimatorFactory outputStatsEstimatorFactory, TaskExecutionStats taskExecutionStats, List planOptimizers, + List adaptivePlanOptimizers, PlanFragmenter planFragmenter, RemoteTaskFactory remoteTaskFactory, int scheduleSplitBatchSize, @@ -208,6 +215,9 @@ private SqlQueryExecution( // analyze query this.analysis = analyze(preparedQuery, stateMachine, warningCollector, planOptimizersStatsCollector, analyzerFactory); + // for adaptive planner + this.adaptivePlanOptimizers = ImmutableList.copyOf(requireNonNull(adaptivePlanOptimizers, "adaptivePlanOptimizers is null")); + stateMachine.addStateChangeListener(state -> { if (!state.isDone()) { return; @@ -401,11 +411,12 @@ public void start() }, directExecutor()); try { - PlanRoot plan = planQuery(); + CachingTableStatsProvider tableStatsProvider = new CachingTableStatsProvider(plannerContext.getMetadata(), getSession()); + PlanRoot plan = planQuery(tableStatsProvider); // DynamicFilterService needs plan for query to be registered. // Query should be registered before dynamic filter suppliers are requested in distribution planning. registerDynamicFilteringQuery(plan); - planDistribution(plan); + planDistribution(plan, tableStatsProvider); } finally { synchronized (this) { @@ -457,20 +468,20 @@ public void addFinalQueryInfoListener(StateChangeListener stateChange stateMachine.addQueryInfoStateChangeListener(stateChangeListener); } - private PlanRoot planQuery() + private PlanRoot planQuery(CachingTableStatsProvider tableStatsProvider) { Span span = tracer.spanBuilder("planner") .setParent(Context.current().with(getSession().getQuerySpan())) .startSpan(); try (var ignored = scopedSpan(span)) { - return doPlanQuery(); + return doPlanQuery(tableStatsProvider); } catch (StackOverflowError e) { throw new TrinoException(STACK_OVERFLOW, "statement is too large (stack overflow during analysis)", e); } } - private PlanRoot doPlanQuery() + private PlanRoot doPlanQuery(CachingTableStatsProvider tableStatsProvider) { // plan query PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); @@ -482,7 +493,8 @@ private PlanRoot doPlanQuery() statsCalculator, costCalculator, stateMachine.getWarningCollector(), - planOptimizersStatsCollector); + planOptimizersStatsCollector, + tableStatsProvider); Plan plan = logicalPlanner.plan(analysis); queryPlan.set(plan); @@ -503,7 +515,7 @@ private PlanRoot doPlanQuery() return new PlanRoot(fragmentedPlan, !explainAnalyze); } - private void planDistribution(PlanRoot plan) + private void planDistribution(PlanRoot plan, CachingTableStatsProvider tableStatsProvider) { // if query was canceled, skip creating scheduler if (stateMachine.isDone()) { @@ -563,6 +575,16 @@ private void planDistribution(PlanRoot plan) failureDetector, dynamicFilterService, taskExecutionStats, + new AdaptivePlanner( + stateMachine.getSession(), + plannerContext, + adaptivePlanOptimizers, + planFragmenter, + DISTRIBUTED_PLAN_SANITY_CHECKER, + typeAnalyzer, + stateMachine.getWarningCollector(), + planOptimizersStatsCollector, + tableStatsProvider), plan.getRoot()); break; default: @@ -755,6 +777,7 @@ public static class SqlQueryExecutionFactory private final OutputStatsEstimatorFactory outputStatsEstimatorFactory; private final TaskExecutionStats taskExecutionStats; private final List planOptimizers; + private final List adaptivePlanOptimizers; private final PlanFragmenter planFragmenter; private final RemoteTaskFactory remoteTaskFactory; private final ExecutorService queryExecutor; @@ -823,7 +846,9 @@ public static class SqlQueryExecutionFactory this.failureDetector = requireNonNull(failureDetector, "failureDetector is null"); this.nodeTaskMap = requireNonNull(nodeTaskMap, "nodeTaskMap is null"); this.executionPolicies = requireNonNull(executionPolicies, "executionPolicies is null"); - this.planOptimizers = planOptimizersFactory.get(); + requireNonNull(planOptimizersFactory, "planOptimizersFactory is null"); + this.planOptimizers = planOptimizersFactory.getPlanOptimizers(); + this.adaptivePlanOptimizers = planOptimizersFactory.getAdaptivePlanOptimizers(); this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null"); this.costCalculator = requireNonNull(costCalculator, "costCalculator is null"); this.dynamicFilterService = requireNonNull(dynamicFilterService, "dynamicFilterService is null"); @@ -862,6 +887,7 @@ public QueryExecution createQueryExecution( outputStatsEstimatorFactory, taskExecutionStats, planOptimizers, + adaptivePlanOptimizers, planFragmenter, remoteTaskFactory, scheduleSplitBatchSize, diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/ByEagerParentOutputStatsEstimator.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/ByEagerParentOutputStatsEstimator.java index 90a1fbebf243..75a475e8e317 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/ByEagerParentOutputStatsEstimator.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/ByEagerParentOutputStatsEstimator.java @@ -49,6 +49,6 @@ public Optional getEstimatedOutputStats(StageExecutio for (int i = 0; i < outputPartitionsCount; ++i) { estimateBuilder.add(0); } - return Optional.of(new OutputStatsEstimateResult(estimateBuilder.build(), 0, "FOR_EAGER_PARENT")); + return Optional.of(new OutputStatsEstimateResult(estimateBuilder.build(), 0, "FOR_EAGER_PARENT", false)); } } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/BySmallStageOutputStatsEstimator.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/BySmallStageOutputStatsEstimator.java index ca9f210c90c3..a8ea496bd93c 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/BySmallStageOutputStatsEstimator.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/BySmallStageOutputStatsEstimator.java @@ -143,6 +143,6 @@ public Optional getEstimatedOutputStats(StageExecutio estimateBuilder.add(inputSizeEstimate / outputPartitionsCount); } // TODO: For now we can skip calculating outputRowCountEstimate since we won't run adaptive planner in the case of small inputs - return Optional.of(new OutputStatsEstimateResult(estimateBuilder.build(), 0, "BY_SMALL_INPUT")); + return Optional.of(new OutputStatsEstimateResult(estimateBuilder.build(), 0, "BY_SMALL_INPUT", false)); } } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/ByTaskProgressOutputStatsEstimator.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/ByTaskProgressOutputStatsEstimator.java index 25ea6f610f8e..12b5a3e3ac80 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/ByTaskProgressOutputStatsEstimator.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/ByTaskProgressOutputStatsEstimator.java @@ -72,6 +72,6 @@ public Optional getEstimatedOutputStats(StageExecutio estimateBuilder.add((long) (partitionSize / progress)); } long outputRowCountEstimate = (long) (stageExecution.getOutputRowCount() / progress); - return Optional.of(new OutputStatsEstimateResult(new OutputDataSizeEstimate(estimateBuilder.build()), outputRowCountEstimate, "BY_PROGRESS")); + return Optional.of(new OutputStatsEstimateResult(new OutputDataSizeEstimate(estimateBuilder.build()), outputRowCountEstimate, "BY_PROGRESS", true)); } } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/EventDrivenFaultTolerantQueryScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/EventDrivenFaultTolerantQueryScheduler.java index 221e0ded36db..d98330010c82 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/EventDrivenFaultTolerantQueryScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/EventDrivenFaultTolerantQueryScheduler.java @@ -42,6 +42,8 @@ import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.context.Context; import io.trino.Session; +import io.trino.cost.RuntimeInfoProvider; +import io.trino.cost.StaticRuntimeInfoProvider; import io.trino.exchange.ExchangeContextInstance; import io.trino.exchange.SpoolingExchangeInput; import io.trino.execution.BasicStageStats; @@ -91,10 +93,9 @@ import io.trino.spi.exchange.ExchangeSourceHandle; import io.trino.spi.exchange.ExchangeSourceOutputSelector; import io.trino.split.RemoteSplit; +import io.trino.sql.planner.AdaptivePlanner; import io.trino.sql.planner.NodePartitioningManager; import io.trino.sql.planner.PlanFragment; -import io.trino.sql.planner.PlanFragmentIdAllocator; -import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.SubPlan; import io.trino.sql.planner.optimizations.PlanNodeSearcher; import io.trino.sql.planner.plan.AggregationNode; @@ -116,7 +117,6 @@ import java.lang.ref.SoftReference; import java.util.ArrayDeque; import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; @@ -144,8 +144,6 @@ import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; -import static com.google.common.collect.ImmutableSet.toImmutableSet; -import static io.airlift.units.DataSize.succinctBytes; import static io.trino.SystemSessionProperties.getFaultTolerantExecutionMaxPartitionCount; import static io.trino.SystemSessionProperties.getFaultTolerantExecutionRuntimeAdaptivePartitioningMaxTaskSize; import static io.trino.SystemSessionProperties.getFaultTolerantExecutionRuntimeAdaptivePartitioningPartitionCount; @@ -156,6 +154,7 @@ import static io.trino.SystemSessionProperties.getRetryMaxDelay; import static io.trino.SystemSessionProperties.getRetryPolicy; import static io.trino.SystemSessionProperties.getTaskRetryAttemptsPerTask; +import static io.trino.SystemSessionProperties.isFaultTolerantExecutionAdaptiveQueryPlanningEnabled; import static io.trino.SystemSessionProperties.isFaultTolerantExecutionRuntimeAdaptivePartitioningEnabled; import static io.trino.SystemSessionProperties.isFaultTolerantExecutionStageEstimationForEagerParentEnabled; import static io.trino.execution.BasicStageStats.aggregateBasicStageStats; @@ -178,14 +177,9 @@ import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static io.trino.spi.StandardErrorCode.REMOTE_HOST_GONE; import static io.trino.spi.exchange.Exchange.SourceHandlesDeliveryMode.EAGER; -import static io.trino.sql.planner.RuntimeAdaptivePartitioningRewriter.consumesHashPartitionedInput; -import static io.trino.sql.planner.RuntimeAdaptivePartitioningRewriter.getMaxPlanFragmentId; -import static io.trino.sql.planner.RuntimeAdaptivePartitioningRewriter.getMaxPlanId; -import static io.trino.sql.planner.RuntimeAdaptivePartitioningRewriter.overridePartitionCountRecursively; import static io.trino.sql.planner.SystemPartitioningHandle.COORDINATOR_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static io.trino.sql.planner.TopologicalOrderSubPlanVisitor.sortPlanInTopologicalOrder; -import static io.trino.sql.planner.plan.ExchangeNode.Type.REPLICATE; import static io.trino.tracing.TrinoAttributes.FAILURE_MESSAGE; import static io.trino.util.Failures.toFailure; import static java.lang.Math.max; @@ -223,6 +217,8 @@ public class EventDrivenFaultTolerantQueryScheduler private final FailureDetector failureDetector; private final DynamicFilterService dynamicFilterService; private final TaskExecutionStats taskExecutionStats; + private final AdaptivePlanner adaptivePlanner; + private final boolean adaptiveQueryPlanningEnabled; private final SubPlan originalPlan; private final boolean stageEstimationForEagerParentEnabled; @@ -253,6 +249,7 @@ public EventDrivenFaultTolerantQueryScheduler( FailureDetector failureDetector, DynamicFilterService dynamicFilterService, TaskExecutionStats taskExecutionStats, + AdaptivePlanner adaptivePlanner, SubPlan originalPlan) { this.queryStateMachine = requireNonNull(queryStateMachine, "queryStateMachine is null"); @@ -276,6 +273,8 @@ public EventDrivenFaultTolerantQueryScheduler( this.failureDetector = requireNonNull(failureDetector, "failureDetector is null"); this.dynamicFilterService = requireNonNull(dynamicFilterService, "dynamicFilterService is null"); this.taskExecutionStats = requireNonNull(taskExecutionStats, "taskExecutionStats is null"); + this.adaptivePlanner = requireNonNull(adaptivePlanner, "adaptivePlanner is null"); + this.adaptiveQueryPlanningEnabled = isFaultTolerantExecutionAdaptiveQueryPlanningEnabled(queryStateMachine.getSession()); this.originalPlan = requireNonNull(originalPlan, "originalPlan is null"); this.stageEstimationForEagerParentEnabled = isFaultTolerantExecutionStageEstimationForEagerParentEnabled(queryStateMachine.getSession()); @@ -359,7 +358,9 @@ public synchronized void start() isFaultTolerantExecutionRuntimeAdaptivePartitioningEnabled(session), getFaultTolerantExecutionRuntimeAdaptivePartitioningPartitionCount(session), getFaultTolerantExecutionRuntimeAdaptivePartitioningMaxTaskSize(session), - stageEstimationForEagerParentEnabled); + stageEstimationForEagerParentEnabled, + adaptivePlanner, + adaptiveQueryPlanningEnabled); queryExecutor.submit(scheduler::run); } catch (Throwable t) { @@ -447,11 +448,7 @@ public StageInfo getStageInfo() // make sure that plan is not staler than stageInfos since `getStageInfo` is called asynchronously SubPlan plan = requireNonNull(this.plan.get(), "plan is null"); Set reportedFragments = new HashSet<>(); - StageInfo stageInfo = getStageInfo(plan, stageInfos, reportedFragments); - // TODO Some stages may no longer be present in the plan when adaptive re-planning is implemented - // TODO Figure out how to report statistics for such stages - verify(reportedFragments.containsAll(stageInfos.keySet()), "some stages are left unreported"); - return stageInfo; + return getStageInfo(plan, stageInfos, reportedFragments); } private StageInfo getStageInfo(SubPlan plan, Map infos, Set reportedFragments) @@ -710,6 +707,10 @@ private static class Scheduler private SubPlan plan; private List planInTopologicalOrder; + + private final AdaptivePlanner adaptivePlanner; + private final boolean adaptiveQueryPlanningEnabled; + private final Map stageExecutions = new HashMap<>(); private final Map isReadyForExecutionCache = new HashMap<>(); private final SetMultimap stageConsumers = HashMultimap.create(); @@ -753,7 +754,9 @@ public Scheduler( boolean runtimeAdaptivePartitioningEnabled, int runtimeAdaptivePartitioningPartitionCount, DataSize runtimeAdaptivePartitioningMaxTaskSize, - boolean stageEstimationForEagerParentEnabled) + boolean stageEstimationForEagerParentEnabled, + AdaptivePlanner adaptivePlanner, + boolean adaptiveQueryPlanningEnabled) { this.queryStateMachine = requireNonNull(queryStateMachine, "queryStateMachine is null"); this.metadata = requireNonNull(metadata, "metadata is null"); @@ -785,6 +788,8 @@ public Scheduler( this.runtimeAdaptivePartitioningEnabled = runtimeAdaptivePartitioningEnabled; this.runtimeAdaptivePartitioningPartitionCount = runtimeAdaptivePartitioningPartitionCount; this.runtimeAdaptivePartitioningMaxTaskSizeInBytes = requireNonNull(runtimeAdaptivePartitioningMaxTaskSize, "runtimeAdaptivePartitioningMaxTaskSize is null").toBytes(); + this.adaptivePlanner = requireNonNull(adaptivePlanner, "adaptivePlanner is null"); + this.adaptiveQueryPlanningEnabled = adaptiveQueryPlanningEnabled; this.stageEstimationForEagerParentEnabled = stageEstimationForEagerParentEnabled; this.schedulerSpan = tracer.spanBuilder("scheduler") .setParent(Context.current().with(queryStateMachine.getSession().getQuerySpan())) @@ -1079,22 +1084,12 @@ private SubPlan optimizePlan(SubPlan plan) { // Re-optimize plan here based on available runtime statistics. // Fragments changed due to re-optimization as well as their downstream stages are expected to be assigned new fragment ids. - plan = updateStagesPartitioning(plan); - return plan; - } - - private SubPlan updateStagesPartitioning(SubPlan plan) - { - if (!runtimeAdaptivePartitioningEnabled || runtimeAdaptivePartitioningApplied) { + if (!adaptiveQueryPlanningEnabled) { return plan; } for (SubPlan subPlan : planInTopologicalOrder) { PlanFragment fragment = subPlan.getFragment(); - if (!consumesHashPartitionedInput(fragment)) { - // no input hash partitioning present - continue; - } StageId stageId = getStageId(fragment.getId()); if (stageExecutions.containsKey(stageId)) { @@ -1102,61 +1097,55 @@ private SubPlan updateStagesPartitioning(SubPlan plan) continue; } + if (subPlan.getChildren().isEmpty()) { + // Skip leaf fragments since adaptive planner can't do much with them. + continue; + } + IsReadyForExecutionResult isReadyForExecutionResult = isReadyForExecution(subPlan); // Caching is not only needed to avoid duplicate calls, but also to avoid the case that a stage that // is not ready now but becomes ready when updateStageExecutions. - // We want to avoid starting an execution without considering changing the number of partitions. + // We want to avoid starting an execution without considering changes in plan. // TODO: think about how to eliminate the cache - isReadyForExecutionCache.put(subPlan, isReadyForExecutionResult); - if (!isReadyForExecutionResult.isReadyForExecution()) { - // not ready for execution - continue; - } + IsReadyForExecutionResult oldValue = isReadyForExecutionCache.put(subPlan, isReadyForExecutionResult); - // calculate (estimated) input data size to determine if we want to change number of partitions at runtime - List partitionedInputBytes = fragment.getRemoteSourceNodes().stream() - .filter(remoteSourceNode -> remoteSourceNode.getExchangeType() != REPLICATE) - .map(remoteSourceNode -> remoteSourceNode.getSourceFragmentIds().stream() - .mapToLong(sourceFragmentId -> { - StageId sourceStageId = getStageId(sourceFragmentId); - OutputDataSizeEstimate outputDataSizeEstimate = isReadyForExecutionResult.getSourceOutputSizeEstimates().get(sourceStageId); - verify(outputDataSizeEstimate != null, "outputDataSizeEstimate not found for source stage %s", sourceStageId); - return outputDataSizeEstimate.getTotalSizeInBytes(); - }) - .sum()) - .collect(toImmutableList()); - // Currently the memory estimation is simplified: - // if it's an aggregation, then we use the total input bytes as the memory consumption - // if it involves multiple joins, conservatively we assume the smallest remote source will be streamed through - // and use the sum of input bytes of other remote sources as the memory consumption - // TODO: more accurate memory estimation based on context (https://github.com/trinodb/trino/issues/18698) - long estimatedMemoryConsumptionInBytes = (partitionedInputBytes.size() == 1) ? partitionedInputBytes.get(0) : - partitionedInputBytes.stream().mapToLong(Long::longValue).sum() - Collections.min(partitionedInputBytes); - - int partitionCount = fragment.getPartitionCount().orElse(maxPartitionCount); - if (estimatedMemoryConsumptionInBytes > runtimeAdaptivePartitioningMaxTaskSizeInBytes * partitionCount) { - log.info("Stage %s has an estimated memory consumption of %s, changing partition count from %s to %s", - stageId, succinctBytes(estimatedMemoryConsumptionInBytes), partitionCount, runtimeAdaptivePartitioningPartitionCount); - runtimeAdaptivePartitioningApplied = true; - PlanFragmentIdAllocator planFragmentIdAllocator = new PlanFragmentIdAllocator(getMaxPlanFragmentId(planInTopologicalOrder) + 1); - PlanNodeIdAllocator planNodeIdAllocator = new PlanNodeIdAllocator(getMaxPlanId(planInTopologicalOrder) + 1); - return overridePartitionCountRecursively( - plan, - partitionCount, - runtimeAdaptivePartitioningPartitionCount, - planFragmentIdAllocator, - planNodeIdAllocator, - planInTopologicalOrder.stream() - .map(SubPlan::getFragment) - .map(PlanFragment::getId) - .filter(planFragmentId -> stageExecutions.containsKey(getStageId(planFragmentId))) - .collect(toImmutableSet())); + // Run adaptive planner only if the stage is ready for execution, and it was not ready before. + // The second condition ensures that we don't repeatedly re-optimize the plan if the stage was + // already ready for execution. + if (isReadyForExecutionResult.isReadyForExecution() + && (oldValue == null || !oldValue.isReadyForExecution())) { + return adaptivePlanner.optimize(plan, createRuntimeInfoProvider()); } } return plan; } + private RuntimeInfoProvider createRuntimeInfoProvider() + { + ImmutableMap.Builder stageRuntimeOutputStats = ImmutableMap.builder(); + ImmutableMap.Builder planFragments = ImmutableMap.builder(); + planInTopologicalOrder.forEach(subPlan -> planFragments.put(subPlan.getFragment().getId(), subPlan.getFragment())); + stageExecutions.forEach((stageId, stageExecution) -> { + if (isStageRuntimeStatsReady(stageExecution)) { + OutputStatsEstimateResult runtimeOutputStats = stageExecution.getOutputStats(stageExecutions::get, false).get(); + stageRuntimeOutputStats.put( + stageExecution.getStageFragment().getId(), + runtimeOutputStats); + } + }); + + return new StaticRuntimeInfoProvider(stageRuntimeOutputStats.buildOrThrow(), planFragments.buildOrThrow()); + } + + private boolean isStageRuntimeStatsReady(StageExecution stageExecution) + { + return stageExecution + .getOutputStats(stageExecutions::get, false) + .map(OutputStatsEstimateResult::isAccurate) + .orElse(false); + } + private void updateStageExecutions() { Set currentPlanStages = new HashSet<>(); @@ -1987,6 +1976,11 @@ public PlanFragmentId getStageFragmentId() return stage.getFragment().getId(); } + public PlanFragment getStageFragment() + { + return stage.getFragment(); + } + public StageState getState() { return stage.getState(); @@ -2435,7 +2429,7 @@ public Optional getOutputStats(Function getEstimatedOutputStats( EventDrivenFaultTolerantQueryScheduler.StageExecution stageExecution, Function stageExecutionLookup, @@ -32,11 +34,12 @@ Optional getEstimatedOutputStats( record OutputStatsEstimateResult( OutputDataSizeEstimate outputDataSizeEstimate, long outputRowCountEstimate, - String kind) + String kind, + boolean isAccurate) { - OutputStatsEstimateResult(ImmutableLongArray partitionDataSizes, long outputRowCountEstimate, String kind) + public OutputStatsEstimateResult(ImmutableLongArray partitionDataSizes, long outputRowCountEstimate, String kind, boolean isAccurate) { - this(new OutputDataSizeEstimate(partitionDataSizes), outputRowCountEstimate, kind); + this(new OutputDataSizeEstimate(partitionDataSizes), outputRowCountEstimate, kind, isAccurate); } public OutputStatsEstimateResult @@ -44,5 +47,15 @@ record OutputStatsEstimateResult( requireNonNull(outputDataSizeEstimate, "outputDataSizeEstimate is null"); requireNonNull(kind, "kind is null"); } + + public static OutputStatsEstimateResult unknown() + { + return UNKNOWN; + } + + public boolean isUnknown() + { + return this == UNKNOWN; + } } } diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/QueryExplainer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/QueryExplainer.java index 6854140023a0..23dcfa0beb30 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/QueryExplainer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/QueryExplainer.java @@ -15,6 +15,7 @@ import io.trino.Session; import io.trino.client.NodeVersion; +import io.trino.cost.CachingTableStatsProvider; import io.trino.cost.CostCalculator; import io.trino.cost.StatsCalculator; import io.trino.execution.querystats.PlanOptimizersStatsCollector; @@ -76,7 +77,7 @@ public class QueryExplainer CostCalculator costCalculator, NodeVersion version) { - this.planOptimizers = requireNonNull(planOptimizersFactory.get(), "planOptimizers is null"); + this.planOptimizers = requireNonNull(planOptimizersFactory.getPlanOptimizers(), "planOptimizers is null"); this.planFragmenter = requireNonNull(planFragmenter, "planFragmenter is null"); this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); this.analyzerFactory = requireNonNull(analyzerFactory, "analyzerFactory is null"); @@ -172,7 +173,8 @@ public Plan getLogicalPlan(Session session, Statement statement, List planOptimizers; + private final PlanFragmenter planFragmenter; + private final PlanSanityChecker planSanityChecker; + private final IrTypeAnalyzer typeAnalyzer; + private final WarningCollector warningCollector; + private final PlanOptimizersStatsCollector planOptimizersStatsCollector; + private final CachingTableStatsProvider tableStatsProvider; + + public AdaptivePlanner( + Session session, + PlannerContext plannerContext, + List planOptimizers, + PlanFragmenter planFragmenter, + PlanSanityChecker planSanityChecker, + IrTypeAnalyzer typeAnalyzer, + WarningCollector warningCollector, + PlanOptimizersStatsCollector planOptimizersStatsCollector, + CachingTableStatsProvider tableStatsProvider) + { + this.session = requireNonNull(session, "session is null"); + this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); + this.planOptimizers = requireNonNull(planOptimizers, "planOptimizers is null"); + this.planFragmenter = requireNonNull(planFragmenter, "planFragmenter is null"); + this.planSanityChecker = requireNonNull(planSanityChecker, "planSanityChecker is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); + this.warningCollector = requireNonNull(warningCollector, "warningCollector is null"); + this.planOptimizersStatsCollector = requireNonNull(planOptimizersStatsCollector, "planOptimizersStatsCollector is null"); + this.tableStatsProvider = requireNonNull(tableStatsProvider, "tableStatsProvider is null"); + } + + public SubPlan optimize(SubPlan root, RuntimeInfoProvider runtimeInfoProvider) + { + // No need to run optimizer since the root is already finished or its stats are almost accurate based on + // estimate by progress. + // TODO: We need add an ability to re-plan fragment whose stats are estimated by progress. + if (runtimeInfoProvider.getRuntimeOutputStats(root.getFragment().getId()).isAccurate()) { + return root; + } + + List subPlans = traverse(root).collect(toImmutableList()); + + // create a new fragment id allocator and symbol allocator + PlanFragmentIdAllocator fragmentIdAllocator = new PlanFragmentIdAllocator(getMaxPlanFragmentId(subPlans) + 1); + SymbolAllocator symbolAllocator = createSymbolAllocator(subPlans); + + // rewrite remote source nodes to exchange nodes, except for fragments which are finisher or whose stats are + // estimated by progress. + ReplaceUnchangedFragmentsWithRemoteSourcesRewriter rewriter = new ReplaceUnchangedFragmentsWithRemoteSourcesRewriter(runtimeInfoProvider, symbolAllocator.getTypes()); + PlanNode currentAdaptivePlan = rewriteWith(rewriter, root.getFragment().getRoot(), root.getChildren()); + + // Remove the adaptive plan node and replace it with initial plan + PlanNode initialPlan = getInitialPlan(currentAdaptivePlan); + // Remove the adaptive plan node and replace it with current plan + PlanNode currentPlan = getCurrentPlan(currentAdaptivePlan); + + // Collect the sub plans for each remote exchange and remote source node. We will use this map during + // re-fragmentation as a cache for all unchanged sub plans. + ExchangeSourceIdToSubPlanCollector exchangeSourceIdToSubPlanCollector = new ExchangeSourceIdToSubPlanCollector(); + currentPlan.accept(exchangeSourceIdToSubPlanCollector, subPlans); + Map exchangeSourceIdToSubPlan = exchangeSourceIdToSubPlanCollector.getExchangeSourceIdToSubPlan(); + + // optimize the current plan + PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(getMaxPlanId(currentPlan) + 1); + AdaptivePlanOptimizer.Result optimizationResult = optimizePlan(currentPlan, symbolAllocator, runtimeInfoProvider, idAllocator); + + // Check whether there are some changes in the plan after optimization + if (optimizationResult.changedPlanNodes().isEmpty()) { + return root; + } + + // Add the adaptive plan node recursively where initialPlan remain as it is and optimizedPlan as new currentPlan + PlanNode adaptivePlan = addAdaptivePlanNode(idAllocator, initialPlan, optimizationResult.plan(), symbolAllocator.getTypes(), optimizationResult.changedPlanNodes()); + // validate the adaptive plan + try (var ignored = scopedSpan(plannerContext.getTracer(), "validate-adaptive-plan")) { + planSanityChecker.validateAdaptivePlan(adaptivePlan, session, plannerContext, typeAnalyzer, symbolAllocator.getTypes(), warningCollector); + } + + // Fragment the adaptive plan + return planFragmenter.createSubPlans( + session, + new Plan(adaptivePlan, symbolAllocator.getTypes(), StatsAndCosts.empty()), + false, + warningCollector, + fragmentIdAllocator, + new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), adaptivePlan.getOutputSymbols()), + // We do not change the subPlans which have no changes and are not downstream of the + // changed plan nodes. This optimization is done to avoid unnecessary stage restart due to speculative + // execution. + getUnchangedSubPlans(adaptivePlan, optimizationResult.changedPlanNodes(), exchangeSourceIdToSubPlan)); + } + + private AdaptivePlanOptimizer.Result optimizePlan( + PlanNode plan, + SymbolAllocator symbolAllocator, + RuntimeInfoProvider runtimeInfoProvider, + PlanNodeIdAllocator idAllocator) + { + AdaptivePlanOptimizer.Result result = new AdaptivePlanOptimizer.Result(plan, Set.of()); + ImmutableSet.Builder changedPlanNodes = ImmutableSet.builder(); + for (AdaptivePlanOptimizer optimizer : planOptimizers) { + result = optimizer.optimizeAndMarkPlanChanges( + result.plan(), + new PlanOptimizer.Context( + session, + symbolAllocator.getTypes(), + symbolAllocator, + idAllocator, + warningCollector, + planOptimizersStatsCollector, + tableStatsProvider, + runtimeInfoProvider)); + changedPlanNodes.addAll(result.changedPlanNodes()); + } + return new AdaptivePlanOptimizer.Result(result.plan(), changedPlanNodes.build()); + } + + private PlanNode addAdaptivePlanNode( + PlanNodeIdAllocator idAllocator, + PlanNode initialPlan, + PlanNode optimizedPlanNode, + TypeProvider types, + Set changedPlanNodes) + { + // We should check optimizedPlanNode here instead of initialPlan since it is possible that new + // nodes have been added, and they aren't part of initialPlan. However, we should put the adaptive plan node + // above them. + if (changedPlanNodes.contains(optimizedPlanNode.getId())) { + return new AdaptivePlanNode( + idAllocator.getNextId(), + initialPlan, + getFilteredSymbols(initialPlan, types), + optimizedPlanNode); + } + + // This condition should always be true because if a plan node is changed, then it should be captured in the + // changedPlanNodes set based on the semantics of PlanOptimizer#optimizeAndMarkPlanChanges. + verify(initialPlan.getSources().size() == optimizedPlanNode.getSources().size()); + ImmutableList.Builder sources = ImmutableList.builder(); + for (int i = 0; i < initialPlan.getSources().size(); i++) { + PlanNode initialSource = initialPlan.getSources().get(i); + PlanNode optimizedSource = optimizedPlanNode.getSources().get(i); + sources.add(addAdaptivePlanNode(idAllocator, initialSource, optimizedSource, types, changedPlanNodes)); + } + return optimizedPlanNode.replaceChildren(sources.build()); + } + + private Map getUnchangedSubPlans( + PlanNode adaptivePlan, Set changedPlanIds, Map exchangeSourceIdToSubPlan) + { + Set changedPlanIdsWithDownstream = new HashSet<>(); + for (PlanNodeId changedId : changedPlanIds) { + changedPlanIdsWithDownstream.addAll(getDownstreamPlanNodeIds(adaptivePlan, changedId)); + } + + return exchangeSourceIdToSubPlan.entrySet().stream() + .filter(entry -> !changedPlanIdsWithDownstream.contains(entry.getKey().exchangeId()) + && !changedPlanIdsWithDownstream.contains(entry.getKey().sourceId())) + .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); + } + + private Set getDownstreamPlanNodeIds(PlanNode root, PlanNodeId id) + { + if (root.getId().equals(id)) { + return ImmutableSet.of(id); + } + Set upstreamNodes = new HashSet<>(); + root.getSources().stream() + .map(source -> getDownstreamPlanNodeIds(source, id)) + .forEach(upstreamNodes::addAll); + if (!upstreamNodes.isEmpty()) { + upstreamNodes.add(root.getId()); + } + return upstreamNodes; + } + + private PlanNode getCurrentPlan(PlanNode node) + { + return rewriteWith(new CurrentPlanRewriter(), node); + } + + private PlanNode getInitialPlan(PlanNode node) + { + return rewriteWith(new InitialPlanRewriter(), node); + } + + private SymbolAllocator createSymbolAllocator(List subPlans) + { + Map usedSymbols = new HashMap<>(); + subPlans.stream() + .map(SubPlan::getFragment) + .map(PlanFragment::getSymbols) + .forEach(usedSymbols::putAll); + return new SymbolAllocator(usedSymbols); + } + + private int getMaxPlanFragmentId(List subPlans) + { + return subPlans.stream() + .map(SubPlan::getFragment) + .map(PlanFragment::getId) + .mapToInt(fragmentId -> Integer.parseInt(fragmentId.toString())) + .max() + .orElseThrow(); + } + + private int getMaxPlanId(PlanNode node) + { + return traverse(node) + .map(PlanNode::getId) + .mapToInt(planNodeId -> Integer.parseInt(planNodeId.toString())) + .max() + .orElseThrow(); + } + + private Stream traverse(PlanNode node) + { + Iterable iterable = Traverser.forTree(PlanNode::getSources).depthFirstPreOrder(node); + return StreamSupport.stream(iterable.spliterator(), false); + } + + private Stream traverse(SubPlan subPlan) + { + Iterable iterable = Traverser.forTree(SubPlan::getChildren).depthFirstPreOrder(subPlan); + return StreamSupport.stream(iterable.spliterator(), false); + } + + private static class ReplaceUnchangedFragmentsWithRemoteSourcesRewriter + extends SimplePlanRewriter> + { + private final RuntimeInfoProvider runtimeInfoProvider; + private final TypeProvider types; + + private ReplaceUnchangedFragmentsWithRemoteSourcesRewriter(RuntimeInfoProvider runtimeInfoProvider, TypeProvider types) + { + this.runtimeInfoProvider = requireNonNull(runtimeInfoProvider, "runtimeInfoProvider is null"); + this.types = requireNonNull(types, "types is null"); + } + + @Override + public PlanNode visitAdaptivePlanNode(AdaptivePlanNode node, RewriteContext> context) + { + // It is possible that the initial plan also contains remote source nodes, therefore we need to + // rewrite them as well. + PlanNode initialPlan = context.rewrite(node.getInitialPlan(), context.get()); + PlanNode currentPlan = context.rewrite(node.getCurrentPlan(), context.get()); + return new AdaptivePlanNode(node.getId(), initialPlan, getFilteredSymbols(initialPlan, types), currentPlan); + } + + @Override + public PlanNode visitRemoteSource(RemoteSourceNode node, RewriteContext> context) + { + // We won't run optimizer rules on sub plans which are either finished or their stats are almost accurate + // based are estimated by progress. + // TODO: We need add an ability to re-plan fragment whose stats are estimated by progress. + if (node.getSourceFragmentIds().stream() + .anyMatch(planFragmentId -> runtimeInfoProvider.getRuntimeOutputStats(planFragmentId).isAccurate())) { + return node; + } + + List sourceSubPlans = context.get().stream() + .filter(subPlan -> node.getSourceFragmentIds().contains(subPlan.getFragment().getId())) + .collect(toImmutableList()); + + ImmutableList.Builder sourceNodesBuilder = ImmutableList.builder(); + for (SubPlan sourceSubPlan : sourceSubPlans) { + PlanNode sourceNode = context.rewrite(sourceSubPlan.getFragment().getRoot(), sourceSubPlan.getChildren()); + sourceNodesBuilder.add(sourceNode); + } + + List sourceNodes = sourceNodesBuilder.build(); + List> inputs = sourceNodes.stream().map(PlanNode::getOutputSymbols).collect(toImmutableList()); + PartitioningScheme partitioningScheme = node.getSourceFragmentIds().stream() + .map(runtimeInfoProvider::getPlanFragment) + .map(PlanFragment::getOutputPartitioningScheme) + .findFirst() + .orElseThrow(); + + return new ExchangeNode( + node.getId(), + node.getExchangeType(), + REMOTE, + partitioningScheme, + sourceNodes, + inputs, + node.getOrderingScheme()); + } + } + + private static class ExchangeSourceIdToSubPlanCollector + extends SimplePlanVisitor> + { + private final Map exchangeSourceIdToSubPlan = new HashMap<>(); + + @Override + public Void visitExchange(ExchangeNode node, List context) + { + // Process the source nodes first + visitPlan(node, context); + + // No need to process the exchange node if it is not a remote exchange + if (node.getScope() != REMOTE) { + return null; + } + + // Find the sub plans for this exchange node + List sourceIds = node.getSources().stream().map(PlanNode::getId).collect(toImmutableList()); + List sourceSubPlans = context.stream() + .filter(subPlan -> sourceIds.contains(subPlan.getFragment().getRoot().getId())) + .collect(toImmutableList()); + verify( + sourceSubPlans.size() == sourceIds.size(), + "Source subPlans not found for exchange node"); + + for (SubPlan sourceSubPlan : sourceSubPlans) { + PlanNodeId sourceId = sourceSubPlan.getFragment().getRoot().getId(); + exchangeSourceIdToSubPlan.put(new ExchangeSourceId(node.getId(), sourceId), sourceSubPlan); + } + return null; + } + + @Override + public Void visitRemoteSource(RemoteSourceNode node, List context) + { + List sourceSubPlans = context.stream() + .filter(subPlan -> node.getSourceFragmentIds().contains(subPlan.getFragment().getId())) + .collect(toImmutableList()); + + for (SubPlan sourceSubPlan : sourceSubPlans) { + PlanNodeId sourceId = sourceSubPlan.getFragment().getRoot().getId(); + exchangeSourceIdToSubPlan.put(new ExchangeSourceId(node.getId(), sourceId), sourceSubPlan); + } + return null; + } + + public Map getExchangeSourceIdToSubPlan() + { + return ImmutableMap.copyOf(exchangeSourceIdToSubPlan); + } + } + + private static class CurrentPlanRewriter + extends SimplePlanRewriter> + { + @Override + public PlanNode visitAdaptivePlanNode(AdaptivePlanNode node, RewriteContext> context) + { + verify( + !containsAdaptivePlanNode(node.getCurrentPlan()), + "Adaptive plan node cannot have a nested adaptive plan node"); + return node.getCurrentPlan(); + } + } + + private static class InitialPlanRewriter + extends SimplePlanRewriter> + { + @Override + public PlanNode visitAdaptivePlanNode(AdaptivePlanNode node, RewriteContext> context) + { + verify( + !containsAdaptivePlanNode(node.getInitialPlan()), + "Adaptive plan node cannot have a nested adaptive plan node"); + return node.getInitialPlan(); + } + } + + private static boolean containsAdaptivePlanNode(PlanNode node) + { + return PlanNodeSearcher.searchFrom(node) + .whereIsInstanceOfAny(AdaptivePlanNode.class) + .matches(); + } + + private static Map getFilteredSymbols(PlanNode node, TypeProvider types) + { + Set dependencies = SymbolsExtractor.extractOutputSymbols(node); + return Maps.filterKeys(types.allTypes(), in(dependencies)); + } + + public record ExchangeSourceId(PlanNodeId exchangeId, PlanNodeId sourceId) + { + public ExchangeSourceId + { + requireNonNull(exchangeId, "exchangeId is null"); + requireNonNull(sourceId, "sourceId is null"); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java index 005e6374f906..f23fbea131b8 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java @@ -188,6 +188,7 @@ import io.trino.sql.gen.OrderingCompiler; import io.trino.sql.gen.PageFunctionCompiler; import io.trino.sql.planner.optimizations.IndexJoinOptimizer; +import io.trino.sql.planner.plan.AdaptivePlanNode; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.AggregationNode.Aggregation; import io.trino.sql.planner.plan.AggregationNode.Step; @@ -3723,6 +3724,12 @@ else if (context.getDriverInstanceCount().isPresent()) { return new PhysicalOperation(new LocalExchangeSourceOperatorFactory(context.getNextOperatorId(), node.getId(), localExchange), makeLayout(node), context); } + @Override + public PhysicalOperation visitAdaptivePlanNode(AdaptivePlanNode node, LocalExecutionPlanContext context) + { + return node.getCurrentPlan().accept(this, context); + } + @Override protected PhysicalOperation visitPlan(PlanNode node, LocalExecutionPlanContext context) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java index c80b0aa5b4fc..6afe9dc37a84 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java @@ -26,6 +26,7 @@ import io.trino.cost.CachingTableStatsProvider; import io.trino.cost.CostCalculator; import io.trino.cost.CostProvider; +import io.trino.cost.RuntimeInfoProvider; import io.trino.cost.StatsAndCosts; import io.trino.cost.StatsCalculator; import io.trino.cost.StatsProvider; @@ -185,6 +186,7 @@ public enum Stage private final CostCalculator costCalculator; private final WarningCollector warningCollector; private final PlanOptimizersStatsCollector planOptimizersStatsCollector; + private final CachingTableStatsProvider tableStatsProvider; public LogicalPlanner( Session session, @@ -195,9 +197,10 @@ public LogicalPlanner( StatsCalculator statsCalculator, CostCalculator costCalculator, WarningCollector warningCollector, - PlanOptimizersStatsCollector planOptimizersStatsCollector) + PlanOptimizersStatsCollector planOptimizersStatsCollector, + CachingTableStatsProvider tableStatsProvider) { - this(session, planOptimizers, DISTRIBUTED_PLAN_SANITY_CHECKER, idAllocator, plannerContext, typeAnalyzer, statsCalculator, costCalculator, warningCollector, planOptimizersStatsCollector); + this(session, planOptimizers, DISTRIBUTED_PLAN_SANITY_CHECKER, idAllocator, plannerContext, typeAnalyzer, statsCalculator, costCalculator, warningCollector, planOptimizersStatsCollector, tableStatsProvider); } public LogicalPlanner( @@ -210,7 +213,8 @@ public LogicalPlanner( StatsCalculator statsCalculator, CostCalculator costCalculator, WarningCollector warningCollector, - PlanOptimizersStatsCollector planOptimizersStatsCollector) + PlanOptimizersStatsCollector planOptimizersStatsCollector, + CachingTableStatsProvider tableStatsProvider) { this.session = requireNonNull(session, "session is null"); this.planOptimizers = requireNonNull(planOptimizers, "planOptimizers is null"); @@ -224,6 +228,7 @@ public LogicalPlanner( this.costCalculator = requireNonNull(costCalculator, "costCalculator is null"); this.warningCollector = requireNonNull(warningCollector, "warningCollector is null"); this.planOptimizersStatsCollector = requireNonNull(planOptimizersStatsCollector, "planOptimizersStatsCollector is null"); + this.tableStatsProvider = requireNonNull(tableStatsProvider, "tableStatsProvider is null"); } public Plan plan(Analysis analysis) @@ -259,8 +264,6 @@ public Plan plan(Analysis analysis, Stage stage, boolean collectPlanStatistics) planSanityChecker.validateIntermediatePlan(root, session, plannerContext, typeAnalyzer, symbolAllocator.getTypes(), warningCollector); } - CachingTableStatsProvider tableStatsProvider = new CachingTableStatsProvider(metadata, session); - if (stage.ordinal() >= OPTIMIZED.ordinal()) { try (var ignored = scopedSpan(plannerContext.getTracer(), "optimizer")) { for (PlanOptimizer optimizer : planOptimizers) { @@ -303,7 +306,7 @@ private PlanNode runOptimizer(PlanNode root, TableStatsProvider tableStatsProvid { PlanNode result; try (var ignored = optimizerSpan(optimizer)) { - result = optimizer.optimize(root, new PlanOptimizer.Context(session, symbolAllocator.getTypes(), symbolAllocator, idAllocator, warningCollector, planOptimizersStatsCollector, tableStatsProvider)); + result = optimizer.optimize(root, new PlanOptimizer.Context(session, symbolAllocator.getTypes(), symbolAllocator, idAllocator, warningCollector, planOptimizersStatsCollector, tableStatsProvider, RuntimeInfoProvider.noImplementation())); } if (result == null) { throw new NullPointerException(optimizer.getClass().getName() + " returned a null plan"); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java index 48ccb73292b3..0849bf65995a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java @@ -14,6 +14,7 @@ package io.trino.sql.planner; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Maps; import com.google.inject.Inject; @@ -35,6 +36,8 @@ import io.trino.spi.TrinoWarning; import io.trino.spi.connector.ConnectorPartitioningHandle; import io.trino.spi.type.Type; +import io.trino.sql.planner.optimizations.PlanNodeSearcher; +import io.trino.sql.planner.plan.AdaptivePlanNode; import io.trino.sql.planner.plan.ExchangeNode; import io.trino.sql.planner.plan.ExplainAnalyzeNode; import io.trino.sql.planner.plan.MergeWriterNode; @@ -64,6 +67,7 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.stream.Stream; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; @@ -74,6 +78,7 @@ import static io.trino.SystemSessionProperties.isForceSingleNodeOutput; import static io.trino.spi.StandardErrorCode.QUERY_HAS_TOO_MANY_STAGES; import static io.trino.spi.connector.StandardWarningCode.TOO_MANY_STAGES; +import static io.trino.sql.planner.AdaptivePlanner.ExchangeSourceId; import static io.trino.sql.planner.SchedulingOrderVisitor.scheduleOrder; import static io.trino.sql.planner.SystemPartitioningHandle.COORDINATOR_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; @@ -119,15 +124,42 @@ public PlanFragmenter( } public SubPlan createSubPlans(Session session, Plan plan, boolean forceSingleNode, WarningCollector warningCollector) + { + return createSubPlans( + session, + plan, + forceSingleNode, + warningCollector, + new PlanFragmentIdAllocator(0), + new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), plan.getRoot().getOutputSymbols()), + ImmutableMap.of()); + } + + public SubPlan createSubPlans( + Session session, + Plan plan, + boolean forceSingleNode, + WarningCollector warningCollector, + PlanFragmentIdAllocator idAllocator, + PartitioningScheme outputPartitioningScheme, + Map unchangedSubPlans) { List activeCatalogs = transactionManager.getActiveCatalogs(session.getTransactionId().orElseThrow()).stream() .map(CatalogInfo::getCatalogHandle) .flatMap(catalogHandle -> catalogManager.getCatalogProperties(catalogHandle).stream()) .collect(toImmutableList()); List languageScalarFunctions = languageFunctionManager.serializeFunctionsForWorkers(session); - Fragmenter fragmenter = new Fragmenter(session, metadata, functionManager, plan.getTypes(), plan.getStatsAndCosts(), activeCatalogs, languageScalarFunctions); - - FragmentProperties properties = new FragmentProperties(new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), plan.getRoot().getOutputSymbols())); + Fragmenter fragmenter = new Fragmenter( + session, + metadata, + functionManager, + plan.getTypes(), + plan.getStatsAndCosts(), + activeCatalogs, + languageScalarFunctions, + idAllocator, + unchangedSubPlans); + FragmentProperties properties = new FragmentProperties(outputPartitioningScheme); if (forceSingleNode || isForceSingleNodeOutput(session)) { properties = properties.setSingleNodeDistribution(); } @@ -214,8 +246,6 @@ private SubPlan reassignPartitioningHandleIfNecessaryHelper(Session session, Sub private static class Fragmenter extends SimplePlanRewriter { - private static final int ROOT_FRAGMENT_ID = 0; - private final Session session; private final Metadata metadata; private final FunctionManager functionManager; @@ -223,7 +253,9 @@ private static class Fragmenter private final StatsAndCosts statsAndCosts; private final List activeCatalogs; private final List languageFunctions; - private final PlanFragmentIdAllocator idAllocator = new PlanFragmentIdAllocator(ROOT_FRAGMENT_ID + 1); + private final PlanFragmentIdAllocator idAllocator; + private final Map unchangedSubPlans; + private final PlanFragmentId rootFragmentID; public Fragmenter( Session session, @@ -232,7 +264,9 @@ public Fragmenter( TypeProvider types, StatsAndCosts statsAndCosts, List activeCatalogs, - List languageFunctions) + List languageFunctions, + PlanFragmentIdAllocator idAllocator, + Map unchangedSubPlans) { this.session = requireNonNull(session, "session is null"); this.metadata = requireNonNull(metadata, "metadata is null"); @@ -241,11 +275,14 @@ public Fragmenter( this.statsAndCosts = requireNonNull(statsAndCosts, "statsAndCosts is null"); this.activeCatalogs = requireNonNull(activeCatalogs, "activeCatalogs is null"); this.languageFunctions = requireNonNull(languageFunctions, "languageFunctions is null"); + this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); + this.unchangedSubPlans = ImmutableMap.copyOf(requireNonNull(unchangedSubPlans, "unchangedSubPlans is null")); + this.rootFragmentID = idAllocator.getNextId(); } public SubPlan buildRootFragment(PlanNode root, FragmentProperties properties) { - return buildFragment(root, properties, new PlanFragmentId(String.valueOf(ROOT_FRAGMENT_ID))); + return buildFragment(root, properties, rootFragmentID); } private SubPlan buildFragment(PlanNode root, FragmentProperties properties, PlanFragmentId fragmentId) @@ -406,6 +443,67 @@ public PlanNode visitTableFunctionProcessor(TableFunctionProcessorNode node, Rew return context.defaultRewrite(node, context.get()); } + @Override + public PlanNode visitAdaptivePlanNode(AdaptivePlanNode node, RewriteContext context) + { + // This is needed to make the initial plan more concise by replacing the exchange nodes with + // remote source nodes for stages that are not being changed by the adaptive planner in the + // case of FTE. This is a cosmetic change and does not affect the execution of the plan. This is + // useful for easier debugging and understanding of the plan. + // Example: + // - Before: + // - AdaptivePlan + // - InitialPlan + // - SomeInitialPlanNode + // - Exchange + // - TableScan + // - CurrentPlan + // - NewPlanNode + // - RemoteSourceNode(1) + // - After: + // - AdaptivePlan + // - InitialPlan + // - SomeInitialPlanNode + // - RemoteSourceNode(1) + // - CurrentPlan + // - NewPlanNode + // - RemoteSourceNode(1) + // As shown in the example, the exchange node is replaced with a remote source node in the initial plan. + + AdaptivePlanNode adaptivePlan = (AdaptivePlanNode) context.defaultRewrite(node, context.get()); + List remoteSourceNodes = getAllRemoteSourceNodes(adaptivePlan.getCurrentPlan(), context.get().getChildren()); + ExchangeNodeToRemoteSourceRewriter rewriter = new ExchangeNodeToRemoteSourceRewriter(remoteSourceNodes, unchangedSubPlans.keySet()); + PlanNode newInitialPlan = SimplePlanRewriter.rewriteWith(rewriter, adaptivePlan.getInitialPlan()); + Set dependencies = SymbolsExtractor.extractOutputSymbols(newInitialPlan); + Map filteredSymbols = Maps.filterKeys(types.allTypes(), in(dependencies)); + return new AdaptivePlanNode(adaptivePlan.getId(), newInitialPlan, filteredSymbols, adaptivePlan.getCurrentPlan()); + } + + @Override + public PlanNode visitRemoteSource(RemoteSourceNode node, RewriteContext context) + { + List completedChildren = unchangedSubPlans.values().stream() + .filter(subPlan -> node.getSourceFragmentIds().contains(subPlan.getFragment().getId())) + .collect(toImmutableList()); + checkState(completedChildren.size() == node.getSourceFragmentIds().size(), "completedSubPlans should contain all remote source children"); + + if (node.getExchangeType() == ExchangeNode.Type.GATHER) { + context.get().setSingleNodeDistribution(); + } + else if (node.getExchangeType() == ExchangeNode.Type.REPARTITION) { + for (SubPlan child : completedChildren) { + PartitioningScheme partitioningScheme = child.getFragment().getOutputPartitioningScheme(); + context.get().setDistribution( + partitioningScheme.getPartitioning().getHandle(), + partitioningScheme.getPartitionCount(), + metadata, + session); + } + } + context.get().addChildren(completedChildren); + return node; + } + @Override public PlanNode visitExchange(ExchangeNode exchange, RewriteContext context) { @@ -431,7 +529,11 @@ else if (exchange.getType() == ExchangeNode.Type.REPARTITION) { for (int sourceIndex = 0; sourceIndex < exchange.getSources().size(); sourceIndex++) { FragmentProperties childProperties = new FragmentProperties(partitioningScheme.translateOutputLayout(exchange.getInputs().get(sourceIndex))); childrenProperties.add(childProperties); - childrenBuilder.add(buildSubPlan(exchange.getSources().get(sourceIndex), childProperties, context)); + childrenBuilder.add(buildSubPlan( + exchange.getSources().get(sourceIndex), + new ExchangeSourceId(exchange.getId(), exchange.getSources().get(sourceIndex).getId()), + childProperties, + context)); } List children = childrenBuilder.build(); @@ -451,13 +553,29 @@ else if (exchange.getType() == ExchangeNode.Type.REPARTITION) { isWorkerCoordinatorBoundary(context.get(), childrenProperties.build()) ? getRetryPolicy(session) : RetryPolicy.NONE); } - private SubPlan buildSubPlan(PlanNode node, FragmentProperties properties, RewriteContext context) + private SubPlan buildSubPlan(PlanNode node, ExchangeSourceId exchangeSourceId, FragmentProperties properties, RewriteContext context) { + SubPlan subPlan = unchangedSubPlans.get(exchangeSourceId); + if (subPlan != null) { + return subPlan; + } PlanFragmentId planFragmentId = idAllocator.getNextId(); PlanNode child = context.rewrite(node, properties); return buildFragment(child, properties, planFragmentId); } + private List getAllRemoteSourceNodes(PlanNode node, List children) + { + return Stream.concat( + children.stream() + .map(SubPlan::getFragment) + .flatMap(fragment -> fragment.getRemoteSourceNodes().stream()), + PlanNodeSearcher.searchFrom(node) + .whereIsInstanceOfAny(RemoteSourceNode.class) + .findAll().stream()) + .collect(toImmutableList()); + } + private static boolean isWorkerCoordinatorBoundary(FragmentProperties fragmentProperties, List childFragmentsProperties) { if (!fragmentProperties.getPartitioningHandle().isCoordinatorOnly()) { @@ -711,4 +829,34 @@ public PlanNode visitTableScan(TableScanNode node, RewriteContext context) node.getUseConnectorNodePartitioning()); } } + + private static final class ExchangeNodeToRemoteSourceRewriter + extends SimplePlanRewriter + { + private final List remoteSourceNodes; + private final Set unchangedRemoteExchanges; + + public ExchangeNodeToRemoteSourceRewriter(List remoteSourceNodes, Set unchangedRemoteExchanges) + { + this.remoteSourceNodes = requireNonNull(remoteSourceNodes, "remoteSourceNodes is null"); + this.unchangedRemoteExchanges = requireNonNull(unchangedRemoteExchanges, "unchangedRemoteExchanges is null"); + } + + @Override + public PlanNode visitExchange(ExchangeNode node, RewriteContext context) + { + if (node.getScope() != REMOTE || !isUnchangedFragment(node.getId())) { + return context.defaultRewrite(node, context.get()); + } + return remoteSourceNodes.stream() + .filter(remoteSource -> remoteSource.getId().equals(node.getId())) + .findFirst() + .orElse(node); + } + + private boolean isUnchangedFragment(PlanNodeId exchangeID) + { + return unchangedRemoteExchanges.stream().anyMatch(fragment -> fragment.exchangeId().equals(exchangeID)); + } + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java index c91798a8d396..e4ec21272be9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java @@ -238,6 +238,8 @@ import io.trino.sql.planner.iterative.rule.UnwrapSingleColumnRowInApply; import io.trino.sql.planner.iterative.rule.UnwrapYearInComparison; import io.trino.sql.planner.iterative.rule.UseNonPartitionedJoinLookupSource; +import io.trino.sql.planner.optimizations.AdaptivePartitioning; +import io.trino.sql.planner.optimizations.AdaptivePlanOptimizer; import io.trino.sql.planner.optimizations.AddExchanges; import io.trino.sql.planner.optimizations.AddLocalExchanges; import io.trino.sql.planner.optimizations.BeginTableWrite; @@ -266,6 +268,7 @@ public class PlanOptimizers implements PlanOptimizersFactory { private final List optimizers; + private final List adaptivePlanOptimizers; private final RuleStatsRecorder ruleStats; private final OptimizerStatsRecorder optimizerStats = new OptimizerStatsRecorder(); @@ -1007,6 +1010,7 @@ public PlanOptimizers( // TODO: figure out how to improve the set flattening optimizer so that it can run at any point this.optimizers = builder.build(); + this.adaptivePlanOptimizers = ImmutableList.of(new AdaptivePartitioning()); } @VisibleForTesting @@ -1067,11 +1071,17 @@ public static Set> columnPruningRules(Metadata metadata) } @Override - public List get() + public List getPlanOptimizers() { return optimizers; } + @Override + public List getAdaptivePlanOptimizers() + { + return adaptivePlanOptimizers; + } + @Override public Map, OptimizerStats> getOptimizerStats() { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizersFactory.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizersFactory.java index b25870547499..8aa845326e9d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizersFactory.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizersFactory.java @@ -15,6 +15,7 @@ import com.google.common.collect.ImmutableMap; import io.trino.sql.planner.iterative.RuleStats; +import io.trino.sql.planner.optimizations.AdaptivePlanOptimizer; import io.trino.sql.planner.optimizations.OptimizerStats; import io.trino.sql.planner.optimizations.PlanOptimizer; @@ -23,7 +24,9 @@ public interface PlanOptimizersFactory { - List get(); + List getPlanOptimizers(); + + List getAdaptivePlanOptimizers(); default Map, OptimizerStats> getOptimizerStats() { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/RuntimeAdaptivePartitioningRewriter.java b/core/trino-main/src/main/java/io/trino/sql/planner/RuntimeAdaptivePartitioningRewriter.java deleted file mode 100644 index 7a88e1b76e48..000000000000 --- a/core/trino-main/src/main/java/io/trino/sql/planner/RuntimeAdaptivePartitioningRewriter.java +++ /dev/null @@ -1,218 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.sql.planner; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.google.common.graph.Traverser; -import io.trino.sql.planner.plan.PlanFragmentId; -import io.trino.sql.planner.plan.PlanNode; -import io.trino.sql.planner.plan.RemoteSourceNode; -import io.trino.sql.planner.plan.SimplePlanRewriter; - -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.stream.Stream; -import java.util.stream.StreamSupport; - -import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.collect.Iterators.getOnlyElement; -import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION; -import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; -import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_HASH_DISTRIBUTION; -import static io.trino.sql.planner.plan.ExchangeNode.Type.REPLICATE; -import static io.trino.sql.planner.plan.SimplePlanRewriter.rewriteWith; -import static java.util.Objects.requireNonNull; - -public final class RuntimeAdaptivePartitioningRewriter -{ - private RuntimeAdaptivePartitioningRewriter() {} - - public static SubPlan overridePartitionCountRecursively( - SubPlan subPlan, - int oldPartitionCount, - int newPartitionCount, - PlanFragmentIdAllocator planFragmentIdAllocator, - PlanNodeIdAllocator planNodeIdAllocator, - Set startedFragments) - { - PlanFragment fragment = subPlan.getFragment(); - if (startedFragments.contains(fragment.getId())) { - // already started, nothing to change for subPlan and its descendants - return subPlan; - } - - PartitioningScheme outputPartitioningScheme = fragment.getOutputPartitioningScheme(); - if (outputPartitioningScheme.getPartitioning().getHandle().equals(FIXED_BROADCAST_DISTRIBUTION)) { - // the result of the subtree will be broadcast, then no need to change partition count for the subtree - // as the planner will only broadcast fragment output if it sees input data is small or filter ratio is high - return subPlan; - } - if (producesHashPartitionedOutput(fragment)) { - fragment = fragment.withOutputPartitioningScheme(outputPartitioningScheme.withPartitionCount(Optional.of(newPartitionCount))); - } - - if (consumesHashPartitionedInput(fragment)) { - fragment = fragment.withPartitionCount(Optional.of(newPartitionCount)); - } - else { - // no input partitioning, then no need to insert extra exchanges to sources - return new SubPlan( - fragment, - subPlan.getChildren().stream() - .map(child -> overridePartitionCountRecursively( - child, - oldPartitionCount, - newPartitionCount, - planFragmentIdAllocator, - planNodeIdAllocator, - startedFragments)) - .collect(toImmutableList())); - } - - // insert extra exchanges to sources - ImmutableList.Builder newSources = ImmutableList.builder(); - ImmutableMap.Builder runtimeAdaptivePlanFragmentIdMapping = ImmutableMap.builder(); - for (SubPlan source : subPlan.getChildren()) { - PlanFragment sourceFragment = source.getFragment(); - RemoteSourceNode sourceRemoteSourceNode = getOnlyElement(fragment.getRemoteSourceNodes().stream() - .filter(remoteSourceNode -> remoteSourceNode.getSourceFragmentIds().contains(sourceFragment.getId())) - .iterator()); - requireNonNull(sourceRemoteSourceNode, "sourceRemoteSourceNode is null"); - if (sourceRemoteSourceNode.getExchangeType() == REPLICATE) { - // since exchange type is REPLICATE, also no need to change partition count for the subtree as the - // planner will only broadcast fragment output if it sees input data is small or filter ratio is high - newSources.add(source); - continue; - } - if (!startedFragments.contains(sourceFragment.getId())) { - // source not started yet, then no need to insert extra exchanges to sources - newSources.add(overridePartitionCountRecursively( - source, - oldPartitionCount, - newPartitionCount, - planFragmentIdAllocator, - planNodeIdAllocator, - startedFragments)); - runtimeAdaptivePlanFragmentIdMapping.put(sourceFragment.getId(), sourceFragment.getId()); - continue; - } - RemoteSourceNode runtimeAdaptiveRemoteSourceNode = new RemoteSourceNode( - planNodeIdAllocator.getNextId(), - sourceFragment.getId(), - sourceFragment.getOutputPartitioningScheme().getOutputLayout(), - sourceRemoteSourceNode.getOrderingScheme(), - sourceRemoteSourceNode.getExchangeType(), - sourceRemoteSourceNode.getRetryPolicy()); - PlanFragment runtimeAdaptivePlanFragment = new PlanFragment( - planFragmentIdAllocator.getNextId(), - runtimeAdaptiveRemoteSourceNode, - sourceFragment.getSymbols(), - FIXED_HASH_DISTRIBUTION, - Optional.of(oldPartitionCount), - ImmutableList.of(), // partitioned sources will be empty as the fragment will only read from `runtimeAdaptiveRemoteSourceNode` - sourceFragment.getOutputPartitioningScheme().withPartitionCount(Optional.of(newPartitionCount)), - sourceFragment.getStatsAndCosts(), - sourceFragment.getActiveCatalogs(), - sourceFragment.getLanguageFunctions(), - sourceFragment.getJsonRepresentation()); - SubPlan newSource = new SubPlan( - runtimeAdaptivePlanFragment, - ImmutableList.of(overridePartitionCountRecursively( - source, - oldPartitionCount, - newPartitionCount, - planFragmentIdAllocator, - planNodeIdAllocator, - startedFragments))); - newSources.add(newSource); - runtimeAdaptivePlanFragmentIdMapping.put(sourceFragment.getId(), runtimeAdaptivePlanFragment.getId()); - } - - return new SubPlan( - fragment.withRoot(rewriteWith( - new UpdateRemoteSourceFragmentIdsRewriter(runtimeAdaptivePlanFragmentIdMapping.buildOrThrow()), - fragment.getRoot())), - newSources.build()); - } - - public static boolean consumesHashPartitionedInput(PlanFragment fragment) - { - return isPartitioned(fragment.getPartitioning()); - } - - public static boolean producesHashPartitionedOutput(PlanFragment fragment) - { - return isPartitioned(fragment.getOutputPartitioningScheme().getPartitioning().getHandle()); - } - - public static int getMaxPlanFragmentId(List subPlans) - { - return subPlans.stream() - .map(SubPlan::getFragment) - .map(PlanFragment::getId) - .mapToInt(fragmentId -> Integer.parseInt(fragmentId.toString())) - .max() - .orElseThrow(); - } - - public static int getMaxPlanId(List subPlans) - { - return subPlans.stream() - .map(SubPlan::getFragment) - .map(PlanFragment::getRoot) - .mapToInt(root -> traverse(root) - .map(PlanNode::getId) - .mapToInt(planNodeId -> Integer.parseInt(planNodeId.toString())) - .max() - .orElseThrow()) - .max() - .orElseThrow(); - } - - private static boolean isPartitioned(PartitioningHandle partitioningHandle) - { - return partitioningHandle.equals(FIXED_HASH_DISTRIBUTION) || partitioningHandle.equals(SCALED_WRITER_HASH_DISTRIBUTION); - } - - private static Stream traverse(PlanNode node) - { - Iterable iterable = Traverser.forTree(PlanNode::getSources).depthFirstPreOrder(node); - return StreamSupport.stream(iterable.spliterator(), false); - } - - private static class UpdateRemoteSourceFragmentIdsRewriter - extends SimplePlanRewriter - { - private final Map runtimeAdaptivePlanFragmentIdMapping; - - public UpdateRemoteSourceFragmentIdsRewriter(Map runtimeAdaptivePlanFragmentIdMapping) - { - this.runtimeAdaptivePlanFragmentIdMapping = requireNonNull(runtimeAdaptivePlanFragmentIdMapping, "runtimeAdaptivePlanFragmentIdMapping is null"); - } - - @Override - public PlanNode visitRemoteSource(RemoteSourceNode node, RewriteContext context) - { - if (node.getExchangeType() == REPLICATE) { - return node; - } - return node.withSourceFragmentIds(node.getSourceFragmentIds().stream() - .map(runtimeAdaptivePlanFragmentIdMapping::get) - .collect(toImmutableList())); - } - } -} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/SplitSourceFactory.java b/core/trino-main/src/main/java/io/trino/sql/planner/SplitSourceFactory.java index 8c69a2cd6268..6c0681aca9a9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/SplitSourceFactory.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/SplitSourceFactory.java @@ -30,6 +30,7 @@ import io.trino.split.SplitSource; import io.trino.sql.DynamicFilters; import io.trino.sql.PlannerContext; +import io.trino.sql.planner.plan.AdaptivePlanNode; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.AssignUniqueId; import io.trino.sql.planner.plan.DistinctLimitNode; @@ -470,6 +471,12 @@ public Map visitExchange(ExchangeNode node, Void contex return processSources(node.getSources(), context); } + @Override + public Map visitAdaptivePlanNode(AdaptivePlanNode node, Void context) + { + return processSources(node.getSources(), context); + } + private Map processSources(List sources, Void context) { ImmutableMap.Builder result = ImmutableMap.builder(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/IterativeOptimizer.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/IterativeOptimizer.java index 3c446fe4197f..4f7a7cf5388d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/IterativeOptimizer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/IterativeOptimizer.java @@ -14,6 +14,7 @@ package io.trino.sql.planner.iterative; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import io.airlift.log.Logger; import io.airlift.units.Duration; import io.trino.Session; @@ -22,6 +23,7 @@ import io.trino.cost.CachingStatsProvider; import io.trino.cost.CostCalculator; import io.trino.cost.CostProvider; +import io.trino.cost.RuntimeInfoProvider; import io.trino.cost.StatsAndCosts; import io.trino.cost.StatsCalculator; import io.trino.cost.StatsProvider; @@ -37,10 +39,13 @@ import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.RuleStatsRecorder; import io.trino.sql.planner.SymbolAllocator; +import io.trino.sql.planner.optimizations.AdaptivePlanOptimizer; import io.trino.sql.planner.optimizations.PlanOptimizer; import io.trino.sql.planner.plan.PlanNode; +import io.trino.sql.planner.plan.PlanNodeId; import io.trino.sql.planner.planprinter.PlanPrinter; +import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Optional; @@ -60,7 +65,7 @@ import static java.util.stream.Collectors.joining; public class IterativeOptimizer - implements PlanOptimizer + implements AdaptivePlanOptimizer { private static final Logger LOG = Logger.get(IterativeOptimizer.class); @@ -95,17 +100,17 @@ public IterativeOptimizer(PlannerContext plannerContext, RuleStatsRecorder stats } @Override - public PlanNode optimize(PlanNode plan, PlanOptimizer.Context context) + public Result optimizeAndMarkPlanChanges(PlanNode plan, PlanOptimizer.Context context) { // only disable new rules if we have legacy rules to fall back to if (useLegacyRules.test(context.session()) && !legacyRules.isEmpty()) { for (PlanOptimizer optimizer : legacyRules) { plan = optimizer.optimize(plan, context); } - - return plan; + return new Result(plan, ImmutableSet.of()); } + Set changedPlanNodeIds = new HashSet<>(); Memo memo = new Memo(context.idAllocator(), plan); Lookup lookup = Lookup.from(planNode -> Stream.of(memo.resolve(planNode))); @@ -119,10 +124,11 @@ public PlanNode optimize(PlanNode plan, PlanOptimizer.Context context) timeout.toMillis(), context.session(), context.warningCollector(), - context.tableStatsProvider()); - exploreGroup(memo.getRootGroup(), optimizerContext); + context.tableStatsProvider(), + context.runtimeInfoProvider()); + exploreGroup(memo.getRootGroup(), optimizerContext, changedPlanNodeIds); context.planOptimizersStatsCollector().add(optimizerContext.getIterativeOptimizerStatsCollector()); - return memo.extract(); + return new Result(memo.extract(), ImmutableSet.copyOf(changedPlanNodeIds)); } // Used for diagnostics. @@ -131,18 +137,18 @@ public Set> getRules() return rules; } - private boolean exploreGroup(int group, Context context) + private boolean exploreGroup(int group, Context context, Set changedPlanNodeIds) { // tracks whether this group or any children groups change as // this method executes - boolean progress = exploreNode(group, context); + boolean progress = exploreNode(group, context, changedPlanNodeIds); - while (exploreChildren(group, context)) { + while (exploreChildren(group, context, changedPlanNodeIds)) { progress = true; // if children changed, try current group again // in case we can match additional rules - if (!exploreNode(group, context)) { + if (!exploreNode(group, context, changedPlanNodeIds)) { // no additional matches, so bail out break; } @@ -151,7 +157,7 @@ private boolean exploreGroup(int group, Context context) return progress; } - private boolean exploreNode(int group, Context context) + private boolean exploreNode(int group, Context context, Set changedPlanNodeIds) { PlanNode node = context.memo.getNode(group); @@ -174,7 +180,9 @@ private boolean exploreNode(int group, Context context) invoked = true; Rule.Result result = transform(node, rule, context); timeEnd = nanoTime(); - + if (result.getTransformedPlan().isPresent()) { + changedPlanNodeIds.add(result.getTransformedPlan().get().getId()); + } if (result.getTransformedPlan().isPresent()) { node = context.memo.replace(group, result.getTransformedPlan().get(), rule.getClass().getName()); @@ -247,7 +255,7 @@ private Rule.Result transform(PlanNode node, Rule rule, Context context) return Rule.Result.empty(); } - private boolean exploreChildren(int group, Context context) + private boolean exploreChildren(int group, Context context, Set changedPlanNodeIds) { boolean progress = false; @@ -255,7 +263,7 @@ private boolean exploreChildren(int group, Context context) for (PlanNode child : expression.getSources()) { checkState(child instanceof GroupReference, "Expected child to be a group reference. Found: " + child.getClass().getName()); - if (exploreGroup(((GroupReference) child).getGroupId(), context)) { + if (exploreGroup(((GroupReference) child).getGroupId(), context, changedPlanNodeIds)) { progress = true; } } @@ -265,7 +273,7 @@ private boolean exploreChildren(int group, Context context) private Rule.Context ruleContext(Context context) { - StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, Optional.of(context.memo), context.lookup, context.session, context.symbolAllocator.getTypes(), context.tableStatsProvider); + StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, Optional.of(context.memo), context.lookup, context.session, context.symbolAllocator.getTypes(), context.tableStatsProvider, context.runtimeStatsProvider); CostProvider costProvider = new CachingCostProvider(costCalculator, statsProvider, Optional.of(context.memo), context.session, context.symbolAllocator.getTypes()); return new Rule.Context() @@ -331,6 +339,7 @@ private static class Context private final Session session; private final WarningCollector warningCollector; private final TableStatsProvider tableStatsProvider; + private final RuntimeInfoProvider runtimeStatsProvider; private final PlanOptimizersStatsCollector iterativeOptimizerStatsCollector; @@ -343,7 +352,8 @@ public Context( long timeoutInMilliseconds, Session session, WarningCollector warningCollector, - TableStatsProvider tableStatsProvider) + TableStatsProvider tableStatsProvider, + RuntimeInfoProvider runtimeStatsProvider) { checkArgument(timeoutInMilliseconds >= 0, "Timeout has to be a non-negative number [milliseconds]"); @@ -357,6 +367,7 @@ public Context( this.warningCollector = warningCollector; this.iterativeOptimizerStatsCollector = createPlanOptimizersStatsCollector(); this.tableStatsProvider = tableStatsProvider; + this.runtimeStatsProvider = runtimeStatsProvider; } public void checkTimeoutNotExhausted() diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AdaptivePartitioning.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AdaptivePartitioning.java new file mode 100644 index 000000000000..ef64cc0c52f5 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AdaptivePartitioning.java @@ -0,0 +1,216 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.optimizations; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import io.airlift.log.Logger; +import io.trino.cost.RuntimeInfoProvider; +import io.trino.sql.planner.PartitioningHandle; +import io.trino.sql.planner.PartitioningScheme; +import io.trino.sql.planner.PlanFragment; +import io.trino.sql.planner.PlanNodeIdAllocator; +import io.trino.sql.planner.plan.ExchangeNode; +import io.trino.sql.planner.plan.PlanNode; +import io.trino.sql.planner.plan.PlanNodeId; +import io.trino.sql.planner.plan.RemoteSourceNode; +import io.trino.sql.planner.plan.SimplePlanRewriter; + +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.airlift.units.DataSize.succinctBytes; +import static io.trino.SystemSessionProperties.getFaultTolerantExecutionMaxPartitionCount; +import static io.trino.SystemSessionProperties.getFaultTolerantExecutionRuntimeAdaptivePartitioningMaxTaskSize; +import static io.trino.SystemSessionProperties.getFaultTolerantExecutionRuntimeAdaptivePartitioningPartitionCount; +import static io.trino.SystemSessionProperties.isFaultTolerantExecutionRuntimeAdaptivePartitioningEnabled; +import static io.trino.execution.scheduler.faulttolerant.OutputStatsEstimator.OutputStatsEstimateResult; +import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION; +import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; +import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_HASH_DISTRIBUTION; +import static io.trino.sql.planner.plan.ExchangeNode.Scope.REMOTE; +import static io.trino.sql.planner.plan.ExchangeNode.Type.REPARTITION; +import static io.trino.sql.planner.plan.ExchangeNode.Type.REPLICATE; +import static io.trino.sql.planner.plan.SimplePlanRewriter.rewriteWith; +import static java.util.Objects.requireNonNull; + +/** + * This optimizer is responsible for changing the partition count of hash partitioned fragments + * at runtime. This uses the runtime output stats from FTE to determine estimated memory consumption + * and changes the partition count if the estimated memory consumption is higher than the + * runtimeAdaptivePartitioningMaxTaskSizeInBytes * current partitionCount. + */ +public class AdaptivePartitioning + implements AdaptivePlanOptimizer +{ + private static final Logger log = Logger.get(AdaptivePartitioning.class); + + @Override + public Result optimizeAndMarkPlanChanges(PlanNode plan, Context context) + { + // Skip if runtime adaptive partitioning is not enabled + if (!isFaultTolerantExecutionRuntimeAdaptivePartitioningEnabled(context.session())) { + return new Result(plan, ImmutableSet.of()); + } + + int maxPartitionCount = getFaultTolerantExecutionMaxPartitionCount(context.session()); + int runtimeAdaptivePartitioningPartitionCount = getFaultTolerantExecutionRuntimeAdaptivePartitioningPartitionCount(context.session()); + long runtimeAdaptivePartitioningMaxTaskSizeInBytes = getFaultTolerantExecutionRuntimeAdaptivePartitioningMaxTaskSize(context.session()).toBytes(); + RuntimeInfoProvider runtimeInfoProvider = context.runtimeInfoProvider(); + for (PlanFragment fragment : runtimeInfoProvider.getAllPlanFragments()) { + // Skip if the stage is not consuming hash partitioned input or if the runtime stats are accurate which + // basically means that the stage can't be re-planned in the current implementation of AdaptivePlaner. + // TODO: We need add an ability to re-plan fragment whose stats are estimated by progress. + if (!consumesHashPartitionedInput(fragment) || runtimeInfoProvider.getRuntimeOutputStats(fragment.getId()).isAccurate()) { + continue; + } + + int partitionCount = fragment.getPartitionCount().orElse(maxPartitionCount); + // Skip if partition count is already at the maximum + if (partitionCount >= runtimeAdaptivePartitioningPartitionCount) { + continue; + } + + // calculate (estimated) input data size to determine if we want to change number of partitions at runtime + List partitionedInputBytes = fragment.getRemoteSourceNodes().stream() + // skip for replicate exchange since it's assumed that broadcast join will be chosen by + // static optimizer only if build size is small. + // TODO: Fix this assumption by using runtime stats + .filter(remoteSourceNode -> remoteSourceNode.getExchangeType() != REPLICATE) + .map(remoteSourceNode -> remoteSourceNode.getSourceFragmentIds().stream() + .mapToLong(sourceFragmentId -> { + OutputStatsEstimateResult runtimeStats = runtimeInfoProvider.getRuntimeOutputStats(sourceFragmentId); + return runtimeStats.outputDataSizeEstimate().getTotalSizeInBytes(); + }) + .sum()) + .collect(toImmutableList()); + + // Currently the memory estimation is simplified: + // if it's an aggregation, then we use the total input bytes as the memory consumption + // if it involves multiple joins, conservatively we assume the smallest remote source will be streamed through + // and use the sum of input bytes of other remote sources as the memory consumption + // TODO: more accurate memory estimation based on context (https://github.com/trinodb/trino/issues/18698) + long estimatedMemoryConsumptionInBytes = (partitionedInputBytes.size() == 1) ? partitionedInputBytes.get(0) : + partitionedInputBytes.stream().mapToLong(Long::longValue).sum() - Collections.min(partitionedInputBytes); + + if (estimatedMemoryConsumptionInBytes > runtimeAdaptivePartitioningMaxTaskSizeInBytes * partitionCount) { + log.info("Stage %s has an estimated memory consumption of %s, changing partition count from %s to %s", + fragment.getId(), succinctBytes(estimatedMemoryConsumptionInBytes), partitionCount, runtimeAdaptivePartitioningPartitionCount); + Rewriter rewriter = new Rewriter(runtimeAdaptivePartitioningPartitionCount, context.idAllocator(), runtimeInfoProvider); + PlanNode planNode = rewriteWith(rewriter, plan); + return new Result(planNode, rewriter.getChangedPlanIds()); + } + } + + return new Result(plan, ImmutableSet.of()); + } + + public static boolean consumesHashPartitionedInput(PlanFragment fragment) + { + return isPartitioned(fragment.getPartitioning()); + } + + private static boolean isPartitioned(PartitioningHandle partitioningHandle) + { + return partitioningHandle.equals(FIXED_HASH_DISTRIBUTION) || partitioningHandle.equals(SCALED_WRITER_HASH_DISTRIBUTION); + } + + private static class Rewriter + extends SimplePlanRewriter + { + private final int partitionCount; + private final PlanNodeIdAllocator idAllocator; + private final RuntimeInfoProvider runtimeInfoProvider; + private final Set changedPlanIds = new HashSet<>(); + + private Rewriter(int partitionCount, PlanNodeIdAllocator idAllocator, RuntimeInfoProvider runtimeInfoProvider) + { + this.partitionCount = partitionCount; + this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); + this.runtimeInfoProvider = requireNonNull(runtimeInfoProvider, "runtimeInfoProvider is null"); + } + + @Override + public PlanNode visitExchange(ExchangeNode node, RewriteContext context) + { + if (node.getPartitioningScheme().getPartitioning().getHandle().equals(FIXED_BROADCAST_DISTRIBUTION)) { + // the result of the subtree will be broadcast, then no need to change partition count for the subtree + // as the planner will only broadcast fragment output if it sees input data is small + // or filter ratio is high + return node; + } + List sources = node.getSources().stream() + .map(context::rewrite) + .collect(toImmutableList()); + PartitioningScheme partitioningScheme = node.getPartitioningScheme(); + + // for FTE it only makes sense to set partition count fot hash partitioned fragments + if (node.getPartitioningScheme().getPartitioning().getHandle() == FIXED_HASH_DISTRIBUTION) { + partitioningScheme = partitioningScheme.withPartitionCount(Optional.of(partitionCount)); + changedPlanIds.add(node.getId()); + } + + return new ExchangeNode( + node.getId(), + node.getType(), + node.getScope(), + partitioningScheme, + sources, + node.getInputs(), + node.getOrderingScheme()); + } + + @Override + public PlanNode visitRemoteSource(RemoteSourceNode node, RewriteContext context) + { + if (node.getExchangeType() != REPARTITION) { + return node; + } + + Optional sourcePartitioningScheme = node.getSourceFragmentIds().stream() + .map(runtimeInfoProvider::getPlanFragment) + .map(PlanFragment::getOutputPartitioningScheme) + .filter(scheme -> isPartitioned(scheme.getPartitioning().getHandle())) + .findFirst(); + + if (sourcePartitioningScheme.isEmpty()) { + return node; + } + + PartitioningScheme newPartitioningSchema = sourcePartitioningScheme.get() + .withPartitionCount(Optional.of(partitionCount)) + .withPartitioningHandle(FIXED_HASH_DISTRIBUTION); + + PlanNodeId nodeId = idAllocator.getNextId(); + changedPlanIds.add(nodeId); + return new ExchangeNode( + nodeId, + REPARTITION, + REMOTE, + newPartitioningSchema, + ImmutableList.of(node), + ImmutableList.of(node.getOutputSymbols()), + node.getOrderingScheme()); + } + + public Set getChangedPlanIds() + { + return ImmutableSet.copyOf(changedPlanIds); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AdaptivePlanOptimizer.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AdaptivePlanOptimizer.java new file mode 100644 index 000000000000..d21e9f24cc51 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AdaptivePlanOptimizer.java @@ -0,0 +1,54 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.trino.sql.planner.optimizations; + +import io.trino.sql.planner.plan.PlanNode; +import io.trino.sql.planner.plan.PlanNodeId; + +import java.util.Set; + +import static java.util.Objects.requireNonNull; + +/** + * This optimizer is needed for adaptive optimization in FTE. + */ +public interface AdaptivePlanOptimizer + extends PlanOptimizer +{ + @Override + default PlanNode optimize(PlanNode plan, Context context) + { + return optimizeAndMarkPlanChanges(plan, context).plan(); + } + + /** + * Optimize the plan and return the changes made to the plan. + */ + Result optimizeAndMarkPlanChanges(PlanNode plan, Context context); + + record Result(PlanNode plan, Set changedPlanNodes) + { + /** + * @param plan The optimized plan + * @param changedPlanNodes The set of PlanNodeIds that were changed during optimization, as well as the new + * PlanNodeIds that were added to the optimized plan. + */ + public Result(PlanNode plan, Set changedPlanNodes) + { + this.plan = requireNonNull(plan, "plan is null"); + this.changedPlanNodes = requireNonNull(changedPlanNodes, "changedPlanNodes is null"); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PlanOptimizer.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PlanOptimizer.java index c97682c306d2..d6a3a14bf3df 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PlanOptimizer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PlanOptimizer.java @@ -14,6 +14,7 @@ package io.trino.sql.planner.optimizations; import io.trino.Session; +import io.trino.cost.RuntimeInfoProvider; import io.trino.cost.TableStatsProvider; import io.trino.execution.querystats.PlanOptimizersStatsCollector; import io.trino.execution.warnings.WarningCollector; @@ -35,7 +36,8 @@ record Context( PlanNodeIdAllocator idAllocator, WarningCollector warningCollector, PlanOptimizersStatsCollector planOptimizersStatsCollector, - TableStatsProvider tableStatsProvider) + TableStatsProvider tableStatsProvider, + RuntimeInfoProvider runtimeInfoProvider) { public Context( Session session, @@ -44,7 +46,8 @@ public Context( PlanNodeIdAllocator idAllocator, WarningCollector warningCollector, PlanOptimizersStatsCollector planOptimizersStatsCollector, - TableStatsProvider tableStatsProvider) + TableStatsProvider tableStatsProvider, + RuntimeInfoProvider runtimeInfoProvider) { this.session = requireNonNull(session, "session is null"); this.types = requireNonNull(types, "types is null"); @@ -53,6 +56,7 @@ public Context( this.warningCollector = requireNonNull(warningCollector, "warningCollector is null"); this.tableStatsProvider = requireNonNull(tableStatsProvider, "tableStatsProvider is null"); this.planOptimizersStatsCollector = requireNonNull(planOptimizersStatsCollector, "planOptimizersStatsCollector is null"); + this.runtimeInfoProvider = requireNonNull(runtimeInfoProvider, "runtimeInfoProvider is null"); } } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/AdaptivePlanNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/AdaptivePlanNode.java new file mode 100644 index 000000000000..9d56920a8b0c --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/AdaptivePlanNode.java @@ -0,0 +1,95 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.plan; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; +import io.trino.spi.type.Type; +import io.trino.sql.planner.Symbol; + +import java.util.List; +import java.util.Map; + +import static java.util.Objects.requireNonNull; + +public class AdaptivePlanNode + extends PlanNode +{ + private final PlanNode initialPlan; + // We do not store the initial plan types in PlanFragment#types since initial plan is only stored for + // printing purposes. Therefore, we need to store the types separately here to be able to print the + // initial plan. + private final Map initialPlanTypes; + private final PlanNode currentPlan; + + @JsonCreator + public AdaptivePlanNode( + @JsonProperty("id") PlanNodeId id, + @JsonProperty("initialPlan") PlanNode initialPlan, + @JsonProperty("initialPlanTypes") Map initialPlanTypes, + @JsonProperty("currentPlan") PlanNode currentPlan) + { + super(id); + + this.initialPlan = requireNonNull(initialPlan, "initialPlan is null"); + this.initialPlanTypes = ImmutableMap.copyOf(requireNonNull(initialPlanTypes, "initialPlanTypes is null")); + this.currentPlan = requireNonNull(currentPlan, "currentPlan is null"); + } + + @JsonProperty + public PlanNode getInitialPlan() + { + return initialPlan; + } + + @JsonProperty + public Map getInitialPlanTypes() + { + return initialPlanTypes; + } + + @JsonProperty + public PlanNode getCurrentPlan() + { + return currentPlan; + } + + @Override + public List getSources() + { + // The initial plan is not used in the execution, so it is not a source of the adaptive plan. + return ImmutableList.of(currentPlan); + } + + @Override + public List getOutputSymbols() + { + return currentPlan.getOutputSymbols(); + } + + @Override + public PlanNode replaceChildren(List newChildren) + { + return new AdaptivePlanNode(getId(), initialPlan, initialPlanTypes, Iterables.getOnlyElement(newChildren)); + } + + @Override + public R accept(PlanVisitor visitor, C context) + { + return visitor.visitAdaptivePlanNode(this, context); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/Patterns.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/Patterns.java index a4372a903837..6afb8f26b1bb 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/Patterns.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/Patterns.java @@ -243,6 +243,11 @@ public static Pattern except() return typeOf(ExceptNode.class); } + public static Pattern remoteSourceNode() + { + return typeOf(RemoteSourceNode.class); + } + public static Property source() { return optionalProperty( diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanNode.java index 36d2841ae58d..ded4a172afe5 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanNode.java @@ -71,6 +71,7 @@ @JsonSubTypes.Type(value = PatternRecognitionNode.class, name = "patternRecognition"), @JsonSubTypes.Type(value = TableFunctionNode.class, name = "tableFunction"), @JsonSubTypes.Type(value = TableFunctionProcessorNode.class, name = "tableFunctionProcessor"), + @JsonSubTypes.Type(value = AdaptivePlanNode.class, name = "adaptivePlanNode"), }) public abstract class PlanNode { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanVisitor.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanVisitor.java index bd8dbafb45d8..45ab769d0f3c 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanVisitor.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanVisitor.java @@ -258,4 +258,9 @@ public R visitTableFunctionProcessor(TableFunctionProcessorNode node, C context) { return visitPlan(node, context); } + + public R visitAdaptivePlanNode(AdaptivePlanNode node, C context) + { + return visitPlan(node, context); + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/JsonRenderer.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/JsonRenderer.java index 0402f142c3f4..ab5ce776d4e0 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/JsonRenderer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/JsonRenderer.java @@ -14,8 +14,10 @@ package io.trino.sql.planner.planprinter; import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; import io.airlift.json.JsonCodec; import io.trino.cost.PlanNodeStatsAndCostSummary; +import io.trino.sql.planner.plan.PlanNodeId; import java.util.List; import java.util.Map; @@ -33,18 +35,15 @@ public class JsonRenderer @Override public String render(PlanRepresentation plan) { - return CODEC.toJson(renderJson(plan, plan.getRoot())); + return CODEC.toJson(renderJson(plan, plan.getRoot(), false)); } - protected JsonRenderedNode renderJson(PlanRepresentation plan, NodeRepresentation node) + protected JsonRenderedNode renderJson(PlanRepresentation plan, NodeRepresentation node, boolean isAdaptivePlanInitialNode) { - List children = node.getChildren().stream() - .map(plan::getNode) - .filter(Optional::isPresent) - .map(Optional::get) - .map(n -> renderJson(plan, n)) - .collect(toImmutableList()); - + ImmutableList.Builder children = ImmutableList.builder(); + // Add initial children first + children.addAll(renderChildren(plan, node.getInitialChildren(), true)); + children.addAll(renderChildren(plan, node.getChildren(), isAdaptivePlanInitialNode)); return new JsonRenderedNode( node.getId().toString(), node.getName(), @@ -52,7 +51,17 @@ protected JsonRenderedNode renderJson(PlanRepresentation plan, NodeRepresentatio node.getOutputs(), node.getDetails(), node.getEstimates(), - children); + children.build()); + } + + private List renderChildren(PlanRepresentation plan, List children, boolean isAdaptivePlanInitialNode) + { + return children.stream() + .map(isAdaptivePlanInitialNode ? plan::getInitialNode : plan::getNode) + .filter(Optional::isPresent) + .map(Optional::get) + .map(n -> renderJson(plan, n, isAdaptivePlanInitialNode)) + .collect(toImmutableList()); } public static class JsonRenderedNode diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/NodeRepresentation.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/NodeRepresentation.java index d7e7faeba91e..22eb56d9a944 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/NodeRepresentation.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/NodeRepresentation.java @@ -46,6 +46,7 @@ public class NodeRepresentation private final Map descriptor; private final List outputs; private final List children; + private final List initialChildren; private final List remoteSources; private final Optional stats; private final List estimatedStats; @@ -65,6 +66,8 @@ public NodeRepresentation( List estimatedCost, Optional reorderJoinStatsAndCost, List children, + // This is used in the case of adaptive plan node + List initialChildren, List remoteSources) { this.id = requireNonNull(id, "id is null"); @@ -77,6 +80,7 @@ public NodeRepresentation( this.estimatedCost = requireNonNull(estimatedCost, "estimatedCost is null"); this.reorderJoinStatsAndCost = requireNonNull(reorderJoinStatsAndCost, "reorderJoinStatsAndCost is null"); this.children = requireNonNull(children, "children is null"); + this.initialChildren = requireNonNull(initialChildren, "initialChildren is null"); this.remoteSources = requireNonNull(remoteSources, "remoteSources is null"); checkArgument(estimatedCost.size() == estimatedStats.size(), "size of cost and stats list does not match"); @@ -123,6 +127,11 @@ public List getChildren() return children; } + public List getInitialChildren() + { + return initialChildren; + } + public List getRemoteSources() { return remoteSources; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java index a94ce5cc86fc..055006104d33 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java @@ -61,6 +61,7 @@ import io.trino.sql.planner.Symbol; import io.trino.sql.planner.TypeProvider; import io.trino.sql.planner.iterative.GroupReference; +import io.trino.sql.planner.plan.AdaptivePlanNode; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.AggregationNode.Aggregation; import io.trino.sql.planner.plan.ApplyNode; @@ -239,7 +240,7 @@ public class PlanPrinter this.representation = new PlanRepresentation(planRoot, types, totalCpuTime, totalScheduledTime, totalBlockedTime); Visitor visitor = new Visitor(types, estimatedStatsAndCosts, stats); - planRoot.accept(visitor, new Context()); + planRoot.accept(visitor, new Context(Optional.empty(), Optional.empty(), false)); } private String toText(boolean verbose, int level) @@ -255,7 +256,7 @@ String toJson() JsonRenderedNode toJsonRenderedNode() { - return new JsonRenderer().renderJson(representation, representation.getRoot()); + return new JsonRenderer().renderJson(representation, representation.getRoot(), false); } public static String jsonFragmentPlan(PlanNode root, Map symbols, Metadata metadata, FunctionManager functionManager, Session session) @@ -585,7 +586,7 @@ private static String formatFragment( builder.append(indentString(1)); String hashColumn = partitioningScheme.getHashColumn().map(anonymizer::anonymize).map(column -> "[" + column + "]").orElse(""); if (replicateNullsAndAny) { - builder.append(format("Output partitioning: %s (replicate nulls and any) [%s]%s\n", + builder.append(format("Output partitioning: %s (replicate nulls and any) [%s]%s", anonymizer.anonymize(partitioningScheme.getPartitioning().getHandle()), Joiner.on(", ").join(arguments), hashColumn)); @@ -596,8 +597,8 @@ private static String formatFragment( Joiner.on(", ").join(arguments), hashColumn)); } - - fragment.getPartitionCount().ifPresent(partitionCount -> builder.append(format("%sPartition count: %s\n", indentString(1), partitionCount))); + partitioningScheme.getPartitionCount().ifPresent(partitionCount -> builder.append(format("%sOutput partition count: %s\n", indentString(1), partitionCount))); + fragment.getPartitionCount().ifPresent(partitionCount -> builder.append(format("%sInput partition count: %s\n", indentString(1), partitionCount))); builder.append( new PlanPrinter( @@ -679,7 +680,7 @@ public Visitor(TypeProvider types, StatsAndCosts estimatedStatsAndCosts, Optiona public Void visitExplainAnalyze(ExplainAnalyzeNode node, Context context) { addNode(node, "ExplainAnalyze", context); - return processChildren(node, new Context(context.types())); + return processChildren(node, new Context(context.types(), context.isInitialPlan())); } @Override @@ -711,8 +712,8 @@ public Void visitJoin(JoinNode node, Context context) if (!node.getDynamicFilters().isEmpty()) { nodeOutput.appendDetails("dynamicFilterAssignments = %s", printDynamicFilterAssignments(node.getDynamicFilters())); } - node.getLeft().accept(this, new Context(context.types())); - node.getRight().accept(this, new Context(context.types())); + node.getLeft().accept(this, new Context(context.types(), context.isInitialPlan())); + node.getRight().accept(this, new Context(context.types(), context.isInitialPlan())); return null; } @@ -726,8 +727,8 @@ public Void visitSpatialJoin(SpatialJoinNode node, Context context) context); nodeOutput.appendDetails("Distribution: %s", node.getDistributionType()); - node.getLeft().accept(this, new Context(context.types())); - node.getRight().accept(this, new Context(context.types())); + node.getLeft().accept(this, new Context(context.types(), context.isInitialPlan())); + node.getRight().accept(this, new Context(context.types(), context.isInitialPlan())); return null; } @@ -743,8 +744,8 @@ public Void visitSemiJoin(SemiJoinNode node, Context context) context); node.getDistributionType().ifPresent(distributionType -> nodeOutput.appendDetails("Distribution: %s", distributionType)); node.getDynamicFilterId().ifPresent(dynamicFilterId -> nodeOutput.appendDetails("dynamicFilterId: %s", dynamicFilterId)); - node.getSource().accept(this, new Context(context.types())); - node.getFilteringSource().accept(this, new Context(context.types())); + node.getSource().accept(this, new Context(context.types(), context.isInitialPlan())); + node.getFilteringSource().accept(this, new Context(context.types(), context.isInitialPlan())); return null; } @@ -757,7 +758,7 @@ public Void visitDynamicFilterSource(DynamicFilterSourceNode node, Context conte "DynamicFilterSource", ImmutableMap.of("dynamicFilterAssignments", printDynamicFilterAssignments(node.getDynamicFilters())), context); - node.getSource().accept(this, new Context(context.types())); + node.getSource().accept(this, new Context(context.types(), context.isInitialPlan())); return null; } @@ -795,8 +796,8 @@ public Void visitIndexJoin(IndexJoinNode node, Context context) "criteria", Joiner.on(" AND ").join(anonymizeExpressions(joinExpressions)), "hash", formatHash(node.getProbeHashSymbol(), node.getIndexHashSymbol())), context); - node.getProbeSource().accept(this, new Context(context.types())); - node.getIndexSource().accept(this, new Context(context.types())); + node.getProbeSource().accept(this, new Context(context.types(), context.isInitialPlan())); + node.getIndexSource().accept(this, new Context(context.types(), context.isInitialPlan())); return null; } @@ -808,7 +809,7 @@ public Void visitOffset(OffsetNode node, Context context) "Offset", ImmutableMap.of("count", String.valueOf(node.getCount())), context); - return processChildren(node, new Context(context.types())); + return processChildren(node, new Context(context.types(), context.isInitialPlan())); } @Override @@ -821,7 +822,7 @@ public Void visitLimit(LimitNode node, Context context) "withTies", formatBoolean(node.isWithTies()), "inputPreSortedBy", formatSymbols(node.getPreSortedInputs())), context); - return processChildren(node, new Context(context.types())); + return processChildren(node, new Context(context.types(), context.isInitialPlan())); } @Override @@ -833,7 +834,7 @@ public Void visitDistinctLimit(DistinctLimitNode node, Context context) "limit", String.valueOf(node.getLimit()), "hash", formatHash(node.getHashSymbol())), context); - return processChildren(node, new Context(context.types())); + return processChildren(node, new Context(context.types(), context.isInitialPlan())); } @Override @@ -860,7 +861,7 @@ public Void visitAggregation(AggregationNode node, Context context) node.getAggregations().forEach((symbol, aggregation) -> nodeOutput.appendDetails("%s := %s", anonymizer.anonymize(symbol), formatAggregation(anonymizer, aggregation))); - return processChildren(node, new Context(context.types())); + return processChildren(node, new Context(context.types(), context.isInitialPlan())); } @Override @@ -884,7 +885,7 @@ public Void visitGroupId(GroupIdNode node, Context context) nodeOutput.appendDetails("%s := %s", anonymizer.anonymize(mapping.getKey()), anonymizer.anonymize(mapping.getValue())); } - return processChildren(node, new Context(context.types())); + return processChildren(node, new Context(context.types(), context.isInitialPlan())); } @Override @@ -898,7 +899,7 @@ public Void visitMarkDistinct(MarkDistinctNode node, Context context) "hash", formatHash(node.getHashSymbol())), context); - return processChildren(node, new Context(context.types())); + return processChildren(node, new Context(context.types(), context.isInitialPlan())); } @Override @@ -949,7 +950,7 @@ public Void visitWindow(WindowNode node, Context context) Joiner.on(", ").join(anonymizeExpressions(function.getArguments())), frameInfo); } - return processChildren(node, new Context(context.types())); + return processChildren(node, new Context(context.types(), context.isInitialPlan())); } @Override @@ -1019,7 +1020,7 @@ public Void visitPatternRecognition(PatternRecognitionNode node, Context context appendValuePointers(nodeOutput, entry.getValue()); } - return processChildren(node, new Context(context.types())); + return processChildren(node, new Context(context.types(), context.isInitialPlan())); } private void appendValuePointers(NodeRepresentation nodeOutput, ExpressionAndValuePointers expressionAndPointers) @@ -1129,7 +1130,7 @@ public Void visitTopNRanking(TopNRankingNode node, Context context) nodeOutput.appendDetails("%s := %s", anonymizer.anonymize(node.getRankingSymbol()), node.getRankingType()); - return processChildren(node, new Context(context.types())); + return processChildren(node, new Context(context.types(), context.isInitialPlan())); } @Override @@ -1151,7 +1152,7 @@ public Void visitRowNumber(RowNumberNode node, Context context) context); nodeOutput.appendDetails("%s := %s", anonymizer.anonymize(node.getRowNumberSymbol()), "row_number()"); - return processChildren(node, new Context(context.types())); + return processChildren(node, new Context(context.types(), context.isInitialPlan())); } @Override @@ -1282,6 +1283,7 @@ private Void visitScanFilterAndProjectInfo( allNodes, ImmutableList.of(sourceNode), ImmutableList.of(), + ImmutableList.of(), Optional.empty(), context); @@ -1329,7 +1331,7 @@ private Void visitScanFilterAndProjectInfo( return null; } - sourceNode.accept(this, new Context(context.types())); + sourceNode.accept(this, new Context(context.types(), context.isInitialPlan())); return null; } @@ -1424,7 +1426,7 @@ public Void visitUnnest(UnnestNode node, Context context) } descriptor.put("unnest", formatOutputs(getTypes(context), unnestInputs)); addNode(node, name, descriptor.buildOrThrow(), context); - return processChildren(node, new Context(context.types())); + return processChildren(node, new Context(context.types(), context.isInitialPlan())); } @Override @@ -1442,7 +1444,7 @@ public Void visitOutput(OutputNode node, Context context) nodeOutput.appendDetails("%s := %s", anonymizer.anonymizeColumn(name), anonymizer.anonymize(symbol)); } } - return processChildren(node, new Context(context.types())); + return processChildren(node, new Context(context.types(), context.isInitialPlan())); } @Override @@ -1454,7 +1456,7 @@ public Void visitTopN(TopNNode node, Context context) "count", String.valueOf(node.getCount()), "orderBy", formatOrderingScheme(node.getOrderingScheme())), context); - return processChildren(node, new Context(context.types())); + return processChildren(node, new Context(context.types(), context.isInitialPlan())); } @Override @@ -1465,7 +1467,7 @@ public Void visitSort(SortNode node, Context context) ImmutableMap.of("orderBy", formatOrderingScheme(node.getOrderingScheme())), context); - return processChildren(node, new Context(context.types())); + return processChildren(node, new Context(context.types(), context.isInitialPlan())); } @Override @@ -1476,6 +1478,7 @@ public Void visitRemoteSource(RemoteSourceNode node, Context context) ImmutableMap.of("sourceFragmentIds", formatCollection(node.getSourceFragmentIds(), Objects::toString)), ImmutableList.of(), ImmutableList.of(), + ImmutableList.of(), node.getSourceFragmentIds(), Optional.empty(), context); @@ -1488,7 +1491,7 @@ public Void visitUnion(UnionNode node, Context context) { addNode(node, "Union", context); - return processChildren(node, new Context(context.types())); + return processChildren(node, new Context(context.types(), context.isInitialPlan())); } @Override @@ -1499,7 +1502,7 @@ public Void visitIntersect(IntersectNode node, Context context) ImmutableMap.of("isDistinct", formatBoolean(node.isDistinct())), context); - return processChildren(node, new Context(context.types())); + return processChildren(node, new Context(context.types(), context.isInitialPlan())); } @Override @@ -1510,7 +1513,7 @@ public Void visitExcept(ExceptNode node, Context context) ImmutableMap.of("isDistinct", formatBoolean(node.isDistinct())), context); - return processChildren(node, new Context(context.types())); + return processChildren(node, new Context(context.types(), context.isInitialPlan())); } @Override @@ -1538,7 +1541,7 @@ public Void visitTableWriter(TableWriterNode node, Context context) printStatisticAggregations(nodeOutput, node.getStatisticsAggregation().get(), node.getStatisticsAggregationDescriptor().get()); } - return processChildren(node, new Context(context.types())); + return processChildren(node, new Context(context.types(), context.isInitialPlan())); } @Override @@ -1548,7 +1551,7 @@ public Void visitStatisticsWriterNode(StatisticsWriterNode node, Context context "StatisticsWriter", ImmutableMap.of("target", anonymizer.anonymize(node.getTarget())), context); - return processChildren(node, new Context(context.types())); + return processChildren(node, new Context(context.types(), context.isInitialPlan())); } @Override @@ -1565,7 +1568,7 @@ public Void visitTableFinish(TableFinishNode node, Context context) printStatisticAggregations(nodeOutput, node.getStatisticsAggregation().get(), node.getStatisticsAggregationDescriptor().get()); } - return processChildren(node, new Context(context.types())); + return processChildren(node, new Context(context.types(), context.isInitialPlan())); } private void printStatisticAggregations(NodeRepresentation nodeOutput, StatisticAggregations aggregations, StatisticAggregationsDescriptor descriptor) @@ -1631,7 +1634,7 @@ public Void visitSample(SampleNode node, Context context) "ratio", String.valueOf(node.getSampleRatio())), context); - return processChildren(node, new Context(context.types())); + return processChildren(node, new Context(context.types(), context.isInitialPlan())); } @Override @@ -1664,7 +1667,25 @@ else if (node.getScope() == Scope.LOCAL) { "hashColumn", formatHash(node.getPartitioningScheme().getHashColumn())), context); } - return processChildren(node, new Context(context.types())); + return processChildren(node, new Context(context.types(), context.isInitialPlan())); + } + + @Override + public Void visitAdaptivePlanNode(AdaptivePlanNode node, Context context) + { + addNode( + node, + "AdaptivePlan", + ImmutableMap.of(), + ImmutableList.of(node.getId()), + ImmutableList.of(node.getCurrentPlan()), + ImmutableList.of(node.getInitialPlan()), + ImmutableList.of(), + Optional.empty(), + context); + node.getInitialPlan().accept(this, new Context("Initial Plan", Optional.of(TypeProvider.viewOf(node.getInitialPlanTypes())), true)); + node.getCurrentPlan().accept(this, new Context("Current Plan", Optional.empty(), false)); + return null; } @Override @@ -1677,7 +1698,7 @@ public Void visitTableExecute(TableExecuteNode node, Context context) nodeOutput.appendDetails("%s := %s", anonymizer.anonymizeColumn(name), anonymizer.anonymize(symbol)); } - return processChildren(node, new Context(context.types())); + return processChildren(node, new Context(context.types(), context.isInitialPlan())); } @Override @@ -1697,7 +1718,7 @@ public Void visitMergeWriter(MergeWriterNode node, Context context) "MergeWriter", ImmutableMap.of("table", anonymizer.anonymize(node.getTarget())), context); - return processChildren(node, new Context(context.types())); + return processChildren(node, new Context(context.types(), context.isInitialPlan())); } @Override @@ -1710,7 +1731,7 @@ public Void visitMergeProcessor(MergeProcessorNode node, Context context) nodeOutput.appendDetails("redistribution columns: %s", anonymize(node.getRedistributionColumnSymbols())); nodeOutput.appendDetails("data columns: %s", anonymize(node.getDataColumnSymbols())); - return processChildren(node, new Context(context.types())); + return processChildren(node, new Context(context.types(), context.isInitialPlan())); } @Override @@ -1721,7 +1742,7 @@ public Void visitTableDelete(TableDeleteNode node, Context context) ImmutableMap.of("target", anonymizer.anonymize(node.getTarget())), context); - return processChildren(node, new Context(context.types())); + return processChildren(node, new Context(context.types(), context.isInitialPlan())); } @Override @@ -1732,7 +1753,7 @@ public Void visitTableUpdate(TableUpdateNode node, Context context) ImmutableMap.of("target", anonymizer.anonymize(node.getTarget())), context); - return processChildren(node, new Context(context.types())); + return processChildren(node, new Context(context.types(), context.isInitialPlan())); } @Override @@ -1740,7 +1761,7 @@ public Void visitEnforceSingleRow(EnforceSingleRowNode node, Context context) { addNode(node, "EnforceSingleRow", context); - return processChildren(node, new Context(context.types())); + return processChildren(node, new Context(context.types(), context.isInitialPlan())); } @Override @@ -1748,7 +1769,7 @@ public Void visitAssignUniqueId(AssignUniqueId node, Context context) { addNode(node, "AssignUniqueId", context); - return processChildren(node, new Context(context.types())); + return processChildren(node, new Context(context.types(), context.isInitialPlan())); } @Override @@ -1774,7 +1795,7 @@ public Void visitApply(ApplyNode node, Context context) context); printAssignments(nodeOutput, node.getSubqueryAssignments()); - return processChildren(node, new Context(context.types())); + return processChildren(node, new Context(context.types(), context.isInitialPlan())); } @Override @@ -1787,7 +1808,7 @@ public Void visitCorrelatedJoin(CorrelatedJoinNode node, Context context) "filter", formatFilter(node.getFilter())), context); - return processChildren(node, new Context(context.types())); + return processChildren(node, new Context(context.types(), context.isInitialPlan())); } @Override @@ -1816,7 +1837,7 @@ public Void visitTableFunction(TableFunctionNode node, Context context) } for (int i = 0; i < node.getSources().size(); i++) { - node.getSources().get(i).accept(this, new Context(node.getTableArgumentProperties().get(i).getArgumentName(), Optional.empty())); + node.getSources().get(i).accept(this, new Context(node.getTableArgumentProperties().get(i).getArgumentName(), context.types(), context.isInitialPlan())); } return null; @@ -1927,7 +1948,7 @@ public Void visitTableFunctionProcessor(TableFunctionProcessorNode node, Context addNode(node, "TableFunctionProcessor", descriptor.put("hash", formatHash(node.getHashSymbol())).buildOrThrow(), context); - return processChildren(node, new Context(context.types())); + return processChildren(node, new Context(context.types(), context.isInitialPlan())); } @Override @@ -2125,7 +2146,7 @@ public NodeRepresentation addNode(PlanNode node, String name, Map descriptor, List children, Optional reorderJoinStatsAndCost, Context context) { - return addNode(node, name, descriptor, ImmutableList.of(node.getId()), children, ImmutableList.of(), reorderJoinStatsAndCost, context); + return addNode(node, name, descriptor, ImmutableList.of(node.getId()), children, ImmutableList.of(), ImmutableList.of(), reorderJoinStatsAndCost, context); } public NodeRepresentation addNode( @@ -2134,11 +2155,13 @@ public NodeRepresentation addNode( Map descriptor, List allNodes, List children, + List initialChildren, List remoteSources, Optional reorderJoinStatsAndCost, Context context) { List childrenIds = children.stream().map(PlanNode::getId).collect(toImmutableList()); + List initialChildrenIds = initialChildren.stream().map(PlanNode::getId).collect(toImmutableList()); List estimatedStats = allNodes.stream() .map(nodeId -> estimatedStatsAndCosts.getStats().getOrDefault(nodeId, PlanNodeStatsEstimate.unknown())) .collect(toList()); @@ -2162,9 +2185,15 @@ public NodeRepresentation addNode( estimatedCosts, reorderJoinStatsAndCost, childrenIds, + initialChildrenIds, remoteSources); - representation.addNode(nodeOutput); + if (context.isInitialPlan()) { + representation.addInitialNode(nodeOutput); + } + else { + representation.addNode(nodeOutput); + } return nodeOutput; } @@ -2253,21 +2282,16 @@ public Expression rewriteFunctionCall(FunctionCall node, Void context, Expressio }, expression); } - private record Context(Optional tag, Optional types) + private record Context(Optional tag, Optional types, boolean isInitialPlan) { - public Context() - { - this(Optional.empty(), Optional.empty()); - } - - public Context(String tag, Optional types) + public Context(Optional types, boolean isInitialPlan) { - this(Optional.of(tag), types); + this(Optional.empty(), types, isInitialPlan); } - public Context(Optional types) + public Context(String tag, Optional types, boolean isInitialPlan) { - this(Optional.empty(), types); + this(Optional.of(tag), types, isInitialPlan); } private Context diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanRepresentation.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanRepresentation.java index 969618898b74..20b58e91b945 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanRepresentation.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanRepresentation.java @@ -33,6 +33,9 @@ class PlanRepresentation private final Optional totalBlockedTime; private final Map nodeInfo = new HashMap<>(); + // Record the initial plan node info for adaptive plan since it is possible that the plan node id remain the same + // but the plan node itself changes + private final Map initialNodeInfo = new HashMap<>(); public PlanRepresentation(PlanNode root, TypeProvider types, Optional totalCpuTime, Optional totalScheduledTime, Optional totalBlockedTime) { @@ -73,8 +76,18 @@ public Optional getNode(PlanNodeId id) return Optional.ofNullable(nodeInfo.get(id)); } + public Optional getInitialNode(PlanNodeId id) + { + return Optional.ofNullable(initialNodeInfo.get(id)); + } + public void addNode(NodeRepresentation node) { nodeInfo.put(node.getId(), node); } + + public void addInitialNode(NodeRepresentation node) + { + initialNodeInfo.put(node.getId(), node); + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/TextRenderer.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/TextRenderer.java index 64118f8f96f1..c01d2c2e2e66 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/TextRenderer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/TextRenderer.java @@ -20,6 +20,7 @@ import io.trino.plugin.base.metrics.TDigestHistogram; import io.trino.spi.metrics.Metric; import io.trino.spi.metrics.Metrics; +import io.trino.sql.planner.plan.PlanNodeId; import java.util.Iterator; import java.util.List; @@ -60,10 +61,10 @@ public String render(PlanRepresentation plan) StringBuilder output = new StringBuilder(); NodeRepresentation root = plan.getRoot(); boolean hasChildren = hasChildren(root, plan); - return writeTextOutput(output, plan, Indent.newInstance(level, hasChildren), root); + return writeTextOutput(output, plan, Indent.newInstance(level, hasChildren), root, false); } - private String writeTextOutput(StringBuilder output, PlanRepresentation plan, Indent indent, NodeRepresentation node) + private String writeTextOutput(StringBuilder output, PlanRepresentation plan, Indent indent, NodeRepresentation node, boolean isAdaptivePlanInitialNode) { output.append(indent.nodeIndent()) .append(node.getName()) @@ -102,18 +103,30 @@ private String writeTextOutput(StringBuilder output, PlanRepresentation plan, In } } - List children = node.getChildren().stream() - .map(plan::getNode) + // Print the initial children first, then the children + printChildren(output, plan, node.getInitialChildren(), indent, true); + printChildren(output, plan, node.getChildren(), indent, isAdaptivePlanInitialNode); + + return output.toString(); + } + + private void printChildren( + StringBuilder output, + PlanRepresentation plan, + List childrenIds, + Indent indent, + boolean isAdaptivePlanInitialNode) + { + List children = childrenIds.stream() + .map(isAdaptivePlanInitialNode ? plan::getInitialNode : plan::getNode) .filter(Optional::isPresent) .map(Optional::get) .collect(toList()); for (Iterator iterator = children.iterator(); iterator.hasNext(); ) { NodeRepresentation child = iterator.next(); - writeTextOutput(output, plan, indent.forChild(!iterator.hasNext(), hasChildren(child, plan)), child); + writeTextOutput(output, plan, indent.forChild(!iterator.hasNext(), hasChildren(child, plan)), child, isAdaptivePlanInitialNode); } - - return output.toString(); } private String printStats(PlanRepresentation plan, NodeRepresentation node) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/PlanSanityChecker.java b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/PlanSanityChecker.java index 9815a29c5b55..ff71947fc7a3 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/PlanSanityChecker.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/PlanSanityChecker.java @@ -64,6 +64,21 @@ public PlanSanityChecker(boolean forceSingleNode) new DynamicFiltersChecker(), new TableScanValidator(), new TableExecuteStructureValidator()) + .putAll( + Stage.AFTER_ADAPTIVE_PLANNING, + new ValidateDependenciesChecker(), + new NoDuplicatePlanNodeIdsChecker(), + new SugarFreeChecker(), + new AllFunctionsResolved(), + new TypeValidator(), + new NoSubqueryExpressionLeftChecker(), + new NoIdentifierLeftChecker(), + new VerifyOnlyOneOutputNode(), + new VerifyNoFilteredAggregations(), + new VerifyUseConnectorNodePartitioningSet(), + new ValidateScaledWritersUsage(), + new TableScanValidator(), + new TableExecuteStructureValidator()) .build(); } @@ -75,28 +90,7 @@ public void validateFinalPlan( TypeProvider types, WarningCollector warningCollector) { - try { - checkers.get(Stage.FINAL).forEach(checker -> checker.validate(planNode, session, plannerContext, typeAnalyzer, types, warningCollector)); - } - catch (RuntimeException e) { - try { - int nestLevel = 4; // so that it renders reasonably within exception stacktrace - String explain = textLogicalPlan( - planNode, - types, - plannerContext.getMetadata(), - plannerContext.getFunctionManager(), - StatsAndCosts.empty(), - session, - nestLevel, - false); - e.addSuppressed(new Exception("Current plan:\n" + explain)); - } - catch (RuntimeException ignore) { - // ignored - } - throw e; - } + validate(Stage.FINAL, planNode, session, plannerContext, typeAnalyzer, types, warningCollector); } public void validateIntermediatePlan( @@ -106,9 +100,32 @@ public void validateIntermediatePlan( IrTypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) + { + validate(Stage.INTERMEDIATE, planNode, session, plannerContext, typeAnalyzer, types, warningCollector); + } + + public void validateAdaptivePlan( + PlanNode planNode, + Session session, + PlannerContext plannerContext, + IrTypeAnalyzer typeAnalyzer, + TypeProvider types, + WarningCollector warningCollector) + { + validate(Stage.AFTER_ADAPTIVE_PLANNING, planNode, session, plannerContext, typeAnalyzer, types, warningCollector); + } + + private void validate( + Stage stage, + PlanNode planNode, + Session session, + PlannerContext plannerContext, + IrTypeAnalyzer typeAnalyzer, + TypeProvider types, + WarningCollector warningCollector) { try { - checkers.get(Stage.INTERMEDIATE).forEach(checker -> checker.validate(planNode, session, plannerContext, typeAnalyzer, types, warningCollector)); + checkers.get(stage).forEach(checker -> checker.validate(planNode, session, plannerContext, typeAnalyzer, types, warningCollector)); } catch (RuntimeException e) { try { @@ -144,6 +161,6 @@ void validate( private enum Stage { - INTERMEDIATE, FINAL + INTERMEDIATE, FINAL, AFTER_ADAPTIVE_PLANNING } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java index 4326d7307fb1..7cb2e2aa7f62 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java @@ -21,6 +21,7 @@ import io.trino.sql.planner.IrTypeAnalyzer; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.TypeProvider; +import io.trino.sql.planner.plan.AdaptivePlanNode; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.AggregationNode.Aggregation; import io.trino.sql.planner.plan.ApplyNode; @@ -124,6 +125,15 @@ protected Void visitPlan(PlanNode node, Set boundSymbols) throw new UnsupportedOperationException("not yet implemented: " + node.getClass().getName()); } + @Override + public Void visitAdaptivePlanNode(AdaptivePlanNode node, Set boundSymbols) + { + PlanNode source = node.getCurrentPlan(); + source.accept(this, boundSymbols); // visit child + + return null; + } + @Override public Void visitExplainAnalyze(ExplainAnalyzeNode node, Set boundSymbols) { diff --git a/core/trino-main/src/main/java/io/trino/testing/PlanTester.java b/core/trino-main/src/main/java/io/trino/testing/PlanTester.java index fbcb6b52e380..f91f7fcb2d45 100644 --- a/core/trino-main/src/main/java/io/trino/testing/PlanTester.java +++ b/core/trino-main/src/main/java/io/trino/testing/PlanTester.java @@ -45,12 +45,14 @@ import io.trino.connector.system.TableCommentSystemTable; import io.trino.connector.system.TablePropertiesSystemTable; import io.trino.connector.system.TransactionsSystemTable; +import io.trino.cost.CachingTableStatsProvider; import io.trino.cost.ComposableStatsCalculator; import io.trino.cost.CostCalculator; import io.trino.cost.CostCalculatorUsingExchanges; import io.trino.cost.CostCalculatorWithEstimatedExchanges; import io.trino.cost.CostComparator; import io.trino.cost.FilterStatsCalculator; +import io.trino.cost.RuntimeInfoProvider; import io.trino.cost.ScalarStatsCalculator; import io.trino.cost.StatsCalculator; import io.trino.cost.StatsCalculatorModule.StatsRulesProvider; @@ -153,6 +155,7 @@ import io.trino.sql.gen.OrderingCompiler; import io.trino.sql.gen.PageFunctionCompiler; import io.trino.sql.parser.SqlParser; +import io.trino.sql.planner.AdaptivePlanner; import io.trino.sql.planner.CompilerConfig; import io.trino.sql.planner.IrTypeAnalyzer; import io.trino.sql.planner.LocalExecutionPlanner; @@ -164,8 +167,10 @@ import io.trino.sql.planner.PlanFragmenter; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.PlanOptimizers; +import io.trino.sql.planner.PlanOptimizersFactory; import io.trino.sql.planner.RuleStatsRecorder; import io.trino.sql.planner.SubPlan; +import io.trino.sql.planner.optimizations.AdaptivePlanOptimizer; import io.trino.sql.planner.optimizations.PlanOptimizer; import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.PlanNodeId; @@ -803,6 +808,16 @@ public Plan createPlan(Session session, @Language("SQL") String sql) } public List getPlanOptimizers(boolean forceSingleNode) + { + return getPlanOptimizersFactory(forceSingleNode).getPlanOptimizers(); + } + + public List getAdaptivePlanOptimizers() + { + return getPlanOptimizersFactory(false).getAdaptivePlanOptimizers(); + } + + public PlanOptimizersFactory getPlanOptimizersFactory(boolean forceSingleNode) { return new PlanOptimizers( plannerContext, @@ -818,7 +833,7 @@ public List getPlanOptimizers(boolean forceSingleNode) new CostComparator(optimizerConfig), taskCountEstimator, nodePartitioningManager, - new RuleStatsRecorder()).get(); + new RuleStatsRecorder()); } public Plan createPlan(Session session, @Language("SQL") String sql, List optimizers, LogicalPlanner.Stage stage, WarningCollector warningCollector, PlanOptimizersStatsCollector planOptimizersStatsCollector) @@ -850,17 +865,33 @@ public Plan createPlan(Session session, @Language("SQL") String sql, List optimizers, WarningCollector warningCollector, PlanOptimizersStatsCollector planOptimizersStatsCollector, RuntimeInfoProvider runtimeInfoProvider) + { + AdaptivePlanner adaptivePlanner = new AdaptivePlanner( + session, + getPlannerContext(), + optimizers, + planFragmenter, + new PlanSanityChecker(false), + new IrTypeAnalyzer(plannerContext), + warningCollector, + planOptimizersStatsCollector, + new CachingTableStatsProvider(getPlannerContext().getMetadata(), session)); + return adaptivePlanner.optimize(subPlan, runtimeInfoProvider); + } + private QueryExplainerFactory createQueryExplainerFactory(List optimizers) { return new QueryExplainerFactory( - () -> optimizers, + createPlanOptimizersFactory(optimizers), planFragmenter, plannerContext, statsCalculator, @@ -868,6 +899,24 @@ private QueryExplainerFactory createQueryExplainerFactory(List op new NodeVersion("test")); } + private PlanOptimizersFactory createPlanOptimizersFactory(List optimizers) + { + return new PlanOptimizersFactory() + { + @Override + public List getPlanOptimizers() + { + return optimizers; + } + + @Override + public List getAdaptivePlanOptimizers() + { + throw new UnsupportedOperationException(); + } + }; + } + private AnalyzerFactory createAnalyzerFactory(QueryExplainerFactory queryExplainerFactory) { return new AnalyzerFactory( diff --git a/core/trino-main/src/test/java/io/trino/cost/StatsCalculatorAssertion.java b/core/trino-main/src/test/java/io/trino/cost/StatsCalculatorAssertion.java index 62c13a6545d7..058bd870807c 100644 --- a/core/trino-main/src/test/java/io/trino/cost/StatsCalculatorAssertion.java +++ b/core/trino-main/src/test/java/io/trino/cost/StatsCalculatorAssertion.java @@ -40,6 +40,7 @@ public class StatsCalculatorAssertion private final TypeProvider types; private final Map sourcesStats; + private RuntimeInfoProvider runtimeInfoProvider = RuntimeInfoProvider.noImplementation(); private Optional tableStatsProvider = Optional.empty(); @@ -86,6 +87,12 @@ public StatsCalculatorAssertion withTableStatisticsProvider(TableStatsProvider t return this; } + public StatsCalculatorAssertion withRuntimeInfoProvider(RuntimeInfoProvider runtimeInfoProvider) + { + this.runtimeInfoProvider = runtimeInfoProvider; + return this; + } + public StatsCalculatorAssertion check(Consumer statisticsAssertionConsumer) { PlanNodeStatsEstimate statsEstimate = queryRunner.getStatsCalculator().calculateStats( @@ -95,7 +102,8 @@ public StatsCalculatorAssertion check(Consumer statistic noLookup(), session, types, - tableStatsProvider.orElseGet(() -> new CachingTableStatsProvider(queryRunner.getPlannerContext().getMetadata(), session)))); + tableStatsProvider.orElseGet(() -> new CachingTableStatsProvider(queryRunner.getPlannerContext().getMetadata(), session)), + runtimeInfoProvider)); statisticsAssertionConsumer.accept(PlanNodeStatsAssertion.assertThat(statsEstimate)); return this; } @@ -110,7 +118,8 @@ public StatsCalculatorAssertion check(Rule rule, Consumer new CachingTableStatsProvider(queryRunner.getPlannerContext().getMetadata(), session)))); + tableStatsProvider.orElseGet(() -> new CachingTableStatsProvider(queryRunner.getPlannerContext().getMetadata(), session)), + runtimeInfoProvider)); checkState(statsEstimate.isPresent(), "Expected stats estimates to be present"); statisticsAssertionConsumer.accept(PlanNodeStatsAssertion.assertThat(statsEstimate.get())); return this; diff --git a/core/trino-main/src/test/java/io/trino/cost/TestRemoteSourceStatsRule.java b/core/trino-main/src/test/java/io/trino/cost/TestRemoteSourceStatsRule.java new file mode 100644 index 000000000000..c4ba857983a6 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/cost/TestRemoteSourceStatsRule.java @@ -0,0 +1,174 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cost; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.primitives.ImmutableLongArray; +import io.trino.sql.planner.Partitioning; +import io.trino.sql.planner.PartitioningScheme; +import io.trino.sql.planner.PlanFragment; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.plan.PlanFragmentId; +import io.trino.sql.planner.plan.PlanNodeId; +import io.trino.sql.planner.plan.TableScanNode; +import org.junit.jupiter.api.Test; + +import java.util.Optional; + +import static io.trino.execution.scheduler.faulttolerant.OutputStatsEstimator.OutputStatsEstimateResult; +import static io.trino.operator.RetryPolicy.TASK; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.spi.type.VarcharType.VARCHAR; +import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; +import static io.trino.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION; +import static io.trino.sql.planner.plan.ExchangeNode.Type.REPARTITION; +import static io.trino.testing.TestingHandles.TEST_TABLE_HANDLE; +import static io.trino.testing.TestingMetadata.TestingColumnHandle; +import static java.lang.Double.NaN; + +public class TestRemoteSourceStatsRule + extends BaseStatsCalculatorTest +{ + @Test + public void testStatsRule() + { + assertRemoteSourceStats(0.1, 3325.333333); + } + + @Test + public void testStatsRuleWithNaNNullFraction() + { + // NaN null fraction is replaced with 0.0 + assertRemoteSourceStats(NaN, 2992); + } + + private void assertRemoteSourceStats(double nullFraction, double avgRowSize) + { + StatsAndCosts statsAndCosts = createStatsAndCosts(nullFraction); + tester().assertStatsFor(pb -> pb + .remoteSource( + ImmutableList.of(new PlanFragmentId("fragment")), + ImmutableList.of( + pb.symbol("col_a", VARCHAR), + pb.symbol("col_b", VARCHAR), + pb.symbol("col_c", BIGINT), + pb.symbol("col_d", DOUBLE)), + Optional.empty(), + REPARTITION, + TASK)) + .withRuntimeInfoProvider(createRuntimeInfoProvider(statsAndCosts)) + .check(check -> check + .outputRowsCount(1_000_000) + .symbolStats(new Symbol("col_a"), assertion -> assertion + .averageRowSize(avgRowSize) + .distinctValuesCount(100) + .nullsFraction(nullFraction) + .lowValueUnknown() + .highValueUnknown()) + .symbolStats(new Symbol("col_b"), assertion -> assertion + .averageRowSize(avgRowSize) + .distinctValuesCount(233) + .nullsFraction(nullFraction) + .lowValueUnknown() + .highValueUnknown()) + .symbolStats(new Symbol("col_c"), assertion -> assertion + .averageRowSize(NaN) + .distinctValuesCount(98) + .nullsFraction(nullFraction) + .highValue(100) + .lowValue(3)) + .symbolStats(new Symbol("col_d"), assertion -> assertion + .averageRowSize(NaN) + .distinctValuesCount(300) + .nullsFraction(nullFraction) + .highValue(100) + .lowValue(3))); + } + + private RuntimeInfoProvider createRuntimeInfoProvider(StatsAndCosts statsAndCosts) + { + PlanFragment planFragment = createPlanFragment(statsAndCosts); + return new StaticRuntimeInfoProvider( + ImmutableMap.of(planFragment.getId(), createRuntimeOutputStatsEstimate()), + ImmutableMap.of(planFragment.getId(), planFragment)); + } + + private OutputStatsEstimateResult createRuntimeOutputStatsEstimate() + { + return new OutputStatsEstimateResult(ImmutableLongArray.of(1_000_000_000L, 2_000_000_000L, 3_000_000_000L), 1_000_000L, "FINISHED", true); + } + + private PlanFragment createPlanFragment(StatsAndCosts statsAndCosts) + { + return new PlanFragment( + new PlanFragmentId("fragment"), + TableScanNode.newInstance( + new PlanNodeId("plan_id"), + TEST_TABLE_HANDLE, + ImmutableList.of(new Symbol("col_a"), new Symbol("col_b"), new Symbol("col_c"), new Symbol("col_d")), + ImmutableMap.of( + new Symbol("col_a"), new TestingColumnHandle("col_a", 0, VARCHAR), + new Symbol("col_b"), new TestingColumnHandle("col_b", 1, VARCHAR), + new Symbol("col_c"), new TestingColumnHandle("col_c", 2, BIGINT), + new Symbol("col_d"), new TestingColumnHandle("col_d", 3, DOUBLE)), + false, + Optional.empty()), + ImmutableMap.of( + new Symbol("col_a"), VARCHAR, + new Symbol("col_b"), VARCHAR, + new Symbol("col_c"), BIGINT, + new Symbol("col_d"), DOUBLE), + SOURCE_DISTRIBUTION, + Optional.empty(), + ImmutableList.of(new PlanNodeId("plan_id")), + new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(new Symbol("col_c"))), + statsAndCosts, + ImmutableList.of(), + ImmutableList.of(), + Optional.empty()); + } + + private StatsAndCosts createStatsAndCosts(double nullFraction) + { + PlanNodeStatsEstimate symbolEstimate = new PlanNodeStatsEstimate(10000, + ImmutableMap.of( + new Symbol("col_a"), + SymbolStatsEstimate.builder() + .setNullsFraction(nullFraction) + .setDistinctValuesCount(100) + .build(), + new Symbol("col_b"), + SymbolStatsEstimate.builder() + .setNullsFraction(nullFraction) + .setDistinctValuesCount(233) + .build(), + new Symbol("col_c"), + SymbolStatsEstimate.builder() + .setNullsFraction(nullFraction) + .setDistinctValuesCount(98) + .setHighValue(100) + .setLowValue(3) + .build(), + new Symbol("col_d"), + SymbolStatsEstimate.builder() + .setNullsFraction(nullFraction) + .setDistinctValuesCount(300) + .setHighValue(100) + .setLowValue(3) + .build())); + return new StatsAndCosts(ImmutableMap.of(new PlanNodeId("plan_id"), symbolEstimate), ImmutableMap.of()); + } +} diff --git a/core/trino-main/src/test/java/io/trino/execution/TestQueryManagerConfig.java b/core/trino-main/src/test/java/io/trino/execution/TestQueryManagerConfig.java index 70167a834d5e..1f68b81b82a5 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestQueryManagerConfig.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestQueryManagerConfig.java @@ -109,6 +109,7 @@ public void testDefaults() .setFaultTolerantExecutionSmallStageSourceSizeMultiplier(1.2) .setFaultTolerantExecutionSmallStageRequireNoMorePartitions(false) .setFaultTolerantExecutionStageEstimationForEagerParentEnabled(true) + .setFaultTolerantExecutionAdaptiveQueryPlanningEnabled(false) .setMaxWriterTaskCount(100)); } @@ -186,6 +187,7 @@ public void testExplicitPropertyMappings() .put("fault-tolerant-execution-small-stage-source-size-multiplier", "1.6") .put("fault-tolerant-execution-small-stage-require-no-more-partitions", "true") .put("fault-tolerant-execution-stage-estimation-for-eager-parent-enabled", "false") + .put("fault-tolerant-execution-adaptive-query-planning-enabled", "true") .buildOrThrow(); QueryManagerConfig expected = new QueryManagerConfig() @@ -258,6 +260,7 @@ public void testExplicitPropertyMappings() .setFaultTolerantExecutionSmallStageSourceSizeMultiplier(1.6) .setFaultTolerantExecutionSmallStageRequireNoMorePartitions(true) .setFaultTolerantExecutionStageEstimationForEagerParentEnabled(false) + .setFaultTolerantExecutionAdaptiveQueryPlanningEnabled(true) .setMaxWriterTaskCount(101); assertFullMapping(properties, expected); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestAdaptivePlanner.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestAdaptivePlanner.java new file mode 100644 index 000000000000..9ef343e9e2fa --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestAdaptivePlanner.java @@ -0,0 +1,301 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.primitives.ImmutableLongArray; +import io.trino.Session; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.sql.planner.assertions.BasePlanTest; +import io.trino.sql.planner.assertions.SubPlanMatcher; +import io.trino.sql.planner.iterative.IterativeOptimizer; +import io.trino.sql.planner.iterative.Rule; +import io.trino.sql.planner.plan.AggregationNode; +import io.trino.sql.planner.plan.JoinNode; +import io.trino.sql.planner.plan.Patterns; +import io.trino.sql.planner.plan.PlanFragmentId; +import io.trino.sql.planner.plan.PlanNodeId; +import io.trino.sql.planner.plan.TableScanNode; +import org.junit.jupiter.api.Test; + +import java.util.HashSet; +import java.util.Set; + +import static io.trino.execution.scheduler.faulttolerant.OutputStatsEstimator.OutputStatsEstimateResult; +import static io.trino.sql.planner.assertions.PlanMatchPattern.adaptivePlan; +import static io.trino.sql.planner.assertions.PlanMatchPattern.any; +import static io.trino.sql.planner.assertions.PlanMatchPattern.exchange; +import static io.trino.sql.planner.assertions.PlanMatchPattern.join; +import static io.trino.sql.planner.assertions.PlanMatchPattern.node; +import static io.trino.sql.planner.assertions.PlanMatchPattern.output; +import static io.trino.sql.planner.assertions.PlanMatchPattern.remoteSource; +import static io.trino.sql.planner.plan.JoinType.INNER; + +public class TestAdaptivePlanner + extends BasePlanTest +{ + @Test + public void testJoinOrderSwitchRule() + { + Session session = Session.builder(getPlanTester().getDefaultSession()) + .setSystemProperty("join_distribution_type", "PARTITIONED") + .build(); + + SubPlanMatcher matcher = SubPlanMatcher.builder() + .fragmentMatcher(fm -> fm + .fragmentId(3) + .planPattern( + any( + adaptivePlan( + join(INNER, builder -> builder + .equiCriteria(ImmutableList.of(aliases -> new JoinNode.EquiJoinClause(new Symbol("nationkey"), new Symbol("nationkey_1")))) + .left(remoteSource(ImmutableList.of(new PlanFragmentId("1")))) + .right(any(remoteSource(ImmutableList.of(new PlanFragmentId("2")))))), + join(INNER, builder -> builder + .equiCriteria(ImmutableList.of(aliases -> new JoinNode.EquiJoinClause(new Symbol("nationkey_1"), new Symbol("nationkey")))) + .right(remoteSource(ImmutableList.of(new PlanFragmentId("1")))) + .left(any(remoteSource(ImmutableList.of(new PlanFragmentId("2")))))))))) + .children( + spb -> spb.fragmentMatcher(fm -> fm.fragmentId(2).planPattern(node(TableScanNode.class))), + spb -> spb.fragmentMatcher(fm -> fm.fragmentId(1).planPattern(any(node(TableScanNode.class))))) + .build(); + + assertAdaptivePlan( + "SELECT n.name FROM supplier AS s JOIN nation AS n on s.nationkey = n.nationkey", + session, + ImmutableList.of(new IterativeOptimizer( + getPlanTester().getPlannerContext(), + new RuleStatsRecorder(), + getPlanTester().getStatsCalculator(), + getPlanTester().getCostCalculator(), + ImmutableSet.>builder() + .add(new TestJoinOrderSwitchRule()) + .build())), + ImmutableMap.of( + new PlanFragmentId("1"), createRuntimeStats(ImmutableLongArray.of(10000L, 10000L, 10000L), 10000), + new PlanFragmentId("2"), createRuntimeStats(ImmutableLongArray.of(200L, 2000L, 1000L), 500)), + matcher); + } + + @Test + public void testNoChangeInFragmentIdsForUnchangedSubPlans() + { + Session session = Session.builder(getPlanTester().getDefaultSession()) + .setSystemProperty("join_distribution_type", "PARTITIONED") + .build(); + + SubPlanMatcher matcher = SubPlanMatcher.builder() + .fragmentMatcher(fm -> fm + // This fragment id should change since it is downstream of adaptive stage + .fragmentId(5) + .planPattern( + output( + node(AggregationNode.class, + exchange( + remoteSource(ImmutableList.of(new PlanFragmentId("6")))))))) + .children( + spb -> spb.fragmentMatcher(fm -> fm + // This fragment id should change since it has adaptive plan + .fragmentId(6) + .planPattern(node(AggregationNode.class, + adaptivePlan( + join(INNER, builder -> builder + .equiCriteria(ImmutableList.of(aliases -> new JoinNode.EquiJoinClause(new Symbol("nationkey"), new Symbol("count")))) + .left(remoteSource(ImmutableList.of(new PlanFragmentId("2")))) + .right(any(remoteSource(ImmutableList.of(new PlanFragmentId("3")))))), + join(INNER, builder -> builder + .equiCriteria(ImmutableList.of(aliases -> new JoinNode.EquiJoinClause(new Symbol("count"), new Symbol("nationkey")))) + .right(remoteSource(ImmutableList.of(new PlanFragmentId("2")))) + .left(any(remoteSource(ImmutableList.of(new PlanFragmentId("3")))))))))) + .children( + spb2 -> spb2.fragmentMatcher(fm -> fm + // This fragment id should not change + .fragmentId(3) + .planPattern( + node(AggregationNode.class, + exchange( + remoteSource(ImmutableList.of(new PlanFragmentId("4"))))))) + .children(spb3 -> spb3.fragmentMatcher(fm -> fm + // This fragment id should not change + .fragmentId(4) + .planPattern(node(AggregationNode.class, node(TableScanNode.class))))), + spb2 -> spb2.fragmentMatcher(fm -> fm + // This fragment id should not change + .fragmentId(2).planPattern(any(node(TableScanNode.class)))))) + .build(); + + assertAdaptivePlan( + """ + WITH t AS (SELECT regionkey, count(*) as some_count FROM nation group by regionkey) + SELECT max(s.nationkey), sum(t.regionkey) + FROM supplier AS s + JOIN t + ON s.nationkey = t.some_count + """, + session, + ImmutableList.of(new IterativeOptimizer( + getPlanTester().getPlannerContext(), + new RuleStatsRecorder(), + getPlanTester().getStatsCalculator(), + getPlanTester().getCostCalculator(), + ImmutableSet.>builder() + .add(new TestJoinOrderSwitchRule()) + .build())), + ImmutableMap.of( + new PlanFragmentId("3"), createRuntimeStats(ImmutableLongArray.of(10000L, 10000L, 10000L), 10000), + new PlanFragmentId("2"), createRuntimeStats(ImmutableLongArray.of(200L, 2000L, 1000L), 500)), + matcher + ); + } + + @Test + public void testNoChangeToRootSubPlanIfStatsAreAccurate() + { + Session session = Session.builder(getPlanTester().getDefaultSession()) + .setSystemProperty("join_distribution_type", "PARTITIONED") + .build(); + + SubPlanMatcher matcher = SubPlanMatcher.builder() + .fragmentMatcher(fm -> fm + .fragmentId(0) + .planPattern( + any(join(INNER, builder -> builder + .equiCriteria(ImmutableList.of(aliases -> new JoinNode.EquiJoinClause(new Symbol("nationkey"), new Symbol("nationkey_1")))) + .left(remoteSource(ImmutableList.of(new PlanFragmentId("1")))) + .right(any(remoteSource(ImmutableList.of(new PlanFragmentId("2"))))))))) + .children( + spb -> spb.fragmentMatcher(fm -> fm.fragmentId(1).planPattern(any(node(TableScanNode.class)))), + spb -> spb.fragmentMatcher(fm -> fm.fragmentId(2).planPattern(node(TableScanNode.class)))) + .build(); + + assertAdaptivePlan( + "SELECT n.name FROM supplier AS s JOIN nation AS n on s.nationkey = n.nationkey", + session, + ImmutableList.of(new IterativeOptimizer( + getPlanTester().getPlannerContext(), + new RuleStatsRecorder(), + getPlanTester().getStatsCalculator(), + getPlanTester().getCostCalculator(), + ImmutableSet.>builder() + .add(new TestJoinOrderSwitchRule()) + .build())), + ImmutableMap.of( + new PlanFragmentId("1"), createRuntimeStats(ImmutableLongArray.of(10000L, 10000L, 10000L), 10000), + new PlanFragmentId("2"), createRuntimeStats(ImmutableLongArray.of(200L, 2000L, 1000L), 500), + // Since the runtime stats are accurate, adaptivePlanner will not change this subplan + new PlanFragmentId("0"), createRuntimeStats(ImmutableLongArray.of(10000L, 10000L, 10000L), 10000)), + matcher); + } + + @Test + public void testNoChangeToNestedSubPlanIfStatsAreAccurate() + { + Session session = Session.builder(getPlanTester().getDefaultSession()) + .setSystemProperty("join_distribution_type", "PARTITIONED") + .build(); + + SubPlanMatcher matcher = SubPlanMatcher.builder() + .fragmentMatcher(fm -> fm + // This fragment id should change since it is downstream of adaptive stage + .fragmentId(0) + .planPattern( + output( + node(AggregationNode.class, + exchange( + remoteSource(ImmutableList.of(new PlanFragmentId("1")))))))) + .children( + spb -> spb.fragmentMatcher(fm -> fm + // This fragment id should change since it has adaptive plan + .fragmentId(1) + .planPattern(node(AggregationNode.class, + join(INNER, builder -> builder + .equiCriteria(ImmutableList.of(aliases -> new JoinNode.EquiJoinClause(new Symbol("nationkey"), new Symbol("count")))) + .left(remoteSource(ImmutableList.of(new PlanFragmentId("2")))) + .right(any(remoteSource(ImmutableList.of(new PlanFragmentId("3"))))))))) + .children( + spb2 -> spb2.fragmentMatcher(fm -> fm + // This fragment id should not change + .fragmentId(2).planPattern(any(node(TableScanNode.class)))), + spb2 -> spb2.fragmentMatcher(fm -> fm + // This fragment id should not change + .fragmentId(3) + .planPattern( + node(AggregationNode.class, + exchange( + remoteSource(ImmutableList.of(new PlanFragmentId("4"))))))) + .children(spb3 -> spb3.fragmentMatcher(fm -> fm + // This fragment id should not change + .fragmentId(4) + .planPattern(node(AggregationNode.class, node(TableScanNode.class))))))) + .build(); + + assertAdaptivePlan( + """ + WITH t AS (SELECT regionkey, count(*) as some_count FROM nation group by regionkey) + SELECT max(s.nationkey), sum(t.regionkey) + FROM supplier AS s + JOIN t + ON s.nationkey = t.some_count + """, + session, + ImmutableList.of(new IterativeOptimizer( + getPlanTester().getPlannerContext(), + new RuleStatsRecorder(), + getPlanTester().getStatsCalculator(), + getPlanTester().getCostCalculator(), + ImmutableSet.>builder() + .add(new TestJoinOrderSwitchRule()) + .build())), + ImmutableMap.of( + // Since the runtime stats are accurate, adaptivePlanner will not change this subplan + new PlanFragmentId("1"), createRuntimeStats(ImmutableLongArray.of(10000L, 10000L, 10000L), 10000), + new PlanFragmentId("3"), createRuntimeStats(ImmutableLongArray.of(10000L, 10000L, 10000L), 10000), + new PlanFragmentId("4"), createRuntimeStats(ImmutableLongArray.of(10000L, 10000L, 10000L), 10000), + new PlanFragmentId("2"), createRuntimeStats(ImmutableLongArray.of(200L, 2000L, 1000L), 500)), + matcher + ); + } + + private OutputStatsEstimateResult createRuntimeStats(ImmutableLongArray partitionDataSizes, long outputRowCountEstimate) + { + return new OutputStatsEstimateResult(partitionDataSizes, outputRowCountEstimate, "FINISHED", true); + } + + // This is a test rule which switches the join order of two tables. + private static class TestJoinOrderSwitchRule + implements Rule + { + private static final Pattern PATTERN = Patterns.join(); + private final Set alreadyVisited = new HashSet<>(); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(JoinNode node, Captures captures, Context context) + { + if (alreadyVisited.contains(node.getId())) { + return Result.empty(); + } + alreadyVisited.add(node.getId()); + return Result.ofPlanNode(node.flipChildren()); + } + } +} diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/AdaptivePlanMatcher.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/AdaptivePlanMatcher.java new file mode 100644 index 000000000000..dbfc60d5a56d --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/AdaptivePlanMatcher.java @@ -0,0 +1,56 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.assertions; + +import io.trino.Session; +import io.trino.cost.StatsProvider; +import io.trino.metadata.Metadata; +import io.trino.sql.planner.plan.AdaptivePlanNode; +import io.trino.sql.planner.plan.PlanNode; + +import static com.google.common.base.Preconditions.checkState; +import static io.trino.sql.planner.iterative.Lookup.noLookup; +import static java.util.Objects.requireNonNull; + +public class AdaptivePlanMatcher + implements Matcher +{ + PlanMatchPattern initialPlan; + + public AdaptivePlanMatcher(PlanMatchPattern initialPlan) + { + this.initialPlan = requireNonNull(initialPlan, "initialPlanPattern is null"); + } + + @Override + public boolean shapeMatches(PlanNode node) + { + return node instanceof AdaptivePlanNode; + } + + @Override + public MatchResult detailMatches(PlanNode node, StatsProvider statsProvider, Session session, Metadata metadata, SymbolAliases symbolAliases) + { + checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); + AdaptivePlanNode adaptivePlanNode = (AdaptivePlanNode) node; + return adaptivePlanNode.getInitialPlan().accept(new PlanMatchingVisitor(session, metadata, statsProvider, noLookup()), initialPlan); + } + + @Override + public String toString() + { + return "AdaptivePlanMatcher\n" + + "initialPlan: " + "\n" + initialPlan; + } +} diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/BasePlanTest.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/BasePlanTest.java index a36c35a32610..26c5e1f933cc 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/BasePlanTest.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/BasePlanTest.java @@ -16,19 +16,26 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.common.graph.Traverser; import io.trino.Session; +import io.trino.cost.RuntimeInfoProvider; +import io.trino.cost.StaticRuntimeInfoProvider; +import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.Metadata; import io.trino.plugin.tpch.TpchConnectorFactory; import io.trino.spi.connector.CatalogHandle; import io.trino.sql.planner.LogicalPlanner; import io.trino.sql.planner.Plan; +import io.trino.sql.planner.PlanFragment; import io.trino.sql.planner.RuleStatsRecorder; import io.trino.sql.planner.SubPlan; import io.trino.sql.planner.iterative.IterativeOptimizer; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.iterative.rule.RemoveRedundantIdentityProjections; +import io.trino.sql.planner.optimizations.AdaptivePlanOptimizer; import io.trino.sql.planner.optimizations.PlanOptimizer; import io.trino.sql.planner.optimizations.UnaliasSymbolReferences; +import io.trino.sql.planner.plan.PlanFragmentId; import io.trino.testing.PlanTester; import org.intellij.lang.annotations.Language; import org.junit.jupiter.api.AfterAll; @@ -40,15 +47,22 @@ import java.util.Map; import java.util.function.Consumer; import java.util.function.Predicate; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static io.airlift.testing.Closeables.closeAllRuntimeException; +import static io.trino.client.NodeVersion.UNKNOWN; import static io.trino.execution.querystats.PlanOptimizersStatsCollector.createPlanOptimizersStatsCollector; +import static io.trino.execution.scheduler.faulttolerant.OutputStatsEstimator.OutputStatsEstimateResult; import static io.trino.execution.warnings.WarningCollector.NOOP; import static io.trino.sql.planner.LogicalPlanner.Stage.OPTIMIZED; import static io.trino.sql.planner.LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED; import static io.trino.sql.planner.PlanOptimizers.columnPruningRules; +import static io.trino.sql.planner.planprinter.PlanPrinter.textDistributedPlan; import static io.trino.testing.TestingHandles.TEST_CATALOG_NAME; import static io.trino.testing.TestingSession.testSessionBuilder; +import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toList; import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; @@ -266,4 +280,73 @@ protected SubPlan subplan(@Language("SQL") String sql, LogicalPlanner.Stage stag throw new AssertionError("Planning failed for SQL: " + sql, e); } } + + protected SubPlan createAdaptivePlan(@Language("SQL") String sql, List optimizers, Map completeStageStats) + { + return createAdaptivePlan(sql, planTester.getDefaultSession(), optimizers, completeStageStats); + } + + protected SubPlan createAdaptivePlan(@Language("SQL") String sql, Session session, List optimizers, Map completeStageStats) + { + try { + return planTester.inTransaction(session, transactionSession -> { + Plan plan = planTester.createPlan(transactionSession, sql, planTester.getPlanOptimizers(false), OPTIMIZED_AND_VALIDATED, WarningCollector.NOOP, createPlanOptimizersStatsCollector()); + SubPlan subPlan = planTester.createSubPlans(transactionSession, plan, false); + return planTester.createAdaptivePlan(transactionSession, subPlan, optimizers, WarningCollector.NOOP, createPlanOptimizersStatsCollector(), createRuntimeInfoProvider(subPlan, completeStageStats)); + }); + } + catch (RuntimeException e) { + throw new AssertionError("Adaptive Planning failed for SQL: " + sql, e); + } + } + + protected void assertAdaptivePlan(@Language("SQL") String sql, Map completeStageStats, SubPlanMatcher subPlanMatcher) + { + assertAdaptivePlan(sql, planTester.getDefaultSession(), planTester.getAdaptivePlanOptimizers(), completeStageStats, subPlanMatcher); + } + + protected void assertAdaptivePlan(@Language("SQL") String sql, Session session, Map completeStageStats, SubPlanMatcher subPlanMatcher) + { + assertAdaptivePlan(sql, session, planTester.getAdaptivePlanOptimizers(), completeStageStats, subPlanMatcher); + } + + protected void assertAdaptivePlan(@Language("SQL") String sql, Session session, List optimizers, Map completeStageStats, SubPlanMatcher subPlanMatcher) + { + try { + planTester.inTransaction(session, transactionSession -> { + Plan plan = planTester.createPlan(transactionSession, sql, planTester.getPlanOptimizers(false), OPTIMIZED_AND_VALIDATED, WarningCollector.NOOP, createPlanOptimizersStatsCollector()); + SubPlan subPlan = planTester.createSubPlans(transactionSession, plan, false); + SubPlan adaptivePlan = planTester.createAdaptivePlan(transactionSession, subPlan, optimizers, WarningCollector.NOOP, createPlanOptimizersStatsCollector(), createRuntimeInfoProvider(subPlan, completeStageStats)); + String formattedPlan = textDistributedPlan(adaptivePlan, planTester.getPlannerContext().getMetadata(), planTester.getPlannerContext().getFunctionManager(), transactionSession, false, UNKNOWN); + if (!subPlanMatcher.matches(adaptivePlan, planTester.getStatsCalculator(), transactionSession, planTester.getPlannerContext().getMetadata())) { + throw new AssertionError(format( + "Adaptive plan does not match, expected [\n\n%s\n] but found [\n\n%s\n]", + subPlanMatcher, + formattedPlan)); + } + return null; + }); + } + catch (RuntimeException e) { + e.addSuppressed(new Exception("Query: " + sql)); + throw e; + } + } + + private RuntimeInfoProvider createRuntimeInfoProvider( + SubPlan subPlan, + Map completeStageStats) + { + Map fragments = traverse(subPlan) + .map(SubPlan::getFragment) + .collect(toImmutableMap(PlanFragment::getId, val -> val)); + + return new StaticRuntimeInfoProvider(completeStageStats, fragments); + } + + private Stream traverse(SubPlan subPlan) + { + Iterable iterable = Traverser.forTree(SubPlan::getChildren).depthFirstPreOrder(subPlan); + return StreamSupport.stream(iterable.spliterator(), false); + } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanFragmentMatcher.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanFragmentMatcher.java new file mode 100644 index 000000000000..91725fd3b503 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanFragmentMatcher.java @@ -0,0 +1,189 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.assertions; + +import io.trino.Session; +import io.trino.cost.CachingStatsProvider; +import io.trino.cost.CachingTableStatsProvider; +import io.trino.cost.StatsAndCosts; +import io.trino.cost.StatsCalculator; +import io.trino.cost.StatsProvider; +import io.trino.metadata.Metadata; +import io.trino.sql.planner.PartitioningHandle; +import io.trino.sql.planner.PartitioningScheme; +import io.trino.sql.planner.PlanFragment; +import io.trino.sql.planner.TypeProvider; +import io.trino.sql.planner.plan.PlanFragmentId; +import io.trino.sql.planner.plan.PlanNodeId; + +import java.util.List; +import java.util.Optional; + +import static io.trino.sql.planner.iterative.Lookup.noLookup; +import static java.util.Objects.requireNonNull; + +public class PlanFragmentMatcher +{ + private final PlanFragmentId fragmentId; + private final Optional planPattern; + private final Optional partitioning; + private final Optional inputPartitionCount; + private final Optional outputPartitionCount; + private final Optional> partitionedSources; + private final Optional outputPartitioningScheme; + private final Optional statsAndCosts; + + public static Builder builder() + { + return new Builder(); + } + + public PlanFragmentMatcher( + PlanFragmentId fragmentId, + Optional planPattern, + Optional partitioning, + Optional inputPartitionCount, + Optional outputPartitionCount, + Optional> partitionedSources, + Optional outputPartitioningScheme, + Optional statsAndCosts) + { + this.fragmentId = requireNonNull(fragmentId, "fragmentId is null"); + this.planPattern = requireNonNull(planPattern, "planPattern is null"); + this.partitioning = requireNonNull(partitioning, "partitioning is null"); + this.inputPartitionCount = requireNonNull(inputPartitionCount, "inputPartitionCount is null"); + this.outputPartitionCount = requireNonNull(outputPartitionCount, "outputPartitionCount is null"); + this.partitionedSources = requireNonNull(partitionedSources, "partitionedSources is null"); + this.outputPartitioningScheme = requireNonNull(outputPartitioningScheme, "outputPartitioningScheme is null"); + this.statsAndCosts = requireNonNull(statsAndCosts, "statsAndCosts is null"); + } + + public boolean matches(PlanFragment fragment, StatsCalculator statsCalculator, Session session, Metadata metadata) + { + if (!fragmentId.equals(fragment.getId())) { + return false; + } + if (planPattern.isPresent()) { + StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, session, TypeProvider.viewOf(fragment.getSymbols()), new CachingTableStatsProvider(metadata, session)); + MatchResult matches = fragment.getRoot().accept(new PlanMatchingVisitor(session, metadata, statsProvider, noLookup()), planPattern.get()); + if (!matches.isMatch()) { + return false; + } + } + if (partitioning.isPresent() && !partitioning.get().equals(fragment.getPartitioning())) { + return false; + } + if (inputPartitionCount.isPresent() && !inputPartitionCount.equals(fragment.getPartitionCount())) { + return false; + } + if (outputPartitionCount.isPresent() && !outputPartitionCount.equals(fragment.getOutputPartitioningScheme().getPartitionCount())) { + return false; + } + if (partitionedSources.isPresent() && !partitionedSources.get().equals(fragment.getPartitionedSources())) { + return false; + } + if (outputPartitioningScheme.isPresent() && !outputPartitioningScheme.get().equals(fragment.getOutputPartitioningScheme())) { + return false; + } + return statsAndCosts.isEmpty() || statsAndCosts.get().equals(fragment.getStatsAndCosts()); + } + + @Override + public String toString() + { + StringBuilder builder = new StringBuilder(); + builder.append("Fragment ").append(fragmentId).append("\n"); + planPattern.ifPresent(planPattern -> builder.append("PlanPattern: \n").append(planPattern).append("\n")); + partitioning.ifPresent(partitioning -> builder.append("Partitioning: ").append(partitioning).append("\n")); + outputPartitioningScheme.ifPresent(outputPartitioningScheme -> builder.append("OutputPartitioningScheme: ").append(outputPartitioningScheme).append("\n")); + inputPartitionCount.ifPresent(partitionCount -> builder.append("InputPartitionCount: ").append(partitionCount).append("\n")); + outputPartitionCount.ifPresent(partitionCount -> builder.append("OutputPartitionCount: ").append(partitionCount).append("\n")); + partitionedSources.ifPresent(partitionedSources -> builder.append("PartitionedSources: ").append(partitionedSources).append("\n")); + statsAndCosts.ifPresent(statsAndCosts -> builder.append("StatsAndCosts: ").append(statsAndCosts).append("\n")); + return builder.toString(); + } + + public static class Builder + { + private PlanFragmentId fragmentId; + private Optional planPattern = Optional.empty(); + private Optional partitioning = Optional.empty(); + private Optional inputPartitionCount = Optional.empty(); + private Optional outputPartitionCount = Optional.empty(); + private Optional> partitionedSources = Optional.empty(); + private Optional outputPartitioningScheme = Optional.empty(); + private Optional statsAndCosts = Optional.empty(); + + public Builder fragmentId(int fragmentId) + { + this.fragmentId = new PlanFragmentId(String.valueOf(fragmentId)); + return this; + } + + public Builder planPattern(PlanMatchPattern planPattern) + { + this.planPattern = Optional.of(planPattern); + return this; + } + + public Builder partitioning(PartitioningHandle partitioning) + { + this.partitioning = Optional.of(partitioning); + return this; + } + + public Builder inputPartitionCount(int inputPartitionCount) + { + this.inputPartitionCount = Optional.of(inputPartitionCount); + return this; + } + + public Builder outputPartitionCount(int outputPartitionCount) + { + this.outputPartitionCount = Optional.of(outputPartitionCount); + return this; + } + + public Builder partitionedSources(List partitionedSources) + { + this.partitionedSources = Optional.of(partitionedSources); + return this; + } + + public Builder outputPartitioningScheme(PartitioningScheme outputPartitioningScheme) + { + this.outputPartitioningScheme = Optional.of(outputPartitioningScheme); + return this; + } + + public Builder statsAndCosts(StatsAndCosts statsAndCosts) + { + this.statsAndCosts = Optional.of(statsAndCosts); + return this; + } + + public PlanFragmentMatcher build() + { + return new PlanFragmentMatcher( + fragmentId, + planPattern, + partitioning, + inputPartitionCount, + outputPartitionCount, + partitionedSources, + outputPartitioningScheme, + statsAndCosts); + } + } +} diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java index 3fc221f9e877..e505f2473318 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java @@ -31,6 +31,7 @@ import io.trino.sql.planner.iterative.GroupReference; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; import io.trino.sql.planner.optimizations.SymbolMapper; +import io.trino.sql.planner.plan.AdaptivePlanNode; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.AggregationNode.Step; import io.trino.sql.planner.plan.ApplyNode; @@ -54,8 +55,10 @@ import io.trino.sql.planner.plan.MergeWriterNode; import io.trino.sql.planner.plan.OffsetNode; import io.trino.sql.planner.plan.OutputNode; +import io.trino.sql.planner.plan.PlanFragmentId; import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.ProjectNode; +import io.trino.sql.planner.plan.RemoteSourceNode; import io.trino.sql.planner.plan.SemiJoinNode; import io.trino.sql.planner.plan.SortNode; import io.trino.sql.planner.plan.SpatialJoinNode; @@ -143,6 +146,21 @@ public static PlanMatchPattern anyNot(Class excludeNodeClass return any(sources).with(new NotPlanNodeMatcher(excludeNodeClass)); } + public static PlanMatchPattern adaptivePlan(PlanMatchPattern initialPlan, PlanMatchPattern currentPlan) + { + return node(AdaptivePlanNode.class, currentPlan).with(new AdaptivePlanMatcher(initialPlan)); + } + + public static PlanMatchPattern remoteSource(List sourceFragmentIds) + { + return node(RemoteSourceNode.class) + .with(new RemoteSourceMatcher( + sourceFragmentIds, + Optional.empty(), + Optional.empty(), + Optional.empty())); + } + public static PlanMatchPattern tableScan(String expectedTableName) { return node(TableScanNode.class) @@ -1206,7 +1224,10 @@ private void toString(StringBuilder builder, int indent) .collect(toImmutableList()); for (Matcher matcher : matchersToPrint) { - builder.append(indentString(indent + 1)).append(matcher.toString()).append("\n"); + builder + .append(indentString(indent + 1)) + .append(matcher.toString().replace("\n", "\n" + indentString(indent + 1))) + .append("\n"); } for (PlanMatchPattern pattern : sourcePatterns) { diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/RemoteSourceMatcher.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/RemoteSourceMatcher.java new file mode 100644 index 000000000000..4031479c9590 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/RemoteSourceMatcher.java @@ -0,0 +1,89 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.assertions; + +import io.trino.Session; +import io.trino.cost.StatsProvider; +import io.trino.metadata.Metadata; +import io.trino.operator.RetryPolicy; +import io.trino.sql.planner.OrderingScheme; +import io.trino.sql.planner.plan.ExchangeNode; +import io.trino.sql.planner.plan.PlanFragmentId; +import io.trino.sql.planner.plan.PlanNode; +import io.trino.sql.planner.plan.RemoteSourceNode; + +import java.util.List; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkState; +import static io.trino.sql.planner.assertions.MatchResult.NO_MATCH; +import static io.trino.sql.planner.assertions.MatchResult.match; +import static java.util.Objects.requireNonNull; + +public class RemoteSourceMatcher + implements Matcher +{ + private final List sourceFragmentIds; + private final Optional orderingScheme; + private final Optional exchangeType; + private final Optional retryPolicy; + + public RemoteSourceMatcher( + List sourceFragmentIds, + Optional orderingScheme, + Optional exchangeType, + Optional retryPolicy) + { + this.sourceFragmentIds = requireNonNull(sourceFragmentIds, "sourceFragmentIds is null"); + this.orderingScheme = requireNonNull(orderingScheme, "orderingScheme is null"); + this.exchangeType = requireNonNull(exchangeType, "exchangeType is null"); + this.retryPolicy = requireNonNull(retryPolicy, "retryPolicy is null"); + } + + @Override + public boolean shapeMatches(PlanNode node) + { + return node instanceof RemoteSourceNode && sourceFragmentIds.equals(((RemoteSourceNode) node).getSourceFragmentIds()); + } + + @Override + public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases) + { + checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); + RemoteSourceNode remoteSourceNode = (RemoteSourceNode) node; + + if (orderingScheme.isPresent() && !remoteSourceNode.getOrderingScheme().equals(orderingScheme)) { + return NO_MATCH; + } + + if (exchangeType.isPresent() && !remoteSourceNode.getExchangeType().equals(exchangeType.get())) { + return NO_MATCH; + } + + if (retryPolicy.isPresent() && !remoteSourceNode.getRetryPolicy().equals(retryPolicy.get())) { + return NO_MATCH; + } + + return match(); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("sourceFragmentIds", sourceFragmentIds) + .toString(); + } +} diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/SubPlanMatcher.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/SubPlanMatcher.java new file mode 100644 index 000000000000..1ac3dd96ff83 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/SubPlanMatcher.java @@ -0,0 +1,99 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.assertions; + +import com.google.common.collect.ImmutableList; +import io.trino.Session; +import io.trino.cost.StatsCalculator; +import io.trino.metadata.Metadata; +import io.trino.sql.planner.SubPlan; + +import java.util.Arrays; +import java.util.List; +import java.util.function.Function; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; + +public class SubPlanMatcher +{ + private final PlanFragmentMatcher fragmentMatcher; + private final List children; + + public static Builder builder() + { + return new Builder(); + } + + public SubPlanMatcher( + PlanFragmentMatcher fragmentMatcher, + List children) + { + this.fragmentMatcher = requireNonNull(fragmentMatcher, "fragmentMatcher is null"); + this.children = requireNonNull(children, "children is null"); + } + + public boolean matches(SubPlan subPlan, StatsCalculator statsCalculator, Session session, Metadata metadata) + { + if (subPlan.getChildren().size() != children.size()) { + // Shape of the plan does not match + return false; + } + + for (int i = 0; i < children.size(); i++) { + if (!children.get(i).matches(subPlan.getChildren().get(i), statsCalculator, session, metadata)) { + return false; + } + } + + return fragmentMatcher.matches(subPlan.getFragment(), statsCalculator, session, metadata); + } + + @Override + public String toString() + { + StringBuilder builder = new StringBuilder(); + builder.append(fragmentMatcher.toString()).append("\n"); + for (SubPlanMatcher child : children) { + builder.append(child.toString()); + } + return builder.toString(); + } + + public static class Builder + { + private PlanFragmentMatcher fragmentMatcher; + private List children = ImmutableList.of(); + + public Builder fragmentMatcher(Function fragmentBuilder) + { + this.fragmentMatcher = fragmentBuilder.apply(PlanFragmentMatcher.builder()).build(); + return this; + } + + @SafeVarargs + public final Builder children(Function... children) + { + this.children = Arrays.stream(children) + .map(child -> child.apply(new Builder()).build()) + .collect(toImmutableList()); + return this; + } + + public SubPlanMatcher build() + { + return new SubPlanMatcher(fragmentMatcher, children); + } + } +} diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestJoinEnumerator.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestJoinEnumerator.java index fa874f0c1bfe..34125fae3563 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestJoinEnumerator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestJoinEnumerator.java @@ -23,6 +23,7 @@ import io.trino.cost.CostComparator; import io.trino.cost.CostProvider; import io.trino.cost.PlanCostEstimate; +import io.trino.cost.RuntimeInfoProvider; import io.trino.cost.StatsProvider; import io.trino.execution.warnings.WarningCollector; import io.trino.sql.planner.PlanNodeIdAllocator; @@ -122,7 +123,8 @@ private Rule.Context createContext() noLookup(), planTester.getDefaultSession(), symbolAllocator.getTypes(), - new CachingTableStatsProvider(planTester.getPlannerContext().getMetadata(), planTester.getDefaultSession())); + new CachingTableStatsProvider(planTester.getPlannerContext().getMetadata(), planTester.getDefaultSession()), + RuntimeInfoProvider.noImplementation()); CachingCostProvider costProvider = new CachingCostProvider( planTester.getCostCalculator(), statsProvider, diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java index b9be826e4406..f788d2139784 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java @@ -1022,6 +1022,22 @@ public JoinNode join(JoinType type, PlanNode left, PlanNode right, List [6] 7 + // / \ | / \ + // 3* 4* 1* [8] [9] + // | | + // 3* 4* + SubPlanMatcher matcher = SubPlanMatcher.builder() + .fragmentMatcher(fm -> fm.fragmentId(5)) + .children( + sb -> sb.fragmentMatcher(fm -> fm.fragmentId(6).outputPartitionCount(10).inputPartitionCount(1)) + .children(sb1 -> sb1.fragmentMatcher(fm -> fm.fragmentId(1).outputPartitionCount(1))), + sb -> sb.fragmentMatcher(fm -> fm.fragmentId(7).inputPartitionCount(10).outputPartitionCount(10)) + .children( + sb1 -> sb1.fragmentMatcher(fm -> fm.fragmentId(8).outputPartitionCount(10).inputPartitionCount(1)) + .children(sb2 -> sb2.fragmentMatcher(fm -> fm.fragmentId(3).outputPartitionCount(1))), + sb1 -> sb1.fragmentMatcher(fm -> fm.fragmentId(9).outputPartitionCount(10).inputPartitionCount(1)) + .children(sb2 -> sb2.fragmentMatcher(fm -> fm.fragmentId(4).outputPartitionCount(1))))) + .build(); + + assertAdaptivePlan( + """ + SELECT n1.* FROM nation n1 + RIGHT JOIN + (SELECT n.nationkey FROM (SELECT * FROM lineitem WHERE suppkey BETWEEN 20 and 30) l LEFT JOIN nation n on l.suppkey = n.nationkey) n2 + ON n1.nationkey = n2.nationkey + 1 + """, + getSession(), + ImmutableMap.of( + new PlanFragmentId("3"), createRuntimeStats(ImmutableLongArray.of(ONE_MB, ONE_MB * 2, ONE_MB), 10000), + new PlanFragmentId("4"), createRuntimeStats(ImmutableLongArray.of(ONE_MB, ONE_MB, ONE_MB), 500), + new PlanFragmentId("1"), createRuntimeStats(ImmutableLongArray.of(ONE_MB, ONE_MB, ONE_MB), 500)), + matcher); + } + + @Test + public void testSkipBroadcastSubtree() + { + // result of fragment 7 will be broadcast, + // so no runtime adaptive partitioning will be applied to its subtree + // already started: 4, 10, 11, 12 + // added fragments: 13 + // 0 0 + // | | + // 1 1 + // / \ / \ + // 2 7 => 2 7 + // / \ | / \ | + // 3 6 8 3 6 8 + // / \ / \ / \ / \ + // 4* 5 9 12* [13] 5 9 12* + // / \ | / \ + // 10* 11* 4* 10* 11* + + SubPlanMatcher matcher = SubPlanMatcher.builder() + .fragmentMatcher(fm -> fm.fragmentId(13)) + .children(sb -> sb.fragmentMatcher(fm -> fm.fragmentId(14).inputPartitionCount(10)) + .children( + sb1 -> sb1.fragmentMatcher(fm -> fm.fragmentId(15).outputPartitionCount(10).inputPartitionCount(10)) + .children( + sb2 -> sb2.fragmentMatcher(fm -> fm.fragmentId(16).outputPartitionCount(10).inputPartitionCount(10)) + .children( + sb3 -> sb3.fragmentMatcher(fm -> fm.fragmentId(17).outputPartitionCount(10).inputPartitionCount(1)) + .children(sb4 -> sb4.fragmentMatcher(fm -> fm.fragmentId(4).outputPartitionCount(1))), + sb3 -> sb3.fragmentMatcher(fm -> fm.fragmentId(18).outputPartitionCount(10))), + sb2 -> sb2.fragmentMatcher(fm -> fm.fragmentId(19).outputPartitionCount(10))), + sb1 -> sb1.fragmentMatcher(fm -> fm.fragmentId(7)) + .children( + sb2 -> sb2.fragmentMatcher(fm -> fm.fragmentId(8).inputPartitionCount(1)) + .children( + sb3 -> sb3.fragmentMatcher(fm -> fm.fragmentId(9).outputPartitionCount(1).inputPartitionCount(1)) + .children( + sb4 -> sb4.fragmentMatcher(fm -> fm.fragmentId(10).outputPartitionCount(1)), + sb4 -> sb4.fragmentMatcher(fm -> fm.fragmentId(11).outputPartitionCount(1))), + sb3 -> sb3.fragmentMatcher(fm -> fm.fragmentId(12).outputPartitionCount(1)))))) + .build(); + assertAdaptivePlan( + "SELECT\n" + + " ps.partkey,\n" + + " sum(ps.supplycost * ps.availqty) AS value\n" + + "FROM\n" + + " partsupp ps,\n" + + " supplier s,\n" + + " nation n\n" + + "WHERE\n" + + " ps.suppkey = s.suppkey\n" + + " AND s.nationkey = n.nationkey\n" + + " AND n.name = 'GERMANY'\n" + + "GROUP BY\n" + + " ps.partkey\n" + + "HAVING\n" + + " sum(ps.supplycost * ps.availqty) > (\n" + + " SELECT sum(ps.supplycost * ps.availqty) * 0.0001\n" + + " FROM\n" + + " partsupp ps,\n" + + " supplier s,\n" + + " nation n\n" + + " WHERE\n" + + " ps.suppkey = s.suppkey\n" + + " AND s.nationkey = n.nationkey\n" + + " AND n.name = 'GERMANY'\n" + + " )\n" + + "ORDER BY\n" + + " value DESC", + getSession(), + ImmutableMap.of( + new PlanFragmentId("4"), createRuntimeStats(ImmutableLongArray.of(ONE_MB, ONE_MB * 2, ONE_MB), 10000), + new PlanFragmentId("10"), createRuntimeStats(ImmutableLongArray.of(ONE_MB, ONE_MB, ONE_MB), 500), + new PlanFragmentId("11"), createRuntimeStats(ImmutableLongArray.of(ONE_MB, ONE_MB, ONE_MB), 500), + new PlanFragmentId("12"), createRuntimeStats(ImmutableLongArray.of(ONE_MB, ONE_MB, ONE_MB), 500)), + matcher); + } + + + + private Session getSession() + { + return Session.builder(getPlanTester().getDefaultSession()) + .setSystemProperty(RETRY_POLICY, TASK.name()) + .setSystemProperty(JOIN_REORDERING_STRATEGY, OptimizerConfig.JoinReorderingStrategy.NONE.name()) + .setSystemProperty(JOIN_DISTRIBUTION_TYPE, OptimizerConfig.JoinDistributionType.PARTITIONED.name()) + .setSystemProperty(FAULT_TOLERANT_EXECUTION_RUNTIME_ADAPTIVE_PARTITIONING_ENABLED, "true") + .setSystemProperty(FAULT_TOLERANT_EXECUTION_MAX_PARTITION_COUNT, "2") + .setSystemProperty(FAULT_TOLERANT_EXECUTION_MIN_PARTITION_COUNT, "1") + .setSystemProperty(FAULT_TOLERANT_EXECUTION_RUNTIME_ADAPTIVE_PARTITIONING_PARTITION_COUNT, "10") + .setSystemProperty(FAULT_TOLERANT_EXECUTION_RUNTIME_ADAPTIVE_PARTITIONING_MAX_TASK_SIZE, "1MB") + .build(); + } + + private OutputStatsEstimator.OutputStatsEstimateResult createRuntimeStats(ImmutableLongArray partitionDataSizes, long outputRowCountEstimate) + { + return new OutputStatsEstimator.OutputStatsEstimateResult(partitionDataSizes, outputRowCountEstimate, "FINISHED", true); + } +} diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestRemoveUnsupportedDynamicFilters.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestRemoveUnsupportedDynamicFilters.java index 609de2540e07..daa53f987b63 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestRemoveUnsupportedDynamicFilters.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestRemoveUnsupportedDynamicFilters.java @@ -16,6 +16,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.trino.cost.CachingTableStatsProvider; +import io.trino.cost.RuntimeInfoProvider; import io.trino.cost.StatsAndCosts; import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.Metadata; @@ -490,7 +491,8 @@ private PlanNode removeUnsupportedDynamicFilters(PlanNode root) new PlanNodeIdAllocator(), WarningCollector.NOOP, createPlanOptimizersStatsCollector(), - new CachingTableStatsProvider(metadata, session))); + new CachingTableStatsProvider(metadata, session), + RuntimeInfoProvider.noImplementation())); new DynamicFiltersChecker().validate(rewrittenPlan, session, plannerContext, new IrTypeAnalyzer(plannerContext), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestUnaliasSymbolReferences.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestUnaliasSymbolReferences.java index 81df7efba5bf..8ec3196a9d38 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestUnaliasSymbolReferences.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestUnaliasSymbolReferences.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableMap; import io.trino.Session; import io.trino.cost.CachingTableStatsProvider; +import io.trino.cost.RuntimeInfoProvider; import io.trino.cost.StatsAndCosts; import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.Metadata; @@ -158,7 +159,8 @@ private void assertOptimizedPlan(PlanOptimizer optimizer, PlanCreator planCreato idAllocator, WarningCollector.NOOP, createPlanOptimizersStatsCollector(), - new CachingTableStatsProvider(metadata, session))); + new CachingTableStatsProvider(metadata, session), + RuntimeInfoProvider.noImplementation())); Plan actual = new Plan(optimized, planBuilder.getTypes(), StatsAndCosts.empty()); PlanAssert.assertPlan(session, planTester.getPlannerContext().getMetadata(), planTester.getPlannerContext().getFunctionManager(), planTester.getStatsCalculator(), actual, pattern); diff --git a/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/TestOverridePartitionCountRecursively.java b/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/TestOverridePartitionCountRecursively.java deleted file mode 100644 index f68702a6e9fb..000000000000 --- a/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/TestOverridePartitionCountRecursively.java +++ /dev/null @@ -1,307 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.faulttolerant; - -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; -import io.trino.Session; -import io.trino.connector.CoordinatorDynamicCatalogManager; -import io.trino.connector.InMemoryCatalogStore; -import io.trino.connector.LazyCatalogFactory; -import io.trino.execution.QueryManagerConfig; -import io.trino.execution.warnings.WarningCollector; -import io.trino.metadata.Metadata; -import io.trino.plugin.exchange.filesystem.FileSystemExchangePlugin; -import io.trino.plugin.hive.HiveQueryRunner; -import io.trino.security.AllowAllAccessControl; -import io.trino.sql.planner.PartitioningHandle; -import io.trino.sql.planner.Plan; -import io.trino.sql.planner.PlanFragment; -import io.trino.sql.planner.PlanFragmentIdAllocator; -import io.trino.sql.planner.PlanFragmenter; -import io.trino.sql.planner.PlanNodeIdAllocator; -import io.trino.sql.planner.SubPlan; -import io.trino.sql.planner.plan.PlanFragmentId; -import io.trino.testing.AbstractTestQueryFramework; -import io.trino.testing.FaultTolerantExecutionConnectorTestHelper; -import io.trino.testing.QueryRunner; -import org.intellij.lang.annotations.Language; -import org.junit.jupiter.api.Test; - -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; - -import static com.google.common.collect.ImmutableSet.toImmutableSet; -import static com.google.common.util.concurrent.MoreExecutors.directExecutor; -import static io.trino.SystemSessionProperties.getFaultTolerantExecutionMaxPartitionCount; -import static io.trino.sql.planner.RuntimeAdaptivePartitioningRewriter.consumesHashPartitionedInput; -import static io.trino.sql.planner.RuntimeAdaptivePartitioningRewriter.getMaxPlanFragmentId; -import static io.trino.sql.planner.RuntimeAdaptivePartitioningRewriter.getMaxPlanId; -import static io.trino.sql.planner.RuntimeAdaptivePartitioningRewriter.overridePartitionCountRecursively; -import static io.trino.sql.planner.SystemPartitioningHandle.COORDINATOR_DISTRIBUTION; -import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION; -import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; -import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION; -import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; -import static io.trino.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION; -import static io.trino.sql.planner.TopologicalOrderSubPlanVisitor.sortPlanInTopologicalOrder; -import static io.trino.testing.TransactionBuilder.transaction; -import static io.trino.tpch.TpchTable.getTables; -import static java.util.Objects.requireNonNull; -import static org.assertj.core.api.Assertions.assertThat; - -public class TestOverridePartitionCountRecursively - extends AbstractTestQueryFramework -{ - private static final int PARTITION_COUNT_OVERRIDE = 40; - - @Override - protected QueryRunner createQueryRunner() - throws Exception - { - ImmutableMap.Builder extraPropertiesWithRuntimeAdaptivePartitioning = ImmutableMap.builder(); - extraPropertiesWithRuntimeAdaptivePartitioning.putAll(FaultTolerantExecutionConnectorTestHelper.getExtraProperties()); - extraPropertiesWithRuntimeAdaptivePartitioning.putAll(FaultTolerantExecutionConnectorTestHelper.enforceRuntimeAdaptivePartitioningProperties()); - - return HiveQueryRunner.builder() - .setExtraProperties(extraPropertiesWithRuntimeAdaptivePartitioning.buildOrThrow()) - .setAdditionalSetup(runner -> { - runner.installPlugin(new FileSystemExchangePlugin()); - runner.loadExchangeManager("filesystem", ImmutableMap.of("exchange.base-directories", - System.getProperty("java.io.tmpdir") + "/trino-local-file-system-exchange-manager")); - }) - .setInitialTables(getTables()) - .build(); - } - - @Test - public void testCreateTableAs() - { - // already started: 3, 5, 6 - // added fragments: 7, 8, 9 - // 0 0 - // | | - // 1 1 - // | | - // 2 2 - // / \ / \ - // 3* 4 => [7] 4 - // / \ | / \ - // 5* 6* 3* [8] [9] - // | | - // 5* 6* - assertOverridePartitionCountRecursively( - noJoinReordering(), - "CREATE TABLE tmp AS " + - "SELECT n1.* FROM nation n1 " + - "RIGHT JOIN " + - "(SELECT n.nationkey FROM (SELECT * FROM lineitem WHERE suppkey BETWEEN 20 and 30) l LEFT JOIN nation n on l.suppkey = n.nationkey) n2" + - " ON n1.nationkey = n2.nationkey + 1", - ImmutableMap.builder() - .put(0, new FragmentPartitioningInfo(COORDINATOR_DISTRIBUTION, Optional.empty(), SINGLE_DISTRIBUTION, Optional.empty())) - .put(1, new FragmentPartitioningInfo(SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION, Optional.empty(), SINGLE_DISTRIBUTION, Optional.empty())) - .put(2, new FragmentPartitioningInfo(FIXED_HASH_DISTRIBUTION, Optional.empty(), SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION, Optional.empty())) - .put(3, new FragmentPartitioningInfo(SOURCE_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.empty())) - .put(4, new FragmentPartitioningInfo(FIXED_HASH_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.empty())) - .put(5, new FragmentPartitioningInfo(SOURCE_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.empty())) - .put(6, new FragmentPartitioningInfo(SOURCE_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.empty())) - .buildOrThrow(), - ImmutableMap.builder() - .put(0, new FragmentPartitioningInfo(COORDINATOR_DISTRIBUTION, Optional.empty(), SINGLE_DISTRIBUTION, Optional.empty())) - .put(1, new FragmentPartitioningInfo(SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION, Optional.empty(), SINGLE_DISTRIBUTION, Optional.empty())) - .put(2, new FragmentPartitioningInfo(FIXED_HASH_DISTRIBUTION, Optional.of(40), SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION, Optional.empty())) - .put(3, new FragmentPartitioningInfo(SOURCE_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.empty())) - .put(4, new FragmentPartitioningInfo(FIXED_HASH_DISTRIBUTION, Optional.of(40), FIXED_HASH_DISTRIBUTION, Optional.of(40))) - .put(5, new FragmentPartitioningInfo(SOURCE_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.empty())) - .put(6, new FragmentPartitioningInfo(SOURCE_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.empty())) - .put(7, new FragmentPartitioningInfo(FIXED_HASH_DISTRIBUTION, Optional.of(5), FIXED_HASH_DISTRIBUTION, Optional.of(40))) - .put(8, new FragmentPartitioningInfo(FIXED_HASH_DISTRIBUTION, Optional.of(5), FIXED_HASH_DISTRIBUTION, Optional.of(40))) - .put(9, new FragmentPartitioningInfo(FIXED_HASH_DISTRIBUTION, Optional.of(5), FIXED_HASH_DISTRIBUTION, Optional.of(40))) - .buildOrThrow(), - ImmutableSet.of(3, 5, 6)); - } - - @Test - public void testSkipBroadcastSubtree() - { - // result of fragment 7 will be broadcast, - // so no runtime adaptive partitioning will be applied to its subtree - // already started: 4, 10, 11, 12 - // added fragments: 13 - // 0 0 - // | | - // 1 1 - // / \ / \ - // 2 7 => 2 7 - // / \ | / \ | - // 3 6 8 3 6 8 - // / \ / \ / \ / \ - // 4* 5 9 12* [13] 5 9 12* - // / \ | / \ - // 10* 11* 4* 10* 11* - assertOverridePartitionCountRecursively( - noJoinReordering(), - "SELECT\n" + - " ps.partkey,\n" + - " sum(ps.supplycost * ps.availqty) AS value\n" + - "FROM\n" + - " partsupp ps,\n" + - " supplier s,\n" + - " nation n\n" + - "WHERE\n" + - " ps.suppkey = s.suppkey\n" + - " AND s.nationkey = n.nationkey\n" + - " AND n.name = 'GERMANY'\n" + - "GROUP BY\n" + - " ps.partkey\n" + - "HAVING\n" + - " sum(ps.supplycost * ps.availqty) > (\n" + - " SELECT sum(ps.supplycost * ps.availqty) * 0.0001\n" + - " FROM\n" + - " partsupp ps,\n" + - " supplier s,\n" + - " nation n\n" + - " WHERE\n" + - " ps.suppkey = s.suppkey\n" + - " AND s.nationkey = n.nationkey\n" + - " AND n.name = 'GERMANY'\n" + - " )\n" + - "ORDER BY\n" + - " value DESC", - ImmutableMap.builder() - .put(0, new FragmentPartitioningInfo(SINGLE_DISTRIBUTION, Optional.empty(), SINGLE_DISTRIBUTION, Optional.empty())) - .put(1, new FragmentPartitioningInfo(FIXED_HASH_DISTRIBUTION, Optional.of(4), SINGLE_DISTRIBUTION, Optional.empty())) - .put(2, new FragmentPartitioningInfo(FIXED_HASH_DISTRIBUTION, Optional.of(4), FIXED_HASH_DISTRIBUTION, Optional.of(4))) - .put(3, new FragmentPartitioningInfo(FIXED_HASH_DISTRIBUTION, Optional.of(4), FIXED_HASH_DISTRIBUTION, Optional.of(4))) - .put(4, new FragmentPartitioningInfo(SOURCE_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.of(4))) - .put(5, new FragmentPartitioningInfo(SOURCE_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.of(4))) - .put(6, new FragmentPartitioningInfo(SOURCE_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.of(4))) - .put(7, new FragmentPartitioningInfo(SINGLE_DISTRIBUTION, Optional.empty(), FIXED_BROADCAST_DISTRIBUTION, Optional.empty())) - .put(8, new FragmentPartitioningInfo(FIXED_HASH_DISTRIBUTION, Optional.of(4), SINGLE_DISTRIBUTION, Optional.empty())) - .put(9, new FragmentPartitioningInfo(FIXED_HASH_DISTRIBUTION, Optional.of(4), FIXED_HASH_DISTRIBUTION, Optional.of(4))) - .put(10, new FragmentPartitioningInfo(SOURCE_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.of(4))) - .put(11, new FragmentPartitioningInfo(SOURCE_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.of(4))) - .put(12, new FragmentPartitioningInfo(SOURCE_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.of(4))) - .buildOrThrow(), - ImmutableMap.builder() - .put(0, new FragmentPartitioningInfo(SINGLE_DISTRIBUTION, Optional.empty(), SINGLE_DISTRIBUTION, Optional.empty())) - .put(1, new FragmentPartitioningInfo(FIXED_HASH_DISTRIBUTION, Optional.of(40), SINGLE_DISTRIBUTION, Optional.empty())) - .put(2, new FragmentPartitioningInfo(FIXED_HASH_DISTRIBUTION, Optional.of(40), FIXED_HASH_DISTRIBUTION, Optional.of(40))) - .put(3, new FragmentPartitioningInfo(FIXED_HASH_DISTRIBUTION, Optional.of(40), FIXED_HASH_DISTRIBUTION, Optional.of(40))) - .put(4, new FragmentPartitioningInfo(SOURCE_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.of(4))) - .put(5, new FragmentPartitioningInfo(SOURCE_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.of(40))) - .put(6, new FragmentPartitioningInfo(SOURCE_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.of(40))) - .put(7, new FragmentPartitioningInfo(SINGLE_DISTRIBUTION, Optional.empty(), FIXED_BROADCAST_DISTRIBUTION, Optional.empty())) - .put(8, new FragmentPartitioningInfo(FIXED_HASH_DISTRIBUTION, Optional.of(4), SINGLE_DISTRIBUTION, Optional.empty())) - .put(9, new FragmentPartitioningInfo(FIXED_HASH_DISTRIBUTION, Optional.of(4), FIXED_HASH_DISTRIBUTION, Optional.of(4))) - .put(10, new FragmentPartitioningInfo(SOURCE_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.of(4))) - .put(11, new FragmentPartitioningInfo(SOURCE_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.of(4))) - .put(12, new FragmentPartitioningInfo(SOURCE_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.of(4))) - .put(13, new FragmentPartitioningInfo(FIXED_HASH_DISTRIBUTION, Optional.of(4), FIXED_HASH_DISTRIBUTION, Optional.of(40))) - .buildOrThrow(), - ImmutableSet.of(4, 10, 11, 12)); - } - - private void assertOverridePartitionCountRecursively( - Session session, - @Language("SQL") String sql, - Map fragmentPartitioningInfoBefore, - Map fragmentPartitioningInfoAfter, - Set startedFragments) - { - SubPlan plan = getSubPlan(session, sql); - List planInTopologicalOrder = sortPlanInTopologicalOrder(plan); - assertThat(planInTopologicalOrder).hasSize(fragmentPartitioningInfoBefore.size()); - for (SubPlan subPlan : planInTopologicalOrder) { - PlanFragment fragment = subPlan.getFragment(); - int fragmentIdAsInt = Integer.parseInt(fragment.getId().toString()); - FragmentPartitioningInfo fragmentPartitioningInfo = fragmentPartitioningInfoBefore.get(fragmentIdAsInt); - assertThat(fragment.getPartitionCount()).isEqualTo(fragmentPartitioningInfo.inputPartitionCount()); - assertThat(fragment.getPartitioning()).isEqualTo(fragmentPartitioningInfo.inputPartitioning()); - assertThat(fragment.getOutputPartitioningScheme().getPartitionCount()).isEqualTo(fragmentPartitioningInfo.outputPartitionCount()); - assertThat(fragment.getOutputPartitioningScheme().getPartitioning().getHandle()).isEqualTo(fragmentPartitioningInfo.outputPartitioning()); - } - - PlanFragmentIdAllocator planFragmentIdAllocator = new PlanFragmentIdAllocator(getMaxPlanFragmentId(planInTopologicalOrder) + 1); - PlanNodeIdAllocator planNodeIdAllocator = new PlanNodeIdAllocator(getMaxPlanId(planInTopologicalOrder) + 1); - int oldPartitionCount = planInTopologicalOrder.stream() - .mapToInt(subPlan -> { - PlanFragment fragment = subPlan.getFragment(); - if (consumesHashPartitionedInput(fragment)) { - return fragment.getPartitionCount().orElse(getFaultTolerantExecutionMaxPartitionCount(session)); - } - else { - return 0; - } - }) - .max() - .orElseThrow(); - assertThat(oldPartitionCount > 0).isTrue(); - - SubPlan newPlan = overridePartitionCountRecursively( - plan, - oldPartitionCount, - PARTITION_COUNT_OVERRIDE, - planFragmentIdAllocator, - planNodeIdAllocator, - startedFragments.stream().map(fragmentIdAsInt -> new PlanFragmentId(String.valueOf(fragmentIdAsInt))).collect(toImmutableSet())); - planInTopologicalOrder = sortPlanInTopologicalOrder(newPlan); - assertThat(planInTopologicalOrder).hasSize(fragmentPartitioningInfoAfter.size()); - for (SubPlan subPlan : planInTopologicalOrder) { - PlanFragment fragment = subPlan.getFragment(); - int fragmentIdAsInt = Integer.parseInt(fragment.getId().toString()); - FragmentPartitioningInfo fragmentPartitioningInfo = fragmentPartitioningInfoAfter.get(fragmentIdAsInt); - assertThat(fragment.getPartitionCount()).isEqualTo(fragmentPartitioningInfo.inputPartitionCount()); - assertThat(fragment.getPartitioning()).isEqualTo(fragmentPartitioningInfo.inputPartitioning()); - assertThat(fragment.getOutputPartitioningScheme().getPartitionCount()).isEqualTo(fragmentPartitioningInfo.outputPartitionCount()); - assertThat(fragment.getOutputPartitioningScheme().getPartitioning().getHandle()).isEqualTo(fragmentPartitioningInfo.outputPartitioning()); - } - } - - private SubPlan getSubPlan(Session session, @Language("SQL") String sql) - { - QueryRunner queryRunner = getDistributedQueryRunner(); - Metadata metadata = queryRunner.getPlannerContext().getMetadata(); - return transaction(queryRunner.getTransactionManager(), metadata, new AllowAllAccessControl()) - .singleStatement() - .execute(session, transactionSession -> { - Plan plan = queryRunner.createPlan(transactionSession, sql); - // metadata.getCatalogHandle() registers the catalog for the transaction - transactionSession.getCatalog().ifPresent(catalog -> metadata.getCatalogHandle(transactionSession, catalog)); - return new PlanFragmenter( - metadata, - queryRunner.getPlannerContext().getFunctionManager(), - queryRunner.getTransactionManager(), - new CoordinatorDynamicCatalogManager(new InMemoryCatalogStore(), new LazyCatalogFactory(), directExecutor()), - queryRunner.getPlannerContext().getLanguageFunctionManager(), - new QueryManagerConfig()).createSubPlans(transactionSession, plan, false, WarningCollector.NOOP); - }); - } - - private record FragmentPartitioningInfo( - PartitioningHandle inputPartitioning, - Optional inputPartitionCount, - PartitioningHandle outputPartitioning, - Optional outputPartitionCount) - { - FragmentPartitioningInfo { - requireNonNull(inputPartitioning, "inputPartitioning is null"); - requireNonNull(inputPartitionCount, "inputPartitionCount is null"); - requireNonNull(outputPartitioning, "outputPartitioning is null"); - requireNonNull(outputPartitionCount, "outputPartitionCount is null"); - } - } -} diff --git a/testing/trino-testing/src/main/java/io/trino/testing/FaultTolerantExecutionConnectorTestHelper.java b/testing/trino-testing/src/main/java/io/trino/testing/FaultTolerantExecutionConnectorTestHelper.java index 774460d419de..a68749ae025f 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/FaultTolerantExecutionConnectorTestHelper.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/FaultTolerantExecutionConnectorTestHelper.java @@ -54,6 +54,7 @@ public static Map getExtraProperties() public static Map enforceRuntimeAdaptivePartitioningProperties() { return ImmutableMap.builder() + .put("fault-tolerant-execution-adaptive-query-planning-enabled", "true") .put("fault-tolerant-execution-runtime-adaptive-partitioning-enabled", "true") .put("fault-tolerant-execution-runtime-adaptive-partitioning-partition-count", "40") // to ensure runtime adaptive partitioning is triggered