diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 4706358..dfdbda9 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -14,6 +14,9 @@ ] } }, + "otherPortsAttributes": { + "onAutoForward": "ignore" + }, "postStartCommand": "pip3 install --user -r requirements-dev.txt", "postAttachCommand": "python3 -m pytest tests", "remoteUser": "vscode" diff --git a/.github/workflows/run-unit-tests.yml b/.github/workflows/run-unit-tests.yml index 2cc3ae8..ffc53b7 100644 --- a/.github/workflows/run-unit-tests.yml +++ b/.github/workflows/run-unit-tests.yml @@ -45,8 +45,8 @@ jobs: - name: Install dependencies run: | - pip install -r requirements.txt - pip install pytest pytest-cov + pip install -r requirements-dev.txt + pip install pytest-cov - name: Run Pytest (Linux/macOS) if: runner.os != 'Windows' diff --git a/.github/workflows/suggest-version-bump.yml b/.github/workflows/suggest-version-bump.yml index f9b247a..b53e15f 100644 --- a/.github/workflows/suggest-version-bump.yml +++ b/.github/workflows/suggest-version-bump.yml @@ -51,7 +51,7 @@ jobs: BUMP="patch" echo "$LABELS" | grep -q 'type: feature' && BUMP="minor" echo "$LABELS" | grep -q 'type: security' && BUMP="minor" - echo "$LABELS" | grep -q 'type: breaking' && BUMP="major" + echo "$LABELS" | grep -q 'special: breaking change' && BUMP="major" echo "bump=$BUMP" >> "$GITHUB_OUTPUT" - name: Get latest tag diff --git a/.github/workflows/super-linter.yml b/.github/workflows/super-linter.yml index 3b5b228..8209037 100644 --- a/.github/workflows/super-linter.yml +++ b/.github/workflows/super-linter.yml @@ -2,7 +2,7 @@ name: Lint permissions: - contents: read + contents: write packages: read statuses: write @@ -27,10 +27,7 @@ jobs: uses: super-linter/super-linter/slim@v7 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - VALIDATE_ALL_CODEBASE: false - FILTER_REGEX_EXCLUDE: '(.devcontainer/Dockerfile|.github/pull_request_template.md|.github/ISSUE_TEMPLATE/*.md)' - VALIDATE_PYTHON_ISORT: false - VALIDATE_PYTHON_MYPY: false + DISABLE_ERRORS: true fix-lint: name: Fix Lint @@ -48,13 +45,18 @@ jobs: VALIDATE_ALL_CODEBASE: false FILTER_REGEX_EXCLUDE: '(.github/pull_request_template.md|.github/ISSUE_TEMPLATE/*.md)' GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + VALIDATE_DOCKERFILE_HADOLINT: false VALIDATE_PYTHON_ISORT: false VALIDATE_PYTHON_MYPY: false + VALIDATE_PYTHON_PYLINT: false + FIX_HTML_PRETTIER: true + FIX_JSON: true FIX_JSON_PRETTIER: true FIX_MARKDOWN: true FIX_MARKDOWN_PRETTIER: true + FIX_PYTHON_BLACK: true + FIX_PYTHON_RUFF: true FIX_YAML_PRETTIER: true - VALIDATE_DOCKERFILE_HADOLINT: false - name: Commit and push linting fixes if: > diff --git a/.gitignore b/.gitignore index 1cc48e5..49267a4 100644 --- a/.gitignore +++ b/.gitignore @@ -187,3 +187,7 @@ cython_debug/ # refer to https://docs.cursor.com/context/ignore-files .cursorignore .cursorindexingignore + +# # Super Linter output +github_conf/branch_protection_rules.json +super-linter-output/super-linter-summary.md diff --git a/.prettierrc b/.prettierrc index 5ea96ce..5c1606c 100644 --- a/.prettierrc +++ b/.prettierrc @@ -7,7 +7,11 @@ "singleQuote": true, "overrides": [ { - "files": ["*.yml", "*.yaml", "*.md"], + "files": [ + "*.yml", + "*.yaml", + "*.md" + ], "options": { "tabWidth": 2 } diff --git a/.vscode/settings.json b/.vscode/settings.json index 000bbd4..1511dea 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -7,19 +7,19 @@ "files.trimTrailingWhitespace": true, "files.exclude": { "**/__pycache__": true, - "**/.pytest_cache": true + "**/.pytest_cache": true, + "**/*.egg-info": true }, "[python]": { - "editor.rulers": [88], + "editor.rulers": [80], "editor.defaultFormatter": "ms-python.black-formatter", - "editor.formatOnSave": true, "editor.codeActionsOnSave": { - "source.organizeImports": "explicit" + "source.organizeImports": "always" } }, "isort.args": ["--profile", "black"], "triggerTaskOnSave.tasks": { - "Run on test file": ["tests/**/test_*.py"], + "Run file tests": ["tests/**/test_*.py"], "Run all tests": ["app/**/*.py", "!tests/**"] }, "python.testing.unittestEnabled": false, diff --git a/.vscode/tasks.json b/.vscode/tasks.json index 0af07d9..d98fe43 100644 --- a/.vscode/tasks.json +++ b/.vscode/tasks.json @@ -5,64 +5,52 @@ "label": "Delete old Git branches", "type": "shell", "command": "git fetch --prune && git fetch -p ; git branch -r | awk '{print $1}' | egrep -v -f /dev/fd/0 <(git branch -vv | grep origin) | awk '{print $1}' | xargs git branch -D", - "problemMatcher": [], "presentation": { "showReuseMessage": false - } + }, + "problemMatcher": [] }, { "label": "Package app", "type": "shell", "command": "rm -rf build dist *.egg-info && python -m build", - "problemMatcher": [], - "runOptions": { - "runOn": "default" + "group": { + "kind": "build", + "isDefault": true }, "presentation": { "panel": "dedicated" }, - "group": { - "kind": "build", - "isDefault": true - } + "problemMatcher": [] }, { "label": "Run all tests", "type": "shell", "command": "python3 -m pytest tests", "group": "build", - "problemMatcher": [], - "runOptions": { - "runOn": "default" - } + "presentation": { + "close": true + }, + "problemMatcher": [] }, { - "label": "Run on test file", + "label": "Run file tests", "type": "shell", - "command": "python3 -m pytest '${relativeFile}' -v -x", + "command": "python3 -m pytest '${file}' -v -x", "group": { "kind": "test" }, - "presentation": { - "close": true - }, - "problemMatcher": [], - "runOptions": { - "runOn": "default", - } + "problemMatcher": [] }, { "label": "Start FastAPI server", "type": "shell", "command": "uvicorn app.main:app --reload", - "problemMatcher": [], - "runOptions": { - "runOn": "default" - }, "presentation": { "panel": "dedicated", "close": true - } + }, + "problemMatcher": [] } ] } diff --git a/app/routes/v1/__init__.py b/app/routes/v1/__init__.py index ec0b86f..d11db38 100644 --- a/app/routes/v1/__init__.py +++ b/app/routes/v1/__init__.py @@ -5,13 +5,15 @@ from fastapi import FastAPI from .endpoints.authentication import router as auth_router +from .endpoints.email import router as email_router from .endpoints.orders import router as order_router from .endpoints.products import router as product_router -__version__ = "1.1.0" +__version__ = "1.2.0" api = FastAPI(title="ChocoMax Shop API", version=__version__) api.include_router(auth_router, prefix="/auth", tags=["Authentication"]) +api.include_router(email_router, prefix="/email", tags=["Email"]) api.include_router(product_router, prefix="/products", tags=["Products"]) api.include_router(order_router, prefix="/orders", tags=["Orders"]) diff --git a/app/routes/v1/endpoints/authentication.py b/app/routes/v1/endpoints/authentication.py index d0925d9..e82b4a0 100644 --- a/app/routes/v1/endpoints/authentication.py +++ b/app/routes/v1/endpoints/authentication.py @@ -1,69 +1,418 @@ -from fastapi import APIRouter, Depends, HTTPException +""" +Authentication endpoints and utilities for user login, registration, and 2FA. + +This module provides FastAPI endpoints for user authentication, including login, +two-factor authentication (2FA), and registration. It also includes utility +functions for interacting with the database and handling authentication logic. +""" + +import json +import random +import secrets +import time + +from fastapi import APIRouter, Depends, HTTPException, Request +from pyotp import random_base32 as generate_otp_secret from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncSession +from user_agents import parse as parse_user_agent +from app.routes.v1.schemas.user.login import UserLogin, UserLogin2FA from app.routes.v1.schemas.user.register import UserRegister from app.utility.database import get_db -from app.utility.security import ( - encrypt_email, - encrypt_phone, - hash_email, - hash_password, - hash_phone, -) +from app.utility.security import hash_email, hash_password, verify_otp, verify_password from app.utility.string_utils import sanitize_username router = APIRouter() +_2fa_sessions = ( + {} +) # Temporary in-memory store for 2FA sessions TODO: (replace with Redis or DB in production) + + +# --- Common utility functions --- + + +async def get_password_hash_by_email_hash( + db: AsyncSession, email_hash: str +) -> str | None: + """ + Retrieve the password hash for a user by their email hash. + + Args: + db (AsyncSession): The database session. + email_hash (str): The hashed email address. + + Returns: + str | None: The password hash if found, otherwise None. + """ + result = await db.execute( + text("SELECT get_password_hash_by_email_hash(:email_hash)"), + {"email_hash": email_hash}, + ) + return result.scalar() + + +async def get_2fa_secret(db: AsyncSession, email_hash: str, method: str = "TOTP"): + """ + Retrieve the 2FA secret for a user by their email hash and authentication method. + + Args: + db (AsyncSession): The database session. + email_hash (str): The hashed email address. + method (str): The authentication method (default: "TOTP"). + + Returns: + Row or None: The database row containing the 2FA secret, or None if not found. + """ + result = await db.execute( + text( + "SELECT * FROM get_user_2fa_secret_by_email_hash(:email_hash, :auth_method)" + ), + {"email_hash": email_hash, "auth_method": method}, + ) + return result.fetchone() + + +async def get_user_info(db: AsyncSession, email_hash: str): + """ + Retrieve user information by their email hash. + + Args: + db (AsyncSession): The database session. + email_hash (str): The hashed email address. + + Returns: + Row or None: The database row containing user information, or None if not found. + """ + result = await db.execute( + text("SELECT * FROM get_user_info_by_email_hash(:email_hash)"), + {"email_hash": email_hash}, + ) + return result.fetchone() + + +async def save_session_token( + db: AsyncSession, + user_id: int, + session_token: str, + device_info: str, + ip_address: str, + user_agent: str, +): + """ + Save a session token for a user with device and IP info. + + Args: + db (AsyncSession): The database session. + user_id (int): The user's ID. + session_token (str): The session token to save. + device_info (str): Information about the user's device. + ip_address (str): The user's IP address. + """ + await db.execute( + text( + """ + CALL create_user_session_token( + :p_user_id, + :p_session_token, + :p_device_info, + :p_ip_address, + :p_user_agent + ) + """ + ), + { + "p_user_id": user_id, + "p_session_token": session_token, + "p_device_info": device_info, + "p_ip_address": ip_address, + "p_user_agent": user_agent, + }, + ) + await db.commit() + + +async def save_refresh_token( + db: AsyncSession, + user_id: int, + session_token: str, + device_info: str, + ip_address: str, + user_agent: str, +): + """ + Save a session token for a user with device and IP info. + + Args: + db (AsyncSession): The database session. + user_id (int): The user's ID. + session_token (str): The session token to save. + device_info (str): Information about the user's device. + ip_address (str): The user's IP address. + """ + await db.execute( + text( + """ + CALL create_user_refresh_token( + :p_user_id, + :p_session_token, + :p_device_info, + :p_ip_address, + :p_user_agent + ) + """ + ), + { + "p_user_id": user_id, + "p_session_token": session_token, + "p_device_info": device_info, + "p_ip_address": ip_address, + "p_user_agent": user_agent, + }, + ) + await db.commit() + + +def get_device_info_and_ip(request: Request): + """Extract device info and IP address from the request.""" + user_agent_str = request.headers.get("User-Agent", "") + ua = parse_user_agent(user_agent_str) + device_info = json.dumps( + { + "family": ua.device.family, + "brand": ua.device.brand, + "model": ua.device.model, + "is_mobile": ua.is_mobile, + "is_tablet": ua.is_tablet, + "is_pc": ua.is_pc, + "is_bot": ua.is_bot, + } + ) + ip_address = request.headers.get("X-Real-IP") or request.client.host + return device_info, ip_address, user_agent_str + + +def filter_user_fields(user_dict, fields): + return {k: user_dict[k] for k in fields if k in user_dict} + + +async def create_and_return_session( + db, user_info, device_info, ip_address, user_agent_str +): + """Create session and refresh tokens, save them, and return selected user info with tokens.""" + session_token = secrets.token_urlsafe(32) + refresh_token = secrets.token_urlsafe(32) + + await save_session_token( + db, user_info.user_id, session_token, device_info, ip_address, user_agent_str + ) + await save_refresh_token( + db, user_info.user_id, refresh_token, device_info, ip_address, user_agent_str + ) + + user_dict = dict(user_info._mapping) + selected_fields = [ + "username", + "discriminator", + "language_id", + "display_role", + "created_at", + ] + + return { + **filter_user_fields(user_dict, selected_fields), + "session_token": session_token, + "refresh_token": refresh_token, + } + + +# --- Endpoints --- + + +@router.post("/login") +async def login(data: UserLogin, request: Request, db: AsyncSession = Depends(get_db)): + """ + Step 1: Verify email and password, check if 2FA is required. + + Args: + data (UserLogin): The login request payload. + db (AsyncSession): The database session. + + Returns: + dict: If 2FA is required, returns a dict with 2FA info and a temporary token. + Otherwise, returns user info/session. + Raises: + HTTPException: If credentials are invalid. + """ + email_hash = hash_email(data.email) + password = data.password + + # Parse user agent for device info + device_info, ip_address, user_agent_str = get_device_info_and_ip(request) + + # Verify password + password_hash = await get_password_hash_by_email_hash(db, email_hash) + if not password_hash or not verify_password(password, password_hash): + raise HTTPException(401, "Invalid credentials") + + # Check if 2FA is enabled before fetching user info + row = await get_2fa_secret(db, email_hash) + user_2fa_secret = row.authentication_secret if row else None + + if user_2fa_secret: + temp_token = secrets.token_urlsafe(32) + + # Store mapping with expiry (5 minutes) + _2fa_sessions[temp_token] = { + "email_hash": email_hash, + "expires_at": time.time() + 300, + } + + result = await db.execute( + text("SELECT * FROM get_user_2fa_methods_by_email_hash(:email_hash)"), + {"email_hash": email_hash}, + ) + methods_rows = await result.fetchall() + methods = [row.authentication_method for row in methods_rows] + preferred_method = next( + (row.authentication_method for row in methods_rows if row.is_preferred), + None, + ) + + if methods: + return { + "2fa_required": True, + "token": temp_token, + "methods": methods, + "preferred_method": preferred_method, + } + + user_info = await get_user_info(db, email_hash) + return await create_and_return_session( + db, user_info, device_info, ip_address, user_agent_str + ) + + +@router.post("/login/otp") +async def login_otp( + data: UserLogin2FA, request: Request, db: AsyncSession = Depends(get_db) +): + """ + Step 2: Verify OTP code and return user info/session. + + Args: + data (UserLogin2FA): The 2FA login request payload. + db (AsyncSession): The database session. + + Returns: + dict: User info/session if OTP is valid. + Raises: + HTTPException: If the session token is invalid/expired, 2FA is not enabled, or OTP is invalid. + """ + device_info, ip_address, user_agent_str = get_device_info_and_ip(request) + + session = _2fa_sessions.get(data.token) + if not session or session["expires_at"] < time.time(): + raise HTTPException(401, "Invalid or expired 2FA session token") + + email_hash = session["email_hash"] + + row = await get_2fa_secret(db, email_hash) + secret = row.authentication_secret if row else None + + if not secret: + raise HTTPException(400, "2FA is not enabled for this user") + + if not verify_otp(secret, data.otp_code): + raise HTTPException(401, "Invalid 2FA code") + + user_info = await get_user_info(db, email_hash) + return await create_and_return_session( + db, user_info, device_info, ip_address, user_agent_str + ) @router.post("/register") async def register(data: UserRegister, db: AsyncSession = Depends(get_db)): - """Endpoint for user registration.""" + """ + Endpoint for user registration. + + Args: + data (UserRegister): The registration request payload. + db (AsyncSession): The database session. + + Returns: + dict: Registration result with username and discriminator. + Raises: + HTTPException: If the token is missing/invalid, all discriminators are taken, or email exists. + """ + token = data.token username = sanitize_username(data.username) - email_encrypted = encrypt_email(data.email) - email_hash = hash_email(data.email) password_hash = hash_password(data.password) - phone_encrypted = encrypt_phone(data.phone) if data.phone else None - phone_hash = hash_phone(data.phone) if data.phone else None language_id = data.language_id + otp_secret = generate_otp_secret() + + if not token: + raise HTTPException(400, "Token is required for registration") - # Check if user is available + # Check if the token exists and is valid using the new function result = await db.execute( - text( - """ - SELECT is_user_available(:username, :email_hash, :phone_hash) AS available - """ - ), - {"username": username, "email_hash": email_hash, "phone_hash": phone_hash}, + text("SELECT is_verification_token_valid(:token)"), + {"token": token}, + ) + is_valid = result.scalar() + + if not is_valid: + raise HTTPException(400, "Invalid or expired verification token") + + # Retrieve the list of discriminators for the username + result = await db.execute( + text("SELECT get_used_discriminators(:username) AS discriminator"), + {"username": username}, + ) + used_discriminators = [row.discriminator for row in result.fetchall()] + available_discriminators = set(range(0, 10000)) - set(used_discriminators) + + if not available_discriminators: + raise HTTPException(409, "All discriminators taken for this username") + + # Choose a random discriminator from the available ones + discriminator = random.choice(list(available_discriminators)) + + # Check if email is available + result = await db.execute( + text("SELECT is_email_available(:token) AS available"), {"token": token} ) available = result.scalar() if not available: - raise HTTPException(409, "Username, email, or phone already exists") + raise HTTPException(409, "Email already exists") await db.execute( text( """ CALL register_user( + :token, :username, - :email_encrypted, - :email_hash, + :discriminator, :password_hash, - :phone_encrypted, - :phone_hash, - :preferred_language_id + :preferred_language_id, + :otp_secret ) """ ), { + "token": token, "username": username, - "email_encrypted": email_encrypted, - "email_hash": email_hash, + "discriminator": discriminator, "password_hash": password_hash, - "phone_encrypted": phone_encrypted, - "phone_hash": phone_hash, "preferred_language_id": language_id, + "otp_secret": otp_secret, }, ) await db.commit() - return {"message": "User registered successfully", "username": username} + return { + "message": "User registered successfully", + "username": username, + "discriminator": discriminator, + } diff --git a/app/routes/v1/endpoints/email.py b/app/routes/v1/endpoints/email.py new file mode 100644 index 0000000..aa8bc82 --- /dev/null +++ b/app/routes/v1/endpoints/email.py @@ -0,0 +1,65 @@ +""" +Email endpoints for user confirmation. + +This module provides FastAPI endpoints for sending confirmation emails to users. +It handles the creation of verification tokens, email encryption, and interaction +with the database to register pending users. +""" + +from fastapi import APIRouter, BackgroundTasks, Depends +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncSession + +from app.routes.v1.schemas.email.request import EmailRequest +from app.utility.database import get_db +from app.utility.email.schemas import RegistrationEmailSchema +from app.utility.email.sender import send_email_background +from app.utility.security import create_verification_token, encrypt_email, hash_email + +router = APIRouter() + + +@router.post("/confirmation") +async def send_confirmation_email( + data: EmailRequest, + background_tasks: BackgroundTasks, + db: AsyncSession = Depends(get_db), +): + """ + Endpoint to send a confirmation email to the user. + + This endpoint accepts a POST request with the user's email in the body, + generates a verification token, encrypts and hashes the email, sends a + confirmation email asynchronously, and stores the pending user in the database. + + Args: + data (EmailRequest): The request payload containing the user's email. + background_tasks (BackgroundTasks): FastAPI background task manager. + db (AsyncSession): The database session. + + Returns: + dict: A dictionary with a detail message and the confirmation token. + """ + email = data.email + token = create_verification_token() + email_encrypted = encrypt_email(email) + email_hash = hash_email(email) + + email_schema = RegistrationEmailSchema( + email=[email], + body={ + "title": "Welcome to ChocoMax", + "confirmation_url": f"http://?token={token}", + }, + ) + + send_email_background(background_tasks, email_schema) + + # Send the token to the Database + await db.execute( + text("CALL create_pending_user(:email_encrypted, :email_hash, :token)"), + {"email_encrypted": email_encrypted, "email_hash": email_hash, "token": token}, + ) + await db.commit() + + return {"detail": "Confirmation email sent"} diff --git a/app/utility/test.py b/app/routes/v1/schemas/email/__init__.py similarity index 100% rename from app/utility/test.py rename to app/routes/v1/schemas/email/__init__.py diff --git a/app/routes/v1/schemas/email/request.py b/app/routes/v1/schemas/email/request.py new file mode 100644 index 0000000..a546123 --- /dev/null +++ b/app/routes/v1/schemas/email/request.py @@ -0,0 +1,19 @@ +""" +Schemas for email-related API requests. + +This module defines the Pydantic model used for validating and serializing +email request payloads in the API v1 endpoints. +""" + +from pydantic import BaseModel, EmailStr + + +class EmailRequest(BaseModel): + """ + Schema for email-related API requests. + + Attributes: + email (EmailStr): The user's email address. + """ + + email: EmailStr diff --git a/app/routes/v1/schemas/user/login.py b/app/routes/v1/schemas/user/login.py new file mode 100644 index 0000000..76222ac --- /dev/null +++ b/app/routes/v1/schemas/user/login.py @@ -0,0 +1,34 @@ +""" +Schemas for user login and two-factor authentication (2FA) requests. + +This module defines Pydantic models used for validating and serializing +user login and 2FA payloads in the authentication endpoints. +""" + +from pydantic import BaseModel + + +class UserLogin(BaseModel): + """ + Schema for user login request. + + Attributes: + email (str): The user's email address. + password (str): The user's password. + """ + + email: str + password: str + + +class UserLogin2FA(BaseModel): + """ + Schema for user two-factor authentication (2FA) request. + + Attributes: + otp_code (int): The one-time password code for 2FA. + token (str): The temporary token issued after initial login. + """ + + otp_code: int + token: str diff --git a/app/routes/v1/schemas/user/register.py b/app/routes/v1/schemas/user/register.py index dd6228f..e21b779 100644 --- a/app/routes/v1/schemas/user/register.py +++ b/app/routes/v1/schemas/user/register.py @@ -1,9 +1,25 @@ -from pydantic import BaseModel, EmailStr +""" +Schemas for user registration requests. + +This module defines the Pydantic model used for validating and serializing +user registration payloads in the authentication endpoints. +""" + +from pydantic import BaseModel class UserRegister(BaseModel): + """ + Schema for user registration request. + + Attributes: + token (str): The registration or invitation token. + username (str): The desired username for the new user. + password (str): The user's password. + language_id (int | None): Optional language preference identifier. + """ + + token: str username: str - email: EmailStr password: str - phone: str | None = None language_id: int | None = None diff --git a/app/templates/email_confirmation.html b/app/templates/email_confirmation.html new file mode 100644 index 0000000..645dce7 --- /dev/null +++ b/app/templates/email_confirmation.html @@ -0,0 +1,107 @@ + + + + Email Confirmation + + + + + +
+

