Skip to content

Commit

Permalink
ENG-524: convert depletion intervals to be a list (#207)
Browse files Browse the repository at this point in the history
ENG-524: make notification intervals more flexible, supporting multiple different scenarios
  • Loading branch information
parikls authored Jan 3, 2025
1 parent 0adeeb9 commit f90a36d
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 34 deletions.
73 changes: 54 additions & 19 deletions neuro_admin_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from abc import abstractmethod
from collections.abc import AsyncIterator, Sequence
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from dataclasses import dataclass, field
from dataclasses import asdict, dataclass, field
from datetime import datetime
from decimal import Decimal
from typing import Any, List, Tuple, Union, overload
Expand All @@ -22,6 +22,7 @@
ClusterUserWithInfo,
Org,
OrgCluster,
OrgNotificationIntervals,
OrgUser,
OrgUserRoleType,
OrgUserWithInfo,
Expand Down Expand Up @@ -471,16 +472,29 @@ async def update_org_defaults(
self,
org_name: str,
user_default_credits: Decimal | None,
notification_balance_depletion_seconds: int | None = None,
notification_intervals: OrgNotificationIntervals | None = None,
) -> Org: ...

@abstractmethod
async def update_org(
self,
org_name: str,
user_default_credits: Decimal | None,
notification_balance_depletion_seconds: int | None = None,
) -> Org: ...
notification_intervals: OrgNotificationIntervals | None = None,
) -> Org:
"""
Updates an organizations.
:param org_name:
Will be used to identify an org.
A name itself won't be updated
:param user_default_credits:
A decimal value which will be used as a default balance
for all the users created in the future.
Can be `None` to disable such functionality for the organization
:param notification_intervals:
An instance of notification intervals object.
See a docstring of `OrgNotificationIntervals` for more details
"""

# org user

Expand Down Expand Up @@ -1115,14 +1129,23 @@ def _parse_quota(self, payload: dict[str, Any] | None) -> Quota:
return Quota()
return Quota(total_running_jobs=payload.get("total_running_jobs"))

def _parse_balance(self, payload: dict[str, Any] | None) -> Balance:
@staticmethod
def _parse_balance(payload: dict[str, Any] | None) -> Balance:
if payload is None:
return Balance()
return Balance(
spent_credits=Decimal(payload["spent_credits"]),
credits=Decimal(payload["credits"]) if payload.get("credits") else None,
)

@staticmethod
def _parse_notification_intervals(
payload: dict[str, Any] | None
) -> OrgNotificationIntervals | None:
if payload is None:
return None
return OrgNotificationIntervals(**payload)

def _parse_cluster_user(
self, cluster_name: str, payload: dict[str, Any]
) -> ClusterUser | ClusterUserWithInfo:
Expand Down Expand Up @@ -1729,8 +1752,8 @@ def _parse_org_payload(self, payload: dict[str, Any]) -> Org:
if payload.get("user_default_credits")
else None
),
notification_balance_depletion_seconds=payload.get(
"notification_balance_depletion_seconds"
notification_intervals=self._parse_notification_intervals(
payload.get("notification_intervals")
),
)

