Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CNDB-11713: index nulls in sai to optimize hybrid ANN #1455

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 8 additions & 36 deletions src/java/org/apache/cassandra/index/sai/IndexContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@
package org.apache.cassandra.index.sai;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
Expand Down Expand Up @@ -77,9 +75,7 @@
import org.apache.cassandra.index.sai.metrics.ColumnQueryMetrics;
import org.apache.cassandra.index.sai.metrics.IndexMetrics;
import org.apache.cassandra.index.sai.plan.Expression;
import org.apache.cassandra.index.sai.plan.Orderer;
import org.apache.cassandra.index.sai.utils.PrimaryKey;
import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey;
import org.apache.cassandra.index.sai.utils.TypeUtil;
import org.apache.cassandra.index.sai.view.IndexViewManager;
import org.apache.cassandra.index.sai.view.View;
Expand All @@ -88,7 +84,6 @@
import org.apache.cassandra.schema.ColumnMetadata;
import org.apache.cassandra.schema.IndexMetadata;
import org.apache.cassandra.schema.TableId;
import org.apache.cassandra.utils.CloseableIterator;
import org.apache.cassandra.utils.FBUtilities;
import org.apache.cassandra.utils.NoSpamLogger;
import org.apache.cassandra.utils.Pair;
Expand Down Expand Up @@ -442,26 +437,26 @@ public MemtableIndex getPendingMemtableIndex(LifecycleNewTracker tracker)
// but they are not a problem as post-filtering would get rid of them.
// The keys matched in other indexes cannot be safely subtracted
// as indexes may contain false positives caused by deletes and updates.
private KeyRangeIterator getNonEqIterator(QueryContext context, Expression expression, AbstractBounds<PartitionPosition> keyRange)
private KeyRangeIterator getNonEqIterator(QueryContext context, Expression expression, AbstractBounds<PartitionPosition> keyRange, boolean isNonReducing)
{
KeyRangeIterator allKeys = scanMemtable(keyRange);
KeyRangeIterator allKeys = scanMemtable(keyRange, isNonReducing);
if (TypeUtil.supportsRounding(expression.validator))
{
return allKeys;
}
else
{
Expression negExpression = expression.negated();
KeyRangeIterator matchedKeys = searchMemtable(context, negExpression, keyRange, Integer.MAX_VALUE);
KeyRangeIterator matchedKeys = searchMemtable(context, negExpression, keyRange, isNonReducing, Integer.MAX_VALUE);
return KeyRangeAntiJoinIterator.create(allKeys, matchedKeys);
}
}

