Skip to content

Commit

Permalink
feat: add jwt framing
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Bluhm <dbluhm@pm.me>
  • Loading branch information
dbluhm committed Jun 30, 2024
1 parent 2cf75d2 commit ddc626d
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 12 deletions.
90 changes: 89 additions & 1 deletion src/token_status_list/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
"""

import base64
from typing import Literal, Union
import json
from time import time
from typing import Literal, Optional, Union
import zlib


Expand All @@ -25,6 +27,11 @@ def b64url_encode(value: bytes) -> bytes:
return base64.urlsafe_b64encode(value).rstrip(b"=")


def dict_to_b64(value: dict) -> bytes:
"""Transform a dictionary into base64url encoded json dump of dictionary."""
return b64url_encode(json.dumps(value, separators=(",", ":")).encode())


VALID = 0x00
INVALID = 0x01
SUSPENDED = 0x02
Expand Down Expand Up @@ -77,6 +84,10 @@ def of_size(cls, bits: Bits, size: int) -> "TokenStatusList":
@classmethod
def with_at_least(cls, bits: Bits, size: int):
"""Create an empty list large enough to accommodate at least the given size."""
# Determine minimum number of bytes to fit size
# This is essentially a fast ceil(n / 2^x)
length = (size + cls.PER_BYTE[bits] - 1) >> cls.SHIFT_BY[bits]
return cls(bits, bytearray(length))

def __getitem__(self, index: int):
"""Retrieve the status of an index."""
Expand Down Expand Up @@ -151,3 +162,80 @@ def deserialize(cls, value: dict) -> "TokenStatusList":

parsed_lst = zlib.decompress(b64url_decode(lst.encode()))
return cls(bits, parsed_lst)

def sign_payload(
self,
*,
alg: str,
kid: str,
iss: str,
sub: str,
iat: Optional[int] = None,
exp: Optional[int] = None,
ttl: Optional[int] = None,
) -> bytes:
"""Create a Status List Token payload for signing.
Signing is NOT performed by this function; only the payload to the signature is
prepared. The caller is responsible for producing a signature.
Args:
alg: REQUIRED. The algorithm to be used to sign the payload.
kid: REQUIRED. The kid used to sign the payload.
iss: REQUIRED when also present in the Referenced Token. The iss (issuer)
claim MUST specify a unique string identifier for the entity that issued
the Status List Token. In the absence of an application profile specifying
otherwise, compliant applications MUST compare issuer values using the
Simple String Comparison method defined in Section 6.2.1 of [RFC3986].
The value MUST be equal to that of the iss claim contained within the
Referenced Token.
sub: REQUIRED. The sub (subject) claim MUST specify a unique string identifier
for the Status List Token. The value MUST be equal to that of the uri
claim contained in the status_list claim of the Referenced Token.
iat: REQUIRED. The iat (issued at) claim MUST specify the time at which the
Status List Token was issued.
exp: OPTIONAL. The exp (expiration time) claim, if present, MUST specify the
time at which the Status List Token is considered expired by its issuer.
ttl: OPTIONAL. The ttl (time to live) claim, if present, MUST specify the
maximum amount of time, in seconds, that the Status List Token can be
cached by a consumer before a fresh copy SHOULD be retrieved. The value
of the claim MUST be a positive number.
"""
headers = {
"typ": "statuslist+jwt",
"alg": alg,
"kid": kid,
}
payload = {
"iss": iss,
"sub": sub,
"iat": iat or int(time()),
"status_list": self.serialize(),
}
if exp is not None:
payload["exp"] = exp

if ttl is not None:
payload["ttl"] = ttl

enc_headers = dict_to_b64(headers).decode()
enc_payload = dict_to_b64(payload).decode()
return f"{enc_headers}.{enc_payload}".encode()

def signed_token(self, signed_payload: bytes, signature: bytes) -> str:
"""Finish creating a signed token.
Args:
signed_payload: The value returned from `sign_payload`.
signature: The signature over the signed_payload in bytes.
Returns:
Finished Status List Token.
"""
return f"{signed_payload.decode()}.{b64url_encode(signature)}"
99 changes: 88 additions & 11 deletions tests/test_token_status_list.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Test TokenStatusList."""

import json
import pytest
from secrets import randbelow

from token_status_list import INVALID, SUSPENDED, TokenStatusList, VALID, b64url_decode

