Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adaptive planning framework in FTE #20276

Merged
merged 6 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ public final class SystemSessionProperties
private static final String FAULT_TOLERANT_EXECUTION_SMALL_STAGE_SOURCE_SIZE_MULTIPLIER = "fault_tolerant_execution_small_stage_source_size_multiplier";
private static final String FAULT_TOLERANT_EXECUTION_SMALL_STAGE_REQUIRE_NO_MORE_PARTITIONS = "fault_tolerant_execution_small_stage_require_no_more_partitions";
private static final String FAULT_TOLERANT_EXECUTION_STAGE_ESTIMATION_FOR_EAGER_PARENT_ENABLED = "fault_tolerant_execution_stage_estimation_for_eager_parent_enabled";
public static final String FAULT_TOLERANT_EXECUTION_ADAPTIVE_QUERY_PLANNING_ENABLED = "fault_tolerant_execution_adaptive_query_planning_enabled";
public static final String ADAPTIVE_PARTIAL_AGGREGATION_ENABLED = "adaptive_partial_aggregation_enabled";
public static final String ADAPTIVE_PARTIAL_AGGREGATION_UNIQUE_ROWS_RATIO_THRESHOLD = "adaptive_partial_aggregation_unique_rows_ratio_threshold";
public static final String REMOTE_TASK_ADAPTIVE_UPDATE_REQUEST_SIZE_ENABLED = "remote_task_adaptive_update_request_size_enabled";
Expand Down Expand Up @@ -999,6 +1000,11 @@ public SystemSessionProperties(
"Enable aggressive stage output size estimation heuristic for children of stages to be executed eagerly",
queryManagerConfig.isFaultTolerantExecutionStageEstimationForEagerParentEnabled(),
true),
booleanProperty(
FAULT_TOLERANT_EXECUTION_ADAPTIVE_QUERY_PLANNING_ENABLED,
"Enable adaptive query planning for the fault tolerant execution",
queryManagerConfig.isFaultTolerantExecutionAdaptiveQueryPlanningEnabled(),
false),
booleanProperty(
ADAPTIVE_PARTIAL_AGGREGATION_ENABLED,
"When enabled, partial aggregation might be adaptively turned off when it does not provide any performance gain",
Expand Down Expand Up @@ -1846,6 +1852,11 @@ public static boolean isFaultTolerantExecutionStageEstimationForEagerParentEnabl
return session.getSystemProperty(FAULT_TOLERANT_EXECUTION_STAGE_ESTIMATION_FOR_EAGER_PARENT_ENABLED, Boolean.class);
}

public static boolean isFaultTolerantExecutionAdaptiveQueryPlanningEnabled(Session session)
{
return session.getSystemProperty(FAULT_TOLERANT_EXECUTION_ADAPTIVE_QUERY_PLANNING_ENABLED, Boolean.class);
}

