diff --git a/CMakeLists.txt b/CMakeLists.txt index bc21ef31e7..b516031deb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -24,8 +24,8 @@ endif() project (P4C) -# set (CMAKE_CXX_EXTENSIONS OFF) # prefer using -std=c++17 rather than -std=gnu++17 -set (CMAKE_CXX_STANDARD 17) +# set (CMAKE_CXX_EXTENSIONS OFF) # prefer using -std=c++20 rather than -std=gnu++20 +set (CMAKE_CXX_STANDARD 20) set (CMAKE_CXX_STANDARD_REQUIRED ON) set (CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake) diff --git a/frontends/p4/removeReturns.cpp b/frontends/p4/removeReturns.cpp index 3a2c8201f1..53bdacdd54 100644 --- a/frontends/p4/removeReturns.cpp +++ b/frontends/p4/removeReturns.cpp @@ -20,7 +20,7 @@ limitations under the License. namespace P4 { -bool MoveToElseAfterBranch::preorder(IR::BlockStatement *block) { +bool MoveToElseAfterBranch::preorder(IR::COWptr block) { movedToIfBranch = false; for (auto it = block->components.begin(); it != block->components.end();) { if (movedToIfBranch) @@ -32,7 +32,7 @@ bool MoveToElseAfterBranch::preorder(IR::BlockStatement *block) { return false; } -bool MoveToElseAfterBranch::moveFromParentTo(const IR::Statement *&child) { +template bool MoveToElseAfterBranch::moveFromParentTo(T &child) { auto parent = getParent(); size_t next = getContext()->child_index + 1; if (!parent || next >= parent->components.size()) { @@ -43,9 +43,9 @@ bool MoveToElseAfterBranch::moveFromParentTo(const IR::Statement *&child) { IR::BlockStatement *modified = nullptr; if (!child) modified = new IR::BlockStatement; - else if (auto *t = child->to()) + else if (auto *t = child->template to()) modified = t->clone(); - else if (child->is()) + else if (child->template is()) modified = new IR::BlockStatement; else modified = new IR::BlockStatement({child}); @@ -55,7 +55,7 @@ bool MoveToElseAfterBranch::moveFromParentTo(const IR::Statement *&child) { return true; } -bool MoveToElseAfterBranch::preorder(IR::IfStatement *ifStmt) { +bool MoveToElseAfterBranch::preorder(IR::COWptr ifStmt) { hasJumped = false; bool movedCode = false; visit(ifStmt->ifTrue, "ifTrue", 1); @@ -74,11 +74,11 @@ bool MoveToElseAfterBranch::preorder(IR::IfStatement *ifStmt) { return false; } -bool MoveToElseAfterBranch::preorder(IR::SwitchStatement *swch) { +bool MoveToElseAfterBranch::preorder(IR::COWptr swch) { // TBD: if there is exactly one case that falls through (all others end with a branch) // then we could move subsequent code into that case, as it done with 'if' bool canFallThrough = false; - for (auto &c : swch->cases) { + for (auto c : swch->cases) { hasJumped = false; visit(c, "cases"); canFallThrough |= !hasJumped; @@ -87,7 +87,7 @@ bool MoveToElseAfterBranch::preorder(IR::SwitchStatement *swch) { return false; } -void MoveToElseAfterBranch::postorder(IR::LoopStatement *) { +void MoveToElseAfterBranch::postorder(IR::COWptr) { // after a loop body is never unreachable hasJumped = false; } diff --git a/frontends/p4/removeReturns.h b/frontends/p4/removeReturns.h index f13ea82e8f..2a0e186271 100644 --- a/frontends/p4/removeReturns.h +++ b/frontends/p4/removeReturns.h @@ -65,7 +65,7 @@ introduce a boolean flag and extra tests to remove those branches. precondition: switchAddDefault pass has run to ensure switch statements cover all cases */ -class MoveToElseAfterBranch : public Modifier { +class MoveToElseAfterBranch : public COWModifier { /* This pass does not use (inherit from) ControlFlowVisitor, even though it is doing * control flow analysis, as it turns out to be more efficient to do it directly here * by overloading the branching constructs (if/switch/loops) and not cloning the visitor, @@ -80,22 +80,22 @@ class MoveToElseAfterBranch : public Modifier { * indicating that it needs to be removed from the BlockStatment */ bool movedToIfBranch = false; - bool preorder(IR::BlockStatement *) override; - bool moveFromParentTo(const IR::Statement *&child); - bool preorder(IR::IfStatement *) override; - bool preorder(IR::SwitchStatement *) override; - void postorder(IR::LoopStatement *) override; + bool preorder(IR::COWptr) override; + template bool moveFromParentTo(T &child); + bool preorder(IR::COWptr) override; + bool preorder(IR::COWptr) override; + void postorder(IR::COWptr) override; bool branch() { hasJumped = true; // no need to visit children return false; } - bool preorder(IR::BreakStatement *) override { return branch(); } - bool preorder(IR::ContinueStatement *) override { return branch(); } - bool preorder(IR::ExitStatement *) override { return branch(); } - bool preorder(IR::ReturnStatement *) override { return branch(); } + bool preorder(IR::COWptr) override { return branch(); } + bool preorder(IR::COWptr) override { return branch(); } + bool preorder(IR::COWptr) override { return branch(); } + bool preorder(IR::COWptr) override { return branch(); } // Only visit statements, skip all expressions - bool preorder(IR::Expression *) override { return false; } + bool preorder(IR::COWptr) override { return false; } public: MoveToElseAfterBranch() {} diff --git a/ir/copy_on_write_inl.h b/ir/copy_on_write_inl.h new file mode 100644 index 0000000000..743ad8ae93 --- /dev/null +++ b/ir/copy_on_write_inl.h @@ -0,0 +1,238 @@ +/* +Copyright 2024 NVIDIA CORPORATION. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef IR_COPY_ON_WRITE_INL_H_ +#define IR_COPY_ON_WRITE_INL_H_ + +/* template methods declared in "copy_on_write_ptr.h" that can't be defined there + * due to order-of-declaration issues. */ + +namespace P4::IR { + +template const T *COWinfo::get() const { + if (clone) return clone->checkedTo(); + return orig->checkedTo(); +} + +template const T *COWinfo::getOrig() const { + return orig->checkedTo(); +} + +template T *COWinfo::mkClone() { + if (!clone) clone = orig->clone(); + return clone->checkedTo(); +} + +template T *COWinfo::getClone() const { + BUG_CHECK(clone != nullptr, "Not yet cloned in getClone"); + return clone->checkedTo(); +} + +/* specializations for IR::Node pointers */ +template +requires std::derived_from +struct COWfieldref { + COWinfo *info; + + const U *get() const { return info->get()->*field; } + operator const U *() const { return get(); } + const U *operator=(const U *val) const { + if (!info->isCloned() && info->get()->*field == val) return val; + return info->mkClone()->*field = val; + } + void set(const U *val) const { *this = val; } + template requires (std::is_base_of_v && !std::same_as) + void set(const V *val) const { *this = val ? val->template checkedTo() : (U *)nullptr; } + const U *operator->() const { return get(); } + const U &operator*() const { return *get(); } +}; + +template +struct COW_element_ref { + COWinfo *info; + bool is_const; + union { + typename C::const_iterator ci; + typename C::iterator ni; + }; + COW_element_ref(COWinfo *inf, typename C::const_iterator i) + : info(inf), is_const(true) { ci = i; } + COW_element_ref(COWinfo *inf, typename C::iterator i) + : info(inf), is_const(false) { ni = i; } + void clone_fixup() { + if (is_const) { + // messy problem -- need to clone (iff not yet cloned) and then find the + // corresponding iterator in the clone + auto i = (info->mkClone()->*field).begin(); + auto &orig_vec = info->getOrig()->*field; + for (auto oi = orig_vec.begin(); oi != ci; ++oi, ++i) + BUG_CHECK(oi != orig_vec.end(), "Invalid iterator in clone_fixup"); + ni = i; + is_const = false; + } + } + U get() { + if (is_const && info->isCloned()) clone_fixup(); + return *ci; + } + operator U() { + if (is_const && info->isCloned()) clone_fixup(); + return *ci; + } + U operator=(U val) { + clone_fixup(); + return *ni = val; + } + void set(U val) { + clone_fixup(); + *ni = val; + } + void set(const Node *val) requires std::is_pointer_v + { + set(val ? val->checkedTo::type>() : (U)nullptr); + } +}; + +template +struct COW_iterator { + COW_element_ref ref; + COW_iterator(COWinfo *inf, typename C::const_iterator i) : ref(inf, i) {} + COW_iterator(COWinfo *inf, typename C::iterator i) : ref(inf, i) {} + COW_iterator &operator++() { ++ref.ci; return *this; } + COW_iterator operator++(int) { COW_iterator rv = *this; ++ref.ci; return rv; } + COW_iterator &operator--() { --ref.ci; return *this; } + COW_iterator operator--(int) { COW_iterator rv = *this; --ref.ci; return rv; } + bool operator==(const COW_iterator &i) const { return ref.ci == i.ref.ci; } + bool operator!=(const COW_iterator &i) const { return ref.ci != i.ref.ci; } + COW_element_ref &operator *() { return ref; } +}; + + +/* specialization for safe_vector */ +template T::*field> +struct COWfieldref, field> { + COWinfo *info; + + using iterator = COW_iterator, U, field>; + + void visit_children(Visitor &) { BUG("TBD"); } + + const safe_vector &get() const { return info->get()->*field; } + safe_vector &modify() const { return info->mkClone()->*field; } + operator const safe_vector&() const { return get(); } + safe_vector &operator=(const safe_vector &val) const { return modify() = val; } + safe_vector &operator=(safe_vector &&val) const { return modify() = std::move(val); } + iterator begin() { + if (info->isCloned()) + return iterator(info, (info->getClone()->*field).begin()); + else + return iterator(info, (info->get()->*field).begin()); + } + iterator end() { + if (info->isCloned()) + return iterator(info, (info->getClone()->*field).begin()); + else + return iterator(info, (info->get()->*field).begin()); + } + // FIXME need to add insert/appeand/prepend/emplace_back specializations +}; + + +/* specializations for IR::Vector */ +template T::*field> +struct COWfieldref, field> { + COWinfo *info; + + using iterator = COW_iterator, const U *, field>; + + void visit_children(Visitor &) { BUG("TBD"); } + + const Vector &get() const { return info->get()->*field; } + Vector &modify() const { return info->mkClone()->*field; } + operator const Vector&() const { return get(); } + Vector &operator=(const Vector &val) const { return modify() = val; } + Vector &operator=(Vector &&val) const { return modify() = std::move(val); } + iterator begin() { + if (info->isCloned()) + return iterator(info, (info->getClone()->*field).begin()); + else + return iterator(info, (info->get()->*field).begin()); + } + iterator end() { + if (info->isCloned()) + return iterator(info, (info->getClone()->*field).begin()); + else + return iterator(info, (info->get()->*field).begin()); + } + iterator erase(iterator i) { + i.ref.clone_fixup(); + Vector &vec = info->getClone()->*field; + return iterator(info, vec.erase(i.ref.ni)); + } + // FIXME need to add insert/appeand/prepend/emplace_back specializations +}; + +/* specializations for IR::IndexedVector */ +template T::*field> +struct COWfieldref, field> { + COWinfo *info; + + using iterator = COW_iterator, const U *, field>; + + void visit_children(Visitor &) { BUG("TBD"); } + + const IndexedVector &get() const { return info->get()->*field; } + IndexedVector &modify() const { return info->mkClone()->*field; } + operator const IndexedVector&() const { return get(); } + IndexedVector &operator=(const IndexedVector &val) const { return modify() = val; } + IndexedVector &operator=(IndexedVector &&val) const { return modify() = std::move(val); } + iterator begin() { + if (info->isCloned()) + return iterator(info, (info->getClone()->*field).begin()); + else + return iterator(info, (info->get()->*field).begin()); + } + iterator end() { + if (info->isCloned()) + return iterator(info, (info->getClone()->*field).begin()); + else + return iterator(info, (info->get()->*field).begin()); + } + iterator erase(iterator i) { + i.ref.clone_fixup(); + IndexedVector &vec = info->getClone()->*field; + return iterator(info, vec.erase(i.ref.ni)); + } + // FIXME need to add insert/appeand/prepend/emplace_back/removeByName specializations +}; + +/* specializations for IR::NameMap */ +template class MAP, + NameMap T::*field> +struct COWfieldref, field> { + COWinfo *info; + + using iterator = COW_iterator, const U *, field>; + + void visit_children(Visitor &) { BUG("TBD"); } +}; + +// FIXME -- need NodeMap specializations if any backend ever uses that template + +} // namespace P4::IR + +#endif /* IR_COPY_ON_WRITE_INL_H_ */ diff --git a/ir/copy_on_write_ptr.h b/ir/copy_on_write_ptr.h new file mode 100644 index 0000000000..029b6e0fa9 --- /dev/null +++ b/ir/copy_on_write_ptr.h @@ -0,0 +1,119 @@ +/* +Copyright 2024 NVIDIA CORPORATION. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef IR_COPY_ON_WRITE_PTR_H_ +#define IR_COPY_ON_WRITE_PTR_H_ + +#include "lib/exceptions.h" + +namespace P4 { + +class COWModifier; +class COWTransform; +class Visitor; +class Visitor_Context; + +namespace IR { + +class Node; + +template +concept COWref = requires(T r) { + r.set(r.get()); +}; + +class COWNode_info { + protected: + const Node *orig; + Node *clone; + Visitor_Context *ctxt; + + COWNode_info() = delete; + COWNode_info(const COWNode_info &) = delete; + COWNode_info(const Node *n, Visitor_Context *c) : orig(n), clone(nullptr), ctxt(c) {} +}; + +template class COWinfo : public COWNode_info { + public: + const T *get() const; + const T *getOrig() const; + T *mkClone(); + T *getClone() const; + bool isCloned() const { return clone != nullptr; } + private: + friend COWModifier; + friend COWTransform; + COWinfo(const T *n, Visitor_Context *c) : COWNode_info(n, c) {} +}; + +template struct COWfieldref { + COWinfo *info; + + const U &get() const { return *info->get(); } + U &modify() const { return *info->mkClone(); } + operator const U&() const { return get(); } + const U &operator=(const U &val) const { + if (!info->isCloned() && info->get()->*field == val) return val; + return info->mkClone()->*field = val; + } + const U &operator=(U &&val) const { + if (!info->isCloned() && info->get()->*field == val) return val; + return info->mkClone()->*field = std::move(val); + } + void set(const U &val) const { *this = val; } +}; + +template class COWptr { + COWinfo *info; + + template friend class COWptr; + friend COWModifier; + friend COWTransform; + friend T; + COWptr(COWinfo *p) : info(p) {} + COWptr(COWNode_info *p) : info(static_cast *>(p)) { + BUG_CHECK(info->get()->template is(), "incorrect type in COWptr ctor"); + } + +public: + COWptr() = default; + COWptr(const COWptr &) = default; + COWptr(COWptr &&) = default; + template requires std::derived_from + COWptr(const COWptr &a) : COWptr(a.info) {} + ~COWptr() = default; + COWptr &operator=(const COWptr &) = default; + COWptr &operator=(COWptr &&) = default; + + operator const T *() const { return info->get(); } + typename T::COWref operator->() const { return typename T::COWref(info); } + typename T::COWref operator*() const { return typename T::COWref(info); } + void visit_children(Visitor &v) const { info->get()->COW_visit_children(info, v); } + bool apply_visitor_preorder(COWModifier &v) const { + return info->get()->apply_visitor_preorder(info, v); } + void apply_visitor_postorder(COWModifier &v) const { + info->get()->apply_visitor_postorder(info, v); } + const IR::Node *apply_visitor_preorder(COWTransform &v) const { + return info->get()->apply_visitor_preorder(info, v); } + const IR::Node *apply_visitor_postorder(COWTransform &v) const { + return info->get()->apply_visitor_postorder(info, v); } +}; + +} + +} + +#endif /* IR_COPY_ON_WRITE_PTR_H_ */ diff --git a/ir/indexed_vector.h b/ir/indexed_vector.h index a1c40dfb5c..b50c9600db 100644 --- a/ir/indexed_vector.h +++ b/ir/indexed_vector.h @@ -207,6 +207,7 @@ class IndexedVector : public Vector { static cstring static_type_name() { return "IndexedVector<" + T::static_type_name() + ">"; } void visit_children(Visitor &v) override; void visit_children(Visitor &v) const override; + void COW_visit_children(COWNode_info *, Visitor &v) const override; void toJSON(JSONGenerator &json) const override; static IndexedVector *fromJSON(JSONLoader &json); diff --git a/ir/ir-inline.h b/ir/ir-inline.h index 471d4ce812..adaf7db39a 100644 --- a/ir/ir-inline.h +++ b/ir/ir-inline.h @@ -44,6 +44,25 @@ namespace P4 { Node::traceVisit("Mod loop_revisit"); \ v.loop_revisit(this); \ } \ + TEMPLATE INLINE bool IR::CLASS TT::apply_visitor_preorder(COWNode_info *info, \ + COWModifier &v) const { \ + Node::traceVisit("Mod pre (COW)"); \ + return v.preorder(IR::COWptr(info)); \ + } \ + TEMPLATE INLINE void IR::CLASS TT::apply_visitor_postorder(COWNode_info *info, \ + COWModifier &v) const { \ + Node::traceVisit("Mod post (COW)"); \ + v.postorder(IR::COWptr(info)); \ + } \ + TEMPLATE INLINE void IR::CLASS TT::apply_visitor_revisit(COWModifier &v, \ + const Node *n) const { \ + Node::traceVisit("Mod revisit (COW)"); \ + v.revisit(this, n); \ + } \ + TEMPLATE INLINE void IR::CLASS TT::apply_visitor_loop_revisit(COWModifier &v) const { \ + Node::traceVisit("Mod loop_revisit (COW)"); \ + v.loop_revisit(this); \ + } \ TEMPLATE INLINE bool IR::CLASS TT::apply_visitor_preorder(Inspector &v) const { \ Node::traceVisit("Insp pre"); \ return v.preorder(this); \ diff --git a/ir/namemap.h b/ir/namemap.h index b5581fda7a..5498eb0e61 100644 --- a/ir/namemap.h +++ b/ir/namemap.h @@ -148,6 +148,7 @@ class NameMap : public Node { static cstring static_type_name() { return "NameMap<" + T::static_type_name() + ">"; } void visit_children(Visitor &v) override; void visit_children(Visitor &v) const override; + void COW_visit_children(COWNode_info *, Visitor &v) const override; void toJSON(JSONGenerator &json) const override; static NameMap *fromJSON(JSONLoader &json); diff --git a/ir/node.h b/ir/node.h index 53e5eab1b1..3437bd210f 100644 --- a/ir/node.h +++ b/ir/node.h @@ -21,6 +21,7 @@ limitations under the License. #include "ir-tree-macros.h" #include "ir/gen-tree-macro.h" +#include "ir/copy_on_write_ptr.h" #include "lib/castable.h" #include "lib/cstring.h" #include "lib/exceptions.h" @@ -31,7 +32,9 @@ class Visitor; struct Visitor_Context; class Inspector; class Modifier; +class COWModifier; class Transform; +class COWTransform; class JSONGenerator; class JSONLoader; } // namespace P4 @@ -72,11 +75,19 @@ class INode : public Util::IHasSourceInfo, public IHasDbPrint, public ICastable virtual cstring node_type_name() const = 0; virtual void validate() const {} +#if 1 + using ICastable::checkedTo; +#else // default checkedTo implementation for nodes: just fallback to generic ICastable method template std::enable_if_t, const T *> checkedTo() const { return ICastable::checkedTo(); } + template + std::enable_if_t, T *> checkedTo() { + return ICastable::checkedTo(); + } +#endif // alternative checkedTo implementation that produces slightly better error message // due to node_type_name() / static_type_name() being available @@ -87,6 +98,13 @@ class INode : public Util::IHasSourceInfo, public IHasDbPrint, public ICastable T::static_type_name()); return result; } + template + std::enable_if_t, T *> checkedTo() { + auto *result = to(); + BUG_CHECK(result, "Cast failed: %1% with type %2% is not a %3%.", this, node_type_name(), + T::static_type_name()); + return result; + } DECLARE_TYPEINFO_WITH_TYPEID(INode, NodeKind::INode); }; @@ -97,6 +115,10 @@ class Node : public virtual INode { virtual void apply_visitor_postorder(Modifier &v); virtual void apply_visitor_revisit(Modifier &v, const Node *n) const; virtual void apply_visitor_loop_revisit(Modifier &v) const; + virtual bool apply_visitor_preorder(COWNode_info *, COWModifier &v) const ; + virtual void apply_visitor_postorder(COWNode_info *, COWModifier &v) const ; + virtual void apply_visitor_revisit(COWModifier &v, const Node *n) const; + virtual void apply_visitor_loop_revisit(COWModifier &v) const; virtual bool apply_visitor_preorder(Inspector &v) const; virtual void apply_visitor_postorder(Inspector &v) const; virtual void apply_visitor_revisit(Inspector &v) const; @@ -105,6 +127,10 @@ class Node : public virtual INode { virtual const Node *apply_visitor_postorder(Transform &v); virtual void apply_visitor_revisit(Transform &v, const Node *n) const; virtual void apply_visitor_loop_revisit(Transform &v) const; + virtual const Node *apply_visitor_preorder(COWNode_info *, COWTransform &v) const ; + virtual const Node *apply_visitor_postorder(COWNode_info *, COWTransform &v) const ; + virtual void apply_visitor_revisit(COWTransform &v, const Node *n) const; + virtual void apply_visitor_loop_revisit(COWTransform &v) const; Node &operator=(const Node &) = default; Node &operator=(Node &&) = default; @@ -114,7 +140,9 @@ class Node : public virtual INode { friend class ::P4::Visitor; friend class ::P4::Inspector; friend class ::P4::Modifier; + friend class ::P4::COWModifier; friend class ::P4::Transform; + friend class ::P4::COWTransform; cstring prepareSourceInfoForJSON(Util::SourceInfo &si, unsigned *lineNumber, unsigned *columnNumber) const; @@ -161,6 +189,7 @@ class Node : public virtual INode { #undef DEFINE_OPEQ_FUNC virtual void visit_children(Visitor &) {} virtual void visit_children(Visitor &) const {} + virtual void COW_visit_children(COWNode_info *, Visitor &) const = 0; bool operator!=(const Node &n) const { return !operator==(n); } @@ -172,6 +201,16 @@ class Node : public virtual INode { sink.Append(n->toString()); } + union COWref { + private: + COWNode_info *_info; + public: + COWfieldref srcInfo; + COWref(COWNode_info *i) { _info = i; } + COWref *operator->() { return this; } + void visit_children(Visitor &) {} + }; + DECLARE_TYPEINFO_WITH_TYPEID(Node, NodeKind::Node, INode); }; @@ -198,21 +237,29 @@ inline bool equiv(const INode *a, const INode *b) { IRNODE_COMMON_SUBCLASS(T) // NOLINTEND(bugprone-macro-parentheses) -#define IRNODE_COMMON_SUBCLASS(T) \ - public: \ - using Node::operator==; \ - bool apply_visitor_preorder(Modifier &v) override; \ - void apply_visitor_postorder(Modifier &v) override; \ - void apply_visitor_revisit(Modifier &v, const Node *n) const override; \ - void apply_visitor_loop_revisit(Modifier &v) const override; \ - bool apply_visitor_preorder(Inspector &v) const override; \ - void apply_visitor_postorder(Inspector &v) const override; \ - void apply_visitor_revisit(Inspector &v) const override; \ - void apply_visitor_loop_revisit(Inspector &v) const override; \ - const Node *apply_visitor_preorder(Transform &v) override; \ - const Node *apply_visitor_postorder(Transform &v) override; \ - void apply_visitor_revisit(Transform &v, const Node *n) const override; \ - void apply_visitor_loop_revisit(Transform &v) const override; +#define IRNODE_COMMON_SUBCLASS(T) \ + public: \ + using Node::operator==; \ + bool apply_visitor_preorder(Modifier &v) override; \ + void apply_visitor_postorder(Modifier &v) override; \ + void apply_visitor_revisit(Modifier &v, const Node *n) const override; \ + void apply_visitor_loop_revisit(Modifier &v) const override; \ + bool apply_visitor_preorder(COWNode_info *, COWModifier &v) const override; \ + void apply_visitor_postorder(COWNode_info *, COWModifier &v) const override; \ + void apply_visitor_revisit(COWModifier &v, const Node *n) const override; \ + void apply_visitor_loop_revisit(COWModifier &v) const override; \ + bool apply_visitor_preorder(Inspector &v) const override; \ + void apply_visitor_postorder(Inspector &v) const override; \ + void apply_visitor_revisit(Inspector &v) const override; \ + void apply_visitor_loop_revisit(Inspector &v) const override; \ + const Node *apply_visitor_preorder(Transform &v) override; \ + const Node *apply_visitor_postorder(Transform &v) override; \ + void apply_visitor_revisit(Transform &v, const Node *n) const override; \ + void apply_visitor_loop_revisit(Transform &v) const override; \ + const Node *apply_visitor_preorder(COWNode_info *, COWTransform &v) const override; \ + const Node *apply_visitor_postorder(COWNode_info *, COWTransform &v) const override; \ + void apply_visitor_revisit(COWTransform &v, const Node *n) const override; \ + void apply_visitor_loop_revisit(COWTransform &v) const override; /* only define 'apply' for a limited number of classes (those we want to call * visitors directly on), as defining it and making it virtual would mean that diff --git a/ir/vector.h b/ir/vector.h index 6e680413d8..cbaf230cd3 100644 --- a/ir/vector.h +++ b/ir/vector.h @@ -203,6 +203,7 @@ class Vector : public VectorBase { static cstring static_type_name() { return "Vector<" + T::static_type_name() + ">"; } void visit_children(Visitor &v) override; void visit_children(Visitor &v) const override; + void COW_visit_children(COWNode_info *, Visitor &v) const override; virtual void parallel_visit_children(Visitor &v); virtual void parallel_visit_children(Visitor &v) const; void toJSON(JSONGenerator &json) const override; diff --git a/ir/visitor.cpp b/ir/visitor.cpp index 4224c25c17..d007b1c4f0 100644 --- a/ir/visitor.cpp +++ b/ir/visitor.cpp @@ -365,6 +365,11 @@ Visitor::profile_t Modifier::init_apply(const IR::Node *root) { visited = std::make_shared(forceClone); return rv; } +Visitor::profile_t COWModifier::init_apply(const IR::Node *root) { + auto rv = Visitor::init_apply(root); + visited = std::make_shared(forceClone); + return rv; +} Visitor::profile_t Inspector::init_apply(const IR::Node *root) { auto rv = Visitor::init_apply(root); visited = std::make_shared(); @@ -375,6 +380,11 @@ Visitor::profile_t Transform::init_apply(const IR::Node *root) { visited = std::make_shared(forceClone); return rv; } +Visitor::profile_t COWTransform::init_apply(const IR::Node *root) { + auto rv = Visitor::init_apply(root); + visited = std::make_shared(forceClone); + return rv; +} void Visitor::end_apply() {} void Visitor::end_apply(const IR::Node *) {} @@ -402,11 +412,15 @@ Visitor::profile_t::~profile_t() { void Inspector::visitOnce() const { visited->visitOnce(getOriginal()); } void Modifier::visitOnce() const { visited->visitOnce(getOriginal()); } +void COWModifier::visitOnce() const { visited->visitOnce(getOriginal()); } void Transform::visitOnce() const { visited->visitOnce(getOriginal()); } +void COWTransform::visitOnce() const { visited->visitOnce(getOriginal()); } void Inspector::visitAgain() const { visited->visitAgain(getOriginal()); } void Modifier::visitAgain() const { visited->visitAgain(getOriginal()); } +void COWModifier::visitAgain() const { visited->visitAgain(getOriginal()); } void Transform::visitAgain() const { visited->visitAgain(getOriginal()); } +void COWTransform::visitAgain() const { visited->visitAgain(getOriginal()); } void Visitor::print_context() const { std::ostream &out = std::cout; @@ -428,10 +442,18 @@ void Modifier::visitor_const_error() { BUG("Modifier called const visit function -- missing template " "instantiation in gen-tree-macro.h?"); } +void COWModifier::visitor_const_error() { + BUG("Modifier called const visit function -- missing template " + "instantiation in gen-tree-macro.h?"); +} void Transform::visitor_const_error() { BUG("Transform called const visit function -- missing template " "instantiation in gen-tree-macro.h?"); } +void COWTransform::visitor_const_error() { + BUG("Transform called const visit function -- missing template " + "instantiation in gen-tree-macro.h?"); +} struct PushContext { Visitor::Context current; @@ -507,6 +529,43 @@ const IR::Node *Modifier::apply_visitor(const IR::Node *n, const char *name) { return n; } +const IR::Node *COWModifier::apply_visitor(const IR::Node *n, const char *name) { + if (ctxt && name) ctxt->child_name = name; + if (n) { + PushContext local(ctxt, n); + switch (visited->try_start(n, visitDagOnce)) { + case VisitStatus::Busy: + n->apply_visitor_loop_revisit(*this); + // FIXME -- should have a way of updating the node? Needs to be decided + // by the visitor somehow, but it is tough + break; + case VisitStatus::Done: + n->apply_visitor_revisit(*this, visited->result(n)); + n = visited->result(n); + break; + default: { // New or Revisit + IR::COWinfo clone_info(n, &local.current); + IR::COWptr cloner(&clone_info); + if (!dontForwardChildrenBeforePreorder) { + ForwardChildren forward_children(*visited); + cloner.visit_children(forward_children); + } + if (cloner.apply_visitor_preorder(*this)) { + cloner.visit_children(*this); + cloner.apply_visitor_postorder(*this); + } + if (visited->finish(n, cloner)) (n = cloner)->validate(); + break; + } + } + } + if (ctxt) + ctxt->child_index++; + else + visited.reset(); + return n; +} + const IR::Node *Inspector::apply_visitor(const IR::Node *n, const char *name) { if (ctxt && name) ctxt->child_name = name; if (n && !join_flows(n)) { @@ -606,8 +665,12 @@ void Inspector::revisit_visited() { visited->revisit_visited(); } bool Inspector::visit_in_progress(const IR::Node *n) const { return visited->busy(n); } void Modifier::revisit_visited() { visited->revisit_visited(); } bool Modifier::visit_in_progress(const IR::Node *n) const { return visited->busy(n); } +void COWModifier::revisit_visited() { visited->revisit_visited(); } +bool COWModifier::visit_in_progress(const IR::Node *n) const { return visited->busy(n); } void Transform::revisit_visited() { visited->revisit_visited(); } bool Transform::visit_in_progress(const IR::Node *n) const { return visited->busy(n); } +void COWTransform::revisit_visited() { visited->revisit_visited(); } +bool COWTransform::visit_in_progress(const IR::Node *n) const { return visited->busy(n); } #define DEFINE_VISIT_FUNCTIONS(CLASS, BASE) \ bool Modifier::preorder(IR::CLASS *n) { return preorder(static_cast(n)); } \ @@ -618,6 +681,16 @@ bool Transform::visit_in_progress(const IR::Node *n) const { return visited->bus void Modifier::loop_revisit(const IR::CLASS *o) { \ loop_revisit(static_cast(o)); \ } \ + bool COWModifier::preorder(IR::COWptr n) { \ + return preorder(IR::COWptr(n)); \ + } \ + void COWModifier::postorder(IR::COWptr n) { postorder(IR::COWptr(n)); } \ + void COWModifier::revisit(const IR::CLASS *o, const IR::CLASS *n) { \ + revisit(static_cast(o), static_cast(n)); \ + } \ + void COWModifier::loop_revisit(const IR::CLASS *o) { \ + loop_revisit(static_cast(o)); \ + } \ bool Inspector::preorder(const IR::CLASS *n) { \ return preorder(static_cast(n)); \ } \ @@ -637,6 +710,18 @@ bool Transform::visit_in_progress(const IR::Node *n) const { return visited->bus } \ void Transform::loop_revisit(const IR::CLASS *o) { \ return loop_revisit(static_cast(o)); \ + } \ + const IR::Node *COWTransform::preorder(IR::COWptr n) { \ + return preorder(IR::COWptr(n)); \ + } \ + const IR::Node *COWTransform::postorder(IR::COWptr n) { \ + return postorder(IR::COWptr(n)); \ + } \ + void COWTransform::revisit(const IR::CLASS *o, const IR::Node *n) { \ + return revisit(static_cast(o), n); \ + } \ + void COWTransform::loop_revisit(const IR::CLASS *o) { \ + return loop_revisit(static_cast(o)); \ } IRNODE_ALL_SUBCLASSES(DEFINE_VISIT_FUNCTIONS) @@ -736,11 +821,21 @@ bool Modifier::check_clone(const Visitor *v) { BUG_CHECK(t && t->visited == visited, "Clone failed to copy base object"); return Visitor::check_clone(v); } +bool COWModifier::check_clone(const Visitor *v) { + auto *t = dynamic_cast(v); + BUG_CHECK(t && t->visited == visited, "Clone failed to copy base object"); + return Visitor::check_clone(v); +} bool Transform::check_clone(const Visitor *v) { auto *t = dynamic_cast(v); BUG_CHECK(t && t->visited == visited, "Clone failed to copy base object"); return Visitor::check_clone(v); } +bool COWTransform::check_clone(const Visitor *v) { + auto *t = dynamic_cast(v); + BUG_CHECK(t && t->visited == visited, "Clone failed to copy base object"); + return Visitor::check_clone(v); +} ControlFlowVisitor &ControlFlowVisitor::flow_clone() { auto *rv = clone(); diff --git a/ir/visitor.h b/ir/visitor.h index 30d39addee..1b58956a42 100644 --- a/ir/visitor.h +++ b/ir/visitor.h @@ -33,6 +33,7 @@ limitations under the License. #include "ir/gen-tree-macro.h" #include "ir/ir-tree-macros.h" #include "ir/node.h" +#include "ir/copy_on_write_ptr.h" #include "ir/vector.h" #include "lib/castable.h" #include "lib/cstring.h" @@ -166,6 +167,13 @@ class Visitor { } n.visit_children(*this); } + template requires IR::COWref + void visit(COW ref, const char *name = 0) { + auto o = ref.get(); + auto n = apply_visitor(o, name); + if (n != o) ref.set(n); + } + template void parallel_visit(IR::Vector &v, const char *name = 0) { if (name && ctxt) ctxt->child_name = name; @@ -378,7 +386,9 @@ class Visitor { const Context *ctxt = nullptr; // should be readonly to subclasses friend class Inspector; friend class Modifier; + friend class COWModifier; friend class Transform; + friend class COWTransform; friend class ControlFlowVisitor; }; @@ -410,6 +420,34 @@ class Modifier : public virtual Visitor { bool forceClone = false; // force clone whole tree even if unchanged }; +class COWModifier : public virtual Visitor { + std::shared_ptr visited; + void visitor_const_error() override; + bool check_clone(const Visitor *) override; + + public: + profile_t init_apply(const IR::Node *root) override; + const IR::Node *apply_visitor(const IR::Node *n, const char *name = 0) override; + virtual bool preorder(IR::COWptr) { return true; } + virtual void postorder(IR::COWptr) {} + virtual void revisit(const IR::Node *, const IR::Node *) {} + virtual void loop_revisit(const IR::Node *) { BUG("IR loop detected"); } +#define DECLARE_VISIT_FUNCTIONS(CLASS, BASE) \ + virtual bool preorder(IR::COWptr); \ + virtual void postorder(IR::COWptr); \ + virtual void revisit(const IR::CLASS *, const IR::CLASS *); \ + virtual void loop_revisit(const IR::CLASS *); + IRNODE_ALL_SUBCLASSES(DECLARE_VISIT_FUNCTIONS) +#undef DECLARE_VISIT_FUNCTIONS + void revisit_visited(); + bool visit_in_progress(const IR::Node *) const; + void visitOnce() const override; + void visitAgain() const override; + + protected: + bool forceClone = false; // force clone whole tree even if unchanged +}; + class Inspector : public virtual Visitor { std::shared_ptr visited; bool check_clone(const Visitor *) override; @@ -470,6 +508,42 @@ class Transform : public virtual Visitor { bool forceClone = false; // force clone whole tree even if unchanged }; +class COWTransform : public virtual Visitor { + std::shared_ptr visited; + bool prune_flag = false; + void visitor_const_error() override; + bool check_clone(const Visitor *) override; + + public: + profile_t init_apply(const IR::Node *root) override; + const IR::Node *apply_visitor(const IR::Node *, const char *name = 0) override; + virtual const IR::Node *preorder(IR::COWptr n) { return n; } + virtual const IR::Node *postorder(IR::COWptr n) { return n; } + virtual void revisit(const IR::Node *, const IR::Node *) {} + virtual void loop_revisit(const IR::Node *) { BUG("IR loop detected"); } +#define DECLARE_VISIT_FUNCTIONS(CLASS, BASE) \ + virtual const IR::Node *preorder(IR::COWptr); \ + virtual const IR::Node *postorder(IR::COWptr); \ + virtual void revisit(const IR::CLASS *, const IR::Node *); \ + virtual void loop_revisit(const IR::CLASS *); + IRNODE_ALL_SUBCLASSES(DECLARE_VISIT_FUNCTIONS) +#undef DECLARE_VISIT_FUNCTIONS + void revisit_visited(); + bool visit_in_progress(const IR::Node *) const; + void visitOnce() const override; + void visitAgain() const override; + // can only be called usefully from a 'preorder' function (directly or indirectly) + void prune() { prune_flag = true; } + + protected: + const IR::Node *transform_child(const IR::Node *child) { + auto *rv = apply_visitor(child); + prune_flag = true; + return rv; + } + bool forceClone = false; // force clone whole tree even if unchanged +}; + // turn this on for extra info tracking control joinFlows for debugging #define DEBUG_FLOW_JOIN 0 diff --git a/tools/ir-generator/irclass.cpp b/tools/ir-generator/irclass.cpp index 013ca5d48d..0b09d0f855 100644 --- a/tools/ir-generator/irclass.cpp +++ b/tools/ir-generator/irclass.cpp @@ -125,6 +125,7 @@ void IrDefinitions::generate(std::ostream &t, std::ostream &out, std::ostream &i << "#include \n\n" << "#include \"lib/big_int.h\" // IWYU pragma: keep\n" << "// Special IR classes and types\n" + << "#include \"ir/copy_on_write_ptr.h\" // IWYU pragma: keep\n" << "#include \"ir/dbprint.h\" // IWYU pragma: keep\n" << "#include \"ir/id.h\" // IWYU pragma: keep\n" << "#include \"ir/indexed_vector.h\" // IWYU pragma: keep\n" @@ -132,6 +133,8 @@ void IrDefinitions::generate(std::ostream &t, std::ostream &out, std::ostream &i << "#include \"ir/node.h\" // IWYU pragma: keep\n" << "#include \"ir/nodemap.h\" // IWYU pragma: keep\n" << "#include \"ir/vector.h\" // IWYU pragma: keep\n" + << "// copy_on_write_inl.h must be after vector.h and indexed_vector.h\n" + << "#include \"ir/copy_on_write_inl.h\" // IWYU pragma: keep\n" << "#include \"lib/ordered_map.h\" // IWYU pragma: keep\n" << std::endl << "namespace P4 {\n" @@ -185,6 +188,9 @@ void IrDefinitions::generate(std::ostream &t, std::ostream &out, std::ostream &i e->generate_impl(impl); } + for (auto cls : *getClasses()) + cls->outputCOWref(out); + out << "#endif /* " << macroname << " */" << std::endl; ///////////////////////////////// tree @@ -357,10 +363,31 @@ void IrMethod::generate_impl(std::ostream &out) const { out << LineDirective(srcInfo); generate_proto(out, true, false); out << " const " << body << std::endl; + if (clss->kind == NodeKind::Concrete || clss->kind == NodeKind::Template) + outputCOWref_visit_children(out); } if (srcInfo.isValid()) out << LineDirective(); } +void IrMethod::outputCOWref_visit_children(std::ostream &out) const { + out << LineDirective(srcInfo); + out << "void IR::" << clss->containedIn << clss->name + << "::COWref::visit_children(Visitor &v) "; + // Since we can't call the base class visit_children directly (as COWrefs are unions + // and C++ does not support inheritance for unions), we need to insert a `reinterpret_cast` + // into the body code where that happens. This is safe as the base class COWref has + // a subset of the fields of the derived class COWref and is compatible + const char *p = body.c_str(); + while (const char *vccall = std::strstr(p, "::visit_children(")) { + const char *t = vccall; + while (t > p && (std::isalnum(t[-1]) || t[-1] == '_' || t[-1] == ':')) t--; + out.write(p, t - p) << "reinterpret_cast<"; + out.write(t, vccall - t) << "::COWref *>(this)->"; + p = vccall + 2; + } + out << p << std::endl; +} + //////////////////////////////////////////////////////////////////////////////////// void IrApply::generate_hdr(std::ostream &out) const { @@ -408,6 +435,43 @@ cstring IrClass::qualified_name(const IrNamespace *in) const { return rv; } +IrElement::access_t IrClass::outputCOWfieldrefs(std::ostream &out) const { + auto access = IrElement::Private; + if (concreteParent) { + access = concreteParent->outputCOWfieldrefs(out); + } else { + out << (access = IrElement::Public); + out << indent << "COWfieldref srcInfo;\n"; + } + for (auto e : elements) { + if (auto *fld = e->to()) { + if (fld->isStatic) continue; + if (e->access != access) out << indent << (access = e->access); + out << indent << "COWfieldref<" << name << ", "; + const IrClass *cls = fld->type->resolve(fld->clss ? fld->clss->containedIn : nullptr); + if (cls != nullptr && !fld->isInline) out << "const "; + out << fld->type->toString(); + if (cls != nullptr && !fld->isInline) out << "*"; + out << fld->type->declSuffix() << ", &" << name << "::" << fld->name << "> " + << fld->name << ";\n"; + } + } + return access; +} + +void IrClass::outputCOWref(std::ostream &out) const { + if (kind != NodeKind::Concrete && kind != NodeKind::Template) return; + out << "union P4::IR::" << name << "::COWref {\n"; + out << IrElement::Private; + out << indent << "COWNode_info *_info;\n"; + if (outputCOWfieldrefs(out) != IrElement::Public) + out << IrElement::Public; + out << indent << "COWref(COWNode_info *i) { _info = i; }\n"; + out << indent << "COWref *operator->() { return this; }\n"; + out << indent << "void visit_children(Visitor &);\n"; + out << "};\n"; +} + void IrClass::generate_hdr(std::ostream &out) const { if (kind != NodeKind::Nested) { out << "namespace P4::IR {" << std::endl; @@ -418,7 +482,10 @@ void IrClass::generate_hdr(std::ostream &out) const { bool concreteParent = false; for (auto p : parentClasses) { - if (p->kind != NodeKind::Interface) concreteParent = true; + if (p->kind != NodeKind::Interface) { + BUG_CHECK(!concreteParent && p == this->concreteParent, "inconsisten concreteParent"); + concreteParent = true; + } } const char *sep = " : "; @@ -448,6 +515,10 @@ void IrClass::generate_hdr(std::ostream &out) const { out << indent << "IRNODE" << (kind == NodeKind::Abstract ? "_ABSTRACT" : "") << "_SUBCLASS(" << name << ")" << std::endl; + if (kind == NodeKind::Concrete || kind == NodeKind::Template) { + out << indent << "union COWref;\n"; + } + auto *irNamespace = IrNamespace::get(nullptr, "IR"_cs); if (kind != NodeKind::Nested) { out << indent << "DECLARE_TYPEINFO_WITH_TYPEID(" << name diff --git a/tools/ir-generator/irclass.h b/tools/ir-generator/irclass.h index 0588b07a4c..9fc9e030fd 100644 --- a/tools/ir-generator/irclass.h +++ b/tools/ir-generator/irclass.h @@ -139,6 +139,7 @@ class IrMethod : public IrElement { void generate_proto(std::ostream &, bool, bool) const; void generate_hdr(std::ostream &) const override; void generate_impl(std::ostream &) const override; + void outputCOWref_visit_children(std::ostream &out) const; struct info_t { const Type *rtype; std::vector args; @@ -343,6 +344,8 @@ class IrClass : public IrElement { void generate_hdr(std::ostream &out) const override; void generate_impl(std::ostream &out) const override; void generateTreeMacro(std::ostream &out) const; + access_t outputCOWfieldrefs(std::ostream &out) const; + void outputCOWref(std::ostream &out) const; void resolve() override; cstring toString() const override { return name; } std::string fullName() const; diff --git a/tools/ir-generator/methods.cpp b/tools/ir-generator/methods.cpp index 5ba1bf436b..51c1720da1 100644 --- a/tools/ir-generator/methods.cpp +++ b/tools/ir-generator/methods.cpp @@ -220,6 +220,14 @@ const ordered_map IrMethod::Generate = { buf << "}"; return needed ? buf : cstring(); }}}, + {"COW_visit_children"_cs, + {&NamedType::Void(), + {new IrField(new PointerType(&NamedType::COWNode_info()), "info"_cs), + new IrField(&ReferenceType::VisitorRef, "v"_cs)}, + CONST + IN_IMPL + OVERRIDE + CONCRETE_ONLY, + [](IrClass *, Util::SourceInfo, cstring) -> cstring { + return ""_cs; + }}}, {"validate"_cs, {&NamedType::Void(), {}, diff --git a/tools/ir-generator/type.cpp b/tools/ir-generator/type.cpp index ae04c42a39..0ad1c453c0 100644 --- a/tools/ir-generator/type.cpp +++ b/tools/ir-generator/type.cpp @@ -127,6 +127,11 @@ NamedType &NamedType::SourceInfo() { return nt; } +NamedType &NamedType::COWNode_info() { + static NamedType nt("COWNode_info"_cs); + return nt; +} + cstring NamedType::toString() const { if (resolved) return resolved->fullName(); if (!lookup && name == "ID") return "IR::ID"_cs; // hack -- ID is in namespace P4::IR diff --git a/tools/ir-generator/type.h b/tools/ir-generator/type.h index f8b5df094e..9356ecd1bb 100644 --- a/tools/ir-generator/type.h +++ b/tools/ir-generator/type.h @@ -111,6 +111,7 @@ class NamedType : public Type { static NamedType &JSONLoader(); static NamedType &JSONObject(); static NamedType &SourceInfo(); + static NamedType &COWNode_info(); }; class TemplateInstantiation : public Type {