-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
28 lines (28 loc) · 1.07 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# import the necessary packages
from torch.utils.data import Dataset
import cv2
class SegmentationDataset(Dataset):
def __init__(self, imagePaths, maskPaths, transforms):
# store the image and mask filepaths, and augmentation
# transforms
self.imagePaths = imagePaths
self.maskPaths = maskPaths
self.transforms = transforms
def __len__(self):
# return the number of total samples contained in the dataset
return len(self.imagePaths)
def __getitem__(self, idx):
# grab the image path from the current index
imagePath = self.imagePaths[idx]
# load the image from disk, swap its channels from BGR to RGB,
# and read the associated mask from disk in grayscale mode
image = cv2.imread(imagePath)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mask = cv2.imread(self.maskPaths[idx], 0)
# check to see if we are applying any transformations
if self.transforms is not None:
# apply the transformations to both image and its mask
image = self.transforms(image)
mask = self.transforms(mask)
# return a tuple of the image and its mask
return (image, mask)