-
Notifications
You must be signed in to change notification settings - Fork 1
/
KTA5 - Prototype Networks.py
423 lines (350 loc) · 15.5 KB
/
KTA5 - Prototype Networks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
# # Prototype Networks
# implementation based on [orobix/Prototypical-Networks-for-Few-shot-Learning-PyTorch](https://github.com/orobix/Prototypical-Networks-for-Few-shot-Learning-PyTorch).
# dependencies: avalanche-lib==0.3.1
# 'pip install avalanche-lib==0.3.1')
import torch
from torch.nn import functional as F
from torch.nn.modules import Module
from types import SimpleNamespace
from utils import load_omniglot_data, init_seed
import os
import numpy as np
from tqdm import tqdm
import torch.nn as nn
class PrototypicalBatchSampler(object):
'''
PrototypicalBatchSampler: yield a batch of indexes at each iteration.
Indexes are calculated by keeping in account 'classes_per_it' and 'num_samples',
In fact at every iteration the batch indexes will refer to 'num_support' + 'num_query' samples
for 'classes_per_it' random classes.
__len__ returns the number of episodes per epoch (same as 'self.iterations').
'''
def __init__(self, labels, classes_per_it, num_samples, iterations):
'''
Initialize the PrototypicalBatchSampler object
Args:
- labels: an iterable containing all the labels for the current dataset
samples indexes will be infered from this iterable.
- classes_per_it: number of random classes for each iteration
- num_samples: number of samples for each iteration for each class (support + query)
- iterations: number of iterations (episodes) per epoch
'''
super(PrototypicalBatchSampler, self).__init__()
self.labels = labels
self.classes_per_it = classes_per_it
self.sample_per_class = num_samples
self.iterations = iterations
self.classes, self.counts = np.unique(self.labels, return_counts=True)
self.classes = torch.LongTensor(self.classes)
# create a matrix, indexes, of dim: classes X max(elements per class)
# fill it with nans
# for every class c, fill the relative row with the indices samples belonging to c
# in numel_per_class we store the number of samples for each class/row
self.idxs = range(len(self.labels))
self.indexes = np.empty((len(self.classes), max(self.counts)), dtype=int) * np.nan
self.indexes = torch.Tensor(self.indexes)
self.numel_per_class = torch.zeros_like(self.classes)
for idx, label in enumerate(self.labels):
label_idx = np.argwhere(self.classes == label).item()
self.indexes[label_idx, np.where(np.isnan(self.indexes[label_idx]))[0][0]] = idx
self.numel_per_class[label_idx] += 1
def __iter__(self):
'''
yield a batch of indexes
'''
spc = self.sample_per_class
cpi = self.classes_per_it
for it in range(self.iterations):
batch_size = spc * cpi
batch = torch.LongTensor(batch_size)
# NOTE: pick classes randomly
c_idxs = torch.randperm(len(self.classes))[:cpi]
# NOTE: for each class, pick a limited set of samples randomly
for i, c in enumerate(self.classes[c_idxs]):
s = slice(i * spc, (i + 1) * spc)
# FIXME when torch.argwhere will exists
label_idx = torch.arange(len(self.classes)).long()[self.classes == c].item()
sample_idxs = torch.randperm(self.numel_per_class[label_idx])[:spc]
batch[s] = self.indexes[label_idx][sample_idxs]
batch = batch[torch.randperm(len(batch))]
yield batch
def __len__(self):
'''
returns the number of iterations (episodes) per epoch
'''
return self.iterations
# source: https://github.com/orobix/Prototypical-Networks-for-Few-shot-Learning-PyTorch
def conv_block(in_channels, out_channels):
'''
returns a block conv-bn-relu-pool
'''
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.MaxPool2d(2)
)
class ProtoNet(nn.Module):
'''
Model as described in the reference paper,
source: https://github.com/jakesnell/prototypical-networks/blob/f0c48808e496989d01db59f86d4449d7aee9ab0c/protonets/models/few_shot.py#L62-L84
'''
def __init__(self, x_dim=1, hid_dim=64, z_dim=64):
super(ProtoNet, self).__init__()
self.encoder = nn.Sequential(
conv_block(x_dim, hid_dim),
conv_block(hid_dim, hid_dim),
conv_block(hid_dim, hid_dim),
conv_block(hid_dim, z_dim),
)
def forward(self, x):
x = self.encoder(x)
return x.view(x.size(0), -1)
def euclidean_dist(x, y):
# x: N x D
# y: M x D
n = x.size(0)
m = y.size(0)
d = x.size(1)
if d != y.size(1):
raise Exception
x = x.unsqueeze(1).expand(n, m, d)
y = y.unsqueeze(0).expand(n, m, d)
return torch.pow(x - y, 2).sum(2)
def prototypical_loss(input, target, n_support):
'''
Inspired by https://github.com/jakesnell/prototypical-networks/blob/master/protonets/models/few_shot.py
Compute the barycentres by averaging the features of n_support
samples for each class in target, computes then the distances from each
samples' features to each one of the barycentres, computes the
log_probability for each n_query samples for each one of the current
classes, of appartaining to a class c, loss and accuracy are then computed
and returned
Args:
- input: the model output for a batch of samples
- target: ground truth for the above batch of samples
- n_support: number of samples to keep in account when computing
barycentres, for each one of the current classes
'''
target_cpu = target.to('cpu')
input_cpu = input.to('cpu')
def supp_idxs(c):
# FIXME when torch will support where as np
return target_cpu.eq(c).nonzero()[:n_support].squeeze(1)
# FIXME when torch.unique will be available on cuda too
classes = torch.unique(target_cpu)
n_classes = len(classes)
# FIXME when torch will support where as np
# assuming n_query, n_target constants
n_query = target_cpu.eq(classes[0].item()).sum().item() - n_support
# NOTE: select support samples and split them by classes
support_idxs = list(map(supp_idxs, classes))
# NOTE: compute prototypes by averaging embeddings for each class
prototypes = torch.stack([input_cpu[idx_list].mean(0) for idx_list in support_idxs])
# NOTE: select query samples
# FIXME when torch will support where as np
query_idxs = torch.stack(list(map(lambda c: target_cpu.eq(c).nonzero()[n_support:], classes))).view(-1)
query_samples = input.to('cpu')[query_idxs]
# NOTE: compute distances between query samples and prototypes
dists = euclidean_dist(query_samples, prototypes)
# NOTE: compute probabilities (actually, logits) from the distances
log_p_y = F.log_softmax(-dists, dim=1).view(n_classes, n_query, -1)
# NOTE: once you have the logits, you can compute the crossentropy loss and accuracy as usual
target_inds = torch.arange(0, n_classes)
target_inds = target_inds.view(n_classes, 1, 1)
target_inds = target_inds.expand(n_classes, n_query, 1).long()
loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean()
_, y_hat = log_p_y.max(2)
acc_val = y_hat.eq(target_inds.squeeze(2)).float().mean()
return loss_val, acc_val
def init_sampler(opt, labels, mode):
if 'train' in mode:
classes_per_it = opt.classes_per_it_tr
num_samples = opt.num_support_tr + opt.num_query_tr
else:
classes_per_it = opt.classes_per_it_val
num_samples = opt.num_support_val + opt.num_query_val
# NOTE: classes_per_it combines the query and support set
# we split them inside the prototypical loss
return PrototypicalBatchSampler(labels=labels,
classes_per_it=classes_per_it,
num_samples=num_samples,
iterations=opt.iterations)
def init_dataloader(opt, mode):
dataset = load_omniglot_data(opt, mode)
sampler = init_sampler(opt, dataset.y, mode)
dataloader = torch.utils.data.DataLoader(dataset, batch_sampler=sampler)
return dataloader
def init_protonet(opt):
'''
Initialize the ProtoNet
'''
device = 'cuda:0' if torch.cuda.is_available() and opt.cuda else 'cpu'
model = ProtoNet().to(device)
return model
def init_optim(opt, model):
'''
Initialize optimizer
'''
return torch.optim.Adam(params=model.parameters(),
lr=opt.learning_rate)
def init_lr_scheduler(opt, optim):
'''
Initialize the learning rate scheduler
'''
return torch.optim.lr_scheduler.StepLR(optimizer=optim,
gamma=opt.lr_scheduler_gamma,
step_size=opt.lr_scheduler_step)
def save_list_to_file(path, thelist):
with open(path, 'w') as f:
for item in thelist:
f.write("%s\n" % item)
def train(opt, tr_dataloader, model, optim, lr_scheduler, val_dataloader=None):
'''
Train the model with the prototypical learning algorithm
'''
device = 'cuda:0' if torch.cuda.is_available() and opt.cuda else 'cpu'
if val_dataloader is None:
best_state = None
train_loss = []
train_acc = []
val_loss = []
val_acc = []
best_acc = 0
best_model_path = os.path.join(opt.experiment_root, 'best_model.pth')
last_model_path = os.path.join(opt.experiment_root, 'last_model.pth')
for epoch in range(opt.epochs):
print('=== Epoch: {} ==='.format(epoch))
tr_iter = iter(tr_dataloader)
model.train()
for batch in tqdm(tr_iter):
optim.zero_grad()
x, y = batch
x, y = x.to(device), y.to(device)
model_output = model(x)
loss, acc = prototypical_loss(model_output, target=y,
n_support=opt.num_support_tr)
loss.backward()
optim.step()
train_loss.append(loss.item())
train_acc.append(acc.item())
avg_loss = np.mean(train_loss[-opt.iterations:])
avg_acc = np.mean(train_acc[-opt.iterations:])
print('Avg Train Loss: {}, Avg Train Acc: {}'.format(avg_loss, avg_acc))
lr_scheduler.step()
if val_dataloader is None:
continue
val_iter = iter(val_dataloader)
model.eval()
for batch in val_iter:
x, y = batch
x, y = x.to(device), y.to(device)
model_output = model(x)
loss, acc = prototypical_loss(model_output, target=y,
n_support=opt.num_support_val)
val_loss.append(loss.item())
val_acc.append(acc.item())
avg_loss = np.mean(val_loss[-opt.iterations:])
avg_acc = np.mean(val_acc[-opt.iterations:])
postfix = ' (Best)' if avg_acc >= best_acc else ' (Best: {})'.format(
best_acc)
print('Avg Val Loss: {}, Avg Val Acc: {}{}'.format(
avg_loss, avg_acc, postfix))
if avg_acc >= best_acc:
torch.save(model.state_dict(), best_model_path)
best_acc = avg_acc
best_state = model.state_dict()
torch.save(model.state_dict(), last_model_path)
for name in ['train_loss', 'train_acc', 'val_loss', 'val_acc']:
save_list_to_file(os.path.join(opt.experiment_root,
name + '.txt'), locals()[name])
return best_state, best_acc, train_loss, train_acc, val_loss, val_acc
def test(opt, test_dataloader, model):
'''
Test the model trained with the prototypical learning algorithm
'''
device = 'cuda:0' if torch.cuda.is_available() and opt.cuda else 'cpu'
avg_acc = list()
for epoch in range(10):
test_iter = iter(test_dataloader)
for batch in test_iter:
x, y = batch
x, y = x.to(device), y.to(device)
model_output = model(x)
_, acc = prototypical_loss(model_output, target=y,
n_support=opt.num_support_val)
avg_acc.append(acc.item())
avg_acc = np.mean(avg_acc)
print('Test Acc: {}'.format(avg_acc))
return avg_acc
def eval(opt):
'''
Initialize everything and train
'''
options = get_parser().parse_args()
if torch.cuda.is_available() and not options.cuda:
print("WARNING: You have a CUDA device, so you should probably run with --cuda")
init_seed(options)
test_dataloader = load_omniglot_data(options)[-1]
model = init_protonet(options)
model_path = os.path.join(opt.experiment_root, 'best_model.pth')
model.load_state_dict(torch.load(model_path))
test(opt=options,
test_dataloader=test_dataloader,
model=model)
if __name__ == '__main__':
# Training Configuration
# Don't forget to increase epochs and iterations and to set `cuda` if you want to get the best performance.
options = SimpleNamespace(
# folders
dataset_root='.' + os.sep + 'dataset', # path to dataset
experiment_root='.' + os.sep + 'output', # root where to store models, losses and accuracies
# training hparams
epochs=1, # number of epochs to train for, default=100
learning_rate=0.001, # learning rate for the model
lr_scheduler_step=20, # StepLR learning rate scheduler step
lr_scheduler_gamma=0.5, # StepLR learning rate scheduler gamma
iterations=10, # number of episodes per epoch, default=100
classes_per_it_tr=60, # number of random classes per episode for training
manual_seed=7, # input for the manual seeds initializations
cuda=False,
# task hparams
num_support_tr=5, # number of samples per class to use as support for training
num_query_tr=5, # number of samples per class to use as query for training
classes_per_it_val=5, # number of random classes per episode for validation
num_support_val=5, # number of samples per class to use as support for validation
num_query_val=15 # number of samples per class to use as query for validation
)
# Omniglot Data
# notice that train and test splits have different classes!
train_data = load_omniglot_data(options, 'train')
test_data = load_omniglot_data(options, 'test')
# print some classes
# list(filter(lambda s: s.startswith('T'), train_data.idx_classes.keys()))
# list(filter(lambda s: s.startswith('T'), test_data.idx_classes.keys()))
os.makedirs(options.experiment_root, exist_ok=True)
init_seed(options)
tr_dataloader = init_dataloader(options, 'train')
val_dataloader = init_dataloader(options, 'val')
test_dataloader = init_dataloader(options, 'test')
model = init_protonet(options)
optim = init_optim(options, model)
lr_scheduler = init_lr_scheduler(options, optim)
# meta-train
res = train(opt=options,
tr_dataloader=tr_dataloader,
val_dataloader=val_dataloader,
model=model,
optim=optim,
lr_scheduler=lr_scheduler)
best_state, best_acc, train_loss, train_acc, val_loss, val_acc = res
# meta-test
print('Testing with last model..')
test(opt=options,
test_dataloader=test_dataloader,
model=model)
model.load_state_dict(best_state)
print('Testing with best model..')
test(opt=options,
test_dataloader=test_dataloader,
model=model)