diff --git a/polaris-service/src/test/java/io/polaris/service/catalog/PolarisSparkIntegrationTest.java b/polaris-service/src/test/java/io/polaris/service/catalog/PolarisSparkIntegrationTest.java index fbde8030b..7feec7d50 100644 --- a/polaris-service/src/test/java/io/polaris/service/catalog/PolarisSparkIntegrationTest.java +++ b/polaris-service/src/test/java/io/polaris/service/catalog/PolarisSparkIntegrationTest.java @@ -43,8 +43,10 @@ import java.util.Map; import org.apache.iceberg.rest.requests.ImmutableRegisterTableRequest; import org.apache.iceberg.rest.responses.LoadTableResponse; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; +import org.intellij.lang.annotations.Language; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; @@ -179,7 +181,7 @@ public void before() { spark = withCatalog(withCatalog(sessionBuilder, CATALOG_NAME), EXTERNAL_CATALOG_NAME).getOrCreate(); - spark.sql("USE " + CATALOG_NAME); + onSpark("USE " + CATALOG_NAME); } private SparkSession.Builder withCatalog(SparkSession.Builder builder, String catalogName) { @@ -215,14 +217,14 @@ public void after() { } private void cleanupCatalog(String catalogName) { - spark.sql("USE " + catalogName); - List namespaces = spark.sql("SHOW NAMESPACES").collectAsList(); + onSpark("USE " + catalogName); + List namespaces = onSpark("SHOW NAMESPACES").collectAsList(); for (Row namespace : namespaces) { - List tables = spark.sql("SHOW TABLES IN " + namespace.getString(0)).collectAsList(); + List tables = onSpark("SHOW TABLES IN " + namespace.getString(0)).collectAsList(); for (Row table : tables) { - spark.sql("DROP TABLE " + namespace.getString(0) + "." + table.getString(1)); + onSpark("DROP TABLE " + namespace.getString(0) + "." + table.getString(1)); } - spark.sql("DROP NAMESPACE " + namespace.getString(0)); + onSpark("DROP NAMESPACE " + namespace.getString(0)); } try (Response response = EXT.client() @@ -240,36 +242,36 @@ private void cleanupCatalog(String catalogName) { @Test public void testCreateTable() { - long namespaceCount = spark.sql("SHOW NAMESPACES").count(); + long namespaceCount = onSpark("SHOW NAMESPACES").count(); assertThat(namespaceCount).isEqualTo(0L); - spark.sql("CREATE NAMESPACE ns1"); - spark.sql("USE ns1"); - spark.sql("CREATE TABLE tb1 (col1 integer, col2 string)"); - spark.sql("INSERT INTO tb1 VALUES (1, 'a'), (2, 'b'), (3, 'c')"); - long recordCount = spark.sql("SELECT * FROM tb1").count(); + onSpark("CREATE NAMESPACE ns1"); + onSpark("USE ns1"); + onSpark("CREATE TABLE tb1 (col1 integer, col2 string)"); + onSpark("INSERT INTO tb1 VALUES (1, 'a'), (2, 'b'), (3, 'c')"); + long recordCount = onSpark("SELECT * FROM tb1").count(); assertThat(recordCount).isEqualTo(3); } @Test public void testCreateAndUpdateExternalTable() { - long namespaceCount = spark.sql("SHOW NAMESPACES").count(); + long namespaceCount = onSpark("SHOW NAMESPACES").count(); assertThat(namespaceCount).isEqualTo(0L); - spark.sql("CREATE NAMESPACE ns1"); - spark.sql("USE ns1"); - spark.sql("CREATE TABLE tb1 (col1 integer, col2 string)"); - spark.sql("INSERT INTO tb1 VALUES (1, 'a'), (2, 'b'), (3, 'c')"); - long recordCount = spark.sql("SELECT * FROM tb1").count(); + onSpark("CREATE NAMESPACE ns1"); + onSpark("USE ns1"); + onSpark("CREATE TABLE tb1 (col1 integer, col2 string)"); + onSpark("INSERT INTO tb1 VALUES (1, 'a'), (2, 'b'), (3, 'c')"); + long recordCount = onSpark("SELECT * FROM tb1").count(); assertThat(recordCount).isEqualTo(3); - spark.sql("USE " + EXTERNAL_CATALOG_NAME); - List existingNamespaces = spark.sql("SHOW NAMESPACES").collectAsList(); + onSpark("USE " + EXTERNAL_CATALOG_NAME); + List existingNamespaces = onSpark("SHOW NAMESPACES").collectAsList(); assertThat(existingNamespaces).isEmpty(); - spark.sql("CREATE NAMESPACE externalns1"); - spark.sql("USE externalns1"); - List existingTables = spark.sql("SHOW TABLES").collectAsList(); + onSpark("CREATE NAMESPACE externalns1"); + onSpark("USE externalns1"); + List existingTables = onSpark("SHOW TABLES").collectAsList(); assertThat(existingTables).isEmpty(); LoadTableResponse tableResponse = loadTable(CATALOG_NAME, "ns1", "tb1"); @@ -293,16 +295,16 @@ public void testCreateAndUpdateExternalTable() { assertThat(registerResponse).returns(Response.Status.OK.getStatusCode(), Response::getStatus); } - long tableCount = spark.sql("SHOW TABLES").count(); + long tableCount = onSpark("SHOW TABLES").count(); assertThat(tableCount).isEqualTo(1); - List tables = spark.sql("SHOW TABLES").collectAsList(); + List tables = onSpark("SHOW TABLES").collectAsList(); assertThat(tables).hasSize(1).extracting(row -> row.getString(1)).containsExactly("mytb1"); - long rowCount = spark.sql("SELECT * FROM mytb1").count(); + long rowCount = onSpark("SELECT * FROM mytb1").count(); assertThat(rowCount).isEqualTo(3); - assertThatThrownBy(() -> spark.sql("INSERT INTO mytb1 VALUES (20, 'new_text')")) + assertThatThrownBy(() -> onSpark("INSERT INTO mytb1 VALUES (20, 'new_text')")) .isInstanceOf(Exception.class); - spark.sql("INSERT INTO " + CATALOG_NAME + ".ns1.tb1 VALUES (20, 'new_text')"); + onSpark("INSERT INTO " + CATALOG_NAME + ".ns1.tb1 VALUES (20, 'new_text')"); tableResponse = loadTable(CATALOG_NAME, "ns1", "tb1"); TableUpdateNotification updateNotification = new TableUpdateNotification( @@ -328,8 +330,8 @@ public void testCreateAndUpdateExternalTable() { .returns(Response.Status.NO_CONTENT.getStatusCode(), Response::getStatus); } // refresh the table so it queries for the latest metadata.json - spark.sql("REFRESH TABLE mytb1"); - rowCount = spark.sql("SELECT * FROM mytb1").count(); + onSpark("REFRESH TABLE mytb1"); + rowCount = onSpark("SELECT * FROM mytb1").count(); assertThat(rowCount).isEqualTo(4); } @@ -348,4 +350,8 @@ private LoadTableResponse loadTable(String catalog, String namespace, String tab return response.readEntity(LoadTableResponse.class); } } + + private static Dataset onSpark(@Language("SQL") String sql) { + return spark.sql(sql); + } }