Skip to content

Commit

Permalink
fixed deprecation warning about tostring_rgb
Browse files Browse the repository at this point in the history
  • Loading branch information
niksirbi committed Oct 24, 2024
1 parent ddcd22a commit 99f84a1
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
14 changes: 10 additions & 4 deletions neuralplayground/arenas/discritized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,9 +413,15 @@ def render(self, history_length=30, display=True):
history = self.history[-history_length:]
ax = self.plot_trajectory(history_data=history, ax=ax)
canvas.draw()
image = np.frombuffer(canvas.tostring_rgb(), dtype="uint8")
image = image.reshape(f.canvas.get_width_height()[::-1] + (3,))
print(image.shape)
width, height = f.canvas.get_width_height()
# Get the RGBA buffer from the canvas
image = np.frombuffer(canvas.buffer_rgba(), dtype="uint8")
image = image.reshape((height, width, 4))
# Remove the alpha channel (RGBA -> RGB)
image_rgb = image[:, :, :3]
# Convert RGB to BGR for OpenCV
image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
print(image_bgr.shape)
if display:
cv2.imshow("2D_env", image)
cv2.imshow("2D_env", image_bgr)
cv2.waitKey(10)
14 changes: 10 additions & 4 deletions neuralplayground/arenas/simple2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,9 +354,15 @@ def render(self, history_length=30, display=True):
history = self.history[-history_length:]
ax = self.plot_trajectory(history_data=history, ax=ax)
canvas.draw()
image = np.frombuffer(canvas.tostring_rgb(), dtype="uint8")
image = image.reshape(f.canvas.get_width_height()[::-1] + (3,))
print(image.shape)
width, height = f.canvas.get_width_height()
# Get the RGBA buffer from the canvas
image = np.frombuffer(canvas.buffer_rgba(), dtype="uint8")
image = image.reshape((height, width, 4))
# Remove the alpha channel (RGBA -> RGB)
image_rgb = image[:, :, :3]
# Convert RGB to BGR for OpenCV
image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
print(image_bgr.shape)
if display:
cv2.imshow("2D_env", image)
cv2.imshow("2D_env", image_bgr)
cv2.waitKey(10)

0 comments on commit 99f84a1

Please sign in to comment.