Skip to content

Commit

Permalink
update logic
Browse files Browse the repository at this point in the history
  • Loading branch information
budzianowski committed Dec 3, 2024
1 parent 8bb1f9a commit f3b7b34
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 126 deletions.
2 changes: 1 addition & 1 deletion sim/h5_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Dict

import h5py
import matplotlib.pyplot as plt
import numpy as np
from datetime import datetime

Expand Down Expand Up @@ -117,6 +116,7 @@ def _plot_dataset(name: str, data: np.ndarray):
name (str): Name of the dataset.
data (np.ndarray): Data to be plotted.
"""
import matplotlib.pyplot as plt # dependency issues with python 3.8
plt.figure(figsize=(10, 5))
if data.ndim == 2: # Handle multi-dimensional data
for i in range(data.shape[1]):
Expand Down
161 changes: 36 additions & 125 deletions sim/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
"""Play a trained policy in the environment.
Run:
python sim/play.py --task g1 --log_h5
python sim/play.py --task stompymini --log_h5
python sim/play.py --task stompypro --log_h5
"""
import argparse
import copy
Expand All @@ -13,6 +12,7 @@
import uuid
from datetime import datetime
from typing import Any, Union
import math

import cv2
import h5py
Expand All @@ -21,7 +21,6 @@
import torch # isort: skip
from tqdm import tqdm

import krec
from sim.env import run_dir # noqa: E402
from sim.envs import task_registry # noqa: E402
from sim.model_export import ActorCfg, convert_model_to_onnx # noqa: E402
Expand All @@ -45,7 +44,7 @@ def play(args: argparse.Namespace) -> None:
logger.info("Configuring environment and training settings...")
env_cfg, train_cfg = task_registry.get_cfgs(name=args.task)

num_parallel_envs = 4
num_parallel_envs = 2
env_cfg.env.num_envs = num_parallel_envs
env_cfg.sim.max_gpu_contact_pairs = 2**10 * num_parallel_envs

Expand Down Expand Up @@ -88,84 +87,31 @@ def play(args: argparse.Namespace) -> None:
robot_index = 0
joint_index = 1
env_steps_to_run = 1000

now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
if args.log_h5:
# Create directory for HDF5 files
h5_dir = run_dir() / "h5_out" / args.task / now
h5_dir.mkdir(parents=True, exist_ok=True)

# Get observation dimensions
num_joints = env.num_dof
obs_buffer = env.obs_history[0].tolist()[0]
obs_size = len(obs_buffer)

# Index mappings for observation buffer
# This is based on stompypro_env.py
# https://github.com/kscalelabs/sim/blob/54c40d55eab15a9e784e89fb47e64a668851a41b/sim/envs/humanoids/stompypro_env.py#L225
command_2d_start = 0
command_2d_end = 2
command_3d_start = 2
command_3d_end = 5
joint_pos_start = 5
joint_pos_end = joint_pos_start + num_joints
joint_vel_start = joint_pos_end
joint_vel_end = joint_vel_start + num_joints
prev_actions_start = joint_vel_end
prev_actions_end = prev_actions_start + num_joints
ang_vel_start = prev_actions_end
ang_vel_end = ang_vel_start + 3
euler_start = ang_vel_end
euler_end = euler_start + 3

num_actions = env.num_dof
obs_buffer = env.obs_buf.shape[1]
prev_actions = np.zeros((num_actions), dtype=np.double)

h5_loggers = []
for env_idx in range(env_cfg.env.num_envs):
h5_dir = run_dir() / "h5_out" / args.task / now / f"env_{env_idx}"
h5_dir.mkdir(parents=True, exist_ok=True)

h5_loggers.append(HDF5Logger(
data_name=f"{args.task}_env_{env_idx}",
num_actions=num_joints,
num_actions=num_actions,
max_timesteps=env_steps_to_run,
num_observations=obs_size,
num_observations=obs_buffer,
h5_out_dir=str(h5_dir)
))


if args.log_krec:
# Create directory for KRec files
krec_dir = run_dir() / "krec_out" / args.task / now
krec_dir.mkdir(parents=True, exist_ok=True)

start_time_ns = time.time_ns() # Current time in nanoseconds

krec_loggers = []
for env_idx in range(env_cfg.env.num_envs):

# Create KRec header for each environment
header = krec.KRecHeader(
uuid=str(uuid.uuid4()),
robot_platform=f"{args.task}-sim",
robot_serial="123",
task=args.task,
start_timestamp=start_time_ns,
end_timestamp=start_time_ns + int(env_steps_to_run * env_cfg.sim.dt * 1e9),
)

# Add actuator configs for each joint
for i in range(env.num_dof):
actuator_config = krec.ActuatorConfig(
i, # actuator id
kp=float(env.p_gains[env_idx, i].cpu()), # use env_idx instead of 0
ki=0.0,
kd=float(env.d_gains[env_idx, i].cpu()), # use env_idx instead of 0
max_torque=float(env.torque_limits[i].cpu()),
name=f"Joint{i}",
)
header.add_actuator_config(actuator_config)

# Create KRec object for this environment
krec_loggers.append(krec.KRec(header))

if args.render:
camera_properties = gymapi.CameraProperties()
camera_properties.width = 1920
Expand Down Expand Up @@ -197,24 +143,7 @@ def play(args: argparse.Namespace) -> None:

for t in tqdm(range(env_steps_to_run)):
actions = policy(obs.detach())
if args.log_h5:
# Extract the current observation
for env_idx in range(env_cfg.env.num_envs):
cur_obs = env.obs_history[env_idx].tolist()[0]

h5_loggers[env_idx].log_data({
"t": np.array([t * env.dt], dtype=np.float32),
"2D_command": np.array(cur_obs[command_2d_start:command_2d_end], dtype=np.float32),
"3D_command": np.array(cur_obs[command_3d_start:command_3d_end], dtype=np.float32),
"joint_pos": np.array(cur_obs[joint_pos_start:joint_pos_end], dtype=np.float32),
"joint_vel": np.array(cur_obs[joint_vel_start:joint_vel_end], dtype=np.float32),
"prev_actions": np.array(cur_obs[prev_actions_start:prev_actions_end], dtype=np.float32),
"curr_actions": actions.detach().cpu().numpy()[0],
"ang_vel": np.array(cur_obs[ang_vel_start:ang_vel_end], dtype=np.float32),
"euler_rotation": np.array(cur_obs[euler_start:euler_end], dtype=np.float32),
"buffer": np.array(cur_obs, dtype=np.float32)
})


if args.fix_command:
env.commands[:, 0] = 0.5
env.commands[:, 1] = 0.0
Expand Down Expand Up @@ -262,46 +191,36 @@ def play(args: argparse.Namespace) -> None:
"contact_forces_z": contact_forces_z,
}
)
actions = actions.detach().cpu().numpy()
if args.log_h5:
# Extract the current observation
for env_idx in range(env_cfg.env.num_envs):
h5_loggers[env_idx].log_data({
"t": np.array([t * env.dt], dtype=np.float32),
"2D_command": np.array(
[
np.sin(2 * math.pi * t * env.dt / env.cfg.rewards.cycle_time),
np.cos(2 * math.pi * t * env.dt / env.cfg.rewards.cycle_time),
],
dtype=np.float32,
),
"3D_command": np.array(env.commands[env_idx, :3].cpu().numpy(), dtype=np.float32),
"joint_pos": np.array(env.dof_pos[env_idx].cpu().numpy(), dtype=np.float32),
"joint_vel": np.array(env.dof_vel[env_idx].cpu().numpy(), dtype=np.float32),
"prev_actions": prev_actions[env_idx].astype(np.float32),
"curr_actions": actions[env_idx].astype(np.float32),
"ang_vel": env.base_ang_vel[env_idx].cpu().numpy().astype(np.float32),
"euler_rotation": env.base_euler_xyz[env_idx].cpu().numpy().astype(np.float32),
"buffer": env.obs_buf[env_idx].cpu().numpy().astype(np.float32)
})

prev_actions = actions

if infos["episode"]:
num_episodes = env.reset_buf.sum().item()
if num_episodes > 0:
env_logger.log_rewards(infos["episode"], num_episodes)

if args.log_krec:
# Log data for each environment
for env_idx in range(env_cfg.env.num_envs):
frame = krec.KRecFrame(
video_timestamp=start_time_ns + int(t * env.dt * 1e9),
frame_number=t,
inference_step=t // env_cfg.control.decimation,
)

# Add actuator states and commands for each joint
for i in range(env.num_dof):
state = krec.ActuatorState(
actuator_id=i,
online=True,
position=env.dof_pos[env_idx, i].item(),
velocity=env.dof_vel[env_idx, i].item(),
torque=env.torques[env_idx, i].item(),
)
command = krec.ActuatorCommand(
i, # actuator id
position=actions[env_idx, i].item(),
velocity=0.0,
torque=actions[env_idx, i].item(),
)
frame.add_actuator_state(state)
frame.add_actuator_command(command)

# Add IMU data
imu_values = krec.IMUValues(
gyro=krec.Vec3(x=obs[env_idx, 0], y=obs[env_idx, 1], z=obs[env_idx, 2]),
quaternion=krec.IMUQuaternion(x=obs[env_idx, 3], y=obs[env_idx, 4], z=obs[env_idx, 5], w=obs[env_idx, 6]),
)
frame.set_imu_values(imu_values)
krec_loggers[env_idx].add_frame(frame)

env_logger.print_rewards()

if args.render:
Expand All @@ -313,14 +232,6 @@ def play(args: argparse.Namespace) -> None:
h5_logger.close()
print(f"HDF5 file(s) saved!")

if args.log_krec:
for env_idx, krec_logger in enumerate(krec_loggers):
krec_file_path = krec_dir / f"env_{env_idx}" / f"walking_{str(uuid.uuid4())[:8]}.krec"
krec_file_path.parent.mkdir(parents=True, exist_ok=True)
print(f"Saving KRec file to {krec_file_path}")
krec_logger.save(str(krec_file_path))
print("KRec files saved!")


if __name__ == "__main__":
base_args = get_args()
Expand Down

0 comments on commit f3b7b34

Please sign in to comment.