diff --git a/ai_traineree/__init__.py b/ai_traineree/__init__.py index ad85d5a..a235332 100644 --- a/ai_traineree/__init__.py +++ b/ai_traineree/__init__.py @@ -1,6 +1,6 @@ import torch -__version__ = "0.6.0" +__version__ = "0.7.1" try: diff --git a/ai_traineree/agents/d3pg.py b/ai_traineree/agents/d3pg.py index ffcd842..841cc3b 100644 --- a/ai_traineree/agents/d3pg.py +++ b/ai_traineree/agents/d3pg.py @@ -1,6 +1,5 @@ import itertools from functools import cached_property -from typing import Dict import torch import torch.nn as nn @@ -170,7 +169,7 @@ def __init__(self, obs_space: DataSpace, action_space: DataSpace, **kwargs): self._metric_batch_value_dist = torch.zeros(self.batch_size) @property - def loss(self) -> Dict[str, float]: + def loss(self) -> dict[str, float]: return {"actor": self._loss_actor, "critic": self._loss_critic} @loss.setter @@ -322,7 +321,7 @@ def learn(self, experiences): soft_update(self.target_actor, self.actor, self.tau) soft_update(self.target_critic, self.critic, self.tau) - def state_dict(self) -> Dict[str, dict]: + def state_dict(self) -> dict[str, dict]: """Describes agent's networks. Returns: diff --git a/ai_traineree/agents/d4pg.py b/ai_traineree/agents/d4pg.py index bcad69c..255b581 100644 --- a/ai_traineree/agents/d4pg.py +++ b/ai_traineree/agents/d4pg.py @@ -5,7 +5,6 @@ import itertools from functools import cached_property -from typing import Dict import torch import torch.nn as nn @@ -175,7 +174,7 @@ def __init__(self, obs_space: DataSpace, action_space: DataSpace, **kwargs): self._loss_critic = float("nan") @property - def loss(self) -> Dict[str, float]: + def loss(self) -> dict[str, float]: return {"actor": self._loss_actor, "critic": self._loss_critic} @cached_property @@ -337,7 +336,7 @@ def learn(self, experiences): soft_update(self.target_actor, self.actor, self.tau) soft_update(self.target_critic, self.critic, self.tau) - def state_dict(self) -> Dict[str, dict]: + def state_dict(self) -> dict[str, dict]: """Describes agent's networks. Returns: diff --git a/ai_traineree/agents/ddpg.py b/ai_traineree/agents/ddpg.py index 19b2eae..ddccc6b 100644 --- a/ai_traineree/agents/ddpg.py +++ b/ai_traineree/agents/ddpg.py @@ -1,7 +1,6 @@ import copy import operator from functools import cached_property, reduce -from typing import Dict, Optional import torch import torch.nn as nn @@ -125,7 +124,7 @@ def reset_agent(self) -> None: hard_update(self.target_critic, self.critic) @property - def loss(self) -> Dict[str, float]: + def loss(self) -> dict[str, float]: return {"actor": self._loss_actor, "critic": self._loss_critic} @loss.setter @@ -201,7 +200,6 @@ def step(self, experience: Experience) -> None: soft_update(self.target_critic, self.critic, self.tau) def compute_value_loss(self, states, actions, next_states, rewards, dones): - Q_expected = self.critic(states, actions) with torch.no_grad(): @@ -254,7 +252,7 @@ def learn(self, experiences) -> None: self.critic.requires_grad_ = True - def state_dict(self) -> Dict[str, dict]: + def state_dict(self) -> dict[str, dict]: """Describes agent's networks. Returns: @@ -328,7 +326,7 @@ def save_state(self, path: str) -> None: agent_state = self.get_state() torch.save(agent_state, path) - def load_state(self, *, path: Optional[str] = None, agent_state: Optional[dict] = None): + def load_state(self, *, path: str | None = None, agent_state: dict | None = None): if path is None and agent_state: raise ValueError("Either `path` or `agent_state` must be provided to load agent's state.") if path is not None and agent_state is None: diff --git a/ai_traineree/agents/dqn.py b/ai_traineree/agents/dqn.py index 1dcff1e..aab7580 100644 --- a/ai_traineree/agents/dqn.py +++ b/ai_traineree/agents/dqn.py @@ -1,5 +1,5 @@ import copy -from typing import Callable, Dict, Optional, Type +from typing import Callable, Type import torch import torch.nn as nn @@ -40,9 +40,9 @@ def __init__( action_space: DataSpace, network_fn: Callable[[], NetworkType] = None, network_class: Type[NetworkTypeClass] = None, - state_transform: Optional[Callable] = None, - reward_transform: Optional[Callable] = None, - **kwargs + state_transform: Callable | None = None, + reward_transform: Callable | None = None, + **kwargs, ): """Initiates the DQN agent. @@ -113,7 +113,7 @@ def __init__( self._loss: float = float("nan") @property - def loss(self) -> Dict[str, float]: + def loss(self) -> dict[str, float]: return {"loss": self._loss} @loss.setter @@ -202,7 +202,7 @@ def act(self, experience: Experience, eps: float = 0.0) -> Experience: action = int(torch.argmax(action_values.cpu())) return experience.update(action=action) - def learn(self, experiences: Dict[str, list]) -> None: + def learn(self, experiences: dict[str, list]) -> None: """Updates agent's networks based on provided experience. Parameters: @@ -237,7 +237,7 @@ def learn(self, experiences: Dict[str, list]) -> None: assert any(~torch.isnan(error)) self.buffer.priority_update(experiences["index"], error.abs()) - def state_dict(self) -> Dict[str, dict]: + def state_dict(self) -> dict[str, dict]: """Describes agent's networks. Returns: @@ -301,7 +301,7 @@ def save_state(self, path: str): agent_state = self.get_state() torch.save(agent_state, path) - def load_state(self, *, path: Optional[str] = None, state: Optional[AgentState] = None) -> None: + def load_state(self, *, path: str | None = None, state: AgentState | None = None) -> None: """Loads state from a file under provided path. Parameters: diff --git a/ai_traineree/agents/ppo.py b/ai_traineree/agents/ppo.py index 43360cb..e77d668 100644 --- a/ai_traineree/agents/ppo.py +++ b/ai_traineree/agents/ppo.py @@ -1,7 +1,6 @@ import copy import itertools import logging -from typing import Dict, List import torch import torch.nn as nn @@ -124,10 +123,10 @@ def __init__(self, obs_space: DataSpace, action_space: DataSpace, **kwargs): self.critic_opt = optim.Adam(self.critic_params, lr=self.critic_lr) self._loss_actor = float("nan") self._loss_critic = float("nan") - self._metrics: Dict[str, float] = {} + self._metrics: dict[str, float] = {} @property - def loss(self) -> Dict[str, float]: + def loss(self) -> dict[str, float]: return {"actor": self._loss_actor, "critic": self._loss_critic} @loss.setter @@ -163,7 +162,7 @@ def act(self, experience: Experience, noise: float = 0.0) -> Experience: Experience updated with action taken. """ - actions: List[ActionType] = [] + actions: list[ActionType] = [] logprobs = [] values = [] t_obs = to_tensor(experience.obs).view((self.num_workers,) + self.obs_space.shape).float().to(self.device) diff --git a/ai_traineree/agents/rainbow.py b/ai_traineree/agents/rainbow.py index 41a5d94..5ed8f6a 100644 --- a/ai_traineree/agents/rainbow.py +++ b/ai_traineree/agents/rainbow.py @@ -1,5 +1,5 @@ import copy -from typing import Callable, Dict, List, Optional +from typing import Callable import torch import torch.nn as nn @@ -45,8 +45,8 @@ def __init__( self, obs_space: DataSpace, action_space: DataSpace, - state_transform: Optional[Callable] = None, - reward_transform: Optional[Callable] = None, + state_transform: Callable | None = None, + reward_transform: Callable | None = None, **kwargs, ): """ @@ -201,7 +201,7 @@ def act(self, experience: Experience, eps: float = 0.0) -> Experience: action = int(q_values.argmax(-1)) # Action maximizes state-action value Q(s, a) return experience.update(action=action) - def learn(self, experiences: Dict[str, List]) -> None: + def learn(self, experiences: dict[str, list]) -> None: """ Parameters: experiences: Contains all experiences for the agent. Typically sampled from the memory buffer. @@ -256,7 +256,7 @@ def learn(self, experiences: Dict[str, List]) -> None: # Update networks - sync local & target soft_update(self.target_net, self.net, self.tau) - def state_dict(self) -> Dict[str, dict]: + def state_dict(self) -> dict[str, dict]: """Returns agent's state dictionary. Returns: diff --git a/ai_traineree/agents/sac.py b/ai_traineree/agents/sac.py index 9826e4e..55bc5ae 100644 --- a/ai_traineree/agents/sac.py +++ b/ai_traineree/agents/sac.py @@ -1,7 +1,6 @@ import copy import itertools from functools import cached_property -from typing import Dict, Tuple, Union import numpy as np import torch @@ -162,14 +161,14 @@ def __init__(self, obs_space: DataSpace, action_space: DataSpace, **kwargs): self._loss_actor = float("nan") self._loss_critic = float("nan") - self._metrics: Dict[str, Union[float, Dict[str, float]]] = {} + self._metrics: dict[str, float | dict[str, float]] = {} @property def alpha(self): return self.log_alpha.exp() @property - def loss(self) -> Dict[str, float]: + def loss(self) -> dict[str, float]: return {"actor": self._loss_actor, "critic": self._loss_critic} @loss.setter @@ -204,7 +203,7 @@ def action_min(self): def action_max(self): return to_tensor(self.action_space.high) - def state_dict(self) -> Dict[str, dict]: + def state_dict(self) -> dict[str, dict]: """ Returns network's weights in order: Actor, TargetActor, Critic, TargetCritic @@ -270,7 +269,7 @@ def step(self, experience: Experience) -> None: for _ in range(self.number_updates): self.learn(self.buffer.sample()) - def compute_value_loss(self, states, actions, rewards, next_states, dones) -> Tuple[Tensor, Tensor]: + def compute_value_loss(self, states, actions, rewards, next_states, dones) -> tuple[Tensor, Tensor]: Q1_expected, Q2_expected = self.double_critic(states, actions) with torch.no_grad(): diff --git a/ai_traineree/agents/td3.py b/ai_traineree/agents/td3.py index 3e75328..928c336 100644 --- a/ai_traineree/agents/td3.py +++ b/ai_traineree/agents/td3.py @@ -1,5 +1,4 @@ from functools import cached_property -from typing import Dict import torch import torch.nn as nn @@ -35,7 +34,7 @@ def __init__( action_space: DataSpace, noise_scale: float = 0.5, noise_sigma: float = 1.0, - **kwargs + **kwargs, ): """ Parameters: @@ -129,7 +128,7 @@ def __init__( self._loss_critic = float("nan") @property - def loss(self) -> Dict[str, float]: + def loss(self) -> dict[str, float]: return {"actor": self._loss_actor, "critic": self._loss_critic} @loss.setter @@ -262,7 +261,7 @@ def _update_policy(self, states): self.critic.requires_grad_ = True - def state_dict(self) -> Dict[str, dict]: + def state_dict(self) -> dict[str, dict]: """Describes agent's networks. Returns: diff --git a/ai_traineree/buffers/__init__.py b/ai_traineree/buffers/__init__.py index edc3490..5656fd3 100644 --- a/ai_traineree/buffers/__init__.py +++ b/ai_traineree/buffers/__init__.py @@ -1,6 +1,6 @@ import abc from collections import defaultdict -from typing import Dict, List, Optional, Sequence, Union +from typing import Sequence import numpy as np import torch @@ -34,15 +34,15 @@ def add(self, **kwargs): """Add samples to the buffer.""" raise NotImplementedError("You shouldn't see this. Look away. Or fix it.") - def sample(self, *args, **kwargs) -> Optional[List[Experience]]: + def sample(self, *args, **kwargs) -> list[Experience | None]: """Sample buffer for a set of experience.""" raise NotImplementedError("You shouldn't see this. Look away. Or fix it.") - def dump_buffer(self, serialize: bool = False) -> List[Dict]: + def dump_buffer(self, serialize: bool = False) -> list[dict]: """Return the whole buffer, e.g. for storing.""" raise NotImplementedError("You shouldn't see this. Look away. Or fix it.") - def load_buffer(self, buffer: List[Experience]) -> None: + def load_buffer(self, buffer: list[Experience]) -> None: """Loads provided data into the buffer.""" raise NotImplementedError("You shouldn't see this. Look away. Or fix it.") @@ -67,7 +67,7 @@ def __len__(self) -> int: return len(self.buffer) @staticmethod - def _hash_element(el) -> Union[int, str]: + def _hash_element(el) -> int | str: if isinstance(el, np.ndarray): return hash(el.data.tobytes()) elif isinstance(el, torch.Tensor): @@ -75,14 +75,14 @@ def _hash_element(el) -> Union[int, str]: else: return str(el) - def add(self, el) -> Union[int, str]: + def add(self, el) -> int | str: idx = self._hash_element(el) self.counter[idx] += 1 if self.counter[idx] < 2: self.buffer[idx] = el return idx - def get(self, idx: Union[int, str]): + def get(self, idx: int | str): return self.buffer[idx] def remove(self, idx: str): diff --git a/ai_traineree/buffers/per.py b/ai_traineree/buffers/per.py index ce97ef7..cdee91a 100644 --- a/ai_traineree/buffers/per.py +++ b/ai_traineree/buffers/per.py @@ -2,7 +2,7 @@ import math import random from collections import defaultdict -from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple +from typing import Any, Iterator, Sequence import numpy @@ -84,7 +84,7 @@ def add(self, *, priority: float = 0, **kwargs): self._states.remove(old_data["obs_idx"]) self._states.remove(old_data["next_obs_idx"]) - def _sample_list(self, beta: float = 1, **kwargs) -> List[Experience]: + def _sample_list(self, beta: float = 1, **kwargs) -> list[Experience]: """The method return samples randomly without duplicates""" if len(self.tree) < self.batch_size: return [] @@ -107,7 +107,7 @@ def _sample_list(self, beta: float = 1, **kwargs) -> List[Experience]: self.priority_update(indices, priorities) # Revert priorities weights = weights / max(weights) - for (experience, weight, index) in zip(samples, weights, indices): + for experience, weight, index in zip(samples, weights, indices): experience.weight = weight experience.index = index if self._states_mng: @@ -117,7 +117,7 @@ def _sample_list(self, beta: float = 1, **kwargs) -> List[Experience]: return experiences - def sample(self, beta: float = 0.5) -> Optional[Dict[str, List]]: + def sample(self, beta: float = 0.5) -> dict[str, list] | None: all_experiences = defaultdict(lambda: []) sampled_exp = self._sample_list(beta=beta) if len(sampled_exp) == 0: @@ -132,7 +132,7 @@ def sample(self, beta: float = 0.5) -> Optional[Dict[str, List]]: all_experiences[key].append(value) return all_experiences - def priority_update(self, indices: Sequence[int], priorities: List) -> None: + def priority_update(self, indices: Sequence[int], priorities: list) -> None: """Updates prioprities for elements on provided indices.""" for i, p in zip(indices, priorities): self.tree.weight_update(i, math.pow(p, self.alpha)) @@ -155,12 +155,12 @@ def from_state(state: BufferState): buffer.load_buffer(state.data) return buffer - def dump_buffer(self, serialize: bool = False) -> Iterator[Dict[str, List]]: + def dump_buffer(self, serialize: bool = False) -> Iterator[dict[str, list]]: for exp in self.tree.data[: len(self.tree)]: # yield Experience(**exp).get_dict(serialize=serialize) yield exp.get_dict(serialize=serialize) - def load_buffer(self, buffer: List[Experience]): + def load_buffer(self, buffer: list[Experience]): for experience in buffer: self.add(**experience.data) @@ -182,7 +182,7 @@ def __init__(self, leafs_num: int): self.leaf_offset = 2 ** (self.tree_height - 1) - 1 self.tree_size = 2**self.tree_height - 1 self.tree = numpy.zeros(self.tree_size) - self.data: List[Optional[Dict]] = [None] * self.leafs_num + self.data: list[dict | None] = [None] * self.leafs_num self.size = 0 self.cursor = 0 @@ -216,12 +216,12 @@ def _tree_update(self, tindex, diff): self.tree[tindex] += diff tindex = (tindex - 1) // 2 - def find(self, weight) -> Tuple[Any, float, int]: + def find(self, weight) -> tuple[Any, float, int]: """Returns (data, weight, index)""" assert 0 <= weight <= 1, "Expecting weight to be sampling weight [0, 1]" return self._find(weight * self.tree[0], 0) - def _find(self, weight, index) -> Tuple[Any, float, int]: + def _find(self, weight, index) -> tuple[Any, float, int]: """Recursively finds a data by the weight. Returns: diff --git a/ai_traineree/buffers/replay.py b/ai_traineree/buffers/replay.py index befd3a5..742df4a 100644 --- a/ai_traineree/buffers/replay.py +++ b/ai_traineree/buffers/replay.py @@ -1,5 +1,5 @@ import random -from typing import Dict, Iterator, List, Optional, Sequence +from typing import Iterator, Sequence from ai_traineree.buffers import ReferenceBuffer from ai_traineree.types.state import BufferState @@ -8,7 +8,6 @@ class ReplayBuffer(BufferBase): - type = "Replay" keys = ["states", "actions", "rewards", "next_states", "dones"] @@ -25,7 +24,7 @@ def __init__(self, batch_size: int, buffer_size=int(1e6), **kwargs): self.batch_size = batch_size self.buffer_size = buffer_size self.indices = range(batch_size) - self.data: List[Experience] = [] + self.data: list[Experience] = [] self._states_mng = kwargs.get("compress_state", False) self._states = ReferenceBuffer(buffer_size + 20) @@ -57,7 +56,7 @@ def add(self, **kwargs): self._states.remove(drop_exp.state_idx) self._states.remove(drop_exp.next_state_idx) - def sample(self, keys: Optional[Sequence[str]] = None) -> Dict[str, List]: + def sample(self, keys: Sequence[str] | None = None) -> dict[str, list]: """ Parameters: keys: A list of keys which limit the return. @@ -66,7 +65,7 @@ def sample(self, keys: Optional[Sequence[str]] = None) -> Dict[str, List]: Returns: Returns all values for asked keys. """ - sampled_exp: List[Experience] = self._rng.sample(self.data, self.batch_size) + sampled_exp: list[Experience] = self._rng.sample(self.data, self.batch_size) keys = keys if keys is not None else list(self.data[0].__dict__.keys()) all_experiences = {k: [] for k in keys} for data in sampled_exp: @@ -88,11 +87,11 @@ def from_state(state: BufferState): buffer.load_buffer(state.data) return buffer - def dump_buffer(self, serialize: bool = False) -> Iterator[Dict[str, List]]: + def dump_buffer(self, serialize: bool = False) -> Iterator[dict[str, list]]: for data in self.data: yield data.get_dict(serialize=serialize) - def load_buffer(self, buffer: List[Experience]): + def load_buffer(self, buffer: list[Experience]): for experience in buffer: # self.add(**experience) self.add(**experience.data) diff --git a/ai_traineree/buffers/rollout.py b/ai_traineree/buffers/rollout.py index fe95569..693ad54 100644 --- a/ai_traineree/buffers/rollout.py +++ b/ai_traineree/buffers/rollout.py @@ -1,5 +1,5 @@ from collections import defaultdict, deque -from typing import Deque, Dict, Iterator, List, Optional +from typing import Deque, Iterator from ai_traineree.buffers import ReferenceBuffer from ai_traineree.types.state import BufferState @@ -8,7 +8,6 @@ class RolloutBuffer(BufferBase): - type = "Rollout" def __init__(self, batch_size: int, buffer_size=int(1e6), **kwargs): @@ -56,7 +55,7 @@ def add(self, **kwargs): self._states.remove(drop_exp.state_idx) self._states.remove(drop_exp.next_state_idx) - def sample(self, batch_size: Optional[int] = None) -> Iterator[Dict[str, list]]: + def sample(self, batch_size: int | None = None) -> Iterator[dict[str, list]]: """ Samples the whole buffer. Iterates all gathered data. Note that sampling doesn't clear the buffer. @@ -94,11 +93,11 @@ def from_state(state: BufferState): buffer.load_buffer(state.data) return buffer - def dump_buffer(self, serialize: bool = False) -> Iterator[Dict[str, List]]: + def dump_buffer(self, serialize: bool = False) -> Iterator[dict[str, list]]: for data in self.data: yield data.get_dict(serialize=serialize) - def load_buffer(self, buffer: List[Experience]): + def load_buffer(self, buffer: list[Experience]): for experience in buffer: self.add(**experience.data) diff --git a/ai_traineree/loggers/data_logger.py b/ai_traineree/loggers/data_logger.py index 3e04095..67adafb 100644 --- a/ai_traineree/loggers/data_logger.py +++ b/ai_traineree/loggers/data_logger.py @@ -1,5 +1,4 @@ import abc -from typing import Dict class DataLogger(abc.ABC): @@ -7,24 +6,19 @@ def __dell__(self): self.close() @abc.abstractmethod - def close(self) -> None: - ... + def close(self) -> None: ... @abc.abstractmethod - def set_hparams(self, *, hparams: Dict) -> None: - ... + def set_hparams(self, *, hparams: dict) -> None: ... @abc.abstractmethod - def log_value(self, name, value, step) -> None: - ... + def log_value(self, name, value, step) -> None: ... @abc.abstractmethod - def log_values_dict(self, name, values, step) -> None: - ... + def log_values_dict(self, name, values, step) -> None: ... @abc.abstractmethod - def add_histogram(self, *args, **kwargs) -> None: - ... + def add_histogram(self, *args, **kwargs) -> None: ... @abc.abstractmethod def create_histogram(self, name, values, step) -> None: diff --git a/ai_traineree/loggers/file_logger.py b/ai_traineree/loggers/file_logger.py index 430a202..a07a115 100644 --- a/ai_traineree/loggers/file_logger.py +++ b/ai_traineree/loggers/file_logger.py @@ -1,7 +1,6 @@ import json import time from os.path import getsize, splitext -from typing import Dict from .data_logger import DataLogger @@ -59,7 +58,7 @@ def _check_and_trim(self): all_rows = f.readlines() f.writelines(all_rows[len(all_rows) // 2 :]) - def set_hparams(self, *, hparams: Dict): + def set_hparams(self, *, hparams: dict): filepath = splitext(self.filepath)[0] with open(filepath + "_hparams.json", "w") as f: @@ -70,7 +69,7 @@ def log_value(self, name: str, value, step: int) -> None: with open(self.filepath, "a") as f: f.write(f"{self._timestamp()},step,{step},{name},{value}\n") - def log_values_dict(self, name: str, values: Dict[str, float], step: int) -> None: + def log_values_dict(self, name: str, values: dict[str, float], step: int) -> None: self._check_and_trim() log = ",".join([f"{name}_{key},{value}" for (key, value) in values.items()]) with open(self.filepath, "a") as f: diff --git a/ai_traineree/loggers/neptune_logger.py b/ai_traineree/loggers/neptune_logger.py index d8ed049..5d08ed5 100644 --- a/ai_traineree/loggers/neptune_logger.py +++ b/ai_traineree/loggers/neptune_logger.py @@ -1,5 +1,4 @@ import logging -from typing import Dict from .data_logger import DataLogger @@ -34,8 +33,7 @@ def __str__(self) -> str: def close(self): self.experiment.stop() - def set_hparams(self, *, hparams: Dict): - ... + def set_hparams(self, *, hparams: dict): ... def log_value(self, name: str, value, step: int) -> None: self.experiment.log_metric(name, x=step, y=value) diff --git a/ai_traineree/loggers/tensorboard_logger.py b/ai_traineree/loggers/tensorboard_logger.py index e3e8ebd..e5d10ad 100644 --- a/ai_traineree/loggers/tensorboard_logger.py +++ b/ai_traineree/loggers/tensorboard_logger.py @@ -1,5 +1,4 @@ import logging -from typing import Dict from .data_logger import DataLogger @@ -44,13 +43,13 @@ def __str__(self): def close(self): self.writer.close() - def set_hparams(self, *, hparams: Dict): + def set_hparams(self, *, hparams: dict): self.writer.add_hparams(hparam_dict=hparams, metric_dict={}) def log_value(self, name: str, value, step: int) -> None: self.writer.add_scalar(name, value, step) - def log_values_dict(self, name: str, values: Dict[str, float], step: int) -> None: + def log_values_dict(self, name: str, values: dict[str, float], step: int) -> None: self.writer.add_scalars(name, values, step) def add_histogram(self, *args, **kwargs): diff --git a/ai_traineree/multi_agents/independent.py b/ai_traineree/multi_agents/independent.py index 1fb2941..da5aa13 100644 --- a/ai_traineree/multi_agents/independent.py +++ b/ai_traineree/multi_agents/independent.py @@ -1,5 +1,3 @@ -from typing import Dict, List, Optional - import torch from ai_traineree.loggers import DataLogger @@ -8,10 +6,9 @@ class IndependentAgents(MultiAgentType): - model = "IMA" - def __init__(self, agents: List[AgentType], agent_names: Optional[List[str]] = None, **kwargs): + def __init__(self, agents: list[AgentType], agent_names: list[str] | None = None, **kwargs): """Independent agents. An abstraction to manage multiple agents. It assumes no interaction between agents. @@ -29,12 +26,12 @@ def __init__(self, agents: List[AgentType], agent_names: Optional[List[str]] = N assert len(agent_names) == len(agents), "Expecting `agents` and `agent_names` to have the same lengths" self.num_agents = len(agents) - self.agents: Dict[str, AgentType] = {agent_name: agent for (agent_name, agent) in zip(agent_names, agents)} + self.agents: dict[str, AgentType] = {agent_name: agent for (agent_name, agent) in zip(agent_names, agents)} self.reset() @property - def loss(self) -> Dict[str, float]: + def loss(self) -> dict[str, float]: out = {} for agent_name, agent in self.agents.items(): for loss_name, loss_value in agent.loss.items(): @@ -94,5 +91,5 @@ def load_state(self, path: str): agent._config = agent_state.get("config", {}) agent.__dict__.update(**agent._config) - def state_dict(self) -> Dict[str, dict]: + def state_dict(self) -> dict[str, dict]: return {name: agent.state_dict() for (name, agent) in self.agents.items()} diff --git a/ai_traineree/multi_agents/iql.py b/ai_traineree/multi_agents/iql.py index 0eb5b4b..c300935 100644 --- a/ai_traineree/multi_agents/iql.py +++ b/ai_traineree/multi_agents/iql.py @@ -1,5 +1,3 @@ -from typing import Dict - import torch from ai_traineree import DEVICE @@ -11,7 +9,6 @@ class IQLAgents(MultiAgentType): - model = "IQL" def __init__(self, obs_space: DataSpace, action_space: DataSpace, num_agents: int, **kwargs): @@ -58,14 +55,14 @@ def __init__(self, obs_space: DataSpace, action_space: DataSpace, num_agents: in kwargs["update_freq"] = int(self._register_param(kwargs, "update_freq", 1)) kwargs["number_updates"] = int(self._register_param(kwargs, "number_updates", 1)) - self.agents: Dict[str, DQNAgent] = { + self.agents: dict[str, DQNAgent] = { agent_name: DQNAgent(obs_space, action_space, name=agent_name, **kwargs) for agent_name in self.agent_names } self.reset() @property - def loss(self) -> Dict[str, float]: + def loss(self) -> dict[str, float]: out = {} for agent_name, agent in self.agents.items(): for loss_name, loss_value in agent.loss.items(): @@ -128,5 +125,5 @@ def load_state(self, path: str): agent._config = agent_state.get("config", {}) agent.__dict__.update(**agent._config) - def state_dict(self) -> Dict[str, dict]: + def state_dict(self) -> dict[str, dict]: return {name: agent.state_dict() for (name, agent) in self.agents.items()} diff --git a/ai_traineree/multi_agents/maddpg.py b/ai_traineree/multi_agents/maddpg.py index 84e5766..c0ea701 100644 --- a/ai_traineree/multi_agents/maddpg.py +++ b/ai_traineree/multi_agents/maddpg.py @@ -1,5 +1,5 @@ from collections import OrderedDict, defaultdict -from typing import Any, Dict, List, Optional +from typing import Any import torch import torch.nn as nn @@ -19,7 +19,6 @@ class MADDPGAgent(MultiAgentType): - model = "MADDPG" def __init__(self, obs_space: DataSpace, action_space: DataSpace, num_agents: int, **kwargs): @@ -53,7 +52,7 @@ def __init__(self, obs_space: DataSpace, action_space: DataSpace, num_agents: in self.obs_space = obs_space self.action_space = action_space self.num_agents: int = num_agents - self.agent_names: List[str] = kwargs.get("agent_names", map(str, range(self.num_agents))) + self.agent_names: list[str] = kwargs.get("agent_names", map(str, range(self.num_agents))) hidden_layers = to_numbers_seq(self._register_param(kwargs, "hidden_layers", (100, 100), update=True)) noise_scale = float(self._register_param(kwargs, "noise_scale", 0.5)) @@ -61,7 +60,7 @@ def __init__(self, obs_space: DataSpace, action_space: DataSpace, num_agents: in actor_lr = float(self._register_param(kwargs, "actor_lr", 3e-4)) critic_lr = float(self._register_param(kwargs, "critic_lr", 3e-4)) - self.agents: Dict[str, DDPGAgent] = OrderedDict( + self.agents: dict[str, DDPGAgent] = OrderedDict( { agent_name: DDPGAgent( obs_space, @@ -78,7 +77,7 @@ def __init__(self, obs_space: DataSpace, action_space: DataSpace, num_agents: in self.gamma = float(self._register_param(kwargs, "gamma", 0.99)) self.tau = float(self._register_param(kwargs, "tau", 0.02)) - self.gradient_clip: Optional[float] = self._register_param(kwargs, "gradient_clip") + self.gradient_clip: float | None = self._register_param(kwargs, "gradient_clip") self.batch_size = int(self._register_param(kwargs, "batch_size", 64)) self.buffer_size = int(self._register_param(kwargs, "buffer_size", int(1e6))) @@ -99,11 +98,11 @@ def __init__(self, obs_space: DataSpace, action_space: DataSpace, num_agents: in self._step_data = {} self._loss_critic: float = float("nan") - self._loss_actor: Dict[str, float] = {name: float("nan") for name in self.agent_names} + self._loss_actor: dict[str, float] = {name: float("nan") for name in self.agent_names} self.reset() @property - def loss(self) -> Dict[str, float]: + def loss(self) -> dict[str, float]: out = {} for agent_name, agent in self.agents.items(): for loss_name, loss_value in agent.loss.items(): @@ -243,7 +242,7 @@ def log_metrics(self, data_logger: DataLogger, step: int, full_log: bool = False for agent_name, agent in self.agents.items(): data_logger.log_values_dict(f"{agent_name}/loss", agent.loss, step) - def get_state(self) -> Dict[str, dict]: + def get_state(self) -> dict[str, dict]: """Returns agents' internal states""" agents_state = {} agents_state["config"] = self._config @@ -263,7 +262,7 @@ def save_state(self, path: str): agents_state = self.get_state() torch.save(agents_state, path) - def load_state(self, *, path: Optional[str] = None, agent_state: Optional[dict] = None) -> None: + def load_state(self, *, path: str | None = None, agent_state: dict | None = None) -> None: """Loads the state into the Multi Agent. The state can be provided either via path to a file that contains the state, @@ -292,5 +291,5 @@ def seed(self, seed: int) -> None: for agent in self.agents.values(): agent.seed(seed) - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return {name: agent.state_dict() for (name, agent) in self.agents.items()} diff --git a/ai_traineree/networks/bodies.py b/ai_traineree/networks/bodies.py index dc4e6ad..f7a6b31 100644 --- a/ai_traineree/networks/bodies.py +++ b/ai_traineree/networks/bodies.py @@ -1,6 +1,6 @@ from functools import reduce from math import sqrt -from typing import Any, Optional, Sequence, Tuple, Union +from typing import Any, Sequence import torch import torch.nn as nn @@ -16,7 +16,7 @@ def hidden_init(layer: nn.Module): return (-lim, lim) -def layer_init(layer: nn.Module, range_value: Optional[Tuple[float, float]] = None, remove_mean=True): +def layer_init(layer: nn.Module, range_value: tuple[float, float] | None = None, remove_mean=True): if not (isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear)): return if range_value is not None: @@ -28,7 +28,7 @@ def layer_init(layer: nn.Module, range_value: Optional[Tuple[float, float]] = No class ScaleNet(NetworkType): - def __init__(self, scale: Union[float, int]) -> None: + def __init__(self, scale: float | int) -> None: super(ScaleNet, self).__init__() self.scale = scale @@ -106,7 +106,7 @@ def __init__(self, input_dim: Sequence[int], **kwargs): self.to(self.device) @staticmethod - def _expand_to_seq(o: Union[Any, Sequence[Any]], size) -> Sequence[Any]: + def _expand_to_seq(o: Any | Sequence[Any], size) -> Sequence[Any]: return o if isinstance(o, Sequence) else (o,) * size @property @@ -144,7 +144,7 @@ def __init__( self, in_features: FeatureType, out_features: FeatureType, - hidden_layers: Optional[Sequence[int]] = (200, 100), + hidden_layers: Sequence[int] | None = (200, 100), last_layer_range=(-3e-4, 3e-4), bias: bool = True, **kwargs, @@ -214,7 +214,7 @@ def __init__( in_features: FeatureType, inj_action_size: int, out_features: FeatureType = (1,), - hidden_layers: Optional[Sequence[int]] = (100, 100), + hidden_layers: Sequence[int] | None = (100, 100), inj_actions_layer: int = 1, **kwargs, ): @@ -373,7 +373,7 @@ def __init__( self, in_features: FeatureType, out_features: FeatureType, - hidden_layers: Optional[Sequence[int]] = (100, 100), + hidden_layers: Sequence[int] | None = (100, 100), sigma=0.4, factorised=True, **kwargs, diff --git a/ai_traineree/networks/heads.py b/ai_traineree/networks/heads.py index 7f8a4c8..8391275 100644 --- a/ai_traineree/networks/heads.py +++ b/ai_traineree/networks/heads.py @@ -13,9 +13,10 @@ Heads are "special" in that each is built on networks/brains and will likely need some special pipeping when attaching to your agent. """ + from functools import lru_cache, reduce from operator import mul -from typing import Callable, List, Optional, Sequence +from typing import Callable, Sequence import torch import torch.nn as nn @@ -34,7 +35,7 @@ class NetChainer(NetworkType): The need for wrapper comes from unified API to reset properties. """ - def __init__(self, net_classes: List[NetworkTypeClass], **kwargs): + def __init__(self, net_classes: list[NetworkTypeClass], **kwargs): super(NetChainer, self).__init__() self.nets = nn.ModuleList(net_classes) @@ -90,9 +91,9 @@ def __init__( in_features: Sequence[int], out_features: Sequence[int], hidden_layers: Sequence[int], - net_fn: Optional[Callable[..., NetworkType]] = None, - net_class: Optional[NetworkTypeClass] = None, - **kwargs + net_fn: Callable[..., NetworkType] | None = None, + net_class: NetworkTypeClass | None = None, + **kwargs, ): """ Parameters: @@ -161,11 +162,11 @@ def __init__( num_atoms: int = 21, v_min: float = -20.0, v_max: float = 20.0, - in_features: Optional[FeatureType] = None, - out_features: Optional[FeatureType] = None, + in_features: FeatureType | None = None, + out_features: FeatureType | None = None, hidden_layers: Sequence[int] = (200, 200), - net: Optional[NetworkType] = None, - device: Optional[torch.device] = None, + net: NetworkType | None = None, + device: torch.device | None = None, ): """ Parameters: diff --git a/ai_traineree/policies.py b/ai_traineree/policies.py index cc25c13..a4a9958 100644 --- a/ai_traineree/policies.py +++ b/ai_traineree/policies.py @@ -1,5 +1,4 @@ from functools import lru_cache -from typing import Optional, Tuple import torch import torch.nn as nn @@ -9,6 +8,7 @@ from ai_traineree.networks import NetworkType from ai_traineree.networks.bodies import FcNet from ai_traineree.types import FeatureType +from ai_traineree.types.dataspace import DataSpace class PolicyType(NetworkType): @@ -192,10 +192,10 @@ def __init__(self, in_features: FeatureType, out_features: FeatureType, out_scal self.log_std_min = -10 self.log_std_max = 2 - self._last_dist: Optional[Distribution] = None + self._last_dist: Distribution | None = None self._last_samples = None - def log_prob(self, samples) -> Optional[torch.Tensor]: + def log_prob(self, samples) -> torch.Tensor | None: if self._last_dist is None: return None return self._last_dist.log_prob(samples).sum(axis=-1) @@ -229,7 +229,8 @@ class BetaPolicy(PolicyType): param_dim = 2 - def __init__(self, size: int, bounds: Tuple[float, float] = (1, float("inf"))): + # def __init__(self, size: int, bounds: Tuple[float, float] = (1, float("inf"))): + def __init__(self, size: int, bound_space: DataSpace, out_scale: float = 1, **kwargs): """ Parameters: size: Observation's dimensionality upon sampling. @@ -238,39 +239,58 @@ def __init__(self, size: int, bounds: Tuple[float, float] = (1, float("inf"))): """ super(BetaPolicy, self).__init__() - self.bounds = bounds + if bound_space.low is None or bound_space.high is None: + raise ValueError( + "Bound space needs to have both low and high boundaries. " + f"Provided: low={bound_space.low}, high={bound_space.high}" + ) + self.bound_space = bound_space self.action_size = size self.dist = Beta if size > 1 else Dirichlet + self._last_dist: Distribution | None = None + self._last_samples = None - def forward(self, x) -> Distribution: + def forward(self, x, deterministic: bool = False) -> torch.Tensor: x = x.view(-1, self.action_size, self.param_dim) - x = torch.clamp(x, self.bounds[0], self.bounds[1]) - dist = self.dist(x[..., 0], x[..., 1]) - return dist + loc = x[..., 0] + if deterministic: + return loc - @staticmethod - def log_prob(dist, samples): - return dist.log_prob(samples).mean(dim=-1) + # x = torch.clamp(x, self.bound_space.low, self.bound_space.high) + # self._last_dist = self.dist(x[..., 0], x[..., 1]) + self._last_dist = self.dist() + return self._last_dist.rsample() + def log_prob(self, samples): + if self._last_dist is None: + return None + return self._last_dist.log_prob(samples).mean(dim=-1) -class DirichletPolicy(PolicyType): - param_dim = 1 +class DirichletPolicy(PolicyType): + param_dim = 2 - def __init__(self, *, alpha_min: float = 0.05): + def __init__(self, size: int, *, alpha_min: float = 0.05): super(DirichletPolicy, self).__init__() + self.size = size self.alpha_min = alpha_min - def forward(self, x) -> Distribution: - x = torch.clamp(x, self.alpha_min) - return Dirichlet(x) + def forward(self, x, deterministic: bool = False) -> torch.Tensor: + _x = x.view(-1, self.param_dim, self.size) + loc = _x[..., 0] + if deterministic: + return loc + alpha = torch.clamp(_x[..., 1], self.alpha_min) + self._last_dist = Dirichlet(alpha) + self._last_loc = loc.detach() + return loc + self._last_dist.rsample() - def log_prob(self, dist: Dirichlet, samples) -> torch.Tensor: - return dist.log_prob(samples) + def log_prob(self, samples) -> torch.Tensor: + _sampled = samples - self._last_loc + return self._last_dist.log_prob(_sampled) class DeterministicPolicy(PolicyType): - param_dim = 1 def __init__(self, action_size): diff --git a/ai_traineree/runners/env_runner.py b/ai_traineree/runners/env_runner.py index c9ba975..9c573ed 100644 --- a/ai_traineree/runners/env_runner.py +++ b/ai_traineree/runners/env_runner.py @@ -5,7 +5,7 @@ import time from collections import deque from pathlib import Path -from typing import Any, Iterable, List, Optional, Tuple +from typing import Any, Iterable from ai_traineree.agents import AgentBase from ai_traineree.loggers import DataLogger @@ -58,18 +58,18 @@ def __init__(self, task: TaskType, agent: AgentBase, max_iterations: int = int(1 self.__images = [] self.logger.setLevel(kwargs.get("logger_level", logging.INFO)) - self.data_logger: Optional[DataLogger] = kwargs.get("data_logger") + self.data_logger: DataLogger | None = kwargs.get("data_logger") if self.data_logger: self.logger.info("DataLogger: %s", str(self.data_logger)) self.data_logger.set_hparams(hparams=self.agent.hparams) self._debug_log: bool = bool(kwargs.get("debug_log", False)) - self._exp: List[Tuple[int, Experience]] = [] - self._actions: List[Any] = [] - self._states: List[Any] = [] - self._rewards: List[Any] = [] - self._dones: List[Any] = [] - self._noises: List[Any] = [] + self._exp: list[tuple[int, Experience]] = [] + self._actions: list[Any] = [] + self._states: list[Any] = [] + self._rewards: list[Any] = [] + self._dones: list[Any] = [] + self._noises: list[Any] = [] self.seed(kwargs.get("seed")) @@ -92,12 +92,12 @@ def interact_episode( self, train: bool = False, eps: float = 0, - max_iterations: Optional[int] = None, + max_iterations: int | None = None, render: bool = False, render_gif: bool = False, - log_interaction_freq: Optional[int] = 10, - full_log_interaction_freq: Optional[int] = 1000, - ) -> Tuple[RewardType, int]: + log_interaction_freq: int | None = 10, + full_log_interaction_freq: int | None = 1000, + ) -> tuple[RewardType, int]: score = 0 obs = self.task.reset() iterations = 0 @@ -171,10 +171,10 @@ def run( eps_decay: float = 0.995, log_episode_freq: int = 1, log_interaction_freq: int = 10, - gif_every_episodes: Optional[int] = None, - checkpoint_every: Optional[int] = 200, + gif_every_episodes: int | None = None, + checkpoint_every: int | None = 200, force_new: bool = False, - ) -> List[float]: + ) -> list[float]: """ Evaluates the agent in the environment. The evaluation will stop when the agent reaches the `reward_goal` in the averaged last `self.window_len`, or @@ -317,7 +317,7 @@ def log_logger(self, **kwargs): def log_episode_metrics(self, **kwargs): """Uses DataLogger, e.g. TensorboardLogger, to store env metrics.""" - episodes: List[int] = kwargs.get("episodes", []) + episodes: list[int] = kwargs.get("episodes", []) for episode, epsilon in zip(episodes, kwargs.get("epsilons", [])): self.data_logger.log_value("episode/epsilon", epsilon, episode) diff --git a/ai_traineree/runners/multi_sync_env_runner.py b/ai_traineree/runners/multi_sync_env_runner.py index 58b04ff..9fa06eb 100644 --- a/ai_traineree/runners/multi_sync_env_runner.py +++ b/ai_traineree/runners/multi_sync_env_runner.py @@ -4,7 +4,6 @@ import sys from collections import deque from pathlib import Path -from typing import List, Optional, Tuple import numpy as np import torch.multiprocessing as mp @@ -37,7 +36,7 @@ class MultiSyncEnvRunner: logger = logging.getLogger("MultiSyncEnvRunner") - def __init__(self, tasks: List[TaskType], agent: AgentBase, max_iterations: int = int(1e5), **kwargs): + def __init__(self, tasks: list[TaskType], agent: AgentBase, max_iterations: int = int(1e5), **kwargs): """ Expects the environment to come as the TaskType and the agent as the AgentBase. @@ -65,7 +64,7 @@ def __init__(self, tasks: List[TaskType], agent: AgentBase, max_iterations: int self.window_len = kwargs.get("window_len", 100) self.scores_window = deque(maxlen=self.window_len) - self.data_logger: Optional[DataLogger] = kwargs.get("data_logger") + self.data_logger: DataLogger | None = kwargs.get("data_logger") self.logger.info("DataLogger: %s", str(self.data_logger)) def __str__(self) -> str: @@ -139,7 +138,7 @@ def run( eps_end: float = 0.01, eps_decay: float = 0.995, log_episode_freq: int = 1, - checkpoint_every: Optional[int] = 200, + checkpoint_every: int | None = 200, force_new=False, ): """ @@ -200,7 +199,7 @@ def _run( eps_end: float = 0.01, eps_decay: float = 0.995, log_episode_freq: int = 1, - checkpoint_every: Optional[int] = 200, + checkpoint_every: int | None = 200, force_new: bool = False, ): # Initiate variables @@ -255,7 +254,6 @@ def _run( # Training part for idx in range(self.task_num): - # Update Episode number if any agent is DONE or enough ITERATIONS if not (experience.done[idx] or _iterations[idx] >= max_iterations): continue @@ -304,7 +302,7 @@ def _step_all_tasks(self, obs, actions): for t_idx in range(self.task_num): self.parent_conns[t_idx].send((t_idx, obs[t_idx], actions[t_idx])) - def _collect_all_tasks(self) -> Tuple[Experience, np.ndarray]: + def _collect_all_tasks(self) -> tuple[Experience, np.ndarray]: obs = np.empty((len(self.tasks),) + self.tasks[0].obs_space.shape, dtype=np.float32) next_obs = obs.copy() actions = np.empty((len(self.tasks),) + self.tasks[0].action_space.shape, dtype=np.float32) @@ -391,7 +389,7 @@ def log_logger(self, **kwargs): def log_episode_metrics(self, **kwargs): """Uses data_logger, e.g. Tensorboard, to store env metrics.""" - episodes: List[int] = kwargs.get("episodes", []) + episodes: list[int] = kwargs.get("episodes", []) for episode, epsilon in zip(episodes, kwargs.get("epsilons", [])): self.data_logger.log_value("episode/epsilon", epsilon, episode) diff --git a/ai_traineree/runners/multiagent_env_runner.py b/ai_traineree/runners/multiagent_env_runner.py index ece7fa5..8a8cd37 100644 --- a/ai_traineree/runners/multiagent_env_runner.py +++ b/ai_traineree/runners/multiagent_env_runner.py @@ -5,7 +5,7 @@ import time from collections import defaultdict, deque from pathlib import Path -from typing import Dict, Iterable, List, Optional, Tuple +from typing import Iterable from ai_traineree.loggers import DataLogger from ai_traineree.tasks import PettingZooTask @@ -67,8 +67,8 @@ def __init__( self.mode = mode self.episode = 0 self.iteration = 0 - self.all_scores: List[List[RewardType]] = [] - self.all_iterations: List[int] = [] + self.all_scores: list[list[RewardType]] = [] + self.all_iterations: list[int] = [] self.scores_window = deque(maxlen=self.window_len) self._images = [] self._debug_log = kwargs.get("debug_log", False) @@ -76,7 +76,7 @@ def __init__( self._rewards = [] self._dones = [] - self.data_logger: Optional[DataLogger] = kwargs.get("data_logger") + self.data_logger: DataLogger | None = kwargs.get("data_logger") self.logger.info("DataLogger: %s", str(self.data_logger)) def __str__(self) -> str: @@ -89,20 +89,20 @@ def seed(self, seed: int): def reset(self): """Resets the EnvRunner. The task env and the agent are preserved.""" self.episode = 0 - self.all_scores: List[List[RewardType]] = [] + self.all_scores: list[list[RewardType]] = [] self.all_iterations = [] self.scores_window = deque(maxlen=self.window_len) def interact_episode( self, eps: float = 0, - max_iterations: Optional[int] = None, + max_iterations: int | None = None, render: bool = False, render_gif: bool = False, - log_interaction_freq: Optional[int] = None, - ) -> Tuple[List[RewardType], int]: - score: List[RewardType] = [0.0] * self.multi_agent.num_agents - states: List[StateType] = self.task.reset() + log_interaction_freq: int | None = None, + ) -> tuple[list[RewardType], int]: + score: list[RewardType] = [0.0] * self.multi_agent.num_agents + states: list[StateType] = self.task.reset() iterations = 0 max_iterations = max_iterations if max_iterations is not None else self.max_iterations @@ -118,10 +118,10 @@ def interact_episode( self.task.render("human") time.sleep(1.0 / FRAMES_PER_SEC) - next_states: List[StateType] = [] - rewards: List[RewardType] = [] - dones: List[DoneType] = [] - actions: List[ActionType] = [] + next_states: list[StateType] = [] + rewards: list[RewardType] = [] + dones: list[DoneType] = [] + actions: list[ActionType] = [] for agent_id in range(self.multi_agent.num_agents): experience = Experience(obs=states[agent_id]) experience = self.multi_agent.act(str(agent_id), experience, eps) @@ -162,10 +162,10 @@ def run( eps_end=0.01, eps_decay=0.995, log_episode_freq=1, - gif_every_episodes: Optional[int] = None, + gif_every_episodes: int | None = None, checkpoint_every=200, force_new=False, - ) -> List[List[RewardType]]: + ) -> list[list[RewardType]]: """ Evaluates the multi_agent in the environment. The evaluation will stop when the agent reaches the `reward_goal` in the averaged last `self.window_len`, or @@ -267,7 +267,7 @@ def log_logger(self, **kwargs): def log_episode_metrics(self, **kwargs): """Uses data_logger, e.g. Tensorboard, to store env metrics.""" assert self.data_logger, "Cannot log without DataLogger" - episodes: List[int] = kwargs.get("episodes", []) + episodes: list[int] = kwargs.get("episodes", []) for episode, epsilon in zip(episodes, kwargs.get("epsilons", [])): self.data_logger.log_value("episode/epsilon", epsilon, episode) @@ -407,8 +407,8 @@ def __init__( self.mode = mode self.episode: float = 0 self.iteration = 0 - self.all_scores: List[Dict[str, RewardType]] = [] - self.all_iterations: List[int] = [] + self.all_scores: list[dict[str, RewardType]] = [] + self.all_iterations: list[int] = [] self.window_len = kwargs.get("window_len", 100) self.scores_window = deque(maxlen=self.window_len) @@ -432,18 +432,18 @@ def seed(self, seed: int) -> None: def reset(self) -> None: """Resets instance. Preserves everything about task and agent.""" self.episode: float = 0 - self.all_scores: List[Dict[str, RewardType]] = [] + self.all_scores: list[dict[str, RewardType]] = [] self.all_iterations = [] self.scores_window = deque(maxlen=self.window_len) def interact_episode( self, eps: float = 0, - max_iterations: Optional[int] = None, + max_iterations: int | None = None, render: bool = False, render_gif: bool = False, - log_interaction_freq: Optional[int] = None, - ) -> Tuple[Dict[str, RewardType], int]: + log_interaction_freq: int | None = None, + ) -> tuple[dict[str, RewardType], int]: score = defaultdict(float) iterations = 0 max_iterations = max_iterations if max_iterations is not None else self.max_iterations @@ -460,9 +460,9 @@ def interact_episode( self.task.render("human") time.sleep(1.0 / FRAMES_PER_SEC) - # next_states: Dict[str, StateType] = {} - # rewards: Dict[str, RewardType] = {} - dones: Dict[str, DoneType] = {} + # next_states: dict[str, StateType] = {} + # rewards: dict[str, RewardType] = {} + dones: dict[str, DoneType] = {} # TODO: Iterate over distinc agents in a single cycle. This `for` doesn't guarantee that. for agent_name in self.task.agent_iter(max_iter=self.multi_agent.num_agents): @@ -509,10 +509,10 @@ def run( eps_end=0.01, eps_decay=0.995, log_episode_freq=1, - gif_every_episodes: Optional[int] = None, + gif_every_episodes: int | None = None, checkpoint_every=200, force_new=False, - ) -> List[Dict[str, RewardType]]: + ) -> list[dict[str, RewardType]]: """ Evaluates the Multi Agent in the environment. The evaluation will stop when the agent reaches the `reward_goal` in the averaged last `self.window_len`, or @@ -614,7 +614,7 @@ def log_logger(self, **kwargs): def log_episode_metrics(self, **kwargs): """Uses data_logger, e.g. Tensorboard, to store env metrics.""" assert self.data_logger, "Cannot log without DataLogger" - episodes: List[int] = kwargs.get("episodes", []) + episodes: list[int] = kwargs.get("episodes", []) for episode, epsilon in zip(episodes, kwargs.get("epsilons", [])): self.data_logger.log_value("episode/epsilon", epsilon, episode) diff --git a/ai_traineree/tasks.py b/ai_traineree/tasks.py index 897c089..d95ad48 100644 --- a/ai_traineree/tasks.py +++ b/ai_traineree/tasks.py @@ -2,7 +2,7 @@ from collections import deque from functools import cached_property, reduce from operator import mul -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Sequence import numpy as np import torch @@ -24,7 +24,7 @@ logging.warning("Couldn't import `gym_unity` and/or `mlagents`. MultiAgentUnityTask won't work.") -GymStepResult = Tuple[np.ndarray, float, bool, Dict] +GymStepResult = tuple[np.ndarray, float, bool, dict] class TerminationMode: @@ -38,9 +38,9 @@ class GymTask(TaskType): def __init__( self, - env: Union[str, gym.Env], - state_transform: Optional[Callable] = None, - reward_transform: Optional[Callable] = None, + env: str | gym.Env, + state_transform: Callable | None = None, + reward_transform: Callable | None = None, can_render=True, stack_frames: int = 1, skip_start_frames: int = 0, @@ -123,7 +123,7 @@ def seed(self, seed): if isinstance(seed, (int, float)): return self.env.reset(seed=seed) - def reset(self) -> Union[torch.Tensor, np.ndarray]: + def reset(self) -> torch.Tensor | np.ndarray: # TODO: info is currently ignored state, info = self.env.reset() # state = self.env.reset() @@ -140,14 +140,14 @@ def render(self, mode="rgb_array"): self.logger.warning("Asked for rendering but it's not available in this environment") return - def step(self, action: ActionType) -> Tuple: + def step(self, action: ActionType) -> tuple: """Each action results in a new state, reward, done flag, and info about env. Parameters: action: An action that the agent is taking in current environment step. Returns: - step_tuple (Tuple[torch.Tensor, float, bool, Any]): + step_tuple (tuple[torch.Tensor, float, bool, Any]): The return consists of a next state, a reward in that state, a flag whether the next state is terminal and additional information provided by the environment regarding that state. @@ -197,7 +197,7 @@ def agents(self): return self.env.agents @cached_property - def observation_spaces(self) -> Dict[str, DataSpace]: + def observation_spaces(self) -> dict[str, DataSpace]: spaces = {} for unit, space in self.env.observation_spaces.items(): if type(space).__name__ == "Dict": @@ -206,10 +206,10 @@ def observation_spaces(self) -> Dict[str, DataSpace]: return spaces @cached_property - def action_spaces(self) -> Dict[str, DataSpace]: + def action_spaces(self) -> dict[str, DataSpace]: return {unit: DataSpace.from_gym_space(space) for (unit, space) in self.env.action_spaces.items()} - def action_mask_spaces(self) -> Optional[Dict[str, DataSpace]]: + def action_mask_spaces(self) -> dict[str, DataSpace] | None: spaces = {} for unit, space in self.env.observation_spaces.items(): if not type(space).__name__ == "Dict": @@ -237,7 +237,7 @@ def is_all_done(self): def dones(self): return self.env.dones - def last(self, agent_name: Optional[str] = None) -> Tuple[Any, float, bool, Any]: + def last(self, agent_name: str | None = None) -> tuple[Any, float, bool, Any]: if agent_name is None: return self.env.last() return ( @@ -365,7 +365,7 @@ def __init__( self._action_space = spaces.Box(-high, high, dtype=np.float32) # Set observations space - list_spaces: List[gym.Space] = [] + list_spaces: list[gym.Space] = [] shapes = self._get_vis_obs_shape() for shape in shapes: if uint8_visual: @@ -382,7 +382,7 @@ def __init__( self._observation_space = list_spaces[0] # only return the first one # def reset(self) -> Union[List[np.ndarray], np.ndarray]: - def reset(self) -> List[StateType]: + def reset(self) -> list[StateType]: """Resets the state of the environment and returns an initial observation. Returns: observation (object/list): the initial observation of the space. @@ -398,7 +398,7 @@ def reset(self) -> List[StateType]: return states # return res[0] - def step(self, action: List[Any], agent_id: int) -> GymStepResult: + def step(self, action: list[Any], agent_id: int) -> GymStepResult: """Run one timestep of the environment's dynamics. When end of episode is reached, you are responsible for calling `reset()` to reset this environment's state. @@ -430,7 +430,7 @@ def step(self, action: List[Any], agent_id: int) -> GymStepResult: return self._single_step(decision_step) # def detect_game_over(self, termianl_steps: List[TerminalSteps]) -> bool: - def detect_game_over(self, termianl_steps: List) -> bool: + def detect_game_over(self, termianl_steps: list) -> bool: """Determine whether the episode has finished. Expects the `terminal_steps` to contain only steps that terminated. Note that other steps @@ -484,8 +484,8 @@ def _get_n_vis_obs(self) -> int: result += 1 return result - def _get_vis_obs_shape(self) -> List[Tuple]: - result: List[Tuple] = [] + def _get_vis_obs_shape(self) -> list[tuple]: + result: list[tuple] = [] for shape in self.group_spec.observation_shapes: if len(shape) == 3: result.append(shape) @@ -493,8 +493,8 @@ def _get_vis_obs_shape(self) -> List[Tuple]: @staticmethod # def _get_vis_obs_list(step_result: Union[DecisionSteps, TerminalSteps]) -> List[np.ndarray]: - def _get_vis_obs_list(step_result) -> List[np.ndarray]: - result: List[np.ndarray] = [] + def _get_vis_obs_list(step_result) -> list[np.ndarray]: + result: list[np.ndarray] = [] for obs in step_result.obs: if len(obs.shape) == 4: result.append(obs) @@ -503,7 +503,7 @@ def _get_vis_obs_list(step_result) -> List[np.ndarray]: @staticmethod # def _get_vector_obs(step_result: Union[DecisionSteps, TerminalSteps]) -> np.ndarray: def _get_vector_obs(step_result) -> np.ndarray: - result: List[np.ndarray] = [] + result: list[np.ndarray] = [] for obs in step_result.obs: if len(obs.shape) == 2: result.append(obs) @@ -550,7 +550,7 @@ def metadata(self): return {"render.modes": ["rgb_array"]} @property - def reward_range(self) -> Tuple[float, float]: + def reward_range(self) -> tuple[float, float]: return -float("inf"), float("inf") @property diff --git a/ai_traineree/types/agent.py b/ai_traineree/types/agent.py index 13ea5cc..f3521f4 100644 --- a/ai_traineree/types/agent.py +++ b/ai_traineree/types/agent.py @@ -1,5 +1,5 @@ import abc -from typing import Any, Dict, List +from typing import Any from ai_traineree.loggers import DataLogger from ai_traineree.types.experience import Experience @@ -9,13 +9,12 @@ class AgentType(abc.ABC): - model: str obs_space: DataSpace action_space: DataSpace - loss: Dict[str, float] + loss: dict[str, float] train: bool = True - _config: Dict = {} + _config: dict = {} @property def hparams(self): @@ -24,7 +23,7 @@ def make_strings_out_of_things_that_are_not_obvious_numbers(v): return {k: make_strings_out_of_things_that_are_not_obvious_numbers(v) for (k, v) in self._config.items()} - def _register_param(self, source: Dict[str, Any], name: str, default_value=None, update=False, drop=False) -> Any: + def _register_param(self, source: dict[str, Any], name: str, default_value=None, update=False, drop=False) -> Any: self._config[name] = value = source.get(name, default_value) if drop and name in source: del source[name] @@ -61,15 +60,14 @@ def load_state(self, path: str): class MultiAgentType(abc.ABC): - model: str obs_space: DataSpace action_space: DataSpace - loss: Dict[str, float] - agents: List[AgentType] - agent_names: List[str] + loss: dict[str, float] + agents: list[AgentType] + agent_names: list[str] num_agents: int - _config: Dict = {} + _config: dict = {} @property def hparams(self): @@ -78,7 +76,7 @@ def make_strings_out_of_things_that_are_not_obvious_numbers(v): return {k: make_strings_out_of_things_that_are_not_obvious_numbers(v) for (k, v) in self._config.items()} - def _register_param(self, source: Dict[str, Any], name: str, default_value=None, update=False, drop=False) -> Any: + def _register_param(self, source: dict[str, Any], name: str, default_value=None, update=False, drop=False) -> Any: self._config[name] = value = source.get(name, default_value) if drop: del source[name] @@ -98,7 +96,7 @@ def log_metrics(self, data_logger: DataLogger, step: int, full_log: bool = False pass @abc.abstractmethod - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: """Returns description of all agent's components.""" pass diff --git a/ai_traineree/types/dataspace.py b/ai_traineree/types/dataspace.py index 5629dca..2a3b7e2 100644 --- a/ai_traineree/types/dataspace.py +++ b/ai_traineree/types/dataspace.py @@ -1,7 +1,6 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Optional, Tuple, Union import gymnasium as gym from torch import Tensor @@ -13,9 +12,9 @@ @dataclass class DataSpace: dtype: str - shape: Tuple[int] - low: Optional[Union[Numeric, Tensor]] = None - high: Optional[Union[Numeric, Tensor]] = None + shape: tuple[int] + low: Numeric | Tensor | None = None + high: Numeric | Tensor | None = None @staticmethod def from_int(size: int): diff --git a/ai_traineree/types/experience.py b/ai_traineree/types/experience.py index a4aff8b..26a41c6 100644 --- a/ai_traineree/types/experience.py +++ b/ai_traineree/types/experience.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any import jsons @@ -38,7 +38,7 @@ class Experience: action: ActionType reward: RewardType done: DoneType - next_obs: Optional[ObsType] + next_obs: ObsType | None state: ObsType next_state: ObsType @@ -53,7 +53,7 @@ def get(self, key: str): return self.data.get(key) def update(self, **kwargs): - for (key, value) in kwargs.items(): + for key, value in kwargs.items(): if key in Experience.whitelist: self.data[key] = value self.__dict__[key] = value # TODO: Delete after checking that everything is updated @@ -62,13 +62,13 @@ def update(self, **kwargs): def __add__(self, o_exp): return self.update(**o_exp.get_dict()) - def get_dict(self, serialize=False) -> Dict[str, Any]: + def get_dict(self, serialize=False) -> dict[str, Any]: if serialize: return {k: to_list(v) for (k, v) in self.data.items()} return self.data -def exprience_serialization(obj: Experience, **kwargs) -> Dict[str, Any]: +def exprience_serialization(obj: Experience, **kwargs) -> dict[str, Any]: # return {k: to_list(v) for (k, v) in obj.data.items() if v is not None} return {k: jsons.dumps(v) for (k, v) in obj.data.items()} diff --git a/ai_traineree/types/primitive.py b/ai_traineree/types/primitive.py index 51be281..594a5f0 100644 --- a/ai_traineree/types/primitive.py +++ b/ai_traineree/types/primitive.py @@ -1,11 +1,11 @@ -from typing import Dict, List, Sequence, Union +from typing import Sequence -Numeric = Union[int, float] -ObsType = ObservationType = Union[List[int], List[float]] -StateType = Union[int, List[float]] -ActionType = Union[int, float, List] +Numeric = int | float +ObsType = ObservationType = list[int] | list[float] +StateType = int | list[float] +ActionType = int | float | list DoneType = bool -RewardType = Union[int, float] +RewardType = int | float -HyperparameterType = Dict[str, str] +HyperparameterType = dict[str, str] FeatureType = Sequence[int] diff --git a/ai_traineree/types/state.py b/ai_traineree/types/state.py index e55d2d7..a8382df 100644 --- a/ai_traineree/types/state.py +++ b/ai_traineree/types/state.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional +from typing import Any, Optional import torch @@ -11,13 +11,13 @@ class BufferState: type: str buffer_size: int batch_size: int - data: Optional[List] = field(default=None, init=False) - extra: Optional[Dict[str, Any]] = field(default=None, init=False, repr=False) + data: list | None = field(default=None, init=False) + extra: dict[str, Any] | None = field(default=None, init=False, repr=False) @dataclass class NetworkState: - net: Dict[str, Any] + net: dict[str, Any] def __eq__(self, other): for key, value in other.net.items(): @@ -34,6 +34,6 @@ class AgentState: model: str obs_space: DataSpace action_space: DataSpace - config: Dict[str, Any] + config: dict[str, Any] network: Optional[NetworkState] buffer: Optional[BufferState] diff --git a/ai_traineree/types/task.py b/ai_traineree/types/task.py index 5b3da73..71d019c 100644 --- a/ai_traineree/types/task.py +++ b/ai_traineree/types/task.py @@ -1,10 +1,10 @@ import abc -from typing import Any, List, Optional, Tuple +from typing import Any from .dataspace import DataSpace from .primitive import ActionType, DoneType, RewardType, StateType -TaskStepType = Tuple[StateType, RewardType, DoneType, Any] +TaskStepType = tuple[StateType, RewardType, DoneType, Any] class TaskType(abc.ABC): @@ -27,7 +27,7 @@ def step(self, action: ActionType, **kwargs) -> TaskStepType: pass @abc.abstractmethod - def render(self, mode: Optional[str] = None) -> None: + def render(self, mode: str | None = None) -> None: pass @abc.abstractmethod @@ -37,5 +37,5 @@ def reset(self) -> StateType: class MultiAgentTaskType(TaskType): @abc.abstractmethod - def reset(self) -> List[StateType]: + def reset(self) -> list[StateType]: pass diff --git a/ai_traineree/utils.py b/ai_traineree/utils.py index 836b1ee..b36eed3 100644 --- a/ai_traineree/utils.py +++ b/ai_traineree/utils.py @@ -1,7 +1,7 @@ import logging import os from pathlib import Path -from typing import Any, List, Tuple, Union +from typing import Any import jsons import numpy as np @@ -37,7 +37,7 @@ def to_tensor(x) -> torch.Tensor: return torch.tensor(x) -def save_gif(path, images: List[np.ndarray]) -> None: +def save_gif(path, images: list[np.ndarray]) -> None: logging.debug(f"Saving as a gif to {path}") from PIL import Image @@ -47,12 +47,12 @@ def save_gif(path, images: List[np.ndarray]) -> None: imgs[0].save(path, save_all=True, append_images=imgs[1:], optimize=True, quality=85) -def str_to_number(s: str) -> Union[int, float]: +def str_to_number(s: str) -> int | float: "Smartly converts string either to an int or float" return int(s) if "." not in s else float(s) -def str_to_list(s: str) -> List: +def str_to_list(s: str) -> list: """Converts a string list of numbers into a evaluated list. Example: @@ -70,7 +70,7 @@ def str_to_list(s: str) -> List: return [str_to_number(num) for num in s.split(",")] -def str_to_tuple(s: str) -> Tuple: +def str_to_tuple(s: str) -> tuple: """Converts a string tuple of numbers into a evaluated tuple. Example: @@ -90,7 +90,7 @@ def str_to_tuple(s: str) -> Tuple: return tuple(map(str_to_number, s.split(","))) -def str_to_seq(s: str) -> Union[Tuple, List]: +def str_to_seq(s: str) -> tuple | list: """Converts a string sequence of number into tuple or list. The distnction is based on the surrounding brackets. If no brackets detected then it attempts to cast to tuple. @@ -109,7 +109,7 @@ def str_to_seq(s: str) -> Union[Tuple, List]: return str_to_tuple(s) -def to_numbers_seq(x: Any) -> Union[Tuple, List]: +def to_numbers_seq(x: Any) -> tuple | list: """Tries to convert an object into a sequence of numbers.""" if isinstance(x, (tuple, list)): return x @@ -126,7 +126,7 @@ def serialize(obj) -> str: return jsons.dumps(obj) -def condens_ndarray(a: np.ndarray) -> Union[int, float, np.ndarray]: +def condens_ndarray(a: np.ndarray) -> int | float | np.ndarray: """Condense ndarray to a common value. Returns: diff --git a/conftest.py b/conftest.py index d281f94..d13eb79 100644 --- a/conftest.py +++ b/conftest.py @@ -1,6 +1,6 @@ import copy import random -from typing import Any, List, Sequence, Tuple +from typing import Any, Sequence import mock import numpy as np diff --git a/pyproject.toml b/pyproject.toml index ef24fd2..c71da2d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ai-traineree" -version = "0.6.0" +version = "0.7.1" description = "Yet another zoo of (Deep) Reinforcement Learning methods in Python using PyTorch" readme = "README.md" requires-python = ">=3.12" diff --git a/scripts/multi_exp_runner.py b/scripts/multi_exp_runner.py index 2f8371a..91ca26f 100755 --- a/scripts/multi_exp_runner.py +++ b/scripts/multi_exp_runner.py @@ -1,5 +1,5 @@ from pprint import pprint -from typing import Any, Dict +from typing import Any import torch @@ -15,7 +15,7 @@ seeds = [32167, 1, 999, 2833700, 13] for idx, config_update in enumerate(config_updates): - config: Dict[str, Any] = config_default.copy() + config: dict[str, Any] = config_default.copy() config.update(config_update) for seed in seeds: diff --git a/tests/types/test_experience.py b/tests/types/test_experience.py index e616012..c47c031 100644 --- a/tests/types/test_experience.py +++ b/tests/types/test_experience.py @@ -1,6 +1,6 @@ import random import string -from typing import Any, Dict, List +from typing import Any from ai_traineree.types.experience import Experience @@ -9,7 +9,7 @@ def r_string(n: int) -> str: return "".join(random.choices(string.printable, k=n)) -def r_float(n: int) -> List[float]: +def r_float(n: int) -> list[float]: return [random.random() for _ in range(n)] @@ -111,7 +111,7 @@ def test_experience_get_dict(): import torch # Assign - init_data: Dict[str, Any] = {k: r_float(5) for k in ["obs", "action", "reward", "done", "next_obs"]} + init_data: dict[str, Any] = {k: r_float(5) for k in ["obs", "action", "reward", "done", "next_obs"]} init_data["state"] = numpy.random.random(10) init_data["advantage"] = torch.rand(10) exp = Experience(**init_data) @@ -128,7 +128,7 @@ def test_experience_get_dict_serialize(): import torch # Assign - init_data: Dict[str, Any] = {k: r_float(5) for k in ["action", "reward", "done"]} + init_data: dict[str, Any] = {k: r_float(5) for k in ["action", "reward", "done"]} init_data["obs"] = numpy.random.random(5) init_data["next_obs"] = torch.rand(10) exp = Experience(**init_data) diff --git a/uv.lock b/uv.lock index e18397c..d123e40 100644 --- a/uv.lock +++ b/uv.lock @@ -20,7 +20,7 @@ wheels = [ [[package]] name = "ai-traineree" -version = "0.6.0" +version = "0.7.1" source = { virtual = "." } dependencies = [ { name = "gymnasium" },