From 99f84a12e22250810f7910368464060e400d1648 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Wed, 23 Oct 2024 13:44:02 +0100 Subject: [PATCH] fixed deprecation warning about tostring_rgb --- neuralplayground/arenas/discritized_objects.py | 14 ++++++++++---- neuralplayground/arenas/simple2d.py | 14 ++++++++++---- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/neuralplayground/arenas/discritized_objects.py b/neuralplayground/arenas/discritized_objects.py index 0720778..a33f068 100644 --- a/neuralplayground/arenas/discritized_objects.py +++ b/neuralplayground/arenas/discritized_objects.py @@ -413,9 +413,15 @@ def render(self, history_length=30, display=True): history = self.history[-history_length:] ax = self.plot_trajectory(history_data=history, ax=ax) canvas.draw() - image = np.frombuffer(canvas.tostring_rgb(), dtype="uint8") - image = image.reshape(f.canvas.get_width_height()[::-1] + (3,)) - print(image.shape) + width, height = f.canvas.get_width_height() + # Get the RGBA buffer from the canvas + image = np.frombuffer(canvas.buffer_rgba(), dtype="uint8") + image = image.reshape((height, width, 4)) + # Remove the alpha channel (RGBA -> RGB) + image_rgb = image[:, :, :3] + # Convert RGB to BGR for OpenCV + image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR) + print(image_bgr.shape) if display: - cv2.imshow("2D_env", image) + cv2.imshow("2D_env", image_bgr) cv2.waitKey(10) diff --git a/neuralplayground/arenas/simple2d.py b/neuralplayground/arenas/simple2d.py index d98e421..90142f5 100644 --- a/neuralplayground/arenas/simple2d.py +++ b/neuralplayground/arenas/simple2d.py @@ -354,9 +354,15 @@ def render(self, history_length=30, display=True): history = self.history[-history_length:] ax = self.plot_trajectory(history_data=history, ax=ax) canvas.draw() - image = np.frombuffer(canvas.tostring_rgb(), dtype="uint8") - image = image.reshape(f.canvas.get_width_height()[::-1] + (3,)) - print(image.shape) + width, height = f.canvas.get_width_height() + # Get the RGBA buffer from the canvas + image = np.frombuffer(canvas.buffer_rgba(), dtype="uint8") + image = image.reshape((height, width, 4)) + # Remove the alpha channel (RGBA -> RGB) + image_rgb = image[:, :, :3] + # Convert RGB to BGR for OpenCV + image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR) + print(image_bgr.shape) if display: - cv2.imshow("2D_env", image) + cv2.imshow("2D_env", image_bgr) cv2.waitKey(10)