Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement field integer division and remainder operations #1349

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelogs/unreleased/1349-dark64
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implement field division and remainder operations using euclidean division
2 changes: 1 addition & 1 deletion zokrates_analysis/src/expression_validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ExpressionValidator {
| FieldElementExpression::Xor(_)
| FieldElementExpression::LeftShift(_)
| FieldElementExpression::RightShift(_) => Err(Error(format!(
"Found non-constant bitwise operation in field element expression `{}`",
"Field element expression `{}` must be a constant expression",
e
))),
FieldElementExpression::Pow(e) => {
Expand Down
6 changes: 6 additions & 0 deletions zokrates_analysis/src/flatten_complex_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1101,6 +1101,12 @@ fn fold_field_expression<'ast, T: Field>(
typed::FieldElementExpression::Div(e) => {
zir::FieldElementExpression::Div(f.fold_binary_expression(statements_buffer, e))
}
typed::FieldElementExpression::IDiv(e) => {
zir::FieldElementExpression::IDiv(f.fold_binary_expression(statements_buffer, e))
}
typed::FieldElementExpression::Rem(e) => {
zir::FieldElementExpression::Rem(f.fold_binary_expression(statements_buffer, e))
}
typed::FieldElementExpression::Pow(e) => {
zir::FieldElementExpression::Pow(f.fold_binary_expression(statements_buffer, e))
}
Expand Down
57 changes: 57 additions & 0 deletions zokrates_analysis/src/panic_extractor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,44 @@ impl<'ast, T: Field> Folder<'ast, T> for PanicExtractor<'ast, T> {
);
FieldElementExpression::div(n, d)
}
FieldElementExpression::IDiv(e) => {
let n = self.fold_field_expression(*e.left);
let d = self.fold_field_expression(*e.right);
self.panic_buffer.push(
ZirStatement::assertion(
BooleanExpression::not(
BooleanExpression::field_eq(
d.clone().span(span),
FieldElementExpression::value(T::zero()).span(span),
)
.span(span),
)
.span(span),
RuntimeError::DivisionByZero,
)
.span(span),
);
FieldElementExpression::idiv(n, d)
}
FieldElementExpression::Rem(e) => {
let n = self.fold_field_expression(*e.left);
let d = self.fold_field_expression(*e.right);
self.panic_buffer.push(
ZirStatement::assertion(
BooleanExpression::not(
BooleanExpression::field_eq(
d.clone().span(span),
FieldElementExpression::value(T::zero()).span(span),
)
.span(span),
)
.span(span),
RuntimeError::DivisionByZero,
)
.span(span),
);
FieldElementExpression::rem(n, d)
}
e => fold_field_expression_cases(self, e),
}
}
Expand Down Expand Up @@ -150,6 +188,25 @@ impl<'ast, T: Field> Folder<'ast, T> for PanicExtractor<'ast, T> {
);
UExpression::div(n, d).into_inner()
}
UExpressionInner::Rem(e) => {
let n = self.fold_uint_expression(*e.left);
let d = self.fold_uint_expression(*e.right);
self.panic_buffer.push(
ZirStatement::assertion(
BooleanExpression::not(
BooleanExpression::uint_eq(
d.clone().span(span),
UExpression::value(0).annotate(b).span(span),
)
.span(span),
)
.span(span),
RuntimeError::DivisionByZero,
)
.span(span),
);
UExpression::rem(n, d).into_inner()
}
e => fold_uint_expression_cases(self, b, e),
}
}
Expand Down
97 changes: 91 additions & 6 deletions zokrates_analysis/src/propagation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,12 @@ impl fmt::Display for Error {
Error::Type(s) => write!(f, "{}", s),
Error::AssertionFailed(err) => write!(f, "Assertion failed ({})", err),
Error::InvalidValue(s) => write!(f, "{}", s),
Error::OutOfBounds(index, size) => write!(
f,
"Out of bounds index ({} >= {}) found during static analysis",
index, size
),
Error::OutOfBounds(index, size) => {
write!(f, "Out of bounds index ({} >= {})", index, size)
}
Error::VariableLength(message) => write!(f, "{}", message),
Error::DivisionByZero => {
write!(f, "Division by zero detected during static analysis",)
write!(f, "Division by zero detected",)
}
}
}
Expand Down Expand Up @@ -856,6 +854,22 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> {
Ok(UExpression::and(e1.annotate(bitwidth), e2.annotate(bitwidth)).into_inner())
}
},
UExpressionInner::Or(e) => match (
self.fold_uint_expression(*e.left)?.into_inner(),
self.fold_uint_expression(*e.right)?.into_inner(),
) {
(UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => {
Ok(UExpression::value(v1.value | v2.value))
}
(UExpressionInner::Value(v), e) | (e, UExpressionInner::Value(v))
if v.value == 0 =>
{
Ok(e)
}
(e1, e2) => {
Ok(UExpression::or(e1.annotate(bitwidth), e2.annotate(bitwidth)).into_inner())
}
},
UExpressionInner::Not(e) => {
let e = self.fold_uint_expression(*e.inner)?.into_inner();
match e {
Expand Down Expand Up @@ -939,6 +953,35 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> {
(e1, e2) => Ok(e1 / e2),
}
}
FieldElementExpression::IDiv(e) => {
let left = self.fold_field_expression(*e.left)?;
let right = self.fold_field_expression(*e.right)?;

Ok(match (left, right) {
(FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => {
FieldElementExpression::value(
T::try_from(n1.value.to_biguint().div(n2.value.to_biguint())).unwrap(),
)
}
(e1, e2) => FieldElementExpression::idiv(e1, e2),
})
}
FieldElementExpression::Rem(e) => {
let left = self.fold_field_expression(*e.left)?;
let right = self.fold_field_expression(*e.right)?;

Ok(match (left, right) {
(_, FieldElementExpression::Value(n)) if n.value == T::from(1) => {
FieldElementExpression::value(T::zero())
}
(FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => {
FieldElementExpression::value(
T::try_from(n1.value.to_biguint().rem(n2.value.to_biguint())).unwrap(),
)
}
(e1, e2) => e1 % e2,
})
}
FieldElementExpression::Neg(e) => match self.fold_field_expression(*e.inner)? {
FieldElementExpression::Value(n) => {
Ok(FieldElementExpression::value(T::zero() - n.value))
Expand Down Expand Up @@ -1606,6 +1649,48 @@ mod tests {
);
}

#[test]
fn idiv() {
let e = FieldElementExpression::idiv(
FieldElementExpression::value(Bn128Field::from(7)),
FieldElementExpression::value(Bn128Field::from(2)),
);

assert_eq!(
Propagator::default().fold_field_expression(e),
Ok(FieldElementExpression::value(Bn128Field::from(3)))
);
}

#[test]
fn rem() {
let mut propagator = Propagator::default();

assert_eq!(
propagator.fold_field_expression(FieldElementExpression::rem(
FieldElementExpression::value(Bn128Field::from(5)),
FieldElementExpression::value(Bn128Field::from(2)),
)),
Ok(FieldElementExpression::value(Bn128Field::from(1)))
);

assert_eq!(
propagator.fold_field_expression(FieldElementExpression::rem(
FieldElementExpression::value(Bn128Field::from(2)),
FieldElementExpression::value(Bn128Field::from(5)),
)),
Ok(FieldElementExpression::value(Bn128Field::from(2)))
);

assert_eq!(
propagator.fold_field_expression(FieldElementExpression::rem(
FieldElementExpression::identifier("a".into()),
FieldElementExpression::value(Bn128Field::from(1)),
)),
Ok(FieldElementExpression::value(Bn128Field::from(0)))
);
}

#[test]
fn pow() {
let e = FieldElementExpression::pow(
Expand Down
112 changes: 106 additions & 6 deletions zokrates_analysis/src/zir_propagation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,11 @@ pub enum Error {
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Error::OutOfBounds(index, size) => write!(
f,
"Out of bounds index ({} >= {}) found in zir during static analysis",
index, size
),
Error::OutOfBounds(index, size) => {
write!(f, "Out of bounds index ({} >= {})", index, size)
}
Error::DivisionByZero => {
write!(f, "Division by zero detected in zir during static analysis",)
write!(f, "Division by zero detected",)
}
Error::AssertionFailed(err) => write!(f, "Assertion failed ({})", err),
}
Expand Down Expand Up @@ -343,6 +341,42 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
(e1, e2) => Ok(FieldElementExpression::div(e1, e2).span(e.span)),
}
}
FieldElementExpression::IDiv(e) => {
let left = self.fold_field_expression(*e.left)?;
let right = self.fold_field_expression(*e.right)?;

match (left, right) {
(_, FieldElementExpression::Value(n)) if n.value == T::from(0) => {
Err(Error::DivisionByZero)
}
(e, FieldElementExpression::Value(n)) if n.value == T::from(1) => Ok(e),
(FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => {
Ok(FieldElementExpression::value(
T::try_from(n1.value.to_biguint().div(n2.value.to_biguint())).unwrap(),
))
}
(e1, e2) => Ok(FieldElementExpression::idiv(e1, e2).span(e.span)),
}
}
FieldElementExpression::Rem(e) => {
let left = self.fold_field_expression(*e.left)?;
let right = self.fold_field_expression(*e.right)?;

match (left, right) {
(_, FieldElementExpression::Value(n)) if n.value == T::from(0) => {
Err(Error::DivisionByZero)
}
(_, FieldElementExpression::Value(n)) if n.value == T::from(1) => {
Ok(FieldElementExpression::value(T::zero()))
}
(FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => {
Ok(FieldElementExpression::value(
T::try_from(n1.value.to_biguint().rem(n2.value.to_biguint())).unwrap(),
))
}
(e1, e2) => Ok(FieldElementExpression::rem(e1, e2).span(e.span)),
}
}
FieldElementExpression::Pow(e) => {
let exponent = self.fold_uint_expression(*e.right)?;
match (self.fold_field_expression(*e.left)?, exponent.into_inner()) {
Expand Down Expand Up @@ -1099,6 +1133,72 @@ mod tests {
);
}

#[test]
fn idiv() {
let mut propagator = ZirPropagator::default();

assert_eq!(
propagator.fold_field_expression(FieldElementExpression::idiv(
FieldElementExpression::value(Bn128Field::from(7)),
FieldElementExpression::value(Bn128Field::from(2)),
)),
Ok(FieldElementExpression::value(Bn128Field::from(3)))
);

assert_eq!(
propagator.fold_field_expression(FieldElementExpression::idiv(
FieldElementExpression::identifier("a".into()),
FieldElementExpression::value(Bn128Field::from(1)),
)),
Ok(FieldElementExpression::identifier("a".into()))
);

assert_eq!(
propagator.fold_field_expression(FieldElementExpression::idiv(
FieldElementExpression::identifier("a".into()),
FieldElementExpression::value(Bn128Field::from(0)),
)),
Err(Error::DivisionByZero)
);
}

#[test]
fn rem() {
let mut propagator = ZirPropagator::default();

assert_eq!(
propagator.fold_field_expression(FieldElementExpression::rem(
FieldElementExpression::value(Bn128Field::from(5)),
FieldElementExpression::value(Bn128Field::from(2)),
)),
Ok(FieldElementExpression::value(Bn128Field::from(1)))
);

assert_eq!(
propagator.fold_field_expression(FieldElementExpression::rem(
FieldElementExpression::value(Bn128Field::from(2)),
FieldElementExpression::value(Bn128Field::from(5)),
)),
Ok(FieldElementExpression::value(Bn128Field::from(2)))
);

assert_eq!(
propagator.fold_field_expression(FieldElementExpression::rem(
FieldElementExpression::identifier("a".into()),
FieldElementExpression::value(Bn128Field::from(1)),
)),
Ok(FieldElementExpression::value(Bn128Field::from(0)))
);

assert_eq!(
propagator.fold_field_expression(FieldElementExpression::div(
FieldElementExpression::identifier("a".into()),
FieldElementExpression::value(Bn128Field::from(0)),
)),
Err(Error::DivisionByZero)
);
}

#[test]
fn pow() {
let mut propagator = ZirPropagator::<Bn128Field>::default();
Expand Down
7 changes: 7 additions & 0 deletions zokrates_ast/src/common/operators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ impl OperatorStr for OpDiv {
const STR: &'static str = "/";
}

#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
pub struct OpIDiv;

impl OperatorStr for OpIDiv {
const STR: &'static str = "\\";
}

#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
pub struct OpRem;

Expand Down
6 changes: 5 additions & 1 deletion zokrates_ast/src/ir/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::ir::Parameter;
use crate::ir::ProgIterator;
use crate::ir::Statement;
use crate::ir::Variable;
use crate::Solver;
use std::collections::HashSet;
use zokrates_field::Field;

Expand Down Expand Up @@ -46,7 +47,10 @@ impl<'ast, T: Field> Folder<'ast, T> for UnconstrainedVariableDetector {
&mut self,
d: DirectiveStatement<'ast, T>,
) -> Vec<Statement<'ast, T>> {
self.variables.extend(d.outputs.iter());
match d.solver {
Solver::Zir(_) => {} // we do not check variables introduced by assembly
_ => self.variables.extend(d.outputs.iter()), // this is not necessary, but we keep it as a sanity check
};
vec![Statement::Directive(d)]
}
}
Loading
Loading