Skip to content

Commit

Permalink
chore: resolve lint + update typing to 3.10+
Browse files Browse the repository at this point in the history
  • Loading branch information
laszukdawid committed Dec 28, 2024
1 parent de99429 commit 9d9d340
Show file tree
Hide file tree
Showing 39 changed files with 266 additions and 276 deletions.
2 changes: 1 addition & 1 deletion ai_traineree/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

__version__ = "0.6.0"
__version__ = "0.7.1"


try:
Expand Down
5 changes: 2 additions & 3 deletions ai_traineree/agents/d3pg.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import itertools
from functools import cached_property
from typing import Dict

import torch
import torch.nn as nn
Expand Down Expand Up @@ -170,7 +169,7 @@ def __init__(self, obs_space: DataSpace, action_space: DataSpace, **kwargs):
self._metric_batch_value_dist = torch.zeros(self.batch_size)

@property
def loss(self) -> Dict[str, float]:
def loss(self) -> dict[str, float]:
return {"actor": self._loss_actor, "critic": self._loss_critic}

@loss.setter
Expand Down Expand Up @@ -322,7 +321,7 @@ def learn(self, experiences):
soft_update(self.target_actor, self.actor, self.tau)
soft_update(self.target_critic, self.critic, self.tau)

def state_dict(self) -> Dict[str, dict]:
def state_dict(self) -> dict[str, dict]:
"""Describes agent's networks.
Returns:
Expand Down
5 changes: 2 additions & 3 deletions ai_traineree/agents/d4pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import itertools
from functools import cached_property
from typing import Dict

import torch
import torch.nn as nn
Expand Down Expand Up @@ -175,7 +174,7 @@ def __init__(self, obs_space: DataSpace, action_space: DataSpace, **kwargs):
self._loss_critic = float("nan")

@property
def loss(self) -> Dict[str, float]:
def loss(self) -> dict[str, float]:
return {"actor": self._loss_actor, "critic": self._loss_critic}

@cached_property
Expand Down Expand Up @@ -337,7 +336,7 @@ def learn(self, experiences):
soft_update(self.target_actor, self.actor, self.tau)
soft_update(self.target_critic, self.critic, self.tau)

def state_dict(self) -> Dict[str, dict]:
def state_dict(self) -> dict[str, dict]:
"""Describes agent's networks.
Returns:
Expand Down
8 changes: 3 additions & 5 deletions ai_traineree/agents/ddpg.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import copy
import operator
from functools import cached_property, reduce
from typing import Dict, Optional

import torch
import torch.nn as nn
Expand Down Expand Up @@ -125,7 +124,7 @@ def reset_agent(self) -> None:
hard_update(self.target_critic, self.critic)

@property
def loss(self) -> Dict[str, float]:
def loss(self) -> dict[str, float]:
return {"actor": self._loss_actor, "critic": self._loss_critic}

@loss.setter
Expand Down Expand Up @@ -201,7 +200,6 @@ def step(self, experience: Experience) -> None:
soft_update(self.target_critic, self.critic, self.tau)

def compute_value_loss(self, states, actions, next_states, rewards, dones):

Q_expected = self.critic(states, actions)

with torch.no_grad():
Expand Down Expand Up @@ -254,7 +252,7 @@ def learn(self, experiences) -> None:

self.critic.requires_grad_ = True

