Skip to content

Commit

Permalink
Fix masked aggregation on dictionary input
Browse files Browse the repository at this point in the history
Co-authored-by: praveenkrishna.d <praveenkrishna@tutanota.com>
  • Loading branch information
2 people authored and martint committed Mar 14, 2024
1 parent 5568127 commit 3098d8f
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -175,22 +175,23 @@ public static Constructor<? extends AggregationMaskBuilder> generateAggregationM
Variable selectedPositionsIndex = scope.declareVariable("selectedPositionsIndex", body, constantInt(0));
Variable rawIds = scope.declareVariable(int[].class, "rawIds");
Variable rawIdsOffset = scope.declareVariable(int.class, "rawIdsOffset");
Variable dictionaryBlockPosition = scope.declareVariable("dictionaryBlockPosition", body, constantInt(0));
body.append(new IfStatement()
.condition(maskBlock.instanceOf(DictionaryBlock.class))
.ifTrue(new BytecodeBlock()
.append(maskValueBlock.set(maskBlock.cast(DictionaryBlock.class).invoke("getDictionary", ValueBlock.class, position).cast(ByteArrayBlock.class)))
.append(maskValueBlock.set(maskBlock.cast(DictionaryBlock.class).invoke("getDictionary", ValueBlock.class).cast(ByteArrayBlock.class)))
.append(rawIds.set(maskBlock.cast(DictionaryBlock.class).invoke("getRawIds", int[].class)))
.append(rawIdsOffset.set(maskBlock.cast(DictionaryBlock.class).invoke("getRawIdsOffset", int.class)))
.append(new ForLoop()
.initialize(position.set(constantInt(0)))
.condition(lessThan(position, positionCount))
.update(position.increment())
.initialize(dictionaryBlockPosition.set(constantInt(0)))
.condition(lessThan(dictionaryBlockPosition, positionCount))
.update(dictionaryBlockPosition.increment())
.body(new BytecodeBlock()
.append(position.set(rawIds.getElement(add(rawIdsOffset, position))))
.append(position.set(rawIds.getElement(add(rawIdsOffset, dictionaryBlockPosition))))
.append(new IfStatement()
.condition(isPositionSelected)
.ifTrue(new BytecodeBlock()
.append(selectedPositions.setElement(selectedPositionsIndex, position))
.append(selectedPositions.setElement(selectedPositionsIndex, dictionaryBlockPosition))
.append(selectedPositionsIndex.increment()))))))
.ifFalse(new BytecodeBlock()
.append(maskValueBlock.set(maskBlock.cast(ByteArrayBlock.class)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.ByteArrayBlock;
import io.trino.spi.block.DictionaryBlock;
import io.trino.spi.block.IntArrayBlock;
import io.trino.spi.block.RunLengthEncodedBlock;
import io.trino.spi.block.ShortArrayBlock;
Expand All @@ -24,6 +25,7 @@
import java.util.Arrays;
import java.util.Optional;
import java.util.function.Supplier;
import java.util.stream.IntStream;

import static io.trino.operator.aggregation.AggregationMaskCompiler.generateAggregationMaskBuilder;
import static org.assertj.core.api.Assertions.assertThat;
Expand Down Expand Up @@ -131,18 +133,22 @@ private void testApplyMask(Supplier<AggregationMaskBuilder> maskBuilderSupplier)
Arrays.fill(mask, (byte) 1);

assertAggregationMaskAll(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlock(positionCount, mask))), positionCount);
assertAggregationMaskAll(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlockAsDictionary(positionCount, mask))), positionCount);

Arrays.fill(mask, (byte) 0);
mask[1] = 1;
mask[3] = 1;
mask[5] = 1;
assertAggregationMaskPositions(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlock(positionCount, mask))), positionCount, 1, 3, 5);
assertAggregationMaskPositions(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlockAsDictionary(positionCount, mask))), positionCount, 1, 3, 5);

mask[3] = 0;
assertAggregationMaskPositions(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlock(positionCount, mask))), positionCount, 1, 5);
assertAggregationMaskPositions(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlockAsDictionary(positionCount, mask))), positionCount, 1, 5);

mask[2] = 1;
assertAggregationMaskPositions(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlock(positionCount, mask))), positionCount, 1, 2, 5);
assertAggregationMaskPositions(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlockAsDictionary(positionCount, mask))), positionCount, 1, 2, 5);

assertAggregationMaskAll(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlockRle(positionCount, (byte) 1))), positionCount);
assertAggregationMaskPositions(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlockRle(positionCount, (byte) 0))), positionCount);
Expand All @@ -165,6 +171,7 @@ private void testApplyMaskNulls(Supplier<AggregationMaskBuilder> maskBuilderSupp
Arrays.fill(mask, (byte) 1);

assertAggregationMaskAll(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlock(positionCount, mask))), positionCount);
assertAggregationMaskAll(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlockAsDictionary(positionCount, mask))), positionCount);

boolean[] nullFlags = new boolean[positionCount];
assertAggregationMaskAll(maskBuilder.buildAggregationMask(buildSingleColumnPage(positionCount), Optional.of(createMaskBlockNulls(nullFlags))), positionCount);
Expand Down Expand Up @@ -197,6 +204,19 @@ private static Block createMaskBlockRle(int positionCount, byte mask)
return RunLengthEncodedBlock.create(createMaskBlock(1, new byte[] {mask}), positionCount);
}

private static Block createMaskBlockAsDictionary(int positionCount, byte[] mask)
{
// spread the mask out and then create a dictionary block choosing only the original mask values
// this ensures that the compiler properly handles unwraps the dictionary block
byte[] newMask = new byte[positionCount * 2];
for (int i = positionCount - 1; i >= 0; i--) {
newMask[i * 2] = mask[i];
newMask[(i * 2) + 1] = (byte) (mask[i] == 0 ? 1 : 0);
}
Block block = DictionaryBlock.create(positionCount * 2, new ByteArrayBlock(positionCount * 2, Optional.empty(), newMask), IntStream.range(0, positionCount * 2).toArray());
return block.getPositions(IntStream.range(0, positionCount).map(i -> i * 2).toArray(), 0, positionCount);
}

private static Block createMaskBlockNulls(boolean[] nulls)
{
int positionCount = nulls.length;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,25 @@ SELECT max(x), current_timestamp, current_date, current_time, localtimestamp, lo
TIME '12:24:0.000')
""");
}

/**
* Regression test for <a href="https://github.com/trinodb/trino/issues/21002">#21002</a>
*/
@Test
public void testAggregationMaskOnDictionaryInput()
{
assertThat(assertions.query(
"""
SELECT
max(update_ts) FILTER (WHERE step_type = 'Rest')
FROM (VALUES
('cell_id', 'Rest', TIMESTAMP '2005-09-10 13:31:00.123 Europe/Warsaw'),
('cell_id', 'Rest', TIMESTAMP '2005-09-10 13:31:00.123 Europe/Warsaw')
) AS t(cell_id, step_type, update_ts)
-- UNNEST to produce DictionaryBlock
CROSS JOIN UNNEST (sequence(1, 1000)) AS a(e)
GROUP BY cell_id
"""))
.matches("VALUES TIMESTAMP '2005-09-10 13:31:00.123 Europe/Warsaw'");
}
}

0 comments on commit 3098d8f

Please sign in to comment.