public KeyRangeIterator searchMemtable(QueryContext context, Expression expression, AbstractBounds<PartitionPosition> keyRange, int limit)
public KeyRangeIterator searchMemtable(QueryContext context, Expression expression, AbstractBounds<PartitionPosition> keyRange, boolean isNonReducing, int limit)
{
if (expression.getOp().isNonEquality())
{
return getNonEqIterator(context, expression, keyRange);
return getNonEqIterator(context, expression, keyRange, isNonReducing);
}

Collection<MemtableIndex> memtables = liveMemtables.values();
Expand All @@ -471,7 +466,7 @@ public KeyRangeIterator searchMemtable(QueryContext context, Expression expressi
return KeyRangeIterator.empty();
}

KeyRangeUnionIterator.Builder builder = KeyRangeUnionIterator.builder();
KeyRangeUnionIterator.Builder builder = KeyRangeUnionIterator.builder(isNonReducing);

try
{
Expand All @@ -489,15 +484,15 @@ public KeyRangeIterator searchMemtable(QueryContext context, Expression expressi
}
}

private KeyRangeIterator scanMemtable(AbstractBounds<PartitionPosition> keyRange)
private KeyRangeIterator scanMemtable(AbstractBounds<PartitionPosition> keyRange, boolean isNonReducing)
{
Collection<Memtable> memtables = liveMemtables.keySet();
if (memtables.isEmpty())
{
return KeyRangeIterator.empty();
}

KeyRangeIterator.Builder builder = KeyRangeUnionIterator.builder(memtables.size());
KeyRangeIterator.Builder builder = KeyRangeUnionIterator.builder(memtables.size(), isNonReducing);

try
{
Expand All @@ -516,29 +511,6 @@ private KeyRangeIterator scanMemtable(AbstractBounds<PartitionPosition> keyRange
}
}

// Search all memtables for all PrimaryKeys in list.
public List<CloseableIterator<PrimaryKeyWithSortKey>> orderResultsBy(QueryContext context, List<PrimaryKey> source, Orderer orderer, int limit)
{
Collection<MemtableIndex> memtables = liveMemtables.values();

if (memtables.isEmpty())
return List.of();

List<CloseableIterator<PrimaryKeyWithSortKey>> result = new ArrayList<>(memtables.size());
try
{
for (MemtableIndex index : memtables)
result.add(index.orderResultsBy(context, source, orderer, limit));

return result;
}
catch (Exception ex)
{
FileUtils.closeQuietly(result);
throw ex;
}
}

public long liveMemtableWriteCount()
{
return liveMemtables.values().stream().mapToLong(MemtableIndex::writeCount).sum();
Expand Down
14 changes: 7 additions & 7 deletions src/java/org/apache/cassandra/index/sai/SSTableIndex.java
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ private static SearchableIndex createSearchableIndex(SSTableContext sstableConte
if (CassandraRelevantProperties.SAI_INDEX_READS_DISABLED.getBoolean())
{
logger.info("Creating dummy (empty) index searcher for sstable {} as SAI index reads are disabled", sstableContext.sstable.descriptor);
return new EmptyIndex();
return new EmptyIndex(null);
}

return perIndexComponents.onDiskFormat().newSearchableIndex(sstableContext, perIndexComponents);
Expand Down Expand Up @@ -230,11 +230,11 @@ public KeyRangeIterator search(Expression expression,
int limit) throws IOException
{
if (expression.getOp().isNonEquality())
{
return getNonEqIterator(expression, keyRange, context, defer);
}

return searchableIndex.search(expression, keyRange, context, defer, limit);
else if (expression.getOp() == Expression.Op.IS_NULL)
return searchableIndex.searchNulls(keyRange, context);
else
return searchableIndex.search(expression, keyRange, context, defer, limit);
}

public List<CloseableIterator<PrimaryKeyWithSortKey>> orderBy(Orderer orderer,
Expand Down Expand Up @@ -332,9 +332,9 @@ public int hashCode()
return Objects.hashCode(sstableContext, indexContext);
}

public List<CloseableIterator<PrimaryKeyWithSortKey>> orderResultsBy(QueryContext context, List<PrimaryKey> keys, Orderer orderer, int limit, long totalRows) throws IOException
public List<CloseableIterator<PrimaryKeyWithSortKey>> orderResultsBy(QueryContext context, List<PrimaryKey> keys, Orderer orderer, int limit, long totalRows, boolean canSkipOutOfWindowPKs) throws IOException
{
return searchableIndex.orderResultsBy(context, keys, orderer, limit, totalRows);
return searchableIndex.orderResultsBy(context, keys, orderer, limit, totalRows, canSkipOutOfWindowPKs);
}

public String toString()
Expand Down
19 changes: 18 additions & 1 deletion src/java/org/apache/cassandra/index/sai/disk/EmptyIndex.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.cassandra.db.virtual.SimpleDataSet;
import org.apache.cassandra.dht.AbstractBounds;
import org.apache.cassandra.index.sai.QueryContext;
import org.apache.cassandra.index.sai.SSTableContext;
import org.apache.cassandra.index.sai.disk.v1.Segment;
import org.apache.cassandra.index.sai.iterators.KeyRangeIterator;
import org.apache.cassandra.index.sai.plan.Expression;
Expand All @@ -38,6 +39,13 @@

public class EmptyIndex implements SearchableIndex
{
private final SSTableContext sstableContext;

public EmptyIndex(SSTableContext sstableContext)
{
this.sstableContext = sstableContext;
}

@Override
public long indexFileCacheSize()
{
Expand Down Expand Up @@ -96,6 +104,15 @@ public KeyRangeIterator search(Expression expression,
return KeyRangeIterator.empty();
}

@Override
public KeyRangeIterator searchNulls(AbstractBounds<PartitionPosition> keyRange, QueryContext context) throws IOException
{
// If an index is empty, then all the rows were null.
// TODO test the pathological case where we have a large sstable with an empty column.
// In that case, we might have a performance regression here.
return PrimaryKeyMapIterator.create(sstableContext, keyRange);
}

@Override
public List<CloseableIterator<PrimaryKeyWithSortKey>> orderBy(Orderer orderer,
Expression slice,
Expand Down Expand Up @@ -136,7 +153,7 @@ public void close() throws IOException
}

@Override
public List<CloseableIterator<PrimaryKeyWithSortKey>> orderResultsBy(QueryContext context, List<PrimaryKey> keys, Orderer orderer, int limit, long totalRows) throws IOException
public List<CloseableIterator<PrimaryKeyWithSortKey>> orderResultsBy(QueryContext context, List<PrimaryKey> keys, Orderer orderer, int limit, long totalRows, boolean canSkipOutOfWindowPKs) throws IOException
{
return List.of();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ protected PrimaryKey computeNext()
return endOfData();

var primaryKey = primaryKeyMap.primaryKeyFromRowId(rowId);
return new PrimaryKeyWithSource(primaryKey, primaryKeyMap.getSSTableId(), rowId);
return new PrimaryKeyWithSource(primaryKey, primaryKeyMap.getSSTableId(), rowId, primaryKeyMap.getMinTimestamp(), primaryKeyMap.getMaxTimestamp());
}
catch (Throwable t)
{
Expand Down Expand Up @@ -163,7 +163,7 @@ private long getNextRowId() throws IOException
{
long targetSstableRowId;
if (skipToToken instanceof PrimaryKeyWithSource
&& ((PrimaryKeyWithSource) skipToToken).getSourceSstableId().equals(primaryKeyMap.getSSTableId()))
&& ((PrimaryKeyWithSource) skipToToken).matchesSource(primaryKeyMap.getSSTableId()))
{
targetSstableRowId = ((PrimaryKeyWithSource) skipToToken).getSourceRowId();
}
Expand Down
12 changes: 12 additions & 0 deletions src/java/org/apache/cassandra/index/sai/disk/PrimaryKeyMap.java
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,18 @@ default void close() throws IOException
*/
SSTableId<?> getSSTableId();

/**
* Returns the minimum timestamp of the for the sstable associated with this {@link PrimaryKeyMap}
* @return the minimum timestamp
*/
long getMinTimestamp();

/**
* Returns the maximum timestamp of the for the sstable associated with this {@link PrimaryKeyMap}
* @return the maximum timestamp
*/
long getMaxTimestamp();

/**
* Returns a {@link PrimaryKey} for a row Id
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,19 @@ private enum KeyFilter

private final PrimaryKeyMap keys;
private final KeyFilter filter;
private final long minTimestamp;
private final long maxTimestamp;
private long currentRowId;


private PrimaryKeyMapIterator(PrimaryKeyMap keys, PrimaryKey min, PrimaryKey max, long startRowId, KeyFilter filter)
private PrimaryKeyMapIterator(PrimaryKeyMap keys, PrimaryKey min, PrimaryKey max, long startRowId, KeyFilter filter, long minTimestamp, long maxTimestamp)
{
super(min, max, keys.count());
this.keys = keys;
this.filter = filter;
this.currentRowId = startRowId;
this.minTimestamp = minTimestamp;
this.maxTimestamp = maxTimestamp;
}

public static KeyRangeIterator create(SSTableContext ctx, AbstractBounds<PartitionPosition> keyRange) throws IOException
Expand All @@ -76,16 +80,12 @@ public static KeyRangeIterator create(SSTableContext ctx, AbstractBounds<Partiti
else // the table doesn't consist anything we want to filter out, so let's use the cheap option
filter = KeyFilter.ALL;

if (perSSTableComponents.isEmpty())
long count = ctx.primaryKeyMapFactory.count();
if (count == 0)
return KeyRangeIterator.empty();

PrimaryKeyMap keys = ctx.primaryKeyMapFactory.newPerSSTablePrimaryKeyMap();
long count = keys.count();
if (keys.count() == 0)
{
keys.close();
return KeyRangeIterator.empty();
}
assert keys.count() == count : "Expected " + count + " keys, but got " + keys.count();

PrimaryKey.Factory pkFactory = ctx.primaryKeyFactory();
Token minToken = keyRange.left.getToken();
Expand All @@ -96,24 +96,26 @@ public static KeyRangeIterator create(SSTableContext ctx, AbstractBounds<Partiti
? minKeyBound
: sstableMinKey;
long startRowId = minToken.isMinimum() ? 0 : keys.ceiling(minKey);
return new PrimaryKeyMapIterator(keys, sstableMinKey, sstableMaxKey, startRowId, filter);
return new PrimaryKeyMapIterator(keys, sstableMinKey, sstableMaxKey, startRowId, filter, ctx.sstable().getMinTimestamp(), ctx.sstable().getMaxTimestamp());
}

@Override
protected void performSkipTo(PrimaryKey nextKey)
{
this.currentRowId = keys.ceiling(nextKey);
long possibleNextRowId = keys.ceiling(nextKey);
this.currentRowId = Math.max(possibleNextRowId, currentRowId);
}

@Override
protected PrimaryKey computeNext()
{
while (currentRowId >= 0 && currentRowId < keys.count())
{
PrimaryKey key = keys.primaryKeyFromRowId(currentRowId++);
long rowId = currentRowId++;
PrimaryKey key = keys.primaryKeyFromRowId(rowId);
if (filter == KeyFilter.KEYS_WITH_CLUSTERING && key.hasEmptyClustering())
continue;
return key;
return new PrimaryKeyWithSource(key, keys.getSSTableId(), rowId, minTimestamp, maxTimestamp);
}
return endOfData();
}
Expand Down
Loading