diff --git a/docs/lavalink.rst b/docs/lavalink.rst index 6b47ac33..e12a1708 100644 --- a/docs/lavalink.rst +++ b/docs/lavalink.rst @@ -24,6 +24,14 @@ Client .. autoclass:: Client :members: +DataIO +------ +.. autoclass:: DataReader + :members: + +.. autoclass:: DataWriter + :members: + Errors ------ .. autoclass:: ClientError diff --git a/lavalink/__init__.py b/lavalink/__init__.py index 5abf1407..973e28b0 100644 --- a/lavalink/__init__.py +++ b/lavalink/__init__.py @@ -4,11 +4,14 @@ __author__ = 'Devoxin' __license__ = 'MIT' __copyright__ = 'Copyright 2017-present Devoxin' -__version__ = '5.2.0' +__version__ = '5.3.0' +from typing import Type + from .abc import BasePlayer, DeferredAudioTrack, Source from .client import Client +from .dataio import DataReader, DataWriter from .errors import (AuthenticationError, ClientError, InvalidTrack, LoadError, RequestError) from .events import (Event, IncomingWebSocketMessage, NodeChangedEvent, @@ -24,12 +27,13 @@ from .playermanager import PlayerManager from .server import (AudioTrack, EndReason, LoadResult, LoadResultError, LoadType, PlaylistInfo, Plugin, Severity) +from .source_decoders import DEFAULT_DECODER_MAPPING from .stats import Penalty, Stats from .utils import (decode_track, encode_track, format_time, parse_time, timestamp_to_millis) -def listener(*events: Event): +def listener(*events: Type[Event]): """ Marks this function as an event listener for Lavalink.py. This **must** be used on class methods, and you must ensure that you register diff --git a/lavalink/abc.py b/lavalink/abc.py index 9e442c6b..93f0123a 100644 --- a/lavalink/abc.py +++ b/lavalink/abc.py @@ -274,7 +274,7 @@ class DeferredAudioTrack(ABC, AudioTrack): for example. """ @abstractmethod - async def load(self, client: 'Client'): + async def load(self, client: 'Client') -> Optional[str]: """|coro| Retrieves a base64 string that's playable by Lavalink. @@ -288,8 +288,9 @@ async def load(self, client: 'Client'): Returns ------- - :class:`str` + Optional[:class:`str`] A Lavalink-compatible base64-encoded string containing track metadata. + If a track string cannot be returned, you may return ``None`` or throw a :class:`LoadError`. """ raise NotImplementedError diff --git a/lavalink/client.py b/lavalink/client.py index b2a59912..fda9370d 100644 --- a/lavalink/client.py +++ b/lavalink/client.py @@ -28,7 +28,7 @@ import random from collections import defaultdict from inspect import getmembers, ismethod -from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, +from typing import (Any, Callable, Dict, Generic, List, Optional, Sequence, Set, Tuple, Type, TypeVar, Union) import aiohttp @@ -47,7 +47,7 @@ EventT = TypeVar('EventT', bound=Event) -class Client: +class Client(Generic[PlayerT]): """ Represents a Lavalink client used to manage nodes and connections. @@ -102,7 +102,7 @@ def __init__(self, user_id: Union[int, str], player: Type[PlayerT] = DefaultPlay self._user_id: int = int(user_id) self._event_hooks = defaultdict(list) self.node_manager: NodeManager = NodeManager(self, regions, connect_back) - self.player_manager: PlayerManager = PlayerManager(self, player) + self.player_manager: PlayerManager[PlayerT] = PlayerManager(self, player) self.sources: Set[Source] = set() @property @@ -113,7 +113,7 @@ def nodes(self) -> List[Node]: return self.node_manager.nodes @property - def players(self) -> Dict[int, BasePlayer]: + def players(self) -> Dict[int, PlayerT]: """ Convenience shortcut for :attr:`PlayerManager.players`. """ @@ -207,7 +207,7 @@ def remove_event_hooks(self, *, events: Optional[Sequence[EventT]] = None, hooks ---------- events: Sequence[:class:`Event`] The events to remove the hooks from. This parameter can be omitted, - and the events registered on the function via :meth:`listener` will be used instead, if applicable. + and the events registered on the function via :func:`listener` will be used instead, if applicable. Otherwise, a default value of ``Generic`` is used instead. hooks: Sequence[Callable] A list of hook methods to remove. diff --git a/lavalink/dataio.py b/lavalink/dataio.py index 62d4ad86..f841d7f8 100644 --- a/lavalink/dataio.py +++ b/lavalink/dataio.py @@ -31,9 +31,14 @@ class DataReader: def __init__(self, base64_str: str): - self._buf: BytesIO = BytesIO(b64decode(base64_str)) + self._buf = BytesIO(b64decode(base64_str)) - def _read(self, count): + @property + def remaining(self) -> int: + """ The amount of bytes left to be read. """ + return self._buf.getbuffer().nbytes - self._buf.tell() + + def _read(self, count: int): return self._buf.read(count) def read_byte(self) -> bytes: @@ -55,8 +60,13 @@ def read_long(self) -> int: result, = struct.unpack('>Q', self._read(8)) return result - def read_nullable_utf(self) -> Optional[str]: - return self.read_utf().decode() if self.read_boolean() else None + def read_nullable_utf(self, utfm: bool = False) -> Optional[str]: + exists = self.read_boolean() + + if not exists: + return None + + return self.read_utfm() if utfm else self.read_utf().decode() def read_utf(self) -> bytes: text_length = self.read_unsigned_short() @@ -110,7 +120,7 @@ def write_utf(self, utf_string): self.write_unsigned_short(byte_len) self._write(utf) - def finish(self): + def finish(self) -> bytes: with BytesIO() as track_buf: byte_len = self._buf.getbuffer().nbytes flags = byte_len | (1 << 30) diff --git a/lavalink/filters.py b/lavalink/filters.py index ed3ed71b..cba8f68a 100644 --- a/lavalink/filters.py +++ b/lavalink/filters.py @@ -485,7 +485,7 @@ def update(self, *, smoothing: float): ------ :class:`ValueError` """ - smoothing = float('smoothing') + smoothing = float(smoothing) if smoothing <= 1: raise ValueError('smoothing must be bigger than 1') diff --git a/lavalink/player.py b/lavalink/player.py index 1061f7b9..ce5bf49a 100644 --- a/lavalink/player.py +++ b/lavalink/player.py @@ -533,6 +533,20 @@ def get_filter(self, _filter: Union[Type[FilterT], str]): return self.filters.get(filter_name.lower(), None) + async def remove_filters(self, *filters: Union[Type[FilterT], str]): + """|coro| + + Removes multiple filters from the player, undoing any effects applied to the audio. + This is similar to :func:`remove_filter` but instead allows you to remove multiple filters with one call. + + Parameters + ---------- + filters: Union[Type[:class:`Filter`], :class:`str`] + The filters to remove. Can be filter name, or filter class (**not** an instance of). + """ + for fltr in filters: + await self.remove_filter(fltr) + async def remove_filter(self, _filter: Union[Type[FilterT], str]): """|coro| diff --git a/lavalink/playermanager.py b/lavalink/playermanager.py index beb44373..628c16b3 100644 --- a/lavalink/playermanager.py +++ b/lavalink/playermanager.py @@ -22,8 +22,8 @@ SOFTWARE. """ import logging -from typing import (TYPE_CHECKING, Callable, Dict, Iterator, Optional, Tuple, - Type, TypeVar) +from typing import (TYPE_CHECKING, Callable, Dict, Generic, Iterator, Optional, + Tuple, Type, TypeVar, Union, overload) from .errors import ClientError from .node import Node @@ -35,9 +35,10 @@ _log = logging.getLogger(__name__) PlayerT = TypeVar('PlayerT', bound=BasePlayer) +CustomPlayerT = TypeVar('CustomPlayerT', bound=BasePlayer) -class PlayerManager: +class PlayerManager(Generic[PlayerT]): """ Represents the player manager that contains all the players. @@ -61,22 +62,22 @@ def __init__(self, client, player: Type[PlayerT]): self.client: 'Client' = client self._player_cls: Type[PlayerT] = player - self.players: Dict[int, BasePlayer] = {} + self.players: Dict[int, PlayerT] = {} def __len__(self) -> int: return len(self.players) - def __iter__(self) -> Iterator[Tuple[int, BasePlayer]]: + def __iter__(self) -> Iterator[Tuple[int, PlayerT]]: """ Returns an iterator that yields a tuple of (guild_id, player). """ for guild_id, player in self.players.items(): yield guild_id, player - def values(self) -> Iterator[BasePlayer]: + def values(self) -> Iterator[PlayerT]: """ Returns an iterator that yields only values. """ for player in self.players.values(): yield player - def find_all(self, predicate: Optional[Callable[[BasePlayer], bool]] = None): + def find_all(self, predicate: Optional[Callable[[PlayerT], bool]] = None): """ Returns a list of players that match the given predicate. @@ -96,7 +97,7 @@ def find_all(self, predicate: Optional[Callable[[BasePlayer], bool]] = None): return [p for p in self.players.values() if bool(predicate(p))] - def get(self, guild_id: int) -> Optional[BasePlayer]: + def get(self, guild_id: int) -> Optional[PlayerT]: """ Gets a player from cache. @@ -126,13 +127,32 @@ def remove(self, guild_id: int): player = self.players.pop(guild_id) player.cleanup() + @overload + def create(self, + guild_id: int, + *, + region: Optional[str] = ..., + endpoint: Optional[str] = ..., + node: Optional[Node] = ...) -> PlayerT: + ... + + @overload + def create(self, + guild_id: int, + *, + region: Optional[str] = ..., + endpoint: Optional[str] = ..., + node: Optional[Node] = ..., + cls: Type[CustomPlayerT]) -> CustomPlayerT: + ... + def create(self, guild_id: int, *, region: Optional[str] = None, endpoint: Optional[str] = None, node: Optional[Node] = None, - cls: Optional[Type[PlayerT]] = None) -> BasePlayer: + cls: Optional[Type[CustomPlayerT]] = None) -> Union[CustomPlayerT, PlayerT]: """ Creates a player if one doesn't exist with the given information. diff --git a/lavalink/server.py b/lavalink/server.py index 9744d84b..608d0cbb 100644 --- a/lavalink/server.py +++ b/lavalink/server.py @@ -86,7 +86,7 @@ class AudioTrack: The track's uploader. duration: :class:`int` The duration of the track, in milliseconds. - stream: :class:`bool` + is_stream: :class:`bool` Whether the track is a live-stream. title: :class:`str` The title of the track. @@ -110,7 +110,7 @@ class AudioTrack: extra: Dict[str, Any] Any extra properties given to this AudioTrack will be stored here. """ - __slots__ = ('raw', 'track', 'identifier', 'is_seekable', 'author', 'duration', 'stream', 'title', 'uri', + __slots__ = ('raw', 'track', 'identifier', 'is_seekable', 'author', 'duration', 'is_stream', 'title', 'uri', 'artwork_url', 'isrc', 'position', 'source_name', 'plugin_info', 'user_data', 'extra') def __init__(self, data: dict, requester: int = 0, **extra): @@ -127,7 +127,7 @@ def __init__(self, data: dict, requester: int = 0, **extra): self.is_seekable: bool = info['isSeekable'] self.author: str = info['author'] self.duration: int = info['length'] - self.stream: bool = info['isStream'] + self.is_stream: bool = info['isStream'] self.title: str = info['title'] self.uri: str = info['uri'] self.artwork_url: Optional[str] = info.get('artworkUrl') @@ -150,6 +150,17 @@ def __getitem__(self, name): def from_dict(cls, mapping: dict): return cls(mapping) + @property + def stream(self) -> bool: + """ + Property indicating whether this track is a stream. + + .. deprecated:: 5.3.0 + To be consistent with attribute naming, this property has been deprecated + in favour of ``is_stream``. + """ + return self.is_stream + @property def requester(self) -> int: return self.extra['requester'] diff --git a/lavalink/source_decoders.py b/lavalink/source_decoders.py new file mode 100644 index 00000000..a8a50f36 --- /dev/null +++ b/lavalink/source_decoders.py @@ -0,0 +1,61 @@ +""" +MIT License + +Copyright (c) 2017-present Devoxin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" +from typing import Any, Callable, Dict, Mapping + +from .dataio import DataReader + + +def decode_probe_info(reader: DataReader) -> Mapping[str, Any]: + probe_info = reader.read_utf().decode() + return {'probe_info': probe_info} + + +def decode_lavasrc_fields(reader: DataReader) -> Mapping[str, Any]: + if reader.remaining <= 8: # 8 bytes (long) = position field + return {} + + album_name = reader.read_nullable_utf() + album_url = reader.read_nullable_utf() + artist_url = reader.read_nullable_utf() + artist_artwork_url = reader.read_nullable_utf() + preview_url = reader.read_nullable_utf() + is_preview = reader.read_boolean() + + return { + 'album_name': album_name, + 'album_url': album_url, + 'artist_url': artist_url, + 'artist_artwork_url': artist_artwork_url, + 'preview_url': preview_url, + 'is_preview': is_preview + } + + +DEFAULT_DECODER_MAPPING: Dict[str, Callable[[DataReader], Mapping[str, Any]]] = { + 'http': decode_probe_info, + 'local': decode_probe_info, + 'deezer': decode_lavasrc_fields, + 'spotify': decode_lavasrc_fields, + 'applemusic': decode_lavasrc_fields +} diff --git a/lavalink/transport.py b/lavalink/transport.py index 530a728a..70c39769 100644 --- a/lavalink/transport.py +++ b/lavalink/transport.py @@ -181,7 +181,7 @@ async def _listen(self): try: await self._handle_message(msg.json()) except Exception: # pylint: disable=W0718 - _log.error('[Node:%s] Unexpected error occurred whilst processing websocket message', self._node.name) + _log.exception('[Node:%s] Unexpected error occurred whilst processing websocket message', self._node.name) elif msg.type == aiohttp.WSMsgType.ERROR: exc = self._ws.exception() _log.error('[Node:%s] Exception in WebSocket!', self._node.name, exc_info=exc) diff --git a/lavalink/utils.py b/lavalink/utils.py index 903c8de8..cb780a4a 100644 --- a/lavalink/utils.py +++ b/lavalink/utils.py @@ -23,11 +23,13 @@ """ import struct from base64 import b64encode -from typing import Dict, Optional, Tuple, Union +from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Union +from .common import MISSING from .dataio import DataReader, DataWriter from .errors import InvalidTrack from .player import AudioTrack +from .source_decoders import DEFAULT_DECODER_MAPPING V2_KEYSET = {'title', 'author', 'length', 'identifier', 'isStream', 'uri', 'sourceName', 'position'} V3_KEYSET = {'title', 'author', 'length', 'identifier', 'isStream', 'uri', 'artworkUrl', 'isrc', 'sourceName', 'position'} @@ -147,7 +149,8 @@ def _write_track_common(track: Dict[str, Union[Optional[str], bool, int]], write writer.write_nullable_utf(track['uri']) -def decode_track(track: str) -> AudioTrack: +def decode_track(track: str, # pylint: disable=R0914 + source_decoders: Mapping[str, Callable[[DataReader], Mapping[str, Any]]] = MISSING) -> AudioTrack: """ Decodes a base64 track string into an AudioTrack object. @@ -155,11 +158,25 @@ def decode_track(track: str) -> AudioTrack: ---------- track: :class:`str` The base64 track string. + source_decoders: Mapping[:class:`str`, Callable[[:class:`DataReader`], Dict[:class:`str`, Any]]] + A mapping of source-specific decoders to use. + Some Lavaplayer sources have additional fields encoded on a per-sourcemanager basis, so you can + specify a mapping of decoders that will handle decoding these additional fields. You can find some + example decoders within the ``source_decoders`` file. This isn't required for all sources, so ensure + that you need them before specifying. + + To overwrite library-provided decoders, just specify them within the mapping and the new decoders will + be used. Returns ------- :class:`AudioTrack` """ + decoders = DEFAULT_DECODER_MAPPING.copy() + + if source_decoders is not MISSING: + decoders.update(source_decoders) + reader = DataReader(track) flags = (reader.read_int() & 0xC0000000) >> 30 @@ -173,6 +190,11 @@ def decode_track(track: str) -> AudioTrack: extra_fields['isrc'] = reader.read_nullable_utf() source = reader.read_utf().decode() + source_specific_fields = {} + + if source in decoders: + source_specific_fields.update(decoders[source](reader)) + position = reader.read_long() track_object = { @@ -190,7 +212,8 @@ def decode_track(track: str) -> AudioTrack: } } - return AudioTrack(track_object, 0, position=position, encoder_version=version) + return AudioTrack(track_object, 0, position=position, encoder_version=version, + source_specific=source_specific_fields) def encode_track(track: Dict[str, Union[Optional[str], int, bool]]) -> Tuple[int, str]: