-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathcreate_dataset.py
74 lines (62 loc) · 2.35 KB
/
create_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from random import shuffle
import glob
import sys
import cv2
import numpy as np
#import skimage.io as io
import tensorflow as tf
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def load_image(addr):
# read an image and resize to (224, 224)
# cv2 load images as BGR, convert it to RGB
img = cv2.imread(addr)
if img is None:
return None
img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_CUBIC)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
def createDataRecord(out_filename, addrs, labels):
# open the TFRecords file
writer = tf.python_io.TFRecordWriter(out_filename)
for i in range(len(addrs)):
# print how many images are saved every 1000 images
if not i % 1000:
print('Train data: {}/{}'.format(i, len(addrs)))
sys.stdout.flush()
# Load the image
img = load_image(addrs[i])
label = labels[i]
if img is None:
continue
# Create a feature
feature = {
'image_raw': _bytes_feature(img.tostring()),
'label': _int64_feature(label)
}
# Create an example protocol buffer
example = tf.train.Example(features=tf.train.Features(feature=feature))
# Serialize to string and write on the file
writer.write(example.SerializeToString())
writer.close()
sys.stdout.flush()
cat_dog_train_path = 'PetImages/*/*.jpg'
# read addresses and labels from the 'train' folder
addrs = glob.glob(cat_dog_train_path)
labels = [0 if 'Cat' in addr else 1 for addr in addrs] # 0 = Cat, 1 = Dog
# to shuffle data
c = list(zip(addrs, labels))
shuffle(c)
addrs, labels = zip(*c)
# Divide the data into 60% train, 20% validation, and 20% test
train_addrs = addrs[0:int(0.6*len(addrs))]
train_labels = labels[0:int(0.6*len(labels))]
val_addrs = addrs[int(0.6*len(addrs)):int(0.8*len(addrs))]
val_labels = labels[int(0.6*len(addrs)):int(0.8*len(addrs))]
test_addrs = addrs[int(0.8*len(addrs)):]
test_labels = labels[int(0.8*len(labels)):]
createDataRecord('train.tfrecords', train_addrs, train_labels)
createDataRecord('val.tfrecords', val_addrs, val_labels)
createDataRecord('test.tfrecords', test_addrs, test_labels)