diff --git a/pyproject.toml b/pyproject.toml index c008976..82c23ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,10 @@ ignore = [ ] per-file-ignores = {"**/{tests}/*" = ["F841", "D", "E501"]} +[tool.pytest.ini_options] +addopts="-m 'not performance'" +markers = "performance: performance tests" + [tool.coverage.report] exclude_lines = ["pragma: no cover", "@abstract"] precision = 2 diff --git a/src/token_status_list/__init__.py b/src/token_status_list/__init__.py index 3789c03..e203d08 100644 --- a/src/token_status_list/__init__.py +++ b/src/token_status_list/__init__.py @@ -5,3 +5,149 @@ This implementation is based on draft 2, found here: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-status-list-02 """ + +import base64 +from typing import Literal, Union +import zlib + + +def b64url_decode(value: bytes) -> bytes: + """Return the base64 url encoded value, without padding.""" + padding_needed = 4 - (len(value) % 4) + if padding_needed != 4: + value += b"=" * padding_needed + + return base64.urlsafe_b64decode(value) + + +def b64url_encode(value: bytes) -> bytes: + """Return the decoded base64 url encoded value, without padding.""" + return base64.urlsafe_b64encode(value).rstrip(b"=") + + +VALID = 0x00 +INVALID = 0x01 +SUSPENDED = 0x02 + +Bits = Union[Literal[1, 2, 4, 8], int] +StatusTypes = Union[Literal[0x00, 0x01, 0x02], int] + + +class TokenStatusList: + """Token Status List.""" + + SHIFT_BY = {1: 3, 2: 2, 4: 1, 8: 0} + # Number of elements that fit in a byte for a number of bits + PER_BYTE = {1: 8, 2: 4, 4: 2, 8: 1} + MASK = {1: 0b1, 2: 0b11, 4: 0b1111, 8: 0b11111111} + MAX = {1: 1, 2: 3, 4: 15, 8: 255} + + def __init__( + self, + bits: Bits, + lst: bytes, + ): + """Initialize the list.""" + if bits not in (1, 2, 4, 8): + raise ValueError("Invalid bits value, must be one of: 1, 2, 4, 8") + + self.bits = bits + self.per_byte = self.PER_BYTE[bits] + self.shift = self.SHIFT_BY[bits] + self.mask = self.MASK[bits] + self.max = self.MAX[bits] + + # len * indexes per byte + self.size = len(lst) << self.shift + self.lst = bytearray(lst) + + @classmethod + def of_size(cls, bits: Bits, size: int) -> "TokenStatusList": + """Create empty list of a given size.""" + per_byte = cls.PER_BYTE[bits] + if size < 1: + raise ValueError("size must be greater than 1") + # size mod per_byte + if size & (per_byte - 1) != 0: + raise ValueError(f"size must be multiple of {per_byte}") + + length = size >> cls.SHIFT_BY[bits] + return cls(bits, bytearray(length)) + + @classmethod + def with_at_least(cls, bits: Bits, size: int): + """Create an empty list large enough to accommodate at least the given size.""" + + def __getitem__(self, index: int): + """Retrieve the status of an index.""" + return self.get(index) + + def __setitem__(self, index: int, status: StatusTypes): + """Set the status of an index.""" + return self.set(index, status) + + def get(self, index: int): + """Retrieve the status of an index.""" + # index / indexes per byte + byte_idx = index >> self.shift + # index mod indexes per byte * bits + # Determines the number of shifts to move relevant bits all the way right + bit_idx = (index & (self.per_byte - 1)) * self.bits + # Shift relevant bits all the way right and mask out irrelevant bits + return self.mask & (self.lst[byte_idx] >> bit_idx) + + def set(self, index: int, status: StatusTypes): + """Set the status of an index.""" + if status > self.max: + raise ValueError(f"status {status} too large for list with bits {self.bits}") + if index >= self.size: + raise ValueError("Invalid index; out of range") + + # index / indexes per byte + byte_idx = index >> self.shift + # index mod indexes per byte * bits + # Determines the number of shifts to move relevant bits all the way right + bit_idx = (index & (self.per_byte - 1)) * self.bits + byte = self.lst[byte_idx] + # Shift relevant bits all the way right and mask out irrelevant bits + current = self.mask & (byte >> bit_idx) + if current == 0x01 and status != 0x01: + raise ValueError("Cannot change status of index previously set to invalid") + + # Shift status to relevant position + status <<= bit_idx + # Create mask to clear bits getting reset + # (0 where the bits will be, 1 everywhere else) + clear_mask = ~(self.mask << bit_idx) + # Reset bits to zero + byte &= clear_mask + # Set status bits + self.lst[byte_idx] = byte | status + + def compressed(self) -> bytes: + """Return compressed list.""" + return zlib.compress(self.lst, level=9) + + def serialize(self) -> dict: + """Return json serializable representation of status list.""" + return {"bits": self.bits, "lst": b64url_encode(self.compressed()).decode()} + + @classmethod + def deserialize(cls, value: dict) -> "TokenStatusList": + """Parse status list from dictionary.""" + bits = value.get("bits") + if not bits: + raise ValueError("bits missing from status list dictionary") + + if not isinstance(bits, int): + raise TypeError("bits must be int") + + lst = value.get("lst") + if not lst: + raise ValueError("lst missing from status list dictionary") + + if not isinstance(lst, str): + raise TypeError("lst must be str") + + parsed_lst = zlib.decompress(b64url_decode(lst.encode())) + return cls(bits, parsed_lst) diff --git a/tests/test_token_status_list.py b/tests/test_token_status_list.py new file mode 100644 index 0000000..033a1e4 --- /dev/null +++ b/tests/test_token_status_list.py @@ -0,0 +1,202 @@ +"""Test TokenStatusList.""" + +import pytest + +from token_status_list import INVALID, SUSPENDED, TokenStatusList, VALID, b64url_decode + + +def test_get_1_bits(): + status = TokenStatusList(1, b"\xb9\xa3") + assert status.size == 16 + assert len(status.lst) == 2 + assert status[0] == 1 + assert status[1] == 0 + assert status[2] == 0 + assert status[3] == 1 + assert status[4] == 1 + assert status[5] == 1 + assert status[6] == 0 + assert status[7] == 1 + assert status[8] == 1 + assert status[9] == 1 + assert status[10] == 0 + assert status[11] == 0 + assert status[12] == 0 + assert status[13] == 1 + assert status[14] == 0 + assert status[15] == 1 + + +def test_set_1_bits(): + status = TokenStatusList.of_size(1, 16) + assert len(status.lst) == 2 + status[0] = 1 + status[1] = 0 + status[2] = 0 + status[3] = 1 + status[4] = 1 + status[5] = 1 + status[6] = 0 + status[7] = 1 + status[8] = 1 + status[9] = 1 + status[10] = 0 + status[11] = 0 + status[12] = 0 + status[13] = 1 + status[14] = 0 + status[15] = 1 + assert status[0] == 1 + assert status[1] == 0 + assert status[2] == 0 + assert status[3] == 1 + assert status[4] == 1 + assert status[5] == 1 + assert status[6] == 0 + assert status[7] == 1 + assert status[8] == 1 + assert status[9] == 1 + assert status[10] == 0 + assert status[11] == 0 + assert status[12] == 0 + assert status[13] == 1 + assert status[14] == 0 + assert status[15] == 1 + + +def test_get_2_bits(): + status = TokenStatusList(2, b"\xc9\x44\xf9") + assert status.size == 12 + assert len(status.lst) == 3 + assert status[0] == 1 + assert status[1] == 2 + assert status[2] == 0 + assert status[3] == 3 + assert status[4] == 0 + assert status[5] == 1 + assert status[6] == 0 + assert status[7] == 1 + assert status[8] == 1 + assert status[9] == 2 + assert status[10] == 3 + assert status[11] == 3 + + +def test_set_2_bits(): + status = TokenStatusList.of_size(2, 12) + assert len(status.lst) == 3 + status[0] = 1 + status[1] = 2 + status[2] = 0 + status[3] = 3 + status[4] = 0 + status[5] = 1 + status[6] = 0 + status[7] = 1 + status[8] = 1 + status[9] = 2 + status[10] = 3 + status[11] = 3 + assert status[0] == 1 + assert status[1] == 2 + assert status[2] == 0 + assert status[3] == 3 + assert status[4] == 0 + assert status[5] == 1 + assert status[6] == 0 + assert status[7] == 1 + assert status[8] == 1 + assert status[9] == 2 + assert status[10] == 3 + assert status[11] == 3 + + +def test_get_4_bits(): + status = TokenStatusList(4, b"\x11\x22\x33\x44") + assert status.size == 8 + assert len(status.lst) == 4 + assert status[0] == 1 + assert status[1] == 1 + assert status[2] == 2 + assert status[3] == 2 + assert status[4] == 3 + assert status[5] == 3 + assert status[6] == 4 + assert status[7] == 4 + + +def test_get_8_bits(): + status = TokenStatusList(8, b"\x01\x02\x03\x04") + assert status.size == 4 + assert len(status.lst) == 4 + assert status[0] == 1 + assert status[1] == 2 + assert status[2] == 3 + assert status[3] == 4 + + +def test_compression(): + status = TokenStatusList(1, b"\xb9\xa3") + compressed = status.compressed() + assert compressed == b64url_decode(b"eNrbuRgAAhcBXQ") + + +@pytest.mark.performance +@pytest.mark.parametrize("bits", (1, 2, 4, 8)) +def test_performance(bits: int): + import time + + print() + print() + print("Bits:", bits) + + # Create a large TokenStatusList + size = 1000000 # Number of tokens + token_list = TokenStatusList.of_size(bits, size) + + # Test setting values + start_time = time.time() + for i in range(size): + token_list[i] = VALID if i % 2 == 0 else INVALID + end_time = time.time() + print(f"Time to set {size} tokens: {end_time - start_time} seconds") + + # Test getting values + start_time = time.time() + for i in range(size): + status = token_list[i] + end_time = time.time() + print(f"Time to get {size} tokens: {end_time - start_time} seconds") + + # Test compression + start_time = time.time() + compressed_data = token_list.compressed() + end_time = time.time() + print(f"Time to compress: {end_time - start_time} seconds") + print(f"Original length: {len(token_list.lst)} bytes") + print(f"Compressed length: {len(compressed_data)} bytes") + print(f"Compression ratio: {len(compressed_data) / len(token_list.lst) * 100:.3f}%") + + +def test_serde(): + expected = TokenStatusList(1, b"\xb9\xa3") + actual = TokenStatusList.deserialize(expected.serialize()) + assert len(expected.lst) == 2 + assert len(actual.lst) == 2 + assert expected.lst == actual.lst + assert expected.bits == actual.bits + assert expected.size == actual.size + + +def test_suspend_to_valid(): + status = TokenStatusList(2, b"\x80") + assert status[3] == SUSPENDED + status[3] = 0x00 + assert status[3] == VALID + + +def test_invalid_to_valid(): + status = TokenStatusList(1, b"\x80") + assert status[7] == INVALID + with pytest.raises(ValueError): + status[7] = 0x00