-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmobilenet.lua
62 lines (45 loc) · 1.68 KB
/
mobilenet.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
--[[
Torch implementation of MobileNet
--]]
require 'nn'
function conv_bn(inp, outp, stride)
local net= nn.Sequential();
net:add((nn.SpatialConvolution(inp, outp,3,3,stride, stride, 1,1)):noBias() )
net:add(nn.SpatialBatchNormalization(outp))
return net;
end
function conv_dw(inp, outp, stride)
local net = nn.Sequential();
net:add(nn.SpatialDepthWiseConvolution(inp, 1,3,3,stride,stride, 1,1) )
net:add(nn.SpatialBatchNormalization(inp))
net:add(nn.ReLU(true))
net:add(nn.SpatialConvolution(inp, outp,1,1,1,1):noBias())
net:add(nn.SpatialBatchNormalization(outp))
net:add(nn.ReLU(true))
return net;
end
function mobilenet(nClasses, H, W)
local model = nn.Sequential();
model:add(conv_bn( 3, 32, 2));
model:add(conv_dw( 32, 64, 1));
model:add(conv_dw( 64, 128, 2));
model:add(conv_dw(128, 128, 1));
model:add(conv_dw(128, 256, 2));
model:add(conv_dw(256, 256, 1));
model:add(conv_dw(256, 512, 2));
model:add(conv_dw(512, 512, 1));
model:add(conv_dw(512, 512, 1));
model:add(conv_dw(512, 512, 1));
model:add(conv_dw(512, 512, 1));
model:add(conv_dw(512, 512, 1));
model:add(conv_dw(512, 1024, 2));
model:add(conv_dw(1024, 1024, 1));
model:add(nn.SpatialAveragePooling(7,7));
-- automatically infer output size for H, W
local out= model:forward(torch.randn(1,3,H,W));
print(out:size())
model:add(nn.View(out:size(2)*out:size(3)*out:size(4)));
model:add(nn.Linear(1024,nClasses))
model:add(nn.LogSoftMax())
return model
end