-
Notifications
You must be signed in to change notification settings - Fork 1
/
colorization.py
186 lines (145 loc) · 6.84 KB
/
colorization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
from torchvision import transforms
import model.model_utils as mu
import model.p_space as p_space
import utils.data as u_data
import utils.images as u_images
import os
import numpy as np
import lpips
DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
ITERATIONS = 2000
SAVE_STEP = 100
# OPTIMIZER
LEARNING_RATE = 0.01
BETA_1 = 0.9
BETA_2 = 0.999
EPSILON = 1e-8
PATH_DIR = "stuff/data/input/"
EXPECTED_RESULTS = "stuff/data/Peihao_result/"
SAVING_DIR = 'stuff/results/improved_embedding/'
# Loading the pretrained model
G = mu.load_pretrained_model(file_name="ffhq.pkl", space='w',device = DEVICE)
# Pixel-Wise MSE Loss
MSE_Loss = nn.MSELoss(reduction="mean").to(DEVICE)
# Load VGG16 feature detector. # StyleGANv2 version of metric
perceptual_vgg16 = lpips.LPIPS(net='vgg').to(DEVICE)
# affine transformation to P_N+
C, E, mean, S = p_space.get_PCA_results(G, DEVICE, load=True)
map2PN = p_space.mapping_P_N(C, S, mean)
#defining function to calculate loss
def calculate_loss(big_gen_img, big_ref_image, small_ref_image, w_opt, perceptual_net,
MSE_Loss, map2PN, condition_function = None, lambda_v = 0.001):
# get the synth img to [0, 1] to measure the perceptual loss
# big_gen_img = (big_gen_img + 1) / 2
# big_gen_img = 2*big_gen_img - 1 # TODO: decide if this line is necesary in this point or before
# Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
down_transform = u_images.BicubicDownSample(factor=1024 // 256, device = DEVICE)
small_gen_img = down_transform(big_gen_img)
big_gen_img = condition_function(big_gen_img)
small_gen_img = condition_function(small_gen_img)
# calculate LPIPS Perceptual Loss
perceptual_loss = perceptual_net.forward(small_ref_image, small_gen_img)
# calculate MSE Loss
mse_loss = MSE_Loss(big_gen_img,big_ref_image)
# adding the regulizer part
regularizer = lambda_v * (map2PN(w_opt)**2).mean()
return mse_loss, perceptual_loss, regularizer
def run_optimization(data, id, init,
condition_function = None,
sub_fix ="",
save_loss = False,
lambda_v = 0.001):
# get the image sample
basename = data[id]['name'].split(".")[0] + sub_fix
# reference = condition_function(data[id]['img'])
reference = data[id]['img']
big_image = transforms.ToTensor()(reference).unsqueeze(0).to(DEVICE)
small_image = u_images.lanczos_transform(reference, DEVICE)
big_image = 2*big_image - 1
small_image = 2*small_image - 1
big_image = condition_function(big_image)
small_image = condition_function(small_image)
print("small_image", small_image.size())
# define the init latent
w_opt = mu.get_initial_latent(init, G, DEVICE)
optimizer = optim.Adam({w_opt},lr=LEARNING_RATE,betas=(BETA_1,BETA_2),eps=EPSILON)
print("Starting Embedding: id: {} name: {}".format(id,basename))
loss_list=[]
loss_mse=[]
loss_perceptual=[]
latent_list = {}
for i in range(0,ITERATIONS):
# reset the gradients
optimizer.zero_grad()
# get the synthetic image
synth_img = G.synthesis(w_opt, noise_mode='const')
# get the loss and backpropagate the gradients
mse_loss, perceptual_loss, regularizer_term = calculate_loss(synth_img,
big_image,
small_image,
w_opt,
perceptual_vgg16,
MSE_Loss,
map2PN,
condition_function,
lambda_v)
loss = mse_loss + perceptual_loss + regularizer_term
loss.backward()
optimizer.step()
# store the losses metrics
loss_list.append(loss.item())
loss_mse.append(mse_loss.item())
loss_perceptual.append(perceptual_loss.item())
# every SAVE_STEP, I store the current latent
if (i +1) % SAVE_STEP == 0:
print('iter[%d]:\t loss: %.4f\t mse_loss: %.4f\tpercep_loss: %.4f\tregularizer: %.4f' % (i+1,
loss.item(),
mse_loss.item(),
perceptual_loss.item(),
regularizer_term.item()))
latent_list[str(i+1)] = w_opt.detach().cpu().numpy()
# store all the embeddings create during optimization in .npz
path_embedding_latent = os.path.join(SAVING_DIR,
"latents/{}_latents_iters_{}_step_{}_{}.npz".format(
basename,
str(ITERATIONS).zfill(6),
str(SAVE_STEP).zfill(4),
init))
print("Saving: {}".format(path_embedding_latent))
np.savez(path_embedding_latent, **latent_list)
if save_loss:
loss_file = "loss_plots/{}_loss_iters_{}_step_{}_{}.npy".format(
basename,
str(ITERATIONS).zfill(6),
str(SAVE_STEP).zfill(4),
init)
path_loss = os.path.join(SAVING_DIR, loss_file)
print("Saving Loss: {}".format(path_loss))
np.save(path_loss, np.array(loss_list))
return loss_list
# load images from directory
data = u_data.load_data(PATH_DIR)
condition_function_options = {
"colorization" : mu.convert2grayscale_tensor,
# "super_resolution" : u_images.BicubicDownSample(factor=1024 // 256)
}
options_lambdas = [0.001, 0.005, 0.01]
for name, condition_function in condition_function_options.items():
for i in range(len(data)):
for lambda_v in options_lambdas:
loss_list = run_optimization(data, id = i,
init = 'w_mean',
sub_fix=f"_{name}_lambda_{lambda_v}",
save_loss = True,
lambda_v=lambda_v,
condition_function = condition_function)
# for lambda_v in options_lambdas:
# loss_list = run_optimization(data, id = 11,
# init = 'w_mean',
# sub_fix=f"_{test_name}_lambda_{lambda_v}",
# save_loss = True,
# lambda_v=lambda_v)