diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/DereferenceCodeGenerator.java b/core/trino-main/src/main/java/io/trino/sql/gen/DereferenceCodeGenerator.java index 6f5736356791..8770e094926d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/DereferenceCodeGenerator.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/DereferenceCodeGenerator.java @@ -29,6 +29,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.bytecode.expression.BytecodeExpressions.constantInt; import static io.trino.sql.gen.SqlTypeBytecodeExpression.constantType; +import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; public class DereferenceCodeGenerator @@ -44,7 +45,7 @@ public DereferenceCodeGenerator(SpecialForm specialForm) returnType = specialForm.getType(); checkArgument(specialForm.getArguments().size() == 2); base = specialForm.getArguments().get(0); - index = (int) ((ConstantExpression) specialForm.getArguments().get(1)).getValue(); + index = toIntExact((long) ((ConstantExpression) specialForm.getArguments().get(1)).getValue()); } @Override diff --git a/core/trino-main/src/main/java/io/trino/sql/relational/ConstantExpression.java b/core/trino-main/src/main/java/io/trino/sql/relational/ConstantExpression.java index 9cbf9fd8c824..fa7f80412064 100644 --- a/core/trino-main/src/main/java/io/trino/sql/relational/ConstantExpression.java +++ b/core/trino-main/src/main/java/io/trino/sql/relational/ConstantExpression.java @@ -15,6 +15,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.primitives.Primitives; import com.google.errorprone.annotations.DoNotCall; import io.airlift.slice.Slice; import io.trino.spi.block.Block; @@ -48,6 +49,13 @@ public static ConstantExpression fromJson( public ConstantExpression(Object value, Type type) { requireNonNull(type, "type is null"); + if (value != null && !Primitives.wrap(type.getJavaType()).isInstance(value)) { + throw new IllegalArgumentException("Invalid value %s of Java type %s for Trino type %s, expected instance of %s".formatted( + value, + value.getClass(), + type, + type.getJavaType())); + } this.value = value; this.type = type; diff --git a/core/trino-main/src/main/java/io/trino/sql/relational/SqlToRowExpressionTranslator.java b/core/trino-main/src/main/java/io/trino/sql/relational/SqlToRowExpressionTranslator.java index 8fd3e2efd5b8..dfedb6720c67 100644 --- a/core/trino-main/src/main/java/io/trino/sql/relational/SqlToRowExpressionTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/relational/SqlToRowExpressionTranslator.java @@ -702,7 +702,7 @@ protected RowExpression visitSubscriptExpression(SubscriptExpression node, Void if (getType(node.getBase()) instanceof RowType) { long value = (Long) ((ConstantExpression) index).getValue(); - return new SpecialForm(DEREFERENCE, getType(node), base, constant((int) value - 1, INTEGER)); + return new SpecialForm(DEREFERENCE, getType(node), base, constant(value - 1, INTEGER)); } return call( diff --git a/core/trino-main/src/test/java/io/trino/sql/gen/TestInCodeGenerator.java b/core/trino-main/src/test/java/io/trino/sql/gen/TestInCodeGenerator.java index 2fcaa99882b5..9bfb5af90025 100644 --- a/core/trino-main/src/test/java/io/trino/sql/gen/TestInCodeGenerator.java +++ b/core/trino-main/src/test/java/io/trino/sql/gen/TestInCodeGenerator.java @@ -43,9 +43,9 @@ public class TestInCodeGenerator public void testInteger() { List values = new ArrayList<>(); - values.add(constant(Integer.MIN_VALUE, INTEGER)); - values.add(constant(Integer.MAX_VALUE, INTEGER)); - values.add(constant(3, INTEGER)); + values.add(constant((long) Integer.MIN_VALUE, INTEGER)); + values.add(constant((long) Integer.MAX_VALUE, INTEGER)); + values.add(constant(3L, INTEGER)); assertThat(checkSwitchGenerationCase(INTEGER, values)).isEqualTo(DIRECT_SWITCH); values.add(constant(null, INTEGER)); @@ -55,11 +55,11 @@ public void testInteger() Collections.singletonList(constant(12345678901234.0, DOUBLE)))); assertThat(checkSwitchGenerationCase(INTEGER, values)).isEqualTo(DIRECT_SWITCH); - values.add(constant(6, BIGINT)); - values.add(constant(7, BIGINT)); + values.add(constant(6L, BIGINT)); + values.add(constant(7L, BIGINT)); assertThat(checkSwitchGenerationCase(INTEGER, values)).isEqualTo(DIRECT_SWITCH); - values.add(constant(8, INTEGER)); + values.add(constant(8L, INTEGER)); assertThat(checkSwitchGenerationCase(INTEGER, values)).isEqualTo(SET_CONTAINS); } @@ -130,9 +130,9 @@ public void testDouble() public void testVarchar() { List values = new ArrayList<>(); - values.add(constant(Slices.utf8Slice("1"), DOUBLE)); - values.add(constant(Slices.utf8Slice("2"), DOUBLE)); - values.add(constant(Slices.utf8Slice("3"), DOUBLE)); + values.add(constant(Slices.utf8Slice("1"), VARCHAR)); + values.add(constant(Slices.utf8Slice("2"), VARCHAR)); + values.add(constant(Slices.utf8Slice("3"), VARCHAR)); assertThat(checkSwitchGenerationCase(VARCHAR, values)).isEqualTo(HASH_SWITCH); values.add(constant(null, VARCHAR)); diff --git a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java index e3699a15e98f..734cd082242f 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java @@ -6642,6 +6642,15 @@ SELECT my_pow(2, 8) """)) .matches("VALUES 256"); + // function with dereference + assertThat(query(""" + WITH FUNCTION get(input row(varchar)) + RETURNS varchar + RETURN input[1] + SELECT get(ROW('abc')) + """)) + .matches("VALUES VARCHAR 'abc'"); + // validations for inline functions assertQueryFails("WITH FUNCTION a.b() RETURNS int RETURN 42 SELECT a.b()", "line 1:6: Inline function names cannot be qualified: a.b");