diff --git a/pyk/src/pyk/k2lean4/k2lean4.py b/pyk/src/pyk/k2lean4/k2lean4.py index 0b0f2583c5..df94fe0b2f 100644 --- a/pyk/src/pyk/k2lean4/k2lean4.py +++ b/pyk/src/pyk/k2lean4/k2lean4.py @@ -2,18 +2,19 @@ import re from dataclasses import dataclass +from graphlib import TopologicalSorter from typing import TYPE_CHECKING from ..konvert import unmunge from ..kore.internal import CollectionKind from ..kore.syntax import SortApp -from ..utils import check_type -from .model import Abbrev, Ctor, ExplBinder, Inductive, Module, Signature, Term +from .model import Abbrev, Ctor, ExplBinder, Inductive, Module, Signature, StructCtor, Structure, Term if TYPE_CHECKING: from typing import Final from ..kore.internal import KoreDefn + from ..kore.syntax import SymbolDecl from .model import Command @@ -29,13 +30,15 @@ class K2Lean4: def sort_module(self) -> Module: commands = [] commands += self._inductives() + commands += self._cells() commands += self._collections() return Module(commands=commands) def _inductives(self) -> list[Command]: def is_inductive(sort: str) -> bool: decl = self.defn.sorts[sort] - return not decl.hooked and 'hasDomainValues' not in decl.attrs_by_key + attrs = decl.attrs_by_key + return not decl.hooked and 'hasDomainValues' not in attrs and not decl.name.endswith('Cell') sorts = sorted(sort for sort in self.defn.sorts if is_inductive(sort)) return [self._inductive(sort) for sort in sorts] @@ -52,13 +55,50 @@ def _inj_ctor(self, sort: str, subsort: str) -> Ctor: return Ctor(f'inj_{subsort}', Signature((ExplBinder(('x',), Term(subsort)),), Term(sort))) def _symbol_ctor(self, sort: str, symbol: str) -> Ctor: - param_sorts = ( - check_type(sort, SortApp).name for sort in self.defn.symbols[symbol].param_sorts - ) # TODO eliminate check_type + decl = self.defn.symbols[symbol] + param_sorts = _param_sorts(decl) symbol = self._symbol_ident(symbol) binders = tuple(ExplBinder((f'x{i}',), Term(sort)) for i, sort in enumerate(param_sorts)) return Ctor(symbol, Signature(binders, Term(sort))) + def _cells(self) -> list[Command]: + def ordering(sorts: list[str]) -> dict[str, list[str]]: + res: dict[str, list[str]] = {} + for sort in sorts: + (cell_ctor,) = self.defn.constructors[sort] # Cells have exactly one constructor + decl = self.defn.symbols[cell_ctor] + res[sort] = [sort for sort in _param_sorts(decl) if self._is_cell(sort)] + return res + + sorts = [sort for sort in self.defn.sorts if self._is_cell(sort)] + sorts = list(TopologicalSorter(ordering(sorts)).static_order()) + return [self._cell(sort) for sort in sorts] + + def _cell(self, sort: str) -> Inductive | Structure: + (cell_ctor,) = self.defn.constructors[sort] + decl = self.defn.symbols[cell_ctor] + param_sorts = _param_sorts(decl) + + if any(not self._is_cell(sort) for sort in param_sorts): + assert len(param_sorts) == 1 + (param_sort,) = param_sorts + ctor = Ctor('mk', Signature((ExplBinder(('val',), Term(param_sort)),), Term(sort))) + return Inductive(sort, Signature((), Term('Type')), ctors=(ctor,)) + + param_names = [] + for param_sort in param_sorts: + assert param_sort.startswith('Sort') + assert param_sort.endswith('Cell') + name = param_sort[4:-4] + name = name[0].lower() + name[1:] + param_names.append(name) + fields = tuple(ExplBinder((name,), Term(sort)) for name, sort in zip(param_names, param_sorts, strict=True)) + return Structure(sort, Signature((), Term('Type')), ctor=StructCtor(fields)) + + @staticmethod + def _is_cell(sort: str) -> bool: + return sort.endswith('Cell') + @staticmethod def _symbol_ident(symbol: str) -> str: if symbol.startswith('Lbl'): @@ -74,7 +114,7 @@ def _collections(self) -> list[Command]: def _collection(self, sort: str) -> Abbrev: coll = self.defn.collections[sort] elem = self.defn.symbols[coll.element] - sorts = ' '.join(check_type(sort, SortApp).name for sort in elem.param_sorts) # TODO eliminate check_type + sorts = ' '.join(_param_sorts(elem)) assert sorts match coll.kind: case CollectionKind.LIST: @@ -84,3 +124,9 @@ def _collection(self, sort: str) -> Abbrev: case CollectionKind.SET: val = Term(f'SetHook {sorts}') return Abbrev(sort, val, Signature((), Term('Type'))) + + +def _param_sorts(decl: SymbolDecl) -> list[str]: + from ..utils import check_type + + return [check_type(sort, SortApp).name for sort in decl.param_sorts] # TODO eliminate check_type diff --git a/pyk/src/pyk/k2lean4/model.py b/pyk/src/pyk/k2lean4/model.py index fdbc630014..7a5c2af6c0 100644 --- a/pyk/src/pyk/k2lean4/model.py +++ b/pyk/src/pyk/k2lean4/model.py @@ -117,6 +117,101 @@ def __str__(self) -> str: return '\n'.join(lines) +@final +@dataclass(frozen=True) +class Structure(Declaration): + ident: DeclId + signature: Signature | None + extends: tuple[Term, ...] + ctor: StructCtor | None + deriving: tuple[str, ...] + modifiers: Modifiers | None + + def __init__( + self, + ident: str | DeclId, + signature: Signature | None = None, + extends: Iterable[Term] | None = None, + ctor: StructCtor | None = None, + deriving: Iterable[str] | None = None, + modifiers: Modifiers | None = None, + ): + ident = DeclId(ident) if isinstance(ident, str) else ident + extends = tuple(extends) if extends is not None else () + deriving = tuple(deriving) if deriving is not None else () + object.__setattr__(self, 'ident', ident) + object.__setattr__(self, 'signature', signature) + object.__setattr__(self, 'extends', extends) + object.__setattr__(self, 'ctor', ctor) + object.__setattr__(self, 'deriving', deriving) + object.__setattr__(self, 'modifiers', modifiers) + + def __str__(self) -> str: + lines = [] + + modifiers = f'{self.modifiers} ' if self.modifiers else '' + binders = ( + ' '.join(str(binder) for binder in self.signature.binders) + if self.signature and self.signature.binders + else '' + ) + binders = f' {binders}' if binders else '' + extends = ', '.join(str(extend) for extend in self.extends) + extends = f' extends {extends}' if extends else '' + ty = f' : {self.signature.ty}' if self.signature and self.signature.ty else '' + where = ' where' if self.ctor else '' + lines.append(f'{modifiers}structure {self.ident}{binders}{extends}{ty}{where}') + + if self.deriving: + lines.append(f' deriving {self.deriving}') + + if self.ctor: + lines.extend(f' {line}' for line in str(self.ctor).splitlines()) + + return '\n'.join(lines) + + +@final +@dataclass(frozen=True) +class StructCtor: + fields: tuple[Binder, ...] # TODO implement StructField + ident: StructIdent | None + + def __init__( + self, + fields: Iterable[Binder], + ident: str | StructIdent | None = None, + ): + fields = tuple(fields) + ident = StructIdent(ident) if isinstance(ident, str) else ident + object.__setattr__(self, 'fields', fields) + object.__setattr__(self, 'ident', ident) + + def __str__(self) -> str: + lines = [] + if self.ident: + lines.append(f'{self.ident} ::') + for field in self.fields: + if isinstance(field, ExplBinder) and len(field.idents) == 1: + (ident,) = field.idents + ty = '' if field.ty is None else f' : {field.ty}' + lines.append(f'{ident}{ty}') + else: + lines.append(str(field)) + return '\n'.join(lines) + + +@final +@dataclass(frozen=True) +class StructIdent: + ident: str + modifiers: Modifiers | None = None + + def __str__(self) -> str: + modifiers = f'{self.modifiers} ' if self.modifiers else '' + return f'{modifiers}{ self.ident}' + + @final @dataclass(frozen=True) class DeclId: