Skip to content

Commit

Permalink
fix: set
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Bluhm <dbluhm@pm.me>
  • Loading branch information
dbluhm committed Jun 30, 2024
1 parent e405b7a commit 2cf75d2
Show file tree
Hide file tree
Showing 3 changed files with 352 additions and 0 deletions.
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ ignore = [
]
per-file-ignores = {"**/{tests}/*" = ["F841", "D", "E501"]}

[tool.pytest.ini_options]
addopts="-m 'not performance'"
markers = "performance: performance tests"

[tool.coverage.report]
exclude_lines = ["pragma: no cover", "@abstract"]
precision = 2
Expand Down
146 changes: 146 additions & 0 deletions src/token_status_list/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,149 @@
This implementation is based on draft 2, found here:
https://datatracker.ietf.org/doc/html/draft-ietf-oauth-status-list-02
"""

import base64
from typing import Literal, Union
import zlib


def b64url_decode(value: bytes) -> bytes:
"""Return the base64 url encoded value, without padding."""
padding_needed = 4 - (len(value) % 4)
if padding_needed != 4:
value += b"=" * padding_needed

return base64.urlsafe_b64decode(value)


def b64url_encode(value: bytes) -> bytes:
"""Return the decoded base64 url encoded value, without padding."""
return base64.urlsafe_b64encode(value).rstrip(b"=")


VALID = 0x00
INVALID = 0x01
SUSPENDED = 0x02

Bits = Union[Literal[1, 2, 4, 8], int]
StatusTypes = Union[Literal[0x00, 0x01, 0x02], int]


class TokenStatusList:
"""Token Status List."""

SHIFT_BY = {1: 3, 2: 2, 4: 1, 8: 0}
# Number of elements that fit in a byte for a number of bits
PER_BYTE = {1: 8, 2: 4, 4: 2, 8: 1}
MASK = {1: 0b1, 2: 0b11, 4: 0b1111, 8: 0b11111111}
MAX = {1: 1, 2: 3, 4: 15, 8: 255}

def __init__(
self,
bits: Bits,
lst: bytes,
):
"""Initialize the list."""
if bits not in (1, 2, 4, 8):
raise ValueError("Invalid bits value, must be one of: 1, 2, 4, 8")

self.bits = bits
self.per_byte = self.PER_BYTE[bits]
self.shift = self.SHIFT_BY[bits]
self.mask = self.MASK[bits]
self.max = self.MAX[bits]

# len * indexes per byte
self.size = len(lst) << self.shift
self.lst = bytearray(lst)

@classmethod
def of_size(cls, bits: Bits, size: int) -> "TokenStatusList":
"""Create empty list of a given size."""
per_byte = cls.PER_BYTE[bits]
if size < 1:
raise ValueError("size must be greater than 1")
# size mod per_byte
if size & (per_byte - 1) != 0:
raise ValueError(f"size must be multiple of {per_byte}")

length = size >> cls.SHIFT_BY[bits]
return cls(bits, bytearray(length))

@classmethod
def with_at_least(cls, bits: Bits, size: int):
"""Create an empty list large enough to accommodate at least the given size."""

def __getitem__(self, index: int):
"""Retrieve the status of an index."""
return self.get(index)

def __setitem__(self, index: int, status: StatusTypes):
"""Set the status of an index."""
return self.set(index, status)

def get(self, index: int):
"""Retrieve the status of an index."""
# index / indexes per byte
byte_idx = index >> self.shift
# index mod indexes per byte * bits
# Determines the number of shifts to move relevant bits all the way right
bit_idx = (index & (self.per_byte - 1)) * self.bits
# Shift relevant bits all the way right and mask out irrelevant bits
return self.mask & (self.lst[byte_idx] >> bit_idx)

def set(self, index: int, status: StatusTypes):
"""Set the status of an index."""
if status > self.max:
raise ValueError(f"status {status} too large for list with bits {self.bits}")
if index >= self.size:
raise ValueError("Invalid index; out of range")

# index / indexes per byte
byte_idx = index >> self.shift
# index mod indexes per byte * bits
# Determines the number of shifts to move relevant bits all the way right
bit_idx = (index & (self.per_byte - 1)) * self.bits
byte = self.lst[byte_idx]
# Shift relevant bits all the way right and mask out irrelevant bits
current = self.mask & (byte >> bit_idx)
if current == 0x01 and status != 0x01:
raise ValueError("Cannot change status of index previously set to invalid")

# Shift status to relevant position
status <<= bit_idx
# Create mask to clear bits getting reset
# (0 where the bits will be, 1 everywhere else)
clear_mask = ~(self.mask << bit_idx)
# Reset bits to zero
byte &= clear_mask
# Set status bits
self.lst[byte_idx] = byte | status

def compressed(self) -> bytes:
"""Return compressed list."""
return zlib.compress(self.lst, level=9)

def serialize(self) -> dict:
"""Return json serializable representation of status list."""
return {"bits": self.bits, "lst": b64url_encode(self.compressed()).decode()}

@classmethod
def deserialize(cls, value: dict) -> "TokenStatusList":
"""Parse status list from dictionary."""
bits = value.get("bits")
if not bits:
raise ValueError("bits missing from status list dictionary")

if not isinstance(bits, int):
raise TypeError("bits must be int")

lst = value.get("lst")
if not lst:
raise ValueError("lst missing from status list dictionary")

if not isinstance(lst, str):
raise TypeError("lst must be str")

parsed_lst = zlib.decompress(b64url_decode(lst.encode()))
return cls(bits, parsed_lst)
202 changes: 202 additions & 0 deletions tests/test_token_status_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
"""Test TokenStatusList."""

import pytest

from token_status_list import INVALID, SUSPENDED, TokenStatusList, VALID, b64url_decode


def test_get_1_bits():
status = TokenStatusList(1, b"\xb9\xa3")
assert status.size == 16
assert len(status.lst) == 2
assert status[0] == 1
assert status[1] == 0
assert status[2] == 0
assert status[3] == 1
assert status[4] == 1
assert status[5] == 1
assert status[6] == 0
assert status[7] == 1
assert status[8] == 1
assert status[9] == 1
assert status[10] == 0
assert status[11] == 0
assert status[12] == 0
assert status[13] == 1
assert status[14] == 0
assert status[15] == 1


def test_set_1_bits():
status = TokenStatusList.of_size(1, 16)
assert len(status.lst) == 2
status[0] = 1
status[1] = 0
status[2] = 0
status[3] = 1
status[4] = 1
status[5] = 1
status[6] = 0
status[7] = 1
status[8] = 1
status[9] = 1
status[10] = 0
status[11] = 0
status[12] = 0
status[13] = 1
status[14] = 0
status[15] = 1
assert status[0] == 1
assert status[1] == 0
assert status[2] == 0
assert status[3] == 1
assert status[4] == 1
assert status[5] == 1
assert status[6] == 0
assert status[7] == 1
assert status[8] == 1
assert status[9] == 1
assert status[10] == 0
assert status[11] == 0
assert status[12] == 0
assert status[13] == 1
assert status[14] == 0
assert status[15] == 1


def test_get_2_bits():
status = TokenStatusList(2, b"\xc9\x44\xf9")
assert status.size == 12
assert len(status.lst) == 3
assert status[0] == 1
assert status[1] == 2
assert status[2] == 0
assert status[3] == 3
assert status[4] == 0
assert status[5] == 1
assert status[6] == 0
assert status[7] == 1
assert status[8] == 1
assert status[9] == 2
assert status[10] == 3
assert status[11] == 3


def test_set_2_bits():
status = TokenStatusList.of_size(2, 12)
assert len(status.lst) == 3
status[0] = 1
status[1] = 2
status[2] = 0
status[3] = 3
status[4] = 0
status[5] = 1
status[6] = 0
status[7] = 1
status[8] = 1
status[9] = 2
status[10] = 3
status[11] = 3
assert status[0] == 1
assert status[1] == 2
assert status[2] == 0
assert status[3] == 3
assert status[4] == 0
assert status[5] == 1
assert status[6] == 0
assert status[7] == 1
assert status[8] == 1
assert status[9] == 2
assert status[10] == 3
assert status[11] == 3


def test_get_4_bits():
status = TokenStatusList(4, b"\x11\x22\x33\x44")
assert status.size == 8
assert len(status.lst) == 4
assert status[0] == 1
assert status[1] == 1
assert status[2] == 2
assert status[3] == 2
assert status[4] == 3
assert status[5] == 3
assert status[6] == 4
assert status[7] == 4


def test_get_8_bits():
status = TokenStatusList(8, b"\x01\x02\x03\x04")
assert status.size == 4
assert len(status.lst) == 4
assert status[0] == 1
assert status[1] == 2
assert status[2] == 3
assert status[3] == 4


def test_compression():
status = TokenStatusList(1, b"\xb9\xa3")
compressed = status.compressed()
assert compressed == b64url_decode(b"eNrbuRgAAhcBXQ")


@pytest.mark.performance
@pytest.mark.parametrize("bits", (1, 2, 4, 8))
def test_performance(bits: int):
import time

print()
print()
print("Bits:", bits)

# Create a large TokenStatusList
size = 1000000 # Number of tokens
token_list = TokenStatusList.of_size(bits, size)

# Test setting values
start_time = time.time()
for i in range(size):
token_list[i] = VALID if i % 2 == 0 else INVALID
end_time = time.time()
print(f"Time to set {size} tokens: {end_time - start_time} seconds")

# Test getting values
start_time = time.time()
for i in range(size):
status = token_list[i]
end_time = time.time()
print(f"Time to get {size} tokens: {end_time - start_time} seconds")

# Test compression
start_time = time.time()
compressed_data = token_list.compressed()
end_time = time.time()
print(f"Time to compress: {end_time - start_time} seconds")
print(f"Original length: {len(token_list.lst)} bytes")
print(f"Compressed length: {len(compressed_data)} bytes")
print(f"Compression ratio: {len(compressed_data) / len(token_list.lst) * 100:.3f}%")


def test_serde():
expected = TokenStatusList(1, b"\xb9\xa3")
actual = TokenStatusList.deserialize(expected.serialize())
assert len(expected.lst) == 2
assert len(actual.lst) == 2
assert expected.lst == actual.lst
assert expected.bits == actual.bits
assert expected.size == actual.size


def test_suspend_to_valid():
status = TokenStatusList(2, b"\x80")
assert status[3] == SUSPENDED
status[3] = 0x00
assert status[3] == VALID


def test_invalid_to_valid():
status = TokenStatusList(1, b"\x80")
assert status[7] == INVALID
with pytest.raises(ValueError):
status[7] = 0x00

0 comments on commit 2cf75d2

Please sign in to comment.