Skip to content

Commit

Permalink
Extract a helper method with @Language annotation
Browse files Browse the repository at this point in the history
  • Loading branch information
ebyhr committed Aug 9, 2024
1 parent 46dd56f commit 7541ffe
Showing 1 changed file with 36 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@
import java.util.Map;
import org.apache.iceberg.rest.requests.ImmutableRegisterTableRequest;
import org.apache.iceberg.rest.responses.LoadTableResponse;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.intellij.lang.annotations.Language;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
Expand Down Expand Up @@ -179,7 +181,7 @@ public void before() {
spark =
withCatalog(withCatalog(sessionBuilder, CATALOG_NAME), EXTERNAL_CATALOG_NAME).getOrCreate();

spark.sql("USE " + CATALOG_NAME);
onSpark("USE " + CATALOG_NAME);
}

private SparkSession.Builder withCatalog(SparkSession.Builder builder, String catalogName) {
Expand Down Expand Up @@ -215,14 +217,14 @@ public void after() {
}

private void cleanupCatalog(String catalogName) {
spark.sql("USE " + catalogName);
List<Row> namespaces = spark.sql("SHOW NAMESPACES").collectAsList();
onSpark("USE " + catalogName);
List<Row> namespaces = onSpark("SHOW NAMESPACES").collectAsList();
for (Row namespace : namespaces) {
List<Row> tables = spark.sql("SHOW TABLES IN " + namespace.getString(0)).collectAsList();
List<Row> tables = onSpark("SHOW TABLES IN " + namespace.getString(0)).collectAsList();
for (Row table : tables) {
spark.sql("DROP TABLE " + namespace.getString(0) + "." + table.getString(1));
onSpark("DROP TABLE " + namespace.getString(0) + "." + table.getString(1));
}
spark.sql("DROP NAMESPACE " + namespace.getString(0));
onSpark("DROP NAMESPACE " + namespace.getString(0));
}
try (Response response =
EXT.client()
Expand All @@ -240,36 +242,36 @@ private void cleanupCatalog(String catalogName) {

@Test
public void testCreateTable() {
long namespaceCount = spark.sql("SHOW NAMESPACES").count();
long namespaceCount = onSpark("SHOW NAMESPACES").count();
assertThat(namespaceCount).isEqualTo(0L);

spark.sql("CREATE NAMESPACE ns1");
spark.sql("USE ns1");
spark.sql("CREATE TABLE tb1 (col1 integer, col2 string)");
spark.sql("INSERT INTO tb1 VALUES (1, 'a'), (2, 'b'), (3, 'c')");
long recordCount = spark.sql("SELECT * FROM tb1").count();
onSpark("CREATE NAMESPACE ns1");
onSpark("USE ns1");
onSpark("CREATE TABLE tb1 (col1 integer, col2 string)");
onSpark("INSERT INTO tb1 VALUES (1, 'a'), (2, 'b'), (3, 'c')");
long recordCount = onSpark("SELECT * FROM tb1").count();
assertThat(recordCount).isEqualTo(3);
}

@Test
public void testCreateAndUpdateExternalTable() {
long namespaceCount = spark.sql("SHOW NAMESPACES").count();
long namespaceCount = onSpark("SHOW NAMESPACES").count();
assertThat(namespaceCount).isEqualTo(0L);

spark.sql("CREATE NAMESPACE ns1");
spark.sql("USE ns1");
spark.sql("CREATE TABLE tb1 (col1 integer, col2 string)");
spark.sql("INSERT INTO tb1 VALUES (1, 'a'), (2, 'b'), (3, 'c')");
long recordCount = spark.sql("SELECT * FROM tb1").count();
onSpark("CREATE NAMESPACE ns1");
onSpark("USE ns1");
onSpark("CREATE TABLE tb1 (col1 integer, col2 string)");
onSpark("INSERT INTO tb1 VALUES (1, 'a'), (2, 'b'), (3, 'c')");
long recordCount = onSpark("SELECT * FROM tb1").count();
assertThat(recordCount).isEqualTo(3);

spark.sql("USE " + EXTERNAL_CATALOG_NAME);
List<Row> existingNamespaces = spark.sql("SHOW NAMESPACES").collectAsList();
onSpark("USE " + EXTERNAL_CATALOG_NAME);
List<Row> existingNamespaces = onSpark("SHOW NAMESPACES").collectAsList();
assertThat(existingNamespaces).isEmpty();

spark.sql("CREATE NAMESPACE externalns1");
spark.sql("USE externalns1");
List<Row> existingTables = spark.sql("SHOW TABLES").collectAsList();
onSpark("CREATE NAMESPACE externalns1");
onSpark("USE externalns1");
List<Row> existingTables = onSpark("SHOW TABLES").collectAsList();
assertThat(existingTables).isEmpty();

LoadTableResponse tableResponse = loadTable(CATALOG_NAME, "ns1", "tb1");
Expand All @@ -293,16 +295,16 @@ public void testCreateAndUpdateExternalTable() {
assertThat(registerResponse).returns(Response.Status.OK.getStatusCode(), Response::getStatus);
}

long tableCount = spark.sql("SHOW TABLES").count();
long tableCount = onSpark("SHOW TABLES").count();
assertThat(tableCount).isEqualTo(1);
List<Row> tables = spark.sql("SHOW TABLES").collectAsList();
List<Row> tables = onSpark("SHOW TABLES").collectAsList();
assertThat(tables).hasSize(1).extracting(row -> row.getString(1)).containsExactly("mytb1");
long rowCount = spark.sql("SELECT * FROM mytb1").count();
long rowCount = onSpark("SELECT * FROM mytb1").count();
assertThat(rowCount).isEqualTo(3);
assertThatThrownBy(() -> spark.sql("INSERT INTO mytb1 VALUES (20, 'new_text')"))
assertThatThrownBy(() -> onSpark("INSERT INTO mytb1 VALUES (20, 'new_text')"))
.isInstanceOf(Exception.class);

spark.sql("INSERT INTO " + CATALOG_NAME + ".ns1.tb1 VALUES (20, 'new_text')");
onSpark("INSERT INTO " + CATALOG_NAME + ".ns1.tb1 VALUES (20, 'new_text')");
tableResponse = loadTable(CATALOG_NAME, "ns1", "tb1");
TableUpdateNotification updateNotification =
new TableUpdateNotification(
Expand All @@ -328,8 +330,8 @@ public void testCreateAndUpdateExternalTable() {
.returns(Response.Status.NO_CONTENT.getStatusCode(), Response::getStatus);
}
// refresh the table so it queries for the latest metadata.json
spark.sql("REFRESH TABLE mytb1");
rowCount = spark.sql("SELECT * FROM mytb1").count();
onSpark("REFRESH TABLE mytb1");
rowCount = onSpark("SELECT * FROM mytb1").count();
assertThat(rowCount).isEqualTo(4);
}

Expand All @@ -348,4 +350,8 @@ private LoadTableResponse loadTable(String catalog, String namespace, String tab
return response.readEntity(LoadTableResponse.class);
}
}

private static Dataset<Row> onSpark(@Language("SQL") String sql) {
return spark.sql(sql);
}
}

0 comments on commit 7541ffe

Please sign in to comment.