diff --git a/jwcrypto/jwt.py b/jwcrypto/jwt.py index 157d519..714b184 100644 --- a/jwcrypto/jwt.py +++ b/jwcrypto/jwt.py @@ -479,6 +479,7 @@ def _check_check_claims(self, check_claims): self._check_string_claim('iss', check_claims) self._check_string_claim('sub', check_claims) self._check_array_or_string_claim('aud', check_claims) + self._check_string_claim('scope', check_claims) self._check_integer_claim('exp', check_claims) self._check_integer_claim('nbf', check_claims) self._check_integer_claim('iat', check_claims) @@ -556,7 +557,24 @@ def _check_provided_claims(self): "'%s'" % (name, claims[name], value)) + elif name == 'scope': + if value is not None: + if not isinstance(claims[name], str): + raise JWTInvalidClaimValue("Invalid '%s' value. Scope list has to be a string, " + "got a %s instead: %s" % ( + name, type(claims[name]), str(claims[name]))) + + found = False + got_scopes = claims[name].split() + for s in got_scopes: + if s == value: + found = True + break + if not found: + raise JWTInvalidClaimValue("Invalid '%s' value. Scope list '%s'" + " does not contain the required scope '%s'" % ( + name, claims[name], value)) else: if value is not None and value != claims[name]: raise JWTInvalidClaimValue( diff --git a/jwcrypto/tests.py b/jwcrypto/tests.py index ae3b140..647235b 100644 --- a/jwcrypto/tests.py +++ b/jwcrypto/tests.py @@ -1977,6 +1977,68 @@ def test_unexpected(self): jwt.JWT(jwt=enctok, key=key) key.key_ops = None + def test_claims_scope(self): + key = jwk.JWK().generate(kty='oct') + + string_header = '{"alg":"HS256"}' + + # no scopes provided + claims = '{}' + t = jwt.JWT(string_header, claims) + t.make_signed_token(key) + token = t.serialize() + self.assertRaises(jwt.JWTMissingClaim, jwt.JWT, jwt=token, + key=key, check_claims={"scope": "read"}) + + # non-string scopes + claims = '{"scope": 12345}' + t = jwt.JWT(string_header, claims) + t.make_signed_token(key) + token = t.serialize() + self.assertRaises(jwt.JWTInvalidClaimValue, jwt.JWT, jwt=token, + key=key, check_claims={"scope": "read"}) + + # empty scopes + claims = '{"scope": ""}' + t = jwt.JWT(string_header, claims) + t.make_signed_token(key) + token = t.serialize() + self.assertRaises(jwt.JWTInvalidClaimValue, jwt.JWT, jwt=token, + key=key, check_claims={"scope": "read"}) + + # one correct scope + claims = '{"scope":"read"}' + t = jwt.JWT(string_header, claims) + t.make_signed_token(key) + token = t.serialize() + jwt.JWT(jwt=token, key=key, check_claims={"scope": "read"}) + self.assertRaises(jwt.JWTInvalidClaimValue, jwt.JWT, jwt=token, + key=key, check_claims={"scope": "write"}) + + # multiple scopes including the correct one + claims = '{"scope":"view read write"}' + t = jwt.JWT(string_header, claims) + t.make_signed_token(key) + token = t.serialize() + jwt.JWT(jwt=token, key=key, check_claims={"scope": "view"}) + jwt.JWT(jwt=token, key=key, check_claims={"scope": "read"}) + jwt.JWT(jwt=token, key=key, check_claims={"scope": "write"}) + self.assertRaises(jwt.JWTInvalidClaimValue, jwt.JWT, jwt=token, + key=key, check_claims={"scope": "wrong"}) + + # one correct scope, invalid value + claims = '{"scope":"read"}' + t = jwt.JWT(string_header, claims) + t.make_signed_token(key) + token = t.serialize() + self.assertRaises(jwt.JWTInvalidClaimFormat, jwt.JWT, jwt=token, + key=key, check_claims={"scope": 123}) + self.assertRaises(jwt.JWTInvalidClaimFormat, jwt.JWT, jwt=token, + key=key, check_claims={"scope": ["test", "wrong"]}) + + # finally make sure it doesn't raise if not checked. + jwt.JWT(jwt=token, key=key) + class ConformanceTests(unittest.TestCase):