diff --git a/examples/lockless_example/main.cc b/examples/lockless_example/main.cc index 229e121..e22e8b7 100644 --- a/examples/lockless_example/main.cc +++ b/examples/lockless_example/main.cc @@ -128,12 +128,11 @@ class NonRTThread : public Thread { int main() { Context ctx; - auto rt_thread = std::make_shared(ctx); - auto non_rt_thread = std::make_shared(ctx); App app; - app.RegisterThread(rt_thread); - app.RegisterThread(non_rt_thread); + + auto rt_thread = app.CreateThread(ctx); + auto non_rt_thread = app.CreateThread(ctx); app.Start(); app.Join(); diff --git a/examples/logging_example/main.cc b/examples/logging_example/main.cc index 7b6cef1..89c8b6d 100644 --- a/examples/logging_example/main.cc +++ b/examples/logging_example/main.cc @@ -37,8 +37,6 @@ int main() { thread_config.cpu_affinity = std::vector{2}; thread_config.SetFifoScheduler(80); - auto thread = std::make_shared("ExampleRTThread", thread_config); - // Create a cactus_rt app configuration cactus_rt::AppConfig app_config; @@ -54,7 +52,7 @@ int main() { app_config.logger_config = logging_config; App app("LoggingExampleApp", app_config); - app.RegisterThread(thread); + auto thread = app.CreateThread("ExampleRTThread", thread_config); constexpr unsigned int time = 5; std::cout << "Testing RT loop for " << time << " seconds.\n"; diff --git a/examples/message_passing_example/main.cc b/examples/message_passing_example/main.cc index 3db98e5..48a6c84 100644 --- a/examples/message_passing_example/main.cc +++ b/examples/message_passing_example/main.cc @@ -6,12 +6,10 @@ using cactus_rt::App; int main() { - auto data_logger = std::make_shared("build/data.csv"); - auto rt_thread = std::make_shared(data_logger); - App app; - app.RegisterThread(data_logger); - app.RegisterThread(rt_thread); + + auto data_logger = app.CreateThread("build/data.csv"); + auto rt_thread = app.CreateThread(data_logger); app.Start(); rt_thread->Join(); // This thread will terminate on its own. diff --git a/examples/mutex_example/main.cc b/examples/mutex_example/main.cc index cfcc43c..c2c5e96 100644 --- a/examples/mutex_example/main.cc +++ b/examples/mutex_example/main.cc @@ -80,12 +80,10 @@ void ThreadedDemo() { // into the thread and maintain the object lifetime to this function. NaiveDoubleBuffer buf; - auto rt_thread = std::make_shared("RTThread", rt_thread_config, buf); - auto non_rt_thread = std::make_shared("NonRTThread", non_rt_thread_config, buf); - App app; + App app; - app.RegisterThread(non_rt_thread); - app.RegisterThread(rt_thread); + auto rt_thread = app.CreateThread("RTThread", rt_thread_config, buf); + auto non_rt_thread = app.CreateThread("NonRTThread", non_rt_thread_config, buf); constexpr unsigned int time = 10; app.Start(); diff --git a/examples/ros2/publisher/complex_data.cc b/examples/ros2/publisher/complex_data.cc index 0a14f9b..85bbd4a 100644 --- a/examples/ros2/publisher/complex_data.cc +++ b/examples/ros2/publisher/complex_data.cc @@ -106,7 +106,6 @@ int main(int argc, const char* argv[]) { std::cout << "Testing RT loop for " << time.count() << " seconds.\n"; auto thread = app.CreateROS2EnabledThread(time); - app.RegisterThread(thread); app.Start(); diff --git a/examples/ros2/publisher/simple_data.cc b/examples/ros2/publisher/simple_data.cc index 8f896b0..4d02267 100644 --- a/examples/ros2/publisher/simple_data.cc +++ b/examples/ros2/publisher/simple_data.cc @@ -66,7 +66,6 @@ int main(int argc, const char* argv[]) { std::cout << "Testing RT loop for " << time.count() << " seconds.\n"; auto thread = app.CreateROS2EnabledThread(time); - app.RegisterThread(thread); app.Start(); diff --git a/examples/ros2/subscriber/complex_data.cc b/examples/ros2/subscriber/complex_data.cc index c6f51f4..9f31ac9 100644 --- a/examples/ros2/subscriber/complex_data.cc +++ b/examples/ros2/subscriber/complex_data.cc @@ -86,7 +86,6 @@ int main(int argc, const char* argv[]) { std::cout << "Testing RT loop for " << time.count() << " seconds.\n"; auto thread = app.CreateROS2EnabledThread(time); - app.RegisterThread(thread); app.Start(); diff --git a/examples/ros2/subscriber/simple_data.cc b/examples/ros2/subscriber/simple_data.cc index 9fc4ea9..3363f47 100644 --- a/examples/ros2/subscriber/simple_data.cc +++ b/examples/ros2/subscriber/simple_data.cc @@ -61,7 +61,6 @@ int main(int argc, const char* argv[]) { std::cout << "Testing RT loop for " << time.count() << " seconds.\n"; auto thread = app.CreateROS2EnabledThread(time); - app.RegisterThread(thread); app.Start(); diff --git a/examples/signal_handling_example/main.cc b/examples/signal_handling_example/main.cc index 5faae95..ed41d6c 100644 --- a/examples/signal_handling_example/main.cc +++ b/examples/signal_handling_example/main.cc @@ -31,10 +31,8 @@ int main() { config.cpu_affinity = std::vector{2}; config.SetFifoScheduler(80); - auto thread = std::make_shared("ExampleRTThread", config); App app; - - app.RegisterThread(thread); + auto thread = app.CreateThread("ExampleRTThread", config); // Sets up the signal handlers for SIGINT and SIGTERM (by default). cactus_rt::SetUpTerminationSignalHandler(); diff --git a/examples/simple_deadline_example/main.cc b/examples/simple_deadline_example/main.cc index 56eb4b5..c60bbc8 100644 --- a/examples/simple_deadline_example/main.cc +++ b/examples/simple_deadline_example/main.cc @@ -35,10 +35,9 @@ int main() { config.period_ns = 1'000'000; config.SetDeadlineScheduler(500'000 /* runtime */, 1'000'000 /* deadline*/); - auto thread = std::make_shared("ExampleRTThread", config); App app; + auto thread = app.CreateThread("ExampleRTThread", config); - app.RegisterThread(thread); constexpr unsigned int time = 5; std::cout << "Testing RT loop for " << time << " seconds.\n"; diff --git a/examples/simple_example/main.cc b/examples/simple_example/main.cc index 4775ccd..70f3f73 100644 --- a/examples/simple_example/main.cc +++ b/examples/simple_example/main.cc @@ -65,12 +65,8 @@ int main() { // We first create cactus_rt App object. App app; - // We then create a thread object - auto thread = std::make_shared(); - - // We then register the thread with the app, which allows the app to start, - // stop, and join the thread via App::Start, App::RequestStop, and App::Join. - app.RegisterThread(thread); + // We then create a thread object. + auto thread = app.CreateThread(); constexpr unsigned int time = 5; std::cout << "Testing RT loop for " << time << " seconds.\n"; diff --git a/examples/tracing_example/main.cc b/examples/tracing_example/main.cc index f588c6a..8e3701e 100644 --- a/examples/tracing_example/main.cc +++ b/examples/tracing_example/main.cc @@ -88,7 +88,6 @@ class SecondRTThread : public CyclicThread { protected: LoopControl Loop(int64_t /*now*/) noexcept final { - const auto span = Tracer().WithSpan("Sense"); WasteTime(std::chrono::microseconds(2000)); return LoopControl::Continue; } @@ -98,11 +97,10 @@ int main() { cactus_rt::AppConfig app_config; app_config.tracer_config.trace_aggregator_cpu_affinity = {0}; // doesn't work yet - auto thread1 = std::make_shared(); - auto thread2 = std::make_shared(); - App app("TracingExampleApp", app_config); - app.RegisterThread(thread1); - app.RegisterThread(thread2); + App app("TracingExampleApp", app_config); + + auto thread1 = app.CreateThread(); + auto thread2 = app.CreateThread(); std::cout << "Testing RT loop for 15 seconds with two trace sessions.\n"; diff --git a/include/cactus_rt/app.h b/include/cactus_rt/app.h index c1e66c6..c2907ad 100644 --- a/include/cactus_rt/app.h +++ b/include/cactus_rt/app.h @@ -33,9 +33,9 @@ class App { TracerConfig tracer_config_; - std::vector> threads_; + std::shared_ptr trace_aggregator_; // Must be above threads_ to guarantee destructor order. - std::shared_ptr trace_aggregator_; + std::vector> threads_; void SetDefaultLogFormat(quill::Config& cfg) { // Create a handler of stdout @@ -62,19 +62,19 @@ class App { App(App&&) noexcept = delete; App& operator=(App&&) noexcept = delete; - /** - * @brief Registers a thread to be automatically started by the app. The start - * order of the threads are in the order of registration. - * - * @param thread A shared ptr to a thread. - */ - void RegisterThread(std::shared_ptr thread); + template + std::shared_ptr CreateThread(Args&&... args) { + static_assert(std::is_base_of_v, "Must derive from cactus_rt::Thread"); + std::shared_ptr thread = std::make_shared(std::forward(args)...); - /** - * @brief Sets up the trace aggregator. Call this before starting the thread - * if you don't want to call RegisterThread and maintain tracing capabilities. - */ - void SetupTraceAggregator(Thread& thread); + Thread* base_thread = thread.get(); + base_thread->trace_aggregator_ = trace_aggregator_; + base_thread->created_by_app_ = true; + + threads_.push_back(thread); + + return thread; + } /** * @brief Starts the app by locking the memory and reserving the memory. Also diff --git a/include/cactus_rt/cyclic_thread.h b/include/cactus_rt/cyclic_thread.h index 9dca435..97d1edf 100644 --- a/include/cactus_rt/cyclic_thread.h +++ b/include/cactus_rt/cyclic_thread.h @@ -14,8 +14,10 @@ class CyclicThread : public Thread { Stop, }; + protected: /** - * @brief Create a cyclic thread + * @brief Create a cyclic thread. + * * @param name The thread name * @param config A cactus_rt::CyclicThreadConfig that specifies configuration parameters for this thread */ @@ -23,7 +25,6 @@ class CyclicThread : public Thread { period_ns_(config.period_ns) { } - protected: void Run() noexcept final; /** diff --git a/include/cactus_rt/ros2/app.h b/include/cactus_rt/ros2/app.h index 848b9fc..acb76f7 100644 --- a/include/cactus_rt/ros2/app.h +++ b/include/cactus_rt/ros2/app.h @@ -61,7 +61,7 @@ class App : public cactus_rt::App { template std::shared_ptr CreateROS2EnabledThread(Args&&... args) { static_assert(std::is_base_of_v, "Must derive ROS2 thread from Ros2ThreadMixin"); - std::shared_ptr thread = std::make_shared(std::forward(args)...); + std::shared_ptr thread = CreateThread(std::forward(args)...); thread->SetRos2Adapter(ros2_adapter_); thread->InitializeForRos2(); diff --git a/include/cactus_rt/thread.h b/include/cactus_rt/thread.h index ef25e62..336a3f3 100644 --- a/include/cactus_rt/thread.h +++ b/include/cactus_rt/thread.h @@ -18,7 +18,13 @@ namespace cactus_rt { /// @private constexpr size_t kDefaultStackSize = 8 * 1024 * 1024; // 8MB default stack space should be plenty +class App; + class Thread { + friend class App; + + bool created_by_app_ = false; // A guard to prevent users to create the thread without using App::CreateThread. + ThreadConfig config_; std::string name_; std::vector cpu_affinity_; @@ -39,11 +45,10 @@ class Thread { static void* RunThread(void* data); // Non-owning TraceAggregator pointer. Used only for notifying that the thread - // has started/stopped for tracing purposes. Set by Thread::Start and read at - // the beginning of Thread::RunThread. + // has started/stopped for tracing purposes. Set by App::CreateThread. std::weak_ptr trace_aggregator_; - public: + protected: /** * Creates a new thread. * @@ -61,6 +66,7 @@ class Thread { } } + public: /** * Returns the name of the thread * @@ -111,16 +117,6 @@ class Thread { */ void Start(int64_t start_monotonic_time_ns); - /** - * @brief Sets the trace_aggregator_ pointer so the thread can notify the - * trace_aggregator_ when it starts. This should only be called by App. - * - * @private - */ - inline void SetTraceAggregator(std::weak_ptr trace_aggregator) { - trace_aggregator_ = trace_aggregator; - } - protected: inline quill::Logger* Logger() const { return logger_; } diff --git a/src/cactus_rt/app.cc b/src/cactus_rt/app.cc index fb8b2ac..ede9ee7 100644 --- a/src/cactus_rt/app.cc +++ b/src/cactus_rt/app.cc @@ -16,15 +16,6 @@ using FileSink = cactus_rt::tracing::FileSink; namespace cactus_rt { -void App::SetupTraceAggregator(Thread& thread) { - thread.SetTraceAggregator(trace_aggregator_); -} - -void App::RegisterThread(std::shared_ptr thread) { - SetupTraceAggregator(*thread); - threads_.push_back(thread); -} - App::App(std::string name, AppConfig config) : name_(name), heap_size_(config.heap_size), diff --git a/src/cactus_rt/ros2/app.cc b/src/cactus_rt/ros2/app.cc index a824b51..5cddc98 100644 --- a/src/cactus_rt/ros2/app.cc +++ b/src/cactus_rt/ros2/app.cc @@ -54,7 +54,6 @@ App::App( // Must initialize rclcpp before making the Ros2Adapter; ros2_adapter_ = std::make_shared(name, ros2_adapter_config); ros2_executor_thread_ = CreateROS2EnabledThread(); - SetupTraceAggregator(*ros2_executor_thread_); } App::~App() { diff --git a/src/cactus_rt/thread.cc b/src/cactus_rt/thread.cc index 8b25120..89f21c7 100644 --- a/src/cactus_rt/thread.cc +++ b/src/cactus_rt/thread.cc @@ -15,6 +15,11 @@ namespace cactus_rt { void* Thread::RunThread(void* data) { auto* thread = static_cast(data); + + if (!thread->created_by_app_) { + throw std::runtime_error(std::string("do not create Thread manually, use App::CreateThread to create thread") + thread->name_); + } + thread->config_.scheduler->SetSchedAttr(); pthread_setname_np(pthread_self(), thread->name_.c_str()); @@ -25,7 +30,11 @@ void* Thread::RunThread(void* data) { if (auto trace_aggregator = thread->trace_aggregator_.lock()) { trace_aggregator->RegisterThreadTracer(thread->tracer_); } else { - LOG_WARNING(thread->Logger(), "thread {} does not have app_ and tracing is disabled for this thread. Did you call App::RegisterThread?", thread->name_); + LOG_WARNING( + thread->Logger(), + "thread {} does not have app_ and tracing is disabled for this thread. Did the App/Thread go out of scope before the thread is launched?", + thread->name_ + ); } quill::preallocate(); // Pre-allocates thread-local data to avoid the need to allocate on the first log message diff --git a/tests/tracing/multi_threaded_test.cc b/tests/tracing/multi_threaded_test.cc index 9fcf468..8e39374 100644 --- a/tests/tracing/multi_threaded_test.cc +++ b/tests/tracing/multi_threaded_test.cc @@ -24,16 +24,12 @@ class MultiThreadTracingTest : public ::testing::Test { } protected: - cactus_rt::App app_; - std::shared_ptr regular_thread_; - std::shared_ptr cyclic_thread_; - std::shared_ptr sink_; + cactus_rt::App app_; + std::shared_ptr sink_; public: MultiThreadTracingTest() : app_(kAppName, CreateAppConfig()), - regular_thread_(std::make_shared()), - cyclic_thread_(std::make_shared()), sink_(std::make_shared()) {} protected: @@ -51,19 +47,19 @@ class MultiThreadTracingTest : public ::testing::Test { }; TEST_F(MultiThreadTracingTest, TraceFromMultipleThreads) { - app_.RegisterThread(regular_thread_); - app_.RegisterThread(cyclic_thread_); + auto regular_thread = app_.CreateThread(); + auto cyclic_thread = app_.CreateThread(); app_.Start(); - regular_thread_->RunOneIteration([](MockRegularThread* self) { + regular_thread->RunOneIteration([](MockRegularThread* self) { self->TracerForTest().InstantEvent("Event1"); WasteTime(std::chrono::microseconds(1000)); }); - cyclic_thread_->Join(); - regular_thread_->RequestStop(); - regular_thread_->Join(); + cyclic_thread->Join(); + regular_thread->RequestStop(); + regular_thread->Join(); app_.StopTraceSession(); @@ -138,7 +134,7 @@ TEST_F(MultiThreadTracingTest, TraceFromMultipleThreads) { TEST_F(MultiThreadTracingTest, CyclicThreadTracesLoop) { // TODO: move the configuration for the number of loops and time per loop here // so it's easier to check the assertions are working. - app_.RegisterThread(cyclic_thread_); + auto cyclic_thread = app_.CreateThread(); app_.Start(); // The cyclic thread should shutdown on its own. @@ -196,12 +192,11 @@ TEST_F(MultiThreadTracingTest, CyclicThreadTracesSleepAndDoesNotTraceLoopIfConfi const char* thread_name = "CustomCyclicThread"; - auto cyclic_thread = std::make_shared( + auto cyclic_thread = app_.CreateThread( thread_name, tracer_config ); - app_.RegisterThread(cyclic_thread); app_.Start(); // The cyclic thread should shutdown on its own. @@ -266,7 +261,7 @@ TEST_F(MultiThreadTracingTest, CyclicThreadTracesLoopOverrun) { const char* thread_name = "CustomCyclicThread"; - auto cyclic_thread = std::make_shared( + auto cyclic_thread = app_.CreateThread( thread_name, tracer_config, [](int64_t num_iterations) { @@ -275,8 +270,6 @@ TEST_F(MultiThreadTracingTest, CyclicThreadTracesLoopOverrun) { } } ); - - app_.RegisterThread(cyclic_thread); app_.Start(); // The cyclic thread should shutdown on its own. diff --git a/tests/tracing/single_threaded_test.cc b/tests/tracing/single_threaded_test.cc index baf1f4a..5f15ade 100644 --- a/tests/tracing/single_threaded_test.cc +++ b/tests/tracing/single_threaded_test.cc @@ -31,12 +31,11 @@ class SingleThreadTracingTest : public ::testing::Test { public: SingleThreadTracingTest() : app_(kAppName, CreateAppConfig()), - regular_thread_(std::make_shared()), + regular_thread_(app_.CreateThread()), sink_(std::make_shared()) {} protected: void SetUp() override { - app_.RegisterThread(regular_thread_); app_.StartTraceSession(sink_); // TODO: make each test manually start the trace session! app_.Start(); while (!regular_thread_->Started()) {