public static boolean isAdaptivePartialAggregationEnabled(Session session)
{
return session.getSystemProperty(ADAPTIVE_PARTIAL_AGGREGATION_ENABLED, Boolean.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,31 @@ public final class CachingStatsProvider
private final Session session;
private final TypeProvider types;
private final TableStatsProvider tableStatsProvider;
private final RuntimeInfoProvider runtimeInfoProvider;

private final Map<PlanNode, PlanNodeStatsEstimate> cache = new IdentityHashMap<>();

public CachingStatsProvider(StatsCalculator statsCalculator, Session session, TypeProvider types, TableStatsProvider tableStatsProvider)
{
this(statsCalculator, Optional.empty(), noLookup(), session, types, tableStatsProvider);
this(statsCalculator, Optional.empty(), noLookup(), session, types, tableStatsProvider, RuntimeInfoProvider.noImplementation());
}

public CachingStatsProvider(StatsCalculator statsCalculator, Optional<Memo> memo, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider)
public CachingStatsProvider(
StatsCalculator statsCalculator,
Optional<Memo> memo,
Lookup lookup,
Session session,
TypeProvider types,
TableStatsProvider tableStatsProvider,
RuntimeInfoProvider runtimeInfoProvider)
{
this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null");
this.memo = requireNonNull(memo, "memo is null");
this.lookup = requireNonNull(lookup, "lookup is null");
this.session = requireNonNull(session, "session is null");
this.types = requireNonNull(types, "types is null");
this.tableStatsProvider = requireNonNull(tableStatsProvider, "tableStatsProvider is null");
this.runtimeInfoProvider = requireNonNull(runtimeInfoProvider, "runtimeInfoProvider is null");
}

@Override
Expand All @@ -79,7 +88,7 @@ public PlanNodeStatsEstimate getStats(PlanNode node)
return stats;
}

stats = statsCalculator.calculateStats(node, new StatsCalculator.Context(this, lookup, session, types, tableStatsProvider));
stats = statsCalculator.calculateStats(node, new StatsCalculator.Context(this, lookup, session, types, tableStatsProvider, runtimeInfoProvider));
verify(cache.put(node, stats) == null, "Stats already set");
return stats;
}
Expand Down
149 changes: 149 additions & 0 deletions core/trino-main/src/main/java/io/trino/cost/RemoteSourceStatsRule.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.trino.cost;

import io.trino.cost.StatsCalculator.Context;
import io.trino.execution.scheduler.OutputDataSizeEstimate;
import io.trino.matching.Pattern;
import io.trino.spi.type.FixedWidthType;
import io.trino.spi.type.Type;
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.plan.PlanFragmentId;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.RemoteSourceNode;

import java.util.List;
import java.util.Optional;

import static com.google.common.base.Verify.verify;
import static io.trino.cost.PlanNodeStatsEstimateMath.addStatsAndMaxDistinctValues;
import static io.trino.execution.scheduler.faulttolerant.OutputStatsEstimator.OutputStatsEstimateResult;
import static io.trino.sql.planner.plan.Patterns.remoteSourceNode;
import static io.trino.util.MoreMath.firstNonNaN;
import static java.lang.Double.NaN;
import static java.lang.Double.isNaN;

