forked from t2kasa/social_lstm_keras_tf
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmy_social_model.py
274 lines (211 loc) · 10.4 KB
/
my_social_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
import tensorflow as tf
from keras import Input, Model, backend as K
from keras.layers import Dense, Concatenate, Permute
from keras.layers import LSTM
from keras.layers import Lambda, Reshape
from keras.optimizers import RMSprop
from general_utils import pxy_dim, out_dim
from load_model_config import ModelConfig
from tf_normal_sampler import normal2d_log_pdf
from tf_normal_sampler import normal2d_sample
from grid import tf_grid_mask
from general_utils import get_image_size
class MySocialModel:
def __init__(self, config: ModelConfig) -> None:
self.x_input = Input((config.obs_len, config.max_n_peds, pxy_dim))
# y_input = Input((config.obs_len, config.max_n_peds, pxy_dim))
self.grid_input = Input(
(config.obs_len, config.max_n_peds, config.max_n_peds,
config.grid_side_squared))
self.zeros_input = Input(
(config.obs_len, config.max_n_peds, config.lstm_state_dim))
# Social LSTM layers
self.lstm_layer = LSTM(config.lstm_state_dim, return_state=True)
self.W_e_relu = Dense(config.emb_dim, activation="relu")
self.W_a_relu = Dense(config.emb_dim, activation="relu")
self.W_p = Dense(out_dim)
self._build_model(config)
def _compute_loss(self, y_batch, o_batch):
"""
:param y_batch: (batch_size, pred_len, max_n_peds, pxy_dim)
:param o_batch: (batch_size, pred_len, max_n_peds, out_dim)
:return: loss
"""
not_exist_pid = 0
y = tf.reshape(y_batch, (-1, pxy_dim))
o = tf.reshape(o_batch, (-1, out_dim))
pids = y[:, 0]
# remain only existing pedestrians data
exist_rows = tf.not_equal(pids, not_exist_pid)
y_exist = tf.boolean_mask(y, exist_rows)
o_exist = tf.boolean_mask(o, exist_rows)
pos_exist = y_exist[:, 1:]
# compute 2D normal prob under output parameters
log_prob_exist = normal2d_log_pdf(o_exist, pos_exist)
# for numerical stability
log_prob_exist = tf.minimum(log_prob_exist, 0.0)
loss = -log_prob_exist
return loss
def _compute_social_tensor(self, grid_t, prev_h_t, config):
"""Compute $H_t_i(m, n, :)$.
this function implementation is same as getSocialTensor() function.
:param grid_t: (batch_size, max_n_peds, max_n_peds, grid_side ** 2)
which is (batch_index, self_pid, other_pid, grid_index).
:param prev_h_t: (batch_size, max_n_peds, lstm_state_dim)
:return: H_t (batch_size, max_n_peds, (grid_side ** 2) * lstm_state_dim)
"""
H_t = []
for i in range(config.max_n_peds):
# (batch_size, max_n_peds, max_n_peds, grid_side ** 2)
# => (batch_size, max_n_peds, grid_side ** 2)
grid_it = Lambda(lambda grid_t: grid_t[:, i, ...])(grid_t)
# (batch_size, max_n_peds, grid_side **2)
# => (batch_size, grid_side ** 2, max_n_peds)
grid_it_T = Permute((2, 1))(grid_it)
# (batch_size, grid_side ** 2, lstm_state_dim)
H_it = Lambda(lambda x: K.batch_dot(x[0], x[1]))(
[grid_it_T, prev_h_t])
# store to H_t
H_t.append(H_it)
# list of (batch_size, grid_side_squared, lstm_state_dim)
# => (max_n_peds, batch_size, grid_side_squared, lstm_state_dim)
H_t = Lambda(lambda H_t: K.stack(H_t, axis=0))(H_t)
# (max_n_peds, batch_size, grid_side_squared, lstm_state_dim)
# => (batch_size, max_n_peds, grid_side_squared, lstm_state_dim)
H_t = Lambda(lambda H_t: K.permute_dimensions(H_t, (1, 0, 2, 3)))(H_t)
# (batch_size, max_n_peds, grid_side_squared, lstm_state_dim)
# => (batch_size, max_n_peds, grid_side_squared * lstm_state_dim)
H_t = Reshape(
(config.max_n_peds,
config.grid_side_squared * config.lstm_state_dim))(
H_t)
return H_t
def _build_model(self, config: ModelConfig):
o_obs_batch = []
for t in range(config.obs_len):
print("t: ", t)
x_t = Lambda(lambda x: x[:, t, :, :])(self.x_input)
grid_t = Lambda(lambda grid: grid[:, t, ...])(self.grid_input)
h_t, c_t = [], []
o_t = []
if t == 0:
prev_h_t = Lambda(lambda z: z[:, t, :, :])(self.zeros_input)
prev_c_t = Lambda(lambda z: z[:, t, :, :])(self.zeros_input)
# compute $H_t$
# (n_samples, max_n_peds, (grid_side ** 2) * lstm_state_dim)
H_t = self._compute_social_tensor(grid_t, prev_h_t, config)
for ped_index in range(config.max_n_peds):
print("(t, li):", t, ped_index)
# ----------------------------------------
# compute $e_i^t$ and $a_i^t$
# ----------------------------------------
x_pos_it = Lambda(lambda x_t: x_t[:, ped_index, 1:])(x_t)
e_it = self.W_e_relu(x_pos_it)
# compute a_it
H_it = Lambda(lambda H_t: H_t[:, ped_index, ...])(H_t)
a_it = self.W_a_relu(H_it)
# build concatenated embedding states for LSTM input
emb_it = Concatenate()([e_it, a_it])
emb_it = Reshape((1, 2 * config.emb_dim))(emb_it)
# initial_state = h_i_tになっている
# h_i_tを次のx_t_pに対してLSTMを適用するときのinitial_stateに使えば良い
prev_states_it = [prev_h_t[:, ped_index],
prev_c_t[:, ped_index]]
lstm_output, h_it, c_it = self.lstm_layer(emb_it,
prev_states_it)
h_t.append(h_it)
c_t.append(c_it)
# compute output_it, which shape is (batch_size, 5)
o_it = self.W_p(lstm_output)
o_t.append(o_it)
# convert lists of h_it/c_it/o_it to h_t/c_t/o_t respectively
h_t = _stack_permute_axis_zero(h_t)
c_t = _stack_permute_axis_zero(c_t)
o_t = _stack_permute_axis_zero(o_t)
o_obs_batch.append(o_t)
# current => previous
prev_h_t = h_t
prev_c_t = c_t
# convert list of output_t to output_batch
o_obs_batch = _stack_permute_axis_zero(o_obs_batch)
# ----------------------------------------------------------------------
# Prediction
# ----------------------------------------------------------------------
# この時点でprev_h_t, prev_c_tにはobs_lenの最終的な状態が残っている
x_obs_t_final = Lambda(lambda x: x[:, -1, :, :])(self.x_input)
pid_obs_t_final = Lambda(lambda x_t: x_t[:, :, 0])(x_obs_t_final)
pid_obs_t_final = Lambda(lambda p_t: K.expand_dims(p_t, 2))(
pid_obs_t_final)
x_pred_batch = []
o_pred_batch = []
for t in range(config.pred_len):
if t == 0:
prev_o_t = Lambda(lambda o_b: o_b[:, -1, :, :])(o_obs_batch)
pred_pos_t = normal2d_sample(prev_o_t)
# assume all the pedestrians in the final observation frame are
# exist in the prediction frames.
x_pred_t = Concatenate(axis=2)([pid_obs_t_final, pred_pos_t])
grid_t = tf_grid_mask(x_pred_t, get_image_size(config.test_dataset_kind),
config.n_neighbor_pixels, config.grid_side)
h_t, c_t, o_t = [], [], []
# compute $H_t$
# (n_samples, max_n_peds, (grid_side ** 2) * lstm_state_dim)
H_t = self._compute_social_tensor(grid_t, prev_h_t, config)
for i in range(config.max_n_peds):
print("(t, li):", t, i)
prev_o_it = Lambda(lambda o_t: o_t[:, i, :])(prev_o_t)
H_it = Lambda(lambda H_t: H_t[:, i, ...])(H_t)
# pred_pos_it: (batch_size, 2)
pred_pos_it = normal2d_sample(prev_o_it)
# compute e_it and a_it
# e_it: (batch_size, emb_dim)
# a_it: (batch_size, emb_dim)
e_it = self.W_e_relu(pred_pos_it)
a_it = self.W_a_relu(H_it)
# build concatenated embedding states for LSTM input
# emb_it: (batch_size, 1, 2 * emb_dim)
emb_it = Concatenate()([e_it, a_it])
emb_it = Reshape((1, 2 * config.emb_dim))(emb_it)
# initial_state = h_i_tになっている
# h_i_tを次のx_t_pに対してLSTMを適用するときのinitial_stateに使えば良い
prev_states_it = [prev_h_t[:, i], prev_c_t[:, i]]
lstm_output, h_it, c_it = self.lstm_layer(emb_it,
prev_states_it)
h_t.append(h_it)
c_t.append(c_it)
# compute output_it, which shape is (batch_size, 5)
o_it = self.W_p(lstm_output)
o_t.append(o_it)
# convert lists of h_it/c_it/o_it to h_t/c_t/o_t respectively
h_t = _stack_permute_axis_zero(h_t)
c_t = _stack_permute_axis_zero(c_t)
o_t = _stack_permute_axis_zero(o_t)
o_pred_batch.append(o_t)
x_pred_batch.append(x_pred_t)
# current => previous
prev_h_t = h_t
prev_c_t = c_t
prev_o_t = o_t
# convert list of output_t to output_batch
o_pred_batch = _stack_permute_axis_zero(o_pred_batch)
x_pred_batch = _stack_permute_axis_zero(x_pred_batch)
# o_concat_batch = Lambda(lambda os: tf.concat(os, axis=1))(
# [o_obs_batch, o_pred_batch])
# 本当に学習に必要なモデルはこっちのはず
self.train_model = Model(
[self.x_input, self.grid_input, self.zeros_input],
o_pred_batch
)
lr = 0.003
optimizer = RMSprop(lr=lr)
self.train_model.compile(optimizer, self._compute_loss)
self.sample_model = Model(
[self.x_input, self.grid_input, self.zeros_input],
x_pred_batch
)
def _stack_permute_axis_zero(xs):
xs = Lambda(lambda xs: K.stack(xs, axis=0))(xs)
# axes (0, 1) are permuted
perm = [1, 0] + list(range(2, xs.shape.ndims))
xs = Lambda(lambda xs: K.permute_dimensions(xs, perm))(xs)
return xs