Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create proof trace event for tail call information #1179

Merged
merged 10 commits into from
Dec 12, 2024
7 changes: 7 additions & 0 deletions bindings/python/ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,13 @@ void bind_proof_trace(py::module_ &m) {
"function_name",
&llvm_pattern_matching_failure_event::get_function_name);

py::class_<
llvm_function_exit_event, std::shared_ptr<llvm_function_exit_event>>(
proof_trace, "llvm_function_exit_event", step_event)
.def_property_readonly(
"rule_ordinal", &llvm_function_exit_event::get_rule_ordinal)
.def_property_readonly("is_tail", &llvm_function_exit_event::is_tail);

py::class_<llvm_function_event, std::shared_ptr<llvm_function_event>>(
proof_trace, "llvm_function_event", step_event)
.def_property_readonly("name", &llvm_function_event::get_name)
Expand Down
3 changes: 3 additions & 0 deletions docs/proof-trace.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ event ::= hook
| side_cond_exit
| config
| pattern_matching_failure
| function_exit

arg ::= kore_term

Expand All @@ -60,6 +61,8 @@ rule ::= WORD(0x22) ordinal arity variable*
side_cond_entry ::= WORD(0xEE) ordinal arity variable*
side_cond_exit ::= WORD(0x33) ordinal boolean_result

function_exit ::= WORD(0x55) ordinal boolean_result

config ::= WORD(0xFF) kore_term

string ::= <c-style null terminated string>
Expand Down
47 changes: 47 additions & 0 deletions include/kllvm/binary/ProofTraceParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ constexpr uint64_t rule_event_sentinel = detail::word(0x22);
constexpr uint64_t side_condition_event_sentinel = detail::word(0xEE);
constexpr uint64_t side_condition_end_sentinel = detail::word(0x33);
constexpr uint64_t pattern_matching_failure_sentinel = detail::word(0x44);
constexpr uint64_t function_exit_sentinel = detail::word(0x55);

