From 7360346c42333d1908690e19d15597a1cd1670f1 Mon Sep 17 00:00:00 2001 From: Dawid Laszuk Date: Tue, 24 Aug 2021 14:39:41 -0700 Subject: [PATCH] Fix: add __init__ to runners --- ai_traineree/__init__.py | 2 +- ai_traineree/runners/__init__.py | 0 ai_traineree/tasks.py | 10 ++++++++-- examples/petting_zoo/rps.py | 1 - examples/snek_rainbow.py | 2 +- examples/space_invaders_dqn.py | 2 +- examples/space_invaders_pixel_rainbow.py | 2 +- setup.cfg | 3 ++- tests/{ => runners}/test_env_runner.py | 0 tests/{ => runners}/test_multiagent_env_runner.py | 0 10 files changed, 14 insertions(+), 8 deletions(-) create mode 100644 ai_traineree/runners/__init__.py rename tests/{ => runners}/test_env_runner.py (100%) rename tests/{ => runners}/test_multiagent_env_runner.py (100%) diff --git a/ai_traineree/__init__.py b/ai_traineree/__init__.py index 1b5acf4..70837ee 100644 --- a/ai_traineree/__init__.py +++ b/ai_traineree/__init__.py @@ -1,7 +1,7 @@ import numpy import torch -__version__ = "0.3.4" +__version__ = "0.3.5" # This is expected to be safe, although in PyTorch 1.7 it comes as a warning, diff --git a/ai_traineree/runners/__init__.py b/ai_traineree/runners/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ai_traineree/tasks.py b/ai_traineree/tasks.py index b5b7db3..2218f27 100644 --- a/ai_traineree/tasks.py +++ b/ai_traineree/tasks.py @@ -11,6 +11,7 @@ try: import gym + BaseEnv = gym.Env # To satisfy parser on MultiAgentUnityTask import except ImportError: logging.warning("Coulnd't import `gym`. Please install `pip install -e .[gym]` if you intend to use it.") @@ -19,7 +20,7 @@ from gym_unity.envs import ActionFlattener from mlagents_envs.base_env import BaseEnv, DecisionSteps, TerminalSteps except (ImportError, ModuleNotFoundError): - BaseEnv = gym.Env + logging.warning("Couldn't import `gym_unity` and/or `mlagents`. MultiAgentUnityTask won't work.") GymStepResult = Tuple[np.ndarray, float, bool, Dict] @@ -235,7 +236,12 @@ def dones(self): def last(self, agent_name: Optional[str] = None) -> Tuple[Any, float, bool, Any]: if agent_name is None: return self.env.last() - return (self.env.observe(agent_name), self.env.rewards[agent_name], self.env.dones[agent_name], self.env.infos[agent_name]) + return ( + self.env.observe(agent_name), + self.env.rewards[agent_name], + self.env.dones[agent_name], + self.env.infos[agent_name], + ) def reset(self): self.env.reset() diff --git a/examples/petting_zoo/rps.py b/examples/petting_zoo/rps.py index 45b68c8..11f4000 100644 --- a/examples/petting_zoo/rps.py +++ b/examples/petting_zoo/rps.py @@ -29,4 +29,3 @@ break task.step(action) - \ No newline at end of file diff --git a/examples/snek_rainbow.py b/examples/snek_rainbow.py index 5e6966b..06783f4 100644 --- a/examples/snek_rainbow.py +++ b/examples/snek_rainbow.py @@ -3,7 +3,7 @@ import sneks # noqa from ai_traineree.agents.rainbow import RainbowAgent -from ai_traineree.env_runner import EnvRunner +from ai_traineree.runners.env_runner import EnvRunner from ai_traineree.loggers import TensorboardLogger from ai_traineree.tasks import GymTask diff --git a/examples/space_invaders_dqn.py b/examples/space_invaders_dqn.py index 0a17306..e5b4976 100644 --- a/examples/space_invaders_dqn.py +++ b/examples/space_invaders_dqn.py @@ -4,7 +4,7 @@ from collections import deque from ai_traineree.agents.dqn import DQNAgent -from ai_traineree.env_runner import EnvRunner +from ai_traineree.runners.env_runner import EnvRunner from ai_traineree.loggers import TensorboardLogger from ai_traineree.networks.heads import NetChainer from ai_traineree.networks.bodies import ConvNet, FlattenNet, FcNet, ScaleNet diff --git a/examples/space_invaders_pixel_rainbow.py b/examples/space_invaders_pixel_rainbow.py index 70c9922..82efcbf 100644 --- a/examples/space_invaders_pixel_rainbow.py +++ b/examples/space_invaders_pixel_rainbow.py @@ -6,7 +6,7 @@ from ai_traineree.loggers import TensorboardLogger from ai_traineree.networks.bodies import ConvNet, FcNet from ai_traineree.networks.heads import NetChainer -from ai_traineree.env_runner import EnvRunner +from ai_traineree.runners.env_runner import EnvRunner from ai_traineree.tasks import GymTask diff --git a/setup.cfg b/setup.cfg index 72a78ae..be5119f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = ai-traineree -version = 0.3.4 +version = 0.3.5 author = Dawid Laszuk author_email = ai-traineree@dawid.lasz.uk description = Yet another zoo of (Deep) Reinforcment Learning methods in Python using PyTorch @@ -77,6 +77,7 @@ exclude_lines = ignore = E203, E226, # I like to group operations. What are you going to do about it, huh? E252, # Ain't nobody tell me how to type arguments + W293, # Lines within code with whitespace W503 max_line_length = 120 diff --git a/tests/test_env_runner.py b/tests/runners/test_env_runner.py similarity index 100% rename from tests/test_env_runner.py rename to tests/runners/test_env_runner.py diff --git a/tests/test_multiagent_env_runner.py b/tests/runners/test_multiagent_env_runner.py similarity index 100% rename from tests/test_multiagent_env_runner.py rename to tests/runners/test_multiagent_env_runner.py