Skip to content

Commit

Permalink
refactor(merkle-chunk): revamp to stick to stick to supernova api
Browse files Browse the repository at this point in the history
  • Loading branch information
tchataigner committed Feb 29, 2024
1 parent f22c71b commit 19ba9a9
Show file tree
Hide file tree
Showing 5 changed files with 216 additions and 314 deletions.
126 changes: 72 additions & 54 deletions crates/chunk/examples/chunk_add_nivc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use arecibo::supernova::{
};
use arecibo::traits::snark::default_ck_hint;
use arecibo::traits::{CurveCycleEquipped, Dual, Engine};
use bellpepper_chunk::traits::{ChunkCircuitInner, ChunkStepCircuit};
use bellpepper_chunk::{FoldStep, InnerCircuit};
use bellpepper_chunk::traits::ChunkStepCircuit;
use bellpepper_chunk::IterationStep;
use bellpepper_core::num::AllocatedNum;
use bellpepper_core::{ConstraintSystem, SynthesisError};
use ff::{Field, PrimeField};
Expand Down Expand Up @@ -60,18 +60,20 @@ impl<F: PrimeField> ChunkStepCircuit<F> for ChunkStep<F> {

// NIVC `StepCircuit`` implementation
#[derive(Clone, Debug)]
struct FoldStepWrapper<F: PrimeField, C: ChunkStepCircuit<F>, const N: usize> {
inner: FoldStep<F, C, N>,
struct IterationStepWrapper<F: PrimeField, C: ChunkStepCircuit<F>, const N: usize> {
inner: IterationStep<F, C, N>,
}

impl<F: PrimeField, C: ChunkStepCircuit<F>, const N: usize> FoldStepWrapper<F, C, N> {
pub fn new(fold_step: FoldStep<F, C, N>) -> Self {
Self { inner: fold_step }
impl<F: PrimeField, C: ChunkStepCircuit<F>, const N: usize> IterationStepWrapper<F, C, N> {
pub fn new(iteration_step: IterationStep<F, C, N>) -> Self {
Self {
inner: iteration_step,
}
}
}

impl<F: PrimeField, C: ChunkStepCircuit<F>, const N: usize> StepCircuit<F>
for FoldStepWrapper<F, C, N>
for IterationStepWrapper<F, C, N>
{
fn arity(&self) -> usize {
self.inner.arity()
Expand All @@ -89,40 +91,50 @@ impl<F: PrimeField, C: ChunkStepCircuit<F>, const N: usize> StepCircuit<F>
) -> Result<(Option<AllocatedNum<F>>, Vec<AllocatedNum<F>>), SynthesisError> {
let (next_pc, res_inner_synth) =
self.inner
.synthesize(&mut cs.namespace(|| "fold_step_wrapper"), pc, z)?;
.synthesize(&mut cs.namespace(|| "iteration_step_wrapper"), pc, z)?;

Ok((next_pc, res_inner_synth))
}
}

// NIVC `NonUniformCircuit` implementation
struct ChunkCircuit<F: PrimeField, C: ChunkStepCircuit<F>, const N: usize> {
inner: InnerCircuit<F, C, N>,
iteration_steps: Vec<IterationStep<F, C, N>>,
}

impl<F: PrimeField, C: ChunkStepCircuit<F>, const N: usize> ChunkCircuit<F, C, N> {
pub fn new(inner: InnerCircuit<F, C, N>) -> Self {
Self { inner }
pub fn new(inputs: &[F]) -> Self {
Self {
iteration_steps: IterationStep::from_inputs(0, inputs, F::ZERO).unwrap(),
}
}

fn get_iteration_step(&self, step: usize) -> IterationStep<F, C, N> {
self.iteration_steps[step].clone()
}

fn get_iteration_circuit(&self, step: usize) -> ChunkCircuitSet<F, C, N> {
ChunkCircuitSet::IterationStep(IterationStepWrapper::new(self.get_iteration_step(step)))
}
}

#[derive(Clone, Debug)]
enum ChunkCircuitSet<F: PrimeField, C: ChunkStepCircuit<F>, const N: usize> {
IterStep(FoldStepWrapper<F, C, N>),
IterationStep(IterationStepWrapper<F, C, N>),
}

impl<F: PrimeField, C: ChunkStepCircuit<F>, const N: usize> StepCircuit<F>
for ChunkCircuitSet<F, C, N>
{
fn arity(&self) -> usize {
match self {
Self::IterStep(fold_step) => fold_step.inner.arity(),
Self::IterationStep(iteration_step) => iteration_step.inner.arity(),
}
}

fn circuit_index(&self) -> usize {
match self {
Self::IterStep(fold_step) => *fold_step.inner.step_nbr(),
Self::IterationStep(iteration_step) => *iteration_step.inner.circuit_index(),
}
}

Expand All @@ -133,7 +145,7 @@ impl<F: PrimeField, C: ChunkStepCircuit<F>, const N: usize> StepCircuit<F>
z: &[AllocatedNum<F>],
) -> Result<(Option<AllocatedNum<F>>, Vec<AllocatedNum<F>>), SynthesisError> {
match self {
Self::IterStep(fold_step) => fold_step.synthesize(cs, pc, z),
Self::IterationStep(iteration_step) => iteration_step.synthesize(cs, pc, z),
}
}
}
Expand All @@ -145,14 +157,16 @@ impl<E1: CurveCycleEquipped, C: ChunkStepCircuit<E1::Scalar>, const N: usize> No
type C2 = TrivialSecondaryCircuit<<Dual<E1> as Engine>::Scalar>;

fn num_circuits(&self) -> usize {
self.inner.num_fold_steps()
1
}

fn primary_circuit(&self, circuit_index: usize) -> Self::C1 {
if let Some(fold_step) = self.inner.circuits().get(circuit_index) {
return Self::C1::IterStep(FoldStepWrapper::new(fold_step.clone()));
match circuit_index {
0 => {
Self::C1::IterationStep(IterationStepWrapper::new(self.iteration_steps[0].clone()))
}
_ => panic!("No circuit found for index {}", circuit_index),
}
unreachable!()
}

fn secondary_circuit(&self) -> Self::C2 {
Expand All @@ -163,54 +177,44 @@ impl<E1: CurveCycleEquipped, C: ChunkStepCircuit<E1::Scalar>, const N: usize> No
fn main() {
const NUM_ITERS_PER_STEP: usize = 3;

type Inner =
InnerCircuit<<E1 as Engine>::Scalar, ChunkStep<<E1 as Engine>::Scalar>, NUM_ITERS_PER_STEP>;
type C1 =
ChunkCircuit<<E1 as Engine>::Scalar, ChunkStep<<E1 as Engine>::Scalar>, NUM_ITERS_PER_STEP>;

println!("NIVC addition accumulator with a Chunk pattern");
println!("=========================================================");

let z0_primary = vec![
<E1 as Engine>::Scalar::zero(),
<E1 as Engine>::Scalar::zero(),
<E1 as Engine>::Scalar::zero(),
let inputs = vec![
<E1 as Engine>::Scalar::zero(),
<E1 as Engine>::Scalar::one(),
<E1 as Engine>::Scalar::from(2),
<E1 as Engine>::Scalar::from(3),
<E1 as Engine>::Scalar::from(4),
<E1 as Engine>::Scalar::from(5),
<E1 as Engine>::Scalar::from(6),
<E1 as Engine>::Scalar::from(7),
<E1 as Engine>::Scalar::from(8),
<E1 as Engine>::Scalar::from(9),
<E1 as Engine>::Scalar::from(10),
];

// Different instantiations of circuit for each of the nova fold steps
let inner_chunk_circuit = Inner::new(
&[
<E1 as Engine>::Scalar::one(),
<E1 as Engine>::Scalar::from(2),
<E1 as Engine>::Scalar::from(3),
<E1 as Engine>::Scalar::from(4),
<E1 as Engine>::Scalar::from(5),
<E1 as Engine>::Scalar::from(6),
<E1 as Engine>::Scalar::from(7),
<E1 as Engine>::Scalar::from(8),
<E1 as Engine>::Scalar::from(9),
<E1 as Engine>::Scalar::from(10),
],
None,
)
.unwrap();
let z0_primary = &[
&[<E1 as Engine>::Scalar::zero()],
&inputs[..NUM_ITERS_PER_STEP],
]
.concat();

let chunk_circuit = C1::new(inner_chunk_circuit);
let intermediate_inputs = &inputs[NUM_ITERS_PER_STEP..];

let circuit_secondary = <C1 as NonUniformCircuit<E1>>::secondary_circuit(&chunk_circuit);
// Different instantiations of circuit for each of the nova fold steps
let chunk_circuit = C1::new(intermediate_inputs);

// produce non-deterministic hint
assert_eq!(
<C1 as NonUniformCircuit<E1>>::num_circuits(&chunk_circuit),
5
);
let circuit_secondary = <C1 as NonUniformCircuit<E1>>::secondary_circuit(&chunk_circuit);

let z0_secondary = vec![<Dual<E1> as Engine>::Scalar::ZERO];

println!(
"Proving {} iterations of Chunk per step",
<C1 as NonUniformCircuit<E1>>::num_circuits(&chunk_circuit)
inputs.len() / NUM_ITERS_PER_STEP + 1
);

// produce public parameters
Expand All @@ -237,18 +241,32 @@ fn main() {

let start = Instant::now();

for step in 0..<C1 as NonUniformCircuit<E1>>::num_circuits(&chunk_circuit) {
let circuit_primary = <C1 as NonUniformCircuit<E1>>::primary_circuit(&chunk_circuit, step);

// We +1 the number of folding steps to account for the modulo of intermediate_inputs.len() by NUM_ITERS_PER_STEP being != 0
for step in 0..inputs.len() / NUM_ITERS_PER_STEP + 1 {
dbg!(format!(
"-----------------------------------{}-------------------------------------",
step
));
let circuit_primary = chunk_circuit.get_iteration_circuit(step);
dbg!(chunk_circuit.get_iteration_step(step).next_input());
let res = recursive_snark.prove_step(&pp, &circuit_primary, &circuit_secondary);
dbg!(&res);
assert!(res.is_ok());
println!(
"RecursiveSNARK::prove_step {}: {:?}, took {:?} ",
step,
res.is_ok(),
start.elapsed()
);

let res = recursive_snark.verify(&pp, &z0_primary, &z0_secondary);
dbg!(&res);
assert!(res.is_ok());
}
assert_eq!(
&<E1 as Engine>::Scalar::from(55),
recursive_snark.zi_primary().first().unwrap()
);
println!(
"Calculated sum: {:?}",
recursive_snark.zi_primary().first().unwrap()
Expand Down
Loading

0 comments on commit 19ba9a9

Please sign in to comment.