Skip to content

Commit

Permalink
update path to aitraineree
Browse files Browse the repository at this point in the history
  • Loading branch information
laszukdawid committed Jan 3, 2025
1 parent c94670d commit 609d40d
Show file tree
Hide file tree
Showing 95 changed files with 665 additions and 502 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ on:
- $default-branch
pull_request:
paths:
- "ai_traineree/**.py"
- "aitraineree/**.py"

jobs:
build-n-test:
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
data/
build/
dist/
ai_traineree.egg-info/
aitraineree.egg-info/

videos/
/gifs
Expand Down
24 changes: 12 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ That, and using PyTorch instead of Tensorflow or JAX.

To get started with training your RL agent you need three things: an agent, an environment and a runner. Let's say you want to train a DQN agent on OpenAI CartPole-v1:
```python
from ai_traineree.agents.dqn import DQNAgent
from ai_traineree.runners.env_runner import EnvRunner
from ai_traineree.tasks import GymTask
from aitraineree.agents.dqn import DQNAgent
from aitraineree.runners.env_runner import EnvRunner
from aitraineree.tasks import GymTask

task = GymTask('CartPole-v1')
agent = DQNAgent(task.obs_space, task.action_space)
Expand Down Expand Up @@ -76,14 +76,14 @@ This is just a beginning and there will be more work on these interactions.

| Short | Progress | Link | Full name | Doc |
| ------- | --------------------------------------------- | -------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------- | ------------------------------------------------------------------------ |
| DQN | [Implemented](ai_traineree/agents/dqn.py) | [DeepMind](https://deepmind.com/research/publications/human-level-control-through-deep-reinforcement-learning) | Deep Q-learning Network | [Doc](https://ai-traineree.readthedocs.io/en/latest/agents.html#dqn) |
| DDPG | [Implemented](ai_traineree/agents/ddpg.py) | [arXiv](https://arxiv.org/abs/1509.02971) | Deep Deterministic Policy Gradient | [Doc](https://ai-traineree.readthedocs.io/en/latest/agents.html#ddpg) |
| D4PG | [Implemented](ai_traineree/agents/d4pg.py) | [arXiv](https://arxiv.org/abs/1804.08617) | Distributed Distributional Deterministic Policy Gradients | [Doc](https://ai-traineree.readthedocs.io/en/latest/agents.html#d4pg) |
| TD3 | [Implemented](ai_traineree/agents/td3.py) | [arXiv](https://arxiv.org/abs/1802.09477) | Twine Delayed Deep Deterministic policy gradient | [Doc](https://ai-traineree.readthedocs.io/en/latest/agents.html#td3) |
| PPO | [Implemented](ai_traineree/agents/ppo.py) | [arXiv](https://arxiv.org/abs/1707.06347) | Proximal Policy Optimization | [Doc](https://ai-traineree.readthedocs.io/en/latest/agents.html#ppo) |
| SAC | [Implemented](ai_traineree/agents/sac.py) | [arXiv](https://arxiv.org/abs/1801.01290) | Soft Actor Critic | [Doc](https://ai-traineree.readthedocs.io/en/latest/agents.html#sac) |
| DQN | [Implemented](aitraineree/agents/dqn.py) | [DeepMind](https://deepmind.com/research/publications/human-level-control-through-deep-reinforcement-learning) | Deep Q-learning Network | [Doc](https://ai-traineree.readthedocs.io/en/latest/agents.html#dqn) |
| DDPG | [Implemented](aitraineree/agents/ddpg.py) | [arXiv](https://arxiv.org/abs/1509.02971) | Deep Deterministic Policy Gradient | [Doc](https://ai-traineree.readthedocs.io/en/latest/agents.html#ddpg) |
| D4PG | [Implemented](aitraineree/agents/d4pg.py) | [arXiv](https://arxiv.org/abs/1804.08617) | Distributed Distributional Deterministic Policy Gradients | [Doc](https://ai-traineree.readthedocs.io/en/latest/agents.html#d4pg) |
| TD3 | [Implemented](aitraineree/agents/td3.py) | [arXiv](https://arxiv.org/abs/1802.09477) | Twine Delayed Deep Deterministic policy gradient | [Doc](https://ai-traineree.readthedocs.io/en/latest/agents.html#td3) |
| PPO | [Implemented](aitraineree/agents/ppo.py) | [arXiv](https://arxiv.org/abs/1707.06347) | Proximal Policy Optimization | [Doc](https://ai-traineree.readthedocs.io/en/latest/agents.html#ppo) |
| SAC | [Implemented](aitraineree/agents/sac.py) | [arXiv](https://arxiv.org/abs/1801.01290) | Soft Actor Critic | [Doc](https://ai-traineree.readthedocs.io/en/latest/agents.html#sac) |
| TRPO | | [arXiv](https://arxiv.org/abs/1502.05477) | Trust Region Policy Optimization |
| RAINBOW | [Implemented](ai_traineree/agents/rainbow.py) | [arXiv](https://arxiv.org/abs/1710.02298) | DQN with a few improvements | [Doc](https://ai-traineree.readthedocs.io/en/latest/agents.html#rainbow) |
| RAINBOW | [Implemented](aitraineree/agents/rainbow.py) | [arXiv](https://arxiv.org/abs/1710.02298) | DQN with a few improvements | [Doc](https://ai-traineree.readthedocs.io/en/latest/agents.html#rainbow) |

### Multi agents

Expand All @@ -92,8 +92,8 @@ However, that doesn't mean one can be used without the other.

| Short | Progress | Link | Full name | Doc |
| ------ | ------------------------------------------------- | ----------------------------------------- | ---------------------- | ---------------------------------------------------------------------------- |
| IQL | [Implemented](ai_traineree/multi_agent/iql.py) | | Independent Q-Learners | [Doc](https://ai-traineree.readthedocs.io/en/latest/multi_agent.html#iql) |
| MADDPG | [Implemented](ai_traineree/multi_agent/maddpg.py) | [arXiv](https://arxiv.org/abs/1706.02275) | Multi agent DDPG | [Doc](https://ai-traineree.readthedocs.io/en/latest/multi_agent.html#maddpg) |
| IQL | [Implemented](aitraineree/multi_agent/iql.py) | | Independent Q-Learners | [Doc](https://ai-traineree.readthedocs.io/en/latest/multi_agent.html#iql) |
| MADDPG | [Implemented](aitraineree/multi_agent/maddpg.py) | [arXiv](https://arxiv.org/abs/1706.02275) | Multi agent DDPG | [Doc](https://ai-traineree.readthedocs.io/en/latest/multi_agent.html#maddpg) |

### Loggers

Expand Down
2 changes: 1 addition & 1 deletion aitraineree/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from ai_traineree.types import AgentType
from aitraineree.types import AgentType


class AgentBase(AgentType):
Expand Down
14 changes: 7 additions & 7 deletions aitraineree/agents/agent_factory.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from ai_traineree.agents import AgentBase
from ai_traineree.agents.ddpg import DDPGAgent
from ai_traineree.agents.dqn import DQNAgent
from ai_traineree.agents.ppo import PPOAgent
from ai_traineree.agents.rainbow import RainbowAgent
from ai_traineree.agents.sac import SACAgent
from ai_traineree.types import AgentState
from aitraineree.agents import AgentBase
from aitraineree.agents.ddpg import DDPGAgent
from aitraineree.agents.dqn import DQNAgent
from aitraineree.agents.ppo import PPOAgent
from aitraineree.agents.rainbow import RainbowAgent
from aitraineree.agents.sac import SACAgent
from aitraineree.types import AgentState


class AgentFactory:
Expand Down
24 changes: 12 additions & 12 deletions aitraineree/agents/d3pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@
import torch.nn.functional as F
from torch.optim import Adam

from ai_traineree import DEVICE
from ai_traineree.agents import AgentBase
from ai_traineree.agents.agent_utils import hard_update, soft_update
from ai_traineree.buffers.nstep import NStepBuffer
from ai_traineree.buffers.per import PERBuffer
from ai_traineree.loggers import DataLogger
from ai_traineree.networks.bodies import ActorBody, CriticBody
from ai_traineree.networks.heads import CategoricalNet
from ai_traineree.policies import MultivariateGaussianPolicy, MultivariateGaussianPolicySimple
from ai_traineree.types.dataspace import DataSpace
from ai_traineree.types.experience import Experience
from ai_traineree.utils import to_numbers_seq, to_tensor
from aitraineree import DEVICE
from aitraineree.agents import AgentBase
from aitraineree.agents.agent_utils import hard_update, soft_update
from aitraineree.buffers.nstep import NStepBuffer
from aitraineree.buffers.per import PERBuffer
from aitraineree.loggers import DataLogger
from aitraineree.networks.bodies import ActorBody, CriticBody
from aitraineree.networks.heads import CategoricalNet
from aitraineree.policies import MultivariateGaussianPolicy, MultivariateGaussianPolicySimple
from aitraineree.types.dataspace import DataSpace
from aitraineree.types.experience import Experience
from aitraineree.utils import to_numbers_seq, to_tensor


class D3PGAgent(AgentBase):
Expand Down
24 changes: 12 additions & 12 deletions aitraineree/agents/d4pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,18 @@
import torch.nn.functional as F
from torch.optim import Adam

from ai_traineree import DEVICE
from ai_traineree.agents import AgentBase
from ai_traineree.agents.agent_utils import hard_update, soft_update
from ai_traineree.buffers.nstep import NStepBuffer
from ai_traineree.buffers.per import PERBuffer
from ai_traineree.loggers import DataLogger
from ai_traineree.networks.bodies import ActorBody, CriticBody
from ai_traineree.networks.heads import CategoricalNet
from ai_traineree.policies import MultivariateGaussianPolicy, MultivariateGaussianPolicySimple
from ai_traineree.types.dataspace import DataSpace
from ai_traineree.types.experience import Experience
from ai_traineree.utils import to_numbers_seq, to_tensor
from aitraineree import DEVICE
from aitraineree.agents import AgentBase
from aitraineree.agents.agent_utils import hard_update, soft_update
from aitraineree.buffers.nstep import NStepBuffer
from aitraineree.buffers.per import PERBuffer
from aitraineree.loggers import DataLogger
from aitraineree.networks.bodies import ActorBody, CriticBody
from aitraineree.networks.heads import CategoricalNet
from aitraineree.policies import MultivariateGaussianPolicy, MultivariateGaussianPolicySimple
from aitraineree.types.dataspace import DataSpace
from aitraineree.types.experience import Experience
from aitraineree.utils import to_numbers_seq, to_tensor


class D4PGAgent(AgentBase):
Expand Down
24 changes: 12 additions & 12 deletions aitraineree/agents/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@
from torch.nn.functional import mse_loss
from torch.optim import Adam

from ai_traineree import DEVICE
from ai_traineree.agents import AgentBase
from ai_traineree.agents.agent_utils import hard_update, soft_update
from ai_traineree.buffers.buffer_factory import BufferFactory
from ai_traineree.buffers.replay import ReplayBuffer
from ai_traineree.loggers import DataLogger
from ai_traineree.networks.bodies import CriticBody, FcNet
from ai_traineree.noise import GaussianNoise
from ai_traineree.types import AgentState, BufferState, NetworkState
from ai_traineree.types.dataspace import DataSpace
from ai_traineree.types.experience import Experience
from ai_traineree.utils import to_numbers_seq, to_tensor
from aitraineree import DEVICE
from aitraineree.agents import AgentBase
from aitraineree.agents.agent_utils import hard_update, soft_update
from aitraineree.buffers.buffer_factory import BufferFactory
from aitraineree.buffers.replay import ReplayBuffer
from aitraineree.loggers import DataLogger
from aitraineree.networks.bodies import CriticBody, FcNet
from aitraineree.noise import GaussianNoise
from aitraineree.types import AgentState, BufferState, NetworkState
from aitraineree.types.dataspace import DataSpace
from aitraineree.types.experience import Experience
from aitraineree.utils import to_numbers_seq, to_tensor


class DDPGAgent(AgentBase):
Expand Down
24 changes: 12 additions & 12 deletions aitraineree/agents/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@
import torch.nn.functional as F
import torch.optim as optim

from ai_traineree import DEVICE
from ai_traineree.agents import AgentBase
from ai_traineree.agents.agent_utils import soft_update
from ai_traineree.buffers.buffer_factory import BufferFactory
from ai_traineree.buffers.nstep import NStepBuffer
from ai_traineree.buffers.per import PERBuffer
from ai_traineree.loggers import DataLogger
from ai_traineree.networks import NetworkType, NetworkTypeClass
from ai_traineree.networks.heads import DuelingNet
from ai_traineree.types import AgentState, BufferState, DataSpace, NetworkState
from ai_traineree.types.experience import Experience
from ai_traineree.utils import to_numbers_seq, to_tensor
from aitraineree import DEVICE
from aitraineree.agents import AgentBase
from aitraineree.agents.agent_utils import soft_update
from aitraineree.buffers.buffer_factory import BufferFactory
from aitraineree.buffers.nstep import NStepBuffer
from aitraineree.buffers.per import PERBuffer
from aitraineree.loggers import DataLogger
from aitraineree.networks import NetworkType, NetworkTypeClass
from aitraineree.networks.heads import DuelingNet
from aitraineree.types import AgentState, BufferState, DataSpace, NetworkState
from aitraineree.types.experience import Experience
from aitraineree.utils import to_numbers_seq, to_tensor


class DQNAgent(AgentBase):
Expand Down
6 changes: 3 additions & 3 deletions aitraineree/agents/dummy.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import numpy as np

from ai_traineree.agents import AgentBase
from ai_traineree.types import DataSpace
from ai_traineree.types.experience import Experience
from aitraineree.agents import AgentBase
from aitraineree.types import DataSpace
from aitraineree.types.experience import Experience


class DummyAgent(AgentBase):
Expand Down
24 changes: 12 additions & 12 deletions aitraineree/agents/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@
import torch.nn.functional as F
import torch.optim as optim

from ai_traineree import DEVICE
from ai_traineree.agents import AgentBase
from ai_traineree.agents.agent_utils import compute_gae, normalize, revert_norm_returns
from ai_traineree.buffers.buffer_factory import BufferFactory
from ai_traineree.buffers.rollout import RolloutBuffer
from ai_traineree.loggers import DataLogger
from ai_traineree.networks.bodies import ActorBody
from ai_traineree.policies import MultivariateGaussianPolicy, MultivariateGaussianPolicySimple
from ai_traineree.types import ActionType, AgentState, BufferState, NetworkState
from ai_traineree.types.dataspace import DataSpace
from ai_traineree.types.experience import Experience
from ai_traineree.utils import to_numbers_seq, to_tensor
from aitraineree import DEVICE
from aitraineree.agents import AgentBase
from aitraineree.agents.agent_utils import compute_gae, normalize, revert_norm_returns
from aitraineree.buffers.buffer_factory import BufferFactory
from aitraineree.buffers.rollout import RolloutBuffer
from aitraineree.loggers import DataLogger
from aitraineree.networks.bodies import ActorBody
from aitraineree.policies import MultivariateGaussianPolicy, MultivariateGaussianPolicySimple
from aitraineree.types import ActionType, AgentState, BufferState, NetworkState
from aitraineree.types.dataspace import DataSpace
from aitraineree.types.experience import Experience
from aitraineree.utils import to_numbers_seq, to_tensor


class PPOAgent(AgentBase):
Expand Down
24 changes: 12 additions & 12 deletions aitraineree/agents/rainbow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@
import torch.nn as nn
import torch.optim as optim

from ai_traineree import DEVICE
from ai_traineree.agents import AgentBase
from ai_traineree.agents.agent_utils import soft_update
from ai_traineree.buffers.buffer_factory import BufferFactory
from ai_traineree.buffers.nstep import NStepBuffer
from ai_traineree.buffers.per import PERBuffer
from ai_traineree.loggers import DataLogger
from ai_traineree.networks.heads import RainbowNet
from ai_traineree.types import AgentState, BufferState, NetworkState
from ai_traineree.types.dataspace import DataSpace
from ai_traineree.types.experience import Experience
from ai_traineree.utils import to_numbers_seq, to_tensor
from aitraineree import DEVICE
from aitraineree.agents import AgentBase
from aitraineree.agents.agent_utils import soft_update
from aitraineree.buffers.buffer_factory import BufferFactory
from aitraineree.buffers.nstep import NStepBuffer
from aitraineree.buffers.per import PERBuffer
from aitraineree.loggers import DataLogger
from aitraineree.networks.heads import RainbowNet
from aitraineree.types import AgentState, BufferState, NetworkState
from aitraineree.types.dataspace import DataSpace
from aitraineree.types.experience import Experience
from aitraineree.utils import to_numbers_seq, to_tensor


class RainbowAgent(AgentBase):
Expand Down
26 changes: 13 additions & 13 deletions aitraineree/agents/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,19 @@
import torch.nn as nn
from torch import Tensor, optim

from ai_traineree import DEVICE
from ai_traineree.agents import AgentBase
from ai_traineree.agents.agent_utils import hard_update, soft_update
from ai_traineree.buffers import PERBuffer
from ai_traineree.buffers.buffer_factory import BufferFactory
from ai_traineree.loggers import DataLogger
from ai_traineree.networks.bodies import ActorBody, CriticBody
from ai_traineree.networks.heads import DoubleCritic
from ai_traineree.policies import GaussianPolicy, MultivariateGaussianPolicySimple
from ai_traineree.types.dataspace import DataSpace
from ai_traineree.types.experience import Experience
from ai_traineree.types.state import AgentState, BufferState, NetworkState
from ai_traineree.utils import to_numbers_seq, to_tensor
from aitraineree import DEVICE
from aitraineree.agents import AgentBase
from aitraineree.agents.agent_utils import hard_update, soft_update
from aitraineree.buffers import PERBuffer
from aitraineree.buffers.buffer_factory import BufferFactory
from aitraineree.loggers import DataLogger
from aitraineree.networks.bodies import ActorBody, CriticBody
from aitraineree.networks.heads import DoubleCritic
from aitraineree.policies import GaussianPolicy, MultivariateGaussianPolicySimple
from aitraineree.types.dataspace import DataSpace
from aitraineree.types.experience import Experience
from aitraineree.types.state import AgentState, BufferState, NetworkState
from aitraineree.utils import to_numbers_seq, to_tensor


class SACAgent(AgentBase):
Expand Down
22 changes: 11 additions & 11 deletions aitraineree/agents/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@
from torch.nn.functional import mse_loss
from torch.optim import Adam

from ai_traineree import DEVICE
from ai_traineree.agents import AgentBase
from ai_traineree.agents.agent_utils import hard_update, soft_update
from ai_traineree.buffers.replay import ReplayBuffer
from ai_traineree.loggers import DataLogger
from ai_traineree.networks.bodies import ActorBody, CriticBody
from ai_traineree.networks.heads import DoubleCritic
from ai_traineree.noise import GaussianNoise
from ai_traineree.types.dataspace import DataSpace
from ai_traineree.types.experience import Experience
from ai_traineree.utils import to_numbers_seq, to_tensor
from aitraineree import DEVICE
from aitraineree.agents import AgentBase
from aitraineree.agents.agent_utils import hard_update, soft_update
from aitraineree.buffers.replay import ReplayBuffer
from aitraineree.loggers import DataLogger
from aitraineree.networks.bodies import ActorBody, CriticBody
from aitraineree.networks.heads import DoubleCritic
from aitraineree.noise import GaussianNoise
from aitraineree.types.dataspace import DataSpace
from aitraineree.types.experience import Experience
from aitraineree.utils import to_numbers_seq, to_tensor


class TD3Agent(AgentBase):
Expand Down
Loading

0 comments on commit 609d40d

Please sign in to comment.