Expand Down Expand Up @@ -1846,18 +1869,16 @@ async def update_org(
self,
org_name: str,
user_default_credits: Decimal | None,
notification_balance_depletion_seconds: int | None = None,
notification_intervals: OrgNotificationIntervals | None = None,
) -> Org:
credits = (
str(user_default_credits) if user_default_credits is not None else None
)
payload: dict[str, str | int | None] = {
payload: dict[str, Any] = {
"credits": credits,
}
if notification_balance_depletion_seconds is not None:
payload["notification_balance_depletion_seconds"] = (
notification_balance_depletion_seconds
)
if notification_intervals is not None:
payload["notification_intervals"] = asdict(notification_intervals)

async with self._request(
"PATCH",
Expand All @@ -1872,12 +1893,12 @@ async def update_org_defaults(
self,
org_name: str,
user_default_credits: Decimal | None,
notification_balance_depletion_seconds: int | None = None,
notification_intervals: OrgNotificationIntervals | None = None,
) -> Org:
return await self.update_org(
org_name=org_name,
user_default_credits=user_default_credits,
notification_balance_depletion_seconds=notification_balance_depletion_seconds,
notification_intervals=notification_intervals,
)

# org user
Expand Down Expand Up @@ -2639,7 +2660,21 @@ class AdminClientDummy(AdminClientABC):
name="org",
balance=Balance(),
user_default_credits=None,
notification_balance_depletion_seconds=60 * 60 * 24,
notification_intervals=OrgNotificationIntervals(
balance_projection_seconds=[
60 * 60 * 24 * 7,
60 * 60 * 24 * 3,
60 * 60 * 24 * 1,
],
balance_amount=[
-100,
-500,
],
negative_balance_seconds=[
60 * 60 * 24 * 1,
60 * 60 * 24 * 7,
],
),
)
DUMMY_ORG_CLUSTER = OrgCluster(
org_name="org",
Expand Down Expand Up @@ -3105,20 +3140,20 @@ async def update_org(
self,
org_name: str,
default_credits: Decimal | None,
notification_balance_depletion_seconds: int | None = None,
notification_intervals: OrgNotificationIntervals | None = None,
) -> Org:
return self.DUMMY_ORG

async def update_org_defaults(
self,
org_name: str,
default_credits: Decimal | None,
notification_balance_depletion_seconds: int | None = None,
notification_intervals: OrgNotificationIntervals | None = None,
) -> Org:
return await self.update_org(
org_name=org_name,
default_credits=default_credits,
notification_balance_depletion_seconds=notification_balance_depletion_seconds,
notification_intervals=notification_intervals,
)

# org user
Expand Down
22 changes: 20 additions & 2 deletions neuro_admin_client/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from datetime import datetime
from decimal import Decimal
from enum import Enum, unique
from typing import Optional
from typing import List, Optional


class FullNameMixin:
Expand Down Expand Up @@ -75,12 +75,30 @@ class Cluster:
maintenance: bool = False


@dataclass(frozen=True)
class OrgNotificationIntervals:
balance_projection_seconds: Optional[List[int]]
"""How many seconds left till the balance reaches zero?
A list of integers, where each number represents a seconds-based interval,
at which the organization management team will receive a notification,
if the projected usage will lead to a reaching of a zero-balance in that
amount of seconds.
"""
balance_amount: Optional[List[int]]
"""What exact balance amounts should trigger a notification?
"""
negative_balance_seconds: Optional[List[int]]
"""If a balance is negative, when we should send a notification?
e.g. 86_400 means 1 day after org reaches a zero balance.
"""


@dataclass(frozen=True)
class Org:
name: str
balance: Balance = Balance()
user_default_credits: Optional[Decimal] = None
notification_balance_depletion_seconds: Optional[int] = None
notification_intervals: Optional[OrgNotificationIntervals] = None


@unique
Expand Down
20 changes: 10 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import datetime
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from dataclasses import dataclass, field, replace
from dataclasses import asdict, dataclass, field, replace
from decimal import Decimal
from typing import Any

Expand All @@ -20,6 +20,7 @@
ClusterUserRoleType,
Org,
OrgCluster,
OrgNotificationIntervals,
OrgUser,
OrgUserRoleType,
Project,
Expand Down Expand Up @@ -157,10 +158,8 @@ def _serialize_org(self, org: Org) -> dict[str, Any]:
res["balance"]["credits"] = str(org.balance.credits)
if org.user_default_credits:
res["user_default_credits"] = str(org.user_default_credits)
if org.notification_balance_depletion_seconds:
res["notification_balance_depletion_seconds"] = (
org.notification_balance_depletion_seconds
)
if org.notification_intervals:
res["notification_intervals"] = asdict(org.notification_intervals)
return res

async def handle_org_post(
Expand Down Expand Up @@ -266,13 +265,14 @@ async def handle_org_patch_defaults(
else None
),
)
notification_balance_depletion_seconds = payload.get(
"notification_balance_depletion_seconds"
)
if notification_balance_depletion_seconds:
notification_intervals = payload.get("notification_intervals")
if notification_intervals:
notification_intervals = OrgNotificationIntervals(
**payload["notification_intervals"]
)
org = replace(
org,
notification_balance_depletion_seconds=notification_balance_depletion_seconds,
notification_intervals=notification_intervals,
)
self.orgs[index] = org
return aiohttp.web.json_response(self._serialize_org(org))
Expand Down
28 changes: 25 additions & 3 deletions tests/test_admin_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
GetUserResponse,
Org,
OrgCluster,
OrgNotificationIntervals,
OrgUser,
OrgUserRoleType,
Project,
Expand Down Expand Up @@ -287,19 +288,40 @@ async def test_patch_org(self, mock_admin_server: AdminServer) -> None:
Org(
name="org",
user_default_credits=Decimal(100),
notification_balance_depletion_seconds=60 * 60 * 24,
notification_intervals=OrgNotificationIntervals(
balance_projection_seconds=[60 * 60 * 24],
balance_amount=None,
negative_balance_seconds=None,
),
),
]

async with AdminClient(base_url=mock_admin_server.url) as client:
org = await client.update_org(
org_name="org",
user_default_credits=Decimal(200),
notification_balance_depletion_seconds=60 * 60 * 24 * 2,
notification_intervals=OrgNotificationIntervals(
balance_projection_seconds=[60 * 60 * 24 * 2],
balance_amount=[
-100,
],
negative_balance_seconds=[
60,
],
),
)
assert org.user_default_credits == Decimal(200)
assert org.notification_balance_depletion_seconds == 60 * 60 * 24 * 2
assert org == mock_admin_server.orgs[0]
intervals = t.cast(OrgNotificationIntervals, org.notification_intervals)

assert intervals.balance_projection_seconds is not None
assert intervals.balance_projection_seconds[0] == 60 * 60 * 24 * 2

assert intervals.balance_amount is not None
assert intervals.balance_amount[0] == -100

assert intervals.negative_balance_seconds is not None
assert intervals.negative_balance_seconds[0] == 60

async def test_create_org_with_defaults(
self, mock_admin_server: AdminServer
Expand Down

0 comments on commit f90a36d

Please sign in to comment.