diff --git a/crates/chunk/examples/chunk_add_nivc.rs b/crates/chunk/examples/chunk_add_nivc.rs index 83f859f..f3e351b 100644 --- a/crates/chunk/examples/chunk_add_nivc.rs +++ b/crates/chunk/examples/chunk_add_nivc.rs @@ -4,8 +4,8 @@ use arecibo::supernova::{ }; use arecibo::traits::snark::default_ck_hint; use arecibo::traits::{CurveCycleEquipped, Dual, Engine}; -use bellpepper_chunk::traits::{ChunkCircuitInner, ChunkStepCircuit}; -use bellpepper_chunk::{FoldStep, InnerCircuit}; +use bellpepper_chunk::traits::ChunkStepCircuit; +use bellpepper_chunk::IterationStep; use bellpepper_core::num::AllocatedNum; use bellpepper_core::{ConstraintSystem, SynthesisError}; use ff::{Field, PrimeField}; @@ -60,18 +60,20 @@ impl ChunkStepCircuit for ChunkStep { // NIVC `StepCircuit`` implementation #[derive(Clone, Debug)] -struct FoldStepWrapper, const N: usize> { - inner: FoldStep, +struct IterationStepWrapper, const N: usize> { + inner: IterationStep, } -impl, const N: usize> FoldStepWrapper { - pub fn new(fold_step: FoldStep) -> Self { - Self { inner: fold_step } +impl, const N: usize> IterationStepWrapper { + pub fn new(iteration_step: IterationStep) -> Self { + Self { + inner: iteration_step, + } } } impl, const N: usize> StepCircuit - for FoldStepWrapper + for IterationStepWrapper { fn arity(&self) -> usize { self.inner.arity() @@ -89,7 +91,7 @@ impl, const N: usize> StepCircuit ) -> Result<(Option>, Vec>), SynthesisError> { let (next_pc, res_inner_synth) = self.inner - .synthesize(&mut cs.namespace(|| "fold_step_wrapper"), pc, z)?; + .synthesize(&mut cs.namespace(|| "iteration_step_wrapper"), pc, z)?; Ok((next_pc, res_inner_synth)) } @@ -97,18 +99,28 @@ impl, const N: usize> StepCircuit // NIVC `NonUniformCircuit` implementation struct ChunkCircuit, const N: usize> { - inner: InnerCircuit, + iteration_steps: Vec>, } impl, const N: usize> ChunkCircuit { - pub fn new(inner: InnerCircuit) -> Self { - Self { inner } + pub fn new(inputs: &[F]) -> Self { + Self { + iteration_steps: IterationStep::from_inputs(0, inputs, F::ZERO).unwrap(), + } + } + + fn get_iteration_step(&self, step: usize) -> IterationStep { + self.iteration_steps[step].clone() + } + + fn get_iteration_circuit(&self, step: usize) -> ChunkCircuitSet { + ChunkCircuitSet::IterationStep(IterationStepWrapper::new(self.get_iteration_step(step))) } } #[derive(Clone, Debug)] enum ChunkCircuitSet, const N: usize> { - IterStep(FoldStepWrapper), + IterationStep(IterationStepWrapper), } impl, const N: usize> StepCircuit @@ -116,13 +128,13 @@ impl, const N: usize> StepCircuit { fn arity(&self) -> usize { match self { - Self::IterStep(fold_step) => fold_step.inner.arity(), + Self::IterationStep(iteration_step) => iteration_step.inner.arity(), } } fn circuit_index(&self) -> usize { match self { - Self::IterStep(fold_step) => *fold_step.inner.step_nbr(), + Self::IterationStep(iteration_step) => *iteration_step.inner.circuit_index(), } } @@ -133,7 +145,7 @@ impl, const N: usize> StepCircuit z: &[AllocatedNum], ) -> Result<(Option>, Vec>), SynthesisError> { match self { - Self::IterStep(fold_step) => fold_step.synthesize(cs, pc, z), + Self::IterationStep(iteration_step) => iteration_step.synthesize(cs, pc, z), } } } @@ -145,14 +157,16 @@ impl, const N: usize> No type C2 = TrivialSecondaryCircuit< as Engine>::Scalar>; fn num_circuits(&self) -> usize { - self.inner.num_fold_steps() + 1 } fn primary_circuit(&self, circuit_index: usize) -> Self::C1 { - if let Some(fold_step) = self.inner.circuits().get(circuit_index) { - return Self::C1::IterStep(FoldStepWrapper::new(fold_step.clone())); + match circuit_index { + 0 => { + Self::C1::IterationStep(IterationStepWrapper::new(self.iteration_steps[0].clone())) + } + _ => panic!("No circuit found for index {}", circuit_index), } - unreachable!() } fn secondary_circuit(&self) -> Self::C2 { @@ -163,54 +177,44 @@ impl, const N: usize> No fn main() { const NUM_ITERS_PER_STEP: usize = 3; - type Inner = - InnerCircuit<::Scalar, ChunkStep<::Scalar>, NUM_ITERS_PER_STEP>; type C1 = ChunkCircuit<::Scalar, ChunkStep<::Scalar>, NUM_ITERS_PER_STEP>; println!("NIVC addition accumulator with a Chunk pattern"); println!("========================================================="); - let z0_primary = vec![ - ::Scalar::zero(), - ::Scalar::zero(), - ::Scalar::zero(), + let inputs = vec![ ::Scalar::zero(), + ::Scalar::one(), + ::Scalar::from(2), + ::Scalar::from(3), + ::Scalar::from(4), + ::Scalar::from(5), + ::Scalar::from(6), + ::Scalar::from(7), + ::Scalar::from(8), + ::Scalar::from(9), + ::Scalar::from(10), ]; - // Different instantiations of circuit for each of the nova fold steps - let inner_chunk_circuit = Inner::new( - &[ - ::Scalar::one(), - ::Scalar::from(2), - ::Scalar::from(3), - ::Scalar::from(4), - ::Scalar::from(5), - ::Scalar::from(6), - ::Scalar::from(7), - ::Scalar::from(8), - ::Scalar::from(9), - ::Scalar::from(10), - ], - None, - ) - .unwrap(); + let z0_primary = &[ + &[::Scalar::zero()], + &inputs[..NUM_ITERS_PER_STEP], + ] + .concat(); - let chunk_circuit = C1::new(inner_chunk_circuit); + let intermediate_inputs = &inputs[NUM_ITERS_PER_STEP..]; - let circuit_secondary = >::secondary_circuit(&chunk_circuit); + // Different instantiations of circuit for each of the nova fold steps + let chunk_circuit = C1::new(intermediate_inputs); - // produce non-deterministic hint - assert_eq!( - >::num_circuits(&chunk_circuit), - 5 - ); + let circuit_secondary = >::secondary_circuit(&chunk_circuit); let z0_secondary = vec![ as Engine>::Scalar::ZERO]; println!( "Proving {} iterations of Chunk per step", - >::num_circuits(&chunk_circuit) + inputs.len() / NUM_ITERS_PER_STEP + 1 ); // produce public parameters @@ -237,10 +241,16 @@ fn main() { let start = Instant::now(); - for step in 0..>::num_circuits(&chunk_circuit) { - let circuit_primary = >::primary_circuit(&chunk_circuit, step); - + // We +1 the number of folding steps to account for the modulo of intermediate_inputs.len() by NUM_ITERS_PER_STEP being != 0 + for step in 0..inputs.len() / NUM_ITERS_PER_STEP + 1 { + dbg!(format!( + "-----------------------------------{}-------------------------------------", + step + )); + let circuit_primary = chunk_circuit.get_iteration_circuit(step); + dbg!(chunk_circuit.get_iteration_step(step).next_input()); let res = recursive_snark.prove_step(&pp, &circuit_primary, &circuit_secondary); + dbg!(&res); assert!(res.is_ok()); println!( "RecursiveSNARK::prove_step {}: {:?}, took {:?} ", @@ -248,7 +258,15 @@ fn main() { res.is_ok(), start.elapsed() ); + + let res = recursive_snark.verify(&pp, &z0_primary, &z0_secondary); + dbg!(&res); + assert!(res.is_ok()); } + assert_eq!( + &::Scalar::from(55), + recursive_snark.zi_primary().first().unwrap() + ); println!( "Calculated sum: {:?}", recursive_snark.zi_primary().first().unwrap() diff --git a/crates/chunk/examples/chunk_merkle_proving.rs b/crates/chunk/examples/chunk_merkle_proving.rs index d92512f..1371307 100644 --- a/crates/chunk/examples/chunk_merkle_proving.rs +++ b/crates/chunk/examples/chunk_merkle_proving.rs @@ -7,8 +7,8 @@ use arecibo::traits::{CurveCycleEquipped, Dual, Engine}; use bellpepper::gadgets::boolean::Boolean; use bellpepper::gadgets::multipack::{bytes_to_bits_le, compute_multipacking, pack_bits}; use bellpepper::gadgets::num::AllocatedNum; -use bellpepper_chunk::traits::{ChunkCircuitInner, ChunkStepCircuit}; -use bellpepper_chunk::{FoldStep, InnerCircuit}; +use bellpepper_chunk::traits::ChunkStepCircuit; +use bellpepper_chunk::IterationStep; use bellpepper_core::{ConstraintSystem, SynthesisError}; use bellpepper_keccak::sha3; use bellpepper_merkle_inclusion::traits::GadgetDigest; @@ -86,22 +86,31 @@ fn reconstruct_hash>( *****************************************/ struct MerkleChunkCircuit, const N: usize> { - inner: InnerCircuit, + pub(crate) iteration_steps: Vec>, } impl, const N: usize> MerkleChunkCircuit { - fn new(inputs: &[F], post_processing_step: Option) -> Self { + fn new(inputs: &[F]) -> Self { Self { - inner: InnerCircuit::new(inputs, post_processing_step).unwrap(), + // We expect EqualityCircuit to be called once the last `IterationStep` is done. + iteration_steps: IterationStep::from_inputs(0, inputs, F::ONE).unwrap(), } } + + fn get_iteration_step(&self, step: usize) -> IterationStep { + self.iteration_steps[step].clone() + } + + fn get_iteration_circuit(&self, step: usize) -> ChunkCircuitSet { + ChunkCircuitSet::IterationStep(IterationStepWrapper::new(self.get_iteration_step(step))) + } } #[derive(Clone, Debug)] enum ChunkCircuitSet, const N: usize> { - IterStep(FoldStepWrapper), + IterationStep(IterationStepWrapper), CheckEquality(EqualityCircuit), } @@ -110,15 +119,15 @@ impl, const N: usize> Ste { fn arity(&self) -> usize { match self { - Self::IterStep(fold_step) => fold_step.inner.arity(), + Self::IterationStep(iteration_step) => iteration_step.inner.arity(), Self::CheckEquality(equality_circuit) => equality_circuit.arity(), } } fn circuit_index(&self) -> usize { match self { - Self::IterStep(fold_step) => *fold_step.inner.step_nbr(), - Self::CheckEquality(equality_circuit) => equality_circuit.circuit_index(), + Self::IterationStep(iteration_step) => *iteration_step.inner.circuit_index(), + Self::CheckEquality(equality_circuit) => equality_circuit.circuit_index(),² } } @@ -129,7 +138,7 @@ impl, const N: usize> Ste z: &[AllocatedNum], ) -> Result<(Option>, Vec>), SynthesisError> { match self { - Self::IterStep(fold_step) => fold_step.synthesize(cs, pc, z), + Self::IterationStep(iteration_step) => iteration_step.synthesize(cs, pc, z), Self::CheckEquality(equality_circuit) => equality_circuit.synthesize(cs, pc, z), } } @@ -142,17 +151,14 @@ impl, const N: usize> No type C2 = TrivialSecondaryCircuit< as Engine>::Scalar>; fn num_circuits(&self) -> usize { - self.inner.num_fold_steps() + 1 + 2 } fn primary_circuit(&self, circuit_index: usize) -> Self::C1 { - if circuit_index == 2 { - Self::C1::CheckEquality(EqualityCircuit::new()) - } else { - if let Some(fold_step) = self.inner.circuits().get(circuit_index) { - return Self::C1::IterStep(FoldStepWrapper::new(fold_step.clone())); - } - panic!("No circuit found for index {}", circuit_index) + match circuit_index { + 0 => self.get_iteration_circuit(0), + 1 => Self::C1::CheckEquality(EqualityCircuit::new()), + _ => panic!("No circuit found for index {}", circuit_index), } } @@ -195,23 +201,26 @@ impl ChunkStepCircuit for ChunkStep { .to_bits_le(&mut cs.namespace(|| "get positional bit")) .unwrap()[0]; - let acc = reconstruct_hash(&mut cs.namespace(|| "reconstruct acc hash"), &z[0..2], 256); - - let sibling = reconstruct_hash( - &mut cs.namespace(|| "reconstruct_sibling_hash"), - &chunk_in[1..3], - 256, - ); - - let new_acc = conditional_hash::<_, _, Sha3>( - &mut cs.namespace(|| "conditional_hash"), - &acc, - &sibling, - boolean, - )?; + let mut acc = reconstruct_hash(&mut cs.namespace(|| "reconstruct acc hash"), &z[0..2], 256); + + // The inputs we handle for one inner iterations are multiple of 3. + for chunk in chunk_in.chunks(3) { + let sibling = reconstruct_hash( + &mut cs.namespace(|| "reconstruct_sibling_hash"), + &chunk[1..3], + 256, + ); + + acc = conditional_hash::<_, _, Sha3>( + &mut cs.namespace(|| "conditional_hash"), + &acc, + &sibling, + boolean, + )?; + } - let new_acc_f_1 = pack_bits(&mut cs.namespace(|| "pack_bits new_acc 1"), &new_acc[..253])?; - let new_acc_f_2 = pack_bits(&mut cs.namespace(|| "pack_bits new_acc 2"), &new_acc[253..])?; + let new_acc_f_1 = pack_bits(&mut cs.namespace(|| "pack_bits new_acc 1"), &acc[..253])?; + let new_acc_f_2 = pack_bits(&mut cs.namespace(|| "pack_bits new_acc 2"), &acc[253..])?; let z_out = vec![new_acc_f_1, new_acc_f_2, z[2].clone(), z[3].clone()]; @@ -220,25 +229,27 @@ impl ChunkStepCircuit for ChunkStep { } #[derive(Clone, Debug)] -struct FoldStepWrapper, const N: usize> { - inner: FoldStep, +struct IterationStepWrapper, const N: usize> { + inner: IterationStep, } -impl, const N: usize> FoldStepWrapper { - pub fn new(fold_step: FoldStep) -> Self { - Self { inner: fold_step } +impl, const N: usize> IterationStepWrapper { + pub fn new(iteration_step: IterationStep) -> Self { + Self { + inner: iteration_step, + } } } impl, const N: usize> StepCircuit - for FoldStepWrapper + for IterationStepWrapper { fn arity(&self) -> usize { self.inner.arity() } fn circuit_index(&self) -> usize { - *self.inner.step_nbr() + *self.inner.circuit_index() } fn synthesize>( @@ -247,11 +258,8 @@ impl, const N: usize> StepCircuit pc: Option<&AllocatedNum>, z: &[AllocatedNum], ) -> Result<(Option>, Vec>), SynthesisError> { - let (next_pc, res_inner_synth) = - self.inner - .synthesize(&mut cs.namespace(|| "fold_step_wrapper"), pc, z)?; - - Ok((next_pc, res_inner_synth)) + self.inner + .synthesize(&mut cs.namespace(|| "iteration_step_wrapper"), pc, z) } } @@ -274,7 +282,7 @@ impl StepCircuit for EqualityCircuit { } fn circuit_index(&self) -> usize { - 2 + 1 } fn synthesize>( @@ -315,6 +323,8 @@ impl StepCircuit for EqualityCircuit { } } +const NBR_CHUNK_INPUT: usize = 3; + fn main() { // produce public parameters let start = Instant::now(); @@ -351,11 +361,12 @@ fn main() { intermediate_key_hashes.append(&mut intermediate_hashes[2..4].to_vec()); // Primary circuit - type C1 = MerkleChunkCircuit<::Scalar, ChunkStep<::Scalar>, 3>; - let chunk_circuit = C1::new( - &intermediate_key_hashes[3..6], - Some(::Scalar::from(2)), - ); + type C1 = MerkleChunkCircuit< + ::Scalar, + ChunkStep<::Scalar>, + NBR_CHUNK_INPUT, + >; + let chunk_circuit = C1::new(&intermediate_key_hashes[NBR_CHUNK_INPUT..]); // Multipacking the leaf and root hashes let mut z0_primary = @@ -365,7 +376,7 @@ fn main() { // The accumulator elements are initialized to 0 z0_primary.append(&mut root_fields.clone()); - z0_primary.append(&mut intermediate_key_hashes[0..3].to_vec()); + z0_primary.append(&mut intermediate_key_hashes[..NBR_CHUNK_INPUT].to_vec()); let circuit_primary = >::primary_circuit(&chunk_circuit, 0); @@ -389,8 +400,15 @@ fn main() { ) .unwrap(); - for step in 0..>::num_circuits(&chunk_circuit) { - let circuit_primary = >::primary_circuit(&chunk_circuit, step); + // We expect nbr_inputs/chunk_value + 1 (post processning circuit) iterations. + for step in 0..intermediate_key_hashes.len() / NBR_CHUNK_INPUT + 1 { + let circuit_primary = if step == intermediate_key_hashes.len() / NBR_CHUNK_INPUT { + // Check equality + >::primary_circuit(&chunk_circuit, 1) + } else { + // Iteration step + chunk_circuit.get_iteration_circuit(step) + }; let res = recursive_snark.prove_step(&pp, &circuit_primary, &circuit_secondary); assert!(res.is_ok()); diff --git a/crates/chunk/src/lib.rs b/crates/chunk/src/lib.rs index 35774d3..ff94924 100644 --- a/crates/chunk/src/lib.rs +++ b/crates/chunk/src/lib.rs @@ -1,5 +1,5 @@ use crate::error::ChunkError; -use crate::traits::{ChunkCircuitInner, ChunkStepCircuit}; +use crate::traits::ChunkStepCircuit; use bellpepper_core::num::AllocatedNum; use bellpepper_core::{ConstraintSystem, SynthesisError}; use ff::PrimeField; @@ -8,40 +8,43 @@ use getset::Getters; pub mod error; pub mod traits; -/// `FoldStep` is the wrapper struct for a `ChunkStepCircuit` implemented by a user. It exist to synthesize multiple of -/// the `ChunkStepCircuit` instances at once. +/// `IterationStep` is the wrapper struct for a `ChunkStepCircuit` implemented by a user. #[derive(Eq, PartialEq, Debug, Getters)] #[getset(get = "pub")] -pub struct FoldStep + Clone, const N: usize> { - /// The step number of the `FoldStep` in the circuit. - step_nbr: usize, - /// Next circuit index. - next_circuit: Option, +pub struct IterationStep + Clone, const N: usize> { + /// The circuit index in the higher order `NonUniformCircuit`. + circuit_index: usize, /// The `ChunkStepCircuit` instance to be used in the `FoldStep`. circuit: C, + /// The step number of the `FoldStep` in the circuit. + step_nbr: usize, /// Number of input to be expected input_nbr: usize, /// The next input values for the next `ChunkStepCircuit` instance. next_input: [F; N], + /// Next program counter. + next_pc: F, } -impl + Clone, const N: usize> FoldStep { +impl + Clone, const N: usize> IterationStep { pub fn arity(&self) -> usize { N + C::arity() } pub fn new( + circuit_index: usize, circuit: C, inputs: [F; N], input_nbr: usize, step_nbr: usize, - next_circuit: Option, + next_circuit: F, ) -> Self { Self { + circuit_index, circuit, next_input: inputs, input_nbr, step_nbr, - next_circuit, + next_pc: next_circuit, } } @@ -63,16 +66,12 @@ impl + Clone, const N: usize> FoldStep { - AllocatedNum::alloc_infallible(cs.namespace(|| "next_circuit"), || *next_circuit) - } - None => AllocatedNum::alloc_infallible(cs.namespace(|| "next_circuit"), || F::ZERO), - }; + let next_pc = + AllocatedNum::alloc_infallible(cs.namespace(|| "next_circuit"), || self.next_pc); // Next input let next_inputs_allocated = self @@ -87,38 +86,12 @@ impl + Clone, const N: usize> FoldStep, const N: usize> Clone for FoldStep { - fn clone(&self) -> Self { - Self { - step_nbr: self.step_nbr, - circuit: self.circuit.clone(), - next_input: self.next_input, - next_circuit: self.next_circuit, - input_nbr: self.input_nbr, - } - } -} - -/// `Circuit` is a helper structure that handles the plumbing of generating the necessary number of `FoldStep` instances -/// to properly prove and verifiy a circuit. -#[derive(Debug, Getters)] -pub struct InnerCircuit, const N: usize> { - /// The `FoldStep` instances that are part of the circuit. - #[getset(get = "pub")] - circuits: Vec>, - /// The number of folding step required in the recursive snark to prove and verify the circuit. - num_fold_steps: usize, -} - -impl, const N: usize> ChunkCircuitInner - for InnerCircuit -{ - fn new( + pub fn from_inputs( + circuit_index: usize, intermediate_steps_input: &[F], - post_processing_circuit: Option, - ) -> anyhow::Result { + pc_post_iter: F, + ) -> anyhow::Result> { // We generate the `FoldStep` instances that are part of the circuit. let mut circuits = intermediate_steps_input .chunks(N) @@ -132,19 +105,21 @@ impl, const N: usize> ChunkCircuitInner, ChunkError>>()?; // As the input represents the generated values by the inner loop, we need to add one more execution to have // a complete circuit and a proper accumulator value. - circuits.push(FoldStep::new( + circuits.push(IterationStep::new( + circuit_index, C::new(), [F::ZERO; N], if intermediate_steps_input.len() % N != 0 { @@ -153,25 +128,22 @@ impl, const N: usize> ChunkCircuitInner Option<&FoldStep> { - self.circuits.first() + Ok(circuits) } +} - fn num_fold_steps(&self) -> usize { - self.num_fold_steps +impl, const N: usize> Clone for IterationStep { + fn clone(&self) -> Self { + Self { + circuit_index: self.circuit_index, + step_nbr: self.step_nbr, + circuit: self.circuit.clone(), + next_input: self.next_input, + next_pc: self.next_pc, + input_nbr: self.input_nbr, + } } } diff --git a/crates/chunk/src/traits.rs b/crates/chunk/src/traits.rs index cde7d61..0751913 100644 --- a/crates/chunk/src/traits.rs +++ b/crates/chunk/src/traits.rs @@ -1,5 +1,3 @@ -use crate::error::ChunkError; -use crate::FoldStep; use bellpepper_core::num::AllocatedNum; use bellpepper_core::{ConstraintSystem, SynthesisError}; use ff::PrimeField; @@ -30,26 +28,3 @@ pub trait ChunkStepCircuit: Clone + Sync + Send + Debug + Partial chunk_in: &[AllocatedNum], ) -> Result>, SynthesisError>; } - -/// `ChunkCircuit` is the trait used to interface with a circuit that is composed of a loop of steps. -pub trait ChunkCircuitInner, const N: usize> { - /// `new` must return a new instance of the chunk circuit. - /// # Arguments - /// * `intermediate_steps_input` - The intermediate input values for each of the step circuits. - /// * `post_processing_circuit` - The post processing circuit to be used after the loop of steps. - /// - /// # Note - /// - /// As `intermediate_steps_input` represents the input values for each of the step circuits, there is currently a need - /// to generate one last `FoldStep` instance to represent the last step in the circuit. - fn new( - intermediate_steps_input: &[F], - post_processing_circuit: Option, - ) -> anyhow::Result - where - Self: Sized; - /// `initial_input` must return the first circuit to be proven/verified. - fn initial_input(&self) -> Option<&FoldStep>; - /// `num_fold_steps` must return the number of recursive snark step necessary to prove and verify the circuit. - fn num_fold_steps(&self) -> usize; -} diff --git a/crates/chunk/tests/gadget.rs b/crates/chunk/tests/gadget.rs index 8a130bf..3ba8e50 100644 --- a/crates/chunk/tests/gadget.rs +++ b/crates/chunk/tests/gadget.rs @@ -1,8 +1,7 @@ use arecibo::provider::Bn256EngineKZG; -use arecibo::supernova::{NonUniformCircuit, StepCircuit, TrivialSecondaryCircuit}; -use arecibo::traits::{CurveCycleEquipped, Dual, Engine}; -use bellpepper_chunk::traits::{ChunkCircuitInner, ChunkStepCircuit}; -use bellpepper_chunk::{FoldStep, InnerCircuit}; +use arecibo::traits::Engine; +use bellpepper_chunk::traits::ChunkStepCircuit; +use bellpepper_chunk::IterationStep; use bellpepper_core::num::AllocatedNum; use bellpepper_core::{ConstraintSystem, SynthesisError}; use ff::PrimeField; @@ -39,110 +38,14 @@ impl ChunkStepCircuit for ChunkStep { } } -// NIVC `StepCircuit`` implementation -#[derive(Clone, Debug)] -struct FoldStepWrapper, const N: usize> { - inner: FoldStep, -} - -impl, const N: usize> FoldStepWrapper { - pub fn new(fold_step: FoldStep) -> Self { - Self { inner: fold_step } - } -} - -impl, const N: usize> StepCircuit - for FoldStepWrapper -{ - fn arity(&self) -> usize { - self.inner.arity() - } - - fn circuit_index(&self) -> usize { - *self.inner.step_nbr() - } - - fn synthesize>( - &self, - cs: &mut CS, - pc: Option<&AllocatedNum>, - z: &[AllocatedNum], - ) -> Result<(Option>, Vec>), SynthesisError> { - let (next_pc, res_inner_synth) = - self.inner - .synthesize(&mut cs.namespace(|| "fold_step_wrapper"), pc, z)?; - - Ok((next_pc, res_inner_synth)) - } -} - -// NIVC `NonUniformCircuit` implementation -struct ChunkCircuit, const N: usize> { - inner: InnerCircuit, -} - -#[derive(Clone, Debug)] -enum ChunkCircuitSet, const N: usize> { - IterStep(FoldStepWrapper), -} - -impl, const N: usize> StepCircuit - for ChunkCircuitSet -{ - fn arity(&self) -> usize { - match self { - Self::IterStep(fold_step) => fold_step.inner.arity(), - } - } - - fn circuit_index(&self) -> usize { - match self { - Self::IterStep(fold_step) => *fold_step.inner.step_nbr(), - } - } - - fn synthesize>( - &self, - cs: &mut CS, - pc: Option<&AllocatedNum>, - z: &[AllocatedNum], - ) -> Result<(Option>, Vec>), SynthesisError> { - match self { - Self::IterStep(fold_step) => fold_step.synthesize(cs, pc, z), - } - } -} - -impl, const N: usize> NonUniformCircuit - for ChunkCircuit -{ - type C1 = ChunkCircuitSet; - type C2 = TrivialSecondaryCircuit< as Engine>::Scalar>; - - fn num_circuits(&self) -> usize { - self.inner.num_fold_steps() - } - - fn primary_circuit(&self, circuit_index: usize) -> Self::C1 { - if let Some(fold_step) = self.inner.circuits().get(circuit_index) { - return Self::C1::IterStep(FoldStepWrapper::new(fold_step.clone())); - } - unreachable!() - } - - fn secondary_circuit(&self) -> Self::C2 { - Default::default() - } -} - fn verify_chunk_circuit, const N: usize>() { let test_inputs = vec![F::ONE; 18]; let expected = (test_inputs.len() / N) + if test_inputs.len() % N != 0 { 2 } else { 1 }; - let circuit = InnerCircuit::::new(&test_inputs, None).unwrap(); + let circuits = IterationStep::::from_inputs(0, &test_inputs, F::ONE).unwrap(); - let actual = circuit.num_fold_steps(); + let actual = circuits.len(); assert_eq!( &expected, &actual, @@ -158,26 +61,42 @@ fn verify_chunk_circuit, const N: usize>() } } - let actual_init = circuit.initial_input().unwrap(); - let expected_init = - FoldStep::::new(C::new(), expected_first_chunk, N, 0, Some(F::ONE)); + let actual_first = &circuits[0]; + let expected_first = + IterationStep::::new(0, C::new(), expected_first_chunk, N, 0, F::ZERO); assert_eq!( - *actual_init, expected_init, - "Expected initial input to be {:?}, got {:?}", - expected_init, actual_init + actual_first, &expected_first, + "Expected first iteration step to be {:?}, got {:?}", + expected_first, actual_first ); - let actual_circuits = circuit.circuits(); - - for (i, actual_circuit) in actual_circuits.iter().enumerate() { + for (i, circuit) in circuits[..circuits.len() - 1].iter().enumerate() { assert_eq!( &i, - actual_circuit.step_nbr(), + circuit.step_nbr(), "Expected inner step nbr to be {:?}, got {:?}", i, - actual_circuit.step_nbr() + circuit.step_nbr() ); + + if i == circuits.len() - 1 { + assert_eq!( + &F::ONE, + circuit.next_pc(), + "Expected inner step nbr to be {:?}, got {:?}", + F::from(1), + circuit.next_pc() + ); + } else { + assert_eq!( + &F::ZERO, + circuit.next_pc(), + "Expected inner step nbr to be {:?}, got {:?}", + F::from(0), + circuit.next_pc() + ); + } } }