Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KCFG optimisation-on-creation #4710

Merged
merged 14 commits into from
Dec 13, 2024
26 changes: 21 additions & 5 deletions pyk/src/pyk/kcfg/kcfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from ..kast.outer import KFlatModule
from ..prelude.kbool import andBool
from ..utils import ensure_dir_path, not_none
from ..utils import ensure_dir_path, not_none, single

if TYPE_CHECKING:
from collections.abc import Iterable, Mapping, MutableMapping
Expand Down Expand Up @@ -557,6 +557,7 @@ def path_length(_path: Iterable[KCFG.Successor]) -> int:
def extend(
self,
extend_result: KCFGExtendResult,
optimize_kcfg: bool,
node: KCFG.Node,
logs: dict[int, tuple[LogEntry, ...]],
PetarMax marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
Expand All @@ -569,6 +570,20 @@ def log(message: str, *, warning: bool = False) -> None:
f'Extending current KCFG with the following: {message}{result_info_message}',
)

def optimize_step(
cterm: CTerm, depth: int, next_node_logs: tuple[LogEntry, ...], rule_labels: list[str]
) -> bool:
in_edges = self.edges(target_id=node.id)
if len(in_edges) == 1:
in_edge = single(in_edges)
PetarMax marked this conversation as resolved.
Show resolved Hide resolved
self.remove_edge(in_edge.source.id, node.id)
ehildenb marked this conversation as resolved.
Show resolved Hide resolved
self.let_node(node_id=node.id, cterm=cterm)
self.create_edge(in_edge.source.id, node.id, in_edge.depth + depth, list(in_edge.rules) + rule_labels)
logs[node.id] = logs[node.id] + next_node_logs
log(f'basic block at depth {depth}: update: {node.id}')
return True
return False

match extend_result:
case Vacuous():
self.add_vacuous(node.id)
Expand All @@ -584,10 +599,11 @@ def log(message: str, *, warning: bool = False) -> None:
log(f'abstraction node: {node.id}')

case Step(cterm, depth, next_node_logs, rule_labels, _):
next_node = self.create_node(cterm)
logs[next_node.id] = next_node_logs
self.create_edge(node.id, next_node.id, depth, rules=rule_labels)
log(f'basic block at depth {depth}: {node.id} --> {next_node.id}')
if not (optimize_kcfg and optimize_step(cterm, depth, next_node_logs, rule_labels)):
next_node = self.create_node(cterm)
logs[next_node.id] = next_node_logs
self.create_edge(node.id, next_node.id, depth, rules=rule_labels)
log(f'basic block at depth {depth}: {node.id} --> {next_node.id}')

case Branch(branches, _):
branch_node_ids = self.split_on_constraints(node.id, branches)
Expand Down
30 changes: 27 additions & 3 deletions pyk/src/pyk/proof/reachability.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
class APRProofResult:
node_id: int
prior_loops_cache_update: tuple[int, ...]
optimize_kcfg: bool


@dataclass
Expand Down Expand Up @@ -220,6 +221,7 @@ def commit(self, result: APRProofResult) -> None:
assert result.cached_node_id in self._next_steps
self.kcfg.extend(
extend_result=self._next_steps.pop(result.cached_node_id),
optimize_kcfg=result.optimize_kcfg,
node=self.kcfg.node(result.node_id),
logs=self.logs,
)
Expand All @@ -230,6 +232,7 @@ def commit(self, result: APRProofResult) -> None:
self._next_steps[result.node_id] = result.extension_to_cache
self.kcfg.extend(
extend_result=result.extension_to_apply,
optimize_kcfg=result.optimize_kcfg,
node=self.kcfg.node(result.node_id),
logs=self.logs,
)
Expand Down Expand Up @@ -715,6 +718,7 @@ class APRProver(Prover[APRProof, APRProofStep, APRProofResult]):
assume_defined: bool
kcfg_explore: KCFGExplore
extra_module: KFlatModule | None
optimize_kcfg: bool

