diff --git a/src/token_status_list/__init__.py b/src/token_status_list/__init__.py index e203d08..bc1be68 100644 --- a/src/token_status_list/__init__.py +++ b/src/token_status_list/__init__.py @@ -7,7 +7,9 @@ """ import base64 -from typing import Literal, Union +import json +from time import time +from typing import Literal, Optional, Union import zlib @@ -25,6 +27,11 @@ def b64url_encode(value: bytes) -> bytes: return base64.urlsafe_b64encode(value).rstrip(b"=") +def dict_to_b64(value: dict) -> bytes: + """Transform a dictionary into base64url encoded json dump of dictionary.""" + return b64url_encode(json.dumps(value, separators=(",", ":")).encode()) + + VALID = 0x00 INVALID = 0x01 SUSPENDED = 0x02 @@ -77,6 +84,10 @@ def of_size(cls, bits: Bits, size: int) -> "TokenStatusList": @classmethod def with_at_least(cls, bits: Bits, size: int): """Create an empty list large enough to accommodate at least the given size.""" + # Determine minimum number of bytes to fit size + # This is essentially a fast ceil(n / 2^x) + length = (size + cls.PER_BYTE[bits] - 1) >> cls.SHIFT_BY[bits] + return cls(bits, bytearray(length)) def __getitem__(self, index: int): """Retrieve the status of an index.""" @@ -151,3 +162,80 @@ def deserialize(cls, value: dict) -> "TokenStatusList": parsed_lst = zlib.decompress(b64url_decode(lst.encode())) return cls(bits, parsed_lst) + + def sign_payload( + self, + *, + alg: str, + kid: str, + iss: str, + sub: str, + iat: Optional[int] = None, + exp: Optional[int] = None, + ttl: Optional[int] = None, + ) -> bytes: + """Create a Status List Token payload for signing. + + Signing is NOT performed by this function; only the payload to the signature is + prepared. The caller is responsible for producing a signature. + + Args: + alg: REQUIRED. The algorithm to be used to sign the payload. + + kid: REQUIRED. The kid used to sign the payload. + + iss: REQUIRED when also present in the Referenced Token. The iss (issuer) + claim MUST specify a unique string identifier for the entity that issued + the Status List Token. In the absence of an application profile specifying + otherwise, compliant applications MUST compare issuer values using the + Simple String Comparison method defined in Section 6.2.1 of [RFC3986]. + The value MUST be equal to that of the iss claim contained within the + Referenced Token. + + sub: REQUIRED. The sub (subject) claim MUST specify a unique string identifier + for the Status List Token. The value MUST be equal to that of the uri + claim contained in the status_list claim of the Referenced Token. + + iat: REQUIRED. The iat (issued at) claim MUST specify the time at which the + Status List Token was issued. + + exp: OPTIONAL. The exp (expiration time) claim, if present, MUST specify the + time at which the Status List Token is considered expired by its issuer. + + ttl: OPTIONAL. The ttl (time to live) claim, if present, MUST specify the + maximum amount of time, in seconds, that the Status List Token can be + cached by a consumer before a fresh copy SHOULD be retrieved. The value + of the claim MUST be a positive number. + """ + headers = { + "typ": "statuslist+jwt", + "alg": alg, + "kid": kid, + } + payload = { + "iss": iss, + "sub": sub, + "iat": iat or int(time()), + "status_list": self.serialize(), + } + if exp is not None: + payload["exp"] = exp + + if ttl is not None: + payload["ttl"] = ttl + + enc_headers = dict_to_b64(headers).decode() + enc_payload = dict_to_b64(payload).decode() + return f"{enc_headers}.{enc_payload}".encode() + + def signed_token(self, signed_payload: bytes, signature: bytes) -> str: + """Finish creating a signed token. + + Args: + signed_payload: The value returned from `sign_payload`. + signature: The signature over the signed_payload in bytes. + + Returns: + Finished Status List Token. + """ + return f"{signed_payload.decode()}.{b64url_encode(signature)}" diff --git a/tests/test_token_status_list.py b/tests/test_token_status_list.py index 033a1e4..eb39ed7 100644 --- a/tests/test_token_status_list.py +++ b/tests/test_token_status_list.py @@ -1,6 +1,8 @@ """Test TokenStatusList.""" +import json import pytest +from secrets import randbelow from token_status_list import INVALID, SUSPENDED, TokenStatusList, VALID, b64url_decode @@ -151,31 +153,44 @@ def test_performance(bits: int): print("Bits:", bits) # Create a large TokenStatusList - size = 1000000 # Number of tokens - token_list = TokenStatusList.of_size(bits, size) + size = 1000000 # Number of indices + status_list = TokenStatusList.of_size(bits, size) + + # Generate random statuses + statuses = [] + while len(statuses) < size: + run = randbelow(10) + status = randbelow(2) + statuses.extend([status] * run) + + diff = len(statuses) - size + if diff > 1: + for _ in range(diff + 1): + statuses.pop() # Test setting values start_time = time.time() - for i in range(size): - token_list[i] = VALID if i % 2 == 0 else INVALID + for i, status in enumerate(statuses): + status_list[i] = status end_time = time.time() - print(f"Time to set {size} tokens: {end_time - start_time} seconds") + print(f"Time to set {size} indices: {end_time - start_time:.3f} seconds") # Test getting values start_time = time.time() for i in range(size): - status = token_list[i] + status = status_list[i] end_time = time.time() - print(f"Time to get {size} tokens: {end_time - start_time} seconds") + print(f"Time to get {size} indices: {end_time - start_time:.3f} seconds") # Test compression start_time = time.time() - compressed_data = token_list.compressed() + compressed_data = status_list.compressed() end_time = time.time() - print(f"Time to compress: {end_time - start_time} seconds") - print(f"Original length: {len(token_list.lst)} bytes") + print(f"Time to compress: {end_time - start_time:.3f} seconds") + print(f"Original length: {len(status_list.lst)} bytes") print(f"Compressed length: {len(compressed_data)} bytes") - print(f"Compression ratio: {len(compressed_data) / len(token_list.lst) * 100:.3f}%") + print(f"Compression ratio: {len(compressed_data) / len(status_list.lst) * 100:.3f}%") + # print(f"List in hex: {status_list.lst.hex()}") def test_serde(): @@ -200,3 +215,65 @@ def test_invalid_to_valid(): assert status[7] == INVALID with pytest.raises(ValueError): status[7] = 0x00 + + +def test_of_size(): + with pytest.raises(ValueError): + status = TokenStatusList.of_size(1, 3) + with pytest.raises(ValueError): + status = TokenStatusList.of_size(2, 21) + with pytest.raises(ValueError): + status = TokenStatusList.of_size(4, 31) + + # Lists with bits 8 can have arbitrary size since there's no byte + # boundaries to worry about + status = TokenStatusList.of_size(8, 31) + assert len(status.lst) == 31 + + status = TokenStatusList.of_size(1, 8) + assert len(status.lst) == 1 + status = TokenStatusList.of_size(1, 16) + assert len(status.lst) == 2 + status = TokenStatusList.of_size(1, 24) + assert len(status.lst) == 3 + status = TokenStatusList.of_size(8, 24) + assert len(status.lst) == 24 + + +def test_with_at_least(): + status = TokenStatusList.with_at_least(1, 3) + assert len(status.lst) == 1 + status = TokenStatusList.with_at_least(2, 21) + assert len(status.lst) == 6 + status = TokenStatusList.with_at_least(4, 31) + assert len(status.lst) == 16 + + status = TokenStatusList.with_at_least(1, 8) + assert len(status.lst) == 1 + status = TokenStatusList.with_at_least(2, 24) + assert len(status.lst) == 6 + status = TokenStatusList.with_at_least(4, 32) + assert len(status.lst) == 16 + + +def test_sign_payload(): + status = TokenStatusList(1, b"\xb9\xa3") + payload = status.sign_payload( + alg="ES256", + kid="12", + iss="https://example.com", + sub="https://example.com/statuslists/1", + iat=1686920170, + exp=2291720170, + ) + headers, payload = payload.split(b".") + headers = json.loads(b64url_decode(headers)) + payload = json.loads(b64url_decode(payload)) + assert headers == {"alg": "ES256", "kid": "12", "typ": "statuslist+jwt"} + assert payload == { + "exp": 2291720170, + "iat": 1686920170, + "iss": "https://example.com", + "status_list": {"bits": 1, "lst": "eNrbuRgAAhcBXQ"}, + "sub": "https://example.com/statuslists/1", + }