Skip to content

Commit

Permalink
Add concurrent writes reconciliation for INSERT in Delta Lake
Browse files Browse the repository at this point in the history
Allow committing INSERT operations in a concurrent context by simply
place these operations right after any other previously concurrently
completed write operations.
Do not allow however committing such operations in case that there is
a table schema or protocol change.
  • Loading branch information
findinpath committed Sep 6, 2023
1 parent 6a9fa01 commit c70f3a7
Show file tree
Hide file tree
Showing 4 changed files with 275 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@
package io.trino.plugin.deltalake;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.google.common.base.Throwables;
import com.google.common.base.VerifyException;
import com.google.common.collect.Comparators;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableTable;
import com.google.common.collect.Sets;
import dev.failsafe.Failsafe;
import dev.failsafe.RetryPolicy;
import io.airlift.json.JsonCodec;
import io.airlift.log.Logger;
import io.airlift.slice.Slice;
Expand Down Expand Up @@ -51,6 +54,7 @@
import io.trino.plugin.deltalake.transactionlog.DeltaLakeComputedStatistics;
import io.trino.plugin.deltalake.transactionlog.DeltaLakeSchemaSupport;
import io.trino.plugin.deltalake.transactionlog.DeltaLakeSchemaSupport.ColumnMappingMode;
import io.trino.plugin.deltalake.transactionlog.DeltaLakeTransactionLogEntry;
import io.trino.plugin.deltalake.transactionlog.MetadataEntry;
import io.trino.plugin.deltalake.transactionlog.MetadataEntry.Format;
import io.trino.plugin.deltalake.transactionlog.ProtocolEntry;
Expand All @@ -61,6 +65,7 @@
import io.trino.plugin.deltalake.transactionlog.statistics.DeltaLakeFileStatistics;
import io.trino.plugin.deltalake.transactionlog.statistics.DeltaLakeJsonFileStatistics;
import io.trino.plugin.deltalake.transactionlog.writer.TransactionConflictException;
import io.trino.plugin.deltalake.transactionlog.writer.TransactionFailedException;
import io.trino.plugin.deltalake.transactionlog.writer.TransactionLogWriter;
import io.trino.plugin.deltalake.transactionlog.writer.TransactionLogWriterFactory;
import io.trino.plugin.hive.HiveType;
Expand Down Expand Up @@ -140,6 +145,7 @@
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayDeque;
import java.util.Collection;
Expand All @@ -150,6 +156,7 @@
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.OptionalLong;
Expand All @@ -158,6 +165,7 @@
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import java.util.stream.Stream;

import static com.google.common.base.Preconditions.checkArgument;
Expand Down Expand Up @@ -188,6 +196,7 @@
import static io.trino.plugin.deltalake.DeltaLakeColumnType.PARTITION_KEY;
import static io.trino.plugin.deltalake.DeltaLakeColumnType.REGULAR;
import static io.trino.plugin.deltalake.DeltaLakeColumnType.SYNTHESIZED;
import static io.trino.plugin.deltalake.DeltaLakeErrorCode.DELTA_LAKE_BAD_DATA;
import static io.trino.plugin.deltalake.DeltaLakeErrorCode.DELTA_LAKE_BAD_WRITE;
import static io.trino.plugin.deltalake.DeltaLakeErrorCode.DELTA_LAKE_FILESYSTEM_ERROR;
import static io.trino.plugin.deltalake.DeltaLakeErrorCode.DELTA_LAKE_INVALID_SCHEMA;
Expand Down Expand Up @@ -244,6 +253,7 @@
import static io.trino.plugin.deltalake.transactionlog.MetadataEntry.configurationForNewTable;
import static io.trino.plugin.deltalake.transactionlog.TransactionLogParser.getMandatoryCurrentVersion;
import static io.trino.plugin.deltalake.transactionlog.TransactionLogUtil.getTransactionLogDir;
import static io.trino.plugin.deltalake.transactionlog.checkpoint.TransactionLogTail.getEntriesFromJson;
import static io.trino.plugin.hive.HiveMetadata.PRESTO_QUERY_ID_NAME;
import static io.trino.plugin.hive.TableType.EXTERNAL_TABLE;
import static io.trino.plugin.hive.TableType.MANAGED_TABLE;
Expand Down Expand Up @@ -332,6 +342,12 @@ public class DeltaLakeMetadata
private static final int CDF_SUPPORTED_WRITER_VERSION = 4;
private static final int COLUMN_MAPPING_MODE_SUPPORTED_READER_VERSION = 2;
private static final int COLUMN_MAPPING_MODE_SUPPORTED_WRITER_VERSION = 5;
private static final RetryPolicy<Object> TRANSACTION_CONFLICT_RETRY_POLICY = RetryPolicy.builder()
.handleIf(throwable -> Throwables.getCausalChain(throwable).stream().anyMatch(TransactionConflictException.class::isInstance))
.withDelay(Duration.ofMillis(200))
.withJitter(Duration.ofMillis(100))
.withMaxRetries(5)
.build();

