From 99979fb424f4f9e410210705a3fd0f89af0357a2 Mon Sep 17 00:00:00 2001 From: Parth Gandhi Date: Wed, 13 Mar 2024 11:36:02 -0400 Subject: [PATCH] Handle Invalid Values in Window Aggregations This PR changes the return type of removeInput method to boolean to handle NaN and Infinite values in Aggregation Window Functions --- .../aggregation/AccumulatorCompiler.java | 83 +++++++++- .../aggregation/AverageAggregations.java | 14 +- .../aggregation/CountAggregation.java | 3 +- .../operator/aggregation/CountColumn.java | 3 +- .../aggregation/CountIfAggregation.java | 3 +- .../aggregation/DoubleSumAggregation.java | 11 +- .../aggregation/LongSumAggregation.java | 3 +- .../aggregation/RealAverageAggregation.java | 11 +- .../aggregation/WindowAccumulator.java | 6 +- .../window/AggregateWindowFunction.java | 19 ++- .../operator/TestRealAverageAggregation.java | 141 +++++++++++++++++ .../AbstractTestAggregationFunction.java | 5 +- .../TestDoubleAverageAggregation.java | 117 ++++++++++++++ .../aggregation/TestDoubleSumAggregation.java | 143 ++++++++++++++++++ .../window/AbstractTestWindowFunction.java | 10 ++ .../window/TestAggregateWindowFunction.java | 64 ++++++++ .../operator/window/WindowAssertions.java | 52 +++++++ 17 files changed, 657 insertions(+), 31 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java index 1a5e6a312bbf..bff12ba50d23 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java @@ -26,6 +26,7 @@ import io.airlift.bytecode.control.IfStatement; import io.airlift.bytecode.expression.BytecodeExpression; import io.airlift.bytecode.expression.BytecodeExpressions; +import io.airlift.bytecode.instruction.LabelNode; import io.trino.operator.window.InternalWindowIndex; import io.trino.spi.Page; import io.trino.spi.block.Block; @@ -287,22 +288,20 @@ public static Constructor generateWindowAccumulator // Generate methods generateCopy(definition, WindowAccumulator.class); - generateAddOrRemoveInputWindowIndex( + generateAddInputWindowIndex( definition, stateFields, argumentNullable, lambdaProviderFields, implementation.getInputFunction(), - "addInput", callSiteBinder); implementation.getRemoveInputFunction().ifPresent( - removeInputFunction -> generateAddOrRemoveInputWindowIndex( + removeInputFunction -> generateRemoveInputWindowIndex( definition, stateFields, argumentNullable, lambdaProviderFields, removeInputFunction, - "removeInput", callSiteBinder)); generateEvaluateFinal(definition, stateFields, implementation.getOutputFunction(), callSiteBinder); @@ -417,13 +416,12 @@ private static void generateAddInput( body.ret(); } - private static void generateAddOrRemoveInputWindowIndex( + private static void generateAddInputWindowIndex( ClassDefinition definition, List stateField, List argumentNullable, List lambdaProviderFields, MethodHandle inputFunction, - String generatedFunctionName, CallSiteBinder callSiteBinder) { // TODO: implement masking based on maskChannel field once Window Functions support DISTINCT arguments to the functions. @@ -434,7 +432,7 @@ private static void generateAddOrRemoveInputWindowIndex( MethodDefinition method = definition.declareMethod( a(PUBLIC), - generatedFunctionName, + "addInput", type(void.class), ImmutableList.of(index, startPosition, endPosition)); Scope scope = method.getScope(); @@ -462,7 +460,7 @@ private static void generateAddOrRemoveInputWindowIndex( invokeInputFunction.append(invokeDynamic( BOOTSTRAP_METHOD, ImmutableList.of(binding.getBindingId()), - generatedFunctionName, + "addInput", binding.getType(), getInvokeFunctionOnWindowIndexParameters( scope.getThis(), @@ -481,6 +479,75 @@ private static void generateAddOrRemoveInputWindowIndex( .ret(); } + private static void generateRemoveInputWindowIndex( + ClassDefinition definition, + List stateField, + List argumentNullable, + List lambdaProviderFields, + MethodHandle removeFunction, + CallSiteBinder callSiteBinder) + { + // TODO: implement masking based on maskChannel field once Window Functions support DISTINCT arguments to the functions. + + Parameter index = arg("index", WindowIndex.class); + Parameter startPosition = arg("startPosition", int.class); + Parameter endPosition = arg("endPosition", int.class); + + MethodDefinition method = definition.declareMethod( + a(PUBLIC), + "removeInput", + type(boolean.class), + ImmutableList.of(index, startPosition, endPosition)); + Scope scope = method.getScope(); + BytecodeBlock body = method.getBody(); + + Variable position = scope.declareVariable(int.class, "position"); + + // input parameters + Variable inputBlockPosition = scope.declareVariable(int.class, "inputBlockPosition"); + List inputBlockVariables = new ArrayList<>(); + for (int i = 0; i < argumentNullable.size(); i++) { + inputBlockVariables.add(scope.declareVariable(Block.class, "inputBlock" + i)); + } + + Binding binding = callSiteBinder.bind(removeFunction); + BytecodeBlock invokeRemoveFunction = new BytecodeBlock(); + // WindowIndex is built on PagesIndex, which simply wraps Blocks + // and currently does not understand ValueBlocks. + // Until PagesIndex is updated to understand ValueBlocks, the + // input function parameters must be directly unwrapped to ValueBlocks. + invokeRemoveFunction.append(inputBlockPosition.set(index.cast(InternalWindowIndex.class).invoke("getRawBlockPosition", int.class, position))); + for (int i = 0; i < inputBlockVariables.size(); i++) { + invokeRemoveFunction.append(inputBlockVariables.get(i).set(index.cast(InternalWindowIndex.class).invoke("getRawBlock", Block.class, constantInt(i), position))); + } + LabelNode returnFalse = new LabelNode("returnFalse"); + invokeRemoveFunction.append(invokeDynamic( + BOOTSTRAP_METHOD, + ImmutableList.of(binding.getBindingId()), + "removeInput", + binding.getType(), + getInvokeFunctionOnWindowIndexParameters( + scope.getThis(), + stateField, + inputBlockPosition, + inputBlockVariables, + lambdaProviderFields))); + invokeRemoveFunction.ifFalseGoto(returnFalse); + + body.append(new ForLoop() + .initialize(position.set(startPosition)) + .condition(BytecodeExpressions.lessThanOrEqual(position, endPosition)) + .update(position.increment()) + .body(new IfStatement() + .condition(anyParametersAreNull(argumentNullable, index, position)) + .ifFalse(invokeRemoveFunction))) + .push(true) + .retBoolean() + .visitLabel(returnFalse) + .push(false) + .retBoolean(); + } + private static BytecodeExpression anyParametersAreNull( List argumentNullable, Variable index, diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AverageAggregations.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AverageAggregations.java index cf223696ecac..ed54f947240f 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AverageAggregations.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AverageAggregations.java @@ -46,17 +46,23 @@ public static void input(@AggregationState LongAndDoubleState state, @SqlType(St } @RemoveInputFunction - public static void removeInput(@AggregationState LongAndDoubleState state, @SqlType(StandardTypes.BIGINT) long value) + public static boolean removeInput(@AggregationState LongAndDoubleState state, @SqlType(StandardTypes.BIGINT) long value) { state.setLong(state.getLong() - 1); state.setDouble(state.getDouble() - value); + return true; } @RemoveInputFunction - public static void removeInput(@AggregationState LongAndDoubleState state, @SqlType(StandardTypes.DOUBLE) double value) + public static boolean removeInput(@AggregationState LongAndDoubleState state, @SqlType(StandardTypes.DOUBLE) double value) { - state.setLong(state.getLong() - 1); - state.setDouble(state.getDouble() - value); + double currentValue = state.getDouble(); + if (Double.isFinite(currentValue)) { + state.setDouble(currentValue - value); + state.setLong(state.getLong() - 1); + return true; + } + return false; } @CombineFunction diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/CountAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/CountAggregation.java index 0c4ce05a19c6..ccca13d8190f 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/CountAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/CountAggregation.java @@ -37,9 +37,10 @@ public static void input(@AggregationState LongState state) } @RemoveInputFunction - public static void removeInput(@AggregationState LongState state) + public static boolean removeInput(@AggregationState LongState state) { state.setValue(state.getValue() - 1); + return true; } @CombineFunction diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/CountColumn.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/CountColumn.java index 87ccef50fbec..0f26448d9a62 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/CountColumn.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/CountColumn.java @@ -47,12 +47,13 @@ public static void input( } @RemoveInputFunction - public static void removeInput( + public static boolean removeInput( @AggregationState LongState state, @BlockPosition @SqlType("T") ValueBlock block, @BlockIndex int position) { state.setValue(state.getValue() - 1); + return true; } @CombineFunction diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/CountIfAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/CountIfAggregation.java index b67a853bce3a..15c39cd49c2f 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/CountIfAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/CountIfAggregation.java @@ -40,11 +40,12 @@ public static void input(@AggregationState LongState state, @SqlType(StandardTyp } @RemoveInputFunction - public static void removeInput(@AggregationState LongState state, @SqlType(StandardTypes.BOOLEAN) boolean value) + public static boolean removeInput(@AggregationState LongState state, @SqlType(StandardTypes.BOOLEAN) boolean value) { if (value) { state.setValue(state.getValue() - 1); } + return true; } @CombineFunction diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/DoubleSumAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/DoubleSumAggregation.java index 6be0a64818bf..9a3d641e4921 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/DoubleSumAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/DoubleSumAggregation.java @@ -38,10 +38,15 @@ public static void sum(@AggregationState LongAndDoubleState state, @SqlType(Stan } @RemoveInputFunction - public static void removeInput(@AggregationState LongAndDoubleState state, @SqlType(StandardTypes.DOUBLE) double value) + public static boolean removeInput(@AggregationState LongAndDoubleState state, @SqlType(StandardTypes.DOUBLE) double value) { - state.setLong(state.getLong() - 1); - state.setDouble(state.getDouble() - value); + double currentValue = state.getDouble(); + if (Double.isFinite(currentValue)) { + state.setDouble(currentValue - value); + state.setLong(state.getLong() - 1); + return true; + } + return false; } @CombineFunction diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/LongSumAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/LongSumAggregation.java index 53597d844720..e497b90d3ae7 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/LongSumAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/LongSumAggregation.java @@ -38,10 +38,11 @@ public static void sum(@AggregationState LongLongState state, @SqlType(StandardT } @RemoveInputFunction - public static void removeInput(@AggregationState LongLongState state, @SqlType(StandardTypes.BIGINT) long value) + public static boolean removeInput(@AggregationState LongLongState state, @SqlType(StandardTypes.BIGINT) long value) { state.setFirst(state.getFirst() - 1); state.setSecond(BigintOperators.subtract(state.getSecond(), value)); + return true; // This should always return true as the state cannot be infinite or NaN for long input values. } @CombineFunction diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/RealAverageAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/RealAverageAggregation.java index 3f0474de4f8d..675920b3b44b 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/RealAverageAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/RealAverageAggregation.java @@ -46,13 +46,18 @@ public static void input( } @RemoveInputFunction - public static void removeInput( + public static boolean removeInput( @AggregationState LongState count, @AggregationState DoubleState sum, @SqlType("REAL") long value) { - count.setValue(count.getValue() - 1); - sum.setValue(sum.getValue() - intBitsToFloat((int) value)); + double currentValue = sum.getValue(); + if (Double.isFinite(currentValue)) { + sum.setValue(currentValue - intBitsToFloat((int) value)); + count.setValue(count.getValue() - 1); + return true; + } + return false; } @CombineFunction diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/WindowAccumulator.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/WindowAccumulator.java index 375a24b31274..e4659ca7c32f 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/WindowAccumulator.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/WindowAccumulator.java @@ -24,7 +24,11 @@ public interface WindowAccumulator void addInput(WindowIndex index, int startPosition, int endPosition); - void removeInput(WindowIndex index, int startPosition, int endPosition); + /** + * @return Returns false when an NaN or Infinite input double value is + * encountered, true otherwise. + */ + boolean removeInput(WindowIndex index, int startPosition, int endPosition); void evaluateFinal(BlockBuilder blockBuilder); } diff --git a/core/trino-main/src/main/java/io/trino/operator/window/AggregateWindowFunction.java b/core/trino-main/src/main/java/io/trino/operator/window/AggregateWindowFunction.java index b2dc6d469b10..8c1c759c73a7 100644 --- a/core/trino-main/src/main/java/io/trino/operator/window/AggregateWindowFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/window/AggregateWindowFunction.java @@ -82,11 +82,13 @@ private void buildNewFrame(int frameStart, int frameEnd) if ((overlapEnd - overlapStart + 1) > (prefixRemoveLength + suffixRemoveLength)) { // It's worth keeping the overlap, and removing the now-unused prefix - if (currentStart < frameStart) { - remove(currentStart, frameStart - 1); + if (currentStart < frameStart && !remove(currentStart, frameStart - 1)) { + resetNewFrame(frameStart, frameEnd); + return; } - if (frameEnd < currentEnd) { - remove(frameEnd + 1, currentEnd); + if (frameEnd < currentEnd && !remove(frameEnd + 1, currentEnd)) { + resetNewFrame(frameStart, frameEnd); + return; } if (frameStart < currentStart) { accumulate(frameStart, currentStart - 1); @@ -101,6 +103,11 @@ private void buildNewFrame(int frameStart, int frameEnd) } // We couldn't or didn't want to modify the accumulation: instead, discard the current accumulation and start fresh. + resetNewFrame(frameStart, frameEnd); + } + + private void resetNewFrame(int frameStart, int frameEnd) + { resetAccumulator(); accumulate(frameStart, frameEnd); currentStart = frameStart; @@ -112,9 +119,9 @@ private void accumulate(int start, int end) accumulator.addInput(windowIndex, start, end); } - private void remove(int start, int end) + private boolean remove(int start, int end) { - accumulator.removeInput(windowIndex, start, end); + return accumulator.removeInput(windowIndex, start, end); } private void resetAccumulator() diff --git a/core/trino-main/src/test/java/io/trino/operator/TestRealAverageAggregation.java b/core/trino-main/src/test/java/io/trino/operator/TestRealAverageAggregation.java index 21c72d5a3a61..7dea353bf0dd 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestRealAverageAggregation.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestRealAverageAggregation.java @@ -14,9 +14,16 @@ package io.trino.operator; import com.google.common.collect.ImmutableList; +import io.trino.block.BlockAssertions; +import io.trino.metadata.ResolvedFunction; import io.trino.operator.aggregation.AbstractTestAggregationFunction; +import io.trino.operator.aggregation.WindowAccumulator; +import io.trino.operator.window.PagesWindowIndex; +import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AggregationImplementation; +import io.trino.spi.function.WindowIndex; import io.trino.spi.type.Type; import org.junit.jupiter.api.Test; @@ -24,9 +31,11 @@ import static io.trino.block.BlockAssertions.createBlockOfReals; import static io.trino.operator.aggregation.AggregationTestUtils.assertAggregation; +import static io.trino.operator.aggregation.AggregationTestUtils.makeValidityAssertion; import static io.trino.spi.type.RealType.REAL; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; import static java.lang.Float.floatToRawIntBits; +import static org.assertj.core.api.Assertions.assertThat; public class TestRealAverageAggregation extends AbstractTestAggregationFunction @@ -99,4 +108,136 @@ protected Object getExpectedValue(int start, int length) } return sum / length; } + + protected Block[] getSequenceBlocksForRealNaNTest(int start, int length) + { + BlockBuilder blockBuilder = REAL.createBlockBuilder(null, length); + for (int i = start; i < start + length - 5; i++) { + REAL.writeLong(blockBuilder, floatToRawIntBits((float) i)); + } + REAL.writeLong(blockBuilder, floatToRawIntBits(Float.NaN)); + for (int i = start + length - 4; i < start + length; i++) { + REAL.writeLong(blockBuilder, floatToRawIntBits((float) i)); + } + return new Block[] {blockBuilder.build()}; + } + + protected Block[] getSequenceBlocksForRealInfinityTest(int start, int length) + { + BlockBuilder blockBuilder = REAL.createBlockBuilder(null, length); + for (int i = start; i < start + length - 5; i++) { + REAL.writeLong(blockBuilder, floatToRawIntBits((float) i)); + } + REAL.writeLong(blockBuilder, floatToRawIntBits(Float.POSITIVE_INFINITY)); + for (int i = start + length - 4; i < start + length; i++) { + REAL.writeLong(blockBuilder, floatToRawIntBits((float) i)); + } + return new Block[] {blockBuilder.build()}; + } + + @Test + public void testSlidingWindowForNaNAndInfinity() + { + int totalPositions = 12; + int[] windowWidths = new int[totalPositions]; + Object[] expectedValues = new Object[totalPositions]; + Object[] expectedValues2 = new Object[totalPositions]; + + for (int i = 0; i < totalPositions; ++i) { + int windowWidth = Integer.min(i, totalPositions - 1 - i); + windowWidths[i] = windowWidth; + if (i >= 4) { + expectedValues[i] = Float.NaN; + expectedValues2[i] = Float.POSITIVE_INFINITY; + } + else { + expectedValues[i] = getExpectedValue(i, windowWidth); + expectedValues2[i] = getExpectedValue(i, windowWidth); + } + } + Page inputPage = new Page(totalPositions, getSequenceBlocksForRealNaNTest(0, totalPositions)); + + PagesIndex pagesIndex = new PagesIndex.TestingFactory(false).newPagesIndex(getFunctionParameterTypes(), totalPositions); + pagesIndex.addPage(inputPage); + WindowIndex windowIndex = new PagesWindowIndex(pagesIndex, 0, totalPositions - 1); + + ResolvedFunction resolvedFunction = functionResolution.resolveFunction(getFunctionName(), fromTypes(getFunctionParameterTypes())); + AggregationImplementation aggregationImplementation = functionResolution.getPlannerContext().getFunctionManager().getAggregationImplementation(resolvedFunction); + WindowAccumulator aggregation = createWindowAccumulator(resolvedFunction, aggregationImplementation); + assertThat(resolvedFunction.getSignature().getReturnType().toString().contains("real")).isTrue(); + assertThat(resolvedFunction.getSignature().getName().toString().contains("avg")).isTrue(); + int oldStart = 0; + int oldWidth = 0; + for (int start = 0; start < totalPositions; ++start) { + int width = windowWidths[start]; + for (int oldi = oldStart; oldi < oldStart + oldWidth; ++oldi) { + if (oldi < start || oldi >= start + width) { + boolean res = aggregation.removeInput(windowIndex, oldi, oldi); + if (oldi >= 4) { + assertThat(res).isFalse(); + } + else { + assertThat(res).isTrue(); + } + } + } + for (int newi = start; newi < start + width; ++newi) { + if (newi < oldStart || newi >= oldStart + oldWidth) { + aggregation.addInput(windowIndex, newi, newi); + } + } + oldStart = start; + oldWidth = width; + + Type outputType = resolvedFunction.getSignature().getReturnType(); + BlockBuilder blockBuilder = outputType.createBlockBuilder(null, 1000); + aggregation.evaluateFinal(blockBuilder); + Block block = blockBuilder.build(); + + assertThat(makeValidityAssertion(expectedValues[start]).apply( + BlockAssertions.getOnlyValue(outputType, block), + expectedValues[start])) + .isTrue(); + } + + Page inputPage2 = new Page(totalPositions, getSequenceBlocksForRealInfinityTest(0, totalPositions)); + + PagesIndex pagesIndex2 = new PagesIndex.TestingFactory(false).newPagesIndex(getFunctionParameterTypes(), totalPositions); + pagesIndex2.addPage(inputPage2); + WindowIndex windowIndex2 = new PagesWindowIndex(pagesIndex2, 0, totalPositions - 1); + WindowAccumulator aggregation2 = createWindowAccumulator(resolvedFunction, aggregationImplementation); + oldStart = 0; + oldWidth = 0; + for (int start = 0; start < totalPositions; ++start) { + int width = windowWidths[start]; + for (int oldi = oldStart; oldi < oldStart + oldWidth; ++oldi) { + if (oldi < start || oldi >= start + width) { + boolean res = aggregation2.removeInput(windowIndex2, oldi, oldi); + if (oldi >= 4) { + assertThat(res).isFalse(); + } + else { + assertThat(res).isTrue(); + } + } + } + for (int newi = start; newi < start + width; ++newi) { + if (newi < oldStart || newi >= oldStart + oldWidth) { + aggregation2.addInput(windowIndex2, newi, newi); + } + } + oldStart = start; + oldWidth = width; + + Type outputType = resolvedFunction.getSignature().getReturnType(); + BlockBuilder blockBuilder = outputType.createBlockBuilder(null, 1000); + aggregation2.evaluateFinal(blockBuilder); + Block block = blockBuilder.build(); + + assertThat(makeValidityAssertion(expectedValues2[start]).apply( + BlockAssertions.getOnlyValue(outputType, block), + expectedValues2[start])) + .isTrue(); + } + } } diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/AbstractTestAggregationFunction.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/AbstractTestAggregationFunction.java index f75364dad388..6d4965ba3ee3 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/AbstractTestAggregationFunction.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/AbstractTestAggregationFunction.java @@ -156,7 +156,8 @@ public void testSlidingWindow() if (aggregationImplementation.getRemoveInputFunction().isPresent()) { for (int oldi = oldStart; oldi < oldStart + oldWidth; ++oldi) { if (oldi < start || oldi >= start + width) { - aggregation.removeInput(windowIndex, oldi, oldi); + boolean res = aggregation.removeInput(windowIndex, oldi, oldi); + assertThat(res).isTrue(); } } for (int newi = start; newi < start + width; ++newi) { @@ -184,7 +185,7 @@ public void testSlidingWindow() } } - private static WindowAccumulator createWindowAccumulator(ResolvedFunction resolvedFunction, AggregationImplementation aggregationImplementation) + protected static WindowAccumulator createWindowAccumulator(ResolvedFunction resolvedFunction, AggregationImplementation aggregationImplementation) { try { Constructor constructor = generateWindowAccumulatorClass( diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDoubleAverageAggregation.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDoubleAverageAggregation.java index 6df7a2b2e384..2606466d980d 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDoubleAverageAggregation.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDoubleAverageAggregation.java @@ -14,13 +14,24 @@ package io.trino.operator.aggregation; import com.google.common.collect.ImmutableList; +import io.trino.block.BlockAssertions; +import io.trino.metadata.ResolvedFunction; +import io.trino.operator.PagesIndex; +import io.trino.operator.window.PagesWindowIndex; +import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AggregationImplementation; +import io.trino.spi.function.WindowIndex; import io.trino.spi.type.Type; +import org.junit.jupiter.api.Test; import java.util.List; +import static io.trino.operator.aggregation.AggregationTestUtils.makeValidityAssertion; import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; +import static org.assertj.core.api.Assertions.assertThat; public class TestDoubleAverageAggregation extends AbstractTestAggregationFunction @@ -60,4 +71,110 @@ protected List getFunctionParameterTypes() { return ImmutableList.of(DOUBLE); } + + @Test + public void testSlidingWindowForNaNAndInfinity() + { + int totalPositions = 12; + int[] windowWidths = new int[totalPositions]; + Object[] expectedValues = new Object[totalPositions]; + Object[] expectedValues2 = new Object[totalPositions]; + + for (int i = 0; i < totalPositions; ++i) { + int windowWidth = Integer.min(i, totalPositions - 1 - i); + windowWidths[i] = windowWidth; + if (i >= 4) { + expectedValues[i] = Double.NaN; + expectedValues2[i] = Double.POSITIVE_INFINITY; + } + else { + expectedValues[i] = getExpectedValue(i, windowWidth); + expectedValues2[i] = getExpectedValue(i, windowWidth); + } + } + Page inputPage = new Page(totalPositions, TestDoubleSumAggregation.getSequenceBlocksForDoubleNaNTest(0, totalPositions)); + + PagesIndex pagesIndex = new PagesIndex.TestingFactory(false).newPagesIndex(getFunctionParameterTypes(), totalPositions); + pagesIndex.addPage(inputPage); + WindowIndex windowIndex = new PagesWindowIndex(pagesIndex, 0, totalPositions - 1); + + ResolvedFunction resolvedFunction = functionResolution.resolveFunction(getFunctionName(), fromTypes(getFunctionParameterTypes())); + AggregationImplementation aggregationImplementation = functionResolution.getPlannerContext().getFunctionManager().getAggregationImplementation(resolvedFunction); + WindowAccumulator aggregation = createWindowAccumulator(resolvedFunction, aggregationImplementation); + assertThat(resolvedFunction.getSignature().getReturnType().toString().contains("double")).isTrue(); + assertThat(resolvedFunction.getSignature().getName().toString().contains("avg")).isTrue(); + int oldStart = 0; + int oldWidth = 0; + for (int start = 0; start < totalPositions; ++start) { + int width = windowWidths[start]; + for (int oldi = oldStart; oldi < oldStart + oldWidth; ++oldi) { + if (oldi < start || oldi >= start + width) { + boolean res = aggregation.removeInput(windowIndex, oldi, oldi); + if (oldi >= 4) { + assertThat(res).isFalse(); + } + else { + assertThat(res).isTrue(); + } + } + } + for (int newi = start; newi < start + width; ++newi) { + if (newi < oldStart || newi >= oldStart + oldWidth) { + aggregation.addInput(windowIndex, newi, newi); + } + } + oldStart = start; + oldWidth = width; + + Type outputType = resolvedFunction.getSignature().getReturnType(); + BlockBuilder blockBuilder = outputType.createBlockBuilder(null, 1000); + aggregation.evaluateFinal(blockBuilder); + Block block = blockBuilder.build(); + + assertThat(makeValidityAssertion(expectedValues[start]).apply( + BlockAssertions.getOnlyValue(outputType, block), + expectedValues[start])) + .isTrue(); + } + + Page inputPage2 = new Page(totalPositions, TestDoubleSumAggregation.getSequenceBlocksForDoubleInfinityTest(0, totalPositions)); + + PagesIndex pagesIndex2 = new PagesIndex.TestingFactory(false).newPagesIndex(getFunctionParameterTypes(), totalPositions); + pagesIndex2.addPage(inputPage2); + WindowIndex windowIndex2 = new PagesWindowIndex(pagesIndex2, 0, totalPositions - 1); + WindowAccumulator aggregation2 = createWindowAccumulator(resolvedFunction, aggregationImplementation); + oldStart = 0; + oldWidth = 0; + for (int start = 0; start < totalPositions; ++start) { + int width = windowWidths[start]; + for (int oldi = oldStart; oldi < oldStart + oldWidth; ++oldi) { + if (oldi < start || oldi >= start + width) { + boolean res = aggregation2.removeInput(windowIndex2, oldi, oldi); + if (oldi >= 4) { + assertThat(res).isFalse(); + } + else { + assertThat(res).isTrue(); + } + } + } + for (int newi = start; newi < start + width; ++newi) { + if (newi < oldStart || newi >= oldStart + oldWidth) { + aggregation2.addInput(windowIndex2, newi, newi); + } + } + oldStart = start; + oldWidth = width; + + Type outputType = resolvedFunction.getSignature().getReturnType(); + BlockBuilder blockBuilder = outputType.createBlockBuilder(null, 1000); + aggregation2.evaluateFinal(blockBuilder); + Block block = blockBuilder.build(); + + assertThat(makeValidityAssertion(expectedValues2[start]).apply( + BlockAssertions.getOnlyValue(outputType, block), + expectedValues2[start])) + .isTrue(); + } + } } diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDoubleSumAggregation.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDoubleSumAggregation.java index bfe5b90cb16f..46466c4c6d35 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDoubleSumAggregation.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDoubleSumAggregation.java @@ -14,13 +14,24 @@ package io.trino.operator.aggregation; import com.google.common.collect.ImmutableList; +import io.trino.block.BlockAssertions; +import io.trino.metadata.ResolvedFunction; +import io.trino.operator.PagesIndex; +import io.trino.operator.window.PagesWindowIndex; +import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AggregationImplementation; +import io.trino.spi.function.WindowIndex; import io.trino.spi.type.Type; +import org.junit.jupiter.api.Test; import java.util.List; +import static io.trino.operator.aggregation.AggregationTestUtils.makeValidityAssertion; import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; +import static org.assertj.core.api.Assertions.assertThat; public class TestDoubleSumAggregation extends AbstractTestAggregationFunction @@ -60,4 +71,136 @@ protected List getFunctionParameterTypes() { return ImmutableList.of(DOUBLE); } + + protected static Block[] getSequenceBlocksForDoubleNaNTest(int start, int length) + { + BlockBuilder blockBuilder = DOUBLE.createBlockBuilder(null, length); + for (int i = start; i < start + length - 5; i++) { + DOUBLE.writeDouble(blockBuilder, i); + } + DOUBLE.writeDouble(blockBuilder, Double.NaN); + for (int i = start + length - 4; i < start + length; i++) { + DOUBLE.writeDouble(blockBuilder, i); + } + return new Block[] {blockBuilder.build()}; + } + + protected static Block[] getSequenceBlocksForDoubleInfinityTest(int start, int length) + { + BlockBuilder blockBuilder = DOUBLE.createBlockBuilder(null, length); + for (int i = start; i < start + length - 5; i++) { + DOUBLE.writeDouble(blockBuilder, i); + } + DOUBLE.writeDouble(blockBuilder, Double.POSITIVE_INFINITY); + for (int i = start + length - 4; i < start + length; i++) { + DOUBLE.writeDouble(blockBuilder, i); + } + return new Block[] {blockBuilder.build()}; + } + + @Test + public void testSlidingWindowForNaNAndInfinity() + { + int totalPositions = 12; + int[] windowWidths = new int[totalPositions]; + Object[] expectedValues = new Object[totalPositions]; + Object[] expectedValues2 = new Object[totalPositions]; + + for (int i = 0; i < totalPositions; ++i) { + int windowWidth = Integer.min(i, totalPositions - 1 - i); + windowWidths[i] = windowWidth; + if (i >= 4) { + expectedValues[i] = Double.NaN; + expectedValues2[i] = Double.POSITIVE_INFINITY; + } + else { + expectedValues[i] = getExpectedValue(i, windowWidth); + expectedValues2[i] = getExpectedValue(i, windowWidth); + } + } + Page inputPage = new Page(totalPositions, getSequenceBlocksForDoubleNaNTest(0, totalPositions)); + + PagesIndex pagesIndex = new PagesIndex.TestingFactory(false).newPagesIndex(getFunctionParameterTypes(), totalPositions); + pagesIndex.addPage(inputPage); + WindowIndex windowIndex = new PagesWindowIndex(pagesIndex, 0, totalPositions - 1); + + ResolvedFunction resolvedFunction = functionResolution.resolveFunction(getFunctionName(), fromTypes(getFunctionParameterTypes())); + AggregationImplementation aggregationImplementation = functionResolution.getPlannerContext().getFunctionManager().getAggregationImplementation(resolvedFunction); + WindowAccumulator aggregation = createWindowAccumulator(resolvedFunction, aggregationImplementation); + assertThat(resolvedFunction.getSignature().getReturnType().toString().contains("double")).isTrue(); + assertThat(resolvedFunction.getSignature().getName().toString().contains("sum")).isTrue(); + int oldStart = 0; + int oldWidth = 0; + for (int start = 0; start < totalPositions; ++start) { + int width = windowWidths[start]; + for (int oldi = oldStart; oldi < oldStart + oldWidth; ++oldi) { + if (oldi < start || oldi >= start + width) { + boolean res = aggregation.removeInput(windowIndex, oldi, oldi); + if (oldi >= 4) { + assertThat(res).isFalse(); + } + else { + assertThat(res).isTrue(); + } + } + } + for (int newi = start; newi < start + width; ++newi) { + if (newi < oldStart || newi >= oldStart + oldWidth) { + aggregation.addInput(windowIndex, newi, newi); + } + } + oldStart = start; + oldWidth = width; + + Type outputType = resolvedFunction.getSignature().getReturnType(); + BlockBuilder blockBuilder = outputType.createBlockBuilder(null, 1000); + aggregation.evaluateFinal(blockBuilder); + Block block = blockBuilder.build(); + + assertThat(makeValidityAssertion(expectedValues[start]).apply( + BlockAssertions.getOnlyValue(outputType, block), + expectedValues[start])) + .isTrue(); + } + + Page inputPage2 = new Page(totalPositions, getSequenceBlocksForDoubleInfinityTest(0, totalPositions)); + + PagesIndex pagesIndex2 = new PagesIndex.TestingFactory(false).newPagesIndex(getFunctionParameterTypes(), totalPositions); + pagesIndex2.addPage(inputPage2); + WindowIndex windowIndex2 = new PagesWindowIndex(pagesIndex2, 0, totalPositions - 1); + WindowAccumulator aggregation2 = createWindowAccumulator(resolvedFunction, aggregationImplementation); + oldStart = 0; + oldWidth = 0; + for (int start = 0; start < totalPositions; ++start) { + int width = windowWidths[start]; + for (int oldi = oldStart; oldi < oldStart + oldWidth; ++oldi) { + if (oldi < start || oldi >= start + width) { + boolean res = aggregation2.removeInput(windowIndex2, oldi, oldi); + if (oldi >= 4) { + assertThat(res).isFalse(); + } + else { + assertThat(res).isTrue(); + } + } + } + for (int newi = start; newi < start + width; ++newi) { + if (newi < oldStart || newi >= oldStart + oldWidth) { + aggregation2.addInput(windowIndex2, newi, newi); + } + } + oldStart = start; + oldWidth = width; + + Type outputType = resolvedFunction.getSignature().getReturnType(); + BlockBuilder blockBuilder = outputType.createBlockBuilder(null, 1000); + aggregation2.evaluateFinal(blockBuilder); + Block block = blockBuilder.build(); + + assertThat(makeValidityAssertion(expectedValues2[start]).apply( + BlockAssertions.getOnlyValue(outputType, block), + expectedValues2[start])) + .isTrue(); + } + } } diff --git a/core/trino-main/src/test/java/io/trino/operator/window/AbstractTestWindowFunction.java b/core/trino-main/src/test/java/io/trino/operator/window/AbstractTestWindowFunction.java index e5a2e8a9c2b5..d6908421b207 100644 --- a/core/trino-main/src/test/java/io/trino/operator/window/AbstractTestWindowFunction.java +++ b/core/trino-main/src/test/java/io/trino/operator/window/AbstractTestWindowFunction.java @@ -67,6 +67,16 @@ protected MaterializedResult executeWindowQueryWithNulls(@Language("SQL") String return WindowAssertions.executeWindowQueryWithNulls(sql, queryRunner); } + protected void assertWindowQueryWithNan(@Language("SQL") String sql, MaterializedResult expected) + { + WindowAssertions.assertWindowQueryWithNan(sql, expected, queryRunner); + } + + protected void assertWindowQueryWithInfinity(@Language("SQL") String sql, MaterializedResult expected) + { + WindowAssertions.assertWindowQueryWithInfinity(sql, expected, queryRunner); + } + protected void assertUnboundedWindowQueryWithNulls(@Language("SQL") String sql, MaterializedResult expected) { assertWindowQueryWithNulls(unbounded(sql), expected); diff --git a/core/trino-main/src/test/java/io/trino/operator/window/TestAggregateWindowFunction.java b/core/trino-main/src/test/java/io/trino/operator/window/TestAggregateWindowFunction.java index 96cc52ae26a6..03d623a68cdd 100644 --- a/core/trino-main/src/test/java/io/trino/operator/window/TestAggregateWindowFunction.java +++ b/core/trino-main/src/test/java/io/trino/operator/window/TestAggregateWindowFunction.java @@ -595,4 +595,68 @@ public void testSumAllNulls() .row(null, null, null) .build()); } + + @Test + public void testAverageRowsRollingWithNanAndInfinity() + { + assertWindowQueryWithNan("avg(orderkey) OVER (ORDER BY orderdate ROWS BETWEEN 2 PRECEDING AND CURRENT ROW)", + resultBuilder(TEST_SESSION, BIGINT, VARCHAR, DOUBLE) + .row(6.0, "F", 6.0) + .row(3.0, "F", 4.5) + .row(33.0, "F", 14.0) + .row(Double.NaN, "F", Double.NaN) + .row(32.0, "O", Double.NaN) + .row(4.0, "O", Double.NaN) + .row(1.0, "O", 12.333333333333334) + .row(7.0, "O", 4.0) + .row(2.0, "O", 3.3333333333333335) + .row(34.0, "O", 14.333333333333334) + .build()); + + assertWindowQueryWithInfinity("avg(orderkey) OVER (ORDER BY orderdate ROWS BETWEEN 2 PRECEDING AND CURRENT ROW)", + resultBuilder(TEST_SESSION, BIGINT, VARCHAR, DOUBLE) + .row(6.0, "F", 6.0) + .row(3.0, "F", 4.5) + .row(33.0, "F", 14.0) + .row(Double.POSITIVE_INFINITY, "F", Double.POSITIVE_INFINITY) + .row(32.0, "O", Double.POSITIVE_INFINITY) + .row(4.0, "O", Double.POSITIVE_INFINITY) + .row(1.0, "O", 12.333333333333334) + .row(7.0, "O", 4.0) + .row(2.0, "O", 3.3333333333333335) + .row(34.0, "O", 14.333333333333334) + .build()); + } + + @Test + public void testSumRowsRollingWithNanAndInfinity() + { + assertWindowQueryWithNan("sum(orderkey) OVER (ORDER BY orderdate ROWS BETWEEN 2 PRECEDING AND CURRENT ROW)", + resultBuilder(TEST_SESSION, BIGINT, VARCHAR, DOUBLE) + .row(6.0, "F", 6.0) + .row(3.0, "F", 9.0) + .row(33.0, "F", 42.0) + .row(Double.NaN, "F", Double.NaN) + .row(32.0, "O", Double.NaN) + .row(4.0, "O", Double.NaN) + .row(1.0, "O", 37.0) + .row(7.0, "O", 12.0) + .row(2.0, "O", 10.0) + .row(34.0, "O", 43.0) + .build()); + + assertWindowQueryWithInfinity("sum(orderkey) OVER (ORDER BY orderdate ROWS BETWEEN 2 PRECEDING AND CURRENT ROW)", + resultBuilder(TEST_SESSION, BIGINT, VARCHAR, DOUBLE) + .row(6.0, "F", 6.0) + .row(3.0, "F", 9.0) + .row(33.0, "F", 42.0) + .row(Double.POSITIVE_INFINITY, "F", Double.POSITIVE_INFINITY) + .row(32.0, "O", Double.POSITIVE_INFINITY) + .row(4.0, "O", Double.POSITIVE_INFINITY) + .row(1.0, "O", 37.0) + .row(7.0, "O", 12.0) + .row(2.0, "O", 10.0) + .row(34.0, "O", 43.0) + .build()); + } } diff --git a/core/trino-main/src/test/java/io/trino/operator/window/WindowAssertions.java b/core/trino-main/src/test/java/io/trino/operator/window/WindowAssertions.java index e0b1cd1c17c5..99fdabd15bb4 100644 --- a/core/trino-main/src/test/java/io/trino/operator/window/WindowAssertions.java +++ b/core/trino-main/src/test/java/io/trino/operator/window/WindowAssertions.java @@ -54,6 +54,38 @@ public final class WindowAssertions " (CAST(NULL AS BIGINT), CAST(NULL AS VARCHAR), '1995-07-16')\n" + ") AS orders (orderkey, orderstatus, orderdate)"; + private static final String VALUES_WITH_NAN = "" + + "SELECT *\n" + + "FROM (\n" + + " VALUES\n" + + " ( 1, 'O', '1996-01-02'),\n" + + " ( 2, 'O', '1996-12-01'),\n" + + " ( 3, 'F', '1993-10-14'),\n" + + " ( 4, 'O', '1995-10-11'),\n" + + " ( nan(), 'F', '1994-07-30'),\n" + + " ( 6, 'F', '1992-02-21'),\n" + + " ( 7, 'O', '1996-01-10'),\n" + + " (32, 'O', '1995-07-16'),\n" + + " (33, 'F', '1993-10-27'),\n" + + " (34, 'O', '1998-07-21')\n" + + ") AS orders (orderkey, orderstatus, orderdate)"; + + private static final String VALUES_WITH_INFINITY = "" + + "SELECT *\n" + + "FROM (\n" + + " VALUES\n" + + " ( 1, 'O', '1996-01-02'),\n" + + " ( 2, 'O', '1996-12-01'),\n" + + " ( 3, 'F', '1993-10-14'),\n" + + " ( 4, 'O', '1995-10-11'),\n" + + " ( infinity(), 'F', '1994-07-30'),\n" + + " ( 6, 'F', '1992-02-21'),\n" + + " ( 7, 'O', '1996-01-10'),\n" + + " (32, 'O', '1995-07-16'),\n" + + " (33, 'F', '1993-10-27'),\n" + + " (34, 'O', '1998-07-21')\n" + + ") AS orders (orderkey, orderstatus, orderdate)"; + private WindowAssertions() {} public static void assertWindowQuery(@Language("SQL") String sql, MaterializedResult expected, QueryRunner queryRunner) @@ -80,4 +112,24 @@ public static MaterializedResult executeWindowQueryWithNulls(@Language("SQL") St return queryRunner.execute(query); } + + public static void assertWindowQueryWithNan(@Language("SQL") String sql, MaterializedResult expected, QueryRunner queryRunner) + { + @Language("SQL") String query = format("" + + "SELECT orderkey, orderstatus,\n%s\n" + + "FROM (%s) x", sql, VALUES_WITH_NAN); + + MaterializedResult actual = queryRunner.execute(query); + assertEqualsIgnoreOrder(actual.getMaterializedRows(), expected.getMaterializedRows()); + } + + public static void assertWindowQueryWithInfinity(@Language("SQL") String sql, MaterializedResult expected, QueryRunner queryRunner) + { + @Language("SQL") String query = format("" + + "SELECT orderkey, orderstatus,\n%s\n" + + "FROM (%s) x", sql, VALUES_WITH_INFINITY); + + MaterializedResult actual = queryRunner.execute(query); + assertEqualsIgnoreOrder(actual.getMaterializedRows(), expected.getMaterializedRows()); + } }