-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
77 lines (61 loc) · 2.51 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
import numpy as np
import random
from env.observations import SimpleObservation
from agent.PPO.PPOLearner import PPOLearner
from configs import Experiment, AdamConfig, FlatlandConfig, \
SimpleObservationConfig, EnvCurriculumConfig, \
EnvCurriculumSampleConfig, SimpleRewardConfig, SparseRewardConfig, NearRewardConfig, \
DeadlockPunishmentConfig, RewardsComposerConfig, \
NotStopShaperConfig, FinishRewardConfig, JudgeConfig
from env.Flatland import Flatland
from agent.judge.Judge import ConstWindowSizeGenerator, LinearOnAgentNumberSizeGenerator
from logger import log, init_logger
from params import PPOParams
from params import test_env, PackOfAgents
def init_random_seeds(RANDOM_SEED, cuda_determenistic):
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
if cuda_determenistic:
torch.backends.cudnn.deterministic = cuda_determenistic
torch.backends.cudnn.benchmark = cuda_determenistic
def train_ppo(exp, n_workers):
init_random_seeds(exp.random_seed, cuda_determenistic=False)
log().update_params(exp)
learner = PPOLearner(exp.env_config, exp.controller_config, n_workers, exp.device)
learner.rollouts(max_opt_steps=exp.opt_steps, max_episodes=exp.episodes)
learner.controller.save_controller(log().get_log_path())
if __name__ == "__main__":
RANDOM_SEED = 23
torch.set_printoptions(precision=6, sci_mode=False)
logname = "tmp"
init_logger("logdir", logname, use_wandb=False)
timetable_config = JudgeConfig(
window_size_generator=LinearOnAgentNumberSizeGenerator(0.0, 10**10),
lr=1e-4,
batch_size=8,
optimization_epochs=3,
)
obs_builder_config = SimpleObservationConfig(max_depth=3, neighbours_depth=3, timetable_config=timetable_config)
reward_config = RewardsComposerConfig((
FinishRewardConfig(finish_value=10),
NearRewardConfig(coeff=0.01),
DeadlockPunishmentConfig(value=-5),
NotStopShaperConfig(on_switch_value=0, other_value=0),
))
envs = [(PackOfAgents(RANDOM_SEED), 1)]
workers = 1
exp = Experiment(
opt_steps=10**10,
episodes=100000,
device=torch.device("cuda"),
logname=logname,
random_seed=RANDOM_SEED,
env_config = EnvCurriculumSampleConfig(*zip(*envs),
obs_builder_config=obs_builder_config,
reward_config=reward_config),
controller_config = PPOParams(),
)
train_ppo(exp, workers)