Skip to content

Commit

Permalink
Teach functions inliner to inline into if condition (#5073)
Browse files Browse the repository at this point in the history
Signed-off-by: Anton Korobeynikov <anton@korobeynikov.info>
  • Loading branch information
asl authored Jan 4, 2025
1 parent 52697c9 commit 057fe94
Show file tree
Hide file tree
Showing 8 changed files with 207 additions and 2 deletions.
29 changes: 27 additions & 2 deletions frontends/p4/functionsInlining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,14 @@ void DiscoverFunctionsInlining::postorder(const IR::MethodCallExpression *mce) {
auto stat = findContext<IR::Statement>();
CHECK_NULL(stat);

BUG_CHECK(bool(RTTI::isAny<IR::MethodCallStatement, IR::AssignmentStatement>(stat)),
"%1%: unexpected statement with call", stat);
BUG_CHECK(
bool(RTTI::isAny<IR::MethodCallStatement, IR::AssignmentStatement, IR::IfStatement>(stat)),
"%1%: unexpected statement with call", stat);

if (const auto *ifStat = stat->to<IR::IfStatement>()) {
// Check that we're inside condition of IfStatement
BUG_CHECK(isInContext(ifStat->condition), "%1%: unexpected statement with call", stat);
}

auto aci = new FunctionCallInfo(caller, ac->function, stat, mce);
toInline->add(aci);
Expand Down Expand Up @@ -180,6 +186,18 @@ const IR::Node *FunctionsInliner::preorder(IR::AssignmentStatement *statement) {
return inlineBefore(callee, callExpr, statement);
}

const IR::Node *FunctionsInliner::preorder(IR::IfStatement *statement) {
auto orig = getOriginal<IR::IfStatement>();
LOG2("Visiting " << dbp(orig));
auto replMap = getReplacementMap();
if (replMap == nullptr) return statement;

auto [callee, callExpr] = get(*replMap, orig);
if (callee == nullptr) return statement;
BUG_CHECK(callExpr != nullptr, "%1%: expected a method call", statement->condition);
return inlineBefore(callee, callExpr, statement);
}

const IR::Node *FunctionsInliner::preorder(IR::P4Parser *parser) {
if (preCaller()) {
parser->visit_children(*this);
Expand Down Expand Up @@ -312,6 +330,13 @@ const IR::Statement *FunctionsInliner::inlineBefore(const IR::Node *calleeNode,
auto [it, inserted] = pendingReplacements.emplace(mce, retExpr);
BUG_CHECK(inserted, "%1%: duplicate value for pending replacements", it->first);
}
} else if (const auto *ifStatement = statement->to<IR::IfStatement>()) {
// We already checked that we are inside condition of if statement. Add
// return value to pending list to be replaced afterwards
CHECK_NULL(retExpr);
body.push_back(ifStatement->clone());
auto [it, inserted] = pendingReplacements.emplace(mce, retExpr);
BUG_CHECK(inserted, "%1%: duplicate value for pending replacements", it->first);
} else {
BUG_CHECK(statement->is<IR::MethodCallStatement>(), "%1%: expected a method call",
statement);
Expand Down
1 change: 1 addition & 0 deletions frontends/p4/functionsInlining.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class FunctionsInliner : public AbstractInliner<FunctionsInlineList, FunctionsIn
const IR::Node *preorder(IR::MethodCallStatement *statement) override;
const IR::Node *preorder(IR::MethodCallExpression *expr) override;
const IR::Node *preorder(IR::AssignmentStatement *statement) override;
const IR::Node *preorder(IR::IfStatement *statement) override;
};

typedef InlineDriver<FunctionsInlineList, FunctionsInlineWorkList> InlineFunctionsDriver;
Expand Down
33 changes: 33 additions & 0 deletions testdata/p4_16_samples/inline-function2.p4
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
bit foo(in bit a) {
return a + 1;
}

control p(inout bit bt) {
action a(inout bit y0, bit y1) {
bit y2 = y1 > 0 ? 1w1 : 0;
if (y2 == 1) {
y0 = 0;
} else if (y1 != 1) {
y0 = y0 | 1w1;
}
}

action b() {
a(bt, foo(bt));
a(bt, 1);
}

table t {
actions = { b; }
default_action = b;
}

apply {
t.apply();
}
}

control simple<T>(inout T arg);
package m<T>(simple<T> pipe);

m(p()) main;
30 changes: 30 additions & 0 deletions testdata/p4_16_samples_outputs/inline-function2-first.p4
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
bit<1> foo(in bit<1> a) {
return a + 1w1;
}
control p(inout bit<1> bt) {
action a(inout bit<1> y0, bit<1> y1) {
bit<1> y2 = (y1 > 1w0 ? 1w1 : 1w0);
if (y2 == 1w1) {
y0 = 1w0;
} else if (y1 != 1w1) {
y0 = y0 | 1w1;
}
}
action b() {
a(bt, foo(bt));
a(bt, 1w1);
}
table t {
actions = {
b();
}
default_action = b();
}
apply {
t.apply();
}
}

control simple<T>(inout T arg);
package m<T>(simple<T> pipe);
m<bit<1>>(p()) main;
53 changes: 53 additions & 0 deletions testdata/p4_16_samples_outputs/inline-function2-frontend.p4
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
control p(inout bit<1> bt) {
@name("p.y2") bit<1> y2_0;
@name("p.tmp") bit<1> tmp;
@name("p.y0") bit<1> y0;
@name("p.y0") bit<1> y0_1;
@name("p.a_0") bit<1> a;
@name("p.retval") bit<1> retval;
@name("p.a_1") bit<1> a_2;
@name("p.retval") bit<1> retval_1;
@name("p.b") action b() {
y0 = bt;
a = bt;
retval = a + 1w1;
if (retval > 1w0) {
tmp = 1w1;
} else {
tmp = 1w0;
}
y2_0 = tmp;
if (y2_0 == 1w1) {
y0 = 1w0;
} else {
a_2 = bt;
retval_1 = a_2 + 1w1;
if (retval_1 != 1w1) {
y0 = y0 | 1w1;
}
}
bt = y0;
y0_1 = bt;
tmp = 1w1;
y2_0 = tmp;
if (y2_0 == 1w1) {
y0_1 = 1w0;
} else {
;
}
bt = y0_1;
}
@name("p.t") table t_0 {
actions = {
b();
}
default_action = b();
}
apply {
t_0.apply();
}
}

control simple<T>(inout T arg);
package m<T>(simple<T> pipe);
m<bit<1>>(p()) main;
33 changes: 33 additions & 0 deletions testdata/p4_16_samples_outputs/inline-function2-midend.p4
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
control p(inout bit<1> bt) {
@name("p.tmp") bit<1> tmp;
@name("p.y0") bit<1> y0;
@name("p.b") action b() {
y0 = bt;
if (bt + 1w1 > 1w0) {
tmp = 1w1;
} else {
tmp = 1w0;
}
if (tmp == 1w1) {
y0 = 1w0;
} else if (bt + 1w1 != 1w1) {
y0 = bt | 1w1;
}
bt = y0;
tmp = 1w1;
bt = 1w0;
}
@name("p.t") table t_0 {
actions = {
b();
}
default_action = b();
}
apply {
t_0.apply();
}
}

control simple<T>(inout T arg);
package m<T>(simple<T> pipe);
m<bit<1>>(p()) main;
30 changes: 30 additions & 0 deletions testdata/p4_16_samples_outputs/inline-function2.p4
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
bit<1> foo(in bit<1> a) {
return a + 1;
}
control p(inout bit<1> bt) {
action a(inout bit<1> y0, bit<1> y1) {
bit<1> y2 = (y1 > 0 ? 1w1 : 0);
if (y2 == 1) {
y0 = 0;
} else if (y1 != 1) {
y0 = y0 | 1w1;
}
}
action b() {
a(bt, foo(bt));
a(bt, 1);
}
table t {
actions = {
b;
}
default_action = b;
}
apply {
t.apply();
}
}

control simple<T>(inout T arg);
package m<T>(simple<T> pipe);
m(p()) main;
Empty file.

0 comments on commit 057fe94

Please sign in to comment.