Skip to content

Commit

Permalink
sharded funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
gabriel-barrett committed Nov 8, 2024
1 parent e1ff74b commit 0e313cf
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 16 deletions.
1 change: 1 addition & 0 deletions src/lair/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ where
<AB as AirBuilder>::Var: Debug,
{
fn eval(&self, builder: &mut AB) {
assert!(self.func.sharded);
self.func.eval(builder, self.toplevel, self.layout_sizes)
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/lair/bytecode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ pub struct Func<F> {
pub(crate) input_size: usize,
pub(crate) output_size: usize,
pub(crate) body: Block<F>,
// This last field is purely to catch potential bugs
pub(crate) sharded: bool,
}

impl<F> Func<F> {
Expand Down
5 changes: 3 additions & 2 deletions src/lair/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,8 @@ impl<F: PrimeField32> QueryRecord<F> {
let inv_func_queries = toplevel
.func_map
.iter()
.map(|(_, func)| {
if func.invertible {
.map(|(_, funcs)| {
if funcs.full_func.invertible {
Some(FxHashMap::default())
} else {
None
Expand Down Expand Up @@ -382,6 +382,7 @@ impl<F: PrimeField32, C1: Chipset<F>, C2: Chipset<F>> Toplevel<F, C1, C2> {
queries: &mut QueryRecord<F>,
dbg_func_idx: Option<usize>,
) -> Result<List<F>> {
assert!(!func.sharded);
let (out, depth) = func.execute(args, self, queries, dbg_func_idx)?;
let mut public_values = Vec::with_capacity(args.len() + out.len());
public_values.extend(args);
Expand Down
3 changes: 2 additions & 1 deletion src/lair/func_chip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ impl<'a, F, C1: Chipset<F>, C2: Chipset<F>> FuncChip<'a, F, C1, C2> {
toplevel
.func_map
.values()
.map(|func| FuncChip::from_func(func, toplevel))
.map(|funcs| FuncChip::from_func(&funcs.full_func, toplevel))
.collect()
}

Expand Down Expand Up @@ -92,6 +92,7 @@ impl<F> Func<F> {
&self,
toplevel: &Toplevel<F, C1, C2>,
) -> LayoutSizes {
assert!(self.sharded);
let input = self.input_size;
// last nonce, last count
let mut aux = 2;
Expand Down
163 changes: 150 additions & 13 deletions src/lair/toplevel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,32 @@
use either::Either;
use p3_field::Field;
use rustc_hash::FxHashMap;
use rustc_hash::{FxHashMap, FxHashSet};

use super::{bytecode::*, chipset::Chipset, expr::*, map::Map, FxIndexMap, List, Name};

/// This struct holds the complete Lair function and the sharded functions, containing
/// only the paths belonging to the same group
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct FuncStruct<F> {
pub(crate) full_func: Func<F>,
pub(crate) sharded_funcs: FxHashMap<ReturnGroup, Func<F>>,
}

impl<F: Field + Ord> FuncStruct<F> {
pub fn from_func(full_func: Func<F>, groups: &FxHashSet<ReturnGroup>) -> Self {
let sharded_funcs = full_func.shard(groups);
Self {
full_func,
sharded_funcs,
}
}
}

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Toplevel<F, C1: Chipset<F>, C2: Chipset<F>> {
/// Lair functions reachable by the `Call` operator
pub(crate) func_map: FxIndexMap<Name, Func<F>>,
pub(crate) func_map: FxIndexMap<Name, FuncStruct<F>>,
/// Extern chips reachable by the `ExternCall` operator. The two different
/// chipset types can be used to encode native and custom chips.
pub(crate) chip_map: FxIndexMap<Name, Either<C1, C2>>,
Expand Down Expand Up @@ -42,8 +60,9 @@ impl<F: Field + Ord, C1: Chipset<F>, C2: Chipset<F>> Toplevel<F, C1, C2> {
.enumerate()
.map(|(i, func)| {
func.check(&info_map, &chip_map);
let cfunc = func.expand().compile(i, &info_map, &chip_map);
(func.name, cfunc)
let (cfunc, groups) = func.expand().compile(i, &info_map, &chip_map);
let func_struct = FuncStruct::from_func(cfunc, &groups);
(func.name, func_struct)
})
.collect();
Toplevel { func_map, chip_map }
Expand All @@ -58,17 +77,21 @@ impl<F: Field + Ord, C1: Chipset<F>, C2: Chipset<F>> Toplevel<F, C1, C2> {
impl<F, C1: Chipset<F>, C2: Chipset<F>> Toplevel<F, C1, C2> {
#[inline]
pub fn func_by_index(&self, i: usize) -> &Func<F> {
self.func_map
&self
.func_map
.get_index(i)
.unwrap_or_else(|| panic!("Func index {i} out of bounds"))
.1
.full_func
}

#[inline]
pub fn func_by_name(&self, name: &'static str) -> &Func<F> {
self.func_map
&self
.func_map
.get(&Name(name))
.unwrap_or_else(|| panic!("Func {name} not found"))
.full_func
}

#[inline]
Expand All @@ -85,7 +108,7 @@ impl<F, C1: Chipset<F>, C2: Chipset<F>> Toplevel<F, C1, C2> {
}
}

/// A map from `Var` its block identifier. Variables in this map are always bound
/// A map from `Var` to its block identifier. Variables in this map are always bound
type BindMap = FxHashMap<Var, usize>;

/// A map that tells whether a `Var`, from a certain block, has been used or not
Expand Down Expand Up @@ -173,6 +196,7 @@ struct LinkCtx<'a, C1, C2> {
var_index: usize,
return_ident: usize,
return_idents: Vec<usize>,
return_groups: FxHashSet<ReturnGroup>,
link_map: LinkMap,
info_map: &'a FxIndexMap<Name, FuncInfo>,
chip_map: &'a FxIndexMap<Name, Either<C1, C2>>,
Expand Down Expand Up @@ -258,28 +282,31 @@ impl<F: Field + Ord> FuncE<F> {
func_index: usize,
info_map: &FxIndexMap<Name, FuncInfo>,
chip_map: &FxIndexMap<Name, Either<C1, C2>>,
) -> Func<F> {
let ctx = &mut LinkCtx {
) -> (Func<F>, FxHashSet<ReturnGroup>) {
let mut ctx = LinkCtx {
var_index: 0,
return_ident: 0,
return_idents: vec![],
return_groups: FxHashSet::default(),
link_map: FxHashMap::default(),
info_map,
chip_map,
};
self.input_params.iter().for_each(|var| {
link_new(var, ctx);
link_new(var, &mut ctx);
});
let body = self.body.compile(ctx);
Func {
let body = self.body.compile(&mut ctx);
let func = Func {
name: self.name,
invertible: self.invertible,
partial: self.partial,
index: func_index,
body,
input_size: self.input_params.total_size(),
output_size: self.output_size,
}
sharded: false,
};
(func, ctx.return_groups)
}
}

Expand Down Expand Up @@ -536,6 +563,7 @@ impl<F: Field + Ord> CtrlE<F> {
let ctrl = Ctrl::Return(ctx.return_ident, return_vec, *group);
ctx.return_idents.push(ctx.return_ident);
ctx.return_ident += 1;
ctx.return_groups.insert(*group);
ctrl
}
CtrlE::Choose(v, cases) => {
Expand Down Expand Up @@ -883,3 +911,112 @@ impl<F: Field + Ord> OpE<F> {
}
}
}

impl<F: Field + Ord> Func<F> {
fn shard(&self, groups: &FxHashSet<ReturnGroup>) -> FxHashMap<ReturnGroup, Func<F>> {
assert!(!self.sharded);
let mut map = FxHashMap::default();
for group in groups.iter() {
let body = self
.body
.shard(*group, &mut 0)
.expect("Group {group} does not exist");
let func = Func {
body,
sharded: true,
..*self
};
map.insert(*group, func);
}
map
}
}

impl<F: Field + Ord> Block<F> {
fn shard(&self, group: ReturnGroup, return_ident: &mut usize) -> Option<Self> {
let (ctrl, return_idents) = self.ctrl.shard(group, return_ident)?;
let block = Block {
ctrl,
ops: self.ops.clone(),
return_idents: return_idents.into(),
};
Some(block)
}
}

impl<F: Field + Ord> Ctrl<F> {
fn shard(&self, group: ReturnGroup, return_ident: &mut usize) -> Option<(Self, Vec<usize>)> {
match self {
Ctrl::Return(_, out, return_group) => {
if group != *return_group {
return None;
}
let ctrl = Ctrl::Return(*return_ident, out.clone(), group);
let return_idents = vec![*return_ident];
*return_ident += 1;
Some((ctrl, return_idents))
}
Ctrl::Choose(var, cases, branches) => {
let mut return_idents = vec![];
let mut map = FxHashMap::default();
let mut idx = 0;
let branches: Vec<_> = branches
.iter()
.enumerate()
.filter_map(|(i, branch)| {
let branch = branch.shard(group, return_ident)?;
map.insert(i, idx);
idx += 1;
return_idents.extend_from_slice(&branch.return_idents);
Some(branch)
})
.collect();
let cases_branches = Map::from_vec(
cases
.branches
.iter()
.filter_map(|(f, i)| {
let idx = map.get(i)?;
Some((*f, *idx))
})
.collect(),
);
let default: Option<_> = cases.default.as_ref().and_then(|i| {
let idx = map.get(i)?;
Some((*idx).into())
});
let cases = Cases {
branches: cases_branches,
default,
};
let ctrl = Ctrl::Choose(*var, cases, branches.into());
Some((ctrl, return_idents))
}
Ctrl::ChooseMany(vars, cases) => {
let mut return_idents = vec![];
let cases_branches = Map::from_vec(
cases
.branches
.iter()
.filter_map(|(fs, branch)| {
let branch = branch.shard(group, return_ident)?;
return_idents.extend(&branch.return_idents);
Some((fs.clone(), branch))
})
.collect(),
);
let default: Option<_> = cases.default.as_ref().and_then(|branch| {
let branch = branch.shard(group, return_ident)?;
return_idents.extend(&branch.return_idents);
Some(branch.into())
});
let cases = Cases {
branches: cases_branches,
default,
};
let ctrl = Ctrl::ChooseMany(vars.clone(), cases);
Some((ctrl, return_idents))
}
}
}
}
1 change: 1 addition & 0 deletions src/lair/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ impl<'a, T> ColumnMutSlice<'a, T> {
impl<F: PrimeField32, C1: Chipset<F>, C2: Chipset<F>> FuncChip<'_, F, C1, C2> {
/// Per-row parallel trace generation
pub fn generate_trace(&self, shard: &Shard<'_, F>) -> RowMajorMatrix<F> {
assert!(self.func.sharded);
let func_queries = &shard.queries().func_queries()[self.func.index];
let range = shard.get_func_range(self.func.index);
let width = self.width();
Expand Down

0 comments on commit 0e313cf

Please sign in to comment.