Skip to content

Commit

Permalink
saving
Browse files Browse the repository at this point in the history
  • Loading branch information
KlemenSpruk committed Nov 20, 2023
1 parent 352eda4 commit 043a3b4
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 15 deletions.
1 change: 1 addition & 0 deletions django_project_base/celery/background_tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def before_start(self, task_id, args, kwargs):
dw = backend.DatabaseWrapper(db_settings)
dw.connect()
connections.databases[NOTIFICATION_QUEUE_NAME] = dw.settings_dict
connections.databases["default"] = dw.settings_dict

def on_failure(self, exc, task_id, args, kwargs, einfo):
logging.getLogger(__name__).error(
Expand Down
5 changes: 3 additions & 2 deletions django_project_base/licensing/logic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, Optional

from django.contrib.contenttypes.models import ContentType
from django.db import connections
from django.db.models import Model, Sum
from django.utils.translation import gettext
from dynamicforms import fields
Expand Down Expand Up @@ -80,8 +81,8 @@ def log(
on_sucess=None,
**kwargs,
) -> int:
content_type = ContentType.objects.using(self.db).get_for_model(model=record._meta.model)

connections["default"] = connections[self.db]
content_type = ContentType.objects.get_for_model(model=record._meta.model)
used = (
LicenseAccessUse.objects.using(self.db)
.filter(user_id=str(user_profile_pk), content_type=content_type)
Expand Down
4 changes: 2 additions & 2 deletions django_project_base/notifications/base/channels/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def _make_send(self, notification_obj, rec_obj, message_str, dlr_pk) -> Tuple[Op
logger.exception(de)
return dlr_obj, (sent and do_send)

def send(self, notification: DjangoProjectBaseNotification, extra_data, **kwargs) -> int:
def send(self, notification: DjangoProjectBaseNotification, extra_data, settings: Settings, **kwargs) -> int:
logger = logging.getLogger("django")
try:
message = self.provider.get_message(notification)
Expand Down Expand Up @@ -243,7 +243,7 @@ def send(self, notification: DjangoProjectBaseNotification, extra_data, **kwargs
else:
exclude_providers.append(f"{self.provider.__module__}.{self.provider.__class__.__name__}")
if next_provider := self._find_provider(
extra_settings=extra_data,
settings=settings,
setting_name=self.provider_setting_name,
exclude=exclude_providers,
):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import uuid
from typing import List

from django.conf import settings
from django.conf import settings, Settings
from django.core.exceptions import ValidationError
from django.core.validators import validate_email

Expand All @@ -19,9 +19,9 @@ class MailChannel(Channel):

provider_setting_name = "NOTIFICATIONS_EMAIL_PROVIDER"

def send(self, notification: DjangoProjectBaseNotification, extra_data, **kwargs) -> int:
def send(self, notification: DjangoProjectBaseNotification, extra_data, settings: Settings, **kwargs) -> int:
if getattr(settings, "TESTING", False):
return super().send(notification=notification, extra_data=extra_data)
return super().send(notification=notification, extra_data=extra_data, settings=settings)
message = self.provider.get_message(notification)
sender = self.sender(notification)
self.provider.client_send(
Expand All @@ -30,7 +30,7 @@ def send(self, notification: DjangoProjectBaseNotification, extra_data, **kwargs
message,
str(uuid.uuid4()),
)
return super().send(notification=notification, extra_data=extra_data)
return super().send(notification=notification, extra_data=extra_data, settings=settings)

def get_recipients(self, notification: DjangoProjectBaseNotification, unique_identifier=""):
return list(set(super().get_recipients(notification, unique_identifier="email")))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List

from django.conf import Settings
from django_project_base.notifications.base.channels.channel import Channel, Recipient
from django_project_base.notifications.base.enums import ChannelIdentifier
from django_project_base.notifications.models import DjangoProjectBaseNotification
Expand All @@ -17,8 +18,10 @@ class SmsChannel(Channel):
def get_recipients(self, notification: DjangoProjectBaseNotification, unique_identifier=""):
return list(set(super().get_recipients(notification, unique_identifier="phone_number")))

def send(self, notification: DjangoProjectBaseNotification, extra_data, **kwargs) -> int: # noqa: F821
return super().send(notification=notification, extra_data=extra_data)
def send(
self, notification: DjangoProjectBaseNotification, extra_data: dict, settings: Settings, **kwargs
) -> int: # noqa: F821
return super().send(notification=notification, extra_data=extra_data, settings=settings)

def clean_sms_recipients(self, recipients: List[Recipient]) -> List[Recipient]:
return list(filter(lambda r: r.phone_number and len(r.phone_number), self.clean_recipients(recipients)))
4 changes: 3 additions & 1 deletion django_project_base/notifications/base/notification.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,9 @@ def send(self) -> DjangoProjectBaseNotification:
return notification

if not self.send_at:
SendNotificationService(settings=settings).make_send(notification, self._extra_data, resend=False)
SendNotificationService(settings=settings, use_default_db_connection=True).make_send(
notification, self._extra_data, resend=False
)
else:
if not self.persist:
raise Exception("Delayed notification must be persisted")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@

class SendNotificationService(object):
settings: Settings
use_default_db_connection = False

def __init__(self, settings: Settings) -> None:
def __init__(self, settings: Settings, use_default_db_connection=False) -> None:
super().__init__()
self.settings = settings
self.use_default_db_connection = use_default_db_connection

def make_send(
self, notification: DjangoProjectBaseNotification, extra_data, resend=False
Expand All @@ -24,6 +26,7 @@ def make_send(

sent_channels: list = []
failed_channels: list = []
db_name = NOTIFICATION_QUEUE_NAME if not self.use_default_db_connection else "default"

exceptions = ""
from django_project_base.licensing.logic import LogAccessService
Expand Down Expand Up @@ -71,13 +74,15 @@ def make_send(
)
try:
# check license
any_sent = LogAccessService(db=NOTIFICATION_QUEUE_NAME).log(
any_sent = LogAccessService(db=db_name).log(
user_profile_pk=notification.user,
notifications_channels_state=sent_channels,
record=notification,
item_price=channel.notification_price,
comment=str(channel),
on_sucess=lambda: channel.send(notification, extra_data),
on_sucess=lambda: channel.send(
notification=notification, extra_data=extra_data, settings=self.settings
),
is_system_notification=extra_data.get("is_system_notification"),
sender=channel.sender(notification),
)
Expand Down Expand Up @@ -140,7 +145,7 @@ def make_send(
"failed_channels",
"exceptions",
],
using=NOTIFICATION_QUEUE_NAME,
using=db_name,
)
db.connections.close_all()
return notification

0 comments on commit 043a3b4

Please sign in to comment.