Skip to content

Commit

Permalink
fix incorrect results when correlated subquery returns zero rows
Browse files Browse the repository at this point in the history
  • Loading branch information
Bhargavi-Sagi authored and martint committed Oct 18, 2023
1 parent f5b1e89 commit cdb4f8f
Show file tree
Hide file tree
Showing 18 changed files with 600 additions and 373 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ private PlanBuilder planScalarSubquery(PlanBuilder subPlan, Cluster<SubqueryExpr
subPlan,
root,
scalarSubquery.getQuery(),
CorrelatedJoinNode.Type.INNER,
CorrelatedJoinNode.Type.LEFT,
TRUE_LITERAL,
mapAll(cluster, subPlan.getScope(), column));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import static io.trino.sql.planner.LogicalPlanner.failFunction;
import static io.trino.sql.planner.optimizations.PlanNodeSearcher.searchFrom;
import static io.trino.sql.planner.optimizations.QueryCardinalityUtil.extractCardinality;
import static io.trino.sql.planner.plan.CorrelatedJoinNode.Type.INNER;
import static io.trino.sql.planner.plan.CorrelatedJoinNode.Type.LEFT;
import static io.trino.sql.planner.plan.Patterns.CorrelatedJoin.correlation;
import static io.trino.sql.planner.plan.Patterns.CorrelatedJoin.filter;
Expand Down Expand Up @@ -123,7 +124,7 @@ public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Co
correlatedJoinNode.getInput(),
rewrittenSubquery,
correlatedJoinNode.getCorrelation(),
producesSingleRow ? correlatedJoinNode.getType() : LEFT,
producesSingleRow ? INNER : correlatedJoinNode.getType(),
correlatedJoinNode.getFilter(),
correlatedJoinNode.getOriginSubquery()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -482,4 +482,34 @@ public void testChecksum()
"ON TRUE"))
.matches("VALUES (1, null), (2, x'd0f70cebd131ec61')");
}

@Test
public void testCorrelatedSubqueryWithGroupedAggregation()
{
assertThat(assertions.query("WITH" +
" t(k, v) AS (VALUES ('A', 1), ('B', NULL), ('C', 2), ('D', 3)), " +
" u(k, v) AS (VALUES (1, 10), (1, 20), (2, 30)) " +
"SELECT" +
" k," +
" (" +
" SELECT max(v) FROM u WHERE t.v = u.k GROUP BY k" +
" ) AS cols " +
"FROM t"))
.matches("VALUES ('A', 20), ('B', NULL), ('C', 30), ('D', NULL)");
}

