diff --git a/sim/play.py b/sim/play.py index 3e596a66..7d168e81 100755 --- a/sim/play.py +++ b/sim/play.py @@ -22,9 +22,10 @@ from sim.env import run_dir # noqa: E402 from sim.envs import task_registry # noqa: E402 -from sim.model_export import ActorCfg, convert_model_to_onnx # noqa: E402 +from sim.model_export import ActorCfg, convert_model_to_onnx, get_actor_policy # noqa: E402 from sim.utils.helpers import get_args # noqa: E402 from sim.utils.logger import Logger # noqa: E402 +from kinfer.export.pytorch import export_to_onnx import torch # isort: skip @@ -81,8 +82,19 @@ def play(args: argparse.Namespace) -> None: # export policy as a onnx module (used to run it on web) if args.export_onnx: path = ppo_runner.alg.actor_critic - convert_model_to_onnx(path, ActorCfg(), save_path="policy.onnx") - print("Exported policy as onnx to: ", path) + policy_cfg = ActorCfg() + actor_model, sim2sim_info, input_tensors = get_actor_policy(path, policy_cfg) + + # Merge policy_cfg and sim2sim_info into a single config object + export_config = {**vars(policy_cfg), **sim2sim_info} + + policy = export_to_onnx( + actor_model, + input_tensors=input_tensors, + config=export_config, + save_path="kinfer_policy.onnx" + ) + print("Exported policy as kinfer-compatible onnx to: ", path) # Prepare for logging env_logger = Logger(env.dt) diff --git a/sim/requirements.txt b/sim/requirements.txt index a131c312..53fb29da 100755 --- a/sim/requirements.txt +++ b/sim/requirements.txt @@ -11,3 +11,5 @@ wandb tensorboard==2.14.0 onnxscript # onnxruntime + +kinfer diff --git a/sim/sim2sim.py b/sim/sim2sim.py index 829e10cd..42bb0edd 100755 --- a/sim/sim2sim.py +++ b/sim/sim2sim.py @@ -21,7 +21,7 @@ from tqdm import tqdm from sim.h5_logger import HDF5Logger -from model_export import ActorCfg, get_actor_policy, convert_model_to_onnx +from sim.model_export import ActorCfg, get_actor_policy from kinfer.export.pytorch import export_to_onnx @@ -363,10 +363,8 @@ def new_func(args, policy_cfg): if args.load_model.endswith(".onnx"): policy = ort.InferenceSession(args.load_model) else: - # Export function is able to infer input shapes - # actor_model = new_func(args, policy_cfg) - # actor_model = torch.jit.load(args.load_model) actor_model, sim2sim_info, input_tensors = get_actor_policy(args.load_model, policy_cfg) + # Merge policy_cfg and sim2sim_info into a single config object export_config = {**vars(policy_cfg), **sim2sim_info} print(export_config) @@ -377,7 +375,7 @@ def new_func(args, policy_cfg): save_path="kinfer_test.onnx" ) # policy = convert_model_to_onnx(args.load_model, policy_cfg, save_path="policy.onnx") - + model_info = parse_modelmeta( policy.get_modelmeta().custom_metadata_map.items(), verbose=True,