-
Notifications
You must be signed in to change notification settings - Fork 3
/
distillation_001.py
71 lines (69 loc) · 1.89 KB
/
distillation_001.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from src.utils import get_lr
from src import constants
image_size = (64, 64)
batch_size = 32
base_lr = 3e-4
frame_stack_size = 16
config = dict(
image_size=image_size,
batch_size=batch_size,
base_lr=base_lr,
min_base_lr=base_lr * 0.01,
ema_decay=0.999,
train_epoch_size=72000,
num_epochs=[3, 18],
stages=["warmup", "train"],
num_dataloader_workers=8,
init_weights=True,
argus_params={
"nn_module": ("dwiseneuro", {
"readout_outputs": constants.num_neurons,
"in_channels": 5,
"core_features": (64, 64, 64, 64,
128, 128, 128,
256, 256),
"spatial_strides": (2, 1, 1, 1,
2, 1, 1,
2, 1),
"spatial_kernel": 3,
"temporal_kernel": 5,
"expansion_ratio": 6,
"se_reduce_ratio": 32,
"cortex_features": (512 * 2, 1024 * 2, 2048 * 2),
"groups": 2,
"softplus_beta": 0.07,
"drop_rate": 0.4,
"drop_path_rate": 0.1,
}),
"loss": ("mice_poisson", {
"log_input": False,
"full": False,
"eps": 1e-8,
}),
"optimizer": ("AdamW", {
"lr": get_lr(base_lr, batch_size),
"weight_decay": 0.05,
}),
"device": "cuda:0",
"frame_stack": {
"size": frame_stack_size,
"step": 2,
"position": "last",
},
"inputs_processor": ("stack_inputs", {
"size": image_size,
"pad_fill_value": 0.,
}),
"responses_processor": ("identity", {}),
"amp": True,
"iter_size": 1,
},
cutmix={
"alpha": 1.0,
"prob": 0.5,
},
distill={
"experiment": "true_batch_001",
"ratio": 0.36,
},
)