@@ -64,8 +64,19 @@ class AuthJwtValidator(models.Model):
6464 ],
6565 default = "RS256" ,
6666 )
67+ audience_type = fields .Selection (
68+ [
69+ ("aud" , "Audience" ),
70+ ("group" , "Group" ),
71+ ("scope" , "Scope" ),
72+ ("custom" , "Custom" ),
73+ ],
74+ required = True ,
75+ default = "aud" ,
76+ )
77+ audience_type_custom = fields .Char (required = False , help = "payload key to validate" )
6778 audience = fields .Char (
68- required = True , help = "Comma separated list of audiences, to validate aud ."
79+ required = True , help = "Comma separated list of attribute needed ."
6980 )
7081 issuer = fields .Char (required = True , help = "To validate iss." )
7182 user_id_strategy = fields .Selection (
@@ -160,7 +171,7 @@ def _get_validator_by_name(self, validator_name):
160171
161172 @tools .ormcache ("self.public_key_jwk_uri" , "kid" )
162173 def _get_key (self , kid ):
163- jwks_client = PyJWKClient (self .public_key_jwk_uri , cache_keys = False )
174+ jwks_client = PyJWKClient (self .public_key_jwk_uri )
164175 return jwks_client .get_signing_key (kid ).key
165176
166177 def _encode (self , payload , secret , expire ):
@@ -194,20 +205,30 @@ def _decode(self, token, secret=None):
194205 raise UnauthorizedInvalidToken () from e
195206 key = self ._get_key (header .get ("kid" ))
196207 algorithm = self .public_key_algorithm
208+ aud = self .audience .split ("," ) if self .audience_type == "aud" else None
197209 try :
198210 payload = jwt .decode (
199211 token ,
200212 key = key ,
201213 algorithms = [algorithm ],
202214 options = dict (
203- require = ["exp" , "aud" , " iss" ],
215+ require = ["exp" , "iss" ],
204216 verify_exp = True ,
205- verify_aud = True ,
206217 verify_iss = True ,
207218 ),
208- audience = self . audience . split ( "," ) ,
219+ audience = aud ,
209220 issuer = self .issuer ,
210221 )
222+ payload_key = (
223+ self .audience_type_custom
224+ if self .audience_type == "custom"
225+ else self .audience_type
226+ )
227+ if len ((self .audience ).split ("," ) or []) > 0 :
228+ for key_value in (self .audience ).split ("," ):
229+ if key_value in (payload .get (payload_key )).split (" " ):
230+ return payload
231+ raise UnauthorizedInvalidToken ()
211232 except Exception as e :
212233 _logger .info ("Invalid token: %s" , e )
213234 raise UnauthorizedInvalidToken () from e
0 commit comments