diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationMaskCompiler.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationMaskCompiler.java index 3336a20b6fbf..12c8656ff7b2 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationMaskCompiler.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationMaskCompiler.java @@ -175,22 +175,23 @@ public static Constructor generateAggregationM Variable selectedPositionsIndex = scope.declareVariable("selectedPositionsIndex", body, constantInt(0)); Variable rawIds = scope.declareVariable(int[].class, "rawIds"); Variable rawIdsOffset = scope.declareVariable(int.class, "rawIdsOffset"); + Variable dictionaryBlockPosition = scope.declareVariable("dictionaryBlockPosition", body, constantInt(0)); body.append(new IfStatement() .condition(maskBlock.instanceOf(DictionaryBlock.class)) .ifTrue(new BytecodeBlock() - .append(maskValueBlock.set(maskBlock.cast(DictionaryBlock.class).invoke("getDictionary", ValueBlock.class, position).cast(ByteArrayBlock.class))) + .append(maskValueBlock.set(maskBlock.cast(DictionaryBlock.class).invoke("getDictionary", ValueBlock.class).cast(ByteArrayBlock.class))) .append(rawIds.set(maskBlock.cast(DictionaryBlock.class).invoke("getRawIds", int[].class))) .append(rawIdsOffset.set(maskBlock.cast(DictionaryBlock.class).invoke("getRawIdsOffset", int.class))) .append(new ForLoop() - .initialize(position.set(constantInt(0))) - .condition(lessThan(position, positionCount)) - .update(position.increment()) + .initialize(dictionaryBlockPosition.set(constantInt(0))) + .condition(lessThan(dictionaryBlockPosition, positionCount)) + .update(dictionaryBlockPosition.increment()) .body(new BytecodeBlock() - .append(position.set(rawIds.getElement(add(rawIdsOffset, position)))) + .append(position.set(rawIds.getElement(add(rawIdsOffset, dictionaryBlockPosition)))) .append(new IfStatement() .condition(isPositionSelected) .ifTrue(new BytecodeBlock() - .append(selectedPositions.setElement(selectedPositionsIndex, position)) + .append(selectedPositions.setElement(selectedPositionsIndex, dictionaryBlockPosition)) .append(selectedPositionsIndex.increment())))))) .ifFalse(new BytecodeBlock() .append(maskValueBlock.set(maskBlock.cast(ByteArrayBlock.class))) diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAggregationMaskCompiler.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAggregationMaskCompiler.java index dc2aa78e36b6..b21daadc2d2b 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAggregationMaskCompiler.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAggregationMaskCompiler.java @@ -16,6 +16,7 @@ import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.ByteArrayBlock; +import io.trino.spi.block.DictionaryBlock; import io.trino.spi.block.IntArrayBlock; import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.block.ShortArrayBlock; @@ -24,6 +25,7 @@ import java.util.Arrays; import java.util.Optional; import java.util.function.Supplier; +import java.util.stream.IntStream; import static io.trino.operator.aggregation.AggregationMaskCompiler.generateAggregationMaskBuilder; import static org.assertj.core.api.Assertions.assertThat; @@ -131,18 +133,22 @@ private void testApplyMask(Supplier maskBuilderSupplier) Arrays.fill(mask, (byte) 1); assertAggregationMaskAll(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlock(positionCount, mask))), positionCount); + assertAggregationMaskAll(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlockAsDictionary(positionCount, mask))), positionCount); Arrays.fill(mask, (byte) 0); mask[1] = 1; mask[3] = 1; mask[5] = 1; assertAggregationMaskPositions(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlock(positionCount, mask))), positionCount, 1, 3, 5); + assertAggregationMaskPositions(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlockAsDictionary(positionCount, mask))), positionCount, 1, 3, 5); mask[3] = 0; assertAggregationMaskPositions(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlock(positionCount, mask))), positionCount, 1, 5); + assertAggregationMaskPositions(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlockAsDictionary(positionCount, mask))), positionCount, 1, 5); mask[2] = 1; assertAggregationMaskPositions(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlock(positionCount, mask))), positionCount, 1, 2, 5); + assertAggregationMaskPositions(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlockAsDictionary(positionCount, mask))), positionCount, 1, 2, 5); assertAggregationMaskAll(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlockRle(positionCount, (byte) 1))), positionCount); assertAggregationMaskPositions(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlockRle(positionCount, (byte) 0))), positionCount); @@ -165,6 +171,7 @@ private void testApplyMaskNulls(Supplier maskBuilderSupp Arrays.fill(mask, (byte) 1); assertAggregationMaskAll(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlock(positionCount, mask))), positionCount); + assertAggregationMaskAll(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlockAsDictionary(positionCount, mask))), positionCount); boolean[] nullFlags = new boolean[positionCount]; assertAggregationMaskAll(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlockNulls(nullFlags))), positionCount); @@ -197,6 +204,19 @@ private static Block createMaskBlockRle(int positionCount, byte mask) return RunLengthEncodedBlock.create(createMaskBlock(1, new byte[] {mask}), positionCount); } + private static Block createMaskBlockAsDictionary(int positionCount, byte[] mask) + { + // spread the mask out and then create a dictionary block choosing only the original mask values + // this ensures that the compiler properly handles unwraps the dictionary block + byte[] newMask = new byte[positionCount * 2]; + for (int i = positionCount - 1; i >= 0; i--) { + newMask[i * 2] = mask[i]; + newMask[(i * 2) + 1] = (byte) (mask[i] == 0 ? 1 : 0); + } + Block block = DictionaryBlock.create(positionCount * 2, new ByteArrayBlock(positionCount * 2, Optional.empty(), newMask), IntStream.range(0, positionCount * 2).toArray()); + return block.getPositions(IntStream.range(0, positionCount).map(i -> i * 2).toArray(), 0, positionCount); + } + private static Block createMaskBlockNulls(boolean[] nulls) { int positionCount = nulls.length; diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestAggregation.java b/core/trino-main/src/test/java/io/trino/sql/query/TestAggregation.java index 651b2c766583..a0d6b67dffc6 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/TestAggregation.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestAggregation.java @@ -87,4 +87,25 @@ SELECT max(x), current_timestamp, current_date, current_time, localtimestamp, lo TIME '12:24:0.000') """); } + + /** + * Regression test for #21002 + */ + @Test + public void testAggregationMaskOnDictionaryInput() + { + assertThat(assertions.query( + """ + SELECT + max(update_ts) FILTER (WHERE step_type = 'Rest') + FROM (VALUES + ('cell_id', 'Rest', TIMESTAMP '2005-09-10 13:31:00.123 Europe/Warsaw'), + ('cell_id', 'Rest', TIMESTAMP '2005-09-10 13:31:00.123 Europe/Warsaw') + ) AS t(cell_id, step_type, update_ts) + -- UNNEST to produce DictionaryBlock + CROSS JOIN UNNEST (sequence(1, 1000)) AS a(e) + GROUP BY cell_id + """)) + .matches("VALUES TIMESTAMP '2005-09-10 13:31:00.123 Europe/Warsaw'"); + } }