-
Notifications
You must be signed in to change notification settings - Fork 111
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #181 from mkanoor/pg_listener
feat: added pg_listener source plugin
- Loading branch information
Showing
4 changed files
with
309 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
"""pg_listener.py. | ||
An ansible-rulebook event source plugin for reading events from | ||
pg_pub_sub | ||
Arguments: | ||
--------- | ||
dsn: The connection string/dsn for Postgres | ||
channels: The list of channels to listen | ||
Example: | ||
------- | ||
- ansible.eda.pg_listener: | ||
dsn: "host=localhost port=5432 dbname=mydb" | ||
channels: | ||
- my_events | ||
- my_alerts | ||
Chunking: | ||
--------- | ||
This is just informational a user doesn't have to do anything | ||
special to enablo chunking. The sender which is the pg_notify | ||
action from ansible rulebook will decide if chunking needs to | ||
happen based on the size of the payload. | ||
If the messages are over 7KB the sender will chunk the messages | ||
into separate payloads with each payload having having the following | ||
keys | ||
* _message_chunked_uuid The unique message uuid | ||
* _message_chunk_count The number of chunks for the message | ||
* _message_chunk_sequence The sequence of the current chunk | ||
* _chunk The actual chunk | ||
* _message_length The total length of the message | ||
* _message_xx_hash A hash for the entire message | ||
The pg_listener source will assemble the chunks and once all the | ||
chunks have been received it will deliver the entire payload to the | ||
rulebook engine. Before the payload is delivered we validated that the entire | ||
message has been received by validate its computed hash. | ||
""" | ||
|
||
import asyncio | ||
import json | ||
import logging | ||
from typing import Any | ||
|
||
import xxhash | ||
from psycopg import AsyncConnection, OperationalError | ||
|
||
LOGGER = logging.getLogger(__name__) | ||
|
||
MESSAGE_CHUNKED_UUID = "_message_chunked_uuid" | ||
MESSAGE_CHUNK_COUNT = "_message_chunk_count" | ||
MESSAGE_CHUNK_SEQUENCE = "_message_chunk_sequence" | ||
MESSAGE_CHUNK = "_chunk" | ||
MESSAGE_LENGTH = "_message_length" | ||
MESSAGE_XX_HASH = "_message_xx_hash" | ||
REQUIRED_KEYS = ("dsn", "channels") | ||
|
||
REQUIRED_CHUNK_KEYS = ( | ||
MESSAGE_CHUNK_COUNT, | ||
MESSAGE_CHUNK_SEQUENCE, | ||
MESSAGE_CHUNK, | ||
MESSAGE_LENGTH, | ||
MESSAGE_XX_HASH, | ||
) | ||
|
||
|
||
class MissingRequiredArgumentError(Exception): | ||
"""Exception class for missing arguments.""" | ||
|
||
def __init__(self: "MissingRequiredArgumentError", key: str) -> None: | ||
"""Class constructor with the missing key.""" | ||
super().__init__(f"PG Listener {key} is a required argument") | ||
|
||
|
||
class MissingChunkKeyError(Exception): | ||
"""Exception class for missing chunking keys.""" | ||
|
||
def __init__(self: "MissingChunkKeyError", key: str) -> None: | ||
"""Class constructor with the missing key.""" | ||
super().__init__(f"Chunked payload is missing required {key}") | ||
|
||
|
||
def _validate_chunked_payload(payload: dict) -> None: | ||
for key in REQUIRED_CHUNK_KEYS: | ||
if key not in payload: | ||
raise MissingChunkKeyError(key) | ||
|
||
|
||
async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None: | ||
"""Listen for events from a channel.""" | ||
for key in REQUIRED_KEYS: | ||
if key not in args: | ||
raise MissingRequiredArgumentError(key) | ||
|
||
try: | ||
async with await AsyncConnection.connect( | ||
conninfo=args["dsn"], | ||
autocommit=True, | ||
) as conn: | ||
chunked_cache = {} | ||
cursor = conn.cursor() | ||
for channel in args["channels"]: | ||
await cursor.execute(f"LISTEN {channel};") | ||
LOGGER.debug("Waiting for notifications on channel %s", channel) | ||
async for event in conn.notifies(): | ||
data = json.loads(event.payload) | ||
if MESSAGE_CHUNKED_UUID in data: | ||
_validate_chunked_payload(data) | ||
await _handle_chunked_message(data, chunked_cache, queue) | ||
else: | ||
await queue.put(data) | ||
except json.decoder.JSONDecodeError: | ||
LOGGER.exception("Error decoding data, ignoring it") | ||
except OperationalError: | ||
LOGGER.exception("PG Listen operational error") | ||
|
||
|
||
async def _handle_chunked_message( | ||
data: dict, | ||
chunked_cache: dict, | ||
queue: asyncio.Queue, | ||
) -> None: | ||
message_uuid = data[MESSAGE_CHUNKED_UUID] | ||
number_of_chunks = data[MESSAGE_CHUNK_COUNT] | ||
message_length = data[MESSAGE_LENGTH] | ||
LOGGER.debug( | ||
"Received chunked message %s total chunks %d message length %d", | ||
message_uuid, | ||
number_of_chunks, | ||
message_length, | ||
) | ||
if message_uuid in chunked_cache: | ||
chunked_cache[message_uuid].append(data) | ||
else: | ||
chunked_cache[message_uuid] = [data] | ||
if ( | ||
len(chunked_cache[message_uuid]) | ||
== chunked_cache[message_uuid][0][MESSAGE_CHUNK_COUNT] | ||
): | ||
LOGGER.debug( | ||
"Received all chunks for message %s", | ||
message_uuid, | ||
) | ||
all_data = "" | ||
for chunk in chunked_cache[message_uuid]: | ||
all_data += chunk[MESSAGE_CHUNK] | ||
chunks = chunked_cache.pop(message_uuid) | ||
xx_hash = xxhash.xxh32(all_data.encode("utf-8")).hexdigest() | ||
LOGGER.debug("Computed XX Hash is %s", xx_hash) | ||
LOGGER.debug( | ||
"XX Hash expected %s", | ||
chunks[0][MESSAGE_XX_HASH], | ||
) | ||
if xx_hash == chunks[0][MESSAGE_XX_HASH]: | ||
data = json.loads(all_data) | ||
await queue.put(data) | ||
else: | ||
LOGGER.error("XX Hash of chunked payload doesn't match") | ||
else: | ||
LOGGER.debug( | ||
"Received %d chunks for message %s", | ||
len(chunked_cache[message_uuid]), | ||
message_uuid, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
# MockQueue if running directly | ||
|
||
class MockQueue: | ||
"""A fake queue.""" | ||
|
||
async def put(self: "MockQueue", event: dict) -> None: | ||
"""Print the event.""" | ||
print(event) # noqa: T201 | ||
|
||
asyncio.run( | ||
main( | ||
MockQueue(), | ||
{ | ||
"dsn": "host=localhost port=5432 dbname=eda " | ||
"user=postgres password=secret", | ||
"channels": ["my_channel"], | ||
}, | ||
), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,3 +7,5 @@ kafka-python | |
pyyaml | ||
systemd-python; sys_platform != 'darwin' | ||
watchdog | ||
psycopg | ||
xxhash |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
""" Tests for pg_listener source plugin """ | ||
|
||
import asyncio | ||
import json | ||
import uuid | ||
from unittest.mock import AsyncMock, MagicMock, patch | ||
|
||
import pytest | ||
import xxhash | ||
|
||
from extensions.eda.plugins.event_source.pg_listener import ( | ||
MESSAGE_CHUNK, | ||
MESSAGE_CHUNK_COUNT, | ||
MESSAGE_CHUNK_SEQUENCE, | ||
MESSAGE_CHUNKED_UUID, | ||
MESSAGE_LENGTH, | ||
MESSAGE_XX_HASH, | ||
) | ||
from extensions.eda.plugins.event_source.pg_listener import main as pg_listener_main | ||
|
||
MAX_LENGTH = 7 * 1024 | ||
|
||
|
||
class _MockQueue: | ||
def __init__(self): | ||
self.queue = [] | ||
|
||
async def put(self, event): | ||
"""Put an event into the queue""" | ||
self.queue.append(event) | ||
|
||
|
||
class _AsyncIterator: | ||
def __init__(self, data): | ||
self.count = 0 | ||
self.data = data | ||
|
||
def __aiter__(self): | ||
return _AsyncIterator(self.data) | ||
|
||
async def __aenter__(self): | ||
return self | ||
|
||
async def __anext__(self): | ||
if self.count >= len(self.data): | ||
raise StopAsyncIteration | ||
|
||
mock = MagicMock() | ||
mock.payload = self.data[self.count] | ||
self.count += 1 | ||
return mock | ||
|
||
|
||
def _to_chunks(payload: str, result: list[str]): | ||
message_length = len(payload) | ||
if message_length >= MAX_LENGTH: | ||
xx_hash = xxhash.xxh32(payload.encode("utf-8")).hexdigest() | ||
message_uuid = str(uuid.uuid4()) | ||
number_of_chunks = int(message_length / MAX_LENGTH) + 1 | ||
chunked = { | ||
MESSAGE_CHUNKED_UUID: message_uuid, | ||
MESSAGE_CHUNK_COUNT: number_of_chunks, | ||
MESSAGE_LENGTH: message_length, | ||
MESSAGE_XX_HASH: xx_hash, | ||
} | ||
sequence = 1 | ||
for i in range(0, message_length, MAX_LENGTH): | ||
chunked[MESSAGE_CHUNK] = payload[i : i + MAX_LENGTH] | ||
chunked[MESSAGE_CHUNK_SEQUENCE] = sequence | ||
sequence += 1 | ||
result.append(json.dumps(chunked)) | ||
else: | ||
result.append(payload) | ||
|
||
|
||
TEST_PAYLOADS = [ | ||
[{"a": 1, "b": 2}, {"name": "Fred", "kids": ["Pebbles"]}], | ||
[{"blob": "x" * 9000, "huge": "h" * 9000}], | ||
[{"a": 1, "x": 2}, {"x": "y" * 20000, "fail": False, "pi": 3.14159}], | ||
] | ||
|
||
|
||
@pytest.mark.parametrize("events", TEST_PAYLOADS) | ||
def test_receive_from_pg_listener(events): | ||
"""Test receiving different payloads from pg notify.""" | ||
notify_payload = [] | ||
myqueue = _MockQueue() | ||
for event in events: | ||
_to_chunks(json.dumps(event), notify_payload) | ||
|
||
def my_iterator(): | ||
return _AsyncIterator(notify_payload) | ||
|
||
with patch( | ||
"extensions.eda.plugins.event_source.pg_listener.AsyncConnection.connect" | ||
) as conn: | ||
mock_object = AsyncMock() | ||
conn.return_value = mock_object | ||
conn.return_value.__aenter__.return_value = mock_object | ||
mock_object.cursor = AsyncMock | ||
mock_object.notifies = my_iterator | ||
|
||
asyncio.run( | ||
pg_listener_main( | ||
myqueue, | ||
{ | ||
"dsn": "host=localhost dbname=mydb user=postgres password=password", | ||
"channels": ["test"], | ||
}, | ||
) | ||
) | ||
|
||
assert len(myqueue.queue) == len(events) | ||
index = 0 | ||
for event in events: | ||
assert myqueue.queue[index] == event | ||
index += 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,3 +6,5 @@ asyncmock | |
azure-servicebus | ||
dpath | ||
pytest-asyncio | ||
psycopg | ||
xxhash |