Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,14 @@ FastAPI native extension, easy and simple JWT auth
## Installation
You can access package [fastapi-jwt in pypi](https://pypi.org/project/fastapi-jwt/)
```shell
pip install fastapi-jwt
pip install fastapi-jwt[authlib]
# or
pip install fastapi-jwt[python_jose]
```

The fastapi-jwt will choose the backend automatically if library is installed with the following priority:
1. authlib
2. python_jose (deprecated)

## Usage
This library made in fastapi style, so it can be used as standard security features
Expand Down Expand Up @@ -81,7 +86,7 @@ There it is open and maintained [Pull Request #3305](https://github.com/tiangolo
## Requirements

* `fastapi`
* `python-jose[cryptography]`
* `authlib` or `python-jose[cryptography]` (deprecated)

## License
This project is licensed under the terms of the MIT license.
This project is licensed under the terms of the MIT license.
1 change: 1 addition & 0 deletions fastapi_jwt/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .jwt import * # noqa: F401, F403
from .jwt_backends import * # noqa: F401, F403
159 changes: 58 additions & 101 deletions fastapi_jwt/jwt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC
from datetime import datetime, timedelta
from typing import Any, Dict, Optional, Set
from typing import Any, Dict, Optional, Set, Type
from uuid import uuid4

from fastapi.exceptions import HTTPException
Expand All @@ -9,13 +9,24 @@
from fastapi.security import APIKeyCookie, HTTPBearer
from starlette.status import HTTP_401_UNAUTHORIZED

try:
from jose import jwt
except ImportError: # pragma: nocover
jwt = None # type: ignore[assignment]
from .jwt_backends import AbstractJWTBackend, authlib_backend, python_jose_backend
from .jwt_backends.abstract_backend import BackendException

DEFAULT_JWT_BACKEND: Optional[Type[AbstractJWTBackend]] = None
if authlib_backend.authlib_jose is not None:
DEFAULT_JWT_BACKEND = authlib_backend.AuthlibJWTBackend
elif python_jose_backend.jose is not None:
DEFAULT_JWT_BACKEND = python_jose_backend.PythonJoseJWTBackend
else: # pragma: nocover
raise ImportError("No JWT backend found, please install 'python-jose' or 'authlib'")

def utcnow():

def force_jwt_backend(cls: Type[AbstractJWTBackend]) -> None:
global DEFAULT_JWT_BACKEND
DEFAULT_JWT_BACKEND = cls


def utcnow() -> datetime:
try:
from datetime import UTC
except ImportError: # pragma: nocover
Expand All @@ -27,6 +38,7 @@ def utcnow():


__all__ = [
"force_jwt_backend",
"JwtAuthorizationCredentials",
"JwtAccessBearer",
"JwtAccessCookie",
Expand All @@ -49,15 +61,11 @@ def __getitem__(self, item: str) -> Any:
class JwtAuthBase(ABC):
class JwtAccessCookie(APIKeyCookie):
def __init__(self, *args: Any, **kwargs: Any):
APIKeyCookie.__init__(
self, *args, name="access_token_cookie", auto_error=False, **kwargs
)
APIKeyCookie.__init__(self, *args, name="access_token_cookie", auto_error=False, **kwargs)

class JwtRefreshCookie(APIKeyCookie):
def __init__(self, *args: Any, **kwargs: Any):
APIKeyCookie.__init__(
self, *args, name="refresh_token_cookie", auto_error=False, **kwargs
)
APIKeyCookie.__init__(self, *args, name="refresh_token_cookie", auto_error=False, **kwargs)

class JwtAccessBearer(HTTPBearer):
def __init__(self, *args: Any, **kwargs: Any):
Expand All @@ -72,38 +80,35 @@ def __init__(
secret_key: str,
places: Optional[Set[str]] = None,
auto_error: bool = True,
algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined]
algorithm: Optional[str] = None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default algorithm is now handled in the jwt backend directly.

access_expires_delta: Optional[timedelta] = None,
refresh_expires_delta: Optional[timedelta] = None,
):
assert jwt is not None, "python-jose must be installed to use JwtAuth"
if places:
assert places.issubset(
{"header", "cookie"}
), "only 'header'/'cookie' are supported"
algorithm = algorithm.upper()
assert (
hasattr(jwt.ALGORITHMS, algorithm) is True # type: ignore[attr-defined]
), f"{algorithm} algorithm is not supported by python-jose library"
assert DEFAULT_JWT_BACKEND is not None, "No JWT backend found, please install 'python-jose' or 'authlib'"

self.jwt_backend = DEFAULT_JWT_BACKEND(algorithm)
self.secret_key = secret_key

self.places = places or {"header"}
assert self.places.issubset({"header", "cookie"}), "only 'header' and/or 'cookie' places are supported"
self.auto_error = auto_error
self.algorithm = algorithm
self.access_expires_delta = access_expires_delta or timedelta(minutes=15)
self.refresh_expires_delta = refresh_expires_delta or timedelta(days=31)

@property
def algorithm(self) -> str:
return self.jwt_backend.algorithm

@classmethod
def from_other(
cls,
other: 'JwtAuthBase',
other: "JwtAuthBase",
secret_key: Optional[str] = None,
auto_error: Optional[bool] = None,
algorithm: Optional[str] = None,
access_expires_delta: Optional[timedelta] = None,
refresh_expires_delta: Optional[timedelta] = None,
) -> 'JwtAuthBase':
) -> "JwtAuthBase":
return cls(
secret_key=secret_key or other.secret_key,
auto_error=auto_error or other.auto_error,
Expand All @@ -112,30 +117,6 @@ def from_other(
refresh_expires_delta=refresh_expires_delta or other.refresh_expires_delta,
)

def _decode(self, token: str) -> Optional[Dict[str, Any]]:
try:
payload: Dict[str, Any] = jwt.decode(
token,
self.secret_key,
algorithms=[self.algorithm],
options={"leeway": 10},
)
return payload
except jwt.ExpiredSignatureError as e: # type: ignore[attr-defined]
if self.auto_error:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail=f"Token time expired: {e}"
)
else:
return None
except jwt.JWTError as e: # type: ignore[attr-defined]
if self.auto_error:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail=f"Wrong token: {e}"
)
else:
return None