class llvm_step_event : public std::enable_shared_from_this<llvm_step_event> {
public:
Expand Down Expand Up @@ -172,6 +173,29 @@ class llvm_pattern_matching_failure_event : public llvm_step_event {
const override;
};

class llvm_function_exit_event : public llvm_step_event {
private:
uint64_t rule_ordinal_;
bool is_tail_;

llvm_function_exit_event(uint64_t rule_ordinal, bool is_tail)
: rule_ordinal_(rule_ordinal)
, is_tail_(is_tail) { }

public:
static sptr<llvm_function_exit_event>
create(uint64_t rule_ordinal, bool is_tail) {
return sptr<llvm_function_exit_event>(
new llvm_function_exit_event(rule_ordinal, is_tail));
}

[[nodiscard]] uint64_t get_rule_ordinal() const { return rule_ordinal_; }
[[nodiscard]] bool is_tail() const { return is_tail_; }

void print(std::ostream &out, bool expand_terms, unsigned indent = 0U)
const override;
};

class llvm_event;

class llvm_function_event : public llvm_step_event {
Expand Down Expand Up @@ -599,6 +623,27 @@ class proof_trace_parser {
return event;
}

sptr<llvm_function_exit_event> static parse_function_exit(
proof_trace_buffer &buffer) {
if (!buffer.check_word(function_exit_sentinel)) {
return nullptr;
}

uint64_t ordinal = 0;
if (!buffer.read_uint64(ordinal)) {
return nullptr;
}

bool is_tail = false;
if (!buffer.read_bool(is_tail)) {
return nullptr;
}

auto event = llvm_function_exit_event::create(ordinal, is_tail);

return event;
}

bool parse_argument(proof_trace_buffer &buffer, llvm_event &event) {
if (buffer.eof() || buffer.peek() != '\x7F') {
return false;
Expand Down Expand Up @@ -634,6 +679,8 @@ class proof_trace_parser {
case pattern_matching_failure_sentinel:
return parse_pattern_matching_failure(buffer);

case function_exit_sentinel: return parse_function_exit(buffer);

default: return nullptr;
}
}
Expand Down
108 changes: 103 additions & 5 deletions include/kllvm/codegen/ProofEvent.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
#include "kllvm/ast/AST.h"
#include "kllvm/codegen/Decision.h"
#include "kllvm/codegen/DecisionParser.h"
#include "kllvm/codegen/Options.h"
#include "kllvm/codegen/Util.h"

#include "llvm/IR/Instructions.h"

#include <fmt/format.h>

#include <map>
#include <tuple>

Expand All @@ -21,27 +24,58 @@ class proof_event {

/*
* Load the boolean flag that controls whether proof hint output is enabled or
* not, then create a branch at the end of this basic block depending on the
* result.
* not, then create a branch at the specified location depending on the
* result. The location can be before a given instruction or at the end of a
* given basic block.
*
* Returns a pair of blocks [proof enabled, merge]; the first of these is
* intended for self-contained behaviour only relevant in proof output mode,
* while the second is for the continuation of the interpreter's previous
* behaviour.
*/
template <typename Location>
std::pair<llvm::BasicBlock *, llvm::BasicBlock *>
proof_branch(std::string const &label, llvm::BasicBlock *insert_at_end);
proof_branch(std::string const &label, Location *insert_loc);

/*
* Return the parent function of the given location.

* Template specializations for llvm::Instruction and llvm::BasicBlock.
*/
template <typename Location>
llvm::Function *get_parent_function(Location *loc);

/*
* Return the parent basic block of the given location.

* Template specializations for llvm::Instruction and llvm::BasicBlock.
*/
template <typename Location>
llvm::BasicBlock *get_parent_block(Location *loc);

/*
* If the given location is an Instruction, this method moves the instruction
* to the merge block.
* If the given location is a BasicBlock, this method simply emits a no-op
* instruction to the merge block.

* Template specializations for llvm::Instruction and llvm::BasicBlock.
*/
template <typename Location>
void fix_insert_loc(Location *loc, llvm::BasicBlock *merge_block);

/*
* Set up a standard event prelude by creating a pair of basic blocks for the
* proof output and continuation, then loading the output filename from its
* global.
* global. The location for the prelude can be before a given instruction or
* at the end of a given basic block.
*
* Returns a triple [proof enabled, merge, proof_writer]; see `proofBranch`
* and `emitGetOutputFileName`.
*/
template <typename Location>
std::tuple<llvm::BasicBlock *, llvm::BasicBlock *, llvm::Value *>
event_prelude(std::string const &label, llvm::BasicBlock *insert_at_end);
event_prelude(std::string const &label, Location *insert_loc);

/*
* Set up a check of whether a new proof hint chunk should be started. The
Expand Down Expand Up @@ -172,6 +206,13 @@ class proof_event {
llvm::Value *proof_writer, std::string const &function_name,
llvm::BasicBlock *insert_at_end);

/*
* Emit a call to the `function_exit` API of the specified `proof_writer`.
*/
llvm::CallInst *emit_write_function_exit(
llvm::Value *proof_writer, uint64_t ordinal, bool is_tail,
llvm::BasicBlock *insert_at_end);

/*
* Emit a call to the `start_new_chunk` API of the specified `proof_writer`.
*/
Expand Down Expand Up @@ -228,6 +269,10 @@ class proof_event {
[[nodiscard]] llvm::BasicBlock *pattern_matching_failure(
kore_composite_pattern const &pattern, llvm::BasicBlock *current_block);

template <typename Location>
[[nodiscard]] llvm::BasicBlock *
function_exit(uint64_t ordinal, bool is_tail, Location *insert_loc);

proof_event(kore_definition *definition, llvm::Module *module)
: definition_(definition)
, module_(module)
Expand All @@ -236,4 +281,57 @@ class proof_event {

} // namespace kllvm

//===----------------------------------------------------------------------===//
// Implementation for method templates
//===----------------------------------------------------------------------===//

template <typename Location>
std::pair<llvm::BasicBlock *, llvm::BasicBlock *>
kllvm::proof_event::proof_branch(
std::string const &label, Location *insert_loc) {
auto *i1_ty = llvm::Type::getInt1Ty(ctx_);

auto *proof_output_flag = module_->getOrInsertGlobal("proof_output", i1_ty);
auto *proof_output = new llvm::LoadInst(
i1_ty, proof_output_flag, "proof_output", insert_loc);

auto *f = get_parent_function(insert_loc);
auto *true_block
= llvm::BasicBlock::Create(ctx_, fmt::format("if_{}", label), f);
auto *merge_block
= llvm::BasicBlock::Create(ctx_, fmt::format("tail_{}", label), f);

llvm::BranchInst::Create(true_block, merge_block, proof_output, insert_loc);

fix_insert_loc(insert_loc, merge_block);

return {true_block, merge_block};
}

template <typename Location>
std::tuple<llvm::BasicBlock *, llvm::BasicBlock *, llvm::Value *>
kllvm::proof_event::event_prelude(
std::string const &label, Location *insert_loc) {
auto [true_block, merge_block] = proof_branch(label, insert_loc);
return {true_block, merge_block, emit_get_proof_trace_writer(true_block)};
}

template <typename Location>
llvm::BasicBlock *kllvm::proof_event::function_exit(
uint64_t ordinal, bool is_tail, Location *insert_loc) {

if (!proof_hint_instrumentation) {
return get_parent_block(insert_loc);
}

auto [true_block, merge_block, proof_writer]
= event_prelude("function_exit", insert_loc);

emit_write_function_exit(proof_writer, ordinal, is_tail, true_block);

llvm::BranchInst::Create(merge_block, true_block);

return merge_block;
}

#endif // PROOF_EVENT_H
2 changes: 2 additions & 0 deletions include/runtime/header.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ void write_side_condition_event_post_to_proof_trace(
void *proof_writer, uint64_t ordinal, bool side_cond_result);
void write_pattern_matching_failure_to_proof_trace(
void *proof_writer, char const *function_name);
void write_function_exit_to_proof_trace(
void *proof_writer, uint64_t ordinal, bool is_tail);
void write_configuration_to_proof_trace(
void *proof_writer, block *config, bool is_initial);
void start_new_chunk_in_proof_trace(void *proof_writer);
Expand Down
23 changes: 23 additions & 0 deletions include/runtime/proof_trace_writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class proof_trace_writer {
side_condition_event_post(uint64_t ordinal, bool side_cond_result)
= 0;
virtual void pattern_matching_failure(char const *function_name) = 0;
virtual void function_exit(uint64_t ordinal, bool is_tail) = 0;
virtual void configuration(block *config, bool is_initial) = 0;
virtual void start_new_chunk() = 0;
virtual void end_of_trace() = 0;
Expand Down Expand Up @@ -163,6 +164,12 @@ class proof_trace_file_writer : public proof_trace_writer {
write_null_terminated_string(function_name);
}

void function_exit(uint64_t ordinal, bool is_tail) override {
write_uint64(kllvm::function_exit_sentinel);
write_uint64(ordinal);
write_bool(is_tail);
}

void configuration(block *config, bool is_initial) override {
write_uint64(kllvm::config_sentinel);
serialize_configuration_to_proof_trace(file_, config, 0);
Expand Down Expand Up @@ -227,6 +234,15 @@ class proof_trace_callback_writer : public proof_trace_writer {
, result(result) { }
};

struct function_exit_construction {
uint64_t ordinal;
bool is_tail;

function_exit_construction(uint64_t ordinal, bool is_tail)
: ordinal(ordinal)
, is_tail(is_tail) { }
};

struct call_event_construction {
char const *hook_name;
char const *symbol_name;
Expand Down Expand Up @@ -281,6 +297,8 @@ class proof_trace_callback_writer : public proof_trace_writer {
side_condition_result_construction const &event) { }
virtual void pattern_matching_failure_callback(
pattern_matching_failure_construction const &event) { }
virtual void function_exit_callback(function_exit_construction const &event) {
}
virtual void configuration_event_callback(
kore_configuration_construction const &config, bool is_initial) { }

Expand Down Expand Up @@ -366,6 +384,11 @@ class proof_trace_callback_writer : public proof_trace_writer {
pattern_matching_failure_callback(pm_failure);
}

void function_exit(uint64_t ordinal, bool is_tail) override {
function_exit_construction function_exit(ordinal, is_tail);
function_exit_callback(function_exit);
}

void configuration(block *config, bool is_initial) override {
kore_configuration_construction configuration(config);
configuration_event_callback(configuration, is_initial);
Expand Down
8 changes: 8 additions & 0 deletions lib/binary/ProofTraceParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,14 @@ void llvm_pattern_matching_failure_event::print(
"{}pattern matching failure: {}\n", indent, function_name_);
}

void llvm_function_exit_event::print(
std::ostream &out, bool expand_terms, unsigned ind) const {
std::string indent(ind * indent_size, ' ');
out << fmt::format(
"{}function exit: {} {}\n", indent, rule_ordinal_,
(is_tail_ ? "tail" : "notail"));
}

void llvm_function_event::print(
std::ostream &out, bool expand_terms, unsigned ind) const {
std::string indent(ind * indent_size, ' ');
Expand Down
21 changes: 21 additions & 0 deletions lib/codegen/CreateTerm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1174,6 +1174,7 @@ bool can_tail_call(llvm::Type *type) {
return int_type->getBitWidth() <= 192;
}

// NOLINTNEXTLINE(*-cognitive-complexity)
bool make_function(
std::string const &name, kore_pattern *pattern, kore_definition *definition,
llvm::Module *module, bool tailcc, bool big_step, bool apply,
Expand Down Expand Up @@ -1276,6 +1277,10 @@ bool make_function(
call->setTailCallKind(llvm::CallInst::TCK_MustTail);
retval = call;
} else {
size_t ordinal = 0;
if (apply) {
ordinal = std::stoll(name.substr(11));
}
if (auto *call = llvm::dyn_cast<llvm::CallInst>(retval)) {
// check that musttail requirements are met:
// 1. Call is in tail position (guaranteed)
Expand All @@ -1286,6 +1291,22 @@ bool make_function(
if (call->getCallingConv() == llvm::CallingConv::Tail
&& can_tail_call(call->getType())) {
call->setTailCallKind(llvm::CallInst::TCK_MustTail);
if (apply) {
current_block
= proof_event(definition, module)
.function_exit(
ordinal, true, llvm::dyn_cast<llvm::Instruction>(call));
}
} else {
if (apply) {
current_block = proof_event(definition, module)
.function_exit(ordinal, false, current_block);
}
}
} else {
if (apply) {
current_block = proof_event(definition, module)
.function_exit(ordinal, false, current_block);
}
}
}
Expand Down
1 change: 1 addition & 0 deletions lib/codegen/Decision.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,7 @@ void leaf_node::codegen(decision *d) {
d->current_block_
= proof_event(d->definition_, d->module_)
.rewrite_event_pre(axiom, arity, vars, subst, d->current_block_);
// maybe report here as part of the rule event whether a tail call happened

if (d->profile_matching_) {
llvm::CallInst::Create(
Expand Down
Loading
Loading