Skip to content

Commit

Permalink
fix(ws): Fix JSON decode error on invalid commands
Browse files Browse the repository at this point in the history
  • Loading branch information
msbrogli committed Jan 14, 2025
1 parent 508030c commit 52aa340
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
6 changes: 6 additions & 0 deletions hathor/websocket/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ class WebSocketMessage(BaseModel):
pass


class WebSocketErrorMessage(BaseModel):
type: str = Field('error', const=True)
success: bool = Field(False, const=True)
errmsg: str


class CapabilitiesMessage(WebSocketMessage):
type: str = Field('capabilities', const=True)
capabilities: list[str]
Expand Down
18 changes: 13 additions & 5 deletions hathor/websocket/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from json import JSONDecodeError
from typing import TYPE_CHECKING, Any, Union

from autobahn.twisted.websocket import WebSocketServerProtocol
Expand All @@ -28,7 +29,7 @@
aiter_xpub_addresses,
gap_limit_search,
)
from hathor.websocket.messages import CapabilitiesMessage, StreamErrorMessage, WebSocketMessage
from hathor.websocket.messages import CapabilitiesMessage, StreamErrorMessage, WebSocketErrorMessage, WebSocketMessage
from hathor.websocket.streamer import HistoryStreamer

if TYPE_CHECKING:
Expand Down Expand Up @@ -103,10 +104,17 @@ def onClose(self, wasClean, code, reason):
def onMessage(self, payload: Union[bytes, str], isBinary: bool) -> None:
"""Called by the websocket protocol when a new message is received."""
self.log.debug('new message', payload=payload.hex() if isinstance(payload, bytes) else payload)
if isinstance(payload, bytes):
message = json_loadb(payload)
else:
message = json_loads(payload)

try:
if isinstance(payload, bytes):
message = json_loadb(payload)
else:
message = json_loads(payload)
except JSONDecodeError:
self.send_message(WebSocketErrorMessage(
errmsg='Malformed command'
))
return

_type = message.get('type')

Expand Down

0 comments on commit 52aa340

Please sign in to comment.