public class RemoteSourceStatsRule
extends SimpleStatsRule<RemoteSourceNode>
{
private static final Pattern<RemoteSourceNode> PATTERN = remoteSourceNode();

public RemoteSourceStatsRule(StatsNormalizer normalizer)
{
super(normalizer);
}

@Override
public Pattern<RemoteSourceNode> getPattern()
{
return PATTERN;
}

@Override
protected Optional<PlanNodeStatsEstimate> doCalculate(RemoteSourceNode node, Context context)
{
Optional<PlanNodeStatsEstimate> estimate = Optional.empty();
RuntimeInfoProvider runtimeInfoProvider = context.runtimeInfoProvider();

for (int i = 0; i < node.getSourceFragmentIds().size(); i++) {
PlanFragmentId planFragmentId = node.getSourceFragmentIds().get(i);
OutputStatsEstimateResult stageRuntimeStats = runtimeInfoProvider.getRuntimeOutputStats(planFragmentId);

PlanNodeStatsEstimate stageEstimatedStats = getEstimatedStats(runtimeInfoProvider, context.statsProvider(), planFragmentId);
PlanNodeStatsEstimate adjustedStageStats = adjustStats(
node.getOutputSymbols(),
context.types(),
stageRuntimeStats,
stageEstimatedStats);

estimate = estimate
.map(planNodeStatsEstimate -> addStatsAndMaxDistinctValues(planNodeStatsEstimate, adjustedStageStats))
.or(() -> Optional.of(adjustedStageStats));
}

verify(estimate.isPresent());
return estimate;
}

private PlanNodeStatsEstimate getEstimatedStats(
RuntimeInfoProvider runtimeInfoProvider,
StatsProvider statsProvider,
PlanFragmentId fragmentId)
{
PlanFragment fragment = runtimeInfoProvider.getPlanFragment(fragmentId);
PlanNode fragmentRoot = fragment.getRoot();
PlanNodeStatsEstimate estimate = fragment.getStatsAndCosts().getStats().get(fragmentRoot.getId());
// We will not have stats for the root node in a PlanFragment if collect_plan_statistics_for_all_queries
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we refactor the code which governs what is put in fragment.getStatsAndCosts().getStats() so we always have some (possibly empty) PlanNodeStatsEstimate in map so we can skip a null check?

// is disabled and query isn't an explain analyze.
if (estimate != null && !estimate.isOutputRowCountUnknown()) {
return estimate;
}
return statsProvider.getStats(fragmentRoot);
}

private PlanNodeStatsEstimate adjustStats(
List<Symbol> outputs,
TypeProvider typeProvider,
OutputStatsEstimateResult runtimeStats,
PlanNodeStatsEstimate estimateStats)
{
if (runtimeStats.isUnknown()) {
return estimateStats;
}

// We prefer runtime stats over estimated stats, because runtime stats are more accurate.
OutputDataSizeEstimate outputDataSizeEstimate = runtimeStats.outputDataSizeEstimate();
gaurav8297 marked this conversation as resolved.
Show resolved Hide resolved
PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder()
.setOutputRowCount(runtimeStats.outputRowCountEstimate());

double fixedWidthTypeSize = 0;
double variableTypeValuesCount = 0;

for (Symbol outputSymbol : outputs) {
Type type = typeProvider.get(outputSymbol);
SymbolStatsEstimate symbolStatistics = estimateStats.getSymbolStatistics(outputSymbol);
double nullsFraction = firstNonNaN(symbolStatistics.getNullsFraction(), 0d);
double numberOfNonNullRows = runtimeStats.outputRowCountEstimate() * (1.0 - nullsFraction);

if (type instanceof FixedWidthType) {
fixedWidthTypeSize += numberOfNonNullRows * ((FixedWidthType) type).getFixedSize();
}
else {
variableTypeValuesCount += numberOfNonNullRows;
}
}

double runtimeOutputDataSize = outputDataSizeEstimate.getTotalSizeInBytes();
double variableTypeValueAverageSize = NaN;
if (variableTypeValuesCount > 0 && runtimeOutputDataSize > fixedWidthTypeSize) {
variableTypeValueAverageSize = (runtimeOutputDataSize - fixedWidthTypeSize) / variableTypeValuesCount;
}

for (Symbol outputSymbol : outputs) {
SymbolStatsEstimate symbolStatistics = estimateStats.getSymbolStatistics(outputSymbol);
Type type = typeProvider.get(outputSymbol);
if (!(isNaN(variableTypeValueAverageSize) || type instanceof FixedWidthType)) {
symbolStatistics = SymbolStatsEstimate.buildFrom(symbolStatistics)
.setAverageRowSize(variableTypeValueAverageSize)
.build();
}
result.addSymbolStatistics(outputSymbol, symbolStatistics);
}

return result.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.cost;

import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.plan.PlanFragmentId;

import java.util.List;

import static io.trino.execution.scheduler.faulttolerant.OutputStatsEstimator.OutputStatsEstimateResult;

/**
* Provides runtime information from FTE execution. This is used to re-optimize the plan based
* on the actual runtime statistics.
*/
public interface RuntimeInfoProvider
{
OutputStatsEstimateResult getRuntimeOutputStats(PlanFragmentId planFragmentId);

PlanFragment getPlanFragment(PlanFragmentId planFragmentId);

List<PlanFragment> getAllPlanFragments();

static RuntimeInfoProvider noImplementation()
{
return new RuntimeInfoProvider()
{
@Override
public OutputStatsEstimateResult getRuntimeOutputStats(PlanFragmentId planFragmentId)
{
throw new UnsupportedOperationException("RuntimeInfoProvider is not implemented");
}

@Override
public PlanFragment getPlanFragment(PlanFragmentId planFragmentId)
{
throw new UnsupportedOperationException("RuntimeInfoProvider is not implemented");
}

@Override
public List<PlanFragment> getAllPlanFragments()
{
throw new UnsupportedOperationException("RuntimeInfoProvider is not implemented");
}
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.trino.cost;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.plan.PlanFragmentId;

import java.util.List;
import java.util.Map;

import static io.trino.execution.scheduler.faulttolerant.OutputStatsEstimator.OutputStatsEstimateResult;
import static java.util.Objects.requireNonNull;

public class StaticRuntimeInfoProvider
implements RuntimeInfoProvider
{
private final Map<PlanFragmentId, OutputStatsEstimateResult> runtimeOutputStats;
private final Map<PlanFragmentId, PlanFragment> planFragments;

public StaticRuntimeInfoProvider(
Map<PlanFragmentId, OutputStatsEstimateResult> runtimeOutputStats,
Map<PlanFragmentId, PlanFragment> planFragments)
{
this.runtimeOutputStats = ImmutableMap.copyOf(requireNonNull(runtimeOutputStats, "runtimeOutputStats is null"));
this.planFragments = ImmutableMap.copyOf(requireNonNull(planFragments, "planFragments is null"));
}

@Override
public OutputStatsEstimateResult getRuntimeOutputStats(PlanFragmentId planFragmentId)
{
return runtimeOutputStats.getOrDefault(planFragmentId, OutputStatsEstimateResult.unknown());
}

@Override
public PlanFragment getPlanFragment(PlanFragmentId planFragmentId)
{
PlanFragment planFragment = planFragments.get(planFragmentId);
requireNonNull(planFragment, "planFragment must not be null: %s".formatted(planFragmentId));
return planFragment;
}

@Override
public List<PlanFragment> getAllPlanFragments()
{
return ImmutableList.copyOf(planFragments.values());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ record Context(
Lookup lookup,
Session session,
TypeProvider types,
TableStatsProvider tableStatsProvider)
TableStatsProvider tableStatsProvider,
RuntimeInfoProvider runtimeInfoProvider)
{
public Context
{
Expand All @@ -49,6 +50,7 @@ record Context(
requireNonNull(session, "session is null");
requireNonNull(types, "types is null");
requireNonNull(tableStatsProvider, "tableStatsProvider is null");
requireNonNull(runtimeInfoProvider, "runtimeInfoProvider is null");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ public List<ComposableStatsCalculator.Rule<?>> get()
rules.add(new SampleStatsRule(normalizer));
rules.add(new SortStatsRule());
rules.add(new DynamicFilterSourceStatsRule());
rules.add(new RemoteSourceStatsRule(normalizer));

return rules.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ public class QueryManagerConfig
private double faultTolerantExecutionSmallStageSourceSizeMultiplier = 1.2;
private boolean faultTolerantExecutionSmallStageRequireNoMorePartitions;
private boolean faultTolerantExecutionStageEstimationForEagerParentEnabled = true;
private boolean faultTolerantExecutionAdaptiveQueryPlanningEnabled;

@Min(1)
public int getScheduleSplitBatchSize()
Expand Down Expand Up @@ -1105,6 +1106,19 @@ public QueryManagerConfig setFaultTolerantExecutionStageEstimationForEagerParent
return this;
}

public boolean isFaultTolerantExecutionAdaptiveQueryPlanningEnabled()
gaurav8297 marked this conversation as resolved.
Show resolved Hide resolved
{
return faultTolerantExecutionAdaptiveQueryPlanningEnabled;
}

@Config("fault-tolerant-execution-adaptive-query-planning-enabled")
@ConfigDescription("Enable adaptive query planning for the fault tolerant execution")
public QueryManagerConfig setFaultTolerantExecutionAdaptiveQueryPlanningEnabled(boolean faultTolerantExecutionSmallStageEstimationEnabled)
{
this.faultTolerantExecutionAdaptiveQueryPlanningEnabled = faultTolerantExecutionSmallStageEstimationEnabled;
return this;
}

public void applyFaultTolerantExecutionDefaults()
{
remoteTaskMaxErrorDuration = new Duration(1, MINUTES);
Expand Down
Loading
Loading