Comment on lines -115 to -138
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was moved in the backend PythonJoseJWTBackend. Since the error strings are now in the backend, a dev could create its own backend to customize the error handling like thhis PR wanted to do: #7

def _generate_payload(
self,
subject: Dict[str, Any],
Expand All @@ -144,7 +125,6 @@ def _generate_payload(
token_type: str,
) -> Dict[str, Any]:
now = utcnow()

return {
"subject": subject.copy(), # main subject
"type": token_type, # 'access' or 'refresh' token
Expand All @@ -165,15 +145,18 @@ async def _get_payload(
# Check token exist
if not token:
if self.auto_error:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail="Credentials are not provided"
)
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Credentials are not provided")
else:
return None

# Try to decode jwt token. auto_error on error
payload = self._decode(token)
return payload
try:
return self.jwt_backend.decode(token, self.secret_key)
except BackendException as e:
if self.auto_error:
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail=str(e))
else:
return None

def create_access_token(
self,
Expand All @@ -183,14 +166,8 @@ def create_access_token(
) -> str:
expires_delta = expires_delta or self.access_expires_delta
unique_identifier = unique_identifier or str(uuid4())
to_encode = self._generate_payload(
subject, expires_delta, unique_identifier, "access"
)

jwt_encoded: str = jwt.encode(
to_encode, self.secret_key, algorithm=self.algorithm
)
return jwt_encoded
to_encode = self._generate_payload(subject, expires_delta, unique_identifier, "access")
return self.jwt_backend.encode(to_encode, self.secret_key)

def create_refresh_token(
self,
Expand All @@ -200,22 +177,12 @@ def create_refresh_token(
) -> str:
expires_delta = expires_delta or self.refresh_expires_delta
unique_identifier = unique_identifier or str(uuid4())
to_encode = self._generate_payload(
subject, expires_delta, unique_identifier, "refresh"
)

jwt_encoded: str = jwt.encode(
to_encode, self.secret_key, algorithm=self.algorithm
)
return jwt_encoded
to_encode = self._generate_payload(subject, expires_delta, unique_identifier, "refresh")
return self.jwt_backend.encode(to_encode, self.secret_key)

@staticmethod
def set_access_cookie(
response: Response, access_token: str, expires_delta: Optional[timedelta] = None
) -> None:
seconds_expires: Optional[int] = (
int(expires_delta.total_seconds()) if expires_delta else None
)
def set_access_cookie(response: Response, access_token: str, expires_delta: Optional[timedelta] = None) -> None:
seconds_expires: Optional[int] = int(expires_delta.total_seconds()) if expires_delta else None
response.set_cookie(
key="access_token_cookie",
value=access_token,
Expand All @@ -229,9 +196,7 @@ def set_refresh_cookie(
refresh_token: str,
expires_delta: Optional[timedelta] = None,
) -> None:
seconds_expires: Optional[int] = (
int(expires_delta.total_seconds()) if expires_delta else None
)
seconds_expires: Optional[int] = int(expires_delta.total_seconds()) if expires_delta else None
response.set_cookie(
key="refresh_token_cookie",
value=refresh_token,
Expand All @@ -241,15 +206,11 @@ def set_refresh_cookie(

@staticmethod
def unset_access_cookie(response: Response) -> None:
response.set_cookie(
key="access_token_cookie", value="", httponly=False, max_age=-1
)
response.set_cookie(key="access_token_cookie", value="", httponly=False, max_age=-1)

@staticmethod
def unset_refresh_cookie(response: Response) -> None:
response.set_cookie(
key="refresh_token_cookie", value="", httponly=True, max_age=-1
)
response.set_cookie(key="refresh_token_cookie", value="", httponly=True, max_age=-1)


class JwtAccess(JwtAuthBase):
Expand All @@ -261,7 +222,7 @@ def __init__(
secret_key: str,
places: Optional[Set[str]] = None,
auto_error: bool = True,
algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined]
algorithm: Optional[str] = None,
access_expires_delta: Optional[timedelta] = None,
refresh_expires_delta: Optional[timedelta] = None,
):
Expand All @@ -282,9 +243,7 @@ async def _get_credentials(
payload = await self._get_payload(bearer, cookie)

if payload:
return JwtAuthorizationCredentials(
payload["subject"], payload.get("jti", None)
)
return JwtAuthorizationCredentials(payload["subject"], payload.get("jti", None))
return None


Expand All @@ -293,7 +252,7 @@ def __init__(
self,
secret_key: str,
auto_error: bool = True,
algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined]
algorithm: Optional[str] = None,
access_expires_delta: Optional[timedelta] = None,
refresh_expires_delta: Optional[timedelta] = None,
):
Expand All @@ -317,7 +276,7 @@ def __init__(
self,
secret_key: str,
auto_error: bool = True,
algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined]
algorithm: Optional[str] = None,
access_expires_delta: Optional[timedelta] = None,
refresh_expires_delta: Optional[timedelta] = None,
):
Expand All @@ -342,7 +301,7 @@ def __init__(
self,
secret_key: str,
auto_error: bool = True,
algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined]
algorithm: Optional[str] = None,
access_expires_delta: Optional[timedelta] = None,
refresh_expires_delta: Optional[timedelta] = None,
):
Expand Down Expand Up @@ -372,7 +331,7 @@ def __init__(
secret_key: str,
places: Optional[Set[str]] = None,
auto_error: bool = True,
algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined]
algorithm: Optional[str] = None,
access_expires_delta: Optional[timedelta] = None,
refresh_expires_delta: Optional[timedelta] = None,
):
Expand All @@ -399,22 +358,20 @@ async def _get_credentials(
if self.auto_error:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Wrong token: 'type' is not 'refresh'",
detail="Invalid token: 'type' is not 'refresh'",
)
else:
return None

return JwtAuthorizationCredentials(
payload["subject"], payload.get("jti", None)
)
return JwtAuthorizationCredentials(payload["subject"], payload.get("jti", None))


class JwtRefreshBearer(JwtRefresh):
def __init__(
self,
secret_key: str,
auto_error: bool = True,
algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined]
algorithm: Optional[str] = None,
access_expires_delta: Optional[timedelta] = None,
refresh_expires_delta: Optional[timedelta] = None,
):
Expand All @@ -438,7 +395,7 @@ def __init__(
self,
secret_key: str,
auto_error: bool = True,
algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined]
algorithm: Optional[str] = None,
access_expires_delta: Optional[timedelta] = None,
refresh_expires_delta: Optional[timedelta] = None,
):
Expand All @@ -463,7 +420,7 @@ def __init__(
self,
secret_key: str,
auto_error: bool = True,
algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined]
algorithm: Optional[str] = None,
access_expires_delta: Optional[timedelta] = None,
refresh_expires_delta: Optional[timedelta] = None,
):
Expand Down
4 changes: 4 additions & 0 deletions fastapi_jwt/jwt_backends/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from . import abstract_backend, authlib_backend, python_jose_backend # noqa: F401
from .abstract_backend import AbstractJWTBackend # noqa: F401
from .authlib_backend import AuthlibJWTBackend # noqa: F401
from .python_jose_backend import PythonJoseJWTBackend # noqa: F401
Loading