-
Notifications
You must be signed in to change notification settings - Fork 1
/
main_eval.py
84 lines (67 loc) · 2 KB
/
main_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
from utils.data_utils import loading_scene_list
from utils.model_util import ScalarMeanTracker
os.environ["OMP_NUM_THREADS"] = "1"
import torch.multiprocessing as mp
import time
import json
from tqdm import tqdm
from runners import a3c_val
def main_eval(args, saved_model, outdir, device):
scenes = loading_scene_list(phase="eval")
processes = []
res_queue = mp.Queue()
args.learned_loss = False
args.num_steps = 50
target = a3c_val.a3c_val
try:
mp.set_start_method("spawn")
except:
pass
rank = 0
scene_types = ['kitchen', 'living_room', 'bedroom', 'bathroom']
for scene_type in scene_types:
p = mp.Process(
target=target,
args=(
args,
saved_model,
res_queue,
250,
scene_type,
scenes[rank],
device,
),
)
p.start()
processes.append(p)
time.sleep(0.1)
rank += 1
count = 0
end_count = 0
proc = len(scene_types)
pbar = tqdm(total=250 * proc)
train_scalars = ScalarMeanTracker()
visualizations = []
try:
while end_count < proc:
train_result = res_queue.get()
pbar.update(1)
count += 1
if "END" in train_result:
end_count += 1
continue
train_scalars.add_scalars(train_result)
visualizations.append(train_result['tools'])
tracked_means = train_scalars.pop_and_reset()
finally:
for p in processes:
time.sleep(0.1)
p.join()
with open(args.results_json, "w") as fp:
json.dump(tracked_means, fp, sort_keys=True, indent=4)
visualization_dir = f"{outdir}/visualization_files"
if not os.path.exists(visualization_dir):
os.mkdir(visualization_dir)
with open(os.path.join(visualization_dir, "visualize.json"), 'w') as wf:
json.dump(visualizations, wf)