diff --git a/examples/walking_pro_ac.pt b/examples/walking_pro_ac.pt
new file mode 100644
index 00000000..02ce2b59
Binary files /dev/null and b/examples/walking_pro_ac.pt differ
diff --git a/sim/model_export.py b/sim/model_export.py
index 9f893791..51a812c5 100644
--- a/sim/model_export.py
+++ b/sim/model_export.py
@@ -205,6 +205,52 @@ def forward(
return actions_scaled, actions, x
+def get_actor_policy(model_path: str, cfg: ActorCfg) -> Tuple[nn.Module, dict, Tuple[Tensor, ...]]:
+ all_weights = torch.load(model_path, map_location="cpu", weights_only=True)
+ weights = all_weights["model_state_dict"]
+ num_actor_obs = weights["actor.0.weight"].shape[1]
+ num_critic_obs = weights["critic.0.weight"].shape[1]
+ num_actions = weights["std"].shape[0]
+ actor_hidden_dims = [v.shape[0] for k, v in weights.items() if re.match(r"actor\.\d+\.weight", k)]
+ critic_hidden_dims = [v.shape[0] for k, v in weights.items() if re.match(r"critic\.\d+\.weight", k)]
+ actor_hidden_dims = actor_hidden_dims[:-1]
+ critic_hidden_dims = critic_hidden_dims[:-1]
+
+ ac_model = ActorCritic(num_actor_obs, num_critic_obs, num_actions, actor_hidden_dims, critic_hidden_dims)
+ ac_model.load_state_dict(weights)
+
+ a_model = Actor(ac_model.actor, cfg)
+
+ # Gets the model input tensors.
+ x_vel = torch.randn(1)
+ y_vel = torch.randn(1)
+ rot = torch.randn(1)
+ t = torch.randn(1)
+ dof_pos = torch.randn(a_model.num_actions)
+ dof_vel = torch.randn(a_model.num_actions)
+ prev_actions = torch.randn(a_model.num_actions)
+ imu_ang_vel = torch.randn(3)
+ imu_euler_xyz = torch.randn(3)
+ buffer = a_model.get_init_buffer()
+ input_tensors = (x_vel, y_vel, rot, t, dof_pos, dof_vel, prev_actions, imu_ang_vel, imu_euler_xyz, buffer)
+
+ jit_model = torch.jit.script(a_model)
+
+ # Add sim2sim metadata
+ robot_effort = list(a_model.robot.effort().values())
+ robot_stiffness = list(a_model.robot.stiffness().values())
+ robot_damping = list(a_model.robot.damping().values())
+ num_actions = a_model.num_actions
+ num_observations = a_model.num_observations
+
+ return a_model, {
+ "robot_effort": robot_effort,
+ "robot_stiffness": robot_stiffness,
+ "robot_damping": robot_damping,
+ "num_actions": num_actions,
+ "num_observations": num_observations,
+ }, input_tensors
+
def convert_model_to_onnx(model_path: str, cfg: ActorCfg, save_path: Optional[str] = None) -> ort.InferenceSession:
"""Converts a PyTorch model to a ONNX format.
diff --git a/sim/play.py b/sim/play.py
index 3e596a66..b5f135fc 100755
--- a/sim/play.py
+++ b/sim/play.py
@@ -22,9 +22,10 @@
from sim.env import run_dir # noqa: E402
from sim.envs import task_registry # noqa: E402
-from sim.model_export import ActorCfg, convert_model_to_onnx # noqa: E402
+from sim.model_export import ActorCfg, get_actor_policy # noqa: E402
from sim.utils.helpers import get_args # noqa: E402
from sim.utils.logger import Logger # noqa: E402
+from kinfer.export.pytorch import export_to_onnx
import torch # isort: skip
@@ -81,8 +82,19 @@ def play(args: argparse.Namespace) -> None:
# export policy as a onnx module (used to run it on web)
if args.export_onnx:
path = ppo_runner.alg.actor_critic
- convert_model_to_onnx(path, ActorCfg(), save_path="policy.onnx")
- print("Exported policy as onnx to: ", path)
+ policy_cfg = ActorCfg()
+ actor_model, sim2sim_info, input_tensors = get_actor_policy(path, policy_cfg)
+
+ # Merge policy_cfg and sim2sim_info into a single config object
+ export_config = {**vars(policy_cfg), **sim2sim_info}
+
+ policy = export_to_onnx(
+ actor_model,
+ input_tensors=input_tensors,
+ config=export_config,
+ save_path="kinfer_policy.onnx"
+ )
+ print("Exported policy as kinfer-compatible onnx to: ", path)
# Prepare for logging
env_logger = Logger(env.dt)
diff --git a/sim/requirements.txt b/sim/requirements.txt
index a131c312..53fb29da 100755
--- a/sim/requirements.txt
+++ b/sim/requirements.txt
@@ -11,3 +11,5 @@ wandb
tensorboard==2.14.0
onnxscript
# onnxruntime
+
+kinfer
diff --git a/sim/resources/stompypro/joints.py b/sim/resources/stompypro/joints.py
index 27c6c2c6..5d3af1e7 100755
--- a/sim/resources/stompypro/joints.py
+++ b/sim/resources/stompypro/joints.py
@@ -127,12 +127,12 @@ def default_standing(cls) -> Dict[str, float]:
Robot.legs.left.hip_pitch: -0.23,
Robot.legs.left.hip_yaw: 0.0,
Robot.legs.left.hip_roll: 0.0,
- Robot.legs.left.knee_pitch: 0.441,
- Robot.legs.left.ankle_pitch: -0.258,
+ Robot.legs.left.knee_pitch: -0.441,
+ Robot.legs.left.ankle_pitch: 0.258,
Robot.legs.right.hip_pitch: -0.23,
Robot.legs.right.hip_yaw: 0.0,
Robot.legs.right.hip_roll: 0.0,
- Robot.legs.right.knee_pitch: 0.441,
+ Robot.legs.right.knee_pitch: -0.441,
Robot.legs.right.ankle_pitch: -0.258,
}
diff --git a/sim/resources/stompypro/robot_fixed.urdf b/sim/resources/stompypro/robot_fixed.urdf
index a8b1f04c..3d8ec480 100644
--- a/sim/resources/stompypro/robot_fixed.urdf
+++ b/sim/resources/stompypro/robot_fixed.urdf
@@ -1,522 +1,1139 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/sim/resources/stompypro/robot_fixed.xml b/sim/resources/stompypro/robot_fixed.xml
index 63345107..0c122080 100644
--- a/sim/resources/stompypro/robot_fixed.xml
+++ b/sim/resources/stompypro/robot_fixed.xml
@@ -1,4 +1,4 @@
-
+
@@ -11,17 +11,27 @@
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
@@ -32,107 +42,87 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
@@ -140,52 +130,52 @@
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
-
+
diff --git a/sim/scripts/download_assets.sh b/sim/scripts/download_assets.sh
index 34e54529..9ce3dd41 100755
--- a/sim/scripts/download_assets.sh
+++ b/sim/scripts/download_assets.sh
@@ -12,7 +12,7 @@ rm meshes.zip
# Stompypro
echo
echo "Downloading Stompypro assets..."
-gdown --folder https://drive.google.com/drive/folders/1-iIqy8j4gF6JeuMc_MjxkRe4vSZl8Ozp -O sim/resources/stompypro/
+gdown --folder https://drive.google.com/drive/folders/1pFxFpnKxGe7UygpBG5S4YboDLfuIaG05 -O sim/resources/stompypro/
# Xbot
echo
diff --git a/sim/sim2sim.py b/sim/sim2sim.py
index 292462b5..59955505 100755
--- a/sim/sim2sim.py
+++ b/sim/sim2sim.py
@@ -17,10 +17,13 @@
import onnxruntime as ort
import pygame
from scipy.spatial.transform import Rotation as R
+import torch
from tqdm import tqdm
from sim.h5_logger import HDF5Logger
-from sim.model_export import ActorCfg, convert_model_to_onnx
+from sim.model_export import ActorCfg, get_actor_policy
+from kinfer.export.pytorch import export_to_onnx
+from kinfer.inference.python import ONNXModel
@dataclass
@@ -238,7 +241,11 @@ def run_mujoco(
input_data["buffer.1"] = hist_obs.astype(np.float32)
- positions, curr_actions, hist_obs = policy.run(None, input_data)
+ policy_output = policy(input_data)
+ positions = policy_output["actions_scaled"]
+ curr_actions = policy_output["actions"]
+ hist_obs = policy_output["x.3"]
+
target_q = positions
if log_h5:
@@ -290,32 +297,6 @@ def run_mujoco(
if log_h5:
logger.close()
-
-def parse_modelmeta(
- modelmeta: List[Tuple[str, str]],
- verbose: bool = False,
-) -> Dict[str, Union[float, List[float], str]]:
- parsed_meta: Dict[str, Union[float, List[float], str]] = {}
- for key, value in modelmeta:
- if value.startswith("[") and value.endswith("]"):
- parsed_meta[key] = list(map(float, value.strip("[]").split(",")))
- else:
- try:
- parsed_meta[key] = float(value)
- try:
- if int(value) == parsed_meta[key]:
- parsed_meta[key] = int(value)
- except ValueError:
- pass
- except ValueError:
- print(f"Failed to convert {value} to float")
- parsed_meta[key] = value
- if verbose:
- for key, value in parsed_meta.items():
- print(f"{key}: {value}")
- return parsed_meta
-
-
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Deployment script.")
parser.add_argument("--embodiment", type=str, required=True, help="Embodiment name.")
@@ -355,16 +336,31 @@ def parse_modelmeta(
)
if args.load_model.endswith(".onnx"):
- policy = ort.InferenceSession(args.load_model)
+ policy = ONNXModel(args.load_model)
else:
- policy = convert_model_to_onnx(
- args.load_model, policy_cfg, save_path="policy.onnx"
+ actor_model, sim2sim_info, input_tensors = get_actor_policy(args.load_model, policy_cfg)
+
+ # Merge policy_cfg and sim2sim_info into a single config object
+ export_config = {**vars(policy_cfg), **sim2sim_info}
+ print(export_config)
+ export_to_onnx(
+ actor_model,
+ input_tensors=input_tensors,
+ config=export_config,
+ save_path="kinfer_test.onnx"
)
+ policy = ONNXModel("kinfer_test.onnx")
+
+ metadata = policy.get_metadata()
+
+ model_info = {
+ "num_actions": metadata["num_actions"],
+ "num_observations": metadata["num_observations"],
+ "robot_effort": metadata["robot_effort"],
+ "robot_stiffness": metadata["robot_stiffness"],
+ "robot_damping": metadata["robot_damping"],
+ }
- model_info = parse_modelmeta(
- policy.get_modelmeta().custom_metadata_map.items(),
- verbose=True,
- )
run_mujoco(
args.embodiment,