def __init__(
self,
Expand All @@ -727,6 +731,7 @@ def __init__(
direct_subproof_rules: bool = False,
assume_defined: bool = False,
extra_module: KFlatModule | None = None,
optimize_kcfg: bool = False,
) -> None:

self.kcfg_explore = kcfg_explore
Expand All @@ -739,6 +744,7 @@ def __init__(
self.direct_subproof_rules = direct_subproof_rules
self.assume_defined = assume_defined
self.extra_module = extra_module
self.optimize_kcfg = optimize_kcfg

def close(self) -> None:
self.kcfg_explore.cterm_symbolic._kore_client.close()
Expand Down Expand Up @@ -808,14 +814,24 @@ def step_proof(self, step: APRProofStep) -> list[APRProofResult]:
_LOGGER.info(f'Prior loop heads for node {step.node.id}: {(step.node.id, prior_loops)}')
if len(prior_loops) > step.bmc_depth:
_LOGGER.warning(f'Bounded node {step.proof_id}: {step.node.id} at bmc depth {step.bmc_depth}')
return [APRProofBoundedResult(node_id=step.node.id, prior_loops_cache_update=prior_loops)]
return [
APRProofBoundedResult(
node_id=step.node.id, optimize_kcfg=self.optimize_kcfg, prior_loops_cache_update=prior_loops
)
]

# Check if the current node and target are terminal
is_terminal = self.kcfg_explore.kcfg_semantics.is_terminal(step.node.cterm)
target_is_terminal = self.kcfg_explore.kcfg_semantics.is_terminal(step.target.cterm)

terminal_result: list[APRProofResult] = (
[APRProofTerminalResult(node_id=step.node.id, prior_loops_cache_update=prior_loops)] if is_terminal else []
[
APRProofTerminalResult(
node_id=step.node.id, optimize_kcfg=self.optimize_kcfg, prior_loops_cache_update=prior_loops
)
]
if is_terminal
else []
)

# Subsumption is checked if and only if the target node
Expand All @@ -826,7 +842,12 @@ def step_proof(self, step: APRProofStep) -> list[APRProofResult]:
# Information about the subsumed node being terminal must be returned
# so that the set of terminal nodes is correctly updated
return terminal_result + [
APRProofSubsumeResult(csubst=csubst, node_id=step.node.id, prior_loops_cache_update=prior_loops)
APRProofSubsumeResult(
csubst=csubst,
optimize_kcfg=self.optimize_kcfg,
node_id=step.node.id,
prior_loops_cache_update=prior_loops,
)
]

if is_terminal:
Expand All @@ -849,6 +870,7 @@ def step_proof(self, step: APRProofStep) -> list[APRProofResult]:
APRProofUseCacheResult(
node_id=step.node.id,
cached_node_id=step.use_cache,
optimize_kcfg=self.optimize_kcfg,
prior_loops_cache_update=prior_loops,
)
]
Expand Down Expand Up @@ -876,6 +898,7 @@ def step_proof(self, step: APRProofStep) -> list[APRProofResult]:
extension_to_apply=extend_results[0],
extension_to_cache=extend_results[1],
prior_loops_cache_update=prior_loops,
optimize_kcfg=self.optimize_kcfg,
)
]

Expand All @@ -885,6 +908,7 @@ def step_proof(self, step: APRProofStep) -> list[APRProofResult]:
node_id=step.node.id,
extension_to_apply=extend_results[0],
prior_loops_cache_update=prior_loops,
optimize_kcfg=self.optimize_kcfg,
)
]

Expand Down
78 changes: 78 additions & 0 deletions pyk/src/tests/integration/proof/test_imp.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,35 @@ def same_loop(self, c1: CTerm, c2: CTerm) -> bool:
),
)

APR_PROVE_WITH_KCFG_OPTIMS_TEST_DATA: Iterable[
tuple[str, Path, str, str, int | None, int | None, Iterable[str], bool, ProofStatus, int]
] = (
(
'imp-simple-sum-100',
K_FILES / 'imp-simple-spec.k',
'IMP-SIMPLE-SPEC',
'sum-100',
None,
None,
[],
True,
ProofStatus.PASSED,
4,
),
(
'imp-simple-long-branches',
K_FILES / 'imp-simple-spec.k',
'IMP-SIMPLE-SPEC',
'long-branches',
None,
None,
[],
True,
ProofStatus.PASSED,
7,
),
)

PATH_CONSTRAINTS_TEST_DATA: Iterable[
tuple[str, Path, str, str, int | None, int | None, Iterable[str], Iterable[str], str]
] = (
Expand Down Expand Up @@ -918,6 +947,55 @@ def test_all_path_reachability_prove(
assert proof.status == proof_status
assert leaf_number(proof) == expected_leaf_number

@pytest.mark.parametrize(
'test_id,spec_file,spec_module,claim_id,max_iterations,max_depth,cut_rules,admit_deps,proof_status,expected_max_node_number',
APR_PROVE_WITH_KCFG_OPTIMS_TEST_DATA,
ids=[test_id for test_id, *_ in APR_PROVE_WITH_KCFG_OPTIMS_TEST_DATA],
)
def test_all_path_reachability_prove_with_kcfg_optims(
self,
kprove: KProve,
kcfg_explore: KCFGExplore,
test_id: str,
spec_file: str,
spec_module: str,
claim_id: str,
max_iterations: int | None,
max_depth: int | None,
cut_rules: Iterable[str],
admit_deps: bool,
proof_status: ProofStatus,
expected_max_node_number: int,
tmp_path_factory: TempPathFactory,
) -> None:
proof_dir = tmp_path_factory.mktemp(f'apr_tmp_proofs-{test_id}')
spec_modules = kprove.parse_modules(Path(spec_file), module_name=spec_module)
spec_label = f'{spec_module}.{claim_id}'
proofs = APRProof.from_spec_modules(
kprove.definition,
spec_modules,
spec_labels=[spec_label],
logs={},
proof_dir=proof_dir,
)
proof = single([p for p in proofs if p.id == spec_label])
if admit_deps:
for subproof in proof.subproofs:
subproof.admit()
subproof.write_proof_data()

prover = APRProver(
kcfg_explore=kcfg_explore, execute_depth=max_depth, cut_point_rules=cut_rules, optimize_kcfg=True
)
prover.advance_proof(proof, max_iterations=max_iterations)

kcfg_show = KCFGShow(kprove, node_printer=APRProofNodePrinter(proof, kprove, full_printer=True))
cfg_lines = kcfg_show.show(proof.kcfg)
_LOGGER.info('\n'.join(cfg_lines))

assert proof.status == proof_status
assert proof.kcfg._node_id == expected_max_node_number

def test_terminal_node_subsumption(
self,
kprove: KProve,
Expand Down
Loading