Skip to content

Commit

Permalink
sharded func chips, execute
Browse files Browse the repository at this point in the history
  • Loading branch information
gabriel-barrett committed Nov 8, 2024
1 parent 0e313cf commit ce11eaa
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 36 deletions.
12 changes: 3 additions & 9 deletions benches/fib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ fn evaluation(c: &mut Criterion) {
b.iter_batched(
|| (args.clone(), record.clone()),
|(args, mut queries)| {
toplevel
.execute(lurk_main.func(), &args, &mut queries, None)
.unwrap();
lurk_main.execute(&args, &mut queries, None).unwrap();
},
BatchSize::SmallInput,
)
Expand All @@ -90,9 +88,7 @@ fn trace_generation(c: &mut Criterion) {
c.bench_function(&format!("fib-trace-generation-{arg}"), |b| {
let (toplevel, ..) = build_lurk_toplevel_native();
let (args, lurk_main, mut record) = setup(arg, &toplevel);
toplevel
.execute(lurk_main.func(), &args, &mut record, None)
.unwrap();
lurk_main.execute(&args, &mut record, None).unwrap();
let lair_chips = build_lair_chip_vector(&lurk_main);
b.iter(|| {
lair_chips.par_iter().for_each(|func_chip| {
Expand All @@ -112,9 +108,7 @@ fn e2e(c: &mut Criterion) {
b.iter_batched(
|| (record.clone(), args.clone()),
|(mut record, args)| {
toplevel
.execute(lurk_main.func(), &args, &mut record, None)
.unwrap();
lurk_main.execute(&args, &mut record, None).unwrap();
let config = BabyBearPoseidon2::new();
let machine = StarkMachine::new(
config,
Expand Down
12 changes: 3 additions & 9 deletions benches/lcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,7 @@ fn evaluation(c: &mut Criterion) {
b.iter_batched(
|| (args.clone(), record.clone()),
|(args, mut queries)| {
toplevel
.execute(lurk_main.func(), &args, &mut queries, None)
.unwrap();
lurk_main.execute(&args, &mut queries, None).unwrap();
},
BatchSize::SmallInput,
)
Expand All @@ -94,9 +92,7 @@ fn trace_generation(c: &mut Criterion) {
c.bench_function("lcs-trace-generation", |b| {
let (toplevel, ..) = build_lurk_toplevel_native();
let (args, lurk_main, mut record) = setup(args.0, args.1, &toplevel);
toplevel
.execute(lurk_main.func(), &args, &mut record, None)
.unwrap();
lurk_main.execute(&args, &mut record, None).unwrap();
let lair_chips = build_lair_chip_vector(&lurk_main);
b.iter(|| {
lair_chips.par_iter().for_each(|func_chip| {
Expand All @@ -116,9 +112,7 @@ fn e2e(c: &mut Criterion) {
b.iter_batched(
|| (record.clone(), args.clone()),
|(mut record, args)| {
toplevel
.execute(lurk_main.func(), &args, &mut record, None)
.unwrap();
lurk_main.execute(&args, &mut record, None).unwrap();
let config = BabyBearPoseidon2::new();
let machine = StarkMachine::new(
config,
Expand Down
12 changes: 3 additions & 9 deletions benches/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,7 @@ fn evaluation(c: &mut Criterion) {
b.iter_batched(
|| (args.clone(), record.clone()),
|(args, mut queries)| {
toplevel
.execute(lurk_main.func(), &args, &mut queries, None)
.unwrap();
lurk_main.execute(&args, &mut queries, None).unwrap();
},
BatchSize::SmallInput,
)
Expand All @@ -95,9 +93,7 @@ fn trace_generation(c: &mut Criterion) {
c.bench_function(&format!("sum-trace-generation-{arg}"), |b| {
let (toplevel, ..) = build_lurk_toplevel_native();
let (args, lurk_main, mut record) = setup(arg, &toplevel);
toplevel
.execute(lurk_main.func(), &args, &mut record, None)
.unwrap();
lurk_main.execute(&args, &mut record, None).unwrap();
let lair_chips = build_lair_chip_vector(&lurk_main);
b.iter(|| {
lair_chips.par_iter().for_each(|func_chip| {
Expand All @@ -117,9 +113,7 @@ fn e2e(c: &mut Criterion) {
b.iter_batched(
|| (record.clone(), args.clone()),
|(mut record, args)| {
toplevel
.execute(lurk_main.func(), &args, &mut record, None)
.unwrap();
lurk_main.execute(&args, &mut record, None).unwrap();
let config = BabyBearPoseidon2::new();
let machine = StarkMachine::new(
config,
Expand Down
4 changes: 1 addition & 3 deletions src/core/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ fn run_tests<C2: Chipset<F>>(
input[16..].copy_from_slice(&env.digest);

let lurk_main = FuncChip::from_name("lurk_main", toplevel);
let result = toplevel
.execute(lurk_main.func, &input, &mut record, None)
.unwrap();
let result = lurk_main.execute(&input, &mut record, None).unwrap();

assert_eq!(result.as_ref(), &expected_cloj(zstore).flatten());

Expand Down
15 changes: 15 additions & 0 deletions src/lair/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use super::{
bytecode::{Ctrl, Func, Op},
chipset::Chipset,
expr::ReturnGroup,
func_chip::FuncChip,
toplevel::Toplevel,
FxIndexMap, List,
};
Expand Down Expand Up @@ -374,6 +375,20 @@ impl<F: PrimeField32> QueryRecord<F> {
}
}

impl<F: PrimeField32, C1: Chipset<F>, C2: Chipset<F>> FuncChip<'_, F, C1, C2> {
#[inline]
pub fn execute(
&self,
args: &[F],
queries: &mut QueryRecord<F>,
dbg_func_idx: Option<usize>,
) -> Result<List<F>> {
let toplevel = self.toplevel;
let func = toplevel.func_by_index(self.func.index);
toplevel.execute(func, args, queries, dbg_func_idx)
}
}

impl<F: PrimeField32, C1: Chipset<F>, C2: Chipset<F>> Toplevel<F, C1, C2> {
pub fn execute(
&self,
Expand Down
13 changes: 10 additions & 3 deletions src/lair/func_chip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@ pub struct FuncChip<'a, F, C1: Chipset<F>, C2: Chipset<F>> {
impl<'a, F, C1: Chipset<F>, C2: Chipset<F>> FuncChip<'a, F, C1, C2> {
#[inline]
pub fn from_name(name: &'static str, toplevel: &'a Toplevel<F, C1, C2>) -> Self {
let func = toplevel.func_by_name(name);
let main_group = 0;
let func = toplevel.sharded_func_by_name(name, main_group);
Self::from_func(func, toplevel)
}

#[inline]
pub fn from_index(idx: usize, toplevel: &'a Toplevel<F, C1, C2>) -> Self {
let func = toplevel.func_by_index(idx);
let main_group = 0;
let func = toplevel.sharded_func_by_index(idx, main_group);
Self::from_func(func, toplevel)
}

Expand All @@ -59,7 +61,12 @@ impl<'a, F, C1: Chipset<F>, C2: Chipset<F>> FuncChip<'a, F, C1, C2> {
toplevel
.func_map
.values()
.map(|funcs| FuncChip::from_func(&funcs.full_func, toplevel))
.flat_map(|funcs| {
funcs
.sharded_funcs
.values()
.map(|func| FuncChip::from_func(func, toplevel))
})
.collect()
}

Expand Down
21 changes: 21 additions & 0 deletions src/lair/toplevel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,17 @@ impl<F, C1: Chipset<F>, C2: Chipset<F>> Toplevel<F, C1, C2> {
.full_func
}

#[inline]
pub fn sharded_func_by_index(&self, i: usize, group: ReturnGroup) -> &Func<F> {
self.func_map
.get_index(i)
.unwrap_or_else(|| panic!("Func index {i} out of bounds"))
.1
.sharded_funcs
.get(&group)
.unwrap_or_else(|| panic!("Group {group} not found"))
}

#[inline]
pub fn func_by_name(&self, name: &'static str) -> &Func<F> {
&self
Expand All @@ -94,6 +105,16 @@ impl<F, C1: Chipset<F>, C2: Chipset<F>> Toplevel<F, C1, C2> {
.full_func
}

#[inline]
pub fn sharded_func_by_name(&self, name: &'static str, group: ReturnGroup) -> &Func<F> {
self.func_map
.get(&Name(name))
.unwrap_or_else(|| panic!("Func {name} not found"))
.sharded_funcs
.get(&group)
.unwrap_or_else(|| panic!("Group {group} not found"))
}

#[inline]
pub fn num_funcs(&self) -> usize {
self.func_map.len()
Expand Down
4 changes: 1 addition & 3 deletions tests/fib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,7 @@ fn fib_e2e() {
let (args, lurk_main, mut record) = setup(arg, &toplevel);
let start_time = Instant::now();

toplevel
.execute(lurk_main.func(), &args, &mut record, None)
.unwrap();
lurk_main.execute(&args, &mut record, None).unwrap();
let config = BabyBearPoseidon2::new();
let machine = StarkMachine::new(
config,
Expand Down

0 comments on commit ce11eaa

Please sign in to comment.