Skip to content

Commit

Permalink
Fix ConstantExpression serialization
Browse files Browse the repository at this point in the history
When `ConstantExpression` was created with a Trino type and a value
that didn't match this type's java type (for example Trino INTEGER with
value being Integer, instead of Long), the `ConstantExpression` would
fail to serialize to JSON. Such failure happens during sending task
status updates to workers and is currently logged and ignored, leading
to query hang. The problem could be triggered with SQL routines, which
utilize `RowExpression` (including `ConstantExpression`) serialization.

Thus, this fixes execution of SQL routines involving Row field
dereference.
  • Loading branch information
findepi committed Dec 21, 2023
1 parent 40117ea commit e567dbf
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import static com.google.common.base.Preconditions.checkArgument;
import static io.airlift.bytecode.expression.BytecodeExpressions.constantInt;
import static io.trino.sql.gen.SqlTypeBytecodeExpression.constantType;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;

public class DereferenceCodeGenerator
Expand All @@ -44,7 +45,7 @@ public DereferenceCodeGenerator(SpecialForm specialForm)
returnType = specialForm.getType();
checkArgument(specialForm.getArguments().size() == 2);
base = specialForm.getArguments().get(0);
index = (int) ((ConstantExpression) specialForm.getArguments().get(1)).getValue();
index = toIntExact((long) ((ConstantExpression) specialForm.getArguments().get(1)).getValue());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.primitives.Primitives;
import com.google.errorprone.annotations.DoNotCall;
import io.airlift.slice.Slice;
import io.trino.spi.block.Block;
Expand Down Expand Up @@ -48,6 +49,13 @@ public static ConstantExpression fromJson(
public ConstantExpression(Object value, Type type)
{
requireNonNull(type, "type is null");
if (value != null && !Primitives.wrap(type.getJavaType()).isInstance(value)) {
throw new IllegalArgumentException("Invalid value %s of Java type %s for Trino type %s, expected instance of %s".formatted(
value,
value.getClass(),
type,
type.getJavaType()));
}

this.value = value;
this.type = type;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ protected RowExpression visitSubscriptExpression(SubscriptExpression node, Void

if (getType(node.getBase()) instanceof RowType) {
long value = (Long) ((ConstantExpression) index).getValue();
return new SpecialForm(DEREFERENCE, getType(node), base, constant((int) value - 1, INTEGER));
return new SpecialForm(DEREFERENCE, getType(node), base, constant(value - 1, INTEGER));
}

return call(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ public class TestInCodeGenerator
public void testInteger()
{
List<RowExpression> values = new ArrayList<>();
values.add(constant(Integer.MIN_VALUE, INTEGER));
values.add(constant(Integer.MAX_VALUE, INTEGER));
values.add(constant(3, INTEGER));
values.add(constant((long) Integer.MIN_VALUE, INTEGER));
values.add(constant((long) Integer.MAX_VALUE, INTEGER));
values.add(constant(3L, INTEGER));
assertThat(checkSwitchGenerationCase(INTEGER, values)).isEqualTo(DIRECT_SWITCH);

values.add(constant(null, INTEGER));
Expand All @@ -55,11 +55,11 @@ public void testInteger()
Collections.singletonList(constant(12345678901234.0, DOUBLE))));
assertThat(checkSwitchGenerationCase(INTEGER, values)).isEqualTo(DIRECT_SWITCH);

values.add(constant(6, BIGINT));
values.add(constant(7, BIGINT));
values.add(constant(6L, BIGINT));
values.add(constant(7L, BIGINT));
assertThat(checkSwitchGenerationCase(INTEGER, values)).isEqualTo(DIRECT_SWITCH);

values.add(constant(8, INTEGER));
values.add(constant(8L, INTEGER));
assertThat(checkSwitchGenerationCase(INTEGER, values)).isEqualTo(SET_CONTAINS);
}

Expand Down Expand Up @@ -130,9 +130,9 @@ public void testDouble()
public void testVarchar()
{
List<RowExpression> values = new ArrayList<>();
values.add(constant(Slices.utf8Slice("1"), DOUBLE));
values.add(constant(Slices.utf8Slice("2"), DOUBLE));
values.add(constant(Slices.utf8Slice("3"), DOUBLE));
values.add(constant(Slices.utf8Slice("1"), VARCHAR));
values.add(constant(Slices.utf8Slice("2"), VARCHAR));
values.add(constant(Slices.utf8Slice("3"), VARCHAR));
assertThat(checkSwitchGenerationCase(VARCHAR, values)).isEqualTo(HASH_SWITCH);

values.add(constant(null, VARCHAR));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6642,6 +6642,15 @@ SELECT my_pow(2, 8)
"""))
.matches("VALUES 256");

// function with dereference
assertThat(query("""
WITH FUNCTION get(input row(varchar))
RETURNS varchar
RETURN input[1]
SELECT get(ROW('abc'))
"""))
.matches("VALUES VARCHAR 'abc'");

// validations for inline functions
assertQueryFails("WITH FUNCTION a.b() RETURNS int RETURN 42 SELECT a.b()",
"line 1:6: Inline function names cannot be qualified: a.b");
Expand Down

0 comments on commit e567dbf

Please sign in to comment.