forked from aiff22/DPED
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_model.py
109 lines (73 loc) · 4.27 KB
/
test_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# python test_model.py model=iphone_orig dped_dir=dped/ test_subset=full iteration=all resolution=orig use_gpu=true
import imageio
from PIL import Image
import numpy as np
import tensorflow as tf
from models import resnet
import utils
import os
import sys
tf.compat.v1.disable_v2_behavior()
# process command arguments
phone, dped_dir, test_subset, iteration, resolution, use_gpu = utils.process_test_model_args(sys.argv)
# get all available image resolutions
res_sizes = utils.get_resolutions()
# get the specified image resolution
IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_SIZE = utils.get_specified_res(res_sizes, phone, resolution)
# disable gpu if specified
config = tf.compat.v1.ConfigProto(device_count={'GPU': 0}) if use_gpu == "false" else None
# create placeholders for input images
x_ = tf.compat.v1.placeholder(tf.float32, [None, IMAGE_SIZE])
x_image = tf.reshape(x_, [-1, IMAGE_HEIGHT, IMAGE_WIDTH, 3])
# generate enhanced image
enhanced = resnet(x_image)
with tf.compat.v1.Session(config=config) as sess:
test_dir = dped_dir + phone.replace("_orig", "") + "/test_data/full_size_test_images/"
test_photos = [f for f in os.listdir(test_dir) if os.path.isfile(test_dir + f)]
if test_subset == "small":
# use five first images only
test_photos = test_photos[0:5]
if phone.endswith("_orig"):
# load pre-trained model
saver = tf.compat.v1.train.Saver()
saver.restore(sess, "models_orig/" + phone)
for photo in test_photos:
# load training image and crop it if necessary
print("Testing original " + phone.replace("_orig", "") + " model, processing image " + photo)
image = np.float16(np.array(Image.fromarray(imageio.imread(test_dir + photo))
.resize([res_sizes[phone][1], res_sizes[phone][0]]))) / 255
image_crop = utils.extract_crop(image, resolution, phone, res_sizes)
image_crop_2d = np.reshape(image_crop, [1, IMAGE_SIZE])
# get enhanced image
enhanced_2d = sess.run(enhanced, feed_dict={x_: image_crop_2d})
enhanced_image = np.reshape(enhanced_2d, [IMAGE_HEIGHT, IMAGE_WIDTH, 3])
before_after = np.hstack((image_crop, enhanced_image))
photo_name = photo.rsplit(".", 1)[0]
# save the results as .png images
imageio.imwrite("visual_results/" + phone + "_" + photo_name + "_enhanced.png", enhanced_image)
imageio.imwrite("visual_results/" + phone + "_" + photo_name + "_before_after.png", before_after)
else:
num_saved_models = int(len([f for f in os.listdir("models/") if f.startswith(phone + "_iteration")]) / 2)
if iteration == "all":
iteration = np.arange(1, num_saved_models) * 1000
else:
iteration = [int(iteration)]
for i in iteration:
# load pre-trained model
saver = tf.compat.v1.train.Saver()
saver.restore(sess, "models/" + phone + "_iteration_" + str(i) + ".ckpt")
for photo in test_photos:
# load training image and crop it if necessary
print("iteration " + str(i) + ", processing image " + photo)
image = np.float16(np.array(Image.fromarray(imageio.imread(test_dir + photo))
.resize([res_sizes[phone][1], res_sizes[phone][0]]))) / 255
image_crop = utils.extract_crop(image, resolution, phone, res_sizes)
image_crop_2d = np.reshape(image_crop, [1, IMAGE_SIZE])
# get enhanced image
enhanced_2d = sess.run(enhanced, feed_dict={x_: image_crop_2d})
enhanced_image = np.reshape(enhanced_2d, [IMAGE_HEIGHT, IMAGE_WIDTH, 3])
before_after = np.hstack((image_crop, enhanced_image))
photo_name = photo.rsplit(".", 1)[0]
# save the results as .png images
imageio.imwrite("visual_results/" + phone + "_" + photo_name + "_iteration_" + str(i) + "_enhanced.png", enhanced_image)
imageio.imwrite("visual_results/" + phone + "_" + photo_name + "_iteration_" + str(i) + "_before_after.png", before_after)