Skip to content

Commit

Permalink
export with params
Browse files Browse the repository at this point in the history
  • Loading branch information
WT-MM committed Nov 23, 2024
1 parent 2c1e791 commit a1bb322
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
10 changes: 8 additions & 2 deletions sim/model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def forward(

return actions_scaled, actions, x

def get_actor_policy(model_path: str, cfg: ActorCfg) -> Actor:
def get_actor_policy(model_path: str, cfg: ActorCfg) -> Tuple[nn.Module, dict, Tuple[Tensor, ...]]:
all_weights = torch.load(model_path, map_location="cpu", weights_only=True)
weights = all_weights["model_state_dict"]
num_actor_obs = weights["actor.0.weight"].shape[1]
Expand Down Expand Up @@ -243,7 +243,13 @@ def get_actor_policy(model_path: str, cfg: ActorCfg) -> Actor:
num_actions = a_model.num_actions
num_observations = a_model.num_observations

return a_model.policy
return a_model, {
"robot_effort": robot_effort,
"robot_stiffness": robot_stiffness,
"robot_damping": robot_damping,
"num_actions": num_actions,
"num_observations": num_observations,
}, input_tensors


def convert_model_to_onnx(model_path: str, cfg: ActorCfg, save_path: Optional[str] = None) -> ort.InferenceSession:
Expand Down
12 changes: 10 additions & 2 deletions sim/sim2sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,8 +366,16 @@ def new_func(args, policy_cfg):
# Export function is able to infer input shapes
# actor_model = new_func(args, policy_cfg)
# actor_model = torch.jit.load(args.load_model)
actor_model = get_actor_policy(args.load_model, policy_cfg)
policy = export_to_onnx(actor_model, input_tensors=None, config=policy_cfg, save_path="kinfer_test.onnx")
actor_model, sim2sim_info, input_tensors = get_actor_policy(args.load_model, policy_cfg)
# Merge policy_cfg and sim2sim_info into a single config object
export_config = {**vars(policy_cfg), **sim2sim_info}
print(export_config)
policy = export_to_onnx(
actor_model,
input_tensors=input_tensors,
config=export_config,
save_path="kinfer_test.onnx"
)
# policy = convert_model_to_onnx(args.load_model, policy_cfg, save_path="policy.onnx")

model_info = parse_modelmeta(
Expand Down

0 comments on commit a1bb322

Please sign in to comment.