forked from t2kasa/social_lstm_keras_tf
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluate_social_model.py
83 lines (63 loc) · 2.95 KB
/
evaluate_social_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import json
import os
from argparse import Namespace, ArgumentParser
import matplotlib.pyplot as plt
import numpy as np
from data_utils import obs_pred_split
from evaluation_metrics import compute_abe, compute_fde
from load_model_config import load_model_config
from my_social_model import MySocialModel
from provide_train_test import provide_train_test
from vizualize_trajectories import visualize_trajectories
def _load_eval_args() -> Namespace:
parser = ArgumentParser()
parser.add_argument("--trained_model_config", type=str, required=True)
parser.add_argument("--trained_model_file", type=str, required=True)
return parser.parse_args()
def main() -> None:
args = _load_eval_args()
config = load_model_config(args.trained_model_config)
out_dir = os.path.join(os.path.dirname(args.trained_model_file), "eval")
os.makedirs(out_dir, exist_ok=True)
# load data
_, test_data = provide_train_test(config)
obs_len_test, pred_len_test = obs_pred_split(
config.obs_len, config.pred_len, *test_data)
# load trained model weights
my_model = MySocialModel(config)
my_model.train_model.load_weights(args.trained_model_file)
# first, predict `pred_len` sequences following observation sequence
x_obs_len_test, _, grid_obs_len_test, zeros_obs_len_test = obs_len_test
x_pred_len_test, *_ = pred_len_test
x_pred_len_model = my_model.sample_model.predict(
[x_obs_len_test, grid_obs_len_test, zeros_obs_len_test],
batch_size=config.batch_size, verbose=1)
# --------------------------------------------------------------------------
# visualization
# --------------------------------------------------------------------------
x_concat_test = np.concatenate([x_obs_len_test, x_pred_len_test], axis=1)
x_concat_model = np.concatenate([x_obs_len_test, x_pred_len_model], axis=1)
# visualize true and predicted trajectories, and save as png files
out_fig_dir = os.path.join(out_dir, "figs")
os.makedirs(out_fig_dir, exist_ok=True)
for s in range(len(x_concat_test)):
fig = visualize_trajectories(x_concat_test[s], x_concat_model[s],
config.obs_len, config.pred_len)
fig_file = os.path.join(out_fig_dir, "{0:04d}.png".format(s))
fig.savefig(fig_file)
plt.close(fig)
# --------------------------------------------------------------------------
# evaluation
# --------------------------------------------------------------------------
# evaluation
ade = compute_abe(x_pred_len_test, x_pred_len_model)
fde = compute_fde(x_pred_len_test, x_pred_len_model)
report = {"ade": float(ade), "fde": float(fde)}
# write to a json file
report_file = os.path.join(out_dir, "report.json")
with open(report_file, "w") as f:
json.dump(report, f)
print("Average displacement error: {}".format(ade))
print("Final displacement error: {}".format(fde))
if __name__ == '__main__':
main()