def state_dict(self) -> Dict[str, dict]:
def state_dict(self) -> dict[str, dict]:
"""Describes agent's networks.
Returns:
Expand Down Expand Up @@ -328,7 +326,7 @@ def save_state(self, path: str) -> None:
agent_state = self.get_state()
torch.save(agent_state, path)

def load_state(self, *, path: Optional[str] = None, agent_state: Optional[dict] = None):
def load_state(self, *, path: str | None = None, agent_state: dict | None = None):
if path is None and agent_state:
raise ValueError("Either `path` or `agent_state` must be provided to load agent's state.")
if path is not None and agent_state is None:
Expand Down
16 changes: 8 additions & 8 deletions ai_traineree/agents/dqn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from typing import Callable, Dict, Optional, Type
from typing import Callable, Type

import torch
import torch.nn as nn
Expand Down Expand Up @@ -40,9 +40,9 @@ def __init__(
action_space: DataSpace,
network_fn: Callable[[], NetworkType] = None,
network_class: Type[NetworkTypeClass] = None,
state_transform: Optional[Callable] = None,
reward_transform: Optional[Callable] = None,
**kwargs
state_transform: Callable | None = None,
reward_transform: Callable | None = None,
**kwargs,
):
"""Initiates the DQN agent.
Expand Down Expand Up @@ -113,7 +113,7 @@ def __init__(
self._loss: float = float("nan")

@property
def loss(self) -> Dict[str, float]:
def loss(self) -> dict[str, float]:
return {"loss": self._loss}

@loss.setter
Expand Down Expand Up @@ -202,7 +202,7 @@ def act(self, experience: Experience, eps: float = 0.0) -> Experience:
action = int(torch.argmax(action_values.cpu()))
return experience.update(action=action)

def learn(self, experiences: Dict[str, list]) -> None:
def learn(self, experiences: dict[str, list]) -> None:
"""Updates agent's networks based on provided experience.
Parameters:
Expand Down Expand Up @@ -237,7 +237,7 @@ def learn(self, experiences: Dict[str, list]) -> None:
assert any(~torch.isnan(error))
self.buffer.priority_update(experiences["index"], error.abs())

def state_dict(self) -> Dict[str, dict]:
def state_dict(self) -> dict[str, dict]:
"""Describes agent's networks.
Returns:
Expand Down Expand Up @@ -301,7 +301,7 @@ def save_state(self, path: str):
agent_state = self.get_state()
torch.save(agent_state, path)

def load_state(self, *, path: Optional[str] = None, state: Optional[AgentState] = None) -> None:
def load_state(self, *, path: str | None = None, state: AgentState | None = None) -> None:
"""Loads state from a file under provided path.
Parameters:
Expand Down
7 changes: 3 additions & 4 deletions ai_traineree/agents/ppo.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import copy
import itertools
import logging
from typing import Dict, List

import torch
import torch.nn as nn
Expand Down Expand Up @@ -124,10 +123,10 @@ def __init__(self, obs_space: DataSpace, action_space: DataSpace, **kwargs):
self.critic_opt = optim.Adam(self.critic_params, lr=self.critic_lr)
self._loss_actor = float("nan")
self._loss_critic = float("nan")
self._metrics: Dict[str, float] = {}
self._metrics: dict[str, float] = {}

@property
def loss(self) -> Dict[str, float]:
def loss(self) -> dict[str, float]:
return {"actor": self._loss_actor, "critic": self._loss_critic}

@loss.setter
Expand Down Expand Up @@ -163,7 +162,7 @@ def act(self, experience: Experience, noise: float = 0.0) -> Experience:
Experience updated with action taken.
"""
actions: List[ActionType] = []
actions: list[ActionType] = []
logprobs = []
values = []
t_obs = to_tensor(experience.obs).view((self.num_workers,) + self.obs_space.shape).float().to(self.device)
Expand Down
10 changes: 5 additions & 5 deletions ai_traineree/agents/rainbow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from typing import Callable, Dict, List, Optional
from typing import Callable

import torch
import torch.nn as nn
Expand Down Expand Up @@ -45,8 +45,8 @@ def __init__(
self,
obs_space: DataSpace,
action_space: DataSpace,
state_transform: Optional[Callable] = None,
reward_transform: Optional[Callable] = None,
state_transform: Callable | None = None,
reward_transform: Callable | None = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -201,7 +201,7 @@ def act(self, experience: Experience, eps: float = 0.0) -> Experience:
action = int(q_values.argmax(-1)) # Action maximizes state-action value Q(s, a)
return experience.update(action=action)

def learn(self, experiences: Dict[str, List]) -> None:
def learn(self, experiences: dict[str, list]) -> None:
"""
Parameters:
experiences: Contains all experiences for the agent. Typically sampled from the memory buffer.
Expand Down Expand Up @@ -256,7 +256,7 @@ def learn(self, experiences: Dict[str, List]) -> None:
# Update networks - sync local & target
soft_update(self.target_net, self.net, self.tau)

def state_dict(self) -> Dict[str, dict]:
def state_dict(self) -> dict[str, dict]:
"""Returns agent's state dictionary.
Returns:
Expand Down
9 changes: 4 additions & 5 deletions ai_traineree/agents/sac.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import copy
import itertools
from functools import cached_property
from typing import Dict, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -162,14 +161,14 @@ def __init__(self, obs_space: DataSpace, action_space: DataSpace, **kwargs):

self._loss_actor = float("nan")
self._loss_critic = float("nan")
self._metrics: Dict[str, Union[float, Dict[str, float]]] = {}
self._metrics: dict[str, float | dict[str, float]] = {}

@property
def alpha(self):
return self.log_alpha.exp()

@property
def loss(self) -> Dict[str, float]:
def loss(self) -> dict[str, float]:
return {"actor": self._loss_actor, "critic": self._loss_critic}

@loss.setter
Expand Down Expand Up @@ -204,7 +203,7 @@ def action_min(self):
def action_max(self):
return to_tensor(self.action_space.high)

def state_dict(self) -> Dict[str, dict]:
def state_dict(self) -> dict[str, dict]:
"""
Returns network's weights in order:
Actor, TargetActor, Critic, TargetCritic
Expand Down Expand Up @@ -270,7 +269,7 @@ def step(self, experience: Experience) -> None:
for _ in range(self.number_updates):
self.learn(self.buffer.sample())

def compute_value_loss(self, states, actions, rewards, next_states, dones) -> Tuple[Tensor, Tensor]:
def compute_value_loss(self, states, actions, rewards, next_states, dones) -> tuple[Tensor, Tensor]:
Q1_expected, Q2_expected = self.double_critic(states, actions)

with torch.no_grad():
Expand Down
7 changes: 3 additions & 4 deletions ai_traineree/agents/td3.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from functools import cached_property
from typing import Dict

import torch
import torch.nn as nn
Expand Down Expand Up @@ -35,7 +34,7 @@ def __init__(
action_space: DataSpace,
noise_scale: float = 0.5,
noise_sigma: float = 1.0,
**kwargs
**kwargs,
):
"""
Parameters:
Expand Down Expand Up @@ -129,7 +128,7 @@ def __init__(
self._loss_critic = float("nan")

@property
def loss(self) -> Dict[str, float]:
def loss(self) -> dict[str, float]:
return {"actor": self._loss_actor, "critic": self._loss_critic}

@loss.setter
Expand Down Expand Up @@ -262,7 +261,7 @@ def _update_policy(self, states):

self.critic.requires_grad_ = True

def state_dict(self) -> Dict[str, dict]:
def state_dict(self) -> dict[str, dict]:
"""Describes agent's networks.
Returns:
Expand Down
14 changes: 7 additions & 7 deletions ai_traineree/buffers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
from collections import defaultdict
from typing import Dict, List, Optional, Sequence, Union
from typing import Sequence

import numpy as np
import torch
Expand Down Expand Up @@ -34,15 +34,15 @@ def add(self, **kwargs):
"""Add samples to the buffer."""
raise NotImplementedError("You shouldn't see this. Look away. Or fix it.")

def sample(self, *args, **kwargs) -> Optional[List[Experience]]:
def sample(self, *args, **kwargs) -> list[Experience | None]:
"""Sample buffer for a set of experience."""
raise NotImplementedError("You shouldn't see this. Look away. Or fix it.")

def dump_buffer(self, serialize: bool = False) -> List[Dict]:
def dump_buffer(self, serialize: bool = False) -> list[dict]:
"""Return the whole buffer, e.g. for storing."""
raise NotImplementedError("You shouldn't see this. Look away. Or fix it.")

def load_buffer(self, buffer: List[Experience]) -> None:
def load_buffer(self, buffer: list[Experience]) -> None:
"""Loads provided data into the buffer."""
raise NotImplementedError("You shouldn't see this. Look away. Or fix it.")

Expand All @@ -67,22 +67,22 @@ def __len__(self) -> int:
return len(self.buffer)

@staticmethod
def _hash_element(el) -> Union[int, str]:
def _hash_element(el) -> int | str:
if isinstance(el, np.ndarray):
return hash(el.data.tobytes())
elif isinstance(el, torch.Tensor):
return hash(str(el))
else:
return str(el)

def add(self, el) -> Union[int, str]:
def add(self, el) -> int | str:
idx = self._hash_element(el)
self.counter[idx] += 1
if self.counter[idx] < 2:
self.buffer[idx] = el
return idx

def get(self, idx: Union[int, str]):
def get(self, idx: int | str):
return self.buffer[idx]

def remove(self, idx: str):
Expand Down
Loading

0 comments on commit 9d9d340

Please sign in to comment.