diff --git a/testing/trino-testing/pom.xml b/testing/trino-testing/pom.xml index 87a34ea616c0..fcc74b3b8cbc 100644 --- a/testing/trino-testing/pom.xml +++ b/testing/trino-testing/pom.xml @@ -57,6 +57,11 @@ configuration + + io.airlift + http-server + + io.airlift log @@ -200,6 +205,12 @@ assertj-core + + org.eclipse.jetty + jetty-server + 12.0.9 + + org.jdbi jdbi3-core @@ -246,4 +257,4 @@ test - \ No newline at end of file + diff --git a/testing/trino-testing/src/main/java/io/trino/testing/DistributedQueryRunner.java b/testing/trino-testing/src/main/java/io/trino/testing/DistributedQueryRunner.java index 67c8f6be2c0c..471bd91f6e2d 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/DistributedQueryRunner.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/DistributedQueryRunner.java @@ -20,6 +20,7 @@ import com.google.inject.Key; import com.google.inject.Module; import io.airlift.discovery.server.testing.TestingDiscoveryServer; +import io.airlift.http.server.HttpServer; import io.airlift.log.Logger; import io.airlift.log.Logging; import io.opentelemetry.sdk.testing.exporter.InMemorySpanExporter; @@ -57,10 +58,13 @@ import io.trino.sql.planner.Plan; import io.trino.testing.containers.OpenTracingCollector; import io.trino.transaction.TransactionManager; +import org.eclipse.jetty.server.Connector; +import org.eclipse.jetty.server.Server; import org.intellij.lang.annotations.Language; import java.io.IOException; import java.io.UncheckedIOException; +import java.lang.reflect.Field; import java.net.URI; import java.nio.file.Path; import java.util.HashMap; @@ -79,6 +83,7 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Throwables.throwIfUnchecked; import static com.google.common.base.Verify.verify; +import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.inject.util.Modules.EMPTY_MODULE; import static io.airlift.log.Level.DEBUG; import static io.airlift.log.Level.ERROR; @@ -88,7 +93,9 @@ import static io.trino.execution.querystats.PlanOptimizersStatsCollector.createPlanOptimizersStatsCollector; import static java.lang.Boolean.parseBoolean; import static java.lang.System.getenv; +import static java.util.Arrays.asList; import static java.util.Objects.requireNonNull; +import static org.assertj.core.api.Assertions.assertThat; public class DistributedQueryRunner implements QueryRunner @@ -100,7 +107,7 @@ public class DistributedQueryRunner private TestingDiscoveryServer discoveryServer; private TestingTrinoServer coordinator; private Optional backupCoordinator; - private Runnable registerNewWorker; + private Consumer> registerNewWorker; private final InMemorySpanExporter spanExporter = InMemorySpanExporter.create(); private final List servers = new CopyOnWriteArrayList<>(); private final List functionBundles = new CopyOnWriteArrayList<>(ImmutableList.of(CustomFunctionBundle.CUSTOM_FUNCTIONS)); @@ -149,11 +156,14 @@ private DistributedQueryRunner( extraCloseables.forEach(closeable -> closer.register(() -> closeUnchecked(closeable))); log.debug("Created TestingDiscoveryServer in %s", nanosSince(discoveryStart)); - registerNewWorker = () -> { + registerNewWorker = additionalWorkerProperties -> { @SuppressWarnings("resource") TestingTrinoServer ignored = createServer( false, - extraProperties, + ImmutableMap.builder() + .putAll(extraProperties) + .putAll(additionalWorkerProperties) + .buildOrThrow(), environment, additionalModule, baseDataDir, @@ -163,7 +173,7 @@ private DistributedQueryRunner( }; for (int i = 0; i < workerCount; i++) { - registerNewWorker.run(); + registerNewWorker.accept(Map.of()); } Map extraCoordinatorProperties = new HashMap<>(); @@ -317,11 +327,37 @@ private static TestingTrinoServer createTestingTrinoServer( public void addServers(int nodeCount) { for (int i = 0; i < nodeCount; i++) { - registerNewWorker.run(); + registerNewWorker.accept(Map.of()); } ensureNodesGloballyVisible(); } + /** + * Simulate worker restart as e.g. in Kubernetes after a pod is killed. + */ + public void restartWorker(TestingTrinoServer server) + throws Exception + { + URI baseUrl = server.getBaseUrl(); + checkState(servers.remove(server), "Server not found: %s", server); + HttpServer workerHttpServer = server.getInstance(Key.get(HttpServer.class)); + // Prevent any HTTP communication with the worker, as if the worker process was killed. + Field serverField = HttpServer.class.getDeclaredField("server"); + serverField.setAccessible(true); + Connector httpConnector = getOnlyElement(asList(((Server) serverField.get(workerHttpServer)).getConnectors())); + httpConnector.stop(); + server.close(); + + Map reusePort = Map.of("http-server.http.port", Integer.toString(baseUrl.getPort())); + registerNewWorker.accept(reusePort); + // Verify the address was reused. + assertThat(servers.stream() + .map(TestingTrinoServer::getBaseUrl) + .filter(baseUrl::equals)) + .hasSize(1); + // Do not wait for new server to be fully registered with other servers + } + private void ensureNodesGloballyVisible() { for (TestingTrinoServer server : servers) { @@ -589,7 +625,7 @@ public final void close() discoveryServer = null; coordinator = null; backupCoordinator = Optional.empty(); - registerNewWorker = () -> { + registerNewWorker = _ -> { throw new IllegalStateException("Already closed"); }; servers.clear(); diff --git a/testing/trino-tests/src/test/java/io/trino/tests/TestWorkerRestart.java b/testing/trino-tests/src/test/java/io/trino/tests/TestWorkerRestart.java new file mode 100644 index 000000000000..3f3962bc8f04 --- /dev/null +++ b/testing/trino-tests/src/test/java/io/trino/tests/TestWorkerRestart.java @@ -0,0 +1,176 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.tests; + +import io.trino.execution.QueryManager; +import io.trino.server.BasicQueryInfo; +import io.trino.server.testing.TestingTrinoServer; +import io.trino.testing.DistributedQueryRunner; +import io.trino.testing.MaterializedResult; +import io.trino.tests.tpch.TpchQueryRunner; +import org.junit.jupiter.api.RepeatedTest; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.parallel.Execution; + +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; + +import static com.google.common.collect.MoreCollectors.onlyElement; +import static io.airlift.concurrent.Threads.daemonThreadsNamed; +import static io.trino.execution.QueryState.RUNNING; +import static io.trino.testing.assertions.Assert.assertEventually; +import static java.util.UUID.randomUUID; +import static java.util.concurrent.Executors.newCachedThreadPool; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.parallel.ExecutionMode.SAME_THREAD; + +/** + * Test that tasks are cleanly rejected when a node restarts (same node ID, different node instance ID). + */ +@Execution(SAME_THREAD) // run single threaded to avoid creating multiple query runners at once +public class TestWorkerRestart +{ + // When working with the test locally it's practical to run multiple iterations at once. + private static final int TEST_ITERATIONS = 1; + + /** + * Test that query passes even if worker is restarted just before query. + */ + @RepeatedTest(TEST_ITERATIONS) + @Timeout(90) + public void testRestartBeforeQuery() + throws Exception + { + try (DistributedQueryRunner queryRunner = TpchQueryRunner.builder().build(); + ExecutorService executor = newCachedThreadPool(daemonThreadsNamed(getClass().getSimpleName() + "-%d"))) { + try { + // Ensure everything initialized + assertThat((long) queryRunner.execute("SELECT count(*) FROM tpch.tiny.lineitem").getOnlyValue()) + .isEqualTo(60_175); + + restartWorker(queryRunner); + // Even though the worker is restarted before we send a query, it is not fully announced to the coordinator. + // Coordinator will still try to send the query to the worker thinking it is the previous instance of it. + Future future = executor.submit(() -> queryRunner.execute("SELECT count(*) FROM tpch.sf1.lineitem -- " + randomUUID())); + future.get(); // query should succeed + + // Ensure that the restarted worker is able to serve queries. + assertThat((long) queryRunner.execute("SELECT count(*) FROM tpch.tiny.lineitem").getOnlyValue()) + .isEqualTo(60_175); + } + finally { + cancelQueries(queryRunner); + } + } + } + + /** + * Test that query fails with when worker crashes during its execution, but next query (e.g. retried query) succeeds without issues. + */ + @RepeatedTest(TEST_ITERATIONS) + @Timeout(90) + public void testRestartDuringQuery() + throws Exception + { + try (DistributedQueryRunner queryRunner = TpchQueryRunner.builder().build(); + ExecutorService executor = newCachedThreadPool(daemonThreadsNamed(getClass().getSimpleName() + "-%d"))) { + try { + // Ensure everything initialized + assertThat((long) queryRunner.execute("SELECT count(*) FROM tpch.tiny.lineitem").getOnlyValue()) + .isEqualTo(60_175); + + String sql = "SELECT count(*) FROM tpch.sf1000000000.lineitem -- " + randomUUID(); + Future future = executor.submit(() -> queryRunner.execute(sql)); + waitForQueryStart(queryRunner, sql); + restartWorker(queryRunner); + assertThatThrownBy(future::get) + .isInstanceOf(ExecutionException.class) + .cause().hasMessageFindingMatch("^Expected response code from \\S+ to be 200, but was 500" + + "|Error fetching \\S+: Expected response code to be 200, but was 500"); + + // Ensure that the restarted worker is able to serve queries. + assertThat((long) queryRunner.execute("SELECT count(*) FROM tpch.tiny.lineitem").getOnlyValue()) + .isEqualTo(60_175); + } + finally { + cancelQueries(queryRunner); + } + } + } + + /** + * Test that query passes if a worker crashed before query started but it still potentially starting up when query is being scheduled. + */ + @RepeatedTest(TEST_ITERATIONS) + @Timeout(90) + public void testStartDuringQuery() + throws Exception + { + try (DistributedQueryRunner queryRunner = TpchQueryRunner.builder().build(); + ExecutorService executor = newCachedThreadPool(daemonThreadsNamed(getClass().getSimpleName() + "-%d"))) { + try { + // Ensure everything initialized + assertThat((long) queryRunner.execute("SELECT count(*) FROM tpch.tiny.lineitem").getOnlyValue()) + .isEqualTo(60_175); + + TestingTrinoServer worker = queryRunner.getServers().stream() + .filter(server -> !server.isCoordinator()) + .findFirst().orElseThrow(); + worker.close(); + Future future = executor.submit(() -> queryRunner.execute("SELECT count(*) FROM tpch.tiny.lineitem -- " + randomUUID())); + // the worker is shut down already, but restartWorker() will reuse its address + queryRunner.restartWorker(worker); + assertThatThrownBy(future::get) + .isInstanceOf(ExecutionException.class) + .hasStackTraceContaining("Error loading catalogs on worker"); + + // Ensure that the restarted worker is able to serve queries. + assertThat((long) queryRunner.execute("SELECT count(*) FROM tpch.tiny.lineitem").getOnlyValue()) + .isEqualTo(60_175); + } + finally { + cancelQueries(queryRunner); + } + } + } + + private static void waitForQueryStart(DistributedQueryRunner queryRunner, String sql) + { + assertEventually(() -> { + BasicQueryInfo queryInfo = queryRunner.getCoordinator().getQueryManager().getQueries().stream() + .filter(query -> query.getQuery().equals(sql)) + .collect(onlyElement()); + assertThat(queryInfo.getState()).isEqualTo(RUNNING); + }); + } + + private static void restartWorker(DistributedQueryRunner queryRunner) + throws Exception + { + TestingTrinoServer worker = queryRunner.getServers().stream() + .filter(server -> !server.isCoordinator()) + .findFirst().orElseThrow(); + queryRunner.restartWorker(worker); + } + + private static void cancelQueries(DistributedQueryRunner queryRunner) + { + QueryManager queryManager = queryRunner.getCoordinator().getQueryManager(); + queryManager.getQueries().stream() + .map(BasicQueryInfo::getQueryId) + .forEach(queryManager::cancelQuery); + } +}