From 05404daab0bae2dfa4de4c899824374ef59a54a6 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Wed, 20 Dec 2023 17:05:14 +0100 Subject: [PATCH] Fix ConstantExpression serialization When `ConstantExpression` was created with a Trino type and a value that didn't match this type's java type (for example Trino INTEGER with value being Integer, instead of Long), the `ConstantExpression` would fail to serialize to JSON. Such failure happens during sending task status updates to workers and is currently logged and ignored, leading to query hang. The problem could be triggered with SQL routines, which utilize `RowExpression` (including `ConstantExpression`) serialization. Thus, this fixes execution of SQL routines involving Row field dereference. --- .../sql/gen/DereferenceCodeGenerator.java | 3 ++- .../sql/relational/ConstantExpression.java | 8 ++++++++ .../SqlToRowExpressionTranslator.java | 2 +- .../io/trino/sql/gen/TestInCodeGenerator.java | 18 +++++++++--------- .../testing/AbstractTestEngineOnlyQueries.java | 9 +++++++++ 5 files changed, 29 insertions(+), 11 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/DereferenceCodeGenerator.java b/core/trino-main/src/main/java/io/trino/sql/gen/DereferenceCodeGenerator.java index 6f5736356791..8770e094926d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/DereferenceCodeGenerator.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/DereferenceCodeGenerator.java @@ -29,6 +29,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.bytecode.expression.BytecodeExpressions.constantInt; import static io.trino.sql.gen.SqlTypeBytecodeExpression.constantType; +import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; public class DereferenceCodeGenerator @@ -44,7 +45,7 @@ public DereferenceCodeGenerator(SpecialForm specialForm) returnType = specialForm.getType(); checkArgument(specialForm.getArguments().size() == 2); base = specialForm.getArguments().get(0); - index = (int) ((ConstantExpression) specialForm.getArguments().get(1)).getValue(); + index = toIntExact((long) ((ConstantExpression) specialForm.getArguments().get(1)).getValue()); } @Override diff --git a/core/trino-main/src/main/java/io/trino/sql/relational/ConstantExpression.java b/core/trino-main/src/main/java/io/trino/sql/relational/ConstantExpression.java index 9cbf9fd8c824..fa7f80412064 100644 --- a/core/trino-main/src/main/java/io/trino/sql/relational/ConstantExpression.java +++ b/core/trino-main/src/main/java/io/trino/sql/relational/ConstantExpression.java @@ -15,6 +15,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.primitives.Primitives; import com.google.errorprone.annotations.DoNotCall; import io.airlift.slice.Slice; import io.trino.spi.block.Block; @@ -48,6 +49,13 @@ public static ConstantExpression fromJson( public ConstantExpression(Object value, Type type) { requireNonNull(type, "type is null"); + if (value != null && !Primitives.wrap(type.getJavaType()).isInstance(value)) { + throw new IllegalArgumentException("Invalid value %s of Java type %s for Trino type %s, expected instance of %s".formatted( + value, + value.getClass(), + type, + type.getJavaType())); + } this.value = value; this.type = type; diff --git a/core/trino-main/src/main/java/io/trino/sql/relational/SqlToRowExpressionTranslator.java b/core/trino-main/src/main/java/io/trino/sql/relational/SqlToRowExpressionTranslator.java index 8fd3e2efd5b8..dfedb6720c67 100644 --- a/core/trino-main/src/main/java/io/trino/sql/relational/SqlToRowExpressionTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/relational/SqlToRowExpressionTranslator.java @@ -702,7 +702,7 @@ protected RowExpression visitSubscriptExpression(SubscriptExpression node, Void if (getType(node.getBase()) instanceof RowType) { long value = (Long) ((ConstantExpression) index).getValue(); - return new SpecialForm(DEREFERENCE, getType(node), base, constant((int) value - 1, INTEGER)); + return new SpecialForm(DEREFERENCE, getType(node), base, constant(value - 1, INTEGER)); } return call( diff --git a/core/trino-main/src/test/java/io/trino/sql/gen/TestInCodeGenerator.java b/core/trino-main/src/test/java/io/trino/sql/gen/TestInCodeGenerator.java index 2fcaa99882b5..9bfb5af90025 100644 --- a/core/trino-main/src/test/java/io/trino/sql/gen/TestInCodeGenerator.java +++ b/core/trino-main/src/test/java/io/trino/sql/gen/TestInCodeGenerator.java @@ -43,9 +43,9 @@ public class TestInCodeGenerator public void testInteger() { List values = new ArrayList<>(); - values.add(constant(Integer.MIN_VALUE, INTEGER)); - values.add(constant(Integer.MAX_VALUE, INTEGER)); - values.add(constant(3, INTEGER)); + values.add(constant((long) Integer.MIN_VALUE, INTEGER)); + values.add(constant((long) Integer.MAX_VALUE, INTEGER)); + values.add(constant(3L, INTEGER)); assertThat(checkSwitchGenerationCase(INTEGER, values)).isEqualTo(DIRECT_SWITCH); values.add(constant(null, INTEGER)); @@ -55,11 +55,11 @@ public void testInteger() Collections.singletonList(constant(12345678901234.0, DOUBLE)))); assertThat(checkSwitchGenerationCase(INTEGER, values)).isEqualTo(DIRECT_SWITCH); - values.add(constant(6, BIGINT)); - values.add(constant(7, BIGINT)); + values.add(constant(6L, BIGINT)); + values.add(constant(7L, BIGINT)); assertThat(checkSwitchGenerationCase(INTEGER, values)).isEqualTo(DIRECT_SWITCH); - values.add(constant(8, INTEGER)); + values.add(constant(8L, INTEGER)); assertThat(checkSwitchGenerationCase(INTEGER, values)).isEqualTo(SET_CONTAINS); } @@ -130,9 +130,9 @@ public void testDouble() public void testVarchar() { List values = new ArrayList<>(); - values.add(constant(Slices.utf8Slice("1"), DOUBLE)); - values.add(constant(Slices.utf8Slice("2"), DOUBLE)); - values.add(constant(Slices.utf8Slice("3"), DOUBLE)); + values.add(constant(Slices.utf8Slice("1"), VARCHAR)); + values.add(constant(Slices.utf8Slice("2"), VARCHAR)); + values.add(constant(Slices.utf8Slice("3"), VARCHAR)); assertThat(checkSwitchGenerationCase(VARCHAR, values)).isEqualTo(HASH_SWITCH); values.add(constant(null, VARCHAR)); diff --git a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java index e3699a15e98f..734cd082242f 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java @@ -6642,6 +6642,15 @@ SELECT my_pow(2, 8) """)) .matches("VALUES 256"); + // function with dereference + assertThat(query(""" + WITH FUNCTION get(input row(varchar)) + RETURNS varchar + RETURN input[1] + SELECT get(ROW('abc')) + """)) + .matches("VALUES VARCHAR 'abc'"); + // validations for inline functions assertQueryFails("WITH FUNCTION a.b() RETURNS int RETURN 42 SELECT a.b()", "line 1:6: Inline function names cannot be qualified: a.b");