-
Notifications
You must be signed in to change notification settings - Fork 0
/
env_training.py
40 lines (32 loc) · 1.33 KB
/
env_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from environments import EnvFactory
from environments.wrappers.DefaultWrappers import DefaultWrappers
from singletons.Logger import Logger
import hydra
from omegaconf import OmegaConf, open_dict
from hydra.utils import instantiate
import numpy as np
import random
import torch
from agents.save.Checkpoint import Checkpoint
@hydra.main(config_path="config", config_name="training")
def train(config):
# Set the seed requested by the user.
np.random.seed(config["seed"])
random.seed(config["seed"])
torch.manual_seed(config["seed"])
# Create the logger and keep track of the configuration.
Logger.get(name="Training").info("Configuration:\n{}".format(OmegaConf.to_yaml(config)))
# Create the environment and apply standard wrappers.
env = EnvFactory.make(config)
with open_dict(config):
config.env.n_actions = env.action_space.n
env = DefaultWrappers.apply(env, config["images"]["shape"])
# Create the agent and train it.
archive = Checkpoint(config["agent"]["tensorboard_dir"], config["checkpoint"]["file"])
agent = archive.load_model() if archive.exists() else instantiate(config["agent"])
agent.train(env, config)
if __name__ == '__main__':
# Make hydra able to load tuples.
OmegaConf.register_new_resolver("tuple", lambda *args: tuple(args))
# Train the agent.
train()