Skip to content

Commit

Permalink
Cleanup predicate simplifier code (#369)
Browse files Browse the repository at this point in the history
We currently have two code paths for predicate simplification. This is
messy and confusing. This PR refactors this into a single code-path
parametrized on whether we want to recurse into equation application or
not.
  • Loading branch information
goodlyrottenapple authored Nov 16, 2023
1 parent 8ecb99b commit ac35589
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 124 deletions.
6 changes: 3 additions & 3 deletions library/Booster/JsonRpc.hs
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ respond stateVar =
[] -> term
ps -> KoreJson.KJAnd tSort $ term : ps
pure $ Right (addHeader result, patternTraces)
(Left ApplyEquations.SideConditionsFalse{}, patternTraces, _) -> do
(Left ApplyEquations.SideConditionFalse{}, patternTraces, _) -> do
let tSort = fromMaybe (error "unknown sort") $ sortOfJson req.state.term
pure $ Right (addHeader $ KoreJson.KJBottom tSort, patternTraces)
(Left (ApplyEquations.EquationLoop terms), _traces, _) ->
Expand Down Expand Up @@ -441,7 +441,7 @@ mkLogEquationTrace
, origin
, result = Failure{reason = "Indeterminate side-condition", _ruleId}
}
ApplyEquations.ConditionFalse
ApplyEquations.ConditionFalse{}
| logFailedSimplifications ->
Just $
Simplification
Expand Down Expand Up @@ -573,7 +573,7 @@ mkLogRewriteTrace
, origin = Booster
, result = Failure{reason = "Internal error: " <> err, _ruleId = Nothing}
}
ApplyEquations.SideConditionsFalse _predicates ->
ApplyEquations.SideConditionFalse _predicate ->
Simplification
{ originalTerm = Nothing
, originalTermIndex = Nothing
Expand Down
87 changes: 40 additions & 47 deletions library/Booster/Pattern/ApplyEquations.hs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import Control.Monad.Logger.CallStack (
)
import Control.Monad.Trans.Class
import Control.Monad.Trans.Except
import Control.Monad.Trans.Maybe
import Control.Monad.Trans.Reader (ReaderT (..), ask)
import Control.Monad.Trans.State
import Data.Bifunctor (second)
Expand Down Expand Up @@ -69,7 +68,7 @@ data EquationFailure
= IndexIsNone Term
| TooManyIterations Int Term Term
| EquationLoop [Term]
| SideConditionsFalse [Predicate]
| SideConditionFalse Predicate
| InternalError Text
deriving stock (Eq, Show)

Expand Down Expand Up @@ -128,21 +127,22 @@ instance Pretty EquationTrace where
]
++ map pretty cs
++ ["using " <> locationInfo]
ConditionFalse ->
ConditionFalse p ->
vsep
[ "Simplifying term"
, prettyTerm
, "failed with false condition"
, pretty p
, "using " <> locationInfo
]
EnsuresFalse ps ->
vsep $
EnsuresFalse p ->
vsep
[ "Simplifying term"
, prettyTerm
, "using " <> locationInfo
, "resulted in ensuring false conditions"
, "resulted in ensuring false condition"
, pretty p
]
<> map pretty ps
MatchConstraintViolated constrained varName ->
vsep
[ "Concreteness constraint violated: "
Expand Down Expand Up @@ -305,7 +305,7 @@ evaluatePattern' Pattern{term, constraints} = do
allPs <- predicates <$> getState
let otherPs = Set.delete p allPs
EquationT $ lift $ lift $ modify $ \s -> s{predicates = otherPs}
newP <- simplifyConstraint' p
newP <- simplifyConstraint' True p
pushConstraints $ Set.singleton newP

----------------------------------------
Expand Down Expand Up @@ -446,8 +446,8 @@ data ApplyEquationResult
| FailedMatch MatchFailReason
| IndeterminateMatch
| IndeterminateCondition [Predicate]
| ConditionFalse
| EnsuresFalse [Predicate]
| ConditionFalse Predicate
| EnsuresFalse Predicate
| RuleNotPreservingDefinedness
| MatchConstraintViolated Constrained VarName
deriving stock (Eq, Show)
Expand All @@ -468,8 +468,8 @@ handleFunctionEquation success continue abort = \case
FailedMatch _ -> continue
IndeterminateMatch -> abort
IndeterminateCondition{} -> abort
ConditionFalse -> continue
EnsuresFalse ps -> throw $ SideConditionsFalse ps
ConditionFalse _ -> continue
EnsuresFalse p -> throw $ SideConditionFalse p
RuleNotPreservingDefinedness -> abort
MatchConstraintViolated{} -> continue

Expand All @@ -479,8 +479,8 @@ handleSimplificationEquation success continue _abort = \case
FailedMatch _ -> continue
IndeterminateMatch -> continue
IndeterminateCondition{} -> continue
ConditionFalse -> continue
EnsuresFalse ps -> throw $ SideConditionsFalse ps
ConditionFalse _ -> continue
EnsuresFalse p -> throw $ SideConditionFalse p
RuleNotPreservingDefinedness -> continue
MatchConstraintViolated{} -> continue

Expand Down Expand Up @@ -580,39 +580,32 @@ applyEquation term rule = fmap (either id Success) $ runExceptT $ do
concatMap
(splitBoolPredicates . substituteInPredicate subst)
rule.requires
unclearConditions' <- runMaybeT $ catMaybes <$> mapM checkConstraint required
unclearConditions' <- catMaybes <$> mapM (checkConstraint ConditionFalse) required

case unclearConditions' of
Nothing -> throwE ConditionFalse
Just unclearConditions ->
if not $ null unclearConditions
then throwE $ IndeterminateCondition unclearConditions
else do
-- check ensured conditions, filter any
-- true ones, prune if any is false
let ensured =
concatMap
(splitBoolPredicates . substituteInPredicate subst)
(Set.toList rule.ensures)
mbEnsuredConditions <-
runMaybeT $ catMaybes <$> mapM checkConstraint ensured
case mbEnsuredConditions of
-- throws if an ensured condition found to be false
Nothing -> throwE $ EnsuresFalse ensured
-- pushes new ensured conditions and return result
Just conditions ->
lift $ pushConstraints $ Set.fromList conditions
pure $ substituteInTerm subst rule.rhs
[] -> do
-- check ensured conditions, filter any
-- true ones, prune if any is false
let ensured =
concatMap
(splitBoolPredicates . substituteInPredicate subst)
(Set.toList rule.ensures)
ensuredConditions <-
-- throws if an ensured condition found to be false
catMaybes <$> mapM (checkConstraint EnsuresFalse) ensured
lift $ pushConstraints $ Set.fromList ensuredConditions
pure $ substituteInTerm subst rule.rhs
unclearConditions -> throwE $ IndeterminateCondition unclearConditions
where
-- evaluate/simplify a predicate, cut the operation short when it
-- is Bottom.
checkConstraint ::
(Predicate -> ApplyEquationResult) ->
Predicate ->
MaybeT (ExceptT ApplyEquationResult (EquationT io)) (Maybe Predicate)
checkConstraint p = do
mApi <- (.llvmApi) <$> lift (lift getConfig)
case simplifyPredicate mApi p of
Bottom -> fail "side condition was false"
ExceptT ApplyEquationResult (EquationT io) (Maybe Predicate)
checkConstraint whenBottom p =
lift (simplifyConstraint' False p) >>= \case
Bottom -> throwE $ whenBottom p
Top -> pure Nothing
_other -> pure $ Just p

Expand Down Expand Up @@ -688,15 +681,15 @@ simplifyConstraint ::
Predicate ->
io (Either EquationFailure Predicate, [EquationTrace], SimplifierCache)
simplifyConstraint doTracing def mbApi cache p =
runEquationT doTracing def mbApi cache $ simplifyConstraint' p
runEquationT doTracing def mbApi cache $ simplifyConstraint' True p

-- version for internal nested evaluation
simplifyConstraint' :: MonadLoggerIO io => Predicate -> EquationT io Predicate
simplifyConstraint' :: MonadLoggerIO io => Bool -> Predicate -> EquationT io Predicate
-- We are assuming all predicates are of the form 'true \equals P' and
-- evaluating them using simplifyBool if they are concrete.
-- Non-concrete \equals predicates are simplified using evaluateTerm.
simplifyConstraint' = \case
EqualsTerm TrueBool t@(Term attributes _)
simplifyConstraint' recurse = \case
p@(EqualsTerm TrueBool t@(Term attributes _))
| isConcrete t && attributes.canBeEvaluated -> do
mbApi <- (.llvmApi) <$> getConfig
case mbApi of
Expand All @@ -705,12 +698,12 @@ simplifyConstraint' = \case
then pure Top
else pure Bottom
Nothing ->
evalBool t >>= prune
if recurse then evalBool t >>= prune else pure p
| otherwise ->
evalBool t >>= prune
if recurse then evalBool t >>= prune else pure p
EqualsTerm t TrueBool ->
-- normalise to 'true' in first argument (like 'kore-rpc')
simplifyConstraint' (EqualsTerm TrueBool t)
simplifyConstraint' recurse (EqualsTerm TrueBool t)
other ->
pure other -- should not occur, predicates should be 'true \equals _'
where
Expand Down
4 changes: 2 additions & 2 deletions library/Booster/Pattern/Rewrite.hs
Original file line number Diff line number Diff line change
Expand Up @@ -605,8 +605,8 @@ performRewrite doTracing def mLlvmLibrary mbMaxDepth cutLabels terminalLabels pa
Right newPattern -> do
rewriteTrace $ RewriteSimplified traces Nothing
pure $ Just newPattern
Left r@(SideConditionsFalse _ps) -> do
logSimplify "Side conditions were found to be false, pruning"
Left r@(SideConditionFalse _p) -> do
logSimplify "A side condition was found to be false, pruning"
rewriteTrace $ RewriteSimplified traces (Just r)
pure Nothing
-- NB any errors here might be caused by simplifying one
Expand Down
73 changes: 1 addition & 72 deletions library/Booster/Pattern/Simplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,11 @@ Copyright : (c) Runtime Verification, 2022
License : BSD-3-Clause
-}
module Booster.Pattern.Simplify (
simplifyPredicate,
splitBoolPredicates,
simplifyConcrete,
) where

import Booster.Definition.Base
import Booster.LLVM (simplifyBool, simplifyTerm)
import Booster.LLVM.Internal qualified as LLVM
import Booster.Pattern.Base
import Booster.Pattern.Util (isConcrete, sortOfTerm)
import Data.Bifunctor (bimap)
import Booster.Pattern.Util (isConcrete)

{- | We want to break apart predicates of type `X #Equals Y1 andBool ... Yn` into
`X #Equals Y1, ..., X #Equals Yn` in the case when some of the `Y`s are abstract
Expand All @@ -25,68 +19,3 @@ splitBoolPredicates = \case
EqualsTerm (AndBool ls) r -> concatMap (splitBoolPredicates . flip EqualsTerm r) ls
EqualsTerm l (AndBool rs) -> concatMap (splitBoolPredicates . EqualsTerm l) rs
other -> [other]

simplifyPredicate :: Maybe LLVM.API -> Predicate -> Predicate
simplifyPredicate mApi = \case
AndPredicate l r -> case (simplifyPredicate mApi l, simplifyPredicate mApi r) of
(Bottom, _) -> Bottom
(_, Bottom) -> Bottom
(Top, r') -> r'
(l', Top) -> l'
(l', r') -> AndPredicate l' r'
Bottom -> Bottom
p@(Ceil _) -> p
p@(EqualsTerm l r) ->
case (mApi, sortOfTerm l == SortBool && isConcrete l && isConcrete r) of
(Just api, True) ->
if simplifyBool api l == simplifyBool api r
then Top
else Bottom
_ -> p
EqualsPredicate l r -> EqualsPredicate (simplifyPredicate mApi l) (simplifyPredicate mApi r)
p@(Exists _ _) -> p
p@(Forall _ _) -> p
Iff l r -> Iff (simplifyPredicate mApi l) (simplifyPredicate mApi r)
Implies l r -> Implies (simplifyPredicate mApi l) (simplifyPredicate mApi r)
p@(In _ _) -> p
Not p -> case simplifyPredicate mApi p of
Top -> Bottom
Bottom -> Top
p' -> p'
Or l r -> Or (simplifyPredicate mApi l) (simplifyPredicate mApi r)
Top -> Top

{- | traverses a term top-down, using a given LLVM dy.lib to simplify
the concrete parts (leaving variables alone)
-}
simplifyConcrete :: Maybe LLVM.API -> KoreDefinition -> Term -> Term
simplifyConcrete Nothing _ trm = trm
simplifyConcrete (Just mApi) def trm = recurse trm
where
recurse :: Term -> Term
-- recursion scheme for this?
-- cata $ \case does not work here, would need helpers for TermF not Term
-- t | isConcreteF t -> simplifyTerm dl def t (sortOfTerm t)
-- other -> embed other\
recurse t@(Term attributes _)
| attributes.isEvaluated =
t
| isConcrete t && attributes.canBeEvaluated =
simplifyTerm mApi def t (sortOfTerm t)
| otherwise =
case t of
var@Var{} ->
var -- nothing to do. Should have isEvaluated set
dv@DomainValue{} ->
dv -- nothing to do. Should have isEvaluated set
AndTerm t1 t2 ->
AndTerm (recurse t1) (recurse t2)
SymbolApplication sym sorts args ->
SymbolApplication sym sorts (map recurse args)
Injection sources target sub ->
Injection sources target $ recurse sub
KMap mdef keyVals rest -> KMap mdef (map (bimap recurse recurse) keyVals) (recurse <$> rest)
KList ldef heads rest ->
KList ldef (map recurse heads) (fmap (bimap recurse (map recurse)) rest)
KSet sdef heads rest ->
KSet sdef (map recurse heads) (fmap recurse rest)

0 comments on commit ac35589

Please sign in to comment.