Skip to content

Commit

Permalink
Test query execution while worker restarts
Browse files Browse the repository at this point in the history
  • Loading branch information
findepi committed May 13, 2024
1 parent 23882f2 commit 86248b3
Show file tree
Hide file tree
Showing 3 changed files with 230 additions and 7 deletions.
13 changes: 12 additions & 1 deletion testing/trino-testing/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@
<artifactId>configuration</artifactId>
</dependency>

<dependency>
<groupId>io.airlift</groupId>
<artifactId>http-server</artifactId>
</dependency>

<dependency>
<groupId>io.airlift</groupId>
<artifactId>log</artifactId>
Expand Down Expand Up @@ -200,6 +205,12 @@
<artifactId>assertj-core</artifactId>
</dependency>

<dependency>
<groupId>org.eclipse.jetty</groupId>
<artifactId>jetty-server</artifactId>
<version>12.0.9</version>
</dependency>

<dependency>
<groupId>org.jdbi</groupId>
<artifactId>jdbi3-core</artifactId>
Expand Down Expand Up @@ -246,4 +257,4 @@
<scope>test</scope>
</dependency>
</dependencies>
</project>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.google.inject.Key;
import com.google.inject.Module;
import io.airlift.discovery.server.testing.TestingDiscoveryServer;
import io.airlift.http.server.HttpServer;
import io.airlift.log.Logger;
import io.airlift.log.Logging;
import io.opentelemetry.sdk.testing.exporter.InMemorySpanExporter;
Expand Down Expand Up @@ -57,10 +58,13 @@
import io.trino.sql.planner.Plan;
import io.trino.testing.containers.OpenTracingCollector;
import io.trino.transaction.TransactionManager;
import org.eclipse.jetty.server.Connector;
import org.eclipse.jetty.server.Server;
import org.intellij.lang.annotations.Language;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.lang.reflect.Field;
import java.net.URI;
import java.nio.file.Path;
import java.util.HashMap;
Expand All @@ -79,6 +83,7 @@
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Throwables.throwIfUnchecked;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.Iterables.getOnlyElement;
import static com.google.inject.util.Modules.EMPTY_MODULE;
import static io.airlift.log.Level.DEBUG;
import static io.airlift.log.Level.ERROR;
Expand All @@ -88,7 +93,9 @@
import static io.trino.execution.querystats.PlanOptimizersStatsCollector.createPlanOptimizersStatsCollector;
import static java.lang.Boolean.parseBoolean;
import static java.lang.System.getenv;
import static java.util.Arrays.asList;
import static java.util.Objects.requireNonNull;
import static org.assertj.core.api.Assertions.assertThat;

