-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathutils.py
62 lines (52 loc) · 2.05 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
from tqdm import tqdm
import config
import numpy as np
def intersection_over_union(pred, true):
"""
Calculates intersection and union for a batch of images.
Args:
pred (torch.Tensor): a tensor of predictions
true (torc.Tensor): a tensor of labels
Returns:
intersection (int): total intersection of pixels
union (int): total union of pixels
"""
valid_pixel_mask = true.ne(255) # valid pixel mask
true = true.masked_select(valid_pixel_mask).to("cpu")
pred = pred.masked_select(valid_pixel_mask).to("cpu")
# Intersection and union totals
intersection = np.logical_and(true, pred)
union = np.logical_or(true, pred)
return intersection.sum() / union.sum()
def train_fn(train_loader, model, optimizer, loss_fn, scaler):
train_loader = tqdm(train_loader, desc="batches")
for it in train_loader:
data = it["chip"].type(torch.FloatTensor)
targets = it["label"]
data = data.to(config.device)
targets = targets.type(torch.LongTensor).to(config.device)
with torch.cuda.amp.autocast():
predictions = model(data)
loss = loss_fn(predictions, targets)
loss = loss.to(config.device)
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
train_loader.set_postfix(loss=loss.item())
def val_fn(val_loader, model):
iou_list = []
val_loader = tqdm(val_loader)
with torch.no_grad():
model.eval()
for it in val_loader:
input_image = it["chip"].type(torch.FloatTensor).to(config.device)
true_mask = it["label"].squeeze()
predicted_mask = model(input_image)
predicted_mask = torch.argmax(predicted_mask, dim=1).squeeze()
batch_iou = intersection_over_union(predicted_mask.detach().to("cpu"), true_mask)
iou_list.append(batch_iou)
val_loader.set_postfix(iou=sum(iou_list) / len(iou_list))
model.train()
return sum(iou_list) / len(iou_list)