-
Notifications
You must be signed in to change notification settings - Fork 1
/
export.py
146 lines (128 loc) · 5.53 KB
/
export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
'''
Export the given datasets.
python examples/fashionpose/export.py \
--inputs data/fashionista-v1.0/test-1.h5 \
--output public/fashionpose/fashionista-v1.0
python examples/fashionpose/export.py \
--inputs public/fashionpose/poseseg-32s-fashionista-v1.0-test.h5 \
--output public/fashionpose/poseseg-32s-fashionista-v1.0-test
'''
import argparse
import h5py
import logging
import numpy as np
import os
import PIL.Image
#files = ['TMM_test.h5']
#files = ['TMM_val.h5']
files = ['TMM_train-2.h5', 'TMM_train-3.h5', 'TMM_train-4.h5', 'TMM_train-5.h5']
def hdf5_reader(filename, fieldnames=None):
'''
Read HDF5 records.
'''
with h5py.File(filename, 'r') as f:
if not fieldnames:
fieldnames = f.keys()
logging.info("Opening {0} for {1}".format(
filename, ", ".join(fieldnames)))
fields = {name: f[name] for name in fieldnames}
for i in range(list(fields.values())[0].shape[0]):
yield {name: field[i, :] for name, field in fields.items()}
def progress(iterable, num):
'''
Print the current progress.
'''
i = 0
for iterator in iterable:
yield iterator
i = i + 1
if (i % num) == 0:
logging.info("{0}".format(i))
logging.info("{0}".format(i))
def createListTxt(options):
'''
Create text file containing filenames.
'''
logging.info("createListTxt")
fimg_list = open(options.output+'/'+'img_list.txt','w')
fgt_list = open(options.output+'/'+'gt_list.txt','w')
fmask_list = open(options.output+'/'+'mask_list.txt','w')
#for input_file in options.inputs:
for input_file in files:
assert os.path.exists(input_file)
logging.info("Input: {0}".format(input_file))
for output in progress(hdf5_reader(input_file), 100):
if 'image' in output:
fimg_list.write("{0}.{1}".format(output['id'][0], "jpg\n"))
if 'segmentation' in output:
fgt_list.write("{0}.{1}".format(output['id'][0], "png\n"))
if 'seg_prob' in output:
fmask_list.write("{0}.{1}".format(output['id'][0], "png\n"))
if 'score' in output:
fmask_list.write("{0}.{1}".format(output['id'][0], "png\n"))
def export(options):
'''
Export results to image files.
'''
# Make the directory.
logging.info("Exporting {0}".format(options.output))
if os.path.exists(options.output):
logging.info("Exists: {0}".format(options.output))
else:
logging.info("Creating {0}".format(options.output))
os.makedirs(options.output)
# Process each input.
#for input_file in options.inputs:
for input_file in files:
assert os.path.exists(input_file)
logging.info("Input: {0}".format(input_file))
for output in progress(hdf5_reader(input_file), 100):
if 'image' in output:
filename = os.path.join(options.output,"{0}.{1}".format(
output['id'][0], "jpg"))
image_mean = np.array(options.image_mean).reshape(3,1,1)
image = output['image'] + image_mean
image = image[::-1,:,:].transpose((1,2,0)).astype('uint8')
PIL.Image.fromarray(image).resize(
options.output_size, resample=PIL.Image.NEAREST
).save(filename)
if 'segmentation' in output:
filename = os.path.join(options.output,"{0}.{1}".format(
output['id'][0], "png"))
annotation = output['segmentation'][0,:].astype('uint8')
PIL.Image.fromarray(annotation).resize(
options.output_size, resample=PIL.Image.NEAREST
).save(filename)
if 'seg_prob' in output:
filename = os.path.join(options.output,"{0}.{1}".format(
output['id'][0], "png"))
prediction = output['seg_prob'].argmax(axis=0).astype('uint8')
PIL.Image.fromarray(prediction).resize(
options.output_size, resample=PIL.Image.NEAREST
).save(filename)
if 'score' in output:
filename = os.path.join(options.output,"{0}.{1}".format(
output['id'][0], "png"))
prediction = output['score'].argmax(axis=0).astype('uint8')
PIL.Image.fromarray(prediction).resize(
options.output_size, resample=PIL.Image.NEAREST
).save(filename)
logging.info("Done.")
if __name__ == "__main__":
logging.basicConfig(format='[%(asctime)s] %(message)s',
level=logging.DEBUG)
parser = argparse.ArgumentParser(description='Export the images.')
parser.add_argument('--inputs', nargs='+',
help='Input data.', default='TMM_test.h5')
parser.add_argument('--output', type=str, default='trainimages',
help='Output data.')
# parser.add_argument('--output', type=str, default='testimages', help='Output data.')
# parser.add_argument('--output', type=str, default='valimages', help='Output data.')
parser.add_argument('--output_metadata', type=str, default=None,
help='Output metadata.')
parser.add_argument('--output_size', default=(400, 600),
help='Output image size.')
parser.add_argument('--image_mean', default=[
104.00699, 116.66877, 122.67892], help='Output image size.')
export(parser.parse_args())
createListTxt(parser.parse_args())