Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve electra attestation aggregation #8346

Merged
merged 8 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public MatchingDataAttestationGroup(
this.spec = spec;
this.attestationData = attestationData;
this.committeesSize = committeesSize;
includedValidators = createEmptyAttestationBits();
this.includedValidators = createEmptyAttestationBits();
}

private AttestationBitsAggregator createEmptyAttestationBits() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@
import com.google.common.base.MoreObjects;
import it.unimi.dsi.fastutil.ints.Int2IntMap;
import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.ints.IntList;
import java.util.ArrayList;
import java.util.List;
import java.util.BitSet;
import java.util.stream.IntStream;
import tech.pegasys.teku.infrastructure.ssz.collections.SszBitlist;
import tech.pegasys.teku.infrastructure.ssz.collections.SszBitvector;
Expand All @@ -29,6 +26,7 @@
class AttestationBitsAggregatorElectra implements AttestationBitsAggregator {
private SszBitlist aggregationBits;
private SszBitvector committeeBits;
private Int2IntMap committeeBitsStartingPositions;
private final Int2IntMap committeesSize;

AttestationBitsAggregatorElectra(
Expand All @@ -38,6 +36,7 @@ class AttestationBitsAggregatorElectra implements AttestationBitsAggregator {
this.aggregationBits = aggregationBits;
this.committeeBits = committeeBits;
this.committeesSize = committeesSize;
this.committeeBitsStartingPositions = calculateCommitteeStartingPositions(committeeBits);
}

static AttestationBitsAggregator fromAttestationSchema(
Expand Down Expand Up @@ -70,101 +69,116 @@ private boolean or(
final SszBitlist otherAggregatedBits,
final boolean isAggregation) {

final SszBitvector aggregatedCommitteeBits = committeeBits.or(otherCommitteeBits);
final SszBitvector combinedCommitteeBits = committeeBits.or(otherCommitteeBits);

final Int2IntMap committeeBitsStartingPositions =
calculateCommitteeStartingPositions(committeeBits);
final Int2IntMap otherCommitteeBitsStartingPositions =
calculateCommitteeStartingPositions(otherCommitteeBits);
final Int2IntMap aggregatedCommitteeBitsStartingPositions =
calculateCommitteeStartingPositions(aggregatedCommitteeBits);

final IntList aggregatedCommitteeIndices = aggregatedCommitteeBits.getAllSetBits();
calculateCommitteeStartingPositions(combinedCommitteeBits);

// create an aggregation bit big as last boundary for last committee bit
final int lastCommitteeIndex =
aggregatedCommitteeIndices.getInt(aggregatedCommitteeIndices.size() - 1);
final int lastCommitteeIndex = combinedCommitteeBits.getLastSetBitIndex();
final int lastCommitteeStartingPosition =
aggregatedCommitteeBitsStartingPositions.get(lastCommitteeIndex);
final int combinedAggregationBitsSize =
lastCommitteeStartingPosition + committeesSize.get(lastCommitteeIndex);

final IntList aggregationIndices = new IntArrayList();

// aggregateBits contains a new set of bits
final BitSet combinedAggregationIndices = new BitSet(combinedAggregationBitsSize);

// let's go over all aggregated committees to calculate indices for the combined aggregation
// bits
try {
aggregatedCommitteeBits
combinedCommitteeBits
.streamAllSetBits()
.forEach(
committeeIndex -> {
int committeeSize = committeesSize.get(committeeIndex);
int destinationStart = aggregatedCommitteeBitsStartingPositions.get(committeeIndex);

final List<SszBitlist> sources = new ArrayList<>();
final List<Integer> sourcesStartingPosition = new ArrayList<>();
SszBitlist source1 = null, maybeSource2 = null;
int source1StartingPosition = 0, source2StartingPosition = 0;

if (committeeBitsStartingPositions.containsKey(committeeIndex)) {
sources.add(aggregationBits);
sourcesStartingPosition.add(committeeBitsStartingPositions.get(committeeIndex));
source1 = aggregationBits;
source1StartingPosition = committeeBitsStartingPositions.get(committeeIndex);
}
if (otherCommitteeBitsStartingPositions.containsKey(committeeIndex)) {
sources.add(otherAggregatedBits);
sourcesStartingPosition.add(
otherCommitteeBitsStartingPositions.get(committeeIndex));
if (source1 != null) {
maybeSource2 = otherAggregatedBits;
source2StartingPosition =
otherCommitteeBitsStartingPositions.get(committeeIndex);
} else {
source1 = otherAggregatedBits;
source1StartingPosition =
otherCommitteeBitsStartingPositions.get(committeeIndex);
}
}

IntStream.range(0, committeeSize)
.forEach(
positionInCommittee -> {
if (orSingleBit(
positionInCommittee,
sources,
sourcesStartingPosition,
isAggregation)) {
aggregationIndices.add(destinationStart + positionInCommittee);
}
});
// Now that we know:
// 1. which aggregationBits (this or other or both) will contribute to the result
// 2. the offset of the committee for each contributing aggregation bits
// We can go over the committee and calculate the combined aggregate bits
for (int positionInCommittee = 0;
positionInCommittee < committeeSize;
positionInCommittee++) {
if (orSingleBit(
positionInCommittee,
source1,
source1StartingPosition,
maybeSource2,
source2StartingPosition,
isAggregation)) {
combinedAggregationIndices.set(destinationStart + positionInCommittee);
}
}
});
} catch (final CannotAggregateException __) {
return false;
}

committeeBits = aggregatedCommitteeBits;
committeeBits = combinedCommitteeBits;
aggregationBits =
aggregationBits
.getSchema()
.ofBits(
lastCommitteeStartingPosition + committeesSize.get(lastCommitteeIndex),
aggregationIndices.toIntArray());
.wrapBitSet(combinedAggregationBitsSize, combinedAggregationIndices);
committeeBitsStartingPositions = aggregatedCommitteeBitsStartingPositions;

return true;
}

private boolean orSingleBit(
final int positionInCommittee,
final List<SszBitlist> sources,
final List<Integer> sourcesStartingPosition,
final boolean isAggregating) {
boolean aggregatedBit = false;
for (int s = 0; s < sources.size(); s++) {
final boolean sourceBit =
sources.get(s).getBit(sourcesStartingPosition.get(s) + positionInCommittee);

if (!aggregatedBit && sourceBit) {
aggregatedBit = true;
} else if (isAggregating && aggregatedBit && sourceBit) {
throw new CannotAggregateException();
}
final SszBitlist source1,
final int source1StartingPosition,
final SszBitlist maybeSource2,
final int source2StartingPosition,
final boolean isAggregation) {

final boolean source1Bit = source1.getBit(source1StartingPosition + positionInCommittee);

if (maybeSource2 == null) {
return source1Bit;
}
return aggregatedBit;

final boolean source2Bit = maybeSource2.getBit(source2StartingPosition + positionInCommittee);

if (isAggregation && source1Bit && source2Bit) {
throw new CannotAggregateException();
}

return source1Bit || source2Bit;
}

private Int2IntMap calculateCommitteeStartingPositions(final SszBitvector committeeBits) {
final Int2IntMap committeeBitsStartingPositions = new Int2IntOpenHashMap();
final IntList committeeIndices = committeeBits.getAllSetBits();
int currentOffset = 0;
for (final int index : committeeIndices) {
committeeBitsStartingPositions.put(index, currentOffset);
currentOffset += committeesSize.get(index);
}
final int[] currentOffset = {0};
committeeBits
.streamAllSetBits()
.forEach(
index -> {
committeeBitsStartingPositions.put(index, currentOffset[0]);
currentOffset[0] += committeesSize.get(index);
});

return committeeBitsStartingPositions;
}
Expand All @@ -175,14 +189,14 @@ public boolean isSuperSetOf(final Attestation other) {
return false;
}

if (committeeBits.equals(other.getCommitteeBitsRequired())) {
if (committeeBits.getBitCount() == other.getCommitteeBitsRequired().getBitCount()) {
// this committeeBits is a superset of the other, and bit count is the same, so they are the
// same set and we can directly compare aggregation bits.
return aggregationBits.isSuperSetOf(other.getAggregationBits());
}

final SszBitvector otherCommitteeBits = other.getCommitteeBitsRequired();

final Int2IntMap committeeBitsStartingPositions =
calculateCommitteeStartingPositions(committeeBits);
final Int2IntMap otherCommitteeBitsStartingPositions =
calculateCommitteeStartingPositions(otherCommitteeBits);

Expand Down Expand Up @@ -235,6 +249,7 @@ public String toString() {
.add("aggregationBits", aggregationBits)
.add("committeeBits", committeeBits)
.add("committeesSize", committeesSize)
.add("committeeBitsStartingPositions", committeeBitsStartingPositions)
.toString();
}
}
Loading