-
Notifications
You must be signed in to change notification settings - Fork 1
/
model.py
30 lines (25 loc) · 957 Bytes
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch.nn as nn
# model definition
class MLP(nn.Module):
# define model elements
def __init__(self, n_inputs, hidden_layer1, hidden_layer2, hidden_layer3):
super(MLP, self).__init__()
self.layer_1 = nn.Linear(n_inputs, hidden_layer1)
self.act1 = nn.ReLU()
self.layer_2 = nn.Linear(hidden_layer1, hidden_layer2)
self.act2 = nn.ReLU()
self.layer_3 = nn.Linear(hidden_layer2, hidden_layer3)
self.act3 = nn.ReLU()
self.layer_4 = nn.Linear(hidden_layer3, 1)
self.act4 = nn.LeakyReLU(0.2)
# forward propagate input
def forward(self, x):
x = self.act1(self.layer_1(x))
x = self.act2(self.layer_2(x))
x = self.act3(self.layer_3(x))
x = self.act4(self.layer_4(x))
return x
if __name__ == "__main__":
m = MLP(13, 39, 39, 13)
for idx, m in enumerate(m.modules()):
print(idx, "-", m)