Skip to content

Commit

Permalink
Handle Invalid Values in Window Aggregations
Browse files Browse the repository at this point in the history
This PR changes the return type of removeInput method to boolean
to handle NaN and Infinite values in Aggregation Window Functions
  • Loading branch information
pgandhi999 authored and pettyjamesm committed Mar 13, 2024
1 parent cd01bb5 commit 99979fb
Show file tree
Hide file tree
Showing 17 changed files with 657 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import io.airlift.bytecode.control.IfStatement;
import io.airlift.bytecode.expression.BytecodeExpression;
import io.airlift.bytecode.expression.BytecodeExpressions;
import io.airlift.bytecode.instruction.LabelNode;
import io.trino.operator.window.InternalWindowIndex;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
Expand Down Expand Up @@ -287,22 +288,20 @@ public static Constructor<? extends WindowAccumulator> generateWindowAccumulator

// Generate methods
generateCopy(definition, WindowAccumulator.class);
generateAddOrRemoveInputWindowIndex(
generateAddInputWindowIndex(
definition,
stateFields,
argumentNullable,
lambdaProviderFields,
implementation.getInputFunction(),
"addInput",
callSiteBinder);
implementation.getRemoveInputFunction().ifPresent(
removeInputFunction -> generateAddOrRemoveInputWindowIndex(
removeInputFunction -> generateRemoveInputWindowIndex(
definition,
stateFields,
argumentNullable,
lambdaProviderFields,
removeInputFunction,
"removeInput",
callSiteBinder));

generateEvaluateFinal(definition, stateFields, implementation.getOutputFunction(), callSiteBinder);
Expand Down Expand Up @@ -417,13 +416,12 @@ private static void generateAddInput(
body.ret();
}

private static void generateAddOrRemoveInputWindowIndex(
private static void generateAddInputWindowIndex(
ClassDefinition definition,
List<FieldDefinition> stateField,
List<Boolean> argumentNullable,
List<FieldDefinition> lambdaProviderFields,
MethodHandle inputFunction,
String generatedFunctionName,
CallSiteBinder callSiteBinder)
{
// TODO: implement masking based on maskChannel field once Window Functions support DISTINCT arguments to the functions.
Expand All @@ -434,7 +432,7 @@ private static void generateAddOrRemoveInputWindowIndex(

MethodDefinition method = definition.declareMethod(
a(PUBLIC),
generatedFunctionName,
"addInput",
type(void.class),
ImmutableList.of(index, startPosition, endPosition));
Scope scope = method.getScope();
Expand Down Expand Up @@ -462,7 +460,7 @@ private static void generateAddOrRemoveInputWindowIndex(
invokeInputFunction.append(invokeDynamic(
BOOTSTRAP_METHOD,
ImmutableList.of(binding.getBindingId()),
generatedFunctionName,
"addInput",
binding.getType(),
getInvokeFunctionOnWindowIndexParameters(
scope.getThis(),
Expand All @@ -481,6 +479,75 @@ private static void generateAddOrRemoveInputWindowIndex(
.ret();
}

private static void generateRemoveInputWindowIndex(
ClassDefinition definition,
List<FieldDefinition> stateField,
List<Boolean> argumentNullable,
List<FieldDefinition> lambdaProviderFields,
MethodHandle removeFunction,
CallSiteBinder callSiteBinder)
{
// TODO: implement masking based on maskChannel field once Window Functions support DISTINCT arguments to the functions.

Parameter index = arg("index", WindowIndex.class);
Parameter startPosition = arg("startPosition", int.class);
Parameter endPosition = arg("endPosition", int.class);

MethodDefinition method = definition.declareMethod(
a(PUBLIC),
"removeInput",
type(boolean.class),
ImmutableList.of(index, startPosition, endPosition));
Scope scope = method.getScope();
BytecodeBlock body = method.getBody();

Variable position = scope.declareVariable(int.class, "position");

// input parameters
Variable inputBlockPosition = scope.declareVariable(int.class, "inputBlockPosition");
List<Variable> inputBlockVariables = new ArrayList<>();
for (int i = 0; i < argumentNullable.size(); i++) {
inputBlockVariables.add(scope.declareVariable(Block.class, "inputBlock" + i));
}

Binding binding = callSiteBinder.bind(removeFunction);
BytecodeBlock invokeRemoveFunction = new BytecodeBlock();
// WindowIndex is built on PagesIndex, which simply wraps Blocks
// and currently does not understand ValueBlocks.
// Until PagesIndex is updated to understand ValueBlocks, the
// input function parameters must be directly unwrapped to ValueBlocks.
invokeRemoveFunction.append(inputBlockPosition.set(index.cast(InternalWindowIndex.class).invoke("getRawBlockPosition", int.class, position)));
for (int i = 0; i < inputBlockVariables.size(); i++) {
invokeRemoveFunction.append(inputBlockVariables.get(i).set(index.cast(InternalWindowIndex.class).invoke("getRawBlock", Block.class, constantInt(i), position)));
}
LabelNode returnFalse = new LabelNode("returnFalse");
invokeRemoveFunction.append(invokeDynamic(
BOOTSTRAP_METHOD,
ImmutableList.of(binding.getBindingId()),
"removeInput",
binding.getType(),
getInvokeFunctionOnWindowIndexParameters(
scope.getThis(),
stateField,
inputBlockPosition,
inputBlockVariables,
lambdaProviderFields)));
invokeRemoveFunction.ifFalseGoto(returnFalse);

body.append(new ForLoop()
.initialize(position.set(startPosition))
.condition(BytecodeExpressions.lessThanOrEqual(position, endPosition))
.update(position.increment())
.body(new IfStatement()
.condition(anyParametersAreNull(argumentNullable, index, position))
.ifFalse(invokeRemoveFunction)))
.push(true)
.retBoolean()
.visitLabel(returnFalse)
.push(false)
.retBoolean();
}

private static BytecodeExpression anyParametersAreNull(
List<Boolean> argumentNullable,
Variable index,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,23 @@ public static void input(@AggregationState LongAndDoubleState state, @SqlType(St
}

@RemoveInputFunction
public static void removeInput(@AggregationState LongAndDoubleState state, @SqlType(StandardTypes.BIGINT) long value)
public static boolean removeInput(@AggregationState LongAndDoubleState state, @SqlType(StandardTypes.BIGINT) long value)
{
state.setLong(state.getLong() - 1);
state.setDouble(state.getDouble() - value);
return true;
}

@RemoveInputFunction
public static void removeInput(@AggregationState LongAndDoubleState state, @SqlType(StandardTypes.DOUBLE) double value)
public static boolean removeInput(@AggregationState LongAndDoubleState state, @SqlType(StandardTypes.DOUBLE) double value)
{
state.setLong(state.getLong() - 1);
state.setDouble(state.getDouble() - value);
double currentValue = state.getDouble();
if (Double.isFinite(currentValue)) {
state.setDouble(currentValue - value);
state.setLong(state.getLong() - 1);
return true;
}
return false;
}

@CombineFunction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ public static void input(@AggregationState LongState state)
}

@RemoveInputFunction
public static void removeInput(@AggregationState LongState state)
public static boolean removeInput(@AggregationState LongState state)
{
state.setValue(state.getValue() - 1);
return true;
}

@CombineFunction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,13 @@ public static void input(
}

@RemoveInputFunction
public static void removeInput(
public static boolean removeInput(
@AggregationState LongState state,
@BlockPosition @SqlType("T") ValueBlock block,
@BlockIndex int position)
{
state.setValue(state.getValue() - 1);
return true;
}

@CombineFunction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,12 @@ public static void input(@AggregationState LongState state, @SqlType(StandardTyp
}

@RemoveInputFunction
public static void removeInput(@AggregationState LongState state, @SqlType(StandardTypes.BOOLEAN) boolean value)
public static boolean removeInput(@AggregationState LongState state, @SqlType(StandardTypes.BOOLEAN) boolean value)
{
if (value) {
state.setValue(state.getValue() - 1);
}
return true;
}

@CombineFunction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,15 @@ public static void sum(@AggregationState LongAndDoubleState state, @SqlType(Stan
}

@RemoveInputFunction
public static void removeInput(@AggregationState LongAndDoubleState state, @SqlType(StandardTypes.DOUBLE) double value)
public static boolean removeInput(@AggregationState LongAndDoubleState state, @SqlType(StandardTypes.DOUBLE) double value)
{
state.setLong(state.getLong() - 1);
state.setDouble(state.getDouble() - value);
double currentValue = state.getDouble();
if (Double.isFinite(currentValue)) {
state.setDouble(currentValue - value);
state.setLong(state.getLong() - 1);
return true;
}
return false;
}

@CombineFunction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,11 @@ public static void sum(@AggregationState LongLongState state, @SqlType(StandardT
}

@RemoveInputFunction
public static void removeInput(@AggregationState LongLongState state, @SqlType(StandardTypes.BIGINT) long value)
public static boolean removeInput(@AggregationState LongLongState state, @SqlType(StandardTypes.BIGINT) long value)
{
state.setFirst(state.getFirst() - 1);
state.setSecond(BigintOperators.subtract(state.getSecond(), value));
return true; // This should always return true as the state cannot be infinite or NaN for long input values.
}

@CombineFunction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,18 @@ public static void input(
}

@RemoveInputFunction
public static void removeInput(
public static boolean removeInput(
@AggregationState LongState count,
@AggregationState DoubleState sum,
@SqlType("REAL") long value)
{
count.setValue(count.getValue() - 1);
sum.setValue(sum.getValue() - intBitsToFloat((int) value));
double currentValue = sum.getValue();
if (Double.isFinite(currentValue)) {
sum.setValue(currentValue - intBitsToFloat((int) value));
count.setValue(count.getValue() - 1);
return true;
}
return false;
}

@CombineFunction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ public interface WindowAccumulator

void addInput(WindowIndex index, int startPosition, int endPosition);

void removeInput(WindowIndex index, int startPosition, int endPosition);
/**
* @return Returns false when an NaN or Infinite input double value is
* encountered, true otherwise.
*/
boolean removeInput(WindowIndex index, int startPosition, int endPosition);

void evaluateFinal(BlockBuilder blockBuilder);
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,13 @@ private void buildNewFrame(int frameStart, int frameEnd)

if ((overlapEnd - overlapStart + 1) > (prefixRemoveLength + suffixRemoveLength)) {
// It's worth keeping the overlap, and removing the now-unused prefix
if (currentStart < frameStart) {
remove(currentStart, frameStart - 1);
if (currentStart < frameStart && !remove(currentStart, frameStart - 1)) {
resetNewFrame(frameStart, frameEnd);
return;
}
if (frameEnd < currentEnd) {
remove(frameEnd + 1, currentEnd);
if (frameEnd < currentEnd && !remove(frameEnd + 1, currentEnd)) {
resetNewFrame(frameStart, frameEnd);
return;
}
if (frameStart < currentStart) {
accumulate(frameStart, currentStart - 1);
Expand All @@ -101,6 +103,11 @@ private void buildNewFrame(int frameStart, int frameEnd)
}

// We couldn't or didn't want to modify the accumulation: instead, discard the current accumulation and start fresh.
resetNewFrame(frameStart, frameEnd);
}

private void resetNewFrame(int frameStart, int frameEnd)
{
resetAccumulator();
accumulate(frameStart, frameEnd);
currentStart = frameStart;
Expand All @@ -112,9 +119,9 @@ private void accumulate(int start, int end)
accumulator.addInput(windowIndex, start, end);
}

private void remove(int start, int end)
private boolean remove(int start, int end)
{
accumulator.removeInput(windowIndex, start, end);
return accumulator.removeInput(windowIndex, start, end);
}

private void resetAccumulator()
Expand Down
Loading

0 comments on commit 99979fb

Please sign in to comment.