From 52aa340eac729b3fa484803620d80272c3823f29 Mon Sep 17 00:00:00 2001 From: Marcelo Salhab Brogliato Date: Tue, 14 Jan 2025 11:04:51 -0600 Subject: [PATCH] fix(ws): Fix JSON decode error on invalid commands --- hathor/websocket/messages.py | 6 ++++++ hathor/websocket/protocol.py | 18 +++++++++++++----- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/hathor/websocket/messages.py b/hathor/websocket/messages.py index 86058759b..524b36b79 100644 --- a/hathor/websocket/messages.py +++ b/hathor/websocket/messages.py @@ -23,6 +23,12 @@ class WebSocketMessage(BaseModel): pass +class WebSocketErrorMessage(BaseModel): + type: str = Field('error', const=True) + success: bool = Field(False, const=True) + errmsg: str + + class CapabilitiesMessage(WebSocketMessage): type: str = Field('capabilities', const=True) capabilities: list[str] diff --git a/hathor/websocket/protocol.py b/hathor/websocket/protocol.py index e23d2b60a..5f173151e 100644 --- a/hathor/websocket/protocol.py +++ b/hathor/websocket/protocol.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from json import JSONDecodeError from typing import TYPE_CHECKING, Any, Union from autobahn.twisted.websocket import WebSocketServerProtocol @@ -28,7 +29,7 @@ aiter_xpub_addresses, gap_limit_search, ) -from hathor.websocket.messages import CapabilitiesMessage, StreamErrorMessage, WebSocketMessage +from hathor.websocket.messages import CapabilitiesMessage, StreamErrorMessage, WebSocketErrorMessage, WebSocketMessage from hathor.websocket.streamer import HistoryStreamer if TYPE_CHECKING: @@ -103,10 +104,17 @@ def onClose(self, wasClean, code, reason): def onMessage(self, payload: Union[bytes, str], isBinary: bool) -> None: """Called by the websocket protocol when a new message is received.""" self.log.debug('new message', payload=payload.hex() if isinstance(payload, bytes) else payload) - if isinstance(payload, bytes): - message = json_loadb(payload) - else: - message = json_loads(payload) + + try: + if isinstance(payload, bytes): + message = json_loadb(payload) + else: + message = json_loads(payload) + except JSONDecodeError: + self.send_message(WebSocketErrorMessage( + errmsg='Malformed command' + )) + return _type = message.get('type')