Expand Down Expand Up @@ -151,31 +153,44 @@ def test_performance(bits: int):
print("Bits:", bits)

# Create a large TokenStatusList
size = 1000000 # Number of tokens
token_list = TokenStatusList.of_size(bits, size)
size = 1000000 # Number of indices
status_list = TokenStatusList.of_size(bits, size)

# Generate random statuses
statuses = []
while len(statuses) < size:
run = randbelow(10)
status = randbelow(2)
statuses.extend([status] * run)

diff = len(statuses) - size
if diff > 1:
for _ in range(diff + 1):
statuses.pop()

# Test setting values
start_time = time.time()
for i in range(size):
token_list[i] = VALID if i % 2 == 0 else INVALID
for i, status in enumerate(statuses):
status_list[i] = status
end_time = time.time()
print(f"Time to set {size} tokens: {end_time - start_time} seconds")
print(f"Time to set {size} indices: {end_time - start_time:.3f} seconds")

# Test getting values
start_time = time.time()
for i in range(size):
status = token_list[i]
status = status_list[i]
end_time = time.time()
print(f"Time to get {size} tokens: {end_time - start_time} seconds")
print(f"Time to get {size} indices: {end_time - start_time:.3f} seconds")

# Test compression
start_time = time.time()
compressed_data = token_list.compressed()
compressed_data = status_list.compressed()
end_time = time.time()
print(f"Time to compress: {end_time - start_time} seconds")
print(f"Original length: {len(token_list.lst)} bytes")
print(f"Time to compress: {end_time - start_time:.3f} seconds")
print(f"Original length: {len(status_list.lst)} bytes")
print(f"Compressed length: {len(compressed_data)} bytes")
print(f"Compression ratio: {len(compressed_data) / len(token_list.lst) * 100:.3f}%")
print(f"Compression ratio: {len(compressed_data) / len(status_list.lst) * 100:.3f}%")
# print(f"List in hex: {status_list.lst.hex()}")


def test_serde():
Expand All @@ -200,3 +215,65 @@ def test_invalid_to_valid():
assert status[7] == INVALID
with pytest.raises(ValueError):
status[7] = 0x00


def test_of_size():
with pytest.raises(ValueError):
status = TokenStatusList.of_size(1, 3)
with pytest.raises(ValueError):
status = TokenStatusList.of_size(2, 21)
with pytest.raises(ValueError):
status = TokenStatusList.of_size(4, 31)

# Lists with bits 8 can have arbitrary size since there's no byte
# boundaries to worry about
status = TokenStatusList.of_size(8, 31)
assert len(status.lst) == 31

status = TokenStatusList.of_size(1, 8)
assert len(status.lst) == 1
status = TokenStatusList.of_size(1, 16)
assert len(status.lst) == 2
status = TokenStatusList.of_size(1, 24)
assert len(status.lst) == 3
status = TokenStatusList.of_size(8, 24)
assert len(status.lst) == 24


def test_with_at_least():
status = TokenStatusList.with_at_least(1, 3)
assert len(status.lst) == 1
status = TokenStatusList.with_at_least(2, 21)
assert len(status.lst) == 6
status = TokenStatusList.with_at_least(4, 31)
assert len(status.lst) == 16

status = TokenStatusList.with_at_least(1, 8)
assert len(status.lst) == 1
status = TokenStatusList.with_at_least(2, 24)
assert len(status.lst) == 6
status = TokenStatusList.with_at_least(4, 32)
assert len(status.lst) == 16


def test_sign_payload():
status = TokenStatusList(1, b"\xb9\xa3")
payload = status.sign_payload(
alg="ES256",
kid="12",
iss="https://example.com",
sub="https://example.com/statuslists/1",
iat=1686920170,
exp=2291720170,
)
headers, payload = payload.split(b".")
headers = json.loads(b64url_decode(headers))
payload = json.loads(b64url_decode(payload))
assert headers == {"alg": "ES256", "kid": "12", "typ": "statuslist+jwt"}
assert payload == {
"exp": 2291720170,
"iat": 1686920170,
"iss": "https://example.com",
"status_list": {"bits": 1, "lst": "eNrbuRgAAhcBXQ"},
"sub": "https://example.com/statuslists/1",
}

0 comments on commit ddc626d

Please sign in to comment.