-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathda_datasets.py
173 lines (145 loc) · 5.8 KB
/
da_datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import glob
import random
import os
import sys
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
from torch.distributions import Beta
from utils.augmentations import *
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from torchvision.utils import save_image
def pad_to_square(img, pad_value):
c, h, w = img.shape
dim_diff = np.abs(h - w)
# (upper / left) padding and (lower / right) padding
pad1, pad2 = dim_diff // 2, dim_diff - dim_diff // 2
# Determine padding
pad = (0, 0, pad1, pad2) if h <= w else (pad1, pad2, 0, 0)
# Add padding
img = F.pad(img, pad, "constant", value=pad_value)
return img, pad
def resize(image, size):
image = F.interpolate(image.unsqueeze(0), size=size, mode="nearest").squeeze(0)
return image
def random_resize(images, min_size=288, max_size=448):
new_size = random.sample(list(range(min_size, max_size + 1, 32)), 1)[0]
images = F.interpolate(images, size=new_size, mode="nearest")
return images
class ImageFolder(Dataset):
def __init__(self, folder_path, img_size=416):
self.files = sorted(glob.glob("%s/*.*" % folder_path))
self.img_size = img_size
def __getitem__(self, index):
img_path = self.files[index % len(self.files)]
# Extract image as PyTorch tensor
img = transforms.ToTensor()(Image.open(img_path))
# Pad to square resolution
img, _ = pad_to_square(img, 0)
# Resize
img = resize(img, self.img_size)
return img_path, img
def __len__(self):
return len(self.files)
class ListDataset(Dataset):
def __init__(self, list_path, img_size=416, augment=True, multiscale=True, normalized_labels=True):
with open(list_path, "r") as file:
self.img_files = file.readlines()
self.label_files = [
path.replace("images", "labels").replace(".png", ".txt").replace(".jpg", ".txt")
for path in self.img_files
]
self.img_size = img_size
self.max_objects = 100
self.augment = augment
self.multiscale = multiscale
self.normalized_labels = normalized_labels
self.min_size = self.img_size - 3 * 32
self.max_size = self.img_size + 3 * 32
self.batch_count = 0
def __getitem__(self, index):
# ---------
# Image
# ---------
aug_path = './data/URPC2019/type'
type = random.randint(1,7)
domain_label = torch.zeros(1, 8)
img_path = self.img_files[index % len(self.img_files)].rstrip()
if self.augment:
if np.random.random() < 0.875:
img_name = img_path.split('/')[-1]
a_path = aug_path + str(type)
path = os.path.join(a_path,img_name)
if os.path.exists(path):
img_path = path
else:
type = 0
else:
type = 0
domain_label[0, type] = 1
domain_label = domain_label.squeeze()
# Extract image as PyTorch tensor
img = transforms.ToTensor()(Image.open(img_path).convert('RGB'))
img = resize(img, self.img_size)
# save_image(img,'out.jpg',normalize=True)
# Handle images with less than three channels
if len(img.shape) != 3:
img = img.unsqueeze(0)
img = img.expand((3, img.shape[1:]))
_, h, w = img.shape
h_factor, w_factor = (h, w) if self.normalized_labels else (1, 1)
# Pad to square resolution
img, pad = pad_to_square(img, 0)
_, padded_h, padded_w = img.shape
# ---------
# Label
# ---------
label_path = self.label_files[index % len(self.img_files)].rstrip()
targets = None
if os.path.exists(label_path):
boxes = torch.from_numpy(np.loadtxt(label_path).reshape(-1, 5))
boxes = boxes - torch.from_numpy(np.array([1.0,0.0,0.0,0.0,0.0]))
# Extract coordinates for unpadded + unscaled image
x1 = w_factor * (boxes[:, 1] - boxes[:, 3] / 2)
y1 = h_factor * (boxes[:, 2] - boxes[:, 4] / 2)
x2 = w_factor * (boxes[:, 1] + boxes[:, 3] / 2)
y2 = h_factor * (boxes[:, 2] + boxes[:, 4] / 2)
# Adjust for added padding
x1 += pad[0]
y1 += pad[2]
x2 += pad[1]
y2 += pad[3]
# Returns (x, y, w, h)
boxes[:, 1] = ((x1 + x2) / 2) / padded_w
boxes[:, 2] = ((y1 + y2) / 2) / padded_h
boxes[:, 3] *= w_factor / padded_w
boxes[:, 4] *= h_factor / padded_h
targets = torch.zeros((len(boxes), 6))
targets[:, 1:] = boxes
# Apply augmentations
"""
if self.augment:
if np.random.random() < 0.5:
img, targets = horisontal_flip(img, targets)
"""
return img_path, img, targets,domain_label
def collate_fn(self, batch):
paths, imgs, targets,domain_labels = list(zip(*batch))
# Remove empty placeholder targets
targets = [boxes for boxes in targets if boxes is not None]
# Add sample index to targets
for i, boxes in enumerate(targets):
boxes[:, 0] = i
targets = torch.cat(targets, 0)
# Selects new image size every tenth batch
if self.multiscale and self.batch_count % 10 == 0:
self.img_size = random.choice(range(self.min_size, self.max_size + 1, 32))
# Resize images to input shape
imgs = torch.stack([resize(img, self.img_size) for img in imgs])
domain_labels = torch.stack([domain_label for domain_label in domain_labels])
self.batch_count += 1
return paths, imgs, targets, domain_labels
def __len__(self):
return len(self.img_files)