-
Notifications
You must be signed in to change notification settings - Fork 9
/
experiment_svhn.py
97 lines (79 loc) · 2.86 KB
/
experiment_svhn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""
Experiments on SVHN dataset.
"""
import numpy as np
import torch
import fire
from experiments import Model, Dataset
from etn import coordinates, networks, transformers
class SVHNModel(Model):
# transformer defaults
tf_default_opts = {
'in_channels': 3,
'kernel_size': 3,
'nf': 32,
'strides': (1, 1),
}
# classification network defaults
net_default_opts = {
'nf': 32,
'p_dropout': 0.3,
'pad_mode': ('constant', 'constant'),
}
# optimizer defaults
optimizer_default_opts = {
'amsgrad': True,
'lr': 2e-3,
'weight_decay': 0.,
}
# learning rate schedule defaults
lr_default_schedule = {
'step_size': 1,
'gamma': 0.99,
}
# dataset mean and standard deviation
normalization_mean = torch.FloatTensor([0.4379, 0.4440, 0.4729])
normalization_std = torch.FloatTensor([0.1981, 0.2010, 0.1970])
def __init__(self,
tfs=[transformers.Translation,
transformers.RotationScale,
transformers.ScaleX],
coords=coordinates.identity_grid,
net=networks.resnet10,
equivariant=True,
tf_opts=tf_default_opts,
net_opts=net_default_opts,
seed=None,
load_path=None):
"""SVHN model"""
tf_opts_copy = dict(self.tf_default_opts)
tf_opts_copy.update(tf_opts)
net_opts_copy = dict(self.net_default_opts)
net_opts_copy.update(net_opts)
super().__init__(tfs=tfs, coords=coords, net=net,
equivariant=equivariant, tf_opts=tf_opts_copy,
net_opts=net_opts_copy, seed=seed, load_path=load_path)
def __str__(self):
return "Street View House Numbers classification (single-digit)"
def _load_dataset(self, path, num_examples=None):
return Dataset(path=path, num_examples=num_examples,
normalization=(self.normalization_mean,
self.normalization_std))
def train(self,
num_epochs=300,
batch_size=128,
optimizer_opts=optimizer_default_opts,
lr_schedule=lr_default_schedule,
**kwargs):
optimizer_opts_copy = dict(self.optimizer_default_opts)
optimizer_opts_copy.update(optimizer_opts)
lr_schedule_copy = dict(self.lr_default_schedule)
lr_schedule_copy.update(lr_schedule)
super().train(
num_epochs=num_epochs,
batch_size=batch_size,
optimizer_opts=optimizer_opts_copy,
lr_schedule=lr_schedule_copy,
**kwargs)
if __name__ == '__main__':
fire.Fire(SVHNModel)