// Matches the dummy column Databricks stores in the metastore
private static final List<Column> DUMMY_DATA_COLUMNS = ImmutableList.of(
Expand Down Expand Up @@ -1738,32 +1754,10 @@ public Optional<ConnectorOutputMetadata> finishInsert(

boolean writeCommitted = false;
try {
TransactionLogWriter transactionLogWriter = transactionLogWriterFactory.newWriter(session, handle.getLocation());

long createdTime = Instant.now().toEpochMilli();

TrinoFileSystem fileSystem = fileSystemFactory.create(session);
long commitVersion = getMandatoryCurrentVersion(fileSystem, handle.getLocation()) + 1;
if (commitVersion != handle.getReadVersion() + 1) {
throw new TransactionConflictException(format("Conflicting concurrent writes found. Expected transaction log version: %s, actual version: %s",
handle.getReadVersion(),
commitVersion - 1));
}
Optional<Long> checkpointInterval = handle.getMetadataEntry().getCheckpointInterval();
// it is not obvious why we need to persist this readVersion
transactionLogWriter.appendCommitInfoEntry(getCommitInfoEntry(session, commitVersion, createdTime, INSERT_OPERATION, handle.getReadVersion()));

ColumnMappingMode columnMappingMode = getColumnMappingMode(handle.getMetadataEntry());
List<String> partitionColumns = getPartitionColumns(
handle.getMetadataEntry().getOriginalPartitionColumns(),
handle.getInputColumns(),
columnMappingMode);
List<String> exactColumnNames = getExactColumnNames(handle.getMetadataEntry());
appendAddFileEntries(transactionLogWriter, dataFileInfos, partitionColumns, exactColumnNames, true);

transactionLogWriter.flush();
long commitVersion = Failsafe.with(TRANSACTION_CONFLICT_RETRY_POLICY)
.get(() -> getInsertCommitVersion(session, handle, dataFileInfos));
writeCommitted = true;
writeCheckpointIfNeeded(session, handle.getTableName(), handle.getLocation(), checkpointInterval, commitVersion);
writeCheckpointIfNeeded(session, handle.getTableName(), handle.getLocation(), handle.getMetadataEntry().getCheckpointInterval(), commitVersion);

if (isCollectExtendedStatisticsColumnStatisticsOnWrite(session) && !computedStatistics.isEmpty() && !dataFileInfos.isEmpty()) {
// TODO (https://github.com/trinodb/trino/issues/16088) Add synchronization when version conflict for INSERT is resolved.
Expand All @@ -1778,7 +1772,7 @@ public Optional<ConnectorOutputMetadata> finishInsert(
handle.getLocation(),
maxFileModificationTime,
computedStatistics,
exactColumnNames,
getExactColumnNames(handle.getMetadataEntry()),
Optional.of(extractSchema(handle.getMetadataEntry(), typeManager).stream()
.collect(toImmutableMap(DeltaLakeColumnMetadata::getName, DeltaLakeColumnMetadata::getPhysicalName))));
}
Expand All @@ -1794,6 +1788,68 @@ public Optional<ConnectorOutputMetadata> finishInsert(
return Optional.empty();
}

private long getInsertCommitVersion(ConnectorSession session, DeltaLakeInsertTableHandle handle, List<DataFileInfo> dataFileInfos)
throws IOException
{
long createdTime = Instant.now().toEpochMilli();

TrinoFileSystem fileSystem = fileSystemFactory.create(session);

String transactionLogDirectory = getTransactionLogDir(handle.getLocation());
long currentVersion = getMandatoryCurrentVersion(fileSystem, handle.getLocation());
if (currentVersion < handle.getReadVersion()) {
throw new TransactionFailedException(format("Conflicting concurrent writes found. Expected transaction log version: %s, actual version: %s",
handle.getReadVersion(),
currentVersion));
}
else if (currentVersion > handle.getReadVersion()) {
// Ensure there are no structural changes on the table if concurrent writes finished in the meantime
List<DeltaLakeTransactionLogEntry> transactionLogEntries = LongStream.rangeClosed(handle.getReadVersion() + 1, currentVersion)
.boxed()
.flatMap(version -> {
try {
return getEntriesFromJson(version, transactionLogDirectory, fileSystem)
.orElseThrow(() -> new TrinoException(DELTA_LAKE_BAD_DATA, "Delta Lake log entries are missing for version " + version))
.stream();
}
catch (IOException e) {
throw new TrinoException(DELTA_LAKE_FILESYSTEM_ERROR, "Failed to access table metadata", e);
}
})
.collect(toImmutableList());
Optional<MetadataEntry> currentMetadataEntry = transactionLogEntries.stream()
.map(DeltaLakeTransactionLogEntry::getMetaData)
.filter(Objects::nonNull)
.findFirst();
if (currentMetadataEntry.isPresent()) {
throw new TransactionFailedException(format("Conflicting concurrent writes found. Metadata changed since the version: %s", handle.getReadVersion()));
}
Optional<ProtocolEntry> currentProtocolEntry = transactionLogEntries.stream()
.map(DeltaLakeTransactionLogEntry::getProtocol)
.filter(Objects::nonNull)
.findFirst();
if (currentProtocolEntry.isPresent()) {
throw new TransactionFailedException(format("Conflicting concurrent writes found. Protocol changed since the version: %s", handle.getReadVersion()));
}
}
long commitVersion = currentVersion + 1;
// it is not obvious why we need to persist this readVersion
TransactionLogWriter transactionLogWriter = transactionLogWriterFactory.newWriter(session, handle.getLocation());
transactionLogWriter.appendCommitInfoEntry(getCommitInfoEntry(session, commitVersion, createdTime, INSERT_OPERATION, currentVersion));

ColumnMappingMode columnMappingMode = getColumnMappingMode(handle.getMetadataEntry());
List<String> partitionColumns = getPartitionColumns(
handle.getMetadataEntry().getOriginalPartitionColumns(),
handle.getInputColumns(),
columnMappingMode);
List<String> exactColumnNames = getExactColumnNames(handle.getMetadataEntry());
appendAddFileEntries(transactionLogWriter, dataFileInfos, partitionColumns, exactColumnNames, true);

transactionLogWriter.flush();

return commitVersion;
}

private static List<String> getPartitionColumns(List<String> originalPartitionColumns, List<DeltaLakeColumnHandle> dataColumns, ColumnMappingMode columnMappingMode)
{
return switch (columnMappingMode) {
Expand Down Expand Up @@ -1899,7 +1955,7 @@ public void finishMerge(ConnectorSession session, ConnectorMergeTableHandle merg
TrinoFileSystem fileSystem = fileSystemFactory.create(session);
long currentVersion = getMandatoryCurrentVersion(fileSystem, tableLocation);
if (currentVersion != handle.getReadVersion()) {
throw new TransactionConflictException(format("Conflicting concurrent writes found. Expected transaction log version: %s, actual version: %s", handle.getReadVersion(), currentVersion));
throw new TransactionFailedException(format("Conflicting concurrent writes found. Expected transaction log version: %s, actual version: %s", handle.getReadVersion(), currentVersion));
}
long commitVersion = currentVersion + 1;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.deltalake.transactionlog.writer;

public class TransactionFailedException
extends RuntimeException
{
public TransactionFailedException(String message)
{
super(message);
}

public TransactionFailedException(String message, Throwable cause)
{
super(message, cause);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.airlift.concurrent.MoreFutures;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;
import io.trino.Session;
Expand All @@ -32,6 +33,7 @@
import io.trino.spi.QueryId;
import io.trino.sql.planner.OptimizerConfig.JoinDistributionType;
import io.trino.testing.BaseConnectorSmokeTest;
import io.trino.testing.DataProviders;
import io.trino.testing.DistributedQueryRunner;
import io.trino.testing.MaterializedResult;
import io.trino.testing.MaterializedResultWithQueryId;
Expand All @@ -49,6 +51,9 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ExecutorService;
import java.util.function.BiConsumer;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
Expand Down Expand Up @@ -80,6 +85,7 @@
import static io.trino.tpch.TpchTable.ORDERS;
import static java.lang.String.format;
import static java.util.Comparator.comparing;
import static java.util.concurrent.Executors.newFixedThreadPool;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.assertj.core.api.Assertions.assertThat;
Expand Down Expand Up @@ -2161,6 +2167,56 @@ public void testPartitionFilterIncluded()
}
}

// Repeat test with invocationCount for better test coverage, since the tested aspect is inherently non-deterministic.
@Test(dataProviderClass = DataProviders.class, dataProvider = "trueFalse", invocationCount = 4)
public void testConcurrentInsertsReconciliation(boolean partitioned)
throws Exception
{
int threads = 3;

CyclicBarrier barrier = new CyclicBarrier(threads);
ExecutorService executor = newFixedThreadPool(threads);
String tableName = "test_concurrent_inserts_table_" + randomNameSuffix();

assertUpdate("CREATE TABLE " + tableName + " (a INT, part INT) " +
(partitioned ? " WITH (partitioned_by = ARRAY['part'])" : ""));

try {
// insert data concurrently
executor.invokeAll(ImmutableList.<Callable<Void>>builder()
.add(() -> {
barrier.await(20, SECONDS);
getQueryRunner().execute("INSERT INTO " + tableName + " VALUES (1, 10)");
return null;
})
.add(() -> {
barrier.await(20, SECONDS);
getQueryRunner().execute("INSERT INTO " + tableName + " VALUES (11, 20)");
return null;
})
.add(() -> {
barrier.await(20, SECONDS);
getQueryRunner().execute("INSERT INTO " + tableName + " VALUES (21, 30)");
return null;
})
.build())
.forEach(MoreFutures::getDone);

assertThat(query("SELECT * FROM " + tableName)).matches("VALUES (1, 10), (11, 20), (21, 30)");
assertQuery("SELECT version, operation, read_version FROM \"" + tableName + "$history\"",
"""
VALUES
(0, 'CREATE TABLE', 0),
(1, 'WRITE', 0),
(2, 'WRITE', 1),
(3, 'WRITE', 2)
""");
}
finally {
assertUpdate("DROP TABLE " + tableName);
}
}

private Set<String> getActiveFiles(String tableName)
{
return getActiveFiles(tableName, getQueryRunner().getDefaultSession());
Expand Down
Loading

0 comments on commit c70f3a7

Please sign in to comment.