-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_GAN.py
43 lines (32 loc) · 1.42 KB
/
test_GAN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import random
from train_GAN_4 import generator, one_hot_to_index, plot_img
#generate some noise
manualSeed = np.random.randint(1, 10000)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
noise = torch.rand(40, 100)
#make generator and load learned weights
gen = generator(100,3844)
gen.load_state_dict(torch.load('b4/b4_weights_batches_2500.pth')) #best verkregen model is b4/2500batches
#gen.load_state_dict(torch.load('gb_weights_batches_2250.pth'))
#gen.load_state_dict(torch.load('control_batches/control_weights_control_batches_2250.pth'))
#gen.load_state_dict(torch.load('constraints1&2_0.1/generator_weights_batches_2750.pth'))
#gen.load_state_dict(torch.load('constraint2_0.1/c2g_weights_batches_1500.pth'))
#gen.load_state_dict(torch.load('c1g_weights_batches_1200.pth'))
#gen.load_state_dict(torch.load('control_model/generator_weights_epoch_9.pth')) #dit is de controle GAN
#gen.load_state_dict(torch.load('trained_model_weights/generator_weights_150_epoch.pth'))
#gen.load_state_dict(torch.load('nieuwe opmaak 33 epochs/generator_weights_epoch_19.pth'))
#gen.load_state_dict(torch.load('3.4_corrected/generator_weights_280.pth'))
gen.eval()
#print some generated levels
samples = gen(noise)
for i in range(20):
plot_img(samples[i])
"""torch.set_printoptions(threshold=10_000)
print(samples[0])
print(samples[1])
print(samples[2])"""