@Test
public void testCorrelatedSubqueryWithGlobalAggregation()
{
assertThat(assertions.query("WITH" +
" t(k, v) AS (VALUES ('A', 1), ('B', NULL), ('C', 2), ('D', 3)), " +
" u(k, v) AS (VALUES (1, 10), (1, 20), (2, 30)) " +
"SELECT" +
" k," +
" (" +
" SELECT max(v) FROM u WHERE t.v = u.k" +
" ) AS cols " +
"FROM t"))
.matches("VALUES ('A', 20), ('B', NULL), ('C', 30), ('D', NULL)");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
import static io.trino.sql.planner.plan.AggregationNode.Step.FINAL;
import static io.trino.sql.planner.plan.AggregationNode.Step.PARTIAL;
import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE;
import static io.trino.sql.planner.plan.JoinNode.Type.INNER;
import static io.trino.sql.planner.plan.JoinNode.Type.LEFT;
import static io.trino.testing.TestingHandles.TEST_CATALOG_NAME;
import static io.trino.testing.TestingSession.testSessionBuilder;
import static java.lang.String.format;
Expand Down Expand Up @@ -204,7 +204,7 @@ public void testCorrelatedSubqueriesWithTopN()
"SELECT (SELECT t.a FROM (VALUES 1, 2, 3) t(a) WHERE t.a = t2.b ORDER BY a LIMIT 1) FROM (VALUES 1.0, 2.0) t2(b)",
"VALUES 1, 2",
output(
join(INNER, builder -> builder
join(LEFT, builder -> builder
.equiCriteria("cast_b", "cast_a")
.left(
project(
Expand All @@ -228,7 +228,7 @@ public void testCorrelatedSubqueriesWithTopN()
"SELECT (SELECT t.a FROM (VALUES 1, 2, 3, 4, 5) t(a) WHERE t.a = t2.b * t2.c - 1 ORDER BY a LIMIT 1) FROM (VALUES (1, 2), (2, 3)) t2(b, c)",
"VALUES 1, 5",
output(
join(INNER, builder -> builder
join(LEFT, builder -> builder
.equiCriteria("expr", "a")
.left(
project(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,93 +1,122 @@
cross join:
cross join:
cross join:
cross join:
cross join:
cross join:
cross join:
cross join:
cross join:
cross join:
cross join:
cross join:
cross join:
cross join:
cross join:
final aggregation over ()
local exchange (GATHER, SINGLE, [])
remote exchange (GATHER, SINGLE, [])
partial aggregation over ()
scan store_sales
local exchange (GATHER, SINGLE, [])
remote exchange (GATHER, SINGLE, [])
scan reason
final aggregation over ()
local exchange (GATHER, SINGLE, [])
remote exchange (GATHER, SINGLE, [])
partial aggregation over ()
scan store_sales
join (LEFT, REPLICATED):
join (LEFT, REPLICATED):
join (LEFT, REPLICATED):
join (LEFT, REPLICATED):
join (LEFT, REPLICATED):
join (LEFT, REPLICATED):
join (LEFT, REPLICATED):
join (RIGHT, PARTITIONED):
remote exchange (REPARTITION, HASH, [])
final aggregation over ()
local exchange (GATHER, SINGLE, [])
remote exchange (GATHER, SINGLE, [])
partial aggregation over ()
scan store_sales
local exchange (GATHER, SINGLE, [])
join (RIGHT, PARTITIONED):
remote exchange (REPARTITION, HASH, [])
final aggregation over ()
local exchange (GATHER, SINGLE, [])
remote exchange (GATHER, SINGLE, [])
partial aggregation over ()
scan store_sales
local exchange (GATHER, SINGLE, [])
join (RIGHT, PARTITIONED):
remote exchange (REPARTITION, HASH, [])
final aggregation over ()
local exchange (GATHER, SINGLE, [])
remote exchange (GATHER, SINGLE, [])
partial aggregation over ()
scan store_sales
final aggregation over ()
local exchange (GATHER, SINGLE, [])
remote exchange (GATHER, SINGLE, [])
partial aggregation over ()
scan store_sales
final aggregation over ()
local exchange (GATHER, SINGLE, [])
remote exchange (GATHER, SINGLE, [])
partial aggregation over ()
scan store_sales
final aggregation over ()
local exchange (GATHER, SINGLE, [])
remote exchange (GATHER, SINGLE, [])
partial aggregation over ()
scan store_sales
join (RIGHT, PARTITIONED):
remote exchange (REPARTITION, HASH, [])
final aggregation over ()
local exchange (GATHER, SINGLE, [])
remote exchange (GATHER, SINGLE, [])
partial aggregation over ()
scan store_sales
local exchange (GATHER, SINGLE, [])
join (RIGHT, PARTITIONED):
remote exchange (REPARTITION, HASH, [])
final aggregation over ()
local exchange (GATHER, SINGLE, [])
remote exchange (GATHER, SINGLE, [])
partial aggregation over ()
scan store_sales
local exchange (GATHER, SINGLE, [])
join (RIGHT, PARTITIONED):
remote exchange (REPARTITION, HASH, [])
final aggregation over ()
local exchange (GATHER, SINGLE, [])
remote exchange (GATHER, SINGLE, [])
partial aggregation over ()
scan store_sales
local exchange (GATHER, SINGLE, [])
join (RIGHT, PARTITIONED):
remote exchange (REPARTITION, HASH, [])
final aggregation over ()
local exchange (GATHER, SINGLE, [])
remote exchange (GATHER, SINGLE, [])
partial aggregation over ()
scan store_sales
local exchange (GATHER, SINGLE, [])
join (RIGHT, PARTITIONED):
remote exchange (REPARTITION, HASH, [])
final aggregation over ()
local exchange (GATHER, SINGLE, [])
remote exchange (GATHER, SINGLE, [])
partial aggregation over ()
scan store_sales
local exchange (GATHER, SINGLE, [])
remote exchange (REPARTITION, HASH, [])
scan reason
local exchange (GATHER, SINGLE, [])
remote exchange (REPLICATE, BROADCAST, [])
final aggregation over ()
local exchange (GATHER, SINGLE, [])
remote exchange (GATHER, SINGLE, [])
partial aggregation over ()
scan store_sales
local exchange (GATHER, SINGLE, [])
remote exchange (REPLICATE, BROADCAST, [])
final aggregation over ()
local exchange (GATHER, SINGLE, [])
remote exchange (GATHER, SINGLE, [])
partial aggregation over ()
scan store_sales
local exchange (GATHER, SINGLE, [])
remote exchange (REPLICATE, BROADCAST, [])
final aggregation over ()
local exchange (GATHER, SINGLE, [])
remote exchange (GATHER, SINGLE, [])
partial aggregation over ()
scan store_sales
local exchange (GATHER, SINGLE, [])
remote exchange (REPLICATE, BROADCAST, [])
final aggregation over ()
local exchange (GATHER, SINGLE, [])
remote exchange (GATHER, SINGLE, [])
partial aggregation over ()
scan store_sales
local exchange (GATHER, SINGLE, [])
remote exchange (REPLICATE, BROADCAST, [])
final aggregation over ()
local exchange (GATHER, SINGLE, [])
remote exchange (GATHER, SINGLE, [])
partial aggregation over ()
scan store_sales
local exchange (GATHER, SINGLE, [])
remote exchange (REPLICATE, BROADCAST, [])
final aggregation over ()
local exchange (GATHER, SINGLE, [])
remote exchange (GATHER, SINGLE, [])
partial aggregation over ()
scan store_sales
local exchange (GATHER, SINGLE, [])
remote exchange (REPLICATE, BROADCAST, [])
final aggregation over ()
local exchange (GATHER, SINGLE, [])
remote exchange (GATHER, SINGLE, [])
partial aggregation over ()
scan store_sales
final aggregation over ()
local exchange (GATHER, SINGLE, [])
remote exchange (GATHER, SINGLE, [])
partial aggregation over ()
scan store_sales
final aggregation over ()
local exchange (GATHER, SINGLE, [])
remote exchange (GATHER, SINGLE, [])
partial aggregation over ()
scan store_sales
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ local exchange (GATHER, SINGLE, [])
local exchange (GATHER, SINGLE, [])
remote exchange (REPARTITION, HASH, ["ss_customer_sk"])
partial aggregation over (ss_customer_sk)
cross join:
cross join:
join (LEFT, REPLICATED):
join (LEFT, REPLICATED):
join (INNER, REPLICATED):
join (INNER, REPLICATED):
dynamic filter (["ss_customer_sk", "ss_sold_date_sk"])
Expand Down Expand Up @@ -48,8 +48,7 @@ local exchange (GATHER, SINGLE, [])
scan date_dim
local exchange (GATHER, SINGLE, [])
remote exchange (REPLICATE, BROADCAST, [])
dynamic filter (["d_month_seq_26"])
scan date_dim
scan date_dim
local exchange (GATHER, SINGLE, [])
remote exchange (REPLICATE, BROADCAST, [])
local exchange (GATHER, SINGLE, [])
Expand Down
Loading

0 comments on commit cdb4f8f

Please sign in to comment.