From e15f1eb083034c5b65bd3b32a65198266a9ac25b Mon Sep 17 00:00:00 2001 From: devoxin Date: Sat, 9 Mar 2024 00:01:45 +0000 Subject: [PATCH 01/14] Fix typing for request overload, yield -> yield from, do_exit -> exit for pylint test --- lavalink/__init__.py | 2 +- lavalink/node.py | 4 ++-- lavalink/nodemanager.py | 3 +-- lavalink/playermanager.py | 6 ++---- run_tests.py | 2 +- 5 files changed, 7 insertions(+), 10 deletions(-) diff --git a/lavalink/__init__.py b/lavalink/__init__.py index 973e28b0..b70bd525 100644 --- a/lavalink/__init__.py +++ b/lavalink/__init__.py @@ -4,7 +4,7 @@ __author__ = 'Devoxin' __license__ = 'MIT' __copyright__ = 'Copyright 2017-present Devoxin' -__version__ = '5.3.0' +__version__ = '5.3.1' from typing import Type diff --git a/lavalink/node.py b/lavalink/node.py index 1e95b9ef..1a5d4b01 100644 --- a/lavalink/node.py +++ b/lavalink/node.py @@ -576,11 +576,11 @@ async def update_session(self, resuming: bool = MISSING, timeout: int = MISSING) return await self.request('PATCH', f'sessions/{session_id}', json=json) # type: ignore @overload - async def request(self, method: str, path: str, *, to: Type[T], trace: bool = ..., versioned: bool = ..., **kwargs) -> T: + async def request(self, method: str, path: str, *, to: Type[str], trace: bool = ..., versioned: bool = ..., **kwargs) -> str: ... @overload - async def request(self, method: str, path: str, *, to: str, trace: bool = ..., versioned: bool = ..., **kwargs) -> str: + async def request(self, method: str, path: str, *, to: Type[T], trace: bool = ..., versioned: bool = ..., **kwargs) -> T: ... @overload diff --git a/lavalink/nodemanager.py b/lavalink/nodemanager.py index 85e85329..6cb0289d 100644 --- a/lavalink/nodemanager.py +++ b/lavalink/nodemanager.py @@ -68,8 +68,7 @@ def __len__(self) -> int: return len(self.nodes) def __iter__(self) -> Iterator[Node]: - for node in self.nodes: - yield node + yield from self.nodes @property def available_nodes(self) -> List[Node]: diff --git a/lavalink/playermanager.py b/lavalink/playermanager.py index 628c16b3..7216f667 100644 --- a/lavalink/playermanager.py +++ b/lavalink/playermanager.py @@ -69,13 +69,11 @@ def __len__(self) -> int: def __iter__(self) -> Iterator[Tuple[int, PlayerT]]: """ Returns an iterator that yields a tuple of (guild_id, player). """ - for guild_id, player in self.players.items(): - yield guild_id, player + yield from self.players.items() def values(self) -> Iterator[PlayerT]: """ Returns an iterator that yields only values. """ - for player in self.players.values(): - yield player + yield from self.players.values() def find_all(self, predicate: Optional[Callable[[PlayerT], bool]] = None): """ diff --git a/run_tests.py b/run_tests.py index 6772f871..f8221dc9 100644 --- a/run_tests.py +++ b/run_tests.py @@ -26,7 +26,7 @@ def test_pylint(): 'too-many-instance-attributes,protected-access,' 'too-many-arguments,too-many-public-methods,too-many-branches,' 'consider-using-with', 'lavalink'] - pylint.Run(opts, reporter=reporter, do_exit=False) + pylint.Run(opts, reporter=reporter, exit=False) out = reporter.out.getvalue() failed = bool(out) From 76a22e83226b7d1888798a262e1f78b151e6fc94 Mon Sep 17 00:00:00 2001 From: devoxin Date: Sat, 9 Mar 2024 00:35:22 +0000 Subject: [PATCH 02/14] Document DataIO classes --- lavalink/dataio.py | 148 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 142 insertions(+), 6 deletions(-) diff --git a/lavalink/dataio.py b/lavalink/dataio.py index f841d7f8..3e354b77 100644 --- a/lavalink/dataio.py +++ b/lavalink/dataio.py @@ -23,6 +23,7 @@ """ import struct from base64 import b64decode +from collections.abc import Buffer from io import BytesIO from typing import Optional @@ -42,25 +43,76 @@ def _read(self, count: int): return self._buf.read(count) def read_byte(self) -> bytes: + """ + Reads a single byte from the stream. + + Returns + ------- + :class:`bytes` + """ return self._read(1) def read_boolean(self) -> bool: + """ + Reads a bool from the stream. + + Returns + ------- + :class:`bool` + """ result, = struct.unpack('B', self.read_byte()) return result != 0 def read_unsigned_short(self) -> int: + """ + Reads an unsigned short from the stream. + + Returns + ------- + :class:`int` + """ result, = struct.unpack('>H', self._read(2)) return result def read_int(self) -> int: + """ + Reads an int from the stream. + + Returns + ------- + :class:`int` + """ result, = struct.unpack('>i', self._read(4)) return result def read_long(self) -> int: + """ + Reads a long from the stream. + + Returns + ------- + :class:`int` + """ result, = struct.unpack('>Q', self._read(8)) return result def read_nullable_utf(self, utfm: bool = False) -> Optional[str]: + """ + .. _modified UTF: https://en.wikipedia.org/wiki/UTF-8#Modified_UTF-8 + + Reads an optional UTF string from the stream. + + Internally, this just reads a bool and then a string if the bool is ``True``. + + Parameters + ---------- + utfm: :class:`bool` + Whether to read the string as `modified UTF`_. + + Returns + ------- + Optional[:class:`str`] + """ exists = self.read_boolean() if not exists: @@ -69,10 +121,30 @@ def read_nullable_utf(self, utfm: bool = False) -> Optional[str]: return self.read_utfm() if utfm else self.read_utf().decode() def read_utf(self) -> bytes: + """ + Reads a UTF string from the stream. + + Returns + ------- + :class:`bytes` + """ text_length = self.read_unsigned_short() return self._read(text_length) def read_utfm(self) -> str: + """ + .. _modified UTF: https://en.wikipedia.org/wiki/UTF-8#Modified_UTF-8 + + Reads a UTF string from the stream. + + This method is different to :func:`read_utf` as it accounts for + different encoding methods utilised by Java's streams, which uses `modified UTF`_ + for character encoding. + + Returns + ------- + :class:`str` + """ text_length = self.read_unsigned_short() utf_string = self._read(text_length) return read_utfm(text_length, utf_string) @@ -86,31 +158,87 @@ def _write(self, data): self._buf.write(data) def write_byte(self, byte): + """ + Writes a single byte to the stream. + + Parameters + ---------- + byte: Any + This can be anything ``BytesIO.write()`` accepts. + """ self._buf.write(byte) - def write_boolean(self, boolean): + def write_boolean(self, boolean: bool): + """ + Writes a bool to the stream. + + Parameters + ---------- + boolean: :class:`bool` + The bool to write. + """ enc = struct.pack('B', 1 if boolean else 0) self.write_byte(enc) - def write_unsigned_short(self, short): + def write_unsigned_short(self, short: int): + """ + Writes an unsigned short to the stream. + + Parameters + ---------- + short: :class:`int` + The unsigned short to write. + """ enc = struct.pack('>H', short) self._write(enc) - def write_int(self, integer): + def write_int(self, integer: int): + """ + Writes an int to the stream. + + Parameters + ---------- + integer: :class:`int` + The integer to write. + """ enc = struct.pack('>i', integer) self._write(enc) - def write_long(self, long_value): + def write_long(self, long_value: int): + """ + Writes a long to the stream. + + Parameters + ---------- + long_value: :class:`int` + The long to write. + """ enc = struct.pack('>Q', long_value) self._write(enc) - def write_nullable_utf(self, utf_string): + def write_nullable_utf(self, utf_string: Optional[str]): + """ + Writes an optional string to the stream. + + Parameters + ---------- + utf_string: Optional[:class:`str`] + The optional string to write. + """ self.write_boolean(bool(utf_string)) if utf_string: self.write_utf(utf_string) - def write_utf(self, utf_string): + def write_utf(self, utf_string: str): + """ + Writes a utf string to the stream. + + Parameters + ---------- + utf_string: :class:`str` + The string to write. + """ utf = utf_string.encode('utf8') byte_len = len(utf) @@ -121,6 +249,14 @@ def write_utf(self, utf_string): self._write(utf) def finish(self) -> bytes: + """ + Finalizes the stream by writing the necessary flags, byte length etc. + + Returns + ---------- + :class:`bytes` + The finalized stream. + """ with BytesIO() as track_buf: byte_len = self._buf.getbuffer().nbytes flags = byte_len | (1 << 30) From d38e5f93a0789cc923a45985480c497b3b7d7eaa Mon Sep 17 00:00:00 2001 From: devoxin Date: Sat, 9 Mar 2024 00:35:50 +0000 Subject: [PATCH 03/14] remove unused import --- lavalink/dataio.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lavalink/dataio.py b/lavalink/dataio.py index 3e354b77..9e89f6e0 100644 --- a/lavalink/dataio.py +++ b/lavalink/dataio.py @@ -23,7 +23,6 @@ """ import struct from base64 import b64decode -from collections.abc import Buffer from io import BytesIO from typing import Optional From 61655f532f2542de8a106837351b45c4c766e9bd Mon Sep 17 00:00:00 2001 From: devoxin Date: Sat, 9 Mar 2024 00:43:49 +0000 Subject: [PATCH 04/14] lint --- lavalink/dataio.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/lavalink/dataio.py b/lavalink/dataio.py index 9e89f6e0..668d023a 100644 --- a/lavalink/dataio.py +++ b/lavalink/dataio.py @@ -65,7 +65,7 @@ def read_boolean(self) -> bool: def read_unsigned_short(self) -> int: """ Reads an unsigned short from the stream. - + Returns ------- :class:`int` @@ -76,7 +76,7 @@ def read_unsigned_short(self) -> int: def read_int(self) -> int: """ Reads an int from the stream. - + Returns ------- :class:`int` @@ -87,7 +87,7 @@ def read_int(self) -> int: def read_long(self) -> int: """ Reads a long from the stream. - + Returns ------- :class:`int` @@ -107,7 +107,7 @@ def read_nullable_utf(self, utfm: bool = False) -> Optional[str]: ---------- utfm: :class:`bool` Whether to read the string as `modified UTF`_. - + Returns ------- Optional[:class:`str`] @@ -122,7 +122,7 @@ def read_nullable_utf(self, utfm: bool = False) -> Optional[str]: def read_utf(self) -> bytes: """ Reads a UTF string from the stream. - + Returns ------- :class:`bytes` From ed61580ea7698b6cd9f8b95dc94b3117c3c24f82 Mon Sep 17 00:00:00 2001 From: Devoxin Date: Wed, 20 Mar 2024 22:18:47 +0000 Subject: [PATCH 05/14] Add some notes about pitfalls. --- lavalink/abc.py | 5 +++++ lavalink/events.py | 20 ++++++++++++++++++++ lavalink/player.py | 10 ++++++++++ 3 files changed, 35 insertions(+) diff --git a/lavalink/abc.py b/lavalink/abc.py index 93f0123a..ee44bfb6 100644 --- a/lavalink/abc.py +++ b/lavalink/abc.py @@ -113,6 +113,11 @@ async def play_track(self, Plays the given track. + Warning + ------- + Multiple calls to this method short timeframe could cause issues with the player's internal state, + which can cause errors when processing a :class:`TrackStartEvent`. + Parameters ---------- track: Union[:class:`AudioTrack`, :class:`DeferredAudioTrack`] diff --git a/lavalink/events.py b/lavalink/events.py index 6557679c..0f26a874 100644 --- a/lavalink/events.py +++ b/lavalink/events.py @@ -58,6 +58,11 @@ class TrackStuckEvent(Event): This event is emitted when the currently playing track is stuck (i.e. has not provided any audio). This is typically a fault of the track's underlying audio stream, and not Lavalink itself. + Note + ---- + You do not need to manually trigger the start of the next track in the queue within + this event when using the :class:`DefaultPlayer`. This is handled for you. + Attributes ---------- player: :class:`BasePlayer` @@ -80,6 +85,11 @@ class TrackExceptionEvent(Event): """ This event is emitted when a track encounters an exception during playback. + Note + ---- + You do not need to manually trigger the start of the next track in the queue within + this event when using the :class:`DefaultPlayer`. This is handled for you. + Attributes ---------- player: :class:`BasePlayer` @@ -108,6 +118,11 @@ class TrackEndEvent(Event): """ This event is emitted when the player finished playing a track. + Note + ---- + You do not need to manually trigger the start of the next track in the queue within + this event when using the :class:`DefaultPlayer`. This is handled for you. + Attributes ---------- player: :class:`BasePlayer` @@ -132,6 +147,11 @@ class TrackLoadFailedEvent(Event): produce a playable track. The player will not do anything by itself, so it is up to you to skip the broken track. + Note + ---- + This event will not automatically trigger the start of the next track in the queue, + so you must ensure that you do this if you want the player to continue playing from the queue. + Attributes ---------- player: :class:`BasePlayer` diff --git a/lavalink/player.py b/lavalink/player.py index ce5bf49a..05709479 100644 --- a/lavalink/player.py +++ b/lavalink/player.py @@ -231,6 +231,11 @@ async def play(self, This method differs from :func:`BasePlayer.play_track` in that it contains additional logic to handle certain attributes, such as ``loop``, ``shuffle``, and loading a base64 string from :class:`DeferredAudioTrack`. + Warning + ------- + Multiple calls to this method short timeframe could cause issues with the player's internal state, + which can cause errors when processing a :class:`TrackStartEvent`. + Parameters ---------- track: Optional[Union[:class:`DeferredAudioTrack`, :class:`AudioTrack`, Dict[str, Union[Optional[str], bool, int]]]] @@ -321,6 +326,11 @@ async def skip(self): """|coro| Plays the next track in the queue, if any. + + Warning + ------- + Multiple calls to this method short timeframe could cause issues with the player's internal state, + which can cause errors when processing a :class:`TrackStartEvent`. """ await self.play() From 571a2a3d0fdcd61b67fb782a2d7111c00ca1e08d Mon Sep 17 00:00:00 2001 From: Devoxin Date: Wed, 20 Mar 2024 22:20:03 +0000 Subject: [PATCH 06/14] grammar. --- lavalink/abc.py | 4 ++-- lavalink/player.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/lavalink/abc.py b/lavalink/abc.py index ee44bfb6..9cea81f0 100644 --- a/lavalink/abc.py +++ b/lavalink/abc.py @@ -115,8 +115,8 @@ async def play_track(self, Warning ------- - Multiple calls to this method short timeframe could cause issues with the player's internal state, - which can cause errors when processing a :class:`TrackStartEvent`. + Multiple calls to this method within a short timeframe could cause issues with the player's + internal state, which can cause errors when processing a :class:`TrackStartEvent`. Parameters ---------- diff --git a/lavalink/player.py b/lavalink/player.py index 05709479..9c79ba16 100644 --- a/lavalink/player.py +++ b/lavalink/player.py @@ -233,8 +233,8 @@ async def play(self, Warning ------- - Multiple calls to this method short timeframe could cause issues with the player's internal state, - which can cause errors when processing a :class:`TrackStartEvent`. + Multiple calls to this method within a short timeframe could cause issues with the player's + internal state, which can cause errors when processing a :class:`TrackStartEvent`. Parameters ---------- @@ -329,8 +329,8 @@ async def skip(self): Warning ------- - Multiple calls to this method short timeframe could cause issues with the player's internal state, - which can cause errors when processing a :class:`TrackStartEvent`. + Multiple calls to this method within a short timeframe could cause issues with the player's + internal state, which can cause errors when processing a :class:`TrackStartEvent`. """ await self.play() From 4c2341b918c646dc9a41e2424b8612b2d9c0f46c Mon Sep 17 00:00:00 2001 From: devoxin Date: Thu, 21 Mar 2024 12:05:56 +0000 Subject: [PATCH 07/14] simplify V3_KEYSET construction --- lavalink/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lavalink/utils.py b/lavalink/utils.py index cb780a4a..516e911c 100644 --- a/lavalink/utils.py +++ b/lavalink/utils.py @@ -32,7 +32,7 @@ from .source_decoders import DEFAULT_DECODER_MAPPING V2_KEYSET = {'title', 'author', 'length', 'identifier', 'isStream', 'uri', 'sourceName', 'position'} -V3_KEYSET = {'title', 'author', 'length', 'identifier', 'isStream', 'uri', 'artworkUrl', 'isrc', 'sourceName', 'position'} +V3_KEYSET = V2_KEYSET | {'artworkUrl', 'isrc'} def timestamp_to_millis(timestamp: str) -> int: From e9fd0c1f553c2b5a24c3b4fafa98dfed432b69ca Mon Sep 17 00:00:00 2001 From: devoxin Date: Thu, 21 Mar 2024 16:58:14 +0000 Subject: [PATCH 08/14] (utils/encode_track) allow specifying source-specific encoders --- lavalink/utils.py | 56 ++++++++++++++++++++++++++++++++--------------- 1 file changed, 38 insertions(+), 18 deletions(-) diff --git a/lavalink/utils.py b/lavalink/utils.py index 516e911c..59168252 100644 --- a/lavalink/utils.py +++ b/lavalink/utils.py @@ -141,12 +141,12 @@ def _read_track_common(reader: DataReader) -> Tuple[str, str, int, str, bool, Op def _write_track_common(track: Dict[str, Union[Optional[str], bool, int]], writer: DataWriter): - writer.write_utf(track['title']) - writer.write_utf(track['author']) - writer.write_long(track['length']) - writer.write_utf(track['identifier']) - writer.write_boolean(track['isStream']) - writer.write_nullable_utf(track['uri']) + writer.write_utf(track['title']) # type: ignore + writer.write_utf(track['author']) # type: ignore + writer.write_long(track['length']) # type: ignore + writer.write_utf(track['identifier']) # type: ignore + writer.write_boolean(track['isStream']) # type: ignore + writer.write_nullable_utf(track['uri']) # type: ignore def decode_track(track: str, # pylint: disable=R0914 @@ -160,7 +160,7 @@ def decode_track(track: str, # pylint: disable=R0914 The base64 track string. source_decoders: Mapping[:class:`str`, Callable[[:class:`DataReader`], Dict[:class:`str`, Any]]] A mapping of source-specific decoders to use. - Some Lavaplayer sources have additional fields encoded on a per-sourcemanager basis, so you can + Some Lavaplayer sources have additional fields encoded on a per-source manager basis, so you can specify a mapping of decoders that will handle decoding these additional fields. You can find some example decoders within the ``source_decoders`` file. This isn't required for all sources, so ensure that you need them before specifying. @@ -216,7 +216,8 @@ def decode_track(track: str, # pylint: disable=R0914 source_specific=source_specific_fields) -def encode_track(track: Dict[str, Union[Optional[str], int, bool]]) -> Tuple[int, str]: +def encode_track(track: Dict[str, Union[Optional[str], int, bool]], + source_encoders: Mapping[str, Callable[[DataWriter]]] = MISSING) -> Tuple[int, str]: """ Encodes a track dict into a base64 string, readable by the Lavalink server. @@ -230,6 +231,15 @@ def encode_track(track: Dict[str, Union[Optional[str], int, bool]]) -> Tuple[int ---------- track: Dict[str, Union[Optional[str], int, bool]] The track dict to serialize. + source_encoders: Mapping[:class:`str`, Callable[[:class:`DataWriter`]] + A mapping of source-specific encoders to use. + Some Lavaplayer sources have additional fields encoded on a per-source manager basis, so you can + specify a mapping of encoders that will handle encoding these additional fields. This isn't required + for all sources, so ensure that you need them before specifying. + + The mapping must be in the format of something like ``{'http': http_encoder_function}``, where the + key ``str`` is the name of the source. These functions will only be called if track's ``sourceName`` + field matches. Raises ------ @@ -250,12 +260,13 @@ def encode_track(track: Dict[str, Union[Optional[str], int, bool]]) -> Tuple[int raise InvalidTrack(f'Track object is missing keys required for serialization: {", ".join(missing_keys)}') if V3_KEYSET <= track_keys: - return (3, encode_track_v3(track)) + return (3, encode_track_v3(track, source_encoders)) - return (2, encode_track_v2(track)) + return (2, encode_track_v2(track, source_encoders)) -def encode_track_v2(track: Dict[str, Union[Optional[str], bool, int]]) -> str: +def encode_track_v2(track: Dict[str, Union[Optional[str], bool, int]], + source_encoders: Mapping[str, Callable[[DataWriter]]] = MISSING) -> str: assert V2_KEYSET <= track.keys() writer = DataWriter() @@ -263,24 +274,33 @@ def encode_track_v2(track: Dict[str, Union[Optional[str], bool, int]]) -> str: version = struct.pack('B', 2) writer.write_byte(version) _write_track_common(track, writer) - writer.write_utf(track['sourceName']) - writer.write_long(track['position']) + writer.write_utf(track['sourceName']) # type: ignore + + if source_encoders is not MISSING and track['sourceName'] in source_encoders: + source_encoders[track['sourceName']](writer) # type: ignore + + writer.write_long(track['position']) # type: ignore enc = writer.finish() return b64encode(enc).decode() -def encode_track_v3(track: Dict[str, Union[Optional[str], bool, int]]) -> str: +def encode_track_v3(track: Dict[str, Union[Optional[str], bool, int]], + source_encoders: Mapping[str, Callable[[DataWriter]]] = MISSING) -> str: assert V3_KEYSET <= track.keys() writer = DataWriter() version = struct.pack('B', 3) writer.write_byte(version) _write_track_common(track, writer) - writer.write_nullable_utf(track['artworkUrl']) - writer.write_nullable_utf(track['isrc']) - writer.write_utf(track['sourceName']) - writer.write_long(track['position']) + writer.write_nullable_utf(track['artworkUrl']) # type: ignore + writer.write_nullable_utf(track['isrc']) # type: ignore + writer.write_utf(track['sourceName']) # type: ignore + + if source_encoders is not MISSING and track['sourceName'] in source_encoders: + source_encoders[track['sourceName']](writer) # type: ignore + + writer.write_long(track['position']) # type: ignore enc = writer.finish() return b64encode(enc).decode() From 96d380ae71291e8d0de1d55527f0bae7cf54c2d7 Mon Sep 17 00:00:00 2001 From: devoxin Date: Thu, 21 Mar 2024 17:02:17 +0000 Subject: [PATCH 09/14] (utils/encode_track) Loosen typings, pass track object to encoder --- lavalink/__init__.py | 2 +- lavalink/utils.py | 44 ++++++++++++++++++++++---------------------- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/lavalink/__init__.py b/lavalink/__init__.py index b70bd525..abe9f2e8 100644 --- a/lavalink/__init__.py +++ b/lavalink/__init__.py @@ -4,7 +4,7 @@ __author__ = 'Devoxin' __license__ = 'MIT' __copyright__ = 'Copyright 2017-present Devoxin' -__version__ = '5.3.1' +__version__ = '5.4.0' from typing import Type diff --git a/lavalink/utils.py b/lavalink/utils.py index 59168252..c91df549 100644 --- a/lavalink/utils.py +++ b/lavalink/utils.py @@ -23,7 +23,7 @@ """ import struct from base64 import b64encode -from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Union +from typing import Any, Callable, Dict, Mapping, Optional, Tuple from .common import MISSING from .dataio import DataReader, DataWriter @@ -140,13 +140,13 @@ def _read_track_common(reader: DataReader) -> Tuple[str, str, int, str, bool, Op return (title, author, length, identifier, is_stream, uri) -def _write_track_common(track: Dict[str, Union[Optional[str], bool, int]], writer: DataWriter): - writer.write_utf(track['title']) # type: ignore - writer.write_utf(track['author']) # type: ignore - writer.write_long(track['length']) # type: ignore - writer.write_utf(track['identifier']) # type: ignore - writer.write_boolean(track['isStream']) # type: ignore - writer.write_nullable_utf(track['uri']) # type: ignore +def _write_track_common(track: Dict[str, Any], writer: DataWriter): + writer.write_utf(track['title']) + writer.write_utf(track['author']) + writer.write_long(track['length']) + writer.write_utf(track['identifier']) + writer.write_boolean(track['isStream']) + writer.write_nullable_utf(track['uri']) def decode_track(track: str, # pylint: disable=R0914 @@ -216,8 +216,8 @@ def decode_track(track: str, # pylint: disable=R0914 source_specific=source_specific_fields) -def encode_track(track: Dict[str, Union[Optional[str], int, bool]], - source_encoders: Mapping[str, Callable[[DataWriter]]] = MISSING) -> Tuple[int, str]: +def encode_track(track: Dict[str, Any], + source_encoders: Mapping[str, Callable[[DataWriter, Dict[str, Any]]]] = MISSING) -> Tuple[int, str]: """ Encodes a track dict into a base64 string, readable by the Lavalink server. @@ -265,8 +265,8 @@ def encode_track(track: Dict[str, Union[Optional[str], int, bool]], return (2, encode_track_v2(track, source_encoders)) -def encode_track_v2(track: Dict[str, Union[Optional[str], bool, int]], - source_encoders: Mapping[str, Callable[[DataWriter]]] = MISSING) -> str: +def encode_track_v2(track: Dict[str, Any], + source_encoders: Mapping[str, Callable[[DataWriter, Dict[str, Any]]]] = MISSING) -> str: assert V2_KEYSET <= track.keys() writer = DataWriter() @@ -274,33 +274,33 @@ def encode_track_v2(track: Dict[str, Union[Optional[str], bool, int]], version = struct.pack('B', 2) writer.write_byte(version) _write_track_common(track, writer) - writer.write_utf(track['sourceName']) # type: ignore + writer.write_utf(track['sourceName']) if source_encoders is not MISSING and track['sourceName'] in source_encoders: - source_encoders[track['sourceName']](writer) # type: ignore + source_encoders[track['sourceName']](writer, track) - writer.write_long(track['position']) # type: ignore + writer.write_long(track['position']) enc = writer.finish() return b64encode(enc).decode() -def encode_track_v3(track: Dict[str, Union[Optional[str], bool, int]], - source_encoders: Mapping[str, Callable[[DataWriter]]] = MISSING) -> str: +def encode_track_v3(track: Dict[str, Any], + source_encoders: Mapping[str, Callable[[DataWriter, Dict[str, Any]]]] = MISSING) -> str: assert V3_KEYSET <= track.keys() writer = DataWriter() version = struct.pack('B', 3) writer.write_byte(version) _write_track_common(track, writer) - writer.write_nullable_utf(track['artworkUrl']) # type: ignore - writer.write_nullable_utf(track['isrc']) # type: ignore - writer.write_utf(track['sourceName']) # type: ignore + writer.write_nullable_utf(track['artworkUrl']) + writer.write_nullable_utf(track['isrc']) + writer.write_utf(track['sourceName']) if source_encoders is not MISSING and track['sourceName'] in source_encoders: - source_encoders[track['sourceName']](writer) # type: ignore + source_encoders[track['sourceName']](writer, track) - writer.write_long(track['position']) # type: ignore + writer.write_long(track['position']) enc = writer.finish() return b64encode(enc).decode() From 6db178b96311b637f1f6bbb32a76d399dd1390d1 Mon Sep 17 00:00:00 2001 From: devoxin Date: Sun, 24 Mar 2024 09:35:24 +0000 Subject: [PATCH 10/14] annotate return type of source_encoders callable --- lavalink/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lavalink/utils.py b/lavalink/utils.py index c91df549..11a0629f 100644 --- a/lavalink/utils.py +++ b/lavalink/utils.py @@ -217,7 +217,7 @@ def decode_track(track: str, # pylint: disable=R0914 def encode_track(track: Dict[str, Any], - source_encoders: Mapping[str, Callable[[DataWriter, Dict[str, Any]]]] = MISSING) -> Tuple[int, str]: + source_encoders: Mapping[str, Callable[[DataWriter, Dict[str, Any]], None]] = MISSING) -> Tuple[int, str]: """ Encodes a track dict into a base64 string, readable by the Lavalink server. @@ -266,7 +266,7 @@ def encode_track(track: Dict[str, Any], def encode_track_v2(track: Dict[str, Any], - source_encoders: Mapping[str, Callable[[DataWriter, Dict[str, Any]]]] = MISSING) -> str: + source_encoders: Mapping[str, Callable[[DataWriter, Dict[str, Any]], None]] = MISSING) -> str: assert V2_KEYSET <= track.keys() writer = DataWriter() @@ -286,7 +286,7 @@ def encode_track_v2(track: Dict[str, Any], def encode_track_v3(track: Dict[str, Any], - source_encoders: Mapping[str, Callable[[DataWriter, Dict[str, Any]]]] = MISSING) -> str: + source_encoders: Mapping[str, Callable[[DataWriter, Dict[str, Any]], None]] = MISSING) -> str: assert V3_KEYSET <= track.keys() writer = DataWriter() From 39943479b046c67c216235d099f3bbbc91ec2ae6 Mon Sep 17 00:00:00 2001 From: Devoxin Date: Thu, 28 Mar 2024 21:25:33 +0000 Subject: [PATCH 11/14] Tentative fix for incorrect position during player node migration. --- lavalink/player.py | 13 +++++++++---- lavalink/transport.py | 10 ++++++++-- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/lavalink/player.py b/lavalink/player.py index 9c79ba16..14967d67 100644 --- a/lavalink/player.py +++ b/lavalink/player.py @@ -641,6 +641,7 @@ async def node_unavailable(self): Called when a player's node becomes unavailable. Useful for changing player state before it's moved to another node. """ + self._last_position = self.position self._internal_pause = True async def change_node(self, node: 'Node'): @@ -653,14 +654,16 @@ async def change_node(self, node: 'Node'): node: :class:`Node` The node the player is changed to. """ + old_node = self.node + self.node = node + + last_position = self.position + try: await self.node.destroy_player(self._internal_id) except (ClientError, RequestError): pass - old_node = self.node - self.node = node - if self._voice_state: await self._dispatch_voice_update() @@ -670,7 +673,9 @@ async def change_node(self, node: 'Node'): if isinstance(self.current, DeferredAudioTrack) and playable_track is None: playable_track = await self.current.load(self.client) - await self.node.update_player(guild_id=self._internal_id, encoded_track=playable_track, position=self.position, + self._last_position = last_position # Ensure that _last_position is correctly set, in case a node sends us bad data. + + await self.node.update_player(guild_id=self._internal_id, encoded_track=playable_track, position=last_position, paused=self.paused, volume=self.volume) self._last_update = int(time() * 1000) diff --git a/lavalink/transport.py b/lavalink/transport.py index 70c39769..bef3260e 100644 --- a/lavalink/transport.py +++ b/lavalink/transport.py @@ -36,6 +36,7 @@ from .stats import Stats if TYPE_CHECKING: + from .abc import BasePlayer from .client import Client from .node import Node @@ -241,12 +242,17 @@ async def _handle_message(self, data: Union[Dict[Any, Any], List[Any]]): await self.client._dispatch_event(NodeReadyEvent(self._node, data['sessionId'], data['resumed'])) elif op == 'playerUpdate': guild_id = int(data['guildId']) - player = self.client.player_manager.get(guild_id) + player: 'BasePlayer' = self.client.player_manager.get(guild_id) # type: ignore if not player: _log.debug('[Node:%s] Received playerUpdate for non-existent player! GuildId: %d', self._node.name, guild_id) return + if player.node != self._node: + _log.debug('[Node:%s] Received playerUpdate for a player that doesn\'t belong to this node (player is moving?) GuildId: %d', + self._node.name, guild_id) + return + state = data['state'] await player.update_state(state) await self.client._dispatch_event(PlayerUpdateEvent(player, state)) @@ -266,7 +272,7 @@ async def _handle_event(self, data: dict): data: :class:`dict` The data given from Lavalink. """ - player = self.client.player_manager.get(int(data['guildId'])) + player: 'BasePlayer' = self.client.player_manager.get(int(data['guildId'])) # type: ignore event_type = data['type'] if not player: From 12cafae9cd570e04e1766ba1791066e266644b83 Mon Sep 17 00:00:00 2001 From: Devoxin Date: Thu, 28 Mar 2024 21:41:26 +0000 Subject: [PATCH 12/14] Add a method to connect to node, and a kwarg to not immediately connect. --- lavalink/client.py | 10 +++++++--- lavalink/node.py | 31 +++++++++++++++++++++++++++++-- lavalink/nodemanager.py | 13 +++++++------ lavalink/transport.py | 9 +++++---- 4 files changed, 48 insertions(+), 15 deletions(-) diff --git a/lavalink/client.py b/lavalink/client.py index fda9370d..a4941ad6 100644 --- a/lavalink/client.py +++ b/lavalink/client.py @@ -265,7 +265,7 @@ def get_source(self, source_name: str) -> Optional[Source]: return next((source for source in self.sources if source.name == source_name), None) def add_node(self, host: str, port: int, password: str, region: str, name: Optional[str] = None, - ssl: bool = False, session_id: Optional[str] = None) -> Node: + ssl: bool = False, session_id: Optional[str] = None, connect: bool = True) -> Node: """ Shortcut for :func:`NodeManager.add_node`. @@ -283,20 +283,24 @@ def add_node(self, host: str, port: int, password: str, region: str, name: Optio The region to assign this node to. name: Optional[:class:`str`] An identifier for the node that will show in logs. Defaults to ``None``. - ssl: Optional[:class:`bool`] + ssl: :class:`bool` Whether to use SSL for the node. SSL will use ``wss`` and ``https``, instead of ``ws`` and ``http``, respectively. Your node should support SSL if you intend to enable this, either via reverse proxy or other methods. Only enable this if you know what you're doing. session_id: Optional[:class:`str`] The ID of the session to resume. Defaults to ``None``. Only specify this if you have the ID of the session you want to resume. + connect: :class:`bool` + Whether to immediately connect to the node after creating it. + If ``False``, you must call :func:`Node.connect` if you require WebSocket functionality. Returns ------- :class:`Node` The created Node instance. """ - return self.node_manager.add_node(host, port, password, region, name, ssl, session_id) + return self.node_manager.add_node(host, port, password, region, name, ssl, + session_id, connect) async def get_local_tracks(self, query: str) -> LoadResult: """|coro| diff --git a/lavalink/node.py b/lavalink/node.py index 1a5d4b01..cabb0901 100644 --- a/lavalink/node.py +++ b/lavalink/node.py @@ -21,6 +21,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from asyncio import Task from collections import defaultdict from time import time from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar, @@ -62,10 +63,10 @@ class Node: __slots__ = ('client', 'manager', '_transport', 'region', 'name', 'stats') def __init__(self, manager, host: str, port: int, password: str, region: str, name: Optional[str] = None, - ssl: bool = False, session_id: Optional[str] = None): + ssl: bool = False, session_id: Optional[str] = None, connect: bool = True): self.client: 'Client' = manager.client self.manager: 'NodeManager' = manager - self._transport = Transport(self, host, port, password, ssl, session_id) + self._transport = Transport(self, host, port, password, ssl, session_id, connect) self.region: str = region self.name: str = name or f'{region}-{host}:{port}' @@ -139,6 +140,32 @@ async def get_rest_latency(self) -> float: return (time() - start) * 1000 + async def connect(self, force: bool = False) -> Optional[Task[Any]]: + """|coro| + + Initiates a WebSocket connection to this node. + If a connection already exists, and ``force`` is ``False``, this will not do anything. + + Parameters + ---------- + force: :class:`bool` + Whether to close any existing WebSocket connections and re-establish a connection to + the node. + + Returns + ------- + Optional[:class:`asyncio.Task`[Any]] + The WebSocket connection task, or ``None`` if a WebSocket connection already exists and force + is ``False``. + """ + if self._transport.ws_connected: + if not force: + return None + + await self._transport.close() + + return self._transport.connect() + async def destroy(self): """|coro| diff --git a/lavalink/nodemanager.py b/lavalink/nodemanager.py index 6cb0289d..8a1db232 100644 --- a/lavalink/nodemanager.py +++ b/lavalink/nodemanager.py @@ -83,7 +83,7 @@ def available_nodes(self) -> List[Node]: return [n for n in self.nodes if n.available] def add_node(self, host: str, port: int, password: str, region: str, name: Optional[str] = None, - ssl: bool = False, session_id: Optional[str] = None) -> Node: + ssl: bool = False, session_id: Optional[str] = None, connect: bool = True) -> Node: """ Adds a node to this node manager. @@ -99,23 +99,24 @@ def add_node(self, host: str, port: int, password: str, region: str, name: Optio The region to assign this node to. name: Optional[:class:`str`] An identifier for the node that will show in logs. Defaults to ``None``. - reconnect_attempts: Optional[:class:`int`] - The amount of times connection with the node will be reattempted before giving up. - Set to `-1` for infinite. Defaults to ``3``. - ssl: Optional[:class:`bool`] + ssl: :class:`bool` Whether to use SSL for the node. SSL will use ``wss`` and ``https``, instead of ``ws`` and ``http``, respectively. Your node should support SSL if you intend to enable this, either via reverse proxy or other methods. Only enable this if you know what you're doing. session_id: Optional[:class:`str`] The ID of the session to resume. Defaults to ``None``. Only specify this if you have the ID of the session you want to resume. + connect: :class:`bool` + Whether to immediately connect to the node after creating it. + If ``False``, you must call :func:`Node.connect` if you require WebSocket functionality. Returns ------- :class:`Node` The created Node instance. """ - node = Node(self, host, port, password, region, name, ssl, session_id) + node = Node(self, host, port, password, region, name, ssl, + session_id, connect) self.nodes.append(node) return node diff --git a/lavalink/transport.py b/lavalink/transport.py index bef3260e..4a0dc7df 100644 --- a/lavalink/transport.py +++ b/lavalink/transport.py @@ -51,11 +51,11 @@ class Transport: - """ The class responsible for dealing with connections to Lavalink. """ + """ The class responsible for handling connections to a Lavalink server. """ __slots__ = ('client', '_node', '_session', '_ws', '_message_queue', 'trace_requests', '_host', '_port', '_password', '_ssl', 'session_id', '_destroyed') - def __init__(self, node, host: str, port: int, password: str, ssl: bool, session_id: Optional[str]): + def __init__(self, node, host: str, port: int, password: str, ssl: bool, session_id: Optional[str], connect: bool = True): self.client: 'Client' = node.client self._node: 'Node' = node @@ -72,7 +72,8 @@ def __init__(self, node, host: str, port: int, password: str, ssl: bool, session self.session_id: Optional[str] = session_id self._destroyed: bool = False - self.connect() + if connect: + self.connect() @property def ws_connected(self): @@ -93,7 +94,7 @@ async def close(self, code=aiohttp.WSCloseCode.OK): await self._ws.close(code=code) self._ws = None - def connect(self) -> asyncio.Task: + def connect(self) -> asyncio.Task[Any]: """ Attempts to establish a connection to Lavalink. """ loop = asyncio.get_event_loop() return loop.create_task(self._connect()) From 27dd0b59e42b5b72495fe2d64f7cda73fbecac29 Mon Sep 17 00:00:00 2001 From: Devoxin Date: Thu, 28 Mar 2024 21:57:10 +0000 Subject: [PATCH 13/14] Correctly re-raise AuthenticationError as such --- lavalink/transport.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lavalink/transport.py b/lavalink/transport.py index 4a0dc7df..f7d43823 100644 --- a/lavalink/transport.py +++ b/lavalink/transport.py @@ -379,7 +379,7 @@ async def _request(self, method: str, path: str, to=None, trace: bool = False, v raise RequestError('An invalid response was received from the node.', status=res.status, response=await res.json(), params=kwargs.get('params', {})) - except RequestError: - raise + except (AuthenticationError, RequestError): + raise # Pass the caught errors back to the caller in their 'original' form. except Exception as original: # It's not pretty but aiohttp doesn't specify what exceptions can be thrown. raise ClientError from original From 8adadabccff58878d2018d5c52e1e57b7cf91080 Mon Sep 17 00:00:00 2001 From: devoxin Date: Sun, 21 Apr 2024 23:24:22 +0100 Subject: [PATCH 14/14] asyncio.Task does not support documenting inner type --- lavalink/node.py | 4 ++-- lavalink/transport.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lavalink/node.py b/lavalink/node.py index cabb0901..e9ee1147 100644 --- a/lavalink/node.py +++ b/lavalink/node.py @@ -140,7 +140,7 @@ async def get_rest_latency(self) -> float: return (time() - start) * 1000 - async def connect(self, force: bool = False) -> Optional[Task[Any]]: + async def connect(self, force: bool = False) -> Optional[Task]: """|coro| Initiates a WebSocket connection to this node. @@ -154,7 +154,7 @@ async def connect(self, force: bool = False) -> Optional[Task[Any]]: Returns ------- - Optional[:class:`asyncio.Task`[Any]] + Optional[:class:`asyncio.Task`] The WebSocket connection task, or ``None`` if a WebSocket connection already exists and force is ``False``. """ diff --git a/lavalink/transport.py b/lavalink/transport.py index f7d43823..cd875614 100644 --- a/lavalink/transport.py +++ b/lavalink/transport.py @@ -94,7 +94,7 @@ async def close(self, code=aiohttp.WSCloseCode.OK): await self._ws.close(code=code) self._ws = None - def connect(self) -> asyncio.Task[Any]: + def connect(self) -> asyncio.Task: """ Attempts to establish a connection to Lavalink. """ loop = asyncio.get_event_loop() return loop.create_task(self._connect())