{{ title }}

+
+

Hello!

+

+ Thank you for signing up! Please confirm your email address + to complete your account creation. +

+ + Confirm Your Email + + + +
+
+ + diff --git a/app/utility/database.py b/app/utility/database.py index c91b1e8..50b5fcb 100644 --- a/app/utility/database.py +++ b/app/utility/database.py @@ -1,13 +1,30 @@ +""" +Database utility module for asynchronous SQLAlchemy sessions. + +This module sets up the asynchronous SQLAlchemy engine and sessionmaker for +database interactions. It provides a dependency function for obtaining a +database session in FastAPI endpoints. +""" + from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker DATABASE_URL = ( "postgresql+asyncpg://postgres:S3cur3Str0ngP%40ss@172.17.0.1:5432/chocomax" ) -engine = create_async_engine(DATABASE_URL, echo=True) +engine = create_async_engine(DATABASE_URL, echo=False, pool_pre_ping=True) SessionLocal = sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False) async def get_db(): + """ + Dependency that provides a SQLAlchemy asynchronous database session. + + Yields: + AsyncSession: An active SQLAlchemy async session for database operations. + + Usage: + Use as a dependency in FastAPI endpoints to access the database. + """ async with SessionLocal() as session: yield session diff --git a/app/utility/email/__init__.py b/app/utility/email/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/utility/email/config.py b/app/utility/email/config.py new file mode 100644 index 0000000..6e8ca41 --- /dev/null +++ b/app/utility/email/config.py @@ -0,0 +1,27 @@ +""" +Email configuration module. + +This module loads environment variables and sets up the email connection +configuration for sending emails through the application. +""" + +import os +from pathlib import Path + +from dotenv import load_dotenv +from fastapi_mail import ConnectionConfig + +load_dotenv(".env") + +conf = ConnectionConfig( + MAIL_USERNAME=os.getenv("MAIL_USERNAME", "user@example.com"), + MAIL_PASSWORD=os.getenv("MAIL_PASSWORD", "password"), + MAIL_FROM=os.getenv("MAIL_FROM", "noreply@chocomax.com"), + MAIL_PORT=os.getenv("MAIL_PORT", "587"), + MAIL_SERVER=os.getenv("MAIL_SERVER", "smtp.example.com"), + MAIL_FROM_NAME=os.getenv("MAIL_FROM_NAME", "ChocoMax"), + MAIL_STARTTLS=True, + MAIL_SSL_TLS=False, + USE_CREDENTIALS=True, + TEMPLATE_FOLDER=Path(__file__).parent.parent.parent / "templates", +) diff --git a/app/utility/email/schemas.py b/app/utility/email/schemas.py new file mode 100644 index 0000000..66ff4e6 --- /dev/null +++ b/app/utility/email/schemas.py @@ -0,0 +1,51 @@ +""" +Schemas for email-related payloads. + +This module defines Pydantic models for validating and serializing +email payloads, such as registration and password reset emails, +used by the application's email utility functions. +""" + +from typing import Any, Dict, List + +from pydantic import BaseModel, EmailStr + + +class BaseEmailSchema(BaseModel): + """ + Base schema for email payloads. + + Attributes: + email (List[EmailStr]): List of recipient email addresses. + subject (str): Subject line of the email. + template_name (str): Name of the template to use for the email body. + body (Dict[str, Any]): Data to render within the email template. + """ + + email: List[EmailStr] + subject: str + template_name: str + body: Dict[str, Any] + + +class RegistrationEmailSchema(BaseEmailSchema): + """ + Schema for registration confirmation emails. + + Attributes: + subject (str): Default subject for registration emails. + template_name (str): Default template for registration emails. + """ + + subject: str = "ChocoMax - Email Confirmation" + template_name: str = "email_confirmation.html" + + +class PasswordResetEmailSchema(BaseEmailSchema): + """ + Schema for password reset emails. + + Inherits all fields from BaseEmailSchema. + """ + + pass # TODO diff --git a/app/utility/email/sender.py b/app/utility/email/sender.py new file mode 100644 index 0000000..1b69148 --- /dev/null +++ b/app/utility/email/sender.py @@ -0,0 +1,36 @@ +""" +Email sending utilities. + +This module provides functions for sending emails using FastAPI background tasks +and the FastMail library. It is used to send templated emails asynchronously +throughout the application. +""" + +from fastapi import BackgroundTasks +from fastapi_mail import FastMail, MessageSchema, MessageType + +from .config import conf +from .schemas import BaseEmailSchema + + +def send_email_background(background_tasks: BackgroundTasks, email: BaseEmailSchema): + """ + Send an email in the background using FastAPI's BackgroundTasks. + + Args: + background_tasks (BackgroundTasks): The FastAPI background task manager. + email (BaseEmailSchema): The email payload containing recipients, subject, template, and body. + + This function creates a message from the provided schema and schedules it to be sent + asynchronously using FastMail and the specified template. + """ + message = MessageSchema( + subject=email.subject, + recipients=email.email, + template_body=email.body, + subtype=MessageType.html, + ) + fm = FastMail(conf) + background_tasks.add_task( + fm.send_message, message, template_name=email.template_name + ) diff --git a/app/utility/security.py b/app/utility/security.py index 4412e34..c13bee6 100644 --- a/app/utility/security.py +++ b/app/utility/security.py @@ -1,7 +1,18 @@ +""" +Security utility module for encryption, hashing, and authentication. + +This module provides functions and constants for handling password hashing, +field encryption/decryption, OTP verification, and normalization of sensitive +fields such as email and phone numbers. It is used throughout the application +to ensure secure handling of user credentials and sensitive data. +""" + import base64 import hashlib import os +import secrets +import pyotp from argon2 import PasswordHasher from cryptography.hazmat.primitives.ciphers.aead import AESGCM @@ -11,18 +22,55 @@ PEPPER = os.getenv("PEPPER", "SuperSecretPepper").encode("utf-8") -def normalize_email(email: str) -> str: - """Normalize email by stripping spaces and converting to lowercase.""" - return email.strip().lower() +def create_token(length: int) -> str: + """ + Generate a secure random token for session management. + Returns: + str: A URL-safe, random token string. + """ + return secrets.token_urlsafe(length) + + +def create_verification_token() -> str: + """ + Generate a secure random token for email verification. + + Returns: + str: A URL-safe, random token string. + """ + return create_token(32) + + +def create_access_token() -> str: + """ + Generate a secure random token for access control. + + Returns: + str: A URL-safe, random token string. + """ + return create_token(32) -def normalize_phone(phone: str) -> str: - """Normalize phone number by stripping spaces and removing non-numeric characters.""" - return phone.strip().replace(" ", "") +def create_refresh_token() -> str: + """ + Generate a secure random token for refresh operations. + + Returns: + str: A URL-safe, random token string. + """ + return create_token(64) def encrypt_field(value: str) -> str: - """Encrypt a value using AES-256-GCM (returns base64 of IV + ciphertext + tag).""" + """ + Encrypt a value using AES-256-GCM. + + Args: + value (str): The value to encrypt. + + Returns: + str: Base64-encoded string of IV + ciphertext + tag. + """ aesgcm = AESGCM(AES_KEY) iv = os.urandom(12) # 96-bit IV recommended for AES-GCM ciphertext = aesgcm.encrypt(iv, value.encode("utf-8"), associated_data=None) @@ -30,7 +78,15 @@ def encrypt_field(value: str) -> str: def decrypt_field(encrypted_base64: str) -> str: - """Decrypt a value encrypted with AES-256-GCM.""" + """ + Decrypt a value encrypted with AES-256-GCM. + + Args: + encrypted_base64 (str): The base64-encoded encrypted value. + + Returns: + str: The decrypted string. + """ encrypted_data = base64.b64decode(encrypted_base64) iv, ciphertext = encrypted_data[:12], encrypted_data[12:] aesgcm = AESGCM(AES_KEY) @@ -39,34 +95,137 @@ def decrypt_field(encrypted_base64: str) -> str: def hash_field(value: str) -> str: - """Generate a SHA-256 hash of a field (used for fast lookup).""" + """ + Generate a SHA-256 hash of a field (used for fast lookup). + + Args: + value (str): The value to hash. + + Returns: + str: The SHA-256 hash as a hexadecimal string. + """ return hashlib.sha256(value.encode("utf-8")).hexdigest() def hash_password(password: str) -> str: - """Hash a password using Argon2 + pepper.""" + """ + Hash a password using Argon2 and a pepper. + + Args: + password (str): The plaintext password. + + Returns: + str: The Argon2 hash of the peppered password. + """ peppered_password = password.encode("utf-8") + PEPPER return ph.hash(peppered_password) -# Wrappers for email and phone to ensure normalization and hashing +def verify_otp( + secret: str, otp_code: str, otp_method: str = "TOTP", counter: int = 0 +) -> bool: + """ + Verify a one-time password (OTP) against a secret using TOTP or HOTP. + + Args: + secret (str): The OTP secret. + otp_code (str): The OTP code to verify. + otp_method (str): The OTP method ("TOTP" or "HOTP"). + counter (int): The HOTP counter (required for HOTP). + + Returns: + bool: True if the OTP is valid, False otherwise. + """ + try: + match otp_method: + case "TOTP": + return pyotp.TOTP(secret).verify(otp_code) + case "HOTP": + return pyotp.HOTP(secret).verify(otp_code, counter) + case _: + raise ValueError("Unsupported OTP method") + except Exception: + return False + + +def verify_password(password: str, hashed_password: str) -> bool: + """ + Verify a password against a hashed password. + + Args: + hashed_password (str): The Argon2 hashed password. + password (str): The plaintext password to verify. + + Returns: + bool: True if the password matches, False otherwise. + """ + peppered_password = password.encode("utf-8") + PEPPER + try: + return ph.verify(hashed_password, peppered_password) + except Exception: + return False def hash_email(email: str) -> str: - """Generate a SHA-256 hash of the email (used for fast lookup).""" - return hash_field(normalize_email(email)) + """ + Generate a SHA-256 hash of the email (used for fast lookup). + + Args: + email (str): The email address. + + Returns: + str: The SHA-256 hash of the email. + """ + return hash_field(email) def hash_phone(phone: str) -> str: - """Generate a SHA-256 hash of the phone number (used for fast lookup).""" - return hash_field(normalize_phone(phone)) + """ + Generate a SHA-256 hash of the phone number (used for fast lookup). + + Args: + phone (str): The phone number. + + Returns: + str: The SHA-256 hash of the phone number. + """ + return hash_field(phone) + + +def hash_token(token: str) -> str: + """ + Generate a SHA-256 hash of the token (used for fast lookup). + + Args: + token (str): The token to hash. + + Returns: + str: The SHA-256 hash of the token. + """ + return hash_field(token) def encrypt_email(email: str) -> str: - """Encrypt the email after normalizing it.""" - return encrypt_field(normalize_email(email)) + """ + Encrypt the email address. + + Args: + email (str): The email address to encrypt. + + Returns: + str: The encrypted email. + """ + return encrypt_field(email) def encrypt_phone(phone: str) -> str: - """Encrypt the phone number after normalizing it.""" - return encrypt_field(normalize_phone(phone)) + """ + Encrypt the normalized phone number. + + Args: + phone (str): The phone number to encrypt. + + Returns: + str: The encrypted phone number. + """ + return encrypt_field(phone) diff --git a/app/utility/string_utils.py b/app/utility/string_utils.py index abd323b..8ecd937 100644 --- a/app/utility/string_utils.py +++ b/app/utility/string_utils.py @@ -1,3 +1,9 @@ +""" +string_utils.py + +Utility functions for string manipulation and sanitization used throughout the API application. +""" + import re diff --git a/app/version.py b/app/version.py index f93bed8..0e4c943 100644 --- a/app/version.py +++ b/app/version.py @@ -4,4 +4,4 @@ It is used to track changes and updates to the codebase. """ -__version__ = "0.2.1" +__version__ = "0.3.0" diff --git a/requirements-dev.txt b/requirements-dev.txt index 5893078..b236181 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,3 +1,4 @@ -r requirements.txt -pytest>=8.3.5 -build>=1.2.0 +build>=1.2.0,<2 +pytest>=8.3.5,<9 +pytest-asyncio>=1.0.0,<2 diff --git a/requirements.txt b/requirements.txt index 5c69585..d3c332e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,14 @@ -argon2-cffi>=25.1.0 -asyncpg>=0.30.0 -cryptography>=45.0.3 -fastapi>=0.115.12 -httpx>=0.28.1 -pydantic[email] -sqlalchemy>=2.0.41 -uvicorn>=0.34.2 +argon2-cffi>=25.1.0,<26 +asyncpg>=0.30.0,<1 +cryptography>=45.0.3,<46 +dotenv>=0.9.9,<1 +fastapi>=0.115.12,<1 +fastapi_mail>=1.5.0,<2 +greenlet>=3.2.3,<4 +httpx>=0.28.1,<1 +pydantic[email]>=2.11.6,<3 +pyotp>=2.9.0,<3 +requests>=2.32.4,<3 +sqlalchemy>=2.0.41,<3 +user-agents>=2.2.0,<3 +uvicorn>=0.34.2,<1 diff --git a/tests/test_home.py b/tests/test_home.py index 75fbedc..5390b36 100644 --- a/tests/test_home.py +++ b/tests/test_home.py @@ -2,14 +2,22 @@ Test the home endpoint of the API. """ +import pytest +from fastapi import FastAPI from fastapi.testclient import TestClient -from app.main import app +from app.routes.home import router as home_router -client = TestClient(app) +@pytest.fixture +def client(): + """Fixture to create a test client for the FastAPI app with only the home router.""" + app = FastAPI() + app.include_router(home_router) + return TestClient(app) -def test_home(): + +def test_home(client): """Test the home endpoint.""" response = client.get("/") assert response.status_code == 200 diff --git a/tests/test_routes/v1/test_authentication.py b/tests/test_routes/v1/test_authentication.py new file mode 100644 index 0000000..93169d1 --- /dev/null +++ b/tests/test_routes/v1/test_authentication.py @@ -0,0 +1,278 @@ +import sys +from unittest.mock import AsyncMock, patch + +import pytest +from fastapi.testclient import TestClient + +from app.routes.v1.endpoints import authentication as auth_module + +AUTH_PATH = "app.routes.v1.endpoints.authentication" + + +@pytest.fixture +def client(): + """ + Returns a FastAPI TestClient with the authentication router included. + """ + from fastapi import FastAPI + + app = FastAPI() + app.include_router(auth_module.router, prefix="/v1/auth") + return TestClient(app) + + +@pytest.fixture(autouse=True) +def patch_auth_dependencies(): + """ + Automatically patches authentication dependencies for all tests. + Provides default mock return values for password hash, password verification, + 2FA secret, user info, and email hashing. + """ + with ( + patch( + f"{AUTH_PATH}.get_password_hash_by_email_hash", new_callable=AsyncMock + ) as get_pw_hash_mock, + patch(f"{AUTH_PATH}.verify_password") as verify_pw_mock, + patch( + f"{AUTH_PATH}.get_2fa_secret", new_callable=AsyncMock + ) as get_2fa_secret_mock, + patch( + f"{AUTH_PATH}.get_user_info", new_callable=AsyncMock + ) as get_user_info_mock, + patch(f"{AUTH_PATH}.hash_email") as hash_email_mock, + ): + get_pw_hash_mock.return_value = "hashed-password" + verify_pw_mock.return_value = True + get_2fa_secret_mock.return_value = None + get_user_info_mock.return_value = AsyncMock(_mapping={"username": "testuser"}) + hash_email_mock.return_value = "dummy-email-hash" + yield { + "get_pw_hash": get_pw_hash_mock, + "verify_pw": verify_pw_mock, + "get_2fa_secret": get_2fa_secret_mock, + "get_user_info": get_user_info_mock, + "hash_email": hash_email_mock, + } + + +@pytest.fixture +def mock_db_and_override(client): + """ + Provides a mock database session and overrides the get_db dependency. + """ + mock_db = AsyncMock() + + async def override_get_db(): + yield mock_db + + client.app.dependency_overrides[auth_module.get_db] = override_get_db + return mock_db + + +@pytest.fixture +def login_payload(): + """Returns a function to generate login payloads.""" + + def _payload(email="testuser", password="password123"): + return { + "email": email, + "password": password, + } + + return _payload + + +@pytest.fixture +def otp_payload(): + """Returns a function to generate OTP login payloads.""" + + def _payload(token="validtoken", otp_code=123456): + return { + "token": token, + "otp_code": otp_code, + } + + return _payload + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "username, password", + [ + ("testuser", "password123"), + ("admin", "adminpass"), + ("user123", "userpass"), + ("test.user", "testpass"), + ("first.last", "flpass"), + ], +) +async def test_login_returns_session_tokens( + client, + mock_db_and_override, + patch_auth_dependencies, + username, + password, + login_payload, +): + """ + Test successful login returns opaque session and refresh tokens when 2FA is not required. + """ + # Patch user info to include user_id for token generation simulation + patch_auth_dependencies["get_user_info"].return_value = AsyncMock( + _mapping={"username": "testuser"} + ) + + response = client.post("/v1/auth/login", json=login_payload()) + assert response.status_code == 200 + data = response.json() + # Expect opaque tokens in response + assert "session_token" in data + assert "refresh_token" in data + assert data["username"] == "testuser" + + +@pytest.mark.asyncio +async def test_login_2fa_required_returns_2fa_token( + client, mock_db_and_override, patch_auth_dependencies, login_payload +): + """ + Test login returns a 2FA-required response with a temporary token if 2FA is enabled. + """ + + # Simulate 2FA enabled + class Dummy2FASecret: + authentication_secret = "dummysecret" + + patch_auth_dependencies["get_2fa_secret"].return_value = Dummy2FASecret() + + # Simulate available 2FA methods + with ( + patch(f"{AUTH_PATH}.text"), + patch(f"{AUTH_PATH}.time"), + ): + # Patch DB call for 2FA methods + mock_methods = [ + type("Row", (), {"authentication_method": "TOTP", "is_preferred": True})(), + type("Row", (), {"authentication_method": "SMS", "is_preferred": False})(), + ] + mock_db = mock_db_and_override + mock_execute = AsyncMock() + mock_execute.fetchall.return_value = mock_methods + mock_db.execute.return_value = mock_execute + + response = client.post("/v1/auth/login", json=login_payload()) + + assert response.status_code == 200 + data = response.json() + assert data["2fa_required"] is True + assert "token" in data + assert set(data["methods"]) == {"TOTP", "SMS"} + assert data["preferred_method"] == "TOTP" + + +@pytest.mark.asyncio +async def test_login_invalid_credentials( + client, mock_db_and_override, patch_auth_dependencies, login_payload +): + """ + Test login with invalid credentials returns 401 and no tokens. + """ + patch_auth_dependencies["verify_pw"].return_value = False + + response = client.post( + "/v1/auth/login", + json=login_payload(email="invaliduser", password="wrongpassword"), + ) + + assert response.status_code == 401 + data = response.json() + assert data["detail"] == "Invalid credentials" + + +@pytest.mark.asyncio +async def test_login_missing_fields(client, mock_db_and_override): + """ + Test login with missing fields returns 422 and no tokens. + """ + response = client.post("/v1/auth/login", json={}) + assert response.status_code == 422 + data = response.json() + assert "detail" in data + + +@pytest.mark.asyncio +async def test_login_otp_success_returns_tokens( + client, mock_db_and_override, patch_auth_dependencies, otp_payload +): + """ + Test /login/otp returns session and refresh tokens on successful OTP verification. + """ + # Simulate valid 2FA session and OTP + patch_auth_dependencies["get_2fa_secret"].return_value = AsyncMock( + authentication_secret="dummysecret" + ) + with patch(f"{AUTH_PATH}.verify_otp") as verify_otp_mock: + verify_otp_mock.return_value = True + patch_auth_dependencies["get_user_info"].return_value = AsyncMock( + _mapping={"username": "testuser"} + ) + # Simulate valid token in _2fa_sessions + with patch.object( + auth_module, + "_2fa_sessions", + { + "validtoken": { + "email_hash": "dummy-email-hash", + "expires_at": 9999999999, + } + }, + ): + response = client.post("/v1/auth/login/otp", json=otp_payload()) + + assert response.status_code == 200 + data = response.json() + assert "session_token" in data + assert "refresh_token" in data + assert data["username"] == "testuser" + + +@pytest.mark.asyncio +async def test_login_otp_invalid_token(client, otp_payload): + """ + Test /login/otp with an invalid or expired token returns 401. + """ + with patch.object(auth_module, "_2fa_sessions", {}): + response = client.post( + "/v1/auth/login/otp", json=otp_payload(token="invalidtoken") + ) + + assert response.status_code == 401 + data = response.json() + assert "2FA session token" in data["detail"] + + +@pytest.mark.asyncio +async def test_login_otp_invalid_otp(client, patch_auth_dependencies, otp_payload): + """ + Test /login/otp with an invalid OTP code returns 401. + """ + patch_auth_dependencies["get_2fa_secret"].return_value = AsyncMock( + authentication_secret="dummysecret" + ) + with patch(f"{AUTH_PATH}.verify_otp") as verify_otp_mock: + verify_otp_mock.return_value = False + with patch.object( + auth_module, + "_2fa_sessions", + { + "validtoken": { + "email_hash": "dummy-email-hash", + "expires_at": sys.maxsize, + } + }, + ): + response = client.post("/v1/auth/login/otp", json=otp_payload()) + + assert response.status_code == 401 + data = response.json() + assert "Invalid 2FA code" in data["detail"] diff --git a/tests/test_routes/v1/test_email.py b/tests/test_routes/v1/test_email.py new file mode 100644 index 0000000..d87cb64 --- /dev/null +++ b/tests/test_routes/v1/test_email.py @@ -0,0 +1,116 @@ +from unittest.mock import AsyncMock, patch + +import pytest +from fastapi.testclient import TestClient + +from app.routes.v1.endpoints import email as email_module + +EMAIL_PATH = "app.routes.v1.endpoints.email" + + +@pytest.fixture +def client(): + from fastapi import FastAPI + + app = FastAPI() + app.include_router(email_module.router, prefix="/v1/email") + return TestClient(app) + + +@pytest.fixture(autouse=True) +def patch_email_dependencies(): + with ( + patch(f"{EMAIL_PATH}.send_email_background") as send_mock, + patch(f"{EMAIL_PATH}.hash_email", return_value="hashed-email") as hash_mock, + patch( + f"{EMAIL_PATH}.encrypt_email", return_value="encrypted-email" + ) as enc_mock, + patch( + f"{EMAIL_PATH}.create_verification_token", return_value="test-token" + ) as token_mock, + ): + yield { + "send": send_mock, + "hash": hash_mock, + "enc": enc_mock, + "token": token_mock, + } + + +@pytest.fixture +def mock_db_and_override(client): + mock_db = AsyncMock() + + async def override_get_db(): + yield mock_db + + client.app.dependency_overrides[email_module.get_db] = override_get_db + return mock_db + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "email", + [ + "user@example.com", + "USER@domain.io", + "test.user+alias@domain.co.uk", + "first.last@sub.domain.com", + "user123@domain.io", + "user_name@domain.org", + ], +) +async def test_send_confirmation_email_success( + client: TestClient, + patch_email_dependencies: dict, + mock_db_and_override: AsyncMock, + email: str, +): + mock_db = mock_db_and_override + + response = client.post("/v1/email/confirmation", json={"email": email}) + + assert response.status_code == 200 + data = response.json() + assert data["detail"] == "Confirmation email sent" + + patch_email_dependencies["token"].assert_called_once() + patch_email_dependencies["enc"].assert_called_once_with(email) + patch_email_dependencies["hash"].assert_called_once_with(email) + patch_email_dependencies["send"].assert_called_once() + mock_db.execute.assert_awaited_once() + mock_db.commit.assert_awaited_once() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "email", + [ + "", + "not-an-email", + "user@.com", + "user@domain", + "userdomain.com", + "@domain.com", + "user@domain..com", + "user@domain,com", + "user@domain.?com", + ], +) +async def test_send_confirmation_email_failure( + client: TestClient, + patch_email_dependencies: dict, + mock_db_and_override: AsyncMock, + email: str, +): + mock_db = mock_db_and_override + + response = client.post("/v1/email/confirmation", json={"email": email}) + assert response.status_code == 422 # Unprocessable Entity + + patch_email_dependencies["token"].assert_not_called() + patch_email_dependencies["enc"].assert_not_called() + patch_email_dependencies["hash"].assert_not_called() + patch_email_dependencies["send"].assert_not_called() + mock_db.execute.assert_not_awaited() + mock_db.commit.assert_not_awaited() diff --git a/tests/test_routes/v1/test_orders.py b/tests/test_routes/v1/test_orders.py index fc294b6..bc90801 100644 --- a/tests/test_routes/v1/test_orders.py +++ b/tests/test_routes/v1/test_orders.py @@ -4,19 +4,26 @@ This module uses the `v1_get` utility to avoid repeating the API version path. """ +import pytest from fastapi.testclient import TestClient -from app.main import app -from tests.utils.request import v1_get +from app.routes.v1.endpoints import orders as orders_module -client = TestClient(app) +@pytest.fixture +def client(): + from fastapi import FastAPI -def test_orders(): + app = FastAPI() + app.include_router(orders_module.router, prefix="/v1/orders") + return TestClient(app) + + +def test_orders(client: TestClient): """ Test that the `/api/v1/orders` endpoint returns a 200 status and responds with a JSON list. """ - response = v1_get(client, "/orders") + response = client.get("/v1/orders") assert response.status_code == 200 assert isinstance(response.json(), list) diff --git a/tests/test_routes/v1/test_products.py b/tests/test_routes/v1/test_products.py index 1f4d4d9..50d400d 100644 --- a/tests/test_routes/v1/test_products.py +++ b/tests/test_routes/v1/test_products.py @@ -4,19 +4,26 @@ This module uses the `v1_get` utility to avoid repeating the API version path. """ +import pytest from fastapi.testclient import TestClient -from app.main import app -from tests.utils.request import v1_get +from app.routes.v1.endpoints import products as products_module -client = TestClient(app) +@pytest.fixture +def client(): + from fastapi import FastAPI -def test_products(): + app = FastAPI() + app.include_router(products_module.router, prefix="/v1/products") + return TestClient(app) + + +def test_products(client: TestClient): """ Test that the `/api/v1/products` endpoint returns a 200 status and responds with a JSON list. """ - response = v1_get(client, "/products") + response = client.get("/v1/products") assert response.status_code == 200 assert isinstance(response.json(), list) diff --git a/tests/test_utility/conftest.py b/tests/test_utility/conftest.py new file mode 100644 index 0000000..53eae16 --- /dev/null +++ b/tests/test_utility/conftest.py @@ -0,0 +1,13 @@ +import pytest + + +@pytest.fixture +def sample_email(): + """Fixture to provide a sample email for testing.""" + return "user@example.com" + + +@pytest.fixture +def sample_phone(): + """Fixture to provide a sample phone number for testing.""" + return "+1234567890" diff --git a/tests/test_utility/test_create_verification_token.py b/tests/test_utility/test_create_verification_token.py new file mode 100644 index 0000000..bda3804 --- /dev/null +++ b/tests/test_utility/test_create_verification_token.py @@ -0,0 +1,35 @@ +import re + +import pytest + +from app.utility.security import create_verification_token + + +@pytest.fixture +def verification_token(): + """Fixture to provide a verification token for testing.""" + return create_verification_token() + + +def test_create_verification_token_type(verification_token): + """ + Test the `create_verification_token` function to ensure it generates a token + with the expected structure and content. + """ + assert isinstance(verification_token, str) + + +def test_create_verification_token_length(verification_token): + """ + Test the length of the token generated by `create_verification_token`. + The token should be URL-safe and typically 43 characters long. + """ + assert len(verification_token) == 43 + + +def test_create_verification_token_format(verification_token): + """ + Test the format of the token generated by `create_verification_token`. + The token should be URL-safe, containing alphanumeric characters and hyphens. + """ + assert re.match(r"^[A-Za-z0-9_-]+$", verification_token) diff --git a/tests/test_utility/test_encrypt_field.py b/tests/test_utility/test_encrypt_field.py new file mode 100644 index 0000000..93fe647 --- /dev/null +++ b/tests/test_utility/test_encrypt_field.py @@ -0,0 +1,52 @@ +import re + +import pytest + +from app.utility.security import decrypt_field, encrypt_email, encrypt_phone + + +@pytest.fixture +def encrypted_email(sample_email): + """Fixture to provide an encrypted email for testing.""" + return encrypt_email(sample_email) + + +@pytest.fixture +def encrypted_phone(sample_phone): + """Fixture to provide an encrypted phone number for testing.""" + return encrypt_phone(sample_phone) + + +def test_encrypt_field_type(encrypted_email, encrypted_phone): + """ + Test that the encrypted fields are of type str. + This ensures that the encryption function returns a string. + """ + assert encrypted_email.isascii() + assert encrypted_phone.isascii() + + +def test_encrypt_field_length(encrypted_email, encrypted_phone): + """ + Test that the encrypted fields are not empty. + This ensures that the encryption does not produce empty strings. + """ + assert len(encrypted_email) > 0 + assert len(encrypted_phone) > 0 + + +def test_encrypt_field_format(encrypted_email, encrypted_phone): + """ + Test that the encrypted fields are in the expected format. + Encrypted fields should be base64 encoded strings. + """ + assert re.match(r"^[A-Za-z0-9+/=]+$", encrypted_email) + assert re.match(r"^[A-Za-z0-9+/=]+$", encrypted_phone) + + +def test_decrypt_email(sample_email, encrypted_email): + assert decrypt_field(encrypted_email) == sample_email + + +def test_decrypt_phone(sample_phone, encrypted_phone): + assert decrypt_field(encrypted_phone) == sample_phone diff --git a/tests/test_utility/test_hash_field.py b/tests/test_utility/test_hash_field.py new file mode 100644 index 0000000..a5fe4a3 --- /dev/null +++ b/tests/test_utility/test_hash_field.py @@ -0,0 +1,44 @@ +import re + +import pytest + +from app.utility.security import hash_email, hash_phone + + +@pytest.fixture +def hashed_email(sample_email): + """Fixture to provide a hashed email for testing.""" + return hash_email(sample_email) + + +@pytest.fixture +def hashed_phone(sample_phone): + """Fixture to provide a hashed phone number for testing.""" + return hash_phone(sample_phone) + + +def test_hash_field_type(hashed_email, hashed_phone): + """ + Test that the hashed fields are of type str. + This ensures that the hash function returns a string. + """ + assert isinstance(hashed_email, str) + assert isinstance(hashed_phone, str) + + +def test_hash_field_length(hashed_email, hashed_phone): + """ + Test that the hashed fields are precisely 64 characters long. + This is the expected length for SHA-256 hashes. + """ + assert len(hashed_email) == 64 + assert len(hashed_phone) == 64 + + +def test_hash_field_format(hashed_email, hashed_phone): + """ + Test that the hashed fields are in the expected format. + Hashed fields should be hexadecimal strings. + """ + assert re.match(r"^[0-9a-f]{64}$", hashed_email) + assert re.match(r"^[0-9a-f]{64}$", hashed_phone) diff --git a/tests/test_utility/test_sanitize_username.py b/tests/test_utility/test_sanitize_username.py new file mode 100644 index 0000000..9306a74 --- /dev/null +++ b/tests/test_utility/test_sanitize_username.py @@ -0,0 +1,24 @@ +import pytest + +from app.utility.string_utils import sanitize_username + + +@pytest.mark.parametrize( + "input_username,expected", + [ + ("validUser_123", "validUser_123"), + ("user.name", "user_name"), + ("user-name", "user_name"), + ("user name", "user_name"), + ("user@domain.com", "user_domain_com"), + ("user!$%^&*()", "user________"), + ("", ""), + ("___", "___"), + ("user__name", "user__name"), + ("user\nname", "user_name"), + ("user\tname", "user_name"), + ("user/\\name", "user__name"), + ], +) +def test_sanitize_username(input_username, expected): + assert sanitize_username(input_username) == expected diff --git a/tests/test_utility/test_verify_otp.py b/tests/test_utility/test_verify_otp.py new file mode 100644 index 0000000..ef73562 --- /dev/null +++ b/tests/test_utility/test_verify_otp.py @@ -0,0 +1,59 @@ +import pyotp +import pytest + +from app.utility.security import verify_otp + + +@pytest.fixture +def totp_secret(): + return pyotp.random_base32() + + +@pytest.fixture +def hotp_secret(): + return pyotp.random_base32() + + +@pytest.fixture +def unsupported_secret(): + return pyotp.random_base32() + + +def test_verify_otp_valid_totp(totp_secret): + totp = pyotp.TOTP(totp_secret) + otp_code = totp.now() + assert verify_otp(totp_secret, otp_code, "TOTP") is True + + +def test_verify_otp_invalid_totp(totp_secret): + otp_code = "000000" + assert verify_otp(totp_secret, otp_code, "TOTP") is False + + +def test_verify_otp_valid_hotp(hotp_secret): + hotp = pyotp.HOTP(hotp_secret) + counter = 0 + otp_code = hotp.at(counter) + assert verify_otp(hotp_secret, otp_code, "HOTP", counter) is True + + +def test_verify_otp_invalid_hotp(hotp_secret): + counter = 0 + otp_code = "000000" + assert verify_otp(hotp_secret, otp_code, "HOTP", counter) is False + + +def test_verify_otp_unsupported_method(unsupported_secret): + otp_code = "123456" + assert verify_otp(unsupported_secret, otp_code, "SMS") is False + + +def test_verify_otp_invalid_secret(): + secret = "not_a_valid_secret" + otp_code = "123456" + assert verify_otp(secret, otp_code, "TOTP") is False + + +def test_verify_otp_invalid_code_type(totp_secret): + otp_code = None + assert verify_otp(totp_secret, otp_code, "TOTP") is False diff --git a/tests/test_utility/test_verify_password.py b/tests/test_utility/test_verify_password.py new file mode 100644 index 0000000..81e97f8 --- /dev/null +++ b/tests/test_utility/test_verify_password.py @@ -0,0 +1,57 @@ +import pytest +from argon2 import PasswordHasher + +from app.utility.security import hash_password, verify_password + +ph = PasswordHasher() + + +@pytest.fixture +def sample_password(): + """Fixture to provide a sample password for testing.""" + return "TestPass123!" + + +@pytest.fixture +def hashed_password(sample_password): + """Fixture to provide a hashed password for testing.""" + return hash_password(sample_password) + + +@pytest.fixture +def wrong_password(): + """Fixture to provide a wrong password for testing.""" + return "WrongPass123!" + + +def test_verify_password(sample_password, hashed_password): + """ + Test that the password verification works correctly. + This ensures that the password can be hashed and then verified successfully. + """ + assert verify_password(sample_password, hashed_password) is True + + +def test_verify_wrong_password(hashed_password, wrong_password): + """ + Test that the password verification fails for a wrong password. + This ensures that the verification function does not falsely accept incorrect passwords. + """ + assert verify_password(wrong_password, hashed_password) is False + + +def test_verify_empty_password(hashed_password): + """ + Test that the password verification fails for an empty password. + This ensures that the verification function does not accept empty strings as valid passwords. + """ + assert verify_password("", hashed_password) is False + + +def test_verify_password_tampered(sample_password, hashed_password): + """ + Test that the password verification fails if the hash is tampered with. + This ensures that the verification function detects modifications to the hash. + """ + tampered_hash = hashed_password[:-5] + "xyz" # Modify the hash slightly + assert verify_password(sample_password, tampered_hash) is False diff --git a/tests/utils/request.py b/tests/utils/request.py deleted file mode 100644 index d685c0b..0000000 --- a/tests/utils/request.py +++ /dev/null @@ -1,51 +0,0 @@ -""" -Utility functions to simplify versioned API requests in tests. - -These helpers reduce duplication of versioned API paths like `/api/v1/...`, -improving readability and consistency in test files. -""" - -from fastapi.testclient import TestClient - - -def api_get(client: TestClient, version: str, path: str) -> TestClient: - """ - Perform a GET request to a versioned API path. - - Args: - client (TestClient): FastAPI test client. - version (str): API version, e.g., 'v1' or 'v2'. - path (str): Path to append after the version, e.g., '/orders'. - - Returns: - Response: FastAPI test client response. - """ - return client.get(f"/api/{version}{path}") - - -def v1_get(client: TestClient, path: str) -> TestClient: - """ - Perform a GET request to a v1 API endpoint. - - Args: - client (TestClient): FastAPI test client. - path (str): Path to append after `/api/v1`. - - Returns: - Response: FastAPI test client response. - """ - return api_get(client, "v1", path) - - -def v2_get(client: TestClient, path: str) -> TestClient: - """ - Perform a GET request to a v2 API endpoint. - - Args: - client (TestClient): FastAPI test client. - path (str): Path to append after `/api/v2`. - - Returns: - Response: FastAPI test client response. - """ - return api_get(client, "v2", path)