public class DistributedQueryRunner
implements QueryRunner
Expand All @@ -100,7 +107,7 @@ public class DistributedQueryRunner
private TestingDiscoveryServer discoveryServer;
private TestingTrinoServer coordinator;
private Optional<TestingTrinoServer> backupCoordinator;
private Runnable registerNewWorker;
private Consumer<Map<String, String>> registerNewWorker;
private final InMemorySpanExporter spanExporter = InMemorySpanExporter.create();
private final List<TestingTrinoServer> servers = new CopyOnWriteArrayList<>();
private final List<FunctionBundle> functionBundles = new CopyOnWriteArrayList<>(ImmutableList.of(CustomFunctionBundle.CUSTOM_FUNCTIONS));
Expand Down Expand Up @@ -149,11 +156,14 @@ private DistributedQueryRunner(
extraCloseables.forEach(closeable -> closer.register(() -> closeUnchecked(closeable)));
log.debug("Created TestingDiscoveryServer in %s", nanosSince(discoveryStart));

registerNewWorker = () -> {
registerNewWorker = additionalWorkerProperties -> {
@SuppressWarnings("resource")
TestingTrinoServer ignored = createServer(
false,
extraProperties,
ImmutableMap.<String, String>builder()
.putAll(extraProperties)
.putAll(additionalWorkerProperties)
.buildOrThrow(),
environment,
additionalModule,
baseDataDir,
Expand All @@ -163,7 +173,7 @@ private DistributedQueryRunner(
};

for (int i = 0; i < workerCount; i++) {
registerNewWorker.run();
registerNewWorker.accept(Map.of());
}

Map<String, String> extraCoordinatorProperties = new HashMap<>();
Expand Down Expand Up @@ -317,11 +327,37 @@ private static TestingTrinoServer createTestingTrinoServer(
public void addServers(int nodeCount)
{
for (int i = 0; i < nodeCount; i++) {
registerNewWorker.run();
registerNewWorker.accept(Map.of());
}
ensureNodesGloballyVisible();
}

/**
* Simulate worker restart as e.g. in Kubernetes after a pod is killed.
*/
public void restartWorker(TestingTrinoServer server)
throws Exception
{
URI baseUrl = server.getBaseUrl();
checkState(servers.remove(server), "Server not found: %s", server);
HttpServer workerHttpServer = server.getInstance(Key.get(HttpServer.class));
// Prevent any HTTP communication with the worker, as if the worker process was killed.
Field serverField = HttpServer.class.getDeclaredField("server");
serverField.setAccessible(true);
Connector httpConnector = getOnlyElement(asList(((Server) serverField.get(workerHttpServer)).getConnectors()));
httpConnector.stop();
server.close();

Map<String, String> reusePort = Map.of("http-server.http.port", Integer.toString(baseUrl.getPort()));
registerNewWorker.accept(reusePort);
// Verify the address was reused.
assertThat(servers.stream()
.map(TestingTrinoServer::getBaseUrl)
.filter(baseUrl::equals))
.hasSize(1);
// Do not wait for new server to be fully registered with other servers
}

private void ensureNodesGloballyVisible()
{
for (TestingTrinoServer server : servers) {
Expand Down Expand Up @@ -589,7 +625,7 @@ public final void close()
discoveryServer = null;
coordinator = null;
backupCoordinator = Optional.empty();
registerNewWorker = () -> {
registerNewWorker = _ -> {
throw new IllegalStateException("Already closed");
};
servers.clear();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.tests;

import io.trino.execution.QueryManager;
import io.trino.server.BasicQueryInfo;
import io.trino.server.testing.TestingTrinoServer;
import io.trino.testing.DistributedQueryRunner;
import io.trino.testing.MaterializedResult;
import io.trino.tests.tpch.TpchQueryRunner;
import org.junit.jupiter.api.RepeatedTest;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.parallel.Execution;

import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;

import static com.google.common.collect.MoreCollectors.onlyElement;
import static io.airlift.concurrent.Threads.daemonThreadsNamed;
import static io.trino.execution.QueryState.RUNNING;
import static io.trino.testing.assertions.Assert.assertEventually;
import static java.util.UUID.randomUUID;
import static java.util.concurrent.Executors.newCachedThreadPool;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.junit.jupiter.api.parallel.ExecutionMode.SAME_THREAD;

/**
* Test that tasks are cleanly rejected when a node restarts (same node ID, different node instance ID).
*/
@Execution(SAME_THREAD) // run single threaded to avoid creating multiple query runners at once
public class TestWorkerRestart
{
// When working with the test locally it's practical to run multiple iterations at once.
private static final int TEST_ITERATIONS = 1;

/**
* Test that query passes even if worker is restarted just before query.
*/
@RepeatedTest(TEST_ITERATIONS)
@Timeout(90)
public void testRestartBeforeQuery()
throws Exception
{
try (DistributedQueryRunner queryRunner = TpchQueryRunner.builder().build();
ExecutorService executor = newCachedThreadPool(daemonThreadsNamed(getClass().getSimpleName() + "-%d"))) {
try {
// Ensure everything initialized
assertThat((long) queryRunner.execute("SELECT count(*) FROM tpch.tiny.lineitem").getOnlyValue())
.isEqualTo(60_175);

restartWorker(queryRunner);
// Even though the worker is restarted before we send a query, it is not fully announced to the coordinator.
// Coordinator will still try to send the query to the worker thinking it is the previous instance of it.
Future<MaterializedResult> future = executor.submit(() -> queryRunner.execute("SELECT count(*) FROM tpch.sf1.lineitem -- " + randomUUID()));
future.get(); // query should succeed

// Ensure that the restarted worker is able to serve queries.
assertThat((long) queryRunner.execute("SELECT count(*) FROM tpch.tiny.lineitem").getOnlyValue())
.isEqualTo(60_175);
}
finally {
cancelQueries(queryRunner);
}
}
}

/**
* Test that query fails with when worker crashes during its execution, but next query (e.g. retried query) succeeds without issues.
*/
@RepeatedTest(TEST_ITERATIONS)
@Timeout(90)
public void testRestartDuringQuery()
throws Exception
{
try (DistributedQueryRunner queryRunner = TpchQueryRunner.builder().build();
ExecutorService executor = newCachedThreadPool(daemonThreadsNamed(getClass().getSimpleName() + "-%d"))) {
try {
// Ensure everything initialized
assertThat((long) queryRunner.execute("SELECT count(*) FROM tpch.tiny.lineitem").getOnlyValue())
.isEqualTo(60_175);

String sql = "SELECT count(*) FROM tpch.sf1000000000.lineitem -- " + randomUUID();
Future<MaterializedResult> future = executor.submit(() -> queryRunner.execute(sql));
waitForQueryStart(queryRunner, sql);
restartWorker(queryRunner);
assertThatThrownBy(future::get)
.isInstanceOf(ExecutionException.class)
.cause().hasMessageFindingMatch("^Expected response code from \\S+ to be 200, but was 500" +
"|Error fetching \\S+: Expected response code to be 200, but was 500");

// Ensure that the restarted worker is able to serve queries.
assertThat((long) queryRunner.execute("SELECT count(*) FROM tpch.tiny.lineitem").getOnlyValue())
.isEqualTo(60_175);
}
finally {
cancelQueries(queryRunner);
}
}
}

/**
* Test that query passes if a worker crashed before query started but it still potentially starting up when query is being scheduled.
*/
@RepeatedTest(TEST_ITERATIONS)
@Timeout(90)
public void testStartDuringQuery()
throws Exception
{
try (DistributedQueryRunner queryRunner = TpchQueryRunner.builder().build();
ExecutorService executor = newCachedThreadPool(daemonThreadsNamed(getClass().getSimpleName() + "-%d"))) {
try {
// Ensure everything initialized
assertThat((long) queryRunner.execute("SELECT count(*) FROM tpch.tiny.lineitem").getOnlyValue())
.isEqualTo(60_175);

TestingTrinoServer worker = queryRunner.getServers().stream()
.filter(server -> !server.isCoordinator())
.findFirst().orElseThrow();
worker.close();
Future<MaterializedResult> future = executor.submit(() -> queryRunner.execute("SELECT count(*) FROM tpch.tiny.lineitem -- " + randomUUID()));
// the worker is shut down already, but restartWorker() will reuse its address
queryRunner.restartWorker(worker);
assertThatThrownBy(future::get)
.isInstanceOf(ExecutionException.class)
.hasStackTraceContaining("Error loading catalogs on worker");

// Ensure that the restarted worker is able to serve queries.
assertThat((long) queryRunner.execute("SELECT count(*) FROM tpch.tiny.lineitem").getOnlyValue())
.isEqualTo(60_175);
}
finally {
cancelQueries(queryRunner);
}
}
}

private static void waitForQueryStart(DistributedQueryRunner queryRunner, String sql)
{
assertEventually(() -> {
BasicQueryInfo queryInfo = queryRunner.getCoordinator().getQueryManager().getQueries().stream()
.filter(query -> query.getQuery().equals(sql))
.collect(onlyElement());
assertThat(queryInfo.getState()).isEqualTo(RUNNING);
});
}

private static void restartWorker(DistributedQueryRunner queryRunner)
throws Exception
{
TestingTrinoServer worker = queryRunner.getServers().stream()
.filter(server -> !server.isCoordinator())
.findFirst().orElseThrow();
queryRunner.restartWorker(worker);
}

private static void cancelQueries(DistributedQueryRunner queryRunner)
{
QueryManager queryManager = queryRunner.getCoordinator().getQueryManager();
queryManager.getQueries().stream()
.map(BasicQueryInfo::getQueryId)
.forEach(queryManager::cancelQuery);
}
}

0 comments on commit 86248